Xuweiyi commited on
Commit
ee823b7
1 Parent(s): f89fbbb

Upload 94 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE.txt +201 -0
  2. animatediff/.DS_Store +0 -0
  3. animatediff/data/dataset.py +98 -0
  4. animatediff/models/attention.py +398 -0
  5. animatediff/models/motion_module.py +353 -0
  6. animatediff/models/resnet.py +217 -0
  7. animatediff/models/unet.py +526 -0
  8. animatediff/models/unet_blocks.py +764 -0
  9. animatediff/pipelines/pipeline_animation.py +746 -0
  10. animatediff/utils/convert_from_ckpt.py +1216 -0
  11. animatediff/utils/convert_lora_safetensor_to_diffusers.py +154 -0
  12. animatediff/utils/util.py +237 -0
  13. app.py +715 -0
  14. configs/.DS_Store +0 -0
  15. configs/eval0.yaml +4 -0
  16. configs/eval1.yaml +5 -0
  17. configs/inference/.ipynb_checkpoints/inference-v1-checkpoint.yaml +26 -0
  18. configs/inference/inference-v1.yaml +26 -0
  19. configs/inference/inference-v2.yaml +27 -0
  20. configs/prompts/.DS_Store +0 -0
  21. configs/prompts/1-ToonYou.yaml +23 -0
  22. configs/prompts/2-Lyriel.yaml +23 -0
  23. configs/prompts/3-RcnzCartoon.yaml +23 -0
  24. configs/prompts/4-MajicMix.yaml +23 -0
  25. configs/prompts/5-RealisticVision.yaml +23 -0
  26. configs/prompts/6-Tusun.yaml +21 -0
  27. configs/prompts/7-FilmVelvia.yaml +24 -0
  28. configs/prompts/8-GhibliBackground.yaml +21 -0
  29. configs/prompts/unictrl_examples/RealisticVision_v1.yaml +30 -0
  30. configs/prompts/unictrl_examples/RealisticVision_v2.yaml +35 -0
  31. configs/prompts/v2/5-RealisticVision-MotionLoRA.yaml +189 -0
  32. configs/prompts/v2/5-RealisticVision.yaml +23 -0
  33. configs/training/image_finetune.yaml +48 -0
  34. configs/training/training.yaml +66 -0
  35. download_bashscripts/0-MotionModule.sh +2 -0
  36. download_bashscripts/1-ToonYou.sh +2 -0
  37. download_bashscripts/2-Lyriel.sh +2 -0
  38. download_bashscripts/3-RcnzCartoon.sh +2 -0
  39. download_bashscripts/4-MajicMix.sh +2 -0
  40. download_bashscripts/5-RealisticVision.sh +2 -0
  41. download_bashscripts/6-Tusun.sh +3 -0
  42. download_bashscripts/7-FilmVelvia.sh +3 -0
  43. download_bashscripts/8-GhibliBackground.sh +3 -0
  44. environment.yaml +25 -0
  45. models/.DS_Store +0 -0
  46. models/DreamBooth_LoRA/Put personalized T2I checkpoints here.txt +0 -0
  47. models/DreamBooth_LoRA/lyriel_v16.safetensors +3 -0
  48. models/DreamBooth_LoRA/majicmixRealistic_v5Preview.safetensors +3 -0
  49. models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors +3 -0
  50. models/DreamBooth_LoRA/realisticVisionV60B1_v20Novae.safetensors +3 -0
LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
animatediff/.DS_Store ADDED
Binary file (6.15 kB). View file
 
animatediff/data/dataset.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, csv, math, random
2
+ import numpy as np
3
+ from einops import rearrange
4
+ from decord import VideoReader
5
+
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ from torch.utils.data.dataset import Dataset
9
+ from animatediff.utils.util import zero_rank_print
10
+
11
+
12
+
13
+ class WebVid10M(Dataset):
14
+ def __init__(
15
+ self,
16
+ csv_path, video_folder,
17
+ sample_size=256, sample_stride=4, sample_n_frames=16,
18
+ is_image=False,
19
+ ):
20
+ zero_rank_print(f"loading annotations from {csv_path} ...")
21
+ with open(csv_path, 'r') as csvfile:
22
+ self.dataset = list(csv.DictReader(csvfile))
23
+ self.length = len(self.dataset)
24
+ zero_rank_print(f"data scale: {self.length}")
25
+
26
+ self.video_folder = video_folder
27
+ self.sample_stride = sample_stride
28
+ self.sample_n_frames = sample_n_frames
29
+ self.is_image = is_image
30
+
31
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
32
+ self.pixel_transforms = transforms.Compose([
33
+ transforms.RandomHorizontalFlip(),
34
+ transforms.Resize(sample_size[0]),
35
+ transforms.CenterCrop(sample_size),
36
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
37
+ ])
38
+
39
+ def get_batch(self, idx):
40
+ video_dict = self.dataset[idx]
41
+ videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
42
+
43
+ video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
44
+ video_reader = VideoReader(video_dir)
45
+ video_length = len(video_reader)
46
+
47
+ if not self.is_image:
48
+ clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
49
+ start_idx = random.randint(0, video_length - clip_length)
50
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
51
+ else:
52
+ batch_index = [random.randint(0, video_length - 1)]
53
+
54
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
55
+ pixel_values = pixel_values / 255.
56
+ del video_reader
57
+
58
+ if self.is_image:
59
+ pixel_values = pixel_values[0]
60
+
61
+ return pixel_values, name
62
+
63
+ def __len__(self):
64
+ return self.length
65
+
66
+ def __getitem__(self, idx):
67
+ while True:
68
+ try:
69
+ pixel_values, name = self.get_batch(idx)
70
+ break
71
+
72
+ except Exception as e:
73
+ idx = random.randint(0, self.length-1)
74
+
75
+ pixel_values = self.pixel_transforms(pixel_values)
76
+ sample = dict(pixel_values=pixel_values, text=name)
77
+ return sample
78
+
79
+
80
+
81
+ if __name__ == "__main__":
82
+ from animatediff.utils.util import save_videos_grid
83
+
84
+ dataset = WebVid10M(
85
+ csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
86
+ video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
87
+ sample_size=256,
88
+ sample_stride=4, sample_n_frames=16,
89
+ is_image=True,
90
+ )
91
+ import pdb
92
+ pdb.set_trace()
93
+
94
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,)
95
+ for idx, batch in enumerate(dataloader):
96
+ print(batch["pixel_values"].shape, len(batch["text"]))
97
+ # for i in range(batch["pixel_values"].shape[0]):
98
+ # save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True)
animatediff/models/attention.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers import ModelMixin
12
+ from diffusers.utils import BaseOutput, USE_PEFT_BACKEND
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+ from diffusers.models.attention import Attention, FeedForward, AdaLayerNorm
15
+
16
+ from einops import rearrange, repeat
17
+ import pdb
18
+
19
+ @dataclass
20
+ class Transformer3DModelOutput(BaseOutput):
21
+ sample: torch.FloatTensor
22
+
23
+
24
+ if is_xformers_available():
25
+ import xformers
26
+ import xformers.ops
27
+ else:
28
+ xformers = None
29
+
30
+
31
+ class Transformer3DModel(ModelMixin, ConfigMixin):
32
+ @register_to_config
33
+ def __init__(
34
+ self,
35
+ num_attention_heads: int = 16,
36
+ attention_head_dim: int = 88,
37
+ in_channels: Optional[int] = None,
38
+ num_layers: int = 1,
39
+ dropout: float = 0.0,
40
+ norm_num_groups: int = 32,
41
+ cross_attention_dim: Optional[int] = None,
42
+ attention_bias: bool = False,
43
+ activation_fn: str = "geglu",
44
+ num_embeds_ada_norm: Optional[int] = None,
45
+ use_linear_projection: bool = False,
46
+ only_cross_attention: bool = False,
47
+ upcast_attention: bool = False,
48
+
49
+ unet_use_cross_frame_attention=None,
50
+ unet_use_temporal_attention=None,
51
+ ):
52
+ super().__init__()
53
+ self.use_linear_projection = use_linear_projection
54
+ self.num_attention_heads = num_attention_heads
55
+ self.attention_head_dim = attention_head_dim
56
+ inner_dim = num_attention_heads * attention_head_dim
57
+
58
+ # Define input layers
59
+ self.in_channels = in_channels
60
+
61
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
62
+ if use_linear_projection:
63
+ self.proj_in = nn.Linear(in_channels, inner_dim)
64
+ else:
65
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
66
+
67
+ # Define transformers blocks
68
+ self.transformer_blocks = nn.ModuleList(
69
+ [
70
+ BasicTransformerBlock(
71
+ inner_dim,
72
+ num_attention_heads,
73
+ attention_head_dim,
74
+ dropout=dropout,
75
+ cross_attention_dim=cross_attention_dim,
76
+ activation_fn=activation_fn,
77
+ num_embeds_ada_norm=num_embeds_ada_norm,
78
+ attention_bias=attention_bias,
79
+ only_cross_attention=only_cross_attention,
80
+ upcast_attention=upcast_attention,
81
+
82
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
83
+ unet_use_temporal_attention=unet_use_temporal_attention,
84
+ )
85
+ for d in range(num_layers)
86
+ ]
87
+ )
88
+
89
+ # 4. Define output layers
90
+ if use_linear_projection:
91
+ self.proj_out = nn.Linear(in_channels, inner_dim)
92
+ else:
93
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
94
+
95
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
96
+ # Input
97
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
98
+ video_length = hidden_states.shape[2]
99
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
100
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
101
+
102
+ batch, channel, height, weight = hidden_states.shape
103
+ residual = hidden_states
104
+
105
+ hidden_states = self.norm(hidden_states)
106
+ if not self.use_linear_projection:
107
+ hidden_states = self.proj_in(hidden_states)
108
+ inner_dim = hidden_states.shape[1]
109
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
110
+ else:
111
+ inner_dim = hidden_states.shape[1]
112
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
113
+ hidden_states = self.proj_in(hidden_states)
114
+
115
+ # Blocks
116
+ for block in self.transformer_blocks:
117
+ hidden_states = block(
118
+ hidden_states,
119
+ encoder_hidden_states=encoder_hidden_states,
120
+ timestep=timestep,
121
+ video_length=video_length
122
+ )
123
+
124
+ # Output
125
+ if not self.use_linear_projection:
126
+ hidden_states = (
127
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
128
+ )
129
+ hidden_states = self.proj_out(hidden_states)
130
+ else:
131
+ hidden_states = self.proj_out(hidden_states)
132
+ hidden_states = (
133
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
134
+ )
135
+
136
+ output = hidden_states + residual
137
+
138
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
139
+ if not return_dict:
140
+ return (output,)
141
+
142
+ return Transformer3DModelOutput(sample=output)
143
+
144
+
145
+ class BasicTransformerBlock(nn.Module):
146
+ def __init__(
147
+ self,
148
+ dim: int,
149
+ num_attention_heads: int,
150
+ attention_head_dim: int,
151
+ dropout=0.0,
152
+ cross_attention_dim: Optional[int] = None,
153
+ activation_fn: str = "geglu",
154
+ num_embeds_ada_norm: Optional[int] = None,
155
+ attention_bias: bool = False,
156
+ only_cross_attention: bool = False,
157
+ upcast_attention: bool = False,
158
+
159
+ unet_use_cross_frame_attention = None,
160
+ unet_use_temporal_attention = None,
161
+ ):
162
+ super().__init__()
163
+ self.only_cross_attention = only_cross_attention
164
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
165
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
166
+ self.unet_use_temporal_attention = unet_use_temporal_attention
167
+
168
+ # SC-Attn
169
+ assert unet_use_cross_frame_attention is not None
170
+ if unet_use_cross_frame_attention:
171
+ self.attn1 = SparseCausalAttention2D(
172
+ query_dim=dim,
173
+ heads=num_attention_heads,
174
+ dim_head=attention_head_dim,
175
+ dropout=dropout,
176
+ bias=attention_bias,
177
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
178
+ upcast_attention=upcast_attention,
179
+ )
180
+ else:
181
+ self.attn1 = Attention(
182
+ query_dim=dim,
183
+ heads=num_attention_heads,
184
+ dim_head=attention_head_dim,
185
+ dropout=dropout,
186
+ bias=attention_bias,
187
+ upcast_attention=upcast_attention,
188
+ )
189
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
190
+
191
+ # Cross-Attn
192
+ if cross_attention_dim is not None:
193
+ self.attn2 = Attention(
194
+ query_dim=dim,
195
+ cross_attention_dim=cross_attention_dim,
196
+ heads=num_attention_heads,
197
+ dim_head=attention_head_dim,
198
+ dropout=dropout,
199
+ bias=attention_bias,
200
+ upcast_attention=upcast_attention,
201
+ )
202
+ else:
203
+ self.attn2 = None
204
+
205
+ if cross_attention_dim is not None:
206
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
207
+ else:
208
+ self.norm2 = None
209
+
210
+ processor = CustomizedAttnProcessor2_0()
211
+ self.attn1.set_processor(processor)
212
+ self.attn2.set_processor(processor)
213
+
214
+ # Feed-forward
215
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
216
+ self.norm3 = nn.LayerNorm(dim)
217
+
218
+ # Temp-Attn
219
+ assert unet_use_temporal_attention is not None
220
+ if unet_use_temporal_attention:
221
+ self.attn_temp = Attention(
222
+ query_dim=dim,
223
+ heads=num_attention_heads,
224
+ dim_head=attention_head_dim,
225
+ dropout=dropout,
226
+ bias=attention_bias,
227
+ upcast_attention=upcast_attention,
228
+ )
229
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
230
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
231
+
232
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, *args, **kwargs):
233
+ if not is_xformers_available():
234
+ print("Here is how to install it")
235
+ raise ModuleNotFoundError(
236
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
237
+ " xformers",
238
+ name="xformers",
239
+ )
240
+ elif not torch.cuda.is_available():
241
+ raise ValueError(
242
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
243
+ " available for GPU "
244
+ )
245
+ else:
246
+ try:
247
+ # Make sure we can run the memory efficient attention
248
+ _ = xformers.ops.memory_efficient_attention(
249
+ torch.randn((1, 2, 40), device="cuda"),
250
+ torch.randn((1, 2, 40), device="cuda"),
251
+ torch.randn((1, 2, 40), device="cuda"),
252
+ )
253
+ except Exception as e:
254
+ raise e
255
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
256
+ if self.attn2 is not None:
257
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
258
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
259
+
260
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
261
+ # SparseCausal-Attention
262
+ norm_hidden_states = (
263
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
264
+ )
265
+
266
+ # if self.only_cross_attention:
267
+ # hidden_states = (
268
+ # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
269
+ # )
270
+ # else:
271
+ # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
272
+
273
+ # pdb.set_trace()
274
+ if self.unet_use_cross_frame_attention:
275
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
276
+ else:
277
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
278
+
279
+ if self.attn2 is not None:
280
+ # Cross-Attention
281
+ norm_hidden_states = (
282
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
283
+ )
284
+ hidden_states = (
285
+ self.attn2(
286
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
287
+ )
288
+ + hidden_states
289
+ )
290
+
291
+ # Feed-forward
292
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
293
+
294
+ # Temporal-Attention
295
+ if self.unet_use_temporal_attention:
296
+ d = hidden_states.shape[1]
297
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
298
+ norm_hidden_states = (
299
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
300
+ )
301
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
302
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
303
+
304
+ return hidden_states
305
+
306
+
307
+ class CustomizedAttnProcessor2_0:
308
+ r"""
309
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
310
+ """
311
+
312
+ def __init__(self):
313
+ if not hasattr(F, "scaled_dot_product_attention"):
314
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
315
+
316
+ def __call__(
317
+ self,
318
+ attn: Attention,
319
+ hidden_states: torch.FloatTensor,
320
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
321
+ k_input: Optional[torch.FloatTensor] = None,
322
+ v_input: Optional[torch.FloatTensor] = None,
323
+ attention_mask: Optional[torch.FloatTensor] = None,
324
+ temb: Optional[torch.FloatTensor] = None,
325
+ scale: float = 1.0,
326
+ ) -> torch.FloatTensor:
327
+ residual = hidden_states
328
+ if attn.spatial_norm is not None:
329
+ hidden_states = attn.spatial_norm(hidden_states, temb)
330
+
331
+ input_ndim = hidden_states.ndim
332
+
333
+ if input_ndim == 4:
334
+ batch_size, channel, height, width = hidden_states.shape
335
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
336
+
337
+ batch_size, sequence_length, _ = (
338
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
339
+ )
340
+
341
+ if attention_mask is not None:
342
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
343
+ # scaled_dot_product_attention expects attention_mask shape to be
344
+ # (batch, heads, source_length, target_length)
345
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
346
+
347
+ if attn.group_norm is not None:
348
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
349
+
350
+ args = () if USE_PEFT_BACKEND else (scale,)
351
+ query = attn.to_q(hidden_states, *args)
352
+
353
+ if encoder_hidden_states is None:
354
+ encoder_hidden_states = hidden_states
355
+ elif attn.norm_cross:
356
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
357
+
358
+ if k_input is not None:
359
+ key = attn.to_k(k_input, *args)
360
+ else:
361
+ key = attn.to_k(encoder_hidden_states, *args)
362
+ if v_input is not None:
363
+ value = attn.to_v(v_input, *args)
364
+ else:
365
+ value = attn.to_v(encoder_hidden_states, *args)
366
+
367
+ inner_dim = key.shape[-1]
368
+ head_dim = inner_dim // attn.heads
369
+
370
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
371
+
372
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
373
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
374
+
375
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
376
+ # TODO: add support for attn.scale when we move to Torch 2.1
377
+ hidden_states = F.scaled_dot_product_attention(
378
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
379
+ )
380
+
381
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
382
+ hidden_states = hidden_states.to(query.dtype)
383
+
384
+ # linear proj
385
+ hidden_states = attn.to_out[0](hidden_states, *args)
386
+ # dropout
387
+ hidden_states = attn.to_out[1](hidden_states)
388
+
389
+ if input_ndim == 4:
390
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
391
+
392
+ if attn.residual_connection:
393
+ hidden_states = hidden_states + residual
394
+
395
+ hidden_states = hidden_states / attn.rescale_output_factor
396
+
397
+ return hidden_states
398
+
animatediff/models/motion_module.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ import torchvision
9
+ import diffusers
10
+ from packaging import version
11
+
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers import ModelMixin
14
+ from diffusers.utils import BaseOutput
15
+ from diffusers.utils.import_utils import is_xformers_available
16
+ from diffusers.models.attention import Attention, FeedForward
17
+
18
+ from einops import rearrange, repeat
19
+ import math
20
+
21
+
22
+ def zero_module(module):
23
+ # Zero out the parameters of a module and return it.
24
+ for p in module.parameters():
25
+ p.detach().zero_()
26
+ return module
27
+
28
+
29
+ @dataclass
30
+ class TemporalTransformer3DModelOutput(BaseOutput):
31
+ sample: torch.FloatTensor
32
+
33
+
34
+ if is_xformers_available():
35
+ import xformers
36
+ import xformers.ops
37
+ else:
38
+ xformers = None
39
+
40
+
41
+ def get_motion_module(
42
+ in_channels,
43
+ motion_module_type: str,
44
+ motion_module_kwargs: dict
45
+ ):
46
+ if motion_module_type == "Vanilla":
47
+ return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
48
+ else:
49
+ raise ValueError
50
+
51
+
52
+ class VanillaTemporalModule(nn.Module):
53
+ def __init__(
54
+ self,
55
+ in_channels,
56
+ num_attention_heads = 8,
57
+ num_transformer_block = 2,
58
+ attention_block_types =( "Temporal_Self", "Temporal_Self" ),
59
+ cross_frame_attention_mode = None,
60
+ temporal_position_encoding = False,
61
+ temporal_position_encoding_max_len = 24,
62
+ temporal_attention_dim_div = 1,
63
+ zero_initialize = True,
64
+ ):
65
+ super().__init__()
66
+
67
+ self.temporal_transformer = TemporalTransformer3DModel(
68
+ in_channels=in_channels,
69
+ num_attention_heads=num_attention_heads,
70
+ attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
71
+ num_layers=num_transformer_block,
72
+ attention_block_types=attention_block_types,
73
+ cross_frame_attention_mode=cross_frame_attention_mode,
74
+ temporal_position_encoding=temporal_position_encoding,
75
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
76
+ )
77
+
78
+ if zero_initialize:
79
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
80
+
81
+ def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
82
+ video_length = input_tensor.shape[2]
83
+
84
+ if video_length > 1:
85
+ hidden_states = input_tensor
86
+ hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
87
+ output = hidden_states
88
+ else:
89
+ output = input_tensor
90
+
91
+ return output
92
+
93
+
94
+ class TemporalTransformer3DModel(nn.Module):
95
+ def __init__(
96
+ self,
97
+ in_channels,
98
+ num_attention_heads,
99
+ attention_head_dim,
100
+
101
+ num_layers,
102
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
103
+ dropout = 0.0,
104
+ norm_num_groups = 32,
105
+ cross_attention_dim = 768,
106
+ activation_fn = "geglu",
107
+ attention_bias = False,
108
+ upcast_attention = False,
109
+
110
+ cross_frame_attention_mode = None,
111
+ temporal_position_encoding = False,
112
+ temporal_position_encoding_max_len = 24,
113
+ ):
114
+ super().__init__()
115
+
116
+ inner_dim = num_attention_heads * attention_head_dim
117
+
118
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
119
+ self.proj_in = nn.Linear(in_channels, inner_dim)
120
+
121
+ self.transformer_blocks = nn.ModuleList(
122
+ [
123
+ TemporalTransformerBlock(
124
+ dim=inner_dim,
125
+ num_attention_heads=num_attention_heads,
126
+ attention_head_dim=attention_head_dim,
127
+ attention_block_types=attention_block_types,
128
+ dropout=dropout,
129
+ norm_num_groups=norm_num_groups,
130
+ cross_attention_dim=cross_attention_dim,
131
+ activation_fn=activation_fn,
132
+ attention_bias=attention_bias,
133
+ upcast_attention=upcast_attention,
134
+ cross_frame_attention_mode=cross_frame_attention_mode,
135
+ temporal_position_encoding=temporal_position_encoding,
136
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
137
+ )
138
+ for d in range(num_layers)
139
+ ]
140
+ )
141
+ self.proj_out = nn.Linear(inner_dim, in_channels)
142
+
143
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
144
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
145
+ video_length = hidden_states.shape[2]
146
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
147
+
148
+ batch, channel, height, weight = hidden_states.shape
149
+ residual = hidden_states
150
+
151
+ hidden_states = self.norm(hidden_states)
152
+ inner_dim = hidden_states.shape[1]
153
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
154
+ hidden_states = self.proj_in(hidden_states)
155
+
156
+ # Transformer Blocks
157
+ for block in self.transformer_blocks:
158
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
159
+
160
+ # output
161
+ hidden_states = self.proj_out(hidden_states)
162
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
163
+
164
+ output = hidden_states + residual
165
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
166
+
167
+ return output
168
+
169
+
170
+ class TemporalTransformerBlock(nn.Module):
171
+ def __init__(
172
+ self,
173
+ dim,
174
+ num_attention_heads,
175
+ attention_head_dim,
176
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
177
+ dropout = 0.0,
178
+ norm_num_groups = 32,
179
+ cross_attention_dim = 768,
180
+ activation_fn = "geglu",
181
+ attention_bias = False,
182
+ upcast_attention = False,
183
+ cross_frame_attention_mode = None,
184
+ temporal_position_encoding = False,
185
+ temporal_position_encoding_max_len = 24,
186
+ ):
187
+ super().__init__()
188
+
189
+ attention_blocks = []
190
+ norms = []
191
+
192
+ for block_name in attention_block_types:
193
+ attention_blocks.append(
194
+ VersatileAttention(
195
+ attention_mode=block_name.split("_")[0],
196
+ cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
197
+
198
+ query_dim=dim,
199
+ heads=num_attention_heads,
200
+ dim_head=attention_head_dim,
201
+ dropout=dropout,
202
+ bias=attention_bias,
203
+ upcast_attention=upcast_attention,
204
+
205
+ cross_frame_attention_mode=cross_frame_attention_mode,
206
+ temporal_position_encoding=temporal_position_encoding,
207
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
208
+ )
209
+ )
210
+ norms.append(nn.LayerNorm(dim))
211
+
212
+ self.attention_blocks = nn.ModuleList(attention_blocks)
213
+ self.norms = nn.ModuleList(norms)
214
+
215
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
216
+ self.ff_norm = nn.LayerNorm(dim)
217
+
218
+
219
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
220
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
221
+ norm_hidden_states = norm(hidden_states)
222
+ hidden_states = attention_block(
223
+ norm_hidden_states,
224
+ encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
225
+ video_length=video_length,
226
+ ) + hidden_states
227
+
228
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
229
+
230
+ output = hidden_states
231
+ return output
232
+
233
+
234
+ class PositionalEncoding(nn.Module):
235
+ def __init__(
236
+ self,
237
+ d_model,
238
+ dropout = 0.,
239
+ max_len = 24
240
+ ):
241
+ super().__init__()
242
+ self.dropout = nn.Dropout(p=dropout)
243
+ position = torch.arange(max_len).unsqueeze(1)
244
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
245
+ pe = torch.zeros(1, max_len, d_model)
246
+ pe[0, :, 0::2] = torch.sin(position * div_term)
247
+ pe[0, :, 1::2] = torch.cos(position * div_term)
248
+ self.register_buffer('pe', pe)
249
+
250
+ def forward(self, x):
251
+ x = x + self.pe[:, :x.size(1)]
252
+ return self.dropout(x)
253
+
254
+
255
+ class VersatileAttention(Attention):
256
+ def __init__(
257
+ self,
258
+ attention_mode = None,
259
+ cross_frame_attention_mode = None,
260
+ temporal_position_encoding = False,
261
+ temporal_position_encoding_max_len = 24,
262
+ *args, **kwargs
263
+ ):
264
+ super().__init__(*args, **kwargs)
265
+ assert attention_mode == "Temporal"
266
+
267
+ self.attention_mode = attention_mode
268
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
269
+
270
+ self.pos_encoder = PositionalEncoding(
271
+ kwargs["query_dim"],
272
+ dropout=0.,
273
+ max_len=temporal_position_encoding_max_len
274
+ ) if (temporal_position_encoding and attention_mode == "Temporal") else None
275
+
276
+ def extra_repr(self):
277
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
278
+
279
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
280
+ batch_size, sequence_length, _ = hidden_states.shape
281
+
282
+ if self.attention_mode == "Temporal":
283
+ d = hidden_states.shape[1]
284
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
285
+
286
+ if self.pos_encoder is not None:
287
+ hidden_states = self.pos_encoder(hidden_states)
288
+
289
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
290
+ else:
291
+ raise NotImplementedError
292
+
293
+ encoder_hidden_states = encoder_hidden_states
294
+
295
+ if version.parse(diffusers.__version__) > version.parse("0.11.1"):
296
+ hidden_states = self.processor(self, hidden_states, encoder_hidden_states)
297
+ else:
298
+ if self.group_norm is not None:
299
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
300
+
301
+ query = self.to_q(hidden_states)
302
+ dim = query.shape[-1]
303
+ query = self.head_to_batch_dim(query)
304
+
305
+ if self.added_kv_proj_dim is not None:
306
+ raise NotImplementedError
307
+
308
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
309
+ key = self.to_k(encoder_hidden_states)
310
+ value = self.to_v(encoder_hidden_states)
311
+
312
+ key = self.head_to_batch_dim(key)
313
+ value = self.head_to_batch_dim(value)
314
+
315
+ if attention_mask is not None:
316
+ if attention_mask.shape[-1] != query.shape[1]:
317
+ target_length = query.shape[1]
318
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
319
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
320
+
321
+ # attention, what we cannot get enough of
322
+
323
+ if self._use_memory_efficient_attention_xformers:
324
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
325
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
326
+ hidden_states = hidden_states.to(query.dtype)
327
+ else:
328
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
329
+ hidden_states = self._attention(query, key, value, attention_mask)
330
+ else:
331
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
332
+ else:
333
+ #if "xformers" in self.processor.__class__.__name__.lower():
334
+ # hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attention_mask)
335
+ # # Some versions of xformers return output in fp32, cast it back to the dtype of the input
336
+ # hidden_states = hidden_states.to(query.dtype)
337
+ #else:
338
+ hidden_states = F.scaled_dot_product_attention(
339
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
340
+ )
341
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
342
+ hidden_states = hidden_states.to(query.dtype)
343
+
344
+ # linear proj
345
+ hidden_states = self.to_out[0](hidden_states)
346
+
347
+ # dropout
348
+ hidden_states = self.to_out[1](hidden_states)
349
+
350
+ if self.attention_mode == "Temporal":
351
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
352
+
353
+ return hidden_states
animatediff/models/resnet.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from einops import rearrange
8
+
9
+
10
+ class InflatedConv3d(nn.Conv2d):
11
+ def forward(self, x):
12
+ video_length = x.shape[2]
13
+
14
+ x = rearrange(x, "b c f h w -> (b f) c h w")
15
+ x = super().forward(x)
16
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17
+
18
+ return x
19
+
20
+
21
+ class InflatedGroupNorm(nn.GroupNorm):
22
+ def forward(self, x):
23
+ video_length = x.shape[2]
24
+
25
+ x = rearrange(x, "b c f h w -> (b f) c h w")
26
+ x = super().forward(x)
27
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
28
+
29
+ return x
30
+
31
+
32
+ class Upsample3D(nn.Module):
33
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
34
+ super().__init__()
35
+ self.channels = channels
36
+ self.out_channels = out_channels or channels
37
+ self.use_conv = use_conv
38
+ self.use_conv_transpose = use_conv_transpose
39
+ self.name = name
40
+
41
+ conv = None
42
+ if use_conv_transpose:
43
+ raise NotImplementedError
44
+ elif use_conv:
45
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
46
+
47
+ def forward(self, hidden_states, output_size=None):
48
+ assert hidden_states.shape[1] == self.channels
49
+
50
+ if self.use_conv_transpose:
51
+ raise NotImplementedError
52
+
53
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
54
+ dtype = hidden_states.dtype
55
+ if dtype == torch.bfloat16:
56
+ hidden_states = hidden_states.to(torch.float32)
57
+
58
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
59
+ if hidden_states.shape[0] >= 64:
60
+ hidden_states = hidden_states.contiguous()
61
+
62
+ # if `output_size` is passed we force the interpolation output
63
+ # size and do not make use of `scale_factor=2`
64
+ if output_size is None:
65
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
66
+ else:
67
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
68
+
69
+ # If the input is bfloat16, we cast back to bfloat16
70
+ if dtype == torch.bfloat16:
71
+ hidden_states = hidden_states.to(dtype)
72
+
73
+ # if self.use_conv:
74
+ # if self.name == "conv":
75
+ # hidden_states = self.conv(hidden_states)
76
+ # else:
77
+ # hidden_states = self.Conv2d_0(hidden_states)
78
+ hidden_states = self.conv(hidden_states)
79
+
80
+ return hidden_states
81
+
82
+
83
+ class Downsample3D(nn.Module):
84
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
85
+ super().__init__()
86
+ self.channels = channels
87
+ self.out_channels = out_channels or channels
88
+ self.use_conv = use_conv
89
+ self.padding = padding
90
+ stride = 2
91
+ self.name = name
92
+
93
+ if use_conv:
94
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ def forward(self, hidden_states):
99
+ assert hidden_states.shape[1] == self.channels
100
+ if self.use_conv and self.padding == 0:
101
+ raise NotImplementedError
102
+
103
+ assert hidden_states.shape[1] == self.channels
104
+ hidden_states = self.conv(hidden_states)
105
+
106
+ return hidden_states
107
+
108
+
109
+ class ResnetBlock3D(nn.Module):
110
+ def __init__(
111
+ self,
112
+ *,
113
+ in_channels,
114
+ out_channels=None,
115
+ conv_shortcut=False,
116
+ dropout=0.0,
117
+ temb_channels=512,
118
+ groups=32,
119
+ groups_out=None,
120
+ pre_norm=True,
121
+ eps=1e-6,
122
+ non_linearity="swish",
123
+ time_embedding_norm="default",
124
+ output_scale_factor=1.0,
125
+ use_in_shortcut=None,
126
+ use_inflated_groupnorm=False,
127
+ ):
128
+ super().__init__()
129
+ self.pre_norm = pre_norm
130
+ self.pre_norm = True
131
+ self.in_channels = in_channels
132
+ out_channels = in_channels if out_channels is None else out_channels
133
+ self.out_channels = out_channels
134
+ self.use_conv_shortcut = conv_shortcut
135
+ self.time_embedding_norm = time_embedding_norm
136
+ self.output_scale_factor = output_scale_factor
137
+
138
+ if groups_out is None:
139
+ groups_out = groups
140
+
141
+ assert use_inflated_groupnorm != None
142
+ if use_inflated_groupnorm:
143
+ self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
144
+ else:
145
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
146
+
147
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
148
+
149
+ if temb_channels is not None:
150
+ if self.time_embedding_norm == "default":
151
+ time_emb_proj_out_channels = out_channels
152
+ elif self.time_embedding_norm == "scale_shift":
153
+ time_emb_proj_out_channels = out_channels * 2
154
+ else:
155
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
156
+
157
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
158
+ else:
159
+ self.time_emb_proj = None
160
+
161
+ if use_inflated_groupnorm:
162
+ self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
163
+ else:
164
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
165
+
166
+ self.dropout = torch.nn.Dropout(dropout)
167
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
168
+
169
+ if non_linearity == "swish":
170
+ self.nonlinearity = lambda x: F.silu(x)
171
+ elif non_linearity == "mish":
172
+ self.nonlinearity = Mish()
173
+ elif non_linearity == "silu":
174
+ self.nonlinearity = nn.SiLU()
175
+
176
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
177
+
178
+ self.conv_shortcut = None
179
+ if self.use_in_shortcut:
180
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
181
+
182
+ def forward(self, input_tensor, temb):
183
+ hidden_states = input_tensor
184
+
185
+ hidden_states = self.norm1(hidden_states)
186
+ hidden_states = self.nonlinearity(hidden_states)
187
+
188
+ hidden_states = self.conv1(hidden_states)
189
+
190
+ if temb is not None:
191
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
192
+
193
+ if temb is not None and self.time_embedding_norm == "default":
194
+ hidden_states = hidden_states + temb
195
+
196
+ hidden_states = self.norm2(hidden_states)
197
+
198
+ if temb is not None and self.time_embedding_norm == "scale_shift":
199
+ scale, shift = torch.chunk(temb, 2, dim=1)
200
+ hidden_states = hidden_states * (1 + scale) + shift
201
+
202
+ hidden_states = self.nonlinearity(hidden_states)
203
+
204
+ hidden_states = self.dropout(hidden_states)
205
+ hidden_states = self.conv2(hidden_states)
206
+
207
+ if self.conv_shortcut is not None:
208
+ input_tensor = self.conv_shortcut(input_tensor)
209
+
210
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
211
+
212
+ return output_tensor
213
+
214
+
215
+ class Mish(torch.nn.Module):
216
+ def forward(self, hidden_states):
217
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
animatediff/models/unet.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import os
7
+ import json
8
+ import pdb
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.utils.checkpoint
13
+
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers import ModelMixin
16
+ from diffusers.utils import BaseOutput, logging
17
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
18
+ from .unet_blocks import (
19
+ CrossAttnDownBlock3D,
20
+ CrossAttnUpBlock3D,
21
+ DownBlock3D,
22
+ UNetMidBlock3DCrossAttn,
23
+ UpBlock3D,
24
+ get_down_block,
25
+ get_up_block,
26
+ )
27
+ from .resnet import InflatedConv3d, InflatedGroupNorm
28
+
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ @dataclass
34
+ class UNet3DConditionOutput(BaseOutput):
35
+ sample: torch.FloatTensor
36
+
37
+
38
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
39
+ _supports_gradient_checkpointing = True
40
+
41
+ @register_to_config
42
+ def __init__(
43
+ self,
44
+ sample_size: Optional[int] = None,
45
+ in_channels: int = 4,
46
+ out_channels: int = 4,
47
+ center_input_sample: bool = False,
48
+ flip_sin_to_cos: bool = True,
49
+ freq_shift: int = 0,
50
+ down_block_types: Tuple[str] = (
51
+ "CrossAttnDownBlock3D",
52
+ "CrossAttnDownBlock3D",
53
+ "CrossAttnDownBlock3D",
54
+ "DownBlock3D",
55
+ ),
56
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
57
+ up_block_types: Tuple[str] = (
58
+ "UpBlock3D",
59
+ "CrossAttnUpBlock3D",
60
+ "CrossAttnUpBlock3D",
61
+ "CrossAttnUpBlock3D"
62
+ ),
63
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
64
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
65
+ layers_per_block: int = 2,
66
+ downsample_padding: int = 1,
67
+ mid_block_scale_factor: float = 1,
68
+ act_fn: str = "silu",
69
+ norm_num_groups: int = 32,
70
+ norm_eps: float = 1e-5,
71
+ cross_attention_dim: int = 1280,
72
+ attention_head_dim: Union[int, Tuple[int]] = 8,
73
+ dual_cross_attention: bool = False,
74
+ use_linear_projection: bool = False,
75
+ class_embed_type: Optional[str] = None,
76
+ num_class_embeds: Optional[int] = None,
77
+ upcast_attention: bool = False,
78
+ resnet_time_scale_shift: str = "default",
79
+
80
+ use_inflated_groupnorm=False,
81
+
82
+ # Additional
83
+ use_motion_module = False,
84
+ motion_module_resolutions = ( 1,2,4,8 ),
85
+ motion_module_mid_block = False,
86
+ motion_module_decoder_only = False,
87
+ motion_module_type = None,
88
+ motion_module_kwargs = {},
89
+ unet_use_cross_frame_attention = False,
90
+ unet_use_temporal_attention = False,
91
+ ):
92
+ super().__init__()
93
+
94
+ self.sample_size = sample_size
95
+ time_embed_dim = block_out_channels[0] * 4
96
+
97
+ # input
98
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
99
+
100
+ # time
101
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
102
+ timestep_input_dim = block_out_channels[0]
103
+
104
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
105
+
106
+ # class embedding
107
+ if class_embed_type is None and num_class_embeds is not None:
108
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
109
+ elif class_embed_type == "timestep":
110
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
111
+ elif class_embed_type == "identity":
112
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
113
+ else:
114
+ self.class_embedding = None
115
+
116
+ self.down_blocks = nn.ModuleList([])
117
+ self.mid_block = None
118
+ self.up_blocks = nn.ModuleList([])
119
+
120
+ if isinstance(only_cross_attention, bool):
121
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
122
+
123
+ if isinstance(attention_head_dim, int):
124
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
125
+
126
+ # down
127
+ output_channel = block_out_channels[0]
128
+ for i, down_block_type in enumerate(down_block_types):
129
+ res = 2 ** i
130
+ input_channel = output_channel
131
+ output_channel = block_out_channels[i]
132
+ is_final_block = i == len(block_out_channels) - 1
133
+
134
+ down_block = get_down_block(
135
+ down_block_type,
136
+ num_layers=layers_per_block,
137
+ in_channels=input_channel,
138
+ out_channels=output_channel,
139
+ temb_channels=time_embed_dim,
140
+ add_downsample=not is_final_block,
141
+ resnet_eps=norm_eps,
142
+ resnet_act_fn=act_fn,
143
+ resnet_groups=norm_num_groups,
144
+ cross_attention_dim=cross_attention_dim,
145
+ attn_num_head_channels=attention_head_dim[i],
146
+ downsample_padding=downsample_padding,
147
+ dual_cross_attention=dual_cross_attention,
148
+ use_linear_projection=use_linear_projection,
149
+ only_cross_attention=only_cross_attention[i],
150
+ upcast_attention=upcast_attention,
151
+ resnet_time_scale_shift=resnet_time_scale_shift,
152
+
153
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
154
+ unet_use_temporal_attention=unet_use_temporal_attention,
155
+ use_inflated_groupnorm=use_inflated_groupnorm,
156
+
157
+ use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
158
+ motion_module_type=motion_module_type,
159
+ motion_module_kwargs=motion_module_kwargs,
160
+ )
161
+ self.down_blocks.append(down_block)
162
+
163
+ # mid
164
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
165
+ self.mid_block = UNetMidBlock3DCrossAttn(
166
+ in_channels=block_out_channels[-1],
167
+ temb_channels=time_embed_dim,
168
+ resnet_eps=norm_eps,
169
+ resnet_act_fn=act_fn,
170
+ output_scale_factor=mid_block_scale_factor,
171
+ resnet_time_scale_shift=resnet_time_scale_shift,
172
+ cross_attention_dim=cross_attention_dim,
173
+ attn_num_head_channels=attention_head_dim[-1],
174
+ resnet_groups=norm_num_groups,
175
+ dual_cross_attention=dual_cross_attention,
176
+ use_linear_projection=use_linear_projection,
177
+ upcast_attention=upcast_attention,
178
+
179
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
180
+ unet_use_temporal_attention=unet_use_temporal_attention,
181
+ use_inflated_groupnorm=use_inflated_groupnorm,
182
+
183
+ use_motion_module=use_motion_module and motion_module_mid_block,
184
+ motion_module_type=motion_module_type,
185
+ motion_module_kwargs=motion_module_kwargs,
186
+ )
187
+ else:
188
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
189
+
190
+ # count how many layers upsample the videos
191
+ self.num_upsamplers = 0
192
+
193
+ # up
194
+ reversed_block_out_channels = list(reversed(block_out_channels))
195
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
196
+ only_cross_attention = list(reversed(only_cross_attention))
197
+ output_channel = reversed_block_out_channels[0]
198
+ for i, up_block_type in enumerate(up_block_types):
199
+ res = 2 ** (3 - i)
200
+ is_final_block = i == len(block_out_channels) - 1
201
+
202
+ prev_output_channel = output_channel
203
+ output_channel = reversed_block_out_channels[i]
204
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
205
+
206
+ # add upsample block for all BUT final layer
207
+ if not is_final_block:
208
+ add_upsample = True
209
+ self.num_upsamplers += 1
210
+ else:
211
+ add_upsample = False
212
+
213
+ up_block = get_up_block(
214
+ up_block_type,
215
+ num_layers=layers_per_block + 1,
216
+ in_channels=input_channel,
217
+ out_channels=output_channel,
218
+ prev_output_channel=prev_output_channel,
219
+ temb_channels=time_embed_dim,
220
+ add_upsample=add_upsample,
221
+ resnet_eps=norm_eps,
222
+ resnet_act_fn=act_fn,
223
+ resnet_groups=norm_num_groups,
224
+ cross_attention_dim=cross_attention_dim,
225
+ attn_num_head_channels=reversed_attention_head_dim[i],
226
+ dual_cross_attention=dual_cross_attention,
227
+ use_linear_projection=use_linear_projection,
228
+ only_cross_attention=only_cross_attention[i],
229
+ upcast_attention=upcast_attention,
230
+ resnet_time_scale_shift=resnet_time_scale_shift,
231
+
232
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
233
+ unet_use_temporal_attention=unet_use_temporal_attention,
234
+ use_inflated_groupnorm=use_inflated_groupnorm,
235
+
236
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
237
+ motion_module_type=motion_module_type,
238
+ motion_module_kwargs=motion_module_kwargs,
239
+ )
240
+ self.up_blocks.append(up_block)
241
+ prev_output_channel = output_channel
242
+
243
+ # out
244
+ if use_inflated_groupnorm:
245
+ self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
246
+ else:
247
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
248
+ self.conv_act = nn.SiLU()
249
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
250
+
251
+ def set_attention_slice(self, slice_size):
252
+ r"""
253
+ Enable sliced attention computation.
254
+
255
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
256
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
257
+
258
+ Args:
259
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
260
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
261
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
262
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
263
+ must be a multiple of `slice_size`.
264
+ """
265
+ sliceable_head_dims = []
266
+
267
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
268
+ if hasattr(module, "set_attention_slice"):
269
+ sliceable_head_dims.append(module.sliceable_head_dim)
270
+
271
+ for child in module.children():
272
+ fn_recursive_retrieve_slicable_dims(child)
273
+
274
+ # retrieve number of attention layers
275
+ for module in self.children():
276
+ fn_recursive_retrieve_slicable_dims(module)
277
+
278
+ num_slicable_layers = len(sliceable_head_dims)
279
+
280
+ if slice_size == "auto":
281
+ # half the attention head size is usually a good trade-off between
282
+ # speed and memory
283
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
284
+ elif slice_size == "max":
285
+ # make smallest slice possible
286
+ slice_size = num_slicable_layers * [1]
287
+
288
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
289
+
290
+ if len(slice_size) != len(sliceable_head_dims):
291
+ raise ValueError(
292
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
293
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
294
+ )
295
+
296
+ for i in range(len(slice_size)):
297
+ size = slice_size[i]
298
+ dim = sliceable_head_dims[i]
299
+ if size is not None and size > dim:
300
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
301
+
302
+ # Recursively walk through all the children.
303
+ # Any children which exposes the set_attention_slice method
304
+ # gets the message
305
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
306
+ if hasattr(module, "set_attention_slice"):
307
+ module.set_attention_slice(slice_size.pop())
308
+
309
+ for child in module.children():
310
+ fn_recursive_set_attention_slice(child, slice_size)
311
+
312
+ reversed_slice_size = list(reversed(slice_size))
313
+ for module in self.children():
314
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
315
+
316
+ def _set_gradient_checkpointing(self, module, value=False):
317
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
318
+ module.gradient_checkpointing = value
319
+
320
+ def forward(
321
+ self,
322
+ sample: torch.FloatTensor,
323
+ timestep: Union[torch.Tensor, float, int],
324
+ encoder_hidden_states: torch.Tensor,
325
+ class_labels: Optional[torch.Tensor] = None,
326
+ attention_mask: Optional[torch.Tensor] = None,
327
+
328
+ # support controlnet
329
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
330
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
331
+
332
+ return_dict: bool = True,
333
+ ) -> Union[UNet3DConditionOutput, Tuple]:
334
+ r"""
335
+ Args:
336
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
337
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
338
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
339
+ return_dict (`bool`, *optional*, defaults to `True`):
340
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
341
+
342
+ Returns:
343
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
344
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
345
+ returning a tuple, the first element is the sample tensor.
346
+ """
347
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
348
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
349
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
350
+ # on the fly if necessary.
351
+ default_overall_up_factor = 2**self.num_upsamplers
352
+
353
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
354
+ forward_upsample_size = False
355
+ upsample_size = None
356
+
357
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
358
+ logger.info("Forward upsample size to force interpolation output size.")
359
+ forward_upsample_size = True
360
+
361
+ # prepare attention_mask
362
+ if attention_mask is not None:
363
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
364
+ attention_mask = attention_mask.unsqueeze(1)
365
+
366
+ # center input if necessary
367
+ if self.config.center_input_sample:
368
+ sample = 2 * sample - 1.0
369
+
370
+ # time
371
+ timesteps = timestep
372
+ if not torch.is_tensor(timesteps):
373
+ # This would be a good case for the `match` statement (Python 3.10+)
374
+ is_mps = sample.device.type == "mps"
375
+ if isinstance(timestep, float):
376
+ dtype = torch.float32 if is_mps else torch.float64
377
+ else:
378
+ dtype = torch.int32 if is_mps else torch.int64
379
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
380
+ elif len(timesteps.shape) == 0:
381
+ timesteps = timesteps[None].to(sample.device)
382
+
383
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
384
+ timesteps = timesteps.expand(sample.shape[0])
385
+
386
+ t_emb = self.time_proj(timesteps)
387
+
388
+ # timesteps does not contain any weights and will always return f32 tensors
389
+ # but time_embedding might actually be running in fp16. so we need to cast here.
390
+ # there might be better ways to encapsulate this.
391
+ t_emb = t_emb.to(dtype=self.dtype)
392
+ emb = self.time_embedding(t_emb)
393
+
394
+ if self.class_embedding is not None:
395
+ if class_labels is None:
396
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
397
+
398
+ if self.config.class_embed_type == "timestep":
399
+ class_labels = self.time_proj(class_labels)
400
+
401
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
402
+ emb = emb + class_emb
403
+
404
+ # pre-process
405
+ sample = self.conv_in(sample)
406
+
407
+ # down
408
+ down_block_res_samples = (sample,)
409
+ for downsample_block in self.down_blocks:
410
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
411
+ sample, res_samples = downsample_block(
412
+ hidden_states=sample,
413
+ temb=emb,
414
+ encoder_hidden_states=encoder_hidden_states,
415
+ attention_mask=attention_mask,
416
+ )
417
+ else:
418
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
419
+
420
+ down_block_res_samples += res_samples
421
+
422
+ # support controlnet
423
+ down_block_res_samples = list(down_block_res_samples)
424
+ if down_block_additional_residuals is not None:
425
+ for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
426
+ if down_block_additional_residual.dim() == 4: # boardcast
427
+ down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
428
+ down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
429
+
430
+ # mid
431
+ sample = self.mid_block(
432
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
433
+ )
434
+
435
+ # support controlnet
436
+ if mid_block_additional_residual is not None:
437
+ if mid_block_additional_residual.dim() == 4: # boardcast
438
+ mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
439
+ sample = sample + mid_block_additional_residual
440
+
441
+ # up
442
+ for i, upsample_block in enumerate(self.up_blocks):
443
+ is_final_block = i == len(self.up_blocks) - 1
444
+
445
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
446
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
447
+
448
+ # if we have not reached the final block and need to forward the
449
+ # upsample size, we do it here
450
+ if not is_final_block and forward_upsample_size:
451
+ upsample_size = down_block_res_samples[-1].shape[2:]
452
+
453
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
454
+ sample = upsample_block(
455
+ hidden_states=sample,
456
+ temb=emb,
457
+ res_hidden_states_tuple=res_samples,
458
+ encoder_hidden_states=encoder_hidden_states,
459
+ upsample_size=upsample_size,
460
+ attention_mask=attention_mask,
461
+ )
462
+ else:
463
+ sample = upsample_block(
464
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
465
+ )
466
+
467
+ # post-process
468
+ sample = self.conv_norm_out(sample)
469
+ sample = self.conv_act(sample)
470
+ sample = self.conv_out(sample)
471
+
472
+ if not return_dict:
473
+ return (sample,)
474
+
475
+ return UNet3DConditionOutput(sample=sample)
476
+
477
+ @classmethod
478
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
479
+ if subfolder is not None:
480
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
481
+ print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...")
482
+
483
+ config_file = os.path.join(pretrained_model_path, 'config.json')
484
+ if not os.path.isfile(config_file):
485
+ raise RuntimeError(f"{config_file} does not exist")
486
+ with open(config_file, "r") as f:
487
+ config = json.load(f)
488
+ config["_class_name"] = cls.__name__
489
+ config["down_block_types"] = [
490
+ "CrossAttnDownBlock3D",
491
+ "CrossAttnDownBlock3D",
492
+ "CrossAttnDownBlock3D",
493
+ "DownBlock3D"
494
+ ]
495
+ config["up_block_types"] = [
496
+ "UpBlock3D",
497
+ "CrossAttnUpBlock3D",
498
+ "CrossAttnUpBlock3D",
499
+ "CrossAttnUpBlock3D"
500
+ ]
501
+
502
+ from diffusers.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
503
+ model = cls.from_config(config, **unet_additional_kwargs)
504
+
505
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
506
+ model_file_safe = os.path.join(pretrained_model_path, SAFETENSORS_WEIGHTS_NAME)
507
+
508
+ if os.path.isfile(model_file_safe):
509
+ model_file = model_file_safe
510
+
511
+ if not os.path.isfile(model_file):
512
+ raise RuntimeError(f"{model_file} does not exist")
513
+
514
+ if SAFETENSORS_WEIGHTS_NAME in model_file:
515
+ from safetensors.torch import load_file
516
+ state_dict = load_file(model_file)
517
+ else:
518
+ state_dict = torch.load(model_file, map_location="cpu")
519
+
520
+ m, u = model.load_state_dict(state_dict, strict=False)
521
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
522
+
523
+ params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()]
524
+ print(f"### Motion Module Parameters: {sum(params) / 1e6} M")
525
+
526
+ return model
animatediff/models/unet_blocks.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .attention import Transformer3DModel
7
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
8
+ from .motion_module import get_motion_module
9
+
10
+ import pdb
11
+
12
+ def checkpoint_no_reentrant(*args, **kwargs):
13
+ kwargs['use_reentrant'] = False
14
+ return torch.utils.checkpoint.checkpoint(*args, **kwargs)
15
+
16
+ def get_down_block(
17
+ down_block_type,
18
+ num_layers,
19
+ in_channels,
20
+ out_channels,
21
+ temb_channels,
22
+ add_downsample,
23
+ resnet_eps,
24
+ resnet_act_fn,
25
+ attn_num_head_channels,
26
+ resnet_groups=None,
27
+ cross_attention_dim=None,
28
+ downsample_padding=None,
29
+ dual_cross_attention=False,
30
+ use_linear_projection=False,
31
+ only_cross_attention=False,
32
+ upcast_attention=False,
33
+ resnet_time_scale_shift="default",
34
+
35
+ unet_use_cross_frame_attention=False,
36
+ unet_use_temporal_attention=False,
37
+ use_inflated_groupnorm=False,
38
+
39
+ use_motion_module=None,
40
+
41
+ motion_module_type=None,
42
+ motion_module_kwargs=None,
43
+ ):
44
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
45
+ if down_block_type == "DownBlock3D":
46
+ return DownBlock3D(
47
+ num_layers=num_layers,
48
+ in_channels=in_channels,
49
+ out_channels=out_channels,
50
+ temb_channels=temb_channels,
51
+ add_downsample=add_downsample,
52
+ resnet_eps=resnet_eps,
53
+ resnet_act_fn=resnet_act_fn,
54
+ resnet_groups=resnet_groups,
55
+ downsample_padding=downsample_padding,
56
+ resnet_time_scale_shift=resnet_time_scale_shift,
57
+
58
+ use_inflated_groupnorm=use_inflated_groupnorm,
59
+
60
+ use_motion_module=use_motion_module,
61
+ motion_module_type=motion_module_type,
62
+ motion_module_kwargs=motion_module_kwargs,
63
+ )
64
+ elif down_block_type == "CrossAttnDownBlock3D":
65
+ if cross_attention_dim is None:
66
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
67
+ return CrossAttnDownBlock3D(
68
+ num_layers=num_layers,
69
+ in_channels=in_channels,
70
+ out_channels=out_channels,
71
+ temb_channels=temb_channels,
72
+ add_downsample=add_downsample,
73
+ resnet_eps=resnet_eps,
74
+ resnet_act_fn=resnet_act_fn,
75
+ resnet_groups=resnet_groups,
76
+ downsample_padding=downsample_padding,
77
+ cross_attention_dim=cross_attention_dim,
78
+ attn_num_head_channels=attn_num_head_channels,
79
+ dual_cross_attention=dual_cross_attention,
80
+ use_linear_projection=use_linear_projection,
81
+ only_cross_attention=only_cross_attention,
82
+ upcast_attention=upcast_attention,
83
+ resnet_time_scale_shift=resnet_time_scale_shift,
84
+
85
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
86
+ unet_use_temporal_attention=unet_use_temporal_attention,
87
+ use_inflated_groupnorm=use_inflated_groupnorm,
88
+
89
+ use_motion_module=use_motion_module,
90
+ motion_module_type=motion_module_type,
91
+ motion_module_kwargs=motion_module_kwargs,
92
+ )
93
+ raise ValueError(f"{down_block_type} does not exist.")
94
+
95
+
96
+ def get_up_block(
97
+ up_block_type,
98
+ num_layers,
99
+ in_channels,
100
+ out_channels,
101
+ prev_output_channel,
102
+ temb_channels,
103
+ add_upsample,
104
+ resnet_eps,
105
+ resnet_act_fn,
106
+ attn_num_head_channels,
107
+ resnet_groups=None,
108
+ cross_attention_dim=None,
109
+ dual_cross_attention=False,
110
+ use_linear_projection=False,
111
+ only_cross_attention=False,
112
+ upcast_attention=False,
113
+ resnet_time_scale_shift="default",
114
+
115
+ unet_use_cross_frame_attention=False,
116
+ unet_use_temporal_attention=False,
117
+ use_inflated_groupnorm=False,
118
+
119
+ use_motion_module=None,
120
+ motion_module_type=None,
121
+ motion_module_kwargs=None,
122
+ ):
123
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
124
+ if up_block_type == "UpBlock3D":
125
+ return UpBlock3D(
126
+ num_layers=num_layers,
127
+ in_channels=in_channels,
128
+ out_channels=out_channels,
129
+ prev_output_channel=prev_output_channel,
130
+ temb_channels=temb_channels,
131
+ add_upsample=add_upsample,
132
+ resnet_eps=resnet_eps,
133
+ resnet_act_fn=resnet_act_fn,
134
+ resnet_groups=resnet_groups,
135
+ resnet_time_scale_shift=resnet_time_scale_shift,
136
+
137
+ use_inflated_groupnorm=use_inflated_groupnorm,
138
+
139
+ use_motion_module=use_motion_module,
140
+ motion_module_type=motion_module_type,
141
+ motion_module_kwargs=motion_module_kwargs,
142
+ )
143
+ elif up_block_type == "CrossAttnUpBlock3D":
144
+ if cross_attention_dim is None:
145
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
146
+ return CrossAttnUpBlock3D(
147
+ num_layers=num_layers,
148
+ in_channels=in_channels,
149
+ out_channels=out_channels,
150
+ prev_output_channel=prev_output_channel,
151
+ temb_channels=temb_channels,
152
+ add_upsample=add_upsample,
153
+ resnet_eps=resnet_eps,
154
+ resnet_act_fn=resnet_act_fn,
155
+ resnet_groups=resnet_groups,
156
+ cross_attention_dim=cross_attention_dim,
157
+ attn_num_head_channels=attn_num_head_channels,
158
+ dual_cross_attention=dual_cross_attention,
159
+ use_linear_projection=use_linear_projection,
160
+ only_cross_attention=only_cross_attention,
161
+ upcast_attention=upcast_attention,
162
+ resnet_time_scale_shift=resnet_time_scale_shift,
163
+
164
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
165
+ unet_use_temporal_attention=unet_use_temporal_attention,
166
+ use_inflated_groupnorm=use_inflated_groupnorm,
167
+
168
+ use_motion_module=use_motion_module,
169
+ motion_module_type=motion_module_type,
170
+ motion_module_kwargs=motion_module_kwargs,
171
+ )
172
+ raise ValueError(f"{up_block_type} does not exist.")
173
+
174
+
175
+ class UNetMidBlock3DCrossAttn(nn.Module):
176
+ def __init__(
177
+ self,
178
+ in_channels: int,
179
+ temb_channels: int,
180
+ dropout: float = 0.0,
181
+ num_layers: int = 1,
182
+ resnet_eps: float = 1e-6,
183
+ resnet_time_scale_shift: str = "default",
184
+ resnet_act_fn: str = "swish",
185
+ resnet_groups: int = 32,
186
+ resnet_pre_norm: bool = True,
187
+ attn_num_head_channels=1,
188
+ output_scale_factor=1.0,
189
+ cross_attention_dim=1280,
190
+ dual_cross_attention=False,
191
+ use_linear_projection=False,
192
+ upcast_attention=False,
193
+
194
+ unet_use_cross_frame_attention=False,
195
+ unet_use_temporal_attention=False,
196
+ use_inflated_groupnorm=False,
197
+
198
+ use_motion_module=None,
199
+
200
+ motion_module_type=None,
201
+ motion_module_kwargs=None,
202
+ ):
203
+ super().__init__()
204
+
205
+ self.has_cross_attention = True
206
+ self.attn_num_head_channels = attn_num_head_channels
207
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
208
+
209
+ # there is always at least one resnet
210
+ resnets = [
211
+ ResnetBlock3D(
212
+ in_channels=in_channels,
213
+ out_channels=in_channels,
214
+ temb_channels=temb_channels,
215
+ eps=resnet_eps,
216
+ groups=resnet_groups,
217
+ dropout=dropout,
218
+ time_embedding_norm=resnet_time_scale_shift,
219
+ non_linearity=resnet_act_fn,
220
+ output_scale_factor=output_scale_factor,
221
+ pre_norm=resnet_pre_norm,
222
+
223
+ use_inflated_groupnorm=use_inflated_groupnorm,
224
+ )
225
+ ]
226
+ attentions = []
227
+ motion_modules = []
228
+
229
+ for _ in range(num_layers):
230
+ if dual_cross_attention:
231
+ raise NotImplementedError
232
+ attentions.append(
233
+ Transformer3DModel(
234
+ attn_num_head_channels,
235
+ in_channels // attn_num_head_channels,
236
+ in_channels=in_channels,
237
+ num_layers=1,
238
+ cross_attention_dim=cross_attention_dim,
239
+ norm_num_groups=resnet_groups,
240
+ use_linear_projection=use_linear_projection,
241
+ upcast_attention=upcast_attention,
242
+
243
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
244
+ unet_use_temporal_attention=unet_use_temporal_attention,
245
+ )
246
+ )
247
+ motion_modules.append(
248
+ get_motion_module(
249
+ in_channels=in_channels,
250
+ motion_module_type=motion_module_type,
251
+ motion_module_kwargs=motion_module_kwargs,
252
+ ) if use_motion_module else None
253
+ )
254
+ resnets.append(
255
+ ResnetBlock3D(
256
+ in_channels=in_channels,
257
+ out_channels=in_channels,
258
+ temb_channels=temb_channels,
259
+ eps=resnet_eps,
260
+ groups=resnet_groups,
261
+ dropout=dropout,
262
+ time_embedding_norm=resnet_time_scale_shift,
263
+ non_linearity=resnet_act_fn,
264
+ output_scale_factor=output_scale_factor,
265
+ pre_norm=resnet_pre_norm,
266
+
267
+ use_inflated_groupnorm=use_inflated_groupnorm,
268
+ )
269
+ )
270
+
271
+ self.attentions = nn.ModuleList(attentions)
272
+ self.resnets = nn.ModuleList(resnets)
273
+ self.motion_modules = nn.ModuleList(motion_modules)
274
+
275
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
276
+ hidden_states = self.resnets[0](hidden_states, temb)
277
+ for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
278
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
279
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
280
+ hidden_states = resnet(hidden_states, temb)
281
+
282
+ return hidden_states
283
+
284
+
285
+ class CrossAttnDownBlock3D(nn.Module):
286
+ def __init__(
287
+ self,
288
+ in_channels: int,
289
+ out_channels: int,
290
+ temb_channels: int,
291
+ dropout: float = 0.0,
292
+ num_layers: int = 1,
293
+ resnet_eps: float = 1e-6,
294
+ resnet_time_scale_shift: str = "default",
295
+ resnet_act_fn: str = "swish",
296
+ resnet_groups: int = 32,
297
+ resnet_pre_norm: bool = True,
298
+ attn_num_head_channels=1,
299
+ cross_attention_dim=1280,
300
+ output_scale_factor=1.0,
301
+ downsample_padding=1,
302
+ add_downsample=True,
303
+ dual_cross_attention=False,
304
+ use_linear_projection=False,
305
+ only_cross_attention=False,
306
+ upcast_attention=False,
307
+
308
+ unet_use_cross_frame_attention=False,
309
+ unet_use_temporal_attention=False,
310
+ use_inflated_groupnorm=False,
311
+
312
+ use_motion_module=None,
313
+
314
+ motion_module_type=None,
315
+ motion_module_kwargs=None,
316
+ ):
317
+ super().__init__()
318
+ resnets = []
319
+ attentions = []
320
+ motion_modules = []
321
+
322
+ self.has_cross_attention = True
323
+ self.attn_num_head_channels = attn_num_head_channels
324
+
325
+ for i in range(num_layers):
326
+ in_channels = in_channels if i == 0 else out_channels
327
+ resnets.append(
328
+ ResnetBlock3D(
329
+ in_channels=in_channels,
330
+ out_channels=out_channels,
331
+ temb_channels=temb_channels,
332
+ eps=resnet_eps,
333
+ groups=resnet_groups,
334
+ dropout=dropout,
335
+ time_embedding_norm=resnet_time_scale_shift,
336
+ non_linearity=resnet_act_fn,
337
+ output_scale_factor=output_scale_factor,
338
+ pre_norm=resnet_pre_norm,
339
+
340
+ use_inflated_groupnorm=use_inflated_groupnorm,
341
+ )
342
+ )
343
+ if dual_cross_attention:
344
+ raise NotImplementedError
345
+ attentions.append(
346
+ Transformer3DModel(
347
+ attn_num_head_channels,
348
+ out_channels // attn_num_head_channels,
349
+ in_channels=out_channels,
350
+ num_layers=1,
351
+ cross_attention_dim=cross_attention_dim,
352
+ norm_num_groups=resnet_groups,
353
+ use_linear_projection=use_linear_projection,
354
+ only_cross_attention=only_cross_attention,
355
+ upcast_attention=upcast_attention,
356
+
357
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
358
+ unet_use_temporal_attention=unet_use_temporal_attention,
359
+ )
360
+ )
361
+ motion_modules.append(
362
+ get_motion_module(
363
+ in_channels=out_channels,
364
+ motion_module_type=motion_module_type,
365
+ motion_module_kwargs=motion_module_kwargs,
366
+ ) if use_motion_module else None
367
+ )
368
+
369
+ self.attentions = nn.ModuleList(attentions)
370
+ self.resnets = nn.ModuleList(resnets)
371
+ self.motion_modules = nn.ModuleList(motion_modules)
372
+
373
+ if add_downsample:
374
+ self.downsamplers = nn.ModuleList(
375
+ [
376
+ Downsample3D(
377
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
378
+ )
379
+ ]
380
+ )
381
+ else:
382
+ self.downsamplers = None
383
+
384
+ self.gradient_checkpointing = False
385
+
386
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
387
+ output_states = ()
388
+
389
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
390
+ if self.training and self.gradient_checkpointing:
391
+
392
+ def create_custom_forward(module, return_dict=None):
393
+ def custom_forward(*inputs):
394
+ if return_dict is not None:
395
+ return module(*inputs, return_dict=return_dict)
396
+ else:
397
+ return module(*inputs)
398
+
399
+ return custom_forward
400
+
401
+ hidden_states = checkpoint_no_reentrant(create_custom_forward(resnet), hidden_states, temb)
402
+ hidden_states = checkpoint_no_reentrant(
403
+ create_custom_forward(attn, return_dict=False),
404
+ hidden_states,
405
+ encoder_hidden_states,
406
+ )[0]
407
+ if motion_module is not None:
408
+ hidden_states = checkpoint_no_reentrant(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
409
+
410
+ else:
411
+ hidden_states = resnet(hidden_states, temb)
412
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
413
+
414
+ # add motion module
415
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
416
+
417
+ output_states += (hidden_states,)
418
+
419
+ if self.downsamplers is not None:
420
+ for downsampler in self.downsamplers:
421
+ hidden_states = downsampler(hidden_states)
422
+
423
+ output_states += (hidden_states,)
424
+
425
+ return hidden_states, output_states
426
+
427
+
428
+ class DownBlock3D(nn.Module):
429
+ def __init__(
430
+ self,
431
+ in_channels: int,
432
+ out_channels: int,
433
+ temb_channels: int,
434
+ dropout: float = 0.0,
435
+ num_layers: int = 1,
436
+ resnet_eps: float = 1e-6,
437
+ resnet_time_scale_shift: str = "default",
438
+ resnet_act_fn: str = "swish",
439
+ resnet_groups: int = 32,
440
+ resnet_pre_norm: bool = True,
441
+ output_scale_factor=1.0,
442
+ add_downsample=True,
443
+ downsample_padding=1,
444
+
445
+ use_inflated_groupnorm=False,
446
+
447
+ use_motion_module=None,
448
+ motion_module_type=None,
449
+ motion_module_kwargs=None,
450
+ ):
451
+ super().__init__()
452
+ resnets = []
453
+ motion_modules = []
454
+
455
+ for i in range(num_layers):
456
+ in_channels = in_channels if i == 0 else out_channels
457
+ resnets.append(
458
+ ResnetBlock3D(
459
+ in_channels=in_channels,
460
+ out_channels=out_channels,
461
+ temb_channels=temb_channels,
462
+ eps=resnet_eps,
463
+ groups=resnet_groups,
464
+ dropout=dropout,
465
+ time_embedding_norm=resnet_time_scale_shift,
466
+ non_linearity=resnet_act_fn,
467
+ output_scale_factor=output_scale_factor,
468
+ pre_norm=resnet_pre_norm,
469
+
470
+ use_inflated_groupnorm=use_inflated_groupnorm,
471
+ )
472
+ )
473
+ motion_modules.append(
474
+ get_motion_module(
475
+ in_channels=out_channels,
476
+ motion_module_type=motion_module_type,
477
+ motion_module_kwargs=motion_module_kwargs,
478
+ ) if use_motion_module else None
479
+ )
480
+
481
+ self.resnets = nn.ModuleList(resnets)
482
+ self.motion_modules = nn.ModuleList(motion_modules)
483
+
484
+ if add_downsample:
485
+ self.downsamplers = nn.ModuleList(
486
+ [
487
+ Downsample3D(
488
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
489
+ )
490
+ ]
491
+ )
492
+ else:
493
+ self.downsamplers = None
494
+
495
+ self.gradient_checkpointing = False
496
+
497
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
498
+ output_states = ()
499
+
500
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
501
+ if self.training and self.gradient_checkpointing:
502
+ def create_custom_forward(module):
503
+ def custom_forward(*inputs):
504
+ return module(*inputs)
505
+
506
+ return custom_forward
507
+
508
+ hidden_states = checkpoint_no_reentrant(create_custom_forward(resnet), hidden_states, temb)
509
+ if motion_module is not None:
510
+ hidden_states = checkpoint_no_reentrant(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
511
+ else:
512
+ hidden_states = resnet(hidden_states, temb)
513
+
514
+ # add motion module
515
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
516
+
517
+ output_states += (hidden_states,)
518
+
519
+ if self.downsamplers is not None:
520
+ for downsampler in self.downsamplers:
521
+ hidden_states = downsampler(hidden_states)
522
+
523
+ output_states += (hidden_states,)
524
+
525
+ return hidden_states, output_states
526
+
527
+
528
+ class CrossAttnUpBlock3D(nn.Module):
529
+ def __init__(
530
+ self,
531
+ in_channels: int,
532
+ out_channels: int,
533
+ prev_output_channel: int,
534
+ temb_channels: int,
535
+ dropout: float = 0.0,
536
+ num_layers: int = 1,
537
+ resnet_eps: float = 1e-6,
538
+ resnet_time_scale_shift: str = "default",
539
+ resnet_act_fn: str = "swish",
540
+ resnet_groups: int = 32,
541
+ resnet_pre_norm: bool = True,
542
+ attn_num_head_channels=1,
543
+ cross_attention_dim=1280,
544
+ output_scale_factor=1.0,
545
+ add_upsample=True,
546
+ dual_cross_attention=False,
547
+ use_linear_projection=False,
548
+ only_cross_attention=False,
549
+ upcast_attention=False,
550
+
551
+ unet_use_cross_frame_attention=False,
552
+ unet_use_temporal_attention=False,
553
+ use_inflated_groupnorm=False,
554
+
555
+ use_motion_module=None,
556
+
557
+ motion_module_type=None,
558
+ motion_module_kwargs=None,
559
+ ):
560
+ super().__init__()
561
+ resnets = []
562
+ attentions = []
563
+ motion_modules = []
564
+
565
+ self.has_cross_attention = True
566
+ self.attn_num_head_channels = attn_num_head_channels
567
+
568
+ for i in range(num_layers):
569
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
570
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
571
+
572
+ resnets.append(
573
+ ResnetBlock3D(
574
+ in_channels=resnet_in_channels + res_skip_channels,
575
+ out_channels=out_channels,
576
+ temb_channels=temb_channels,
577
+ eps=resnet_eps,
578
+ groups=resnet_groups,
579
+ dropout=dropout,
580
+ time_embedding_norm=resnet_time_scale_shift,
581
+ non_linearity=resnet_act_fn,
582
+ output_scale_factor=output_scale_factor,
583
+ pre_norm=resnet_pre_norm,
584
+
585
+ use_inflated_groupnorm=use_inflated_groupnorm,
586
+ )
587
+ )
588
+ if dual_cross_attention:
589
+ raise NotImplementedError
590
+ attentions.append(
591
+ Transformer3DModel(
592
+ attn_num_head_channels,
593
+ out_channels // attn_num_head_channels,
594
+ in_channels=out_channels,
595
+ num_layers=1,
596
+ cross_attention_dim=cross_attention_dim,
597
+ norm_num_groups=resnet_groups,
598
+ use_linear_projection=use_linear_projection,
599
+ only_cross_attention=only_cross_attention,
600
+ upcast_attention=upcast_attention,
601
+
602
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
603
+ unet_use_temporal_attention=unet_use_temporal_attention,
604
+ )
605
+ )
606
+ motion_modules.append(
607
+ get_motion_module(
608
+ in_channels=out_channels,
609
+ motion_module_type=motion_module_type,
610
+ motion_module_kwargs=motion_module_kwargs,
611
+ ) if use_motion_module else None
612
+ )
613
+
614
+ self.attentions = nn.ModuleList(attentions)
615
+ self.resnets = nn.ModuleList(resnets)
616
+ self.motion_modules = nn.ModuleList(motion_modules)
617
+
618
+ if add_upsample:
619
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
620
+ else:
621
+ self.upsamplers = None
622
+
623
+ self.gradient_checkpointing = False
624
+
625
+ def forward(
626
+ self,
627
+ hidden_states,
628
+ res_hidden_states_tuple,
629
+ temb=None,
630
+ encoder_hidden_states=None,
631
+ upsample_size=None,
632
+ attention_mask=None,
633
+ ):
634
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
635
+ # pop res hidden states
636
+ res_hidden_states = res_hidden_states_tuple[-1]
637
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
638
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
639
+
640
+ if self.training and self.gradient_checkpointing:
641
+
642
+ def create_custom_forward(module, return_dict=None):
643
+ def custom_forward(*inputs):
644
+ if return_dict is not None:
645
+ return module(*inputs, return_dict=return_dict)
646
+ else:
647
+ return module(*inputs)
648
+
649
+ return custom_forward
650
+
651
+ hidden_states = checkpoint_no_reentrant(create_custom_forward(resnet), hidden_states, temb)
652
+ hidden_states = checkpoint_no_reentrant(
653
+ create_custom_forward(attn, return_dict=False),
654
+ hidden_states,
655
+ encoder_hidden_states,
656
+ )[0]
657
+ if motion_module is not None:
658
+ hidden_states = checkpoint_no_reentrant(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
659
+
660
+ else:
661
+ hidden_states = resnet(hidden_states, temb)
662
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
663
+
664
+ # add motion module
665
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
666
+
667
+ if self.upsamplers is not None:
668
+ for upsampler in self.upsamplers:
669
+ hidden_states = upsampler(hidden_states, upsample_size)
670
+
671
+ return hidden_states
672
+
673
+
674
+ class UpBlock3D(nn.Module):
675
+ def __init__(
676
+ self,
677
+ in_channels: int,
678
+ prev_output_channel: int,
679
+ out_channels: int,
680
+ temb_channels: int,
681
+ dropout: float = 0.0,
682
+ num_layers: int = 1,
683
+ resnet_eps: float = 1e-6,
684
+ resnet_time_scale_shift: str = "default",
685
+ resnet_act_fn: str = "swish",
686
+ resnet_groups: int = 32,
687
+ resnet_pre_norm: bool = True,
688
+ output_scale_factor=1.0,
689
+ add_upsample=True,
690
+
691
+ use_inflated_groupnorm=False,
692
+
693
+ use_motion_module=None,
694
+ motion_module_type=None,
695
+ motion_module_kwargs=None,
696
+ ):
697
+ super().__init__()
698
+ resnets = []
699
+ motion_modules = []
700
+
701
+ for i in range(num_layers):
702
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
703
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
704
+
705
+ resnets.append(
706
+ ResnetBlock3D(
707
+ in_channels=resnet_in_channels + res_skip_channels,
708
+ out_channels=out_channels,
709
+ temb_channels=temb_channels,
710
+ eps=resnet_eps,
711
+ groups=resnet_groups,
712
+ dropout=dropout,
713
+ time_embedding_norm=resnet_time_scale_shift,
714
+ non_linearity=resnet_act_fn,
715
+ output_scale_factor=output_scale_factor,
716
+ pre_norm=resnet_pre_norm,
717
+
718
+ use_inflated_groupnorm=use_inflated_groupnorm,
719
+ )
720
+ )
721
+ motion_modules.append(
722
+ get_motion_module(
723
+ in_channels=out_channels,
724
+ motion_module_type=motion_module_type,
725
+ motion_module_kwargs=motion_module_kwargs,
726
+ ) if use_motion_module else None
727
+ )
728
+
729
+ self.resnets = nn.ModuleList(resnets)
730
+ self.motion_modules = nn.ModuleList(motion_modules)
731
+
732
+ if add_upsample:
733
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
734
+ else:
735
+ self.upsamplers = None
736
+
737
+ self.gradient_checkpointing = False
738
+
739
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
740
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
741
+ # pop res hidden states
742
+ res_hidden_states = res_hidden_states_tuple[-1]
743
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
744
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
745
+
746
+ if self.training and self.gradient_checkpointing:
747
+ def create_custom_forward(module):
748
+ def custom_forward(*inputs):
749
+ return module(*inputs)
750
+
751
+ return custom_forward
752
+
753
+ hidden_states = checkpoint_no_reentrant(create_custom_forward(resnet), hidden_states, temb)
754
+ if motion_module is not None:
755
+ hidden_states = checkpoint_no_reentrant(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
756
+ else:
757
+ hidden_states = resnet(hidden_states, temb)
758
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
759
+
760
+ if self.upsamplers is not None:
761
+ for upsampler in self.upsamplers:
762
+ hidden_states = upsampler(hidden_states, upsample_size)
763
+
764
+ return hidden_states
animatediff/pipelines/pipeline_animation.py ADDED
@@ -0,0 +1,746 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2
+
3
+ import inspect
4
+ from typing import Callable, List, Optional, Union
5
+ from dataclasses import dataclass
6
+ import numpy as np
7
+ import torch
8
+ from tqdm import tqdm
9
+
10
+ from diffusers.utils import is_accelerate_available
11
+ from packaging import version
12
+ from transformers import CLIPTextModel, CLIPTokenizer
13
+
14
+ from diffusers.configuration_utils import FrozenDict
15
+ from diffusers.models import AutoencoderKL
16
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
17
+ from diffusers.schedulers import (
18
+ DDIMScheduler,
19
+ DPMSolverMultistepScheduler,
20
+ EulerAncestralDiscreteScheduler,
21
+ EulerDiscreteScheduler,
22
+ LMSDiscreteScheduler,
23
+ PNDMScheduler,
24
+ )
25
+ from diffusers.utils import deprecate, logging, BaseOutput
26
+
27
+ from einops import rearrange
28
+
29
+ from ..models.unet import UNet3DConditionModel
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ @dataclass
35
+ class AnimationPipelineOutput(BaseOutput):
36
+ videos: Union[torch.Tensor, np.ndarray]
37
+
38
+
39
+ class AnimationPipeline(DiffusionPipeline):
40
+ _optional_components = []
41
+
42
+ def __init__(
43
+ self,
44
+ vae: AutoencoderKL,
45
+ text_encoder: CLIPTextModel,
46
+ tokenizer: CLIPTokenizer,
47
+ unet: UNet3DConditionModel,
48
+ scheduler: Union[
49
+ DDIMScheduler,
50
+ PNDMScheduler,
51
+ LMSDiscreteScheduler,
52
+ EulerDiscreteScheduler,
53
+ EulerAncestralDiscreteScheduler,
54
+ DPMSolverMultistepScheduler,
55
+ ],
56
+ ):
57
+ super().__init__()
58
+
59
+ if (
60
+ hasattr(scheduler.config, "steps_offset")
61
+ and scheduler.config.steps_offset != 1
62
+ ):
63
+ deprecation_message = (
64
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
65
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
66
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
67
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
68
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
69
+ " file"
70
+ )
71
+ deprecate(
72
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
73
+ )
74
+ new_config = dict(scheduler.config)
75
+ new_config["steps_offset"] = 1
76
+ scheduler._internal_dict = FrozenDict(new_config)
77
+
78
+ if (
79
+ hasattr(scheduler.config, "clip_sample")
80
+ and scheduler.config.clip_sample is True
81
+ ):
82
+ deprecation_message = (
83
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
84
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
85
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
86
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
87
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
88
+ )
89
+ deprecate(
90
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
91
+ )
92
+ new_config = dict(scheduler.config)
93
+ new_config["clip_sample"] = False
94
+ scheduler._internal_dict = FrozenDict(new_config)
95
+
96
+ is_unet_version_less_0_9_0 = hasattr(
97
+ unet.config, "_diffusers_version"
98
+ ) and version.parse(
99
+ version.parse(unet.config._diffusers_version).base_version
100
+ ) < version.parse(
101
+ "0.9.0.dev0"
102
+ )
103
+ is_unet_sample_size_less_64 = (
104
+ hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
105
+ )
106
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
107
+ deprecation_message = (
108
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
109
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
110
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
111
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
112
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
113
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
114
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
115
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
116
+ " the `unet/config.json` file"
117
+ )
118
+ deprecate(
119
+ "sample_size<64", "1.0.0", deprecation_message, standard_warn=False
120
+ )
121
+ new_config = dict(unet.config)
122
+ new_config["sample_size"] = 64
123
+ unet._internal_dict = FrozenDict(new_config)
124
+
125
+ self.register_modules(
126
+ vae=vae,
127
+ text_encoder=text_encoder,
128
+ tokenizer=tokenizer,
129
+ unet=unet,
130
+ scheduler=scheduler,
131
+ )
132
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
133
+
134
+ def enable_vae_slicing(self):
135
+ self.vae.enable_slicing()
136
+
137
+ def disable_vae_slicing(self):
138
+ self.vae.disable_slicing()
139
+
140
+ def enable_sequential_cpu_offload(self, gpu_id=0):
141
+ if is_accelerate_available():
142
+ from accelerate import cpu_offload
143
+ else:
144
+ raise ImportError("Please install accelerate via `pip install accelerate`")
145
+
146
+ device = torch.device(f"cuda:{gpu_id}")
147
+
148
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
149
+ if cpu_offloaded_model is not None:
150
+ cpu_offload(cpu_offloaded_model, device)
151
+
152
+ @property
153
+ def _execution_device(self):
154
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
155
+ return self.device
156
+ for module in self.unet.modules():
157
+ if (
158
+ hasattr(module, "_hf_hook")
159
+ and hasattr(module._hf_hook, "execution_device")
160
+ and module._hf_hook.execution_device is not None
161
+ ):
162
+ return torch.device(module._hf_hook.execution_device)
163
+ return self.device
164
+
165
+ def _encode_prompt(
166
+ self,
167
+ prompt,
168
+ device,
169
+ num_videos_per_prompt,
170
+ do_classifier_free_guidance,
171
+ negative_prompt,
172
+ ):
173
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
174
+
175
+ text_inputs = self.tokenizer(
176
+ prompt,
177
+ padding="max_length",
178
+ max_length=self.tokenizer.model_max_length,
179
+ truncation=True,
180
+ return_tensors="pt",
181
+ )
182
+ text_input_ids = text_inputs.input_ids
183
+ untruncated_ids = self.tokenizer(
184
+ prompt, padding="longest", return_tensors="pt"
185
+ ).input_ids
186
+
187
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
188
+ text_input_ids, untruncated_ids
189
+ ):
190
+ removed_text = self.tokenizer.batch_decode(
191
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
192
+ )
193
+ logger.warning(
194
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
195
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
196
+ )
197
+
198
+ if (
199
+ hasattr(self.text_encoder.config, "use_attention_mask")
200
+ and self.text_encoder.config.use_attention_mask
201
+ ):
202
+ attention_mask = text_inputs.attention_mask.to(device)
203
+ else:
204
+ attention_mask = None
205
+
206
+ text_embeddings = self.text_encoder(
207
+ text_input_ids.to(device),
208
+ attention_mask=attention_mask,
209
+ )
210
+ text_embeddings = text_embeddings[0]
211
+
212
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
213
+ bs_embed, seq_len, _ = text_embeddings.shape
214
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
215
+ text_embeddings = text_embeddings.view(
216
+ bs_embed * num_videos_per_prompt, seq_len, -1
217
+ )
218
+
219
+ # get unconditional embeddings for classifier free guidance
220
+ if do_classifier_free_guidance:
221
+ uncond_tokens: List[str]
222
+ if negative_prompt is None:
223
+ uncond_tokens = [""] * batch_size
224
+ elif type(prompt) is not type(negative_prompt):
225
+ raise TypeError(
226
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
227
+ f" {type(prompt)}."
228
+ )
229
+ elif isinstance(negative_prompt, str):
230
+ uncond_tokens = [negative_prompt]
231
+ elif batch_size != len(negative_prompt):
232
+ raise ValueError(
233
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
234
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
235
+ " the batch size of `prompt`."
236
+ )
237
+ else:
238
+ uncond_tokens = negative_prompt
239
+
240
+ max_length = text_input_ids.shape[-1]
241
+ uncond_input = self.tokenizer(
242
+ uncond_tokens,
243
+ padding="max_length",
244
+ max_length=max_length,
245
+ truncation=True,
246
+ return_tensors="pt",
247
+ )
248
+
249
+ if (
250
+ hasattr(self.text_encoder.config, "use_attention_mask")
251
+ and self.text_encoder.config.use_attention_mask
252
+ ):
253
+ attention_mask = uncond_input.attention_mask.to(device)
254
+ else:
255
+ attention_mask = None
256
+
257
+ uncond_embeddings = self.text_encoder(
258
+ uncond_input.input_ids.to(device),
259
+ attention_mask=attention_mask,
260
+ )
261
+ uncond_embeddings = uncond_embeddings[0]
262
+
263
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
264
+ seq_len = uncond_embeddings.shape[1]
265
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
266
+ uncond_embeddings = uncond_embeddings.view(
267
+ batch_size * num_videos_per_prompt, seq_len, -1
268
+ )
269
+
270
+ # For classifier free guidance, we need to do two forward passes.
271
+ # Here we concatenate the unconditional and text embeddings into a single batch
272
+ # to avoid doing two forward passes
273
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
274
+
275
+ return text_embeddings
276
+
277
+ def decode_latents(self, latents):
278
+ video_length = latents.shape[2]
279
+ latents = 1 / 0.18215 * latents
280
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
281
+ # video = self.vae.decode(latents).sample
282
+ video = []
283
+ for frame_idx in tqdm(range(latents.shape[0])):
284
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
285
+ video = torch.cat(video)
286
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
287
+ video = (video / 2 + 0.5).clamp(0, 1)
288
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
289
+ video = video.cpu().float().numpy()
290
+ return video
291
+
292
+ def prepare_extra_step_kwargs(self, generator, eta):
293
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
294
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
295
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
296
+ # and should be between [0, 1]
297
+
298
+ accepts_eta = "eta" in set(
299
+ inspect.signature(self.scheduler.step).parameters.keys()
300
+ )
301
+ extra_step_kwargs = {}
302
+ if accepts_eta:
303
+ extra_step_kwargs["eta"] = eta
304
+
305
+ # check if the scheduler accepts generator
306
+ accepts_generator = "generator" in set(
307
+ inspect.signature(self.scheduler.step).parameters.keys()
308
+ )
309
+ if accepts_generator:
310
+ extra_step_kwargs["generator"] = generator
311
+ return extra_step_kwargs
312
+
313
+ def check_inputs(self, prompt, height, width, callback_steps):
314
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
315
+ raise ValueError(
316
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
317
+ )
318
+
319
+ if height % 8 != 0 or width % 8 != 0:
320
+ raise ValueError(
321
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
322
+ )
323
+
324
+ if (callback_steps is None) or (
325
+ callback_steps is not None
326
+ and (not isinstance(callback_steps, int) or callback_steps <= 0)
327
+ ):
328
+ raise ValueError(
329
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
330
+ f" {type(callback_steps)}."
331
+ )
332
+
333
+ def prepare_latents(
334
+ self,
335
+ batch_size,
336
+ num_channels_latents,
337
+ video_length,
338
+ height,
339
+ width,
340
+ dtype,
341
+ device,
342
+ generator,
343
+ latents=None,
344
+ ):
345
+ shape = (
346
+ batch_size,
347
+ num_channels_latents,
348
+ video_length,
349
+ height // self.vae_scale_factor,
350
+ width // self.vae_scale_factor,
351
+ )
352
+ if isinstance(generator, list) and len(generator) != batch_size:
353
+ raise ValueError(
354
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
355
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
356
+ )
357
+ if latents is None:
358
+ rand_device = "cpu" if device.type == "mps" else device
359
+
360
+ if isinstance(generator, list):
361
+ shape = shape
362
+ # shape = (1,) + shape[1:]
363
+ latents = [
364
+ torch.randn(
365
+ shape, generator=generator[i], device=rand_device, dtype=dtype
366
+ )
367
+ for i in range(batch_size)
368
+ ]
369
+ latents = torch.cat(latents, dim=0).to(device)
370
+ else:
371
+ np.random.seed(generator.initial_seed() if generator is not None else 0)
372
+ # np.random.seed(0 if generator is not None else None)
373
+ latents = np.random.standard_normal(shape)
374
+ # latents = torch.randn(
375
+ # shape, generator=generator, device=rand_device, dtype=dtype
376
+ # ).to(device)
377
+ latents = torch.tensor(latents, dtype=dtype).to(device)
378
+ else:
379
+ if latents.shape != shape:
380
+ raise ValueError(
381
+ f"Unexpected latents shape, got {latents.shape}, expected {shape}"
382
+ )
383
+ latents = latents.to(device)
384
+
385
+ # scale the initial noise by the standard deviation required by the scheduler
386
+ latents = latents * self.scheduler.init_noise_sigma
387
+ return latents
388
+
389
+ @torch.no_grad()
390
+ def __call__(
391
+ self,
392
+ prompt: Union[str, List[str]],
393
+ video_length: Optional[int],
394
+ height: Optional[int] = None,
395
+ width: Optional[int] = None,
396
+ num_inference_steps: int = 50,
397
+ guidance_scale: float = 7.5,
398
+ negative_prompt: Optional[Union[str, List[str]]] = None,
399
+ num_videos_per_prompt: Optional[int] = 1,
400
+ eta: float = 0.0,
401
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
402
+ latents: Optional[torch.FloatTensor] = None,
403
+ output_type: Optional[str] = "tensor",
404
+ return_dict: bool = True,
405
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
406
+ callback_steps: Optional[int] = 1,
407
+ **kwargs,
408
+ ):
409
+ # Default height and width to unet
410
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
411
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
412
+
413
+ # Check inputs. Raise error if not correct
414
+ self.check_inputs(prompt, height, width, callback_steps)
415
+
416
+ # Define call parameters
417
+ # batch_size = 1 if isinstance(prompt, str) else len(prompt)
418
+ batch_size = 1
419
+ if latents is not None:
420
+ batch_size = latents.shape[0]
421
+ if isinstance(prompt, list):
422
+ batch_size = len(prompt)
423
+
424
+ device = self._execution_device
425
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
426
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
427
+ # corresponds to doing no classifier free guidance.
428
+ do_classifier_free_guidance = guidance_scale > 1.0
429
+
430
+ # Encode input prompt
431
+ prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
432
+ if negative_prompt is not None:
433
+ negative_prompt = (
434
+ negative_prompt
435
+ if isinstance(negative_prompt, list)
436
+ else [negative_prompt] * batch_size
437
+ )
438
+ text_embeddings = self._encode_prompt(
439
+ prompt,
440
+ device,
441
+ num_videos_per_prompt,
442
+ do_classifier_free_guidance,
443
+ negative_prompt,
444
+ )
445
+
446
+ # Prepare timesteps
447
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
448
+ timesteps = self.scheduler.timesteps
449
+
450
+ # Prepare latent variables
451
+ num_channels_latents = self.unet.in_channels
452
+ latents = self.prepare_latents(
453
+ batch_size * num_videos_per_prompt,
454
+ num_channels_latents,
455
+ video_length,
456
+ height,
457
+ width,
458
+ text_embeddings.dtype,
459
+ device,
460
+ generator,
461
+ latents,
462
+ )
463
+ latents_dtype = latents.dtype
464
+
465
+ # Prepare extra step kwargs.
466
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
467
+
468
+ # Denoising loop
469
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
470
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
471
+ for i, t in enumerate(timesteps):
472
+ # expand the latents if we are doing classifier free guidance
473
+ latent_model_input = (
474
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
475
+ )
476
+ latent_model_input = self.scheduler.scale_model_input(
477
+ latent_model_input, t
478
+ )
479
+
480
+ # predict the noise residual
481
+ noise_pred = self.unet(
482
+ latent_model_input, t, encoder_hidden_states=text_embeddings
483
+ ).sample.to(dtype=latents_dtype)
484
+
485
+ # perform guidance
486
+ if do_classifier_free_guidance:
487
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
488
+ noise_pred = noise_pred_uncond + guidance_scale * (
489
+ noise_pred_text - noise_pred_uncond
490
+ )
491
+
492
+ # compute the previous noisy sample x_t -> x_t-1
493
+ latents = self.scheduler.step(
494
+ noise_pred, t, latents, **extra_step_kwargs
495
+ ).prev_sample
496
+
497
+ # call the callback, if provided
498
+ if i == len(timesteps) - 1 or (
499
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
500
+ ):
501
+ progress_bar.update()
502
+ if callback is not None and i % callback_steps == 0:
503
+ callback(i, t, latents)
504
+
505
+ # Post-processing
506
+ video = self.decode_latents(latents)
507
+
508
+ # Convert to tensor
509
+ if output_type == "tensor":
510
+ video = torch.from_numpy(video)
511
+
512
+ if not return_dict:
513
+ return video
514
+
515
+ return AnimationPipelineOutput(videos=video)
516
+
517
+
518
+ class AnimationCtrlPipeline(AnimationPipeline):
519
+ """
520
+ AnimationPipeline (_type_): Pipeline for AnimateDiff augmented with UniCtrl
521
+ """
522
+
523
+ _optional_components = []
524
+
525
+ def __init__(
526
+ self,
527
+ vae: AutoencoderKL,
528
+ text_encoder: CLIPTextModel,
529
+ tokenizer: CLIPTokenizer,
530
+ unet: UNet3DConditionModel,
531
+ scheduler: Union[
532
+ DDIMScheduler,
533
+ PNDMScheduler,
534
+ LMSDiscreteScheduler,
535
+ EulerDiscreteScheduler,
536
+ EulerAncestralDiscreteScheduler,
537
+ DPMSolverMultistepScheduler,
538
+ ],
539
+ ):
540
+ super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
541
+
542
+ @torch.no_grad()
543
+ def __call__(
544
+ self,
545
+ prompt: Union[str, List[str]],
546
+ video_length: Optional[int],
547
+ height: Optional[int] = None,
548
+ width: Optional[int] = None,
549
+ num_inference_steps: int = 50,
550
+ guidance_scale: float = 7.5,
551
+ negative_prompt: Optional[Union[str, List[str]]] = None,
552
+ num_videos_per_prompt: Optional[int] = 1,
553
+ eta: float = 0.0,
554
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
555
+ latents: Optional[torch.FloatTensor] = None,
556
+ output_type: Optional[str] = "tensor",
557
+ return_dict: bool = True,
558
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
559
+ callback_steps: Optional[int] = 1,
560
+ # unictrl args
561
+ use_fp16: bool = False,
562
+ **kwargs,
563
+ ):
564
+ if use_fp16:
565
+ print("Warning: using half percision for inferencing!")
566
+ self.vae.to(dtype=torch.float16)
567
+ self.unet.to(dtype=torch.float16)
568
+ self.text_encoder.to(dtype=torch.float16)
569
+ # Default height and width to unet
570
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
571
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
572
+
573
+ # Check inputs. Raise error if not correct
574
+ self.check_inputs(prompt, height, width, callback_steps)
575
+
576
+ # Define call parameters
577
+ # batch_size = 1 if isinstance(prompt, str) else len(prompt)
578
+ batch_size = 1
579
+ if latents is not None:
580
+ batch_size = latents.shape[0]
581
+ if isinstance(prompt, list):
582
+ batch_size = len(prompt)
583
+
584
+ device = self._execution_device
585
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
586
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
587
+ # corresponds to doing no classifier free guidance.
588
+ do_classifier_free_guidance = guidance_scale > 1.0
589
+
590
+ # Encode input prompt
591
+ prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
592
+ if negative_prompt is not None:
593
+ negative_prompt = (
594
+ negative_prompt
595
+ if isinstance(negative_prompt, list)
596
+ else [negative_prompt] * batch_size
597
+ )
598
+ text_embeddings = self._encode_prompt(
599
+ prompt,
600
+ device,
601
+ num_videos_per_prompt,
602
+ do_classifier_free_guidance,
603
+ negative_prompt,
604
+ )
605
+
606
+ # Prepare timesteps
607
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
608
+ timesteps = self.scheduler.timesteps
609
+
610
+ # Prepare latent variables
611
+ num_channels_latents = self.unet.in_channels
612
+ latents = self.prepare_latents(
613
+ batch_size * num_videos_per_prompt,
614
+ num_channels_latents,
615
+ video_length,
616
+ height,
617
+ width,
618
+ text_embeddings.dtype,
619
+ device,
620
+ generator,
621
+ latents,
622
+ )
623
+ latents_dtype = latents.dtype
624
+
625
+ # Prepare extra step kwargs.
626
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
627
+
628
+ latents = latents.to(latents_dtype)
629
+ motion_latents = latents.clone()
630
+
631
+ # Denoising loop
632
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
633
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
634
+ current_timesteps = timesteps
635
+ for i, t in enumerate(current_timesteps):
636
+ # Spatiotemporal Synchronization
637
+ motion_latents = latents.clone()
638
+ # Output Branch
639
+ latent_model_input = (
640
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
641
+ )
642
+ latent_model_input = self.scheduler.scale_model_input(
643
+ latent_model_input, t
644
+ )
645
+ # Motion Branch
646
+ motion_latent_model_input = (
647
+ torch.cat([motion_latents] * 2)
648
+ if do_classifier_free_guidance
649
+ else motion_latents
650
+ )
651
+ motion_latent_model_input = self.scheduler.scale_model_input(
652
+ motion_latent_model_input, t
653
+ )
654
+
655
+ if do_classifier_free_guidance:
656
+ concat_latent_model_input = torch.stack(
657
+ [
658
+ latent_model_input[0],
659
+ motion_latent_model_input[0],
660
+ latent_model_input[1],
661
+ motion_latent_model_input[1],
662
+ ],
663
+ dim=0,
664
+ )
665
+ concat_prompt_embeds = torch.stack(
666
+ [
667
+ text_embeddings[0],
668
+ text_embeddings[0],
669
+ text_embeddings[1],
670
+ text_embeddings[1],
671
+ ],
672
+ dim=0,
673
+ )
674
+ else:
675
+ concat_latent_model_input = torch.cat(
676
+ [
677
+ latent_model_input,
678
+ motion_latent_model_input,
679
+ ],
680
+ dim=0,
681
+ )
682
+ concat_prompt_embeds = torch.cat(
683
+ [
684
+ text_embeddings,
685
+ text_embeddings,
686
+ ],
687
+ dim=0,
688
+ )
689
+
690
+ # predict the noise residual for both branchs
691
+ concat_noise_pred = self.unet(
692
+ concat_latent_model_input,
693
+ t,
694
+ encoder_hidden_states=concat_prompt_embeds,
695
+ ).sample.to(dtype=latents_dtype)
696
+
697
+ # perform guidance
698
+ if do_classifier_free_guidance:
699
+ (
700
+ noise_pred_uncond,
701
+ motion_noise_pred_uncond,
702
+ noise_pred_text,
703
+ motion_noise_pred_text,
704
+ ) = concat_noise_pred.chunk(4, dim=0)
705
+
706
+ noise_pred = noise_pred_uncond + guidance_scale * (
707
+ noise_pred_text - noise_pred_uncond
708
+ )
709
+ motion_noise_pred = motion_noise_pred_uncond + guidance_scale * (
710
+ motion_noise_pred_text - motion_noise_pred_uncond
711
+ )
712
+
713
+ else:
714
+ (
715
+ noise_pred,
716
+ motion_noise_pred,
717
+ ) = concat_noise_pred.chunk(2, dim=0)
718
+
719
+ # compute the previous noisy sample x_t -> x_t-1
720
+ latents = self.scheduler.step(
721
+ noise_pred, t, latents, **extra_step_kwargs
722
+ ).prev_sample
723
+
724
+ motion_latents = self.scheduler.step(
725
+ motion_noise_pred, t, motion_latents, **extra_step_kwargs
726
+ ).prev_sample
727
+
728
+ # call the callback, if provided
729
+ if i == len(current_timesteps) - 1 or (
730
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
731
+ ):
732
+ progress_bar.update()
733
+ if callback is not None and i % callback_steps == 0:
734
+ callback(i, t, latents)
735
+
736
+ # Post-processing
737
+ video = self.decode_latents(latents)
738
+
739
+ # Convert to tensor
740
+ if output_type == "tensor":
741
+ video = torch.from_numpy(video)
742
+
743
+ if not return_dict:
744
+ return video
745
+
746
+ return AnimationPipelineOutput(videos=video)
animatediff/utils/convert_from_ckpt.py ADDED
@@ -0,0 +1,1216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Conversion script for the Stable Diffusion checkpoints."""
16
+
17
+ import re
18
+ from io import BytesIO
19
+ from typing import Optional
20
+
21
+ import requests
22
+ import torch
23
+ from transformers import (
24
+ AutoFeatureExtractor,
25
+ BertTokenizerFast,
26
+ CLIPImageProcessor,
27
+ CLIPTextModel,
28
+ CLIPTextModelWithProjection,
29
+ CLIPTokenizer,
30
+ CLIPVisionConfig,
31
+ CLIPVisionModelWithProjection,
32
+ )
33
+
34
+ from diffusers.models import (
35
+ AutoencoderKL,
36
+ PriorTransformer,
37
+ UNet2DConditionModel,
38
+ )
39
+ from diffusers.schedulers import (
40
+ DDIMScheduler,
41
+ DDPMScheduler,
42
+ DPMSolverMultistepScheduler,
43
+ EulerAncestralDiscreteScheduler,
44
+ EulerDiscreteScheduler,
45
+ HeunDiscreteScheduler,
46
+ LMSDiscreteScheduler,
47
+ PNDMScheduler,
48
+ UnCLIPScheduler,
49
+ )
50
+ from diffusers.utils.import_utils import BACKENDS_MAPPING
51
+
52
+
53
+ def shave_segments(path, n_shave_prefix_segments=1):
54
+ """
55
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
56
+ """
57
+ if n_shave_prefix_segments >= 0:
58
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
59
+ else:
60
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
61
+
62
+
63
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
64
+ """
65
+ Updates paths inside resnets to the new naming scheme (local renaming)
66
+ """
67
+ mapping = []
68
+ for old_item in old_list:
69
+ new_item = old_item.replace("in_layers.0", "norm1")
70
+ new_item = new_item.replace("in_layers.2", "conv1")
71
+
72
+ new_item = new_item.replace("out_layers.0", "norm2")
73
+ new_item = new_item.replace("out_layers.3", "conv2")
74
+
75
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
76
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
77
+
78
+ new_item = shave_segments(
79
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments
80
+ )
81
+
82
+ mapping.append({"old": old_item, "new": new_item})
83
+
84
+ return mapping
85
+
86
+
87
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
88
+ """
89
+ Updates paths inside resnets to the new naming scheme (local renaming)
90
+ """
91
+ mapping = []
92
+ for old_item in old_list:
93
+ new_item = old_item
94
+
95
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
96
+ new_item = shave_segments(
97
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments
98
+ )
99
+
100
+ mapping.append({"old": old_item, "new": new_item})
101
+
102
+ return mapping
103
+
104
+
105
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
106
+ """
107
+ Updates paths inside attentions to the new naming scheme (local renaming)
108
+ """
109
+ mapping = []
110
+ for old_item in old_list:
111
+ new_item = old_item
112
+
113
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
114
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
115
+
116
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
117
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
118
+
119
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
120
+
121
+ mapping.append({"old": old_item, "new": new_item})
122
+
123
+ return mapping
124
+
125
+
126
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
127
+ """
128
+ Updates paths inside attentions to the new naming scheme (local renaming)
129
+ """
130
+ mapping = []
131
+ for old_item in old_list:
132
+ new_item = old_item
133
+
134
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
135
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
136
+
137
+ new_item = new_item.replace("q.weight", "query.weight")
138
+ new_item = new_item.replace("q.bias", "query.bias")
139
+
140
+ new_item = new_item.replace("k.weight", "key.weight")
141
+ new_item = new_item.replace("k.bias", "key.bias")
142
+
143
+ new_item = new_item.replace("v.weight", "value.weight")
144
+ new_item = new_item.replace("v.bias", "value.bias")
145
+
146
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
147
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
148
+
149
+ new_item = shave_segments(
150
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments
151
+ )
152
+
153
+ mapping.append({"old": old_item, "new": new_item})
154
+
155
+ return mapping
156
+
157
+
158
+ def assign_to_checkpoint(
159
+ paths,
160
+ checkpoint,
161
+ old_checkpoint,
162
+ attention_paths_to_split=None,
163
+ additional_replacements=None,
164
+ config=None,
165
+ ):
166
+ """
167
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
168
+ attention layers, and takes into account additional replacements that may arise.
169
+
170
+ Assigns the weights to the new checkpoint.
171
+ """
172
+ assert isinstance(
173
+ paths, list
174
+ ), "Paths should be a list of dicts containing 'old' and 'new' keys."
175
+
176
+ # Splits the attention layers into three variables.
177
+ if attention_paths_to_split is not None:
178
+ for path, path_map in attention_paths_to_split.items():
179
+ old_tensor = old_checkpoint[path]
180
+ channels = old_tensor.shape[0] // 3
181
+
182
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
183
+
184
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
185
+
186
+ old_tensor = old_tensor.reshape(
187
+ (num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]
188
+ )
189
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
190
+
191
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
192
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
193
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
194
+
195
+ for path in paths:
196
+ new_path = path["new"]
197
+
198
+ # These have already been assigned
199
+ if (
200
+ attention_paths_to_split is not None
201
+ and new_path in attention_paths_to_split
202
+ ):
203
+ continue
204
+
205
+ # Global renaming happens here
206
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
207
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
208
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
209
+
210
+ if additional_replacements is not None:
211
+ for replacement in additional_replacements:
212
+ new_path = new_path.replace(replacement["old"], replacement["new"])
213
+
214
+ # proj_attn.weight has to be converted from conv 1D to linear
215
+ if "proj_attn.weight" in new_path:
216
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
217
+ else:
218
+ checkpoint[new_path] = old_checkpoint[path["old"]]
219
+
220
+
221
+ def conv_attn_to_linear(checkpoint):
222
+ keys = list(checkpoint.keys())
223
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
224
+ for key in keys:
225
+ if ".".join(key.split(".")[-2:]) in attn_keys:
226
+ if checkpoint[key].ndim > 2:
227
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
228
+ elif "proj_attn.weight" in key:
229
+ if checkpoint[key].ndim > 2:
230
+ checkpoint[key] = checkpoint[key][:, :, 0]
231
+
232
+
233
+ def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
234
+ """
235
+ Creates a config for the diffusers based on the config of the LDM model.
236
+ """
237
+ if controlnet:
238
+ unet_params = original_config.model.params.control_stage_config.params
239
+ else:
240
+ unet_params = original_config.model.params.unet_config.params
241
+
242
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
243
+
244
+ block_out_channels = [
245
+ unet_params.model_channels * mult for mult in unet_params.channel_mult
246
+ ]
247
+
248
+ down_block_types = []
249
+ resolution = 1
250
+ for i in range(len(block_out_channels)):
251
+ block_type = (
252
+ "CrossAttnDownBlock2D"
253
+ if resolution in unet_params.attention_resolutions
254
+ else "DownBlock2D"
255
+ )
256
+ down_block_types.append(block_type)
257
+ if i != len(block_out_channels) - 1:
258
+ resolution *= 2
259
+
260
+ up_block_types = []
261
+ for i in range(len(block_out_channels)):
262
+ block_type = (
263
+ "CrossAttnUpBlock2D"
264
+ if resolution in unet_params.attention_resolutions
265
+ else "UpBlock2D"
266
+ )
267
+ up_block_types.append(block_type)
268
+ resolution //= 2
269
+
270
+ vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
271
+
272
+ head_dim = unet_params.num_heads if "num_heads" in unet_params else None
273
+ use_linear_projection = (
274
+ unet_params.use_linear_in_transformer
275
+ if "use_linear_in_transformer" in unet_params
276
+ else False
277
+ )
278
+ if use_linear_projection:
279
+ # stable diffusion 2-base-512 and 2-768
280
+ if head_dim is None:
281
+ head_dim = [5, 10, 20, 20]
282
+
283
+ class_embed_type = None
284
+ projection_class_embeddings_input_dim = None
285
+
286
+ if "num_classes" in unet_params:
287
+ if unet_params.num_classes == "sequential":
288
+ class_embed_type = "projection"
289
+ assert "adm_in_channels" in unet_params
290
+ projection_class_embeddings_input_dim = unet_params.adm_in_channels
291
+ else:
292
+ raise NotImplementedError(
293
+ f"Unknown conditional unet num_classes config: {unet_params.num_classes}"
294
+ )
295
+
296
+ config = {
297
+ "sample_size": image_size // vae_scale_factor,
298
+ "in_channels": unet_params.in_channels,
299
+ "down_block_types": tuple(down_block_types),
300
+ "block_out_channels": tuple(block_out_channels),
301
+ "layers_per_block": unet_params.num_res_blocks,
302
+ "cross_attention_dim": unet_params.context_dim,
303
+ "attention_head_dim": head_dim,
304
+ "use_linear_projection": use_linear_projection,
305
+ "class_embed_type": class_embed_type,
306
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
307
+ }
308
+
309
+ if not controlnet:
310
+ config["out_channels"] = unet_params.out_channels
311
+ config["up_block_types"] = tuple(up_block_types)
312
+
313
+ return config
314
+
315
+
316
+ def create_vae_diffusers_config(original_config, image_size: int):
317
+ """
318
+ Creates a config for the diffusers based on the config of the LDM model.
319
+ """
320
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
321
+ _ = original_config.model.params.first_stage_config.params.embed_dim
322
+
323
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
324
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
325
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
326
+
327
+ config = {
328
+ "sample_size": image_size,
329
+ "in_channels": vae_params.in_channels,
330
+ "out_channels": vae_params.out_ch,
331
+ "down_block_types": tuple(down_block_types),
332
+ "up_block_types": tuple(up_block_types),
333
+ "block_out_channels": tuple(block_out_channels),
334
+ "latent_channels": vae_params.z_channels,
335
+ "layers_per_block": vae_params.num_res_blocks,
336
+ }
337
+ return config
338
+
339
+
340
+ def create_diffusers_schedular(original_config):
341
+ schedular = DDIMScheduler(
342
+ num_train_timesteps=original_config.model.params.timesteps,
343
+ beta_start=original_config.model.params.linear_start,
344
+ beta_end=original_config.model.params.linear_end,
345
+ beta_schedule="scaled_linear",
346
+ )
347
+ return schedular
348
+
349
+
350
+ def create_ldm_bert_config(original_config):
351
+ bert_params = original_config.model.parms.cond_stage_config.params
352
+ config = LDMBertConfig(
353
+ d_model=bert_params.n_embed,
354
+ encoder_layers=bert_params.n_layer,
355
+ encoder_ffn_dim=bert_params.n_embed * 4,
356
+ )
357
+ return config
358
+
359
+
360
+ def convert_ldm_unet_checkpoint(
361
+ checkpoint, config, path=None, extract_ema=False, controlnet=False
362
+ ):
363
+ """
364
+ Takes a state dict and a config, and returns a converted checkpoint.
365
+ """
366
+
367
+ # extract state_dict for UNet
368
+ unet_state_dict = {}
369
+ keys = list(checkpoint.keys())
370
+
371
+ if controlnet:
372
+ unet_key = "control_model."
373
+ else:
374
+ unet_key = "model.diffusion_model."
375
+
376
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
377
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
378
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
379
+ print(
380
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
381
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
382
+ )
383
+ for key in keys:
384
+ if key.startswith("model.diffusion_model"):
385
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
386
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(
387
+ flat_ema_key
388
+ )
389
+ else:
390
+ if sum(k.startswith("model_ema") for k in keys) > 100:
391
+ print(
392
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
393
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
394
+ )
395
+
396
+ for key in keys:
397
+ if key.startswith(unet_key):
398
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
399
+
400
+ new_checkpoint = {}
401
+
402
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict[
403
+ "time_embed.0.weight"
404
+ ]
405
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict[
406
+ "time_embed.0.bias"
407
+ ]
408
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict[
409
+ "time_embed.2.weight"
410
+ ]
411
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict[
412
+ "time_embed.2.bias"
413
+ ]
414
+
415
+ if config["class_embed_type"] is None:
416
+ # No parameters to port
417
+ ...
418
+ elif (
419
+ config["class_embed_type"] == "timestep"
420
+ or config["class_embed_type"] == "projection"
421
+ ):
422
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict[
423
+ "label_emb.0.0.weight"
424
+ ]
425
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict[
426
+ "label_emb.0.0.bias"
427
+ ]
428
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict[
429
+ "label_emb.0.2.weight"
430
+ ]
431
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict[
432
+ "label_emb.0.2.bias"
433
+ ]
434
+ else:
435
+ raise NotImplementedError(
436
+ f"Not implemented `class_embed_type`: {config['class_embed_type']}"
437
+ )
438
+
439
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
440
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
441
+
442
+ if not controlnet:
443
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
444
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
445
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
446
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
447
+
448
+ # Retrieves the keys for the input blocks only
449
+ num_input_blocks = len(
450
+ {
451
+ ".".join(layer.split(".")[:2])
452
+ for layer in unet_state_dict
453
+ if "input_blocks" in layer
454
+ }
455
+ )
456
+ input_blocks = {
457
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
458
+ for layer_id in range(num_input_blocks)
459
+ }
460
+
461
+ # Retrieves the keys for the middle blocks only
462
+ num_middle_blocks = len(
463
+ {
464
+ ".".join(layer.split(".")[:2])
465
+ for layer in unet_state_dict
466
+ if "middle_block" in layer
467
+ }
468
+ )
469
+ middle_blocks = {
470
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
471
+ for layer_id in range(num_middle_blocks)
472
+ }
473
+
474
+ # Retrieves the keys for the output blocks only
475
+ num_output_blocks = len(
476
+ {
477
+ ".".join(layer.split(".")[:2])
478
+ for layer in unet_state_dict
479
+ if "output_blocks" in layer
480
+ }
481
+ )
482
+ output_blocks = {
483
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
484
+ for layer_id in range(num_output_blocks)
485
+ }
486
+
487
+ for i in range(1, num_input_blocks):
488
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
489
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
490
+
491
+ resnets = [
492
+ key
493
+ for key in input_blocks[i]
494
+ if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
495
+ ]
496
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
497
+
498
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
499
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = (
500
+ unet_state_dict.pop(f"input_blocks.{i}.0.op.weight")
501
+ )
502
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = (
503
+ unet_state_dict.pop(f"input_blocks.{i}.0.op.bias")
504
+ )
505
+
506
+ paths = renew_resnet_paths(resnets)
507
+ meta_path = {
508
+ "old": f"input_blocks.{i}.0",
509
+ "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}",
510
+ }
511
+ assign_to_checkpoint(
512
+ paths,
513
+ new_checkpoint,
514
+ unet_state_dict,
515
+ additional_replacements=[meta_path],
516
+ config=config,
517
+ )
518
+
519
+ if len(attentions):
520
+ paths = renew_attention_paths(attentions)
521
+ meta_path = {
522
+ "old": f"input_blocks.{i}.1",
523
+ "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}",
524
+ }
525
+ assign_to_checkpoint(
526
+ paths,
527
+ new_checkpoint,
528
+ unet_state_dict,
529
+ additional_replacements=[meta_path],
530
+ config=config,
531
+ )
532
+
533
+ resnet_0 = middle_blocks[0]
534
+ attentions = middle_blocks[1]
535
+ resnet_1 = middle_blocks[2]
536
+
537
+ resnet_0_paths = renew_resnet_paths(resnet_0)
538
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
539
+
540
+ resnet_1_paths = renew_resnet_paths(resnet_1)
541
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
542
+
543
+ attentions_paths = renew_attention_paths(attentions)
544
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
545
+ assign_to_checkpoint(
546
+ attentions_paths,
547
+ new_checkpoint,
548
+ unet_state_dict,
549
+ additional_replacements=[meta_path],
550
+ config=config,
551
+ )
552
+
553
+ for i in range(num_output_blocks):
554
+ block_id = i // (config["layers_per_block"] + 1)
555
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
556
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
557
+ output_block_list = {}
558
+
559
+ for layer in output_block_layers:
560
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
561
+ if layer_id in output_block_list:
562
+ output_block_list[layer_id].append(layer_name)
563
+ else:
564
+ output_block_list[layer_id] = [layer_name]
565
+
566
+ if len(output_block_list) > 1:
567
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
568
+ attentions = [
569
+ key for key in output_blocks[i] if f"output_blocks.{i}.1" in key
570
+ ]
571
+
572
+ resnet_0_paths = renew_resnet_paths(resnets)
573
+ paths = renew_resnet_paths(resnets)
574
+
575
+ meta_path = {
576
+ "old": f"output_blocks.{i}.0",
577
+ "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}",
578
+ }
579
+ assign_to_checkpoint(
580
+ paths,
581
+ new_checkpoint,
582
+ unet_state_dict,
583
+ additional_replacements=[meta_path],
584
+ config=config,
585
+ )
586
+
587
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
588
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
589
+ index = list(output_block_list.values()).index(
590
+ ["conv.bias", "conv.weight"]
591
+ )
592
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = (
593
+ unet_state_dict[f"output_blocks.{i}.{index}.conv.weight"]
594
+ )
595
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = (
596
+ unet_state_dict[f"output_blocks.{i}.{index}.conv.bias"]
597
+ )
598
+
599
+ # Clear attentions as they have been attributed above.
600
+ if len(attentions) == 2:
601
+ attentions = []
602
+
603
+ if len(attentions):
604
+ paths = renew_attention_paths(attentions)
605
+ meta_path = {
606
+ "old": f"output_blocks.{i}.1",
607
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
608
+ }
609
+ assign_to_checkpoint(
610
+ paths,
611
+ new_checkpoint,
612
+ unet_state_dict,
613
+ additional_replacements=[meta_path],
614
+ config=config,
615
+ )
616
+ else:
617
+ resnet_0_paths = renew_resnet_paths(
618
+ output_block_layers, n_shave_prefix_segments=1
619
+ )
620
+ for path in resnet_0_paths:
621
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
622
+ new_path = ".".join(
623
+ [
624
+ "up_blocks",
625
+ str(block_id),
626
+ "resnets",
627
+ str(layer_in_block_id),
628
+ path["new"],
629
+ ]
630
+ )
631
+
632
+ new_checkpoint[new_path] = unet_state_dict[old_path]
633
+
634
+ if controlnet:
635
+ # conditioning embedding
636
+
637
+ orig_index = 0
638
+
639
+ new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = (
640
+ unet_state_dict.pop(f"input_hint_block.{orig_index}.weight")
641
+ )
642
+ new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
643
+ f"input_hint_block.{orig_index}.bias"
644
+ )
645
+
646
+ orig_index += 2
647
+
648
+ diffusers_index = 0
649
+
650
+ while diffusers_index < 6:
651
+ new_checkpoint[
652
+ f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"
653
+ ] = unet_state_dict.pop(f"input_hint_block.{orig_index}.weight")
654
+ new_checkpoint[
655
+ f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"
656
+ ] = unet_state_dict.pop(f"input_hint_block.{orig_index}.bias")
657
+ diffusers_index += 1
658
+ orig_index += 2
659
+
660
+ new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = (
661
+ unet_state_dict.pop(f"input_hint_block.{orig_index}.weight")
662
+ )
663
+ new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
664
+ f"input_hint_block.{orig_index}.bias"
665
+ )
666
+
667
+ # down blocks
668
+ for i in range(num_input_blocks):
669
+ new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(
670
+ f"zero_convs.{i}.0.weight"
671
+ )
672
+ new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(
673
+ f"zero_convs.{i}.0.bias"
674
+ )
675
+
676
+ # mid block
677
+ new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop(
678
+ "middle_block_out.0.weight"
679
+ )
680
+ new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop(
681
+ "middle_block_out.0.bias"
682
+ )
683
+
684
+ return new_checkpoint
685
+
686
+
687
+ def convert_ldm_vae_checkpoint(checkpoint, config):
688
+ # extract state dict for VAE
689
+ vae_state_dict = {}
690
+ vae_key = "first_stage_model."
691
+ keys = list(checkpoint.keys())
692
+ for key in keys:
693
+ if key.startswith(vae_key):
694
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
695
+
696
+ new_checkpoint = {}
697
+
698
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
699
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
700
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[
701
+ "encoder.conv_out.weight"
702
+ ]
703
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
704
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[
705
+ "encoder.norm_out.weight"
706
+ ]
707
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[
708
+ "encoder.norm_out.bias"
709
+ ]
710
+
711
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
712
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
713
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[
714
+ "decoder.conv_out.weight"
715
+ ]
716
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
717
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[
718
+ "decoder.norm_out.weight"
719
+ ]
720
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[
721
+ "decoder.norm_out.bias"
722
+ ]
723
+
724
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
725
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
726
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
727
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
728
+
729
+ # Retrieves the keys for the encoder down blocks only
730
+ num_down_blocks = len(
731
+ {
732
+ ".".join(layer.split(".")[:3])
733
+ for layer in vae_state_dict
734
+ if "encoder.down" in layer
735
+ }
736
+ )
737
+ down_blocks = {
738
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key]
739
+ for layer_id in range(num_down_blocks)
740
+ }
741
+
742
+ # Retrieves the keys for the decoder up blocks only
743
+ num_up_blocks = len(
744
+ {
745
+ ".".join(layer.split(".")[:3])
746
+ for layer in vae_state_dict
747
+ if "decoder.up" in layer
748
+ }
749
+ )
750
+ up_blocks = {
751
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key]
752
+ for layer_id in range(num_up_blocks)
753
+ }
754
+
755
+ for i in range(num_down_blocks):
756
+ resnets = [
757
+ key
758
+ for key in down_blocks[i]
759
+ if f"down.{i}" in key and f"down.{i}.downsample" not in key
760
+ ]
761
+
762
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
763
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = (
764
+ vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
765
+ )
766
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = (
767
+ vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
768
+ )
769
+
770
+ paths = renew_vae_resnet_paths(resnets)
771
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
772
+ assign_to_checkpoint(
773
+ paths,
774
+ new_checkpoint,
775
+ vae_state_dict,
776
+ additional_replacements=[meta_path],
777
+ config=config,
778
+ )
779
+
780
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
781
+ num_mid_res_blocks = 2
782
+ for i in range(1, num_mid_res_blocks + 1):
783
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
784
+
785
+ paths = renew_vae_resnet_paths(resnets)
786
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
787
+ assign_to_checkpoint(
788
+ paths,
789
+ new_checkpoint,
790
+ vae_state_dict,
791
+ additional_replacements=[meta_path],
792
+ config=config,
793
+ )
794
+
795
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
796
+ paths = renew_vae_attention_paths(mid_attentions)
797
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
798
+ assign_to_checkpoint(
799
+ paths,
800
+ new_checkpoint,
801
+ vae_state_dict,
802
+ additional_replacements=[meta_path],
803
+ config=config,
804
+ )
805
+ conv_attn_to_linear(new_checkpoint)
806
+
807
+ for i in range(num_up_blocks):
808
+ block_id = num_up_blocks - 1 - i
809
+ resnets = [
810
+ key
811
+ for key in up_blocks[block_id]
812
+ if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
813
+ ]
814
+
815
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
816
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = (
817
+ vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
818
+ )
819
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = (
820
+ vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
821
+ )
822
+
823
+ paths = renew_vae_resnet_paths(resnets)
824
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
825
+ assign_to_checkpoint(
826
+ paths,
827
+ new_checkpoint,
828
+ vae_state_dict,
829
+ additional_replacements=[meta_path],
830
+ config=config,
831
+ )
832
+
833
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
834
+ num_mid_res_blocks = 2
835
+ for i in range(1, num_mid_res_blocks + 1):
836
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
837
+
838
+ paths = renew_vae_resnet_paths(resnets)
839
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
840
+ assign_to_checkpoint(
841
+ paths,
842
+ new_checkpoint,
843
+ vae_state_dict,
844
+ additional_replacements=[meta_path],
845
+ config=config,
846
+ )
847
+
848
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
849
+ paths = renew_vae_attention_paths(mid_attentions)
850
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
851
+ assign_to_checkpoint(
852
+ paths,
853
+ new_checkpoint,
854
+ vae_state_dict,
855
+ additional_replacements=[meta_path],
856
+ config=config,
857
+ )
858
+ conv_attn_to_linear(new_checkpoint)
859
+ return new_checkpoint
860
+
861
+
862
+ def convert_ldm_bert_checkpoint(checkpoint, config):
863
+ def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
864
+ hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
865
+ hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
866
+ hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
867
+
868
+ hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
869
+ hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
870
+
871
+ def _copy_linear(hf_linear, pt_linear):
872
+ hf_linear.weight = pt_linear.weight
873
+ hf_linear.bias = pt_linear.bias
874
+
875
+ def _copy_layer(hf_layer, pt_layer):
876
+ # copy layer norms
877
+ _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
878
+ _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
879
+
880
+ # copy attn
881
+ _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
882
+
883
+ # copy MLP
884
+ pt_mlp = pt_layer[1][1]
885
+ _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
886
+ _copy_linear(hf_layer.fc2, pt_mlp.net[2])
887
+
888
+ def _copy_layers(hf_layers, pt_layers):
889
+ for i, hf_layer in enumerate(hf_layers):
890
+ if i != 0:
891
+ i += i
892
+ pt_layer = pt_layers[i : i + 2]
893
+ _copy_layer(hf_layer, pt_layer)
894
+
895
+ hf_model = LDMBertModel(config).eval()
896
+
897
+ # copy embeds
898
+ hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
899
+ hf_model.model.embed_positions.weight.data = (
900
+ checkpoint.transformer.pos_emb.emb.weight
901
+ )
902
+
903
+ # copy layer norm
904
+ _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
905
+
906
+ # copy hidden layers
907
+ _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
908
+
909
+ _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
910
+
911
+ return hf_model
912
+
913
+
914
+ def convert_ldm_clip_checkpoint(checkpoint):
915
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
916
+ keys = list(checkpoint.keys())
917
+
918
+ text_model_dict = {}
919
+
920
+ for key in keys:
921
+ if key.startswith("cond_stage_model.transformer"):
922
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[
923
+ key
924
+ ]
925
+
926
+ text_model.load_state_dict(text_model_dict)
927
+
928
+ return text_model
929
+
930
+
931
+ textenc_conversion_lst = [
932
+ (
933
+ "cond_stage_model.model.positional_embedding",
934
+ "text_model.embeddings.position_embedding.weight",
935
+ ),
936
+ (
937
+ "cond_stage_model.model.token_embedding.weight",
938
+ "text_model.embeddings.token_embedding.weight",
939
+ ),
940
+ ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
941
+ ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
942
+ ]
943
+ textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
944
+
945
+ textenc_transformer_conversion_lst = [
946
+ # (stable-diffusion, HF Diffusers)
947
+ ("resblocks.", "text_model.encoder.layers."),
948
+ ("ln_1", "layer_norm1"),
949
+ ("ln_2", "layer_norm2"),
950
+ (".c_fc.", ".fc1."),
951
+ (".c_proj.", ".fc2."),
952
+ (".attn", ".self_attn"),
953
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
954
+ (
955
+ "token_embedding.weight",
956
+ "transformer.text_model.embeddings.token_embedding.weight",
957
+ ),
958
+ (
959
+ "positional_embedding",
960
+ "transformer.text_model.embeddings.position_embedding.weight",
961
+ ),
962
+ ]
963
+ protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
964
+ textenc_pattern = re.compile("|".join(protected.keys()))
965
+
966
+
967
+ def convert_paint_by_example_checkpoint(checkpoint):
968
+ config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
969
+ model = PaintByExampleImageEncoder(config)
970
+
971
+ keys = list(checkpoint.keys())
972
+
973
+ text_model_dict = {}
974
+
975
+ for key in keys:
976
+ if key.startswith("cond_stage_model.transformer"):
977
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[
978
+ key
979
+ ]
980
+
981
+ # load clip vision
982
+ model.model.load_state_dict(text_model_dict)
983
+
984
+ # load mapper
985
+ keys_mapper = {
986
+ k[len("cond_stage_model.mapper.res") :]: v
987
+ for k, v in checkpoint.items()
988
+ if k.startswith("cond_stage_model.mapper")
989
+ }
990
+
991
+ MAPPING = {
992
+ "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
993
+ "attn.c_proj": ["attn1.to_out.0"],
994
+ "ln_1": ["norm1"],
995
+ "ln_2": ["norm3"],
996
+ "mlp.c_fc": ["ff.net.0.proj"],
997
+ "mlp.c_proj": ["ff.net.2"],
998
+ }
999
+
1000
+ mapped_weights = {}
1001
+ for key, value in keys_mapper.items():
1002
+ prefix = key[: len("blocks.i")]
1003
+ suffix = key.split(prefix)[-1].split(".")[-1]
1004
+ name = key.split(prefix)[-1].split(suffix)[0][1:-1]
1005
+ mapped_names = MAPPING[name]
1006
+
1007
+ num_splits = len(mapped_names)
1008
+ for i, mapped_name in enumerate(mapped_names):
1009
+ new_name = ".".join([prefix, mapped_name, suffix])
1010
+ shape = value.shape[0] // num_splits
1011
+ mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
1012
+
1013
+ model.mapper.load_state_dict(mapped_weights)
1014
+
1015
+ # load final layer norm
1016
+ model.final_layer_norm.load_state_dict(
1017
+ {
1018
+ "bias": checkpoint["cond_stage_model.final_ln.bias"],
1019
+ "weight": checkpoint["cond_stage_model.final_ln.weight"],
1020
+ }
1021
+ )
1022
+
1023
+ # load final proj
1024
+ model.proj_out.load_state_dict(
1025
+ {
1026
+ "bias": checkpoint["proj_out.bias"],
1027
+ "weight": checkpoint["proj_out.weight"],
1028
+ }
1029
+ )
1030
+
1031
+ # load uncond vector
1032
+ model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
1033
+ return model
1034
+
1035
+
1036
+ def convert_open_clip_checkpoint(checkpoint):
1037
+ text_model = CLIPTextModel.from_pretrained(
1038
+ "stabilityai/stable-diffusion-2", subfolder="text_encoder"
1039
+ )
1040
+
1041
+ keys = list(checkpoint.keys())
1042
+
1043
+ text_model_dict = {}
1044
+
1045
+ if "cond_stage_model.model.text_projection" in checkpoint:
1046
+ d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
1047
+ else:
1048
+ d_model = 1024
1049
+
1050
+ text_model_dict["text_model.embeddings.position_ids"] = (
1051
+ text_model.text_model.embeddings.get_buffer("position_ids")
1052
+ )
1053
+
1054
+ for key in keys:
1055
+ if (
1056
+ "resblocks.23" in key
1057
+ ): # Diffusers drops the final layer and only uses the penultimate layer
1058
+ continue
1059
+ if key in textenc_conversion_map:
1060
+ text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
1061
+ if key.startswith("cond_stage_model.model.transformer."):
1062
+ new_key = key[len("cond_stage_model.model.transformer.") :]
1063
+ if new_key.endswith(".in_proj_weight"):
1064
+ new_key = new_key[: -len(".in_proj_weight")]
1065
+ new_key = textenc_pattern.sub(
1066
+ lambda m: protected[re.escape(m.group(0))], new_key
1067
+ )
1068
+ text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][
1069
+ :d_model, :
1070
+ ]
1071
+ text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][
1072
+ d_model : d_model * 2, :
1073
+ ]
1074
+ text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][
1075
+ d_model * 2 :, :
1076
+ ]
1077
+ elif new_key.endswith(".in_proj_bias"):
1078
+ new_key = new_key[: -len(".in_proj_bias")]
1079
+ new_key = textenc_pattern.sub(
1080
+ lambda m: protected[re.escape(m.group(0))], new_key
1081
+ )
1082
+ text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
1083
+ text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][
1084
+ d_model : d_model * 2
1085
+ ]
1086
+ text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][
1087
+ d_model * 2 :
1088
+ ]
1089
+ else:
1090
+ new_key = textenc_pattern.sub(
1091
+ lambda m: protected[re.escape(m.group(0))], new_key
1092
+ )
1093
+
1094
+ text_model_dict[new_key] = checkpoint[key]
1095
+
1096
+ text_model.load_state_dict(text_model_dict)
1097
+
1098
+ return text_model
1099
+
1100
+
1101
+ def stable_unclip_image_encoder(original_config):
1102
+ """
1103
+ Returns the image processor and clip image encoder for the img2img unclip pipeline.
1104
+
1105
+ We currently know of two types of stable unclip models which separately use the clip and the openclip image
1106
+ encoders.
1107
+ """
1108
+
1109
+ image_embedder_config = original_config.model.params.embedder_config
1110
+
1111
+ sd_clip_image_embedder_class = image_embedder_config.target
1112
+ sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
1113
+
1114
+ if sd_clip_image_embedder_class == "ClipImageEmbedder":
1115
+ clip_model_name = image_embedder_config.params.model
1116
+
1117
+ if clip_model_name == "ViT-L/14":
1118
+ feature_extractor = CLIPImageProcessor()
1119
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
1120
+ "openai/clip-vit-large-patch14"
1121
+ )
1122
+ else:
1123
+ raise NotImplementedError(
1124
+ f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}"
1125
+ )
1126
+
1127
+ elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
1128
+ feature_extractor = CLIPImageProcessor()
1129
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
1130
+ "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
1131
+ )
1132
+ else:
1133
+ raise NotImplementedError(
1134
+ f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
1135
+ )
1136
+
1137
+ return feature_extractor, image_encoder
1138
+
1139
+
1140
+ def stable_unclip_image_noising_components(
1141
+ original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
1142
+ ):
1143
+ """
1144
+ Returns the noising components for the img2img and txt2img unclip pipelines.
1145
+
1146
+ Converts the stability noise augmentor into
1147
+ 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
1148
+ 2. a `DDPMScheduler` for holding the noise schedule
1149
+
1150
+ If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
1151
+ """
1152
+ noise_aug_config = original_config.model.params.noise_aug_config
1153
+ noise_aug_class = noise_aug_config.target
1154
+ noise_aug_class = noise_aug_class.split(".")[-1]
1155
+
1156
+ if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
1157
+ noise_aug_config = noise_aug_config.params
1158
+ embedding_dim = noise_aug_config.timestep_dim
1159
+ max_noise_level = noise_aug_config.noise_schedule_config.timesteps
1160
+ beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
1161
+
1162
+ image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
1163
+ image_noising_scheduler = DDPMScheduler(
1164
+ num_train_timesteps=max_noise_level, beta_schedule=beta_schedule
1165
+ )
1166
+
1167
+ if "clip_stats_path" in noise_aug_config:
1168
+ if clip_stats_path is None:
1169
+ raise ValueError(
1170
+ "This stable unclip config requires a `clip_stats_path`"
1171
+ )
1172
+
1173
+ clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
1174
+ clip_mean = clip_mean[None, :]
1175
+ clip_std = clip_std[None, :]
1176
+
1177
+ clip_stats_state_dict = {
1178
+ "mean": clip_mean,
1179
+ "std": clip_std,
1180
+ }
1181
+
1182
+ image_normalizer.load_state_dict(clip_stats_state_dict)
1183
+ else:
1184
+ raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
1185
+
1186
+ return image_normalizer, image_noising_scheduler
1187
+
1188
+
1189
+ def convert_controlnet_checkpoint(
1190
+ checkpoint,
1191
+ original_config,
1192
+ checkpoint_path,
1193
+ image_size,
1194
+ upcast_attention,
1195
+ extract_ema,
1196
+ ):
1197
+ ctrlnet_config = create_unet_diffusers_config(
1198
+ original_config, image_size=image_size, controlnet=True
1199
+ )
1200
+ ctrlnet_config["upcast_attention"] = upcast_attention
1201
+
1202
+ ctrlnet_config.pop("sample_size")
1203
+
1204
+ controlnet_model = ControlNetModel(**ctrlnet_config)
1205
+
1206
+ converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
1207
+ checkpoint,
1208
+ ctrlnet_config,
1209
+ path=checkpoint_path,
1210
+ extract_ema=extract_ema,
1211
+ controlnet=True,
1212
+ )
1213
+
1214
+ controlnet_model.load_state_dict(converted_ctrl_checkpoint)
1215
+
1216
+ return controlnet_model
animatediff/utils/convert_lora_safetensor_to_diffusers.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Conversion script for the LoRA's safetensors checkpoints. """
17
+
18
+ import argparse
19
+
20
+ import torch
21
+ from safetensors.torch import load_file
22
+
23
+ from diffusers import StableDiffusionPipeline
24
+ import pdb
25
+
26
+
27
+
28
+ def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
29
+ # directly update weight in diffusers model
30
+ for key in state_dict:
31
+ # only process lora down key
32
+ if "up." in key: continue
33
+
34
+ up_key = key.replace(".down.", ".up.")
35
+ model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
36
+ model_key = model_key.replace("to_out.", "to_out.0.")
37
+ layer_infos = model_key.split(".")[:-1]
38
+
39
+ curr_layer = pipeline.unet
40
+ while len(layer_infos) > 0:
41
+ temp_name = layer_infos.pop(0)
42
+ curr_layer = curr_layer.__getattr__(temp_name)
43
+
44
+ weight_down = state_dict[key]
45
+ weight_up = state_dict[up_key]
46
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
47
+
48
+ return pipeline
49
+
50
+
51
+
52
+ def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
53
+ # load base model
54
+ # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
55
+
56
+ # load LoRA weight from .safetensors
57
+ # state_dict = load_file(checkpoint_path)
58
+
59
+ visited = []
60
+
61
+ # directly update weight in diffusers model
62
+ for key in state_dict:
63
+ # it is suggested to print out the key, it usually will be something like below
64
+ # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
65
+
66
+ # as we have set the alpha beforehand, so just skip
67
+ if ".alpha" in key or key in visited:
68
+ continue
69
+
70
+ if "text" in key:
71
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
72
+ curr_layer = pipeline.text_encoder
73
+ else:
74
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
75
+ curr_layer = pipeline.unet
76
+
77
+ # find the target layer
78
+ temp_name = layer_infos.pop(0)
79
+ while len(layer_infos) > -1:
80
+ try:
81
+ curr_layer = curr_layer.__getattr__(temp_name)
82
+ if len(layer_infos) > 0:
83
+ temp_name = layer_infos.pop(0)
84
+ elif len(layer_infos) == 0:
85
+ break
86
+ except Exception:
87
+ if len(temp_name) > 0:
88
+ temp_name += "_" + layer_infos.pop(0)
89
+ else:
90
+ temp_name = layer_infos.pop(0)
91
+
92
+ pair_keys = []
93
+ if "lora_down" in key:
94
+ pair_keys.append(key.replace("lora_down", "lora_up"))
95
+ pair_keys.append(key)
96
+ else:
97
+ pair_keys.append(key)
98
+ pair_keys.append(key.replace("lora_up", "lora_down"))
99
+
100
+ # update weight
101
+ if len(state_dict[pair_keys[0]].shape) == 4:
102
+ weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
103
+ weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
104
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
105
+ else:
106
+ weight_up = state_dict[pair_keys[0]].to(torch.float32)
107
+ weight_down = state_dict[pair_keys[1]].to(torch.float32)
108
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
109
+
110
+ # update visited list
111
+ for item in pair_keys:
112
+ visited.append(item)
113
+
114
+ return pipeline
115
+
116
+
117
+ if __name__ == "__main__":
118
+ parser = argparse.ArgumentParser()
119
+
120
+ parser.add_argument(
121
+ "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
122
+ )
123
+ parser.add_argument(
124
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
125
+ )
126
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
127
+ parser.add_argument(
128
+ "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
129
+ )
130
+ parser.add_argument(
131
+ "--lora_prefix_text_encoder",
132
+ default="lora_te",
133
+ type=str,
134
+ help="The prefix of text encoder weight in safetensors",
135
+ )
136
+ parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
137
+ parser.add_argument(
138
+ "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
139
+ )
140
+ parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
141
+
142
+ args = parser.parse_args()
143
+
144
+ base_model_path = args.base_model_path
145
+ checkpoint_path = args.checkpoint_path
146
+ dump_path = args.dump_path
147
+ lora_prefix_unet = args.lora_prefix_unet
148
+ lora_prefix_text_encoder = args.lora_prefix_text_encoder
149
+ alpha = args.alpha
150
+
151
+ pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
152
+
153
+ pipe = pipe.to(args.device)
154
+ pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
animatediff/utils/util.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ from typing import Union
5
+
6
+ import torch
7
+ import torchvision
8
+ import torch.distributed as dist
9
+
10
+ from PIL import Image
11
+ from transformers import AutoProcessor, CLIPModel
12
+ import torch.nn as nn
13
+
14
+ from safetensors import safe_open
15
+ from tqdm import tqdm
16
+ from einops import rearrange
17
+ from animatediff.utils.convert_from_ckpt import (
18
+ convert_ldm_unet_checkpoint,
19
+ convert_ldm_clip_checkpoint,
20
+ #convert_ldm_vae_checkpoint,
21
+ )
22
+ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_vae_checkpoint
23
+ from animatediff.utils.convert_lora_safetensor_to_diffusers import (
24
+ convert_lora,
25
+ convert_motion_lora_ckpt_to_diffusers,
26
+ )
27
+
28
+
29
+ def zero_rank_print(s):
30
+ if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0):
31
+ print("### " + s)
32
+
33
+
34
+ def ToImage(videos: torch.Tensor, rescale=False, n_rows=6):
35
+ videos = rearrange(videos, "b c t h w -> t b c h w")
36
+ outputs = []
37
+ for x in videos:
38
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
39
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
40
+ if rescale:
41
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
42
+ x = (x * 255).numpy().astype(np.uint8)
43
+ outputs.append(Image.fromarray(x))
44
+ return outputs
45
+
46
+
47
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
48
+ videos = rearrange(videos, "b c t h w -> t b c h w")
49
+ outputs = []
50
+ for x in videos:
51
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
52
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
53
+ if rescale:
54
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
55
+ x = (x * 255).numpy().astype(np.uint8)
56
+ outputs.append(x)
57
+
58
+ os.makedirs(os.path.dirname(path), exist_ok=True)
59
+ imageio.mimsave(path, outputs, fps=fps)
60
+
61
+
62
+ # DDIM Inversion
63
+ @torch.no_grad()
64
+ def init_prompt(prompt, pipeline):
65
+ uncond_input = pipeline.tokenizer(
66
+ [""],
67
+ padding="max_length",
68
+ max_length=pipeline.tokenizer.model_max_length,
69
+ return_tensors="pt",
70
+ )
71
+ uncond_embeddings = pipeline.text_encoder(
72
+ uncond_input.input_ids.to(pipeline.device)
73
+ )[0]
74
+ text_input = pipeline.tokenizer(
75
+ [prompt],
76
+ padding="max_length",
77
+ max_length=pipeline.tokenizer.model_max_length,
78
+ truncation=True,
79
+ return_tensors="pt",
80
+ )
81
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
82
+ context = torch.cat([uncond_embeddings, text_embeddings])
83
+
84
+ return context
85
+
86
+
87
+ def next_step(
88
+ model_output: Union[torch.FloatTensor, np.ndarray],
89
+ timestep: int,
90
+ sample: Union[torch.FloatTensor, np.ndarray],
91
+ ddim_scheduler,
92
+ ):
93
+ timestep, next_timestep = (
94
+ min(
95
+ timestep
96
+ - ddim_scheduler.config.num_train_timesteps
97
+ // ddim_scheduler.num_inference_steps,
98
+ 999,
99
+ ),
100
+ timestep,
101
+ )
102
+ alpha_prod_t = (
103
+ ddim_scheduler.alphas_cumprod[timestep]
104
+ if timestep >= 0
105
+ else ddim_scheduler.final_alpha_cumprod
106
+ )
107
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
108
+ beta_prod_t = 1 - alpha_prod_t
109
+ next_original_sample = (
110
+ sample - beta_prod_t**0.5 * model_output
111
+ ) / alpha_prod_t**0.5
112
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
113
+ next_sample = alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction
114
+ return next_sample
115
+
116
+
117
+ def get_noise_pred_single(latents, t, context, unet):
118
+ noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
119
+ return noise_pred
120
+
121
+
122
+ @torch.no_grad()
123
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
124
+ context = init_prompt(prompt, pipeline)
125
+ uncond_embeddings, cond_embeddings = context.chunk(2)
126
+ all_latent = [latent]
127
+ latent = latent.clone().detach()
128
+ for i in tqdm(range(num_inv_steps)):
129
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
130
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
131
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
132
+ all_latent.append(latent)
133
+ return all_latent
134
+
135
+
136
+ @torch.no_grad()
137
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
138
+ ddim_latents = ddim_loop(
139
+ pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt
140
+ )
141
+ return ddim_latents
142
+
143
+
144
+ def load_weights(
145
+ animation_pipeline,
146
+ # motion module
147
+ motion_module_path="",
148
+ motion_module_lora_configs=[],
149
+ # image layers
150
+ dreambooth_model_path="",
151
+ lora_model_path="",
152
+ lora_alpha=0.8,
153
+ ):
154
+ # 1.1 motion module
155
+ unet_state_dict = {}
156
+ if motion_module_path != "":
157
+ print(f"load motion module from {motion_module_path}")
158
+ motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
159
+ motion_module_state_dict = (
160
+ motion_module_state_dict["state_dict"]
161
+ if "state_dict" in motion_module_state_dict
162
+ else motion_module_state_dict
163
+ )
164
+ unet_state_dict.update(
165
+ {
166
+ name: param
167
+ for name, param in motion_module_state_dict.items()
168
+ if "motion_modules." in name
169
+ }
170
+ )
171
+
172
+ missing, unexpected = animation_pipeline.unet.load_state_dict(
173
+ unet_state_dict, strict=False
174
+ )
175
+ assert len(unexpected) == 0
176
+ del unet_state_dict
177
+
178
+ if dreambooth_model_path != "":
179
+ print(f"load dreambooth model from {dreambooth_model_path}")
180
+ if dreambooth_model_path.endswith(".safetensors"):
181
+ dreambooth_state_dict = {}
182
+ with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f:
183
+ for key in f.keys():
184
+ dreambooth_state_dict[key] = f.get_tensor(key)
185
+ elif dreambooth_model_path.endswith(".ckpt"):
186
+ dreambooth_state_dict = torch.load(
187
+ dreambooth_model_path, map_location="cpu"
188
+ )
189
+
190
+ # 1. vae
191
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(
192
+ dreambooth_state_dict, animation_pipeline.vae.config
193
+ )
194
+ animation_pipeline.vae.load_state_dict(converted_vae_checkpoint)
195
+ # 2. unet
196
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
197
+ dreambooth_state_dict, animation_pipeline.unet.config
198
+ )
199
+ animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
200
+ # 3. text_model
201
+ animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(
202
+ dreambooth_state_dict
203
+ )
204
+ del dreambooth_state_dict
205
+
206
+ if lora_model_path != "":
207
+ print(f"load lora model from {lora_model_path}")
208
+ assert lora_model_path.endswith(".safetensors")
209
+ lora_state_dict = {}
210
+ with safe_open(lora_model_path, framework="pt", device="cpu") as f:
211
+ for key in f.keys():
212
+ lora_state_dict[key] = f.get_tensor(key)
213
+
214
+ animation_pipeline = convert_lora(
215
+ animation_pipeline, lora_state_dict, alpha=lora_alpha
216
+ )
217
+ del lora_state_dict
218
+
219
+ for motion_module_lora_config in motion_module_lora_configs:
220
+ path, alpha = (
221
+ motion_module_lora_config["path"],
222
+ motion_module_lora_config["alpha"],
223
+ )
224
+ print(f"load motion LoRA from {path}")
225
+
226
+ motion_lora_state_dict = torch.load(path, map_location="cpu")
227
+ motion_lora_state_dict = (
228
+ motion_lora_state_dict["state_dict"]
229
+ if "state_dict" in motion_lora_state_dict
230
+ else motion_lora_state_dict
231
+ )
232
+
233
+ animation_pipeline = convert_motion_lora_ckpt_to_diffusers(
234
+ animation_pipeline, motion_lora_state_dict, alpha
235
+ )
236
+
237
+ return animation_pipeline
app.py ADDED
@@ -0,0 +1,715 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ # fix all the seeds for reproducibility
5
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
6
+ torch.backends.cudnn.benchmark = False
7
+ torch.use_deterministic_algorithms(True)
8
+ import ptp_utils
9
+ import random
10
+ import abc
11
+ import gradio as gr
12
+ from glob import glob
13
+ from einops import rearrange
14
+ from omegaconf import OmegaConf
15
+ from safetensors import safe_open
16
+ from diffusers import AutoencoderKL
17
+ from diffusers import DDIMScheduler
18
+ from diffusers.utils.import_utils import is_xformers_available
19
+ from transformers import CLIPTextModel, CLIPTokenizer
20
+
21
+ from animatediff.models.unet import UNet3DConditionModel
22
+ from animatediff.pipelines.pipeline_animation import AnimationPipeline
23
+ from animatediff.pipelines.pipeline_animation import AnimationCtrlPipeline
24
+ from animatediff.utils.util import save_videos_grid
25
+ from animatediff.utils.convert_from_ckpt import (
26
+ convert_ldm_unet_checkpoint,
27
+ convert_ldm_clip_checkpoint,
28
+ )
29
+ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
30
+ convert_ldm_vae_checkpoint,
31
+ )
32
+ from diffusers.training_utils import set_seed
33
+
34
+
35
+ pretrained_model_path = "./models/StableDiffusion"
36
+ inference_config_path = "configs/inference/inference-v1.yaml"
37
+
38
+ css = """
39
+ .toolbutton {
40
+ margin-buttom: 0em 0em 0em 0em;
41
+ max-width: 2.5em;
42
+ min-width: 2.5em !important;
43
+ height: 2.5em;
44
+ }
45
+ """
46
+
47
+
48
+ class AttentionControl(abc.ABC):
49
+ def step_callback(self, x_t):
50
+ return x_t
51
+
52
+ def between_steps(self):
53
+ return
54
+
55
+ @abc.abstractmethod
56
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
57
+ raise NotImplementedError
58
+
59
+ def __call__(self, hidden_states, video_length, place_in_unet: str):
60
+ hidden_states = rearrange(hidden_states, "(b f) d c -> b f d c", f=video_length)
61
+ batch_size = hidden_states.shape[0] // 2
62
+
63
+ if batch_size == 2:
64
+ # Do classifier-free guidance
65
+ hidden_states_uncondition, hidden_states_condition = hidden_states.chunk(2)
66
+
67
+ if self.cur_step <= self.motion_control_step:
68
+ hidden_states_motion_uncondition = hidden_states_uncondition[
69
+ 1
70
+ ].unsqueeze(0)
71
+ else:
72
+ hidden_states_motion_uncondition = hidden_states_uncondition[
73
+ 0
74
+ ].unsqueeze(0)
75
+
76
+ hidden_states_out_uncondition = torch.cat(
77
+ [
78
+ hidden_states_motion_uncondition,
79
+ hidden_states_uncondition[1].unsqueeze(0),
80
+ ],
81
+ dim=0,
82
+ ) # Query
83
+ hidden_states_sac_in_uncondition = self.forward(
84
+ hidden_states_uncondition[0].unsqueeze(0), video_length, place_in_unet
85
+ )
86
+ hidden_states_sac_out_uncondition = torch.cat(
87
+ [
88
+ hidden_states_sac_in_uncondition,
89
+ hidden_states_uncondition[1].unsqueeze(0),
90
+ ],
91
+ dim=0,
92
+ ) # Key & Value
93
+
94
+ if self.cur_step <= self.motion_control_step:
95
+ hidden_states_motion_condition = hidden_states_condition[1].unsqueeze(0)
96
+ else:
97
+ hidden_states_motion_condition = hidden_states_condition[0].unsqueeze(0)
98
+
99
+ hidden_states_out_condition = torch.cat(
100
+ [
101
+ hidden_states_motion_condition,
102
+ hidden_states_condition[1].unsqueeze(0),
103
+ ],
104
+ dim=0,
105
+ ) # Query
106
+ hidden_states_sac_in_condition = self.forward(
107
+ hidden_states_condition[0].unsqueeze(0), video_length, place_in_unet
108
+ )
109
+ hidden_states_sac_out_condition = torch.cat(
110
+ [
111
+ hidden_states_sac_in_condition,
112
+ hidden_states_condition[1].unsqueeze(0),
113
+ ],
114
+ dim=0,
115
+ ) # Key & Value
116
+
117
+ hidden_states_out = torch.cat(
118
+ [hidden_states_out_uncondition, hidden_states_out_condition], dim=0
119
+ )
120
+ hidden_states_sac_out = torch.cat(
121
+ [hidden_states_sac_out_uncondition, hidden_states_sac_out_condition],
122
+ dim=0,
123
+ )
124
+
125
+ elif batch_size == 1:
126
+ if self.cur_step <= self.motion_control_step:
127
+ hidden_states_motion = hidden_states[1].unsqueeze(0)
128
+ else:
129
+ hidden_states_motion = hidden_states[0].unsqueeze(0)
130
+
131
+ hidden_states_out = torch.cat(
132
+ [hidden_states_motion, hidden_states[1].unsqueeze(0)], dim=0
133
+ ) # Query
134
+ hidden_states_sac_in = self.forward(
135
+ hidden_states[0].unsqueeze(0), video_length, place_in_unet
136
+ )
137
+ hidden_states_sac_out = torch.cat(
138
+ [hidden_states_sac_in, hidden_states[1].unsqueeze(0)], dim=0
139
+ ) # Key & Value
140
+
141
+ else:
142
+ raise gr.Error(f"Not implemented error")
143
+ hidden_states = rearrange(hidden_states, "b f d c -> (b f) d c", f=video_length)
144
+ hidden_states_out = rearrange(
145
+ hidden_states_out, "b f d c -> (b f) d c", f=video_length
146
+ )
147
+ hidden_states_sac_out = rearrange(
148
+ hidden_states_sac_out, "b f d c -> (b f) d c", f=video_length
149
+ )
150
+ self.cur_att_layer += 1
151
+ if self.cur_att_layer == self.num_att_layers:
152
+ self.cur_att_layer = 0
153
+ self.cur_step += 1
154
+ return hidden_states_out, hidden_states_sac_out, hidden_states_sac_out
155
+
156
+ def reset(self):
157
+ self.cur_step = 0
158
+ self.cur_att_layer = 0
159
+ self.num_att_layers = -1
160
+ self.motion_control_step = 0
161
+
162
+ def __init__(self):
163
+ self.cur_step = 0
164
+ self.cur_att_layer = 0
165
+ self.num_att_layers = -1
166
+ self.motion_control_step = 0
167
+
168
+
169
+ class EmptyControl(AttentionControl):
170
+ def forward(self, hidden_states, video_length, place_in_unet):
171
+ return hidden_states
172
+
173
+
174
+ class FreeSAC(AttentionControl):
175
+ def forward(self, hidden_states, video_length, place_in_unet):
176
+ hidden_states_sac = (
177
+ hidden_states[:, 0, :, :].unsqueeze(1).repeat(1, video_length, 1, 1)
178
+ )
179
+ return hidden_states_sac
180
+
181
+
182
+ examples = [
183
+ # 0-RealisticVision
184
+ [
185
+ "realisticVisionV60B1_v20Novae.safetensors",
186
+ "mm_sd_v14.ckpt",
187
+ "A panda standing on a surfboard in the ocean under moonlight.",
188
+ "worst quality, low quality, nsfw, logo",
189
+ 0.2,
190
+ 512,
191
+ 512,
192
+ "12345",
193
+ ["use_fp16"],
194
+ ],
195
+ [
196
+ "toonyou_beta3.safetensors",
197
+ "mm_sd_v14.ckpt",
198
+ "(best quality, masterpiece), 1girl, looking at viewer, blurry background, upper body, contemporary, dress",
199
+ "(worst quality, low quality)",
200
+ 0.2,
201
+ 512,
202
+ 512,
203
+ "12345",
204
+ ["use_fp16"],
205
+ ],
206
+ [
207
+ "lyriel_v16.safetensors",
208
+ "mm_sd_v14.ckpt",
209
+ "hypercars cyberpunk moving, muted colors, swirling color smokes, legend, cityscape, space",
210
+ "3d, cartoon, anime, sketches, worst quality, low quality, nsfw, logo",
211
+ 0.2,
212
+ 512,
213
+ 512,
214
+ "12345",
215
+ ["use_fp16"],
216
+ ],
217
+ [
218
+ "rcnzCartoon3d_v10.safetensors",
219
+ "mm_sd_v14.ckpt",
220
+ "A cute raccoon playing guitar in a boat on the ocean",
221
+ "worst quality, low quality, nsfw, logo",
222
+ 0.2,
223
+ 512,
224
+ 512,
225
+ "42",
226
+ ["use_fp16"],
227
+ ],
228
+ [
229
+ "majicmixRealistic_v5Preview.safetensors",
230
+ "mm_sd_v14.ckpt",
231
+ "1girl, reading book",
232
+ "(ng_deepnegative_v1_75t:1.2), (badhandv4:1), (worst quality:2), (low quality:2), (normal quality:2), lowres, bad anatomy, bad hands, watermark, moles",
233
+ 0.2,
234
+ 512,
235
+ 512,
236
+ "12345",
237
+ ["use_fp16"],
238
+ ],
239
+ ]
240
+
241
+ # clean Gradio cache
242
+ print(f"### Cleaning cached examples ...")
243
+ os.system(f"rm -rf gradio_cached_examples/")
244
+
245
+
246
+ class AnimateController:
247
+ def __init__(self):
248
+ # config dirs
249
+ self.basedir = os.getcwd()
250
+ self.stable_diffusion_dir = os.path.join(
251
+ self.basedir, "models", "StableDiffusion"
252
+ )
253
+ self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
254
+ self.personalized_model_dir = os.path.join(
255
+ self.basedir, "models", "DreamBooth_LoRA"
256
+ )
257
+ self.savedir = os.path.join(self.basedir, "samples")
258
+ os.makedirs(self.savedir, exist_ok=True)
259
+
260
+ self.base_model_list = [None]
261
+ self.motion_module_list = []
262
+ self.selected_base_model = None
263
+ self.selected_motion_module = None
264
+ self.set_width = None
265
+ self.set_height = None
266
+
267
+ self.refresh_motion_module()
268
+ self.refresh_personalized_model()
269
+
270
+ # config models
271
+ self.inference_config = OmegaConf.load(inference_config_path)
272
+
273
+ self.tokenizer = CLIPTokenizer.from_pretrained(
274
+ pretrained_model_path, subfolder="tokenizer"
275
+ )
276
+ self.text_encoder = CLIPTextModel.from_pretrained(
277
+ pretrained_model_path, subfolder="text_encoder"
278
+ ).cuda()
279
+ self.vae = AutoencoderKL.from_pretrained(
280
+ pretrained_model_path, subfolder="vae"
281
+ ).cuda()
282
+ self.unet = UNet3DConditionModel.from_pretrained_2d(
283
+ pretrained_model_path,
284
+ subfolder="unet",
285
+ unet_additional_kwargs=OmegaConf.to_container(
286
+ self.inference_config.unet_additional_kwargs
287
+ ),
288
+ ).cuda()
289
+
290
+ self.freq_filter = None
291
+
292
+ self.update_base_model(self.base_model_list[-2])
293
+ self.update_motion_module(self.motion_module_list[0])
294
+
295
+ def refresh_motion_module(self):
296
+ motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
297
+ self.motion_module_list = sorted(
298
+ [os.path.basename(p) for p in motion_module_list]
299
+ )
300
+
301
+ def refresh_personalized_model(self):
302
+ base_model_list = glob(
303
+ os.path.join(self.personalized_model_dir, "*.safetensors")
304
+ )
305
+ self.base_model_list += sorted([os.path.basename(p) for p in base_model_list])
306
+
307
+ def update_base_model(self, base_model_dropdown):
308
+ self.selected_base_model = base_model_dropdown
309
+ if base_model_dropdown == "None" or base_model_dropdown is None:
310
+ return gr.Dropdown.update()
311
+
312
+ base_model_dropdown = os.path.join(
313
+ self.personalized_model_dir, base_model_dropdown
314
+ )
315
+ base_model_state_dict = {}
316
+ with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
317
+ for key in f.keys():
318
+ base_model_state_dict[key] = f.get_tensor(key)
319
+
320
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(
321
+ base_model_state_dict, self.vae.config
322
+ )
323
+ self.vae.load_state_dict(converted_vae_checkpoint)
324
+
325
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
326
+ base_model_state_dict, self.unet.config
327
+ )
328
+ self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
329
+
330
+ self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
331
+ return gr.Dropdown.update()
332
+
333
+ def update_motion_module(self, motion_module_dropdown):
334
+ self.selected_motion_module = motion_module_dropdown
335
+
336
+ motion_module_dropdown = os.path.join(
337
+ self.motion_module_dir, motion_module_dropdown
338
+ )
339
+ motion_module_state_dict = torch.load(
340
+ motion_module_dropdown, map_location="cpu"
341
+ )
342
+ _, unexpected = self.unet.load_state_dict(
343
+ motion_module_state_dict, strict=False
344
+ )
345
+ assert len(unexpected) == 0
346
+ return gr.Dropdown.update()
347
+
348
+ def run_pipeline(self, pipeline, args):
349
+ # Initialize CUDA context in the subprocess
350
+ torch.cuda.init()
351
+ # Run the pipeline with the given arguments
352
+ return pipeline(**args)
353
+
354
+ def animate_ctrl(
355
+ self,
356
+ base_model_dropdown,
357
+ motion_module_dropdown,
358
+ prompt_textbox,
359
+ negative_prompt_textbox,
360
+ motion_control,
361
+ width_slider,
362
+ height_slider,
363
+ seed_textbox,
364
+ # speed up
365
+ speed_up_options,
366
+ ):
367
+ set_seed(42)
368
+ inference_step = 25
369
+
370
+ if self.selected_base_model != base_model_dropdown:
371
+ self.update_base_model(base_model_dropdown)
372
+ if self.selected_motion_module != motion_module_dropdown:
373
+ self.update_motion_module(motion_module_dropdown)
374
+
375
+ if is_xformers_available():
376
+ self.unet.enable_xformers_memory_efficient_attention()
377
+
378
+ if int(seed_textbox) > 0:
379
+ seed = int(seed_textbox)
380
+ else:
381
+ seed = random.randint(1, 1e16)
382
+ torch.manual_seed(int(seed))
383
+
384
+ assert seed == torch.initial_seed()
385
+ print(f"### seed: {seed}")
386
+
387
+ generator = torch.Generator(device="cuda:0")
388
+ generator.manual_seed(seed)
389
+
390
+ pipeline = AnimationCtrlPipeline(
391
+ vae=self.vae,
392
+ text_encoder=self.text_encoder,
393
+ tokenizer=self.tokenizer,
394
+ unet=self.unet,
395
+ scheduler=DDIMScheduler(
396
+ **OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)
397
+ ),
398
+ ).to("cuda")
399
+
400
+ motion_control_step = motion_control * inference_step
401
+
402
+ attn_controller = FreeSAC()
403
+ attn_controller.motion_control_step = motion_control_step
404
+ ptp_utils.register_attention_control(pipeline, attn_controller)
405
+
406
+ sample_output_ctrl = pipeline(
407
+ prompt_textbox,
408
+ negative_prompt=negative_prompt_textbox,
409
+ num_inference_steps=inference_step,
410
+ guidance_scale=7.5,
411
+ width=width_slider,
412
+ height=height_slider,
413
+ video_length=16,
414
+ use_fp16=True if "use_fp16" in speed_up_options else False,
415
+ generator=generator,
416
+ )
417
+
418
+ ctrl_sample = sample_output_ctrl.videos
419
+
420
+ save_ctrl_sample_path = os.path.join(self.savedir, "ctrl_sample.mp4")
421
+ save_videos_grid(ctrl_sample, save_ctrl_sample_path)
422
+
423
+ json_config = {
424
+ "prompt": prompt_textbox,
425
+ "n_prompt": negative_prompt_textbox,
426
+ "width": width_slider,
427
+ "height": height_slider,
428
+ "seed": seed,
429
+ "base_model": base_model_dropdown,
430
+ "motion_module": motion_module_dropdown,
431
+ "use_fp16": True if "use_fp16" in speed_up_options else False,
432
+ }
433
+
434
+ del attn_controller
435
+ del pipeline
436
+ torch.cuda.empty_cache()
437
+ return (
438
+ gr.Video.update(value=save_ctrl_sample_path),
439
+ gr.Json.update(value=json_config),
440
+ )
441
+
442
+ def animate(
443
+ self,
444
+ base_model_dropdown,
445
+ motion_module_dropdown,
446
+ prompt_textbox,
447
+ negative_prompt_textbox,
448
+ motion_control,
449
+ width_slider,
450
+ height_slider,
451
+ seed_textbox,
452
+ # freeinit params
453
+ filter_type_dropdown,
454
+ speed_up_options,
455
+ ):
456
+ # set global seed
457
+ set_seed(42)
458
+ # set inference step
459
+ inference_step = 25
460
+
461
+ if self.selected_base_model != base_model_dropdown:
462
+ self.update_base_model(base_model_dropdown)
463
+ if self.selected_motion_module != motion_module_dropdown:
464
+ self.update_motion_module(motion_module_dropdown)
465
+
466
+ if is_xformers_available():
467
+ self.unet.enable_xformers_memory_efficient_attention()
468
+
469
+ if seed_textbox and int(seed_textbox) >= 0:
470
+ seed = int(seed_textbox)
471
+ else:
472
+ seed = random.randint(0, 2**32 - 1)
473
+ torch.manual_seed(int(seed))
474
+
475
+ assert seed == torch.initial_seed()
476
+ print(f"seed: {seed}")
477
+
478
+ generator = torch.Generator(device="cuda:0")
479
+ generator.manual_seed(seed)
480
+
481
+ pipeline = AnimationPipeline(
482
+ vae=self.vae,
483
+ text_encoder=self.text_encoder,
484
+ tokenizer=self.tokenizer,
485
+ unet=self.unet,
486
+ scheduler=DDIMScheduler(
487
+ **OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)
488
+ ),
489
+ ).to("cuda")
490
+
491
+ attn_controller = EmptyControl()
492
+ attn_controller.motion_control_step = -1
493
+ ptp_utils.register_attention_control(pipeline, attn_controller)
494
+
495
+ sample_output_orig = pipeline(
496
+ prompt_textbox,
497
+ negative_prompt=negative_prompt_textbox,
498
+ num_inference_steps=inference_step,
499
+ guidance_scale=7.5,
500
+ width=width_slider,
501
+ height=height_slider,
502
+ video_length=16,
503
+ use_fp16=(
504
+ True if speed_up_options and "use_fp16" in speed_up_options else False
505
+ ),
506
+ generator=generator,
507
+ )
508
+
509
+ orig_sample = sample_output_orig.videos
510
+
511
+ save_orig_sample_path = os.path.join(self.savedir, "orig_sample.mp4")
512
+ save_videos_grid(orig_sample, save_orig_sample_path)
513
+
514
+ json_config = {
515
+ "prompt": prompt_textbox,
516
+ "n_prompt": negative_prompt_textbox,
517
+ "width": width_slider,
518
+ "height": height_slider,
519
+ "seed": seed,
520
+ "base_model": base_model_dropdown,
521
+ "motion_module": motion_module_dropdown,
522
+ "filter_type": filter_type_dropdown,
523
+ "use_fp16": (
524
+ True if speed_up_options and "use_fp16" in speed_up_options else False
525
+ ),
526
+ }
527
+ del pipeline
528
+ torch.cuda.empty_cache()
529
+
530
+ return (
531
+ gr.Video.update(value=save_orig_sample_path),
532
+ gr.Json.update(value=json_config),
533
+ )
534
+
535
+
536
+ controller = AnimateController()
537
+
538
+
539
+ def ui():
540
+ with gr.Blocks(css=css) as demo:
541
+ # gr.Markdown('# FreeInit')
542
+ gr.Markdown(
543
+ """
544
+ <div align="center">
545
+ <h1>UniCtrl: Improving the Spatiotemporal Consistency of Text-to-Video Diffusion Models via Training-Free Unified Attention Control</h1>
546
+ </div>
547
+ """
548
+ )
549
+ gr.Markdown(
550
+ """
551
+ <p align="center">
552
+ <a title="Project Page" href="https://unified-attention-control.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
553
+ <img src="https://img.shields.io/badge/Project-Website-5B7493?logo=googlechrome&logoColor=5B7493">
554
+ </a>
555
+ <a title="arXiv" href="https://arxiv.org/abs/2312.07537" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
556
+ <img src="https://img.shields.io/badge/arXiv-Paper-b31b1b?logo=arxiv&logoColor=b31b1b">
557
+ </a>
558
+ <a title="GitHub" href="https://github.com/XuweiyiChen/UniCtrl" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
559
+ <img src="https://img.shields.io/github/stars/XuweiyiChen/UniCtrl?label=GitHub%E2%98%85&&logo=github" alt="badge-github-stars">
560
+ </a>
561
+ </p>
562
+ """
563
+ )
564
+ gr.Markdown(
565
+ """
566
+ Official Gradio Demo for ***UniCtrl: Improving the Spatiotemporal Consistency of Text-to-Video Diffusion Models via Training-Free Unified Attention Control***.
567
+ UniCtrl improves spatiotemporal consistency of diffusion-based video generation at inference time. In this demo, we apply FreeInit on [AnimateDiff v1](https://github.com/guoyww/AnimateDiff) as an example. Sampling time: ~ 80s.<br>
568
+ """
569
+ )
570
+
571
+ with gr.Row():
572
+ with gr.Column():
573
+ prompt_textbox = gr.Textbox(
574
+ label="Prompt", lines=3, placeholder="Enter your prompt here"
575
+ )
576
+ negative_prompt_textbox = gr.Textbox(
577
+ label="Negative Prompt",
578
+ lines=3,
579
+ value="worst quality, low quality, nsfw, logo",
580
+ )
581
+ motion_control = gr.Slider(
582
+ label="Motion Injection Degree",
583
+ value=0.2,
584
+ minimum=0,
585
+ maximum=1,
586
+ step=0.1,
587
+ info="Motion Control Strength",
588
+ )
589
+
590
+ gr.Markdown(
591
+ """
592
+ *Prompt Tips:*
593
+
594
+ For each personalized model in `Model Settings`, you can refer to their webpage on CivitAI to learn how to write good prompts for them:
595
+ - [`realisticVisionV60B1_v20Novae.safetensors.safetensors`](https://civitai.com/models/4201?modelVersionId=130072)
596
+ - [`toonyou_beta3.safetensors`](https://civitai.com/models/30240?modelVersionId=78775)
597
+ - [`lyriel_v16.safetensors`](https://civitai.com/models/22922/lyriel)
598
+ - [`rcnzCartoon3d_v10.safetensors`](https://civitai.com/models/66347?modelVersionId=71009)
599
+ - [`majicmixRealistic_v5Preview.safetensors`](https://civitai.com/models/43331?modelVersionId=79068)
600
+ """
601
+ )
602
+
603
+ with gr.Accordion("Model Settings", open=False):
604
+ gr.Markdown(
605
+ """
606
+ Select personalized model and motion module for AnimateDiff.
607
+ """
608
+ )
609
+ base_model_dropdown = gr.Dropdown(
610
+ label="Base DreamBooth Model",
611
+ choices=controller.base_model_list,
612
+ value=controller.base_model_list[-2],
613
+ interactive=True,
614
+ info="Select personalized text-to-image model from community",
615
+ )
616
+ motion_module_dropdown = gr.Dropdown(
617
+ label="Motion Module",
618
+ choices=controller.motion_module_list,
619
+ value=controller.motion_module_list[0],
620
+ interactive=True,
621
+ info="Select motion module. Recommend mm_sd_v14.ckpt for larger movements.",
622
+ )
623
+
624
+ base_model_dropdown.change(
625
+ fn=controller.update_base_model,
626
+ inputs=[base_model_dropdown],
627
+ outputs=[base_model_dropdown],
628
+ )
629
+ motion_module_dropdown.change(
630
+ fn=controller.update_motion_module,
631
+ inputs=[motion_module_dropdown],
632
+ outputs=[base_model_dropdown],
633
+ )
634
+
635
+ with gr.Accordion("Advance", open=False):
636
+ with gr.Row():
637
+ width_slider = gr.Slider(
638
+ label="Width", value=512, minimum=256, maximum=1024, step=64
639
+ )
640
+ height_slider = gr.Slider(
641
+ label="Height",
642
+ value=512,
643
+ minimum=256,
644
+ maximum=1024,
645
+ step=64,
646
+ )
647
+ with gr.Row():
648
+ seed_textbox = gr.Textbox(label="Seed", value=442)
649
+ seed_button = gr.Button(
650
+ value="\U0001F3B2", elem_classes="toolbutton"
651
+ )
652
+ seed_button.click(
653
+ fn=lambda: gr.Textbox.update(value=random.randint(1, 1e9)),
654
+ inputs=[],
655
+ outputs=[seed_textbox],
656
+ )
657
+ with gr.Row():
658
+ speed_up_options = gr.CheckboxGroup(
659
+ ["use_fp16"],
660
+ label="Speed-Up Options",
661
+ value=["use_fp16"],
662
+ )
663
+
664
+ with gr.Column():
665
+ with gr.Row():
666
+ orig_video = gr.Video(label="AnimateDiff", interactive=False)
667
+ ctrl_video = gr.Video(
668
+ label="AnimateDiff + UniCtrl", interactive=False
669
+ )
670
+ with gr.Row():
671
+ generate_button = gr.Button(
672
+ value="Generate Original", variant="primary"
673
+ )
674
+ generate_button_ctr = gr.Button(
675
+ value="Generate UniCtrl", variant="primary"
676
+ )
677
+ with gr.Row():
678
+ json_config = gr.Json(label="Config", value=None)
679
+
680
+ inputs = [
681
+ base_model_dropdown,
682
+ motion_module_dropdown,
683
+ prompt_textbox,
684
+ negative_prompt_textbox,
685
+ motion_control,
686
+ width_slider,
687
+ height_slider,
688
+ seed_textbox,
689
+ speed_up_options,
690
+ ]
691
+
692
+ generate_button.click(
693
+ fn=controller.animate, inputs=inputs, outputs=[orig_video, json_config]
694
+ )
695
+ generate_button_ctr.click(
696
+ fn=controller.animate_ctrl,
697
+ inputs=inputs,
698
+ outputs=[ctrl_video, json_config],
699
+ )
700
+
701
+ gr.Examples(
702
+ fn=controller.animate_ctrl,
703
+ examples=examples,
704
+ inputs=inputs,
705
+ outputs=[ctrl_video, json_config],
706
+ cache_examples=True,
707
+ )
708
+
709
+ return demo
710
+
711
+
712
+ if __name__ == "__main__":
713
+ demo = ui()
714
+ demo.queue(max_size=20)
715
+ demo.launch(server_name="localhost", share=True)
configs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
configs/eval0.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ filter_params:
2
+ method: 'gaussian'
3
+ d_s: 0.25
4
+ d_t: 0.25
configs/eval1.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ filter_params:
2
+ method: 'butterworth'
3
+ n: 4
4
+ d_s: 0.25
5
+ d_t: 0.25
configs/inference/.ipynb_checkpoints/inference-v1-checkpoint.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ unet_use_cross_frame_attention: false
3
+ unet_use_temporal_attention: false
4
+ use_motion_module: true
5
+ motion_module_resolutions:
6
+ - 1
7
+ - 2
8
+ - 4
9
+ - 8
10
+ motion_module_mid_block: false
11
+ motion_module_decoder_only: false
12
+ motion_module_type: Vanilla
13
+ motion_module_kwargs:
14
+ num_attention_heads: 8
15
+ num_transformer_block: 1
16
+ attention_block_types:
17
+ - Temporal_Self
18
+ - Temporal_Self
19
+ temporal_position_encoding: true
20
+ temporal_position_encoding_max_len: 24
21
+ temporal_attention_dim_div: 1
22
+
23
+ noise_scheduler_kwargs:
24
+ beta_start: 0.00085
25
+ beta_end: 0.012
26
+ beta_schedule: "linear"
configs/inference/inference-v1.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ unet_use_cross_frame_attention: false
3
+ unet_use_temporal_attention: false
4
+ use_motion_module: true
5
+ motion_module_resolutions:
6
+ - 1
7
+ - 2
8
+ - 4
9
+ - 8
10
+ motion_module_mid_block: false
11
+ motion_module_decoder_only: false
12
+ motion_module_type: Vanilla
13
+ motion_module_kwargs:
14
+ num_attention_heads: 8
15
+ num_transformer_block: 1
16
+ attention_block_types:
17
+ - Temporal_Self
18
+ - Temporal_Self
19
+ temporal_position_encoding: true
20
+ temporal_position_encoding_max_len: 24
21
+ temporal_attention_dim_div: 1
22
+
23
+ noise_scheduler_kwargs:
24
+ beta_start: 0.00085
25
+ beta_end: 0.012
26
+ beta_schedule: "linear"
configs/inference/inference-v2.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ use_inflated_groupnorm: true
3
+ unet_use_cross_frame_attention: false
4
+ unet_use_temporal_attention: false
5
+ use_motion_module: true
6
+ motion_module_resolutions:
7
+ - 1
8
+ - 2
9
+ - 4
10
+ - 8
11
+ motion_module_mid_block: true
12
+ motion_module_decoder_only: false
13
+ motion_module_type: Vanilla
14
+ motion_module_kwargs:
15
+ num_attention_heads: 8
16
+ num_transformer_block: 1
17
+ attention_block_types:
18
+ - Temporal_Self
19
+ - Temporal_Self
20
+ temporal_position_encoding: true
21
+ temporal_position_encoding_max_len: 32
22
+ temporal_attention_dim_div: 1
23
+
24
+ noise_scheduler_kwargs:
25
+ beta_start: 0.00085
26
+ beta_end: 0.012
27
+ beta_schedule: "linear"
configs/prompts/.DS_Store ADDED
Binary file (6.15 kB). View file
 
configs/prompts/1-ToonYou.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ToonYou:
2
+ motion_module:
3
+ - "models/Motion_Module/mm_sd_v14.ckpt"
4
+ - "models/Motion_Module/mm_sd_v15.ckpt"
5
+
6
+ dreambooth_path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors"
7
+ lora_model_path: ""
8
+
9
+ seed: [10788741199826055526, 6520604954829636163, 6519455744612555650, 16372571278361863751]
10
+ steps: 25
11
+ guidance_scale: 7.5
12
+
13
+ prompt:
14
+ - "best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress"
15
+ - "masterpiece, best quality, 1girl, solo, cherry blossoms, hanami, pink flower, white flower, spring season, wisteria, petals, flower, plum blossoms, outdoors, falling petals, white hair, black eyes,"
16
+ - "best quality, masterpiece, 1boy, formal, abstract, looking at viewer, masculine, marble pattern"
17
+ - "best quality, masterpiece, 1girl, cloudy sky, dandelion, contrapposto, alternate hairstyle,"
18
+
19
+ n_prompt:
20
+ - ""
21
+ - "badhandv4,easynegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, teeth"
22
+ - ""
23
+ - ""
configs/prompts/2-Lyriel.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Lyriel:
2
+ motion_module:
3
+ - "models/Motion_Module/mm_sd_v14.ckpt"
4
+ - "models/Motion_Module/mm_sd_v15.ckpt"
5
+
6
+ dreambooth_path: "models/DreamBooth_LoRA/lyriel_v16.safetensors"
7
+ lora_model_path: ""
8
+
9
+ seed: [10917152860782582783, 6399018107401806238, 15875751942533906793, 6653196880059936551]
10
+ steps: 25
11
+ guidance_scale: 7.5
12
+
13
+ prompt:
14
+ - "dark shot, epic realistic, portrait of halo, sunglasses, blue eyes, tartan scarf, white hair by atey ghailan, by greg rutkowski, by greg tocchini, by james gilleard, by joe fenton, by kaethe butcher, gradient yellow, black, brown and magenta color scheme, grunge aesthetic!!! graffiti tag wall background, art by greg rutkowski and artgerm, soft cinematic light, adobe lightroom, photolab, hdr, intricate, highly detailed, depth of field, faded, neutral colors, hdr, muted colors, hyperdetailed, artstation, cinematic, warm lights, dramatic light, intricate details, complex background, rutkowski, teal and orange"
15
+ - "A forbidden castle high up in the mountains, pixel art, intricate details2, hdr, intricate details, hyperdetailed5, natural skin texture, hyperrealism, soft light, sharp, game art, key visual, surreal"
16
+ - "dark theme, medieval portrait of a man sharp features, grim, cold stare, dark colors, Volumetric lighting, baroque oil painting by Greg Rutkowski, Artgerm, WLOP, Alphonse Mucha dynamic lighting hyperdetailed intricately detailed, hdr, muted colors, complex background, hyperrealism, hyperdetailed, amandine van ray"
17
+ - "As I have gone alone in there and with my treasures bold, I can keep my secret where and hint of riches new and old. Begin it where warm waters halt and take it in a canyon down, not far but too far to walk, put in below the home of brown."
18
+
19
+ n_prompt:
20
+ - "3d, cartoon, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, young, loli, elf, 3d, illustration"
21
+ - "3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular"
22
+ - "dof, grayscale, black and white, bw, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular,badhandsv5-neg, By bad artist -neg 1, monochrome"
23
+ - "holding an item, cowboy, hat, cartoon, 3d, disfigured, bad art, deformed,extra limbs,close up,b&w, wierd colors, blurry, duplicate, morbid, mutilated, [out of frame], extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, out of frame, ugly, extra limbs, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, Photoshop, video game, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, 3d render"
configs/prompts/3-RcnzCartoon.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ RcnzCartoon:
2
+ motion_module:
3
+ - "models/Motion_Module/mm_sd_v14.ckpt"
4
+ - "models/Motion_Module/mm_sd_v15.ckpt"
5
+
6
+ dreambooth_path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors"
7
+ lora_model_path: ""
8
+
9
+ seed: [16931037867122267877, 2094308009433392066, 4292543217695451092, 15572665120852309890]
10
+ steps: 25
11
+ guidance_scale: 7.5
12
+
13
+ prompt:
14
+ - "Jane Eyre with headphones, natural skin texture,4mm,k textures, soft cinematic light, adobe lightroom, photolab, hdr, intricate, elegant, highly detailed, sharp focus, cinematic look, soothing tones, insane details, intricate details, hyperdetailed, low contrast, soft cinematic light, dim colors, exposure blend, hdr, faded"
15
+ - "close up Portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal [rust], elegant, sharp focus, photo by greg rutkowski, soft lighting, vibrant colors, masterpiece, streets, detailed face"
16
+ - "absurdres, photorealistic, masterpiece, a 30 year old man with gold framed, aviator reading glasses and a black hooded jacket and a beard, professional photo, a character portrait, altermodern, detailed eyes, detailed lips, detailed face, grey eyes"
17
+ - "a golden labrador, warm vibrant colours, natural lighting, dappled lighting, diffused lighting, absurdres, highres,k, uhd, hdr, rtx, unreal, octane render, RAW photo, photorealistic, global illumination, subsurface scattering"
18
+
19
+ n_prompt:
20
+ - "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation"
21
+ - "nude, cross eyed, tongue, open mouth, inside, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, red eyes, muscular"
22
+ - "easynegative, cartoon, anime, sketches, necklace, earrings worst quality, low quality, normal quality, bad anatomy, bad hands, shiny skin, error, missing fingers, extra digit, fewer digits, jpeg artifacts, signature, watermark, username, blurry, chubby, anorectic, bad eyes, old, wrinkled skin, red skin, photograph By bad artist -neg, big eyes, muscular face,"
23
+ - "beard, EasyNegative, lowres, chromatic aberration, depth of field, motion blur, blurry, bokeh, bad quality, worst quality, multiple arms, badhand"
configs/prompts/4-MajicMix.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MajicMix:
2
+ motion_module:
3
+ - "models/Motion_Module/mm_sd_v14.ckpt"
4
+ - "models/Motion_Module/mm_sd_v15.ckpt"
5
+
6
+ dreambooth_path: "models/DreamBooth_LoRA/majicmixRealistic_v5Preview.safetensors"
7
+ lora_model_path: ""
8
+
9
+ seed: [1572448948722921032, 1099474677988590681, 6488833139725635347, 18339859844376517918]
10
+ steps: 25
11
+ guidance_scale: 7.5
12
+
13
+ prompt:
14
+ - "1girl, offshoulder, light smile, shiny skin best quality, masterpiece, photorealistic"
15
+ - "best quality, masterpiece, photorealistic, 1boy, 50 years old beard, dramatic lighting"
16
+ - "best quality, masterpiece, photorealistic, 1girl, light smile, shirt with collars, waist up, dramatic lighting, from below"
17
+ - "male, man, beard, bodybuilder, skinhead,cold face, tough guy, cowboyshot, tattoo, french windows, luxury hotel masterpiece, best quality, photorealistic"
18
+
19
+ n_prompt:
20
+ - "ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, watermark, moles"
21
+ - "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome"
22
+ - "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome"
23
+ - "nude, nsfw, ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, monochrome, grayscale watermark, moles, people"
configs/prompts/5-RealisticVision.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ RealisticVision:
2
+ motion_module:
3
+ - "models/Motion_Module/mm_sd_v14.ckpt"
4
+ - "models/Motion_Module/mm_sd_v15.ckpt"
5
+
6
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
7
+ lora_model_path: ""
8
+
9
+ seed: [5658137986800322009, 12099779162349365895, 10499524853910852697, 16768009035333711932]
10
+ steps: 25
11
+ guidance_scale: 7.5
12
+
13
+ prompt:
14
+ - "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
15
+ - "close up photo of a rabbit, forest, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot"
16
+ - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
17
+ - "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain"
18
+
19
+ n_prompt:
20
+ - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
21
+ - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
22
+ - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
23
+ - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, art, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
configs/prompts/6-Tusun.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tusun:
2
+ motion_module:
3
+ - "models/Motion_Module/mm_sd_v14.ckpt"
4
+ - "models/Motion_Module/mm_sd_v15.ckpt"
5
+
6
+ dreambooth_path: "models/DreamBooth_LoRA/moonfilm_reality20.safetensors"
7
+ lora_model_path: "models/DreamBooth_LoRA/TUSUN.safetensors"
8
+ lora_alpha: 0.6
9
+
10
+ seed: [10154078483724687116, 2664393535095473805, 4231566096207622938, 1713349740448094493]
11
+ steps: 25
12
+ guidance_scale: 7.5
13
+
14
+ prompt:
15
+ - "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing"
16
+ - "cute tusun with a blurry background, black background, simple background, signature, face, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing"
17
+ - "cut tusuncub walking in the snow, blurry, looking at viewer, depth of field, blurry background, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing"
18
+ - "character design, cyberpunk tusun kitten wearing astronaut suit, sci-fic, realistic eye color and details, fluffy, big head, science fiction, communist ideology, Cyborg, fantasy, intense angle, soft lighting, photograph, 4k, hyper detailed, portrait wallpaper, realistic, photo-realistic, DSLR, 24 Megapixels, Full Frame, vibrant details, octane render, finely detail, best quality, incredibly absurdres, robotic parts, rim light, vibrant details, luxurious cyberpunk, hyperrealistic, cable electric wires, microchip, full body"
19
+
20
+ n_prompt:
21
+ - "worst quality, low quality, deformed, distorted, disfigured, bad eyes, bad anatomy, disconnected limbs, wrong body proportions, low quality, worst quality, text, watermark, signatre, logo, illustration, painting, cartoons, ugly, easy_negative"
configs/prompts/7-FilmVelvia.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FilmVelvia:
2
+ motion_module:
3
+ - "models/Motion_Module/mm_sd_v14.ckpt"
4
+ - "models/Motion_Module/mm_sd_v15.ckpt"
5
+
6
+ dreambooth_path: "models/DreamBooth_LoRA/majicmixRealistic_v4.safetensors"
7
+ lora_model_path: "models/DreamBooth_LoRA/FilmVelvia2.safetensors"
8
+ lora_alpha: 0.6
9
+
10
+ seed: [358675358833372813, 3519455280971923743, 11684545350557985081, 8696855302100399877]
11
+ steps: 25
12
+ guidance_scale: 7.5
13
+
14
+ prompt:
15
+ - "a woman standing on the side of a road at night,girl, long hair, motor vehicle, car, looking at viewer, ground vehicle, night, hands in pockets, blurry background, coat, black hair, parted lips, bokeh, jacket, brown hair, outdoors, red lips, upper body, artist name"
16
+ - ", dark shot,0mm, portrait quality of a arab man worker,boy, wasteland that stands out vividly against the background of the desert, barren landscape, closeup, moles skin, soft light, sharp, exposure blend, medium shot, bokeh, hdr, high contrast, cinematic, teal and orange5, muted colors, dim colors, soothing tones, low saturation, hyperdetailed, noir"
17
+ - "fashion photography portrait of 1girl, offshoulder, fluffy short hair, soft light, rim light, beautiful shadow, low key, photorealistic, raw photo, natural skin texture, realistic eye and face details, hyperrealism, ultra high res, 4K, Best quality, masterpiece, necklace, cleavage, in the dark"
18
+ - "In this lighthearted portrait, a woman is dressed as a fierce warrior, armed with an arsenal of paintbrushes and palette knives. Her war paint is composed of thick, vibrant strokes of color, and her armor is made of paint tubes and paint-splattered canvases. She stands victoriously atop a mountain of conquered blank canvases, with a beautiful, colorful landscape behind her, symbolizing the power of art and creativity. bust Portrait, close-up, Bright and transparent scene lighting, "
19
+
20
+ n_prompt:
21
+ - "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
22
+ - "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
23
+ - "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
24
+ - "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
configs/prompts/8-GhibliBackground.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GhibliBackground:
2
+ motion_module:
3
+ - "models/Motion_Module/mm_sd_v14.ckpt"
4
+ - "models/Motion_Module/mm_sd_v15.ckpt"
5
+
6
+ dreambooth_path: "models/DreamBooth_LoRA/CounterfeitV30_25.safetensors"
7
+ lora_model_path: "models/DreamBooth_LoRA/lora_Ghibli_n3.safetensors"
8
+ lora_alpha: 1.0
9
+
10
+ seed: [8775748474469046618, 5893874876080607656, 11911465742147695752, 12437784838692000640]
11
+ steps: 25
12
+ guidance_scale: 7.5
13
+
14
+ prompt:
15
+ - "best quality,single build,architecture, blue_sky, building,cloudy_sky, day, fantasy, fence, field, house, build,architecture,landscape, moss, outdoors, overgrown, path, river, road, rock, scenery, sky, sword, tower, tree, waterfall"
16
+ - "black_border, building, city, day, fantasy, ice, landscape, letterboxed, mountain, ocean, outdoors, planet, scenery, ship, snow, snowing, water, watercraft, waterfall, winter"
17
+ - ",mysterious sea area, fantasy,build,concept"
18
+ - "Tomb Raider,Scenography,Old building"
19
+
20
+ n_prompt:
21
+ - "easynegative,bad_construction,bad_structure,bad_wail,bad_windows,blurry,cloned_window,cropped,deformed,disfigured,error,extra_windows,extra_chimney,extra_door,extra_structure,extra_frame,fewer_digits,fused_structure,gross_proportions,jpeg_artifacts,long_roof,low_quality,structure_limbs,missing_windows,missing_doors,missing_roofs,mutated_structure,mutation,normal_quality,out_of_frame,owres,poorly_drawn_structure,poorly_drawn_house,signature,text,too_many_windows,ugly,username,uta,watermark,worst_quality"
configs/prompts/unictrl_examples/RealisticVision_v1.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ RealisticVision:
2
+ motion_module:
3
+ - "models/Motion_Module/mm_sd_v14.ckpt"
4
+
5
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v20Novae.safetensors"
6
+ lora_model_path: ""
7
+
8
+ seed: [442, 123]
9
+ steps: 25
10
+ guidance_scale: 7.5
11
+
12
+
13
+ filter_params:
14
+ method: 'butterworth'
15
+ n: 4
16
+ d_s: 0.25
17
+ d_t: 0.25
18
+
19
+ # filter_params:
20
+ # method: 'gaussian'
21
+ # d_s: 0.25
22
+ # d_t: 0.25
23
+
24
+ prompt:
25
+ - "A cat wearing sunglasses and working as a lifeguard at a pool."
26
+ - "A panda cooking in the kitchen"
27
+
28
+ n_prompt:
29
+ - "worst quality, low quality, nsfw, logo"
30
+ - ""
configs/prompts/unictrl_examples/RealisticVision_v2.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ RealisticVision:
2
+ inference_config: "configs/inference/inference-v2.yaml"
3
+ motion_module:
4
+ - "models/Motion_Module/mm_sd_v15_v2.ckpt"
5
+
6
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV60B1_v20Novae.safetensors"
7
+ lora_model_path: ""
8
+
9
+ seed: [9620, 913, 6840, 1334]
10
+ steps: 25
11
+ guidance_scale: 7.5
12
+
13
+ filter_params:
14
+ method: 'butterworth'
15
+ n: 4
16
+ d_s: 0.25
17
+ d_t: 0.25
18
+
19
+ # filter_params:
20
+ # method: 'gaussian'
21
+ # d_s: 0.25
22
+ # d_t: 0.25
23
+
24
+ prompt:
25
+ - "A panda cooking in the kitchen"
26
+ - "A cat wearing sunglasses and working as a lifeguard at a pool."
27
+ - "A confused panda in calculus class"
28
+ - "A robot DJ is playing the turntable, in heavy raining futuristic tokyo rooftop cyberpunk night, sci-fi, fantasy"
29
+
30
+ n_prompt:
31
+ - ""
32
+ - ""
33
+ - ""
34
+ - ""
35
+
configs/prompts/v2/5-RealisticVision-MotionLoRA.yaml ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ZoomIn:
2
+ inference_config: "configs/inference/inference-v2.yaml"
3
+ motion_module:
4
+ - "models/Motion_Module/mm_sd_v15_v2.ckpt"
5
+
6
+ motion_module_lora_configs:
7
+ - path: "models/MotionLoRA/v2_lora_ZoomIn.ckpt"
8
+ alpha: 1.0
9
+
10
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
11
+ lora_model_path: ""
12
+
13
+ seed: 45987230
14
+ steps: 25
15
+ guidance_scale: 7.5
16
+
17
+ prompt:
18
+ - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
19
+
20
+ n_prompt:
21
+ - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
22
+
23
+
24
+
25
+ ZoomOut:
26
+ inference_config: "configs/inference/inference-v2.yaml"
27
+ motion_module:
28
+ - "models/Motion_Module/mm_sd_v15_v2.ckpt"
29
+
30
+ motion_module_lora_configs:
31
+ - path: "models/MotionLoRA/v2_lora_ZoomOut.ckpt"
32
+ alpha: 1.0
33
+
34
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
35
+ lora_model_path: ""
36
+
37
+ seed: 45987230
38
+ steps: 25
39
+ guidance_scale: 7.5
40
+
41
+ prompt:
42
+ - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
43
+
44
+ n_prompt:
45
+ - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
46
+
47
+
48
+
49
+ PanLeft:
50
+ inference_config: "configs/inference/inference-v2.yaml"
51
+ motion_module:
52
+ - "models/Motion_Module/mm_sd_v15_v2.ckpt"
53
+
54
+ motion_module_lora_configs:
55
+ - path: "models/MotionLoRA/v2_lora_PanLeft.ckpt"
56
+ alpha: 1.0
57
+
58
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
59
+ lora_model_path: ""
60
+
61
+ seed: 45987230
62
+ steps: 25
63
+ guidance_scale: 7.5
64
+
65
+ prompt:
66
+ - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
67
+
68
+ n_prompt:
69
+ - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
70
+
71
+
72
+
73
+ PanRight:
74
+ inference_config: "configs/inference/inference-v2.yaml"
75
+ motion_module:
76
+ - "models/Motion_Module/mm_sd_v15_v2.ckpt"
77
+
78
+ motion_module_lora_configs:
79
+ - path: "models/MotionLoRA/v2_lora_PanRight.ckpt"
80
+ alpha: 1.0
81
+
82
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
83
+ lora_model_path: ""
84
+
85
+ seed: 45987230
86
+ steps: 25
87
+ guidance_scale: 7.5
88
+
89
+ prompt:
90
+ - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
91
+
92
+ n_prompt:
93
+ - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
94
+
95
+
96
+
97
+ TiltUp:
98
+ inference_config: "configs/inference/inference-v2.yaml"
99
+ motion_module:
100
+ - "models/Motion_Module/mm_sd_v15_v2.ckpt"
101
+
102
+ motion_module_lora_configs:
103
+ - path: "models/MotionLoRA/v2_lora_TiltUp.ckpt"
104
+ alpha: 1.0
105
+
106
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
107
+ lora_model_path: ""
108
+
109
+ seed: 45987230
110
+ steps: 25
111
+ guidance_scale: 7.5
112
+
113
+ prompt:
114
+ - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
115
+
116
+ n_prompt:
117
+ - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
118
+
119
+
120
+
121
+ TiltDown:
122
+ inference_config: "configs/inference/inference-v2.yaml"
123
+ motion_module:
124
+ - "models/Motion_Module/mm_sd_v15_v2.ckpt"
125
+
126
+ motion_module_lora_configs:
127
+ - path: "models/MotionLoRA/v2_lora_TiltDown.ckpt"
128
+ alpha: 1.0
129
+
130
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
131
+ lora_model_path: ""
132
+
133
+ seed: 45987230
134
+ steps: 25
135
+ guidance_scale: 7.5
136
+
137
+ prompt:
138
+ - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
139
+
140
+ n_prompt:
141
+ - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
142
+
143
+
144
+
145
+ RollingAnticlockwise:
146
+ inference_config: "configs/inference/inference-v2.yaml"
147
+ motion_module:
148
+ - "models/Motion_Module/mm_sd_v15_v2.ckpt"
149
+
150
+ motion_module_lora_configs:
151
+ - path: "models/MotionLoRA/v2_lora_RollingAnticlockwise.ckpt"
152
+ alpha: 1.0
153
+
154
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
155
+ lora_model_path: ""
156
+
157
+ seed: 45987230
158
+ steps: 25
159
+ guidance_scale: 7.5
160
+
161
+ prompt:
162
+ - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
163
+
164
+ n_prompt:
165
+ - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
166
+
167
+
168
+
169
+ RollingClockwise:
170
+ inference_config: "configs/inference/inference-v2.yaml"
171
+ motion_module:
172
+ - "models/Motion_Module/mm_sd_v15_v2.ckpt"
173
+
174
+ motion_module_lora_configs:
175
+ - path: "models/MotionLoRA/v2_lora_RollingClockwise.ckpt"
176
+ alpha: 1.0
177
+
178
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
179
+ lora_model_path: ""
180
+
181
+ seed: 45987230
182
+ steps: 25
183
+ guidance_scale: 7.5
184
+
185
+ prompt:
186
+ - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
187
+
188
+ n_prompt:
189
+ - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
configs/prompts/v2/5-RealisticVision.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ RealisticVision:
2
+ inference_config: "configs/inference/inference-v2.yaml"
3
+ motion_module:
4
+ - "models/Motion_Module/mm_sd_v15_v2.ckpt"
5
+
6
+ dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
7
+ lora_model_path: ""
8
+
9
+ seed: [13100322578370451493, 14752961627088720670, 9329399085567825781, 16987697414827649302]
10
+ steps: 25
11
+ guidance_scale: 7.5
12
+
13
+ prompt:
14
+ - "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
15
+ - "close up photo of a rabbit, forest, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot"
16
+ - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
17
+ - "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain"
18
+
19
+ n_prompt:
20
+ - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
21
+ - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
22
+ - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
23
+ - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, art, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
configs/training/image_finetune.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_finetune: true
2
+
3
+ output_dir: "outputs"
4
+ pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5"
5
+
6
+ noise_scheduler_kwargs:
7
+ num_train_timesteps: 1000
8
+ beta_start: 0.00085
9
+ beta_end: 0.012
10
+ beta_schedule: "scaled_linear"
11
+ steps_offset: 1
12
+ clip_sample: false
13
+
14
+ train_data:
15
+ csv_path: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv"
16
+ video_folder: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val"
17
+ sample_size: 256
18
+
19
+ validation_data:
20
+ prompts:
21
+ - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons."
22
+ - "A drone view of celebration with Christma tree and fireworks, starry sky - background."
23
+ - "Robot dancing in times square."
24
+ - "Pacific coast, carmel by the sea ocean and waves."
25
+ num_inference_steps: 25
26
+ guidance_scale: 8.
27
+
28
+ trainable_modules:
29
+ - "."
30
+
31
+ unet_checkpoint_path: ""
32
+
33
+ learning_rate: 1.e-5
34
+ train_batch_size: 50
35
+
36
+ max_train_epoch: -1
37
+ max_train_steps: 100
38
+ checkpointing_epochs: -1
39
+ checkpointing_steps: 60
40
+
41
+ validation_steps: 5000
42
+ validation_steps_tuple: [2, 50]
43
+
44
+ global_seed: 42
45
+ mixed_precision_training: true
46
+ enable_xformers_memory_efficient_attention: True
47
+
48
+ is_debug: False
configs/training/training.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_finetune: false
2
+
3
+ output_dir: "outputs"
4
+ pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5"
5
+
6
+ unet_additional_kwargs:
7
+ use_motion_module : true
8
+ motion_module_resolutions : [ 1,2,4,8 ]
9
+ unet_use_cross_frame_attention : false
10
+ unet_use_temporal_attention : false
11
+
12
+ motion_module_type: Vanilla
13
+ motion_module_kwargs:
14
+ num_attention_heads : 8
15
+ num_transformer_block : 1
16
+ attention_block_types : [ "Temporal_Self", "Temporal_Self" ]
17
+ temporal_position_encoding : true
18
+ temporal_position_encoding_max_len : 24
19
+ temporal_attention_dim_div : 1
20
+ zero_initialize : true
21
+
22
+ noise_scheduler_kwargs:
23
+ num_train_timesteps: 1000
24
+ beta_start: 0.00085
25
+ beta_end: 0.012
26
+ beta_schedule: "linear"
27
+ steps_offset: 1
28
+ clip_sample: false
29
+
30
+ train_data:
31
+ csv_path: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv"
32
+ video_folder: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val"
33
+ sample_size: 256
34
+ sample_stride: 4
35
+ sample_n_frames: 16
36
+
37
+ validation_data:
38
+ prompts:
39
+ - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons."
40
+ - "A drone view of celebration with Christma tree and fireworks, starry sky - background."
41
+ - "Robot dancing in times square."
42
+ - "Pacific coast, carmel by the sea ocean and waves."
43
+ num_inference_steps: 25
44
+ guidance_scale: 8.
45
+
46
+ trainable_modules:
47
+ - "motion_modules."
48
+
49
+ unet_checkpoint_path: ""
50
+
51
+ learning_rate: 1.e-4
52
+ train_batch_size: 4
53
+
54
+ max_train_epoch: -1
55
+ max_train_steps: 100
56
+ checkpointing_epochs: -1
57
+ checkpointing_steps: 60
58
+
59
+ validation_steps: 5000
60
+ validation_steps_tuple: [2, 50]
61
+
62
+ global_seed: 42
63
+ mixed_precision_training: true
64
+ enable_xformers_memory_efficient_attention: True
65
+
66
+ is_debug: False
download_bashscripts/0-MotionModule.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gdown 1RqkQuGPaCO5sGZ6V6KZ-jUWmsRu48Kdq -O models/Motion_Module/
2
+ gdown 1ql0g_Ys4UCz2RnokYlBjyOYPbttbIpbu -O models/Motion_Module/
download_bashscripts/1-ToonYou.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ wget https://civitai.com/api/download/models/78775 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
download_bashscripts/2-Lyriel.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ wget https://civitai.com/api/download/models/72396 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
download_bashscripts/3-RcnzCartoon.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ wget https://civitai.com/api/download/models/71009 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
download_bashscripts/4-MajicMix.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ wget https://civitai.com/api/download/models/79068 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
download_bashscripts/5-RealisticVision.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #!/bin/bash
2
+ wget https://civitai.com/api/download/models/29460 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
download_bashscripts/6-Tusun.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/bin/bash
2
+ wget https://civitai.com/api/download/models/97261 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
3
+ wget https://civitai.com/api/download/models/50705 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
download_bashscripts/7-FilmVelvia.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/bin/bash
2
+ wget https://civitai.com/api/download/models/90115 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
3
+ wget https://civitai.com/api/download/models/92475 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
download_bashscripts/8-GhibliBackground.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/bin/bash
2
+ wget https://civitai.com/api/download/models/102828 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
3
+ wget https://civitai.com/api/download/models/57618 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate
environment.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: animatediff_pt2
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python=3.10
7
+ - pytorch=2.1.2
8
+ - torchvision=0.16.2
9
+ - torchaudio=2.1.2
10
+ - pytorch-cuda=11.8
11
+ - pip
12
+ - pip:
13
+ - accelerate==0.25.0
14
+ - diffusers==0.26.1
15
+ - transformers==4.25.1
16
+ - imageio==2.27.0
17
+ - decord==0.6.0
18
+ - gdown
19
+ - einops
20
+ - omegaconf
21
+ - safetensors
22
+ - gradio==3.41.2
23
+ - wandb
24
+ - imageio-ffmpeg
25
+ - av
models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/DreamBooth_LoRA/Put personalized T2I checkpoints here.txt ADDED
File without changes
models/DreamBooth_LoRA/lyriel_v16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cdfd07bbcceec4cea1984cb3fc1d723dbcf66ce1ca3a9bc7060e90fc54065b2b
3
+ size 482344960
models/DreamBooth_LoRA/majicmixRealistic_v5Preview.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de829e47b30a6ccd50186c598a93275cca888b0e29c031f6e4a7eb9e94bf57bb
3
+ size 482607104
models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f363dfbc22412bb4ec8f0e62eca01bc23c780e352521267fbc8a94956621361
3
+ size 480247808
models/DreamBooth_LoRA/realisticVisionV60B1_v20Novae.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d08a6c2431f19dc98fc6a02b0f4d74713d43816316ec8958eb701ea95fa58711
3
+ size 480772096