adaface-neurips commited on
Commit
02cc20b
1 Parent(s): 550b1c1
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +7 -0
  2. Dockerfile +15 -0
  3. README.md +6 -8
  4. README2.md +241 -0
  5. adaface/adaface-infer.py +131 -0
  6. adaface/adaface-translate.py +208 -0
  7. adaface/adaface_wrapper.py +286 -0
  8. adaface/arc2face_models.py +303 -0
  9. adaface/subj_basis_generator.py +758 -0
  10. adaface/util.py +341 -0
  11. animatediff/models/attention.py +327 -0
  12. animatediff/models/attention_bkp.py +326 -0
  13. animatediff/models/motion_module.py +552 -0
  14. animatediff/models/motion_module_bkp.py +331 -0
  15. animatediff/models/resnet.py +217 -0
  16. animatediff/models/sparse_controlnet.py +587 -0
  17. animatediff/models/unet.py +600 -0
  18. animatediff/models/unet_blocks.py +760 -0
  19. animatediff/pipelines/pipeline_animation.py +793 -0
  20. animatediff/sd/.gitattributes +35 -0
  21. animatediff/sd/feature_extractor/preprocessor_config.json +20 -0
  22. animatediff/sd/model_index.json +32 -0
  23. animatediff/sd/safety_checker/config.json +175 -0
  24. animatediff/sd/safety_checker/pytorch_model.bin +3 -0
  25. animatediff/sd/scheduler/scheduler_config.json +13 -0
  26. animatediff/sd/text_encoder/config.json +25 -0
  27. animatediff/sd/text_encoder/pytorch_model.bin +3 -0
  28. animatediff/sd/tokenizer/merges.txt +0 -0
  29. animatediff/sd/tokenizer/special_tokens_map.json +24 -0
  30. animatediff/sd/tokenizer/tokenizer_config.json +34 -0
  31. animatediff/sd/tokenizer/vocab.json +0 -0
  32. animatediff/sd/unet/config.json +36 -0
  33. animatediff/sd/unet/diffusion_pytorch_model.bin +3 -0
  34. animatediff/sd/v1-inference.yaml +70 -0
  35. animatediff/sd/vae/config.json +29 -0
  36. animatediff/sd/vae/diffusion_pytorch_model.bin +3 -0
  37. animatediff/utils/convert_from_ckpt.py +959 -0
  38. animatediff/utils/convert_lora_safetensor_to_diffusers.py +152 -0
  39. animatediff/utils/convert_original_stable_diffusion_to_diffusers.py +188 -0
  40. animatediff/utils/util.py +225 -0
  41. app.py +402 -0
  42. assets/alita/alita armor orig.mp4 +0 -0
  43. assets/alita/alita armor.mp4 +0 -0
  44. assets/alita/alita beach orig.mp4 +0 -0
  45. assets/alita/alita beach.mp4 +0 -0
  46. assets/alita/alita cooking orig.mp4 +0 -0
  47. assets/alita/alita cooking.mp4 +0 -0
  48. assets/alita/alita dancing orig.mp4 +0 -0
  49. assets/alita/alita dancing.mp4 +0 -0
  50. assets/alita/alita iron man orig.mp4 +0 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/*
2
+ __pycache__/
3
+ *.pyc
4
+ gradio_cached_examples/*
5
+ gradio_cached_examples/
6
+ samples/*
7
+ samples/
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8-slim
2
+ ENV PYTHONUNBUFFERED=1
3
+
4
+ RUN RUN apt-get update && \
5
+ apt-get install -y \
6
+ bash \
7
+ git git-lfs \
8
+ wget curl procps \
9
+ htop vim nano && \
10
+ rm -rf /var/lib/apt/lists/*
11
+
12
+ WORKDIR /app
13
+ COPY --link --chown=1000 ./ /app
14
+
15
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,12 +1,10 @@
1
  ---
2
- title: Adaface Animate
3
- emoji: 🌖
4
- colorFrom: gray
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: AdaFace-Animate
3
+ emoji: 🎨
4
+ colorFrom: yellow
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.27.0
8
  app_file: app.py
9
+ pinned: true
10
+ ---
 
 
README2.md ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AdaFace-Animate
2
+
3
+ This folder contains the preliminary implementation of **AdaFace-Animate**.
4
+ It is a zero-shot subject-guided animation generator conditioned with human subject images, by combining AnimateDiff, ID-Animator and AdaFace. The ID-Animator provides AnimateDiff with rough subject characteristics, and AdaFace provides refined and more authentic subject facial details.
5
+
6
+ Please refer to our NeurIPS 2024 submission for more details about AdaFace:
7
+
8
+ **AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization**
9
+ </br>
10
+
11
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97 Hugging Face-Spaces-yellow)](https://huggingface.co/spaces/adaface-neurips/adaface-animate)
12
+
13
+ This pipeline uses 4 pretrained models: [Stable Diffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5), [AnimateDiff v3](https://github.com/guoyww/animatediff), [ID-Animator](https://github.com/ID-Animator/ID-Animator) and [AdaFace](https://huggingface.co/adaface-neurips/adaface).
14
+
15
+ AnimateDiff uses a SD-1.5 type checkpoint, referred to as a "DreamBooth" model. The DreamBooth model we use is an average of three SD-1.5 models named as "SAR": the original [Stable Diffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors), [AbsoluteReality V1.8.1](https://civitai.com/models/81458?modelVersionId=132760), and [RealisticVision V4.0](https://civitai.com/models/4201?modelVersionId=114367). In our experiments, this average model performs better than any of the individual models.
16
+
17
+ ## Procedures of Generation
18
+ We find that using an initial image helps stablize the animation sequence and improve the quality. When generating each example video, an initial image is first generated by AdaFace with the same prompt as used to generate the video. This image is blended with multiple frames of random noises with weights decreasing with $t$. The multi-frame blended noises are converted to a 1-second animation with AnimateDiff, conditioned by both AdaFace and ID-Animator embeddings.
19
+
20
+ ## Gallery
21
+ [Gallery](./assets/) contains 100 subject videos generated by us. They belong to 10 celebrities, each with 10 different prompts. The (shortened) prompts are: "Armor Suit", "Iron Man Costume", "Superman Costume", "Wielding a Lightsaber", "Walking on the beach", "Cooking", "Dancing", "Playing Guitar", "Reading", and "Running".
22
+
23
+ Some example videos are shown below. The full set of videos can be found in [Gallery](./assets/).
24
+
25
+ (Hint: use the horizontal scroll bar at the bottom of the table to view the full table)
26
+
27
+ <table class="center" style="table-layout: fixed; width: 100%; overflow-x: auto;">
28
+ <tr style="line-height: 1">
29
+ <td width="25%" style="text-align: center">Input (Celebrities)</td>
30
+ <td width="25%" style="text-align: center">Animation 1: Playing Guitar</td>
31
+ <td width="25%" style="text-align: center">Animation 2: Cooking</td>
32
+ <td width="25%" style="text-align: center">Animation 3: Dancing</td>
33
+ </tr>
34
+ <tr>
35
+ <td style="text-align: center"><img src="assets/jennifer-lawrence/jennifer lawrence.jpg" style="width:100%"></td>
36
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/ea12a906-8637-4b32-97ba-c439990fec0a" type="video/mp4"></video></td>
37
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/83a08691-4f4e-4898-b4ae-be5dfcd1fb85" type="video/mp4"></video></td>
38
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/1e957f80-376b-4ca7-81ca-fa63f19a1c5a" type="video/mp4"></video></td>
39
+ </tr>
40
+ <tr>
41
+ <td style="text-align: center"><img src="assets/yann-lecun/yann lecun.png" style="width:100%"></td>
42
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/0af3f6dc-d3d9-486c-a083-ab77a8397d80" type="video/mp4"></video></td>
43
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/54f3745f-abf6-4608-93c5-d8e103d05dc7" type="video/mp4"></video></td>
44
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/273ecced-a796-4e59-a43a-217db7fb4681" type="video/mp4"></video></td>
45
+ </tr>
46
+ <tr>
47
+ <td style="text-align: center"><img src="assets/gakki/gakki.png" style="width:100%"></td>
48
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/28056aeb-5ce4-42bc-a593-877ba49834b9" type="video/mp4"></video></td>
49
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/68ad643c-8a2b-43a8-9c7b-10c36c4912d4" type="video/mp4"></video></td>
50
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/93f3891d-19c5-40fb-af21-a0b2e03d0d7f" type="video/mp4"></video></td>
51
+ </tr>
52
+
53
+ </table>
54
+
55
+ To illustrate the wide range of applications of our method, we animated 8 internet memes. 4 of them are shown in the table below. The full gallery can be found in [memes](./assets/memes/).
56
+
57
+ <table class="center">
58
+ <tr style="line-height: 1">
59
+ <td width=25% style="text-align: center">Input (Memes)</td>
60
+ <td width=25% style="text-align: center">Animation</td>
61
+ <td width=25% style="text-align: center">Input</td>
62
+ <td width=25% style="text-align: center">Animation</td>
63
+ </tr>
64
+ <tr>
65
+ <td style="text-align: center">Yao Ming Laugh</td><td></td><td style="text-align: center">Girl Burning House</td></td><td>
66
+ </tr>
67
+ <tr>
68
+ <td><img src="assets/memes/yao ming laugh.jpg" style="width:100%"></td>
69
+ <td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/984a751f-ed2b-4ce3-aef8-41056ac111cf" type="video/mp4"></video></td>
70
+ <td><img src="assets/memes/girl burning house.jpg" style="width:100%"></td>
71
+ <td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/11c83ae1-dece-4798-baf5-e608ab8709e3" type="video/mp4"></video></td>
72
+ </tr>
73
+ <tr>
74
+ <td style="text-align: center">Girl with a Pearl Earring</td><td></td><td style="text-align: center">Great Gatsby</td></td><td>
75
+ </tr>
76
+ <tr>
77
+ <td><img src="assets/memes/girl with a pearl earring.jpg" style="width:100%"></td>
78
+ <td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/3b773486-b87e-4331-9e5d-ec8d54e11394" type="video/mp4"></video></td>
79
+ <td><img src="assets/memes/great gatsby.jpg" style="width:100%"></td>
80
+ <td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/db941805-8a6c-4596-ba3a-247b54baa5ef" type="video/mp4"></video></td>
81
+ </tr>
82
+ </table>
83
+
84
+ ## Comparison with ID-Animator, with AdaFace Initial Images
85
+ To compare with the baseline method "ID-Animator", for each video, we disable AdaFace, and generate the corresponding video with ID-Animator, using otherwise identical settings: the same subject image(s) and initial image, and the same random seed and prompt. The table below compares some of these videos side-by-side with the AdaFace-Animate videos. The full set of ID-Animator videos can be found in each subject folder in [Gallery](./assets/), named as "* orig.mp4".
86
+
87
+ **NOTE** Since ID-Animator videos utilize initial images generated by **AdaFace**, this gives ID-Animator an advantage over the original ID-Animator.
88
+
89
+ (Hint: use the horizontal scroll bar at the bottom of the table to view the full table)
90
+
91
+ <table class="center" style="table-layout: fixed; width: 100%; overflow-x: auto;">
92
+ <tr style="line-height: 1">
93
+ <td width="14%" style="text-align: center; white-space: normal; word-wrap: break-word;">Initial Image: Playing Guitar</td>
94
+ <td width="18%" style="text-align: center; white-space: normal; word-wrap: break-word;">ID-Animator: Playing Guitar</td>
95
+ <td width="18%" style="text-align: center; white-space: normal; word-wrap: break-word;">AdaFace-Animate: Playing Guitar</td>
96
+ <td width="14%" style="text-align: center; white-space: normal; word-wrap: break-word;">Initial Image Dancing</td>
97
+ <td width="18%" style="text-align: center; white-space: normal; word-wrap: break-word;">ID-Animator: Dancing</td>
98
+ <td width="18%" style="text-align: center; white-space: normal; word-wrap: break-word;">AdaFace-Animate: Dancing</td>
99
+ </tr>
100
+ <tr>
101
+ <td style="text-align: center"><img src="assets/jennifer-lawrence/init images/jennifer lawrence playing guitar.jpg" style="width:100%"></td>
102
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/f5d0f2c6-f4bd-4517-bfa1-021db1577895" type="video/mp4"></video></td>
103
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/ea12a906-8637-4b32-97ba-c439990fec0a" type="video/mp4"></video></td>
104
+ <td style="text-align: center"><img src="assets/jennifer-lawrence/init images/jennifer lawrence dancing.jpg" style="width:100%"></td>
105
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/421f5b81-e1a7-459a-869a-f7f6dc51a74e" type="video/mp4"></video></td>
106
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/1e957f80-376b-4ca7-81ca-fa63f19a1c5a"
107
+ type="video/mp4"></video></td>
108
+ </tr>
109
+ <tr>
110
+ <td style="text-align: center"><img src="assets/yann-lecun/init images/yann lecun playing guitar.jpg" style="width:100%"></td>
111
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/3bbfc15b-4205-4052-b5cc-c4f8d6d17027" type="video/mp4"></video></td>
112
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/0af3f6dc-d3d9-486c-a083-ab77a8397d80" type="video/mp4"></video></td>
113
+ <td style="text-align: center"><img src="assets/yann-lecun/init images/yann lecun dancing.jpg" style="width:100%"></td>
114
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/75e191f4-87e2-486c-90e7-c9e21a1bf494" type="video/mp4"></video></td>
115
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/273ecced-a796-4e59-a43a-217db7fb4681"
116
+ type="video/mp4"></video></td>
117
+ </tr>
118
+ <tr>
119
+ <td style="text-align: center"><img src="assets/gakki/init images/gakki playing guitar.jpg" style="width:100%"></td>
120
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/6a5579ce-23e3-4603-8917-00a16d6a3682" type="video/mp4"></video></td>
121
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/28056aeb-5ce4-42bc-a593-877ba49834b9" type="video/mp4"></video></td>
122
+ <td style="text-align: center"><img src="assets/gakki/init images/gakki dancing.jpg" style="width:100%"></td>
123
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/28082e58-a0ed-4492-8c51-cb563f92baeb" type="video/mp4"></video></td>
124
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/93f3891d-19c5-40fb-af21-a0b2e03d0d7f" type="video/mp4"></video></td>
125
+ </tr>
126
+
127
+ </table>
128
+
129
+ The table below compares the animated internet memes. The initial image for each video is the meme image itself. For "Yao Ming laughing" and "Great Gatsby", 2~3 extra portrait photos of the subject are included as the subject images to enhance the facial fidelity. For other memes, the subject image is only the meme image. The full set of ID-Animator meme videos can be found in [memes](./assets/memes/), named as "* orig.mp4".
130
+
131
+ <table class="center" style="width: 60%;">
132
+ <tr style="line-height: 1">
133
+ <td width=20% style="text-align: center">Input (Memes)</td>
134
+ <td width=20% style="text-align: center">ID-Animator</td>
135
+ <td width=20% style="text-align: center">AdaFace-Animate</td>
136
+ </tr>
137
+ <tr>
138
+ <td><img src="assets/memes/yao ming laugh.jpg" style="width:100%"></td>
139
+ <td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/9daf814c-ae8a-476d-9c32-fa9ef6be16d9" type="video/mp4"></video></td>
140
+ <td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/984a751f-ed2b-4ce3-aef8-41056ac111cf" type="video/mp4"></video></td>
141
+ <tr>
142
+ <td><img src="assets/memes/girl with a pearl earring.jpg" style="width:100%"></td>
143
+ <td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/05ed29d5-4eaa-4a0a-bee2-bc77e5649f58" type="video/mp4"></video></td>
144
+ <td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/3b773486-b87e-4331-9e5d-ec8d54e11394" type="video/mp4"></video></td>
145
+ </tr>
146
+ </table>
147
+
148
+ We can see that the subjects in AdaFace-Animate videos have more authentic facial features and better preserve the facial expressions, while the subjects in ID-Animator videos are less authentic and faithful to the original images.
149
+
150
+ ## Comparison with ID-Animator, without AdaFace Initial Images
151
+ To exclude the effects of AdaFace, we generate a subset of videos with AdaFace-Animate / ID-Animator *without initial images*. These videos were generated under the same settings as above, except not using initial images. The table below shows a selection of the videos. The complete set of such videos can be found in [no-init](./assets/no-init/). It can be seen that without the help of AdaFace initial images, the compositionality, or the overall layout deteriorates on some prompts. In particular, some background objects are suppressed by over-expressed facial features. Moreover, the performance discrepancy between AdaFace-Animate and ID-Animator becomes more pronounced.
152
+
153
+ (Hint: use the horizontal scroll bar at the bottom of the table to view the full table)
154
+
155
+ <table class="center" style="table-layout: fixed; width: 100%; overflow-x: auto;">
156
+ <tr style="line-height: 1">
157
+ <td width="20%" style="text-align: center; white-space: normal; word-wrap: break-word;">Input (Celebrities)</td>
158
+ <td width="20%" style="text-align: center; white-space: normal; word-wrap: break-word;">ID-Animator: Playing Guitar</td>
159
+ <td width="20%" style="text-align: center; white-space: normal; word-wrap: break-word;">AdaFace-Animate: Playing Guitar</td>
160
+ <td width="20%" style="text-align: center; white-space: normal; word-wrap: break-word;">ID-Animator: Dancing</td>
161
+ <td width="20%" style="text-align: center; white-space: normal; word-wrap: break-word;">AdaFace-Animate: Dancing</td>
162
+ </tr>
163
+ <tr>
164
+ <td style="text-align: center"><img src="assets/jennifer-lawrence/jennifer lawrence.jpg" style="width:100%"></td>
165
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/2c3fa70b-4a38-48d1-aead-cd94976f6beb" type="video/mp4"></video></td>
166
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/f658f9e6-c3b6-4c4a-920c-00a89b98d97a" type="video/mp4"></video></td>
167
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/2de5cb38-f62c-4e9d-90ad-9bbb72d1ba7a" type="video/mp4"></video></td>
168
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/3b39e66d-c696-4022-81ff-6afae8147981" type="video/mp4"></video></td>
169
+ </tr>
170
+ <tr>
171
+ <td style="text-align: center"><img src="assets/yann-lecun/yann lecun.png" style="width:100%"></td>
172
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/7f7f8cd0-7ca3-47b4-a44d-8c6b399bdbc4" type="video/mp4"></video></td>
173
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/eb173058-2314-470a-8cf4-3702036022ad" type="video/mp4"></video></td>
174
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/cd5a9687-bae0-47fd-b82c-febc0d343ac2" type="video/mp4"></video></td>
175
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/e08c778c-5e87-40f6-a7a1-328f5d0d016f" type="video/mp4"></video></td>
176
+ </tr>
177
+ <tr>
178
+ <td style="text-align: center"><img src="assets/gakki/gakki.png" style="width:100%"></td>
179
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/0370714b-d10c-422d-adee-76f6221aa1be" type="video/mp4"></video></td>
180
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/79cd95d2-95ea-4854-816e-2caf0cbebf94" type="video/mp4"></video></td>
181
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/60fa6b2a-6e1a-48c0-a777-4b62504ff679" type="video/mp4"></video></td>
182
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/c72836ee-9d7a-4525-a48a-7017b60f83f3" type="video/mp4"></video></td>
183
+ </tr>
184
+
185
+ </table>
186
+
187
+ ## Installation
188
+
189
+ ### Manually Download Model Checkpoints
190
+ - Download Stable Diffusion V1.5 into ``animatediff/sd``:
191
+
192
+ ``git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 animatediff/sd``
193
+ - Download AnimateDiff motion module into ``models/v3_sd15_mm.ckpt``: https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_mm.ckpt
194
+ - Download Animatediff adapter into ``models/v3_adapter_sd_v15.ckpt``: https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_adapter.ckpt
195
+ - Download ID-Animator checkpoint into ``models/animator.ckpt`` from: https://huggingface.co/spaces/ID-Animator/ID-Animator/blob/main/animator.ckpt
196
+ - Download CLIP Image encoder into ``models/image_encoder/`` from: https://huggingface.co/spaces/ID-Animator/ID-Animator/tree/main/image_encoder
197
+ - Download AdaFace checkpoint into ``models/adaface/`` from: https://huggingface.co/adaface-neurips/adaface/tree/main/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt
198
+
199
+ ### Prepare the SAR Model
200
+
201
+ Manually download the three `.safetensors` models: the original [Stable Diffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors), [AbsoluteReality V1.8.1](https://civitai.com/models/81458?modelVersionId=132760), and [RealisticVision V4.0](https://civitai.com/models/4201?modelVersionId=114367). Save them to `models/sar`.
202
+
203
+ Run the following command to generate an average of the three models:
204
+ ```
205
+ python3 scripts/avg_models.py --input models/sar/absolutereality_v181.safetensors models/sar/realisticVisionV40_v40VAE.safetensors models/sar/v1-5-pruned.safetensors --output models/sar/sar.safetensors
206
+ ```
207
+
208
+ \[Optional Improvement\]
209
+ 1. You can replace the VAE of the SAR model with the [MSE-840000 finetuned VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse-original/tree/main) for slightly better video details:
210
+ ```
211
+ python3 scripts/repl_vae.py --base_ckpt models/sar/sar.safetensors --vae_ckpt models/sar/vae-ft-mse-840000-ema-pruned.ckpt --out_ckpt models/sar/sar-vae.safetensors
212
+ mv models/sar/sar-vae.safetensors models/sar/sar.safetensors
213
+ ```
214
+
215
+ 2. You can replace the text encoder of the SAR model with the text encoder of [DreamShaper V8](https://civitai.com/models/4384?modelVersionId=252914) for slightly more authentic facial features:
216
+ ```
217
+ python3 scripts/repl_textencoder.py --base_ckpt models/sar/sar.safetensors --te_ckpt models/sar/dreamshaper_8.safetensors --out_ckpt models/sar/sar2.safetensors
218
+ mv models/sar/sar2.safetensors models/sar/sar.safetensors
219
+ ```
220
+ ### Inference
221
+
222
+ Run the demo inference scripts:
223
+ ```
224
+ python3 app.py
225
+ ```
226
+ Then connect to the Gradio interface at `local-ip-address:7860` or `https://*.gradio.live` shown in the terminal.
227
+
228
+ #### Use of Initial Image
229
+ The use of an initial image is optional. It usually helps stabilize the animation sequence and improve the quality.
230
+
231
+ You can generate 3 initial images in one go by clicking "Generate 3 new init images". The images will be based on the same prompt as the video generation. You can also use different prompts for the initial images and the video generation. Select the desired initial image by clicking on the image, and then click "Generate Video". If none of the initial images are good enough, you can generate again by clicking "Generate 3 new init images" again.
232
+
233
+ ### Common Issues
234
+ 1. **Defocus**. This is the biggest possible issue. When the subject is far from the camera, the model may not be able to generate a clear face and control the subject's facial details. In this situation, consider to increase the weights of "Image Embedding Scale", "Attention Processor Scale" and "AdaFace Embedding ID CFG Scale". You can also add a prefix "face portrait of" to the prompt to help the model focus on the face.
235
+ 2. **Motion Degeneration**. When the subject is too close to the camera, the model may not be able to generate correct motions and poses, and only generate the face. In this situation, consider to decrease the weights of "Image Embedding Scale", "Attention Processor Scale" and "AdaFace Embedding ID CFG Scale". You can also adjust the prompt slightly to let it focus on the whole body.
236
+ 3. **Lesser Facial Characteristics**. If the subject's facial characteristics is not so distinctive, you can increase the weights of "AdaFace Embedding ID CFG Scale".
237
+ 4. **Unstable Motions**. If the generated video has unstable motions, this is probably due to the limitations of AnimateDiff. Nonetheless, you can make it more stable by using a carefully selected initial image, and optionally increase the "Init Image Strength" and "Final Weight of the Init Image". Note that when "Final Weight of the Init Image" is larger, the motion in the generated video will be less dynamic.
238
+
239
+
240
+ ## Disclaimer
241
+ This project is intended for academic purposes only. We do not accept responsibility for user-generated content. Users are solely responsible for their own actions. The contributors to this project are not legally affiliated with, nor are they liable for, the actions of users. Please use this generative model responsibly, in accordance with ethical and legal standards.
adaface/adaface-infer.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from adaface.adaface_wrapper import AdaFaceWrapper
2
+ import torch
3
+ #import torch.nn.functional as F
4
+ from PIL import Image
5
+ import numpy as np
6
+ import os, argparse, glob, re
7
+
8
+ def save_images(images, num_images_per_row, subject_name, prompt, noise_level, save_dir = "samples-ada"):
9
+ if num_images_per_row > len(images):
10
+ num_images_per_row = len(images)
11
+
12
+ os.makedirs(save_dir, exist_ok=True)
13
+
14
+ num_columns = int(np.ceil(len(images) / num_images_per_row))
15
+ # Save 4 images as a grid image in save_dir
16
+ grid_image = Image.new('RGB', (512 * num_images_per_row, 512 * num_columns))
17
+ for i, image in enumerate(images):
18
+ image = image.resize((512, 512))
19
+ grid_image.paste(image, (512 * (i % num_images_per_row), 512 * (i // num_images_per_row)))
20
+
21
+ prompt_sig = prompt.replace(" ", "_").replace(",", "_")
22
+ grid_filepath = os.path.join(save_dir, f"{subject_name}-{prompt_sig}-noise{noise_level:.02f}.png")
23
+ if os.path.exists(grid_filepath):
24
+ grid_count = 2
25
+ grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-noise{noise_level:.02f}-{grid_count}.jpg')
26
+ while os.path.exists(grid_filepath):
27
+ grid_count += 1
28
+ grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-noise{noise_level:.02f}-{grid_count}.jpg')
29
+
30
+ grid_image.save(grid_filepath)
31
+ print(f"Saved to {grid_filepath}")
32
+
33
+ def seed_everything(seed):
34
+ np.random.seed(seed)
35
+ torch.manual_seed(seed)
36
+ torch.cuda.manual_seed_all(seed)
37
+ torch.backends.cudnn.deterministic = True
38
+ torch.backends.cudnn.benchmark = False
39
+ os.environ["PL_GLOBAL_SEED"] = str(seed)
40
+
41
+ def parse_args():
42
+ parser = argparse.ArgumentParser()
43
+ parser.add_argument("--base_model_path", type=str, default='runwayml/stable-diffusion-v1-5',
44
+ help="Type of checkpoints to use (default: SD 1.5)")
45
+ parser.add_argument("--embman_ckpt", type=str, required=True,
46
+ help="Path to the checkpoint of the embedding manager")
47
+ parser.add_argument("--subject", type=str, required=True)
48
+ parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use")
49
+ parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate")
50
+ parser.add_argument("--prompt", type=str, default="a woman z in superman costume")
51
+ parser.add_argument("--noise", dest='noise_level', type=float, default=0)
52
+ parser.add_argument("--randface", action="store_true")
53
+ parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
54
+ help="Guidance scale for the diffusion model")
55
+ parser.add_argument("--id_cfg_scale", type=float, default=1,
56
+ help="CFG scale when generating the identity embeddings")
57
+
58
+ parser.add_argument("--subject_string",
59
+ type=str, default="z",
60
+ help="Subject placeholder string used in prompts to denote the concept.")
61
+ parser.add_argument("--num_vectors", type=int, default=16,
62
+ help="Number of vectors used to represent the subject.")
63
+ parser.add_argument("--num_images_per_row", type=int, default=4,
64
+ help="Number of images to display in a row in the output grid image.")
65
+ parser.add_argument("--num_inference_steps", type=int, default=50,
66
+ help="Number of DDIM inference steps")
67
+ parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
68
+ parser.add_argument("--seed", type=int, default=42,
69
+ help="the seed (for reproducible sampling). Set to -1 to disable.")
70
+ args = parser.parse_args()
71
+
72
+ return args
73
+
74
+ if __name__ == "__main__":
75
+ args = parse_args()
76
+ if args.seed != -1:
77
+ seed_everything(args.seed)
78
+
79
+ if re.match(r"^\d+$", args.device):
80
+ args.device = f"cuda:{args.device}"
81
+ print(f"Using device {args.device}")
82
+
83
+ adaface = AdaFaceWrapper("text2img", args.base_model_path, args.embman_ckpt, args.device,
84
+ args.subject_string, args.num_vectors, args.num_inference_steps)
85
+
86
+ if not args.randface:
87
+ image_folder = args.subject
88
+ if image_folder.endswith("/"):
89
+ image_folder = image_folder[:-1]
90
+
91
+ if os.path.isfile(image_folder):
92
+ # Get the second to the last part of the path
93
+ subject_name = os.path.basename(os.path.dirname(image_folder))
94
+ image_paths = [image_folder]
95
+
96
+ else:
97
+ subject_name = os.path.basename(image_folder)
98
+ image_types = ["*.jpg", "*.png", "*.jpeg"]
99
+ alltype_image_paths = []
100
+ for image_type in image_types:
101
+ # glob returns the full path.
102
+ image_paths = glob.glob(os.path.join(image_folder, image_type))
103
+ if len(image_paths) > 0:
104
+ alltype_image_paths.extend(image_paths)
105
+
106
+ # Filter out images of "*_mask.png"
107
+ alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path]
108
+
109
+ # image_paths contain at most args.example_image_count full image paths.
110
+ if args.example_image_count > 0:
111
+ image_paths = alltype_image_paths[:args.example_image_count]
112
+ else:
113
+ image_paths = alltype_image_paths
114
+ else:
115
+ subject_name = None
116
+ image_paths = None
117
+ image_folder = None
118
+
119
+ subject_name = "randface-" + str(torch.seed()) if args.randface else subject_name
120
+ rand_face_embs = torch.randn(1, 512)
121
+
122
+ pre_face_embs = rand_face_embs if args.randface else None
123
+ noise = torch.randn(args.out_image_count, 4, 64, 64).cuda()
124
+ # args.noise_level: the *relative* std of the noise added to the face embeddings.
125
+ # A noise level of 0.08 could change gender, but 0.06 is usually safe.
126
+ # adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call).
127
+ adaface_subj_embs = adaface.generate_adaface_embeddings(image_paths, image_folder, pre_face_embs, args.randface,
128
+ out_id_embs_scale=args.id_cfg_scale, noise_level=args.noise_level,
129
+ update_text_encoder=True)
130
+ images = adaface(noise, args.prompt, args.guidance_scale, args.out_image_count, verbose=True)
131
+ save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.noise_level)
adaface/adaface-translate.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from adaface.adaface_wrapper import AdaFaceWrapper
2
+ import torch
3
+ #import torch.nn.functional as F
4
+ from PIL import Image
5
+ import numpy as np
6
+ import os, argparse, glob, re, shutil
7
+
8
+ def str2bool(v):
9
+ if isinstance(v, bool):
10
+ return v
11
+ if v.lower() in ("yes", "true", "t", "y", "1"):
12
+ return True
13
+ elif v.lower() in ("no", "false", "f", "n", "0"):
14
+ return False
15
+ else:
16
+ raise argparse.ArgumentTypeError("Boolean value expected.")
17
+
18
+ def seed_everything(seed):
19
+ np.random.seed(seed)
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed_all(seed)
22
+ torch.backends.cudnn.deterministic = True
23
+ torch.backends.cudnn.benchmark = False
24
+ os.environ["PL_GLOBAL_SEED"] = str(seed)
25
+
26
+ def parse_args():
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument("--base_model_path", type=str, default='models/realisticvision/realisticVisionV40_v40VAE.safetensors',
29
+ help="Path to the UNet checkpoint (default: RealisticVision 4.0)")
30
+ parser.add_argument("--embman_ckpt", type=str, required=True,
31
+ help="Path to the checkpoint of the embedding manager")
32
+ parser.add_argument("--in_folder", type=str, required=True, help="Path to the folder containing input images")
33
+ # If True, the input folder contains images of mixed subjects.
34
+ # If False, the input folder contains multiple subfolders, each of which contains images of the same subject.
35
+ parser.add_argument("--is_mix_subj_folder", type=str2bool, const=True, default=False, nargs="?",
36
+ help="Whether the input folder contains images of mixed subjects")
37
+ parser.add_argument("--max_images_per_subject", type=int, default=5, help="Number of example images used per subject")
38
+ parser.add_argument("--trans_subject_count", type=int, default=-1, help="Number of example images to be translated")
39
+ parser.add_argument("--out_folder", type=str, required=True, help="Path to the folder saving output images")
40
+ parser.add_argument("--out_count_per_input_image", type=int, default=1, help="Number of output images to generate per input image")
41
+ parser.add_argument("--copy_masks", action="store_true", help="Copy the mask images to the output folder")
42
+ parser.add_argument("--noise", dest='noise_level', type=float, default=0)
43
+ parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
44
+ help="Guidance scale for the diffusion model")
45
+ parser.add_argument("--ref_img_strength", type=float, default=0.8,
46
+ help="Strength of the reference image in the output image.")
47
+ parser.add_argument("--subject_string",
48
+ type=str, default="z",
49
+ help="Subject placeholder string used in prompts to denote the concept.")
50
+ parser.add_argument("--num_vectors", type=int, default=16,
51
+ help="Number of vectors used to represent the subject.")
52
+ parser.add_argument("--prompt", type=str, default="a person z")
53
+ parser.add_argument("--num_images_per_row", type=int, default=4,
54
+ help="Number of images to display in a row in the output grid image.")
55
+ parser.add_argument("--num_inference_steps", type=int, default=50,
56
+ help="Number of DDIM inference steps")
57
+ parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use. If num_gpus > 1, use accelerate for distributed execution.")
58
+ parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
59
+ parser.add_argument("--seed", type=int, default=42,
60
+ help="the seed (for reproducible sampling). Set to -1 to disable.")
61
+ args = parser.parse_args()
62
+
63
+ return args
64
+
65
+ if __name__ == "__main__":
66
+ args = parse_args()
67
+ if args.seed != -1:
68
+ seed_everything(args.seed)
69
+
70
+ # screen -dm -L -Logfile trans_rv4-2.txt accelerate launch --multi_gpu --num_processes=2 scripts/adaface-translate.py
71
+ # --embman_ckpt logs/subjects-celebrity2024-05-16T17-22-46_zero3-ada/checkpoints/embeddings_gs-30000.pt
72
+ # --base_model_path models/realisticvision/realisticVisionV40_v40VAE.safetensors --in_folder /data/shaohua/VGGface2_HQ_masks/
73
+ # --is_mix_subj_folder 0 --out_folder /data/shaohua/VGGface2_HQ_masks_rv4a --copy_masks --num_gpus 2
74
+ if args.num_gpus > 1:
75
+ from accelerate import PartialState
76
+ distributed_state = PartialState()
77
+ args.device = distributed_state.device
78
+ process_index = distributed_state.process_index
79
+ elif re.match(r"^\d+$", args.device):
80
+ args.device = f"cuda:{args.device}"
81
+ distributed_state = None
82
+ process_index = 0
83
+
84
+ adaface = AdaFaceWrapper("img2img", args.base_model_path, args.embman_ckpt, args.device,
85
+ args.subject_string, args.num_vectors, args.num_inference_steps)
86
+
87
+ in_folder = args.in_folder
88
+ if os.path.isfile(in_folder):
89
+ subject_folders = [ os.path.dirname(in_folder) ]
90
+ images_by_subject = [[in_folder]]
91
+ else:
92
+ if not args.is_mix_subj_folder:
93
+ in_folders = [in_folder]
94
+ else:
95
+ in_folders = [ os.path.join(in_folder, subfolder) for subfolder in sorted(os.listdir(in_folder)) ]
96
+
97
+ images_by_subject = []
98
+ subject_folders = []
99
+ for in_folder in in_folders:
100
+ image_types = ["*.jpg", "*.png", "*.jpeg"]
101
+ alltype_image_paths = []
102
+ for image_type in image_types:
103
+ # glob returns the full path.
104
+ image_paths = glob.glob(os.path.join(in_folder, image_type))
105
+ if len(image_paths) > 0:
106
+ alltype_image_paths.extend(image_paths)
107
+
108
+ # Filter out images of "*_mask.png"
109
+ alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path]
110
+ alltype_image_paths = sorted(alltype_image_paths)
111
+
112
+ if not args.is_mix_subj_folder:
113
+ # image_paths contain at most args.max_images_per_subject full image paths.
114
+ if args.max_images_per_subject > 0:
115
+ image_paths = alltype_image_paths[:args.max_images_per_subject]
116
+ else:
117
+ image_paths = alltype_image_paths
118
+
119
+ images_by_subject.append(image_paths)
120
+ subject_folders.append(in_folder)
121
+ else:
122
+ # Each image in the folder is treated as an individual subject.
123
+ images_by_subject.extend([[image_path] for image_path in alltype_image_paths])
124
+ subject_folders.extend([in_folder] * len(alltype_image_paths))
125
+
126
+ if args.trans_subject_count > 0 and len(subject_folders) >= args.trans_subject_count:
127
+ break
128
+
129
+ if args.trans_subject_count > 0:
130
+ images_by_subject = images_by_subject[:args.trans_subject_count]
131
+ subject_folders = subject_folders[:args.trans_subject_count]
132
+
133
+ out_image_count = 0
134
+ out_mask_count = 0
135
+ if not args.out_folder.endswith("/"):
136
+ args.out_folder += "/"
137
+
138
+ if args.num_gpus > 1:
139
+ # Split the subjects across the GPUs.
140
+ subject_folders = subject_folders[process_index::args.num_gpus]
141
+ images_by_subject = images_by_subject[process_index::args.num_gpus]
142
+ #subject_folders, images_by_subject = distributed_state.split_between_processes(zip(subject_folders, images_by_subject))
143
+
144
+ for (subject_folder, image_paths) in zip(subject_folders, images_by_subject):
145
+ # If is_mix_subj_folder, then image_paths only contains 1 image, and we use the file name as the signature of the image.
146
+ # Otherwise, we use the folder name as the signature of the images.
147
+ images_sig = subject_folder if not args.is_mix_subj_folder else os.path.basename(image_paths[0])
148
+
149
+ print(f"Translating {images_sig}...")
150
+ with torch.no_grad():
151
+ adaface_subj_embs = adaface.generate_adaface_embeddings(image_paths, subject_folder, None, False,
152
+ out_id_embs_scale=1, noise_level=args.noise_level,
153
+ update_text_encoder=True)
154
+
155
+ # Replace the first occurrence of "in_folder" with "out_folder" in the path of the subject_folder.
156
+ subject_out_folder = subject_folder.replace(args.in_folder, args.out_folder, 1)
157
+ if not os.path.exists(subject_out_folder):
158
+ os.makedirs(subject_out_folder)
159
+ print(f"Output images will be saved to {subject_out_folder}")
160
+
161
+ in_images = []
162
+ for image_path in image_paths:
163
+ image = Image.open(image_path).convert("RGB").resize((512, 512))
164
+ # [512, 512, 3] -> [3, 512, 512].
165
+ image = np.array(image).transpose(2, 0, 1)
166
+ # Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU.
167
+ image = torch.tensor(image).unsqueeze(0).float().cuda()
168
+ in_images.append(image)
169
+
170
+ # Put all input images of the subject into a batch. This assumes max_images_per_subject is small.
171
+ # NOTE: For simplicity, we do not check overly large batch sizes.
172
+ in_images = torch.cat(in_images, dim=0)
173
+ # in_images: [5, 3, 512, 512].
174
+ # Normalize the pixel values to [0, 1].
175
+ in_images = in_images / 255.0
176
+ num_out_images = len(in_images) * args.out_count_per_input_image
177
+
178
+ with torch.no_grad():
179
+ # args.noise_level: the *relative* std of the noise added to the face embeddings.
180
+ # A noise level of 0.08 could change gender, but 0.06 is usually safe.
181
+ # The returned adaface_subj_embs are already incorporated in the text encoder, and not used explicitly.
182
+ # NOTE: We assume out_count_per_input_image == 1, so that the output images are of the same number as the input images.
183
+ out_images = adaface(in_images, args.prompt, args.guidance_scale, num_out_images, ref_img_strength=args.ref_img_strength)
184
+
185
+ for img_i, img in enumerate(out_images):
186
+ # out_images: subj_1, subj_2, ..., subj_n, subj_1, subj_2, ..., subj_n, ...
187
+ subj_i = img_i % len(in_images)
188
+ copy_i = img_i // len(in_images)
189
+ image_filename_stem, image_fileext = os.path.splitext(os.path.basename(image_paths[subj_i]))
190
+ if copy_i == 0:
191
+ img.save(os.path.join(subject_out_folder, f"{image_filename_stem}{image_fileext}"))
192
+ else:
193
+ img.save(os.path.join(subject_out_folder, f"{image_filename_stem}_{copy_i}{image_fileext}"))
194
+
195
+ if args.copy_masks:
196
+ mask_path = image_paths[subj_i].replace(image_fileext, "_mask.png")
197
+ if os.path.exists(mask_path):
198
+ if copy_i == 0:
199
+ shutil.copy(mask_path, subject_out_folder)
200
+ else:
201
+ mask_filename_stem = image_filename_stem
202
+ shutil.copy(mask_path, os.path.join(subject_out_folder, f"{mask_filename_stem}_{copy_i}_mask.png"))
203
+
204
+ out_mask_count += 1
205
+
206
+ out_image_count += len(out_images)
207
+
208
+ print(f"{out_image_count} output images and {out_mask_count} masks saved to {args.out_folder}")
adaface/adaface_wrapper.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPTextModel
4
+ from diffusers import (
5
+ StableDiffusionPipeline,
6
+ StableDiffusionImg2ImgPipeline,
7
+ UNet2DConditionModel,
8
+ DDIMScheduler,
9
+ AutoencoderKL,
10
+ )
11
+ from insightface.app import FaceAnalysis
12
+ from adaface.arc2face_models import CLIPTextModelWrapper
13
+ from adaface.util import get_arc2face_id_prompt_embs
14
+ import re, os
15
+
16
+ class AdaFaceWrapper(nn.Module):
17
+ def __init__(self, pipeline_name, base_model_path, adaface_ckpt_path, device,
18
+ subject_string='z', num_vectors=16,
19
+ num_inference_steps=50, negative_prompt=None,
20
+ use_840k_vae=False, use_ds_text_encoder=False, is_training=False):
21
+ '''
22
+ pipeline_name: "text2img" or "img2img" or None. If None, the unet and vae are
23
+ removed from the pipeline to release RAM.
24
+ '''
25
+ super().__init__()
26
+ self.pipeline_name = pipeline_name
27
+ self.base_model_path = base_model_path
28
+ self.adaface_ckpt_path = adaface_ckpt_path
29
+ self.use_840k_vae = use_840k_vae
30
+ self.use_ds_text_encoder = use_ds_text_encoder
31
+ self.subject_string = subject_string
32
+ self.num_vectors = num_vectors
33
+ self.num_inference_steps = num_inference_steps
34
+ self.device = device
35
+ self.is_training = is_training
36
+ self.initialize_pipeline()
37
+ self.extend_tokenizer_and_text_encoder()
38
+ if negative_prompt is None:
39
+ self.negative_prompt = \
40
+ "flaws in the eyes, flaws in the face, lowres, non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, " \
41
+ "mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, deformed eyeballs, cross-eyed, blurry, " \
42
+ "mutation, duplicate, out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, " \
43
+ "nude, naked, nsfw, topless, bare breasts"
44
+ else:
45
+ self.negative_prompt = negative_prompt
46
+
47
+ def load_subj_basis_generator(self, adaface_ckpt_path):
48
+ ckpt = torch.load(adaface_ckpt_path, map_location='cpu')
49
+ string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
50
+ if self.subject_string not in string_to_subj_basis_generator_dict:
51
+ print(f"Subject '{self.subject_string}' not found in the embedding manager.")
52
+ breakpoint()
53
+
54
+ self.subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string]
55
+ # In the original ckpt, num_out_layers is 16 for layerwise embeddings.
56
+ # But we don't do layerwise embeddings here, so we set it to 1.
57
+ self.subj_basis_generator.num_out_layers = 1
58
+ print(f"Loaded subject basis generator for '{self.subject_string}'.")
59
+ print(repr(self.subj_basis_generator))
60
+ self.subj_basis_generator.to(self.device)
61
+ if self.is_training:
62
+ self.subj_basis_generator.train()
63
+ else:
64
+ self.subj_basis_generator.eval()
65
+
66
+ def initialize_pipeline(self):
67
+ self.load_subj_basis_generator(self.adaface_ckpt_path)
68
+ # arc2face_text_encoder maps the face analysis embedding to 16 face embeddings
69
+ # in the UNet image space.
70
+ arc2face_text_encoder = CLIPTextModelWrapper.from_pretrained(
71
+ 'models/arc2face', subfolder="encoder", torch_dtype=torch.float16
72
+ )
73
+ self.arc2face_text_encoder = arc2face_text_encoder.to(self.device)
74
+
75
+ if self.use_840k_vae:
76
+ # The 840000-step vae model is slightly better in face details than the original vae model.
77
+ # https://huggingface.co/stabilityai/sd-vae-ft-mse-original
78
+ vae = AutoencoderKL.from_single_file("models/diffusers/sd-vae-ft-mse-original/vae-ft-mse-840000-ema-pruned.ckpt", torch_dtype=torch.float16)
79
+ else:
80
+ vae = None
81
+
82
+ if self.use_ds_text_encoder:
83
+ # The dreamshaper v7 finetuned text encoder follows the prompt slightly better than the original text encoder.
84
+ # https://huggingface.co/Lykon/DreamShaper/tree/main/text_encoder
85
+ text_encoder = CLIPTextModel.from_pretrained("models/ds_text_encoder", torch_dtype=torch.float16)
86
+ else:
87
+ text_encoder = None
88
+
89
+ remove_unet = False
90
+
91
+ if self.pipeline_name == "img2img":
92
+ PipelineClass = StableDiffusionImg2ImgPipeline
93
+ elif self.pipeline_name == "text2img":
94
+ PipelineClass = StableDiffusionPipeline
95
+ # pipeline_name is None means only use this instance to generate adaface embeddings, not to generate images.
96
+ elif self.pipeline_name is None:
97
+ PipelineClass = StableDiffusionPipeline
98
+ remove_unet = True
99
+ else:
100
+ raise ValueError(f"Unknown pipeline name: {self.pipeline_name}")
101
+
102
+ if os.path.isfile(self.base_model_path):
103
+ pipeline = PipelineClass.from_single_file(
104
+ self.base_model_path,
105
+ torch_dtype=torch.float16
106
+ )
107
+ else:
108
+ pipeline = PipelineClass.from_pretrained(
109
+ self.base_model_path,
110
+ torch_dtype=torch.float16,
111
+ safety_checker=None
112
+ )
113
+ print(f"Loaded pipeline from {self.base_model_path}.")
114
+
115
+ if self.use_840k_vae:
116
+ pipeline.vae = vae
117
+ print("Replaced the VAE with the 840k-step VAE.")
118
+
119
+ if self.use_ds_text_encoder:
120
+ pipeline.text_encoder = text_encoder
121
+ print("Replaced the text encoder with the DreamShaper text encoder.")
122
+
123
+ if remove_unet:
124
+ # Remove unet and vae to release RAM. Only keep tokenizer and text_encoder.
125
+ pipeline.unet = None
126
+ pipeline.vae = None
127
+ print("Removed UNet and VAE from the pipeline.")
128
+
129
+ noise_scheduler = DDIMScheduler(
130
+ num_train_timesteps=1000,
131
+ beta_start=0.00085,
132
+ beta_end=0.012,
133
+ beta_schedule="scaled_linear",
134
+ clip_sample=False,
135
+ set_alpha_to_one=False,
136
+ steps_offset=1,
137
+ )
138
+
139
+ pipeline.scheduler = noise_scheduler
140
+ self.pipeline = pipeline.to(self.device)
141
+ # FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
142
+ # Note there's a second "model" in the path.
143
+ self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
144
+ self.face_app.prepare(ctx_id=0, det_size=(512, 512))
145
+ # Patch the missing tokenizer in the subj_basis_generator.
146
+ if not hasattr(self.subj_basis_generator, 'clip_tokenizer'):
147
+ self.subj_basis_generator.clip_tokenizer = self.pipeline.tokenizer
148
+ print("Patched the missing tokenizer in the subj_basis_generator.")
149
+
150
+ def extend_tokenizer_and_text_encoder(self):
151
+ if self.num_vectors < 1:
152
+ raise ValueError(f"num_vectors has to be larger or equal to 1, but is {self.num_vectors}")
153
+
154
+ tokenizer = self.pipeline.tokenizer
155
+ # Add z0, z1, z2, ..., z15.
156
+ self.placeholder_tokens = []
157
+ for i in range(0, self.num_vectors):
158
+ self.placeholder_tokens.append(f"{self.subject_string}_{i}")
159
+
160
+ self.placeholder_tokens_str = " ".join(self.placeholder_tokens)
161
+
162
+ # Add the new tokens to the tokenizer.
163
+ num_added_tokens = tokenizer.add_tokens(self.placeholder_tokens)
164
+ if num_added_tokens != self.num_vectors:
165
+ raise ValueError(
166
+ f"The tokenizer already contains the token {self.subject_string}. Please pass a different"
167
+ " `subject_string` that is not already in the tokenizer.")
168
+
169
+ print(f"Added {num_added_tokens} tokens ({self.placeholder_tokens_str}) to the tokenizer.")
170
+
171
+ # placeholder_token_ids: [49408, ..., 49423].
172
+ self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.placeholder_tokens)
173
+ # print(self.placeholder_token_ids)
174
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
175
+ old_weight = self.pipeline.text_encoder.get_input_embeddings().weight
176
+ self.pipeline.text_encoder.resize_token_embeddings(len(tokenizer))
177
+ new_weight = self.pipeline.text_encoder.get_input_embeddings().weight
178
+ print(f"Resized text encoder token embeddings from {old_weight.shape} to {new_weight.shape} on {new_weight.device}.")
179
+
180
+ # Extend pipeline.text_encoder with the adaface subject emeddings.
181
+ # subj_embs: [16, 768].
182
+ def update_text_encoder_subj_embs(self, subj_embs):
183
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
184
+ token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data
185
+ with torch.no_grad():
186
+ for i, token_id in enumerate(self.placeholder_token_ids):
187
+ token_embeds[token_id] = subj_embs[i]
188
+ print(f"Updated {len(self.placeholder_token_ids)} tokens ({self.placeholder_tokens_str}) in the text encoder.")
189
+
190
+ def update_prompt(self, prompt):
191
+ # If the placeholder tokens are already in the prompt, then return the prompt as is.
192
+ if self.placeholder_tokens_str in prompt:
193
+ return prompt
194
+
195
+ # If the subject string 'z' is not in the prompt, then simply prepend the placeholder tokens to the prompt.
196
+ if re.search(r'\b' + self.subject_string + r'\b', prompt) is None:
197
+ print(f"Subject string '{self.subject_string}' not found in the prompt. Adding it.")
198
+ comp_prompt = self.placeholder_tokens_str + " " + prompt
199
+ else:
200
+ # Replace the subject string 'z' with the placeholder tokens.
201
+ comp_prompt = re.sub(r'\b' + self.subject_string + r'\b', self.placeholder_tokens_str, prompt)
202
+ return comp_prompt
203
+
204
+ # image_paths: a list of image paths. image_folder: the parent folder name.
205
+ def generate_adaface_embeddings(self, image_paths, image_folder=None,
206
+ pre_face_embs=None, gen_rand_face=False,
207
+ out_id_embs_scale=1., noise_level=0, update_text_encoder=True):
208
+ # faceid_embeds is a batch of extracted face analysis embeddings (BS * 512 = id_batch_size * 512).
209
+ # If extract_faceid_embeds is True, faceid_embeds is *the same* embedding repeated by id_batch_size times.
210
+ # Otherwise, faceid_embeds is a batch of random embeddings, each instance is different.
211
+ # The same applies to id_prompt_emb.
212
+ # faceid_embeds is in the face analysis embeddings. id_prompt_emb is in the image prompt space.
213
+ # Here id_batch_size = 1, so
214
+ # faceid_embeds: [1, 512]. NOT used later.
215
+ # id_prompt_emb: [1, 16, 768].
216
+ # NOTE: Since return_core_id_embs is True, id_prompt_emb is only the 16 core ID embeddings.
217
+ # arc2face prompt template: "photo of a id person"
218
+ # ID embeddings start from "id person ...". So there are 3 template tokens before the 16 ID embeddings.
219
+ faceid_embeds, id_prompt_emb \
220
+ = get_arc2face_id_prompt_embs(self.face_app, self.pipeline.tokenizer, self.arc2face_text_encoder,
221
+ extract_faceid_embeds=not gen_rand_face,
222
+ pre_face_embs=pre_face_embs,
223
+ # image_folder is passed only for logging purpose.
224
+ # image_paths contains the paths of the images.
225
+ image_folder=image_folder, image_paths=image_paths,
226
+ images_np=None,
227
+ id_batch_size=1,
228
+ device=self.device,
229
+ # input_max_length == 22: only keep the first 22 tokens,
230
+ # including 3 template tokens and 16 ID tokens, and BOS and EOS tokens.
231
+ # The results are indistinguishable from input_max_length=77.
232
+ input_max_length=22,
233
+ noise_level=noise_level,
234
+ return_core_id_embs=True,
235
+ gen_neg_prompt=False,
236
+ verbose=True)
237
+
238
+ # adaface_subj_embs: [1, 1, 16, 768].
239
+ # adaface_prompt_embs: [1, 77, 768] (not used).
240
+ adaface_subj_embs, adaface_prompt_embs = \
241
+ self.subj_basis_generator(id_prompt_emb, None, None,
242
+ out_id_embs_scale=out_id_embs_scale,
243
+ is_face=True, is_training=False,
244
+ adaface_prompt_embs_inf_type='full_half_pad')
245
+ # adaface_subj_embs: [16, 768]
246
+ adaface_subj_embs = adaface_subj_embs.squeeze()
247
+ if update_text_encoder:
248
+ self.update_text_encoder_subj_embs(adaface_subj_embs)
249
+ return adaface_subj_embs
250
+
251
+ def encode_prompt(self, prompt, device="cuda", verbose=False):
252
+ prompt = self.update_prompt(prompt)
253
+ if verbose:
254
+ print(f"Prompt: {prompt}")
255
+
256
+ # For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device).
257
+ # So we manually move it to GPU here.
258
+ self.pipeline.text_encoder.to(device)
259
+ # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
260
+ prompt_embeds_, negative_prompt_embeds_ = \
261
+ self.pipeline.encode_prompt(prompt, device=device, num_images_per_prompt=1,
262
+ do_classifier_free_guidance=True, negative_prompt=self.negative_prompt)
263
+ return prompt_embeds_, negative_prompt_embeds_
264
+
265
+ # ref_img_strength is used only in the img2img pipeline.
266
+ def forward(self, noise, prompt, guidance_scale=4.0, out_image_count=4, ref_img_strength=0.8, verbose=False):
267
+ # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
268
+ prompt_embeds_, negative_prompt_embeds_ = self.encode_prompt(prompt, device=self.device, verbose=verbose)
269
+
270
+ # Repeat the prompt embeddings for all images in the batch.
271
+ prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1)
272
+ negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1)
273
+ noise = noise.to(self.device).to(torch.float16)
274
+
275
+ # noise: [BS, 4, 64, 64]
276
+ # When the pipeline is text2img, strength is ignored.
277
+ images = self.pipeline(image=noise,
278
+ prompt_embeds=prompt_embeds_,
279
+ negative_prompt_embeds=negative_prompt_embeds_,
280
+ num_inference_steps=self.num_inference_steps,
281
+ guidance_scale=guidance_scale,
282
+ num_images_per_prompt=1,
283
+ strength=ref_img_strength).images
284
+ # images: [BS, 3, 512, 512]
285
+ return images
286
+
adaface/arc2face_models.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPTextModel
4
+ from transformers.models.clip.modeling_clip import CLIPAttention
5
+ from typing import Any, Callable, Dict, Optional, Tuple, Union, List
6
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
7
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
8
+ # from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
9
+ _make_causal_mask = AttentionMaskConverter._make_causal_mask
10
+ _expand_mask = AttentionMaskConverter._expand_mask
11
+
12
+ from adaface.util import add_noise_to_tensor
13
+
14
+ # Extend CLIPAttention by using multiple k_proj and v_proj in each head.
15
+ # To avoid too much increase of computation, we don't extend q_proj.
16
+ class CLIPAttentionMKV(nn.Module):
17
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
18
+
19
+ def __init__(self, config, multiplier=2):
20
+ super().__init__()
21
+ self.config = config
22
+ self.embed_dim = config.hidden_size
23
+ self.num_heads = config.num_attention_heads
24
+ self.head_dim = self.embed_dim // self.num_heads
25
+ if self.head_dim * self.num_heads != self.embed_dim:
26
+ raise ValueError(
27
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
28
+ f" {self.num_heads})."
29
+ )
30
+ self.scale = self.head_dim**-0.5
31
+ self.dropout = config.attention_dropout
32
+ self.multiplier = multiplier
33
+
34
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim * self.multiplier)
35
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim * self.multiplier)
36
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
37
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
38
+
39
+ # The (approximately) repeated token features are repeated along the last dim in tensor
40
+ # (multiplier * num_heads * head_dim), and then reshaped to (bsz, -1, num_heads, head_dim).
41
+ # Therefore, the "multiplier" dim is tucked into the seq_len dim, which looks like
42
+ # [token1_emb, token1_emb, token2_emb, token2_emb, ..., tokenN_emb, tokenN_emb].
43
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
44
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
45
+
46
+ def extend_weights(self, clip_attn_layer, layer_idx, multiplier, noise_std=0.1,
47
+ noise_std_is_relative=True, keep_norm=False, verbose=False):
48
+ self.multiplier *= multiplier
49
+ # q_proj and out_proj are the same as the original CLIPAttention.
50
+ self.q_proj.weight.data = clip_attn_layer.q_proj.weight.data.clone()
51
+ self.q_proj.bias.data = clip_attn_layer.q_proj.bias.data.clone()
52
+ self.out_proj.weight.data = clip_attn_layer.out_proj.weight.data.clone()
53
+ self.out_proj.bias.data = clip_attn_layer.out_proj.bias.data.clone()
54
+
55
+ # bias doesn't need noise perturbation, as after the weights are noised,
56
+ # different copies of the weight/bias will receive different gradients,
57
+ # making the bias terms diverge and identifiable after training.
58
+ self.v_proj.bias.data = clip_attn_layer.v_proj.bias.data.repeat(multiplier)
59
+ self.k_proj.bias.data = clip_attn_layer.k_proj.bias.data.repeat(multiplier)
60
+
61
+ self.v_proj.weight.data = clip_attn_layer.v_proj.weight.data.repeat(multiplier, 1)
62
+ self.k_proj.weight.data = clip_attn_layer.k_proj.weight.data.repeat(multiplier, 1)
63
+
64
+ if noise_std > 0:
65
+ ORIG_V_SHAPE = list(clip_attn_layer.v_proj.weight.shape)
66
+ ORIG_V_SHAPE_D0 = ORIG_V_SHAPE[0]
67
+ # Adding noise to the extra copies of the weights (keep the first copy unchanged).
68
+ self.v_proj.weight.data[ORIG_V_SHAPE_D0:] = \
69
+ add_noise_to_tensor(self.v_proj.weight.data[ORIG_V_SHAPE_D0:],
70
+ noise_std, noise_std_is_relative, keep_norm)
71
+ if verbose:
72
+ NEW_V_SHAPE = list(self.v_proj.weight.shape)
73
+ NOISED_V_SHAPE = list(self.v_proj.weight.data[ORIG_V_SHAPE_D0:].shape)
74
+ print(f"Layer {layer_idx}: {NOISED_V_SHAPE} in {NEW_V_SHAPE} of v_proj is added with {noise_std} noise")
75
+
76
+ ORIG_K_SHAPE = list(clip_attn_layer.k_proj.weight.shape)
77
+ ORIG_K_SHAPE_D0 = ORIG_K_SHAPE[0]
78
+ # Adding noise to the extra copies of the weights.
79
+ self.k_proj.weight.data[ORIG_K_SHAPE_D0:] = \
80
+ add_noise_to_tensor(self.k_proj.weight.data[ORIG_K_SHAPE_D0:],
81
+ noise_std, noise_std_is_relative, keep_norm)
82
+ if verbose:
83
+ NEW_K_SHAPE = list(self.k_proj.weight.shape)
84
+ NOISED_K_SHAPE = list(self.k_proj.weight.data[ORIG_K_SHAPE_D0:].shape)
85
+ print(f"Layer {layer_idx}: {NOISED_K_SHAPE} in {NEW_K_SHAPE} of k_proj is added with {noise_std} noise")
86
+
87
+ def forward(
88
+ self,
89
+ hidden_states: torch.Tensor,
90
+ attention_mask: Optional[torch.Tensor] = None,
91
+ causal_attention_mask: Optional[torch.Tensor] = None,
92
+ output_attentions: Optional[bool] = False,
93
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
94
+ """Input shape: Batch x Time x Channel"""
95
+
96
+ bsz, tgt_len, embed_dim = hidden_states.size()
97
+
98
+ query_states = self.q_proj(hidden_states) * self.scale
99
+ # For key_states and value_states, the multiplier is absorbed into the seq_len (dim 1, shape specified as -1).
100
+ # [token0_head_emb, token0_head_emb, token1_head_emb, token1_head_emb, ..., tokenN-1_head_emb, tokenN-1_head_emb].
101
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
102
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
103
+
104
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
105
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
106
+ key_states = key_states.view(*proj_shape)
107
+ value_states = value_states.view(*proj_shape)
108
+
109
+ src_len = key_states.size(1)
110
+ # src_len0 is the original src_len without the multiplier.
111
+ src_len0 = src_len // self.multiplier
112
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
113
+
114
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
115
+ raise ValueError(
116
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
117
+ f" {attn_weights.size()}"
118
+ )
119
+
120
+ # apply the causal_attention_mask first
121
+ if causal_attention_mask is not None:
122
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len0):
123
+ raise ValueError(
124
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len0)}, but is"
125
+ f" {causal_attention_mask.size()}"
126
+ )
127
+ # The last dim of attn_weights corresponds to [token0, token0, token1, token1, ..., tokenN-1, tokenN-1].
128
+ # If reshaping it as (self.multiplier, src_len0), it will become
129
+ # [[token0, token0, token1, token1, ..., tokenN//2], [tokenN//2+1, tokenN//2+1, ..., tokenN-1, tokenN-1]],
130
+ # and the mask will be applied to wrong elements.
131
+ # If reshaping it as (src_len0, self.multiplier), it will become
132
+ # [[token0, token1, ..., tokenN-1], [token0, token1, ..., tokenN-1]], and then
133
+ # the mask at element i will mask all the multiplier elements at i, which is desired.
134
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len0, self.multiplier) + causal_attention_mask.unsqueeze(4)
135
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
136
+
137
+ if attention_mask is not None:
138
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len0):
139
+ raise ValueError(
140
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len0)}, but is {attention_mask.size()}"
141
+ )
142
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len0, self.multiplier) + attention_mask.unsqueeze(4)
143
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
144
+
145
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
146
+
147
+ if output_attentions:
148
+ # this operation is a bit awkward, but it's required to
149
+ # make sure that attn_weights keeps its gradient.
150
+ # In order to do so, attn_weights have to reshaped
151
+ # twice and have to be reused in the following
152
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
153
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
154
+ else:
155
+ attn_weights_reshaped = None
156
+
157
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
158
+
159
+ attn_output = torch.bmm(attn_probs, value_states)
160
+
161
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
162
+ raise ValueError(
163
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
164
+ f" {attn_output.size()}"
165
+ )
166
+
167
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
168
+ attn_output = attn_output.transpose(1, 2)
169
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
170
+
171
+ attn_output = self.out_proj(attn_output)
172
+
173
+ return attn_output, attn_weights_reshaped
174
+
175
+ class CLIPTextModelWrapper(CLIPTextModel):
176
+ # Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
177
+ # Modified to accept precomputed token embeddings "input_token_embs" as input or calculate them from input_ids and return them.
178
+ def forward(
179
+ self,
180
+ input_ids: Optional[torch.Tensor] = None,
181
+ attention_mask: Optional[torch.Tensor] = None,
182
+ position_ids: Optional[torch.Tensor] = None,
183
+ output_attentions: Optional[bool] = None,
184
+ output_hidden_states: Optional[bool] = None,
185
+ return_dict: Optional[bool] = None,
186
+ input_token_embs: Optional[torch.Tensor] = None,
187
+ hidden_state_layer_weights: Optional[torch.Tensor] = None,
188
+ return_token_embs: Optional[bool] = False,
189
+ ) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]:
190
+
191
+ if return_token_embs:
192
+ return self.text_model.embeddings.token_embedding(input_ids)
193
+
194
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
195
+
196
+ output_attentions = output_attentions if output_attentions is not None else self.text_model.config.output_attentions
197
+ output_hidden_states = (
198
+ output_hidden_states if output_hidden_states is not None else self.text_model.config.output_hidden_states
199
+ )
200
+ if hidden_state_layer_weights is not None:
201
+ output_hidden_states = True
202
+ return_dict = return_dict if return_dict is not None else self.text_model.config.use_return_dict
203
+
204
+ if input_ids is None:
205
+ raise ValueError("You have to specify input_ids")
206
+
207
+ input_shape = input_ids.size()
208
+ input_ids = input_ids.view(-1, input_shape[-1])
209
+
210
+ hidden_states = self.text_model.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=input_token_embs)
211
+
212
+ # CLIP's text model uses causal mask, prepare it here.
213
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
214
+ causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
215
+ # expand attention_mask
216
+ if attention_mask is not None:
217
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
218
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
219
+
220
+ encoder_outputs = self.text_model.encoder(
221
+ inputs_embeds=hidden_states,
222
+ attention_mask=attention_mask,
223
+ causal_attention_mask=causal_attention_mask,
224
+ output_attentions=output_attentions,
225
+ # output_hidden_states is False by default, and only True if hidden_state_layer_weights is provided.
226
+ output_hidden_states=output_hidden_states,
227
+ return_dict=return_dict,
228
+ )
229
+
230
+ # If output_hidden_states is True, then encoder_outputs[0] is last_hidden_state [1, 22, 768].
231
+ # encoder_outputs[1] is hidden_states, which is a tuple of 13 hidden states, each being [1, 22, 768].
232
+ # encoder_outputs[0] == encoder_outputs[1][12].
233
+ if hidden_state_layer_weights is None:
234
+ last_hidden_state = encoder_outputs[0]
235
+ else:
236
+ num_hidden_state_layers = len(hidden_state_layer_weights)
237
+ last_hidden_states = encoder_outputs[1][-num_hidden_state_layers:]
238
+ hidden_state_layer_weights = hidden_state_layer_weights.to(last_hidden_states[0].dtype)
239
+ # Normalize the weights of to sum to 1 across layers.
240
+ # hidden_state_layer_weights: [3, 1] or [3, 768].
241
+ hidden_state_layer_weights = hidden_state_layer_weights / hidden_state_layer_weights.sum(dim=0, keepdim=True)
242
+ # [3, 1/768] -> [3, 1, 1, 1/768]
243
+ hidden_state_layer_weights = hidden_state_layer_weights.unsqueeze(1).unsqueeze(1)
244
+ # A weighted sum of last_hidden_states.
245
+ # [3, 1, 22, 768] * [3, 1, 1, 1/768] -> [3, 1, 22, 768] -> [1, 22, 768]
246
+ last_hidden_state = (torch.stack(last_hidden_states, dim=0) * hidden_state_layer_weights).sum(dim=0)
247
+
248
+ last_hidden_state = self.text_model.final_layer_norm(last_hidden_state)
249
+
250
+ # self.text_model.eos_token_id == 2 is True.
251
+ if self.text_model.eos_token_id == 2:
252
+ # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
253
+ # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
254
+ # ------------------------------------------------------------
255
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
256
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
257
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
258
+ pooled_output = last_hidden_state[
259
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
260
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
261
+ ]
262
+ else:
263
+ # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
264
+ pooled_output = last_hidden_state[
265
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
266
+ # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
267
+ (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.text_model.eos_token_id)
268
+ .int()
269
+ .argmax(dim=-1),
270
+ ]
271
+
272
+ if not return_dict:
273
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
274
+
275
+ return BaseModelOutputWithPooling(
276
+ last_hidden_state=last_hidden_state,
277
+ pooler_output=pooled_output,
278
+ hidden_states=encoder_outputs.hidden_states,
279
+ attentions=encoder_outputs.attentions,
280
+ )
281
+
282
+ # Applied to layers [begin_layer_idx, end_layer_idx) in the encoder.
283
+ # The layer indexed by end_layer_idx is not included.
284
+ # If both layer indices are -1, then apply to all layers (0-11).
285
+ def extend_clip_attention_MKV_multiplier(self, begin_layer_idx=-1, end_layer_idx=-1, multiplier=2, noise_std=0.1):
286
+ num_extended_layers = 0
287
+
288
+ for layer_idx, layer in enumerate(self.text_model.encoder.layers):
289
+ if begin_layer_idx >= 0 and layer_idx < begin_layer_idx:
290
+ continue
291
+ if end_layer_idx >= 0 and layer_idx >= end_layer_idx:
292
+ break
293
+ # This shouldn't happen, unless self_attn has already been extended as CLIPAttentionMKV.
294
+ if not isinstance(layer.self_attn, (CLIPAttention, CLIPAttentionMKV)):
295
+ breakpoint()
296
+ old_attn_layer = layer.self_attn
297
+ if not isinstance(old_attn_layer, CLIPAttentionMKV):
298
+ layer.self_attn = CLIPAttentionMKV(old_attn_layer.config, 1)
299
+ layer.self_attn.extend_weights(old_attn_layer, layer_idx, multiplier, noise_std, verbose=True)
300
+ num_extended_layers += 1
301
+
302
+ return num_extended_layers
303
+
adaface/subj_basis_generator.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Borrowed from ip-adapter resampler.py.
2
+ # https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
3
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
4
+ # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
5
+
6
+ import math
7
+
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+ from einops.layers.torch import Rearrange
13
+ from transformers import CLIPVisionModel, CLIPTokenizer
14
+
15
+ import numpy as np
16
+ from torch import einsum
17
+ from dataclasses import dataclass
18
+ from typing import Optional, Tuple
19
+ from transformers.utils import ModelOutput
20
+ from adaface.util import arc2face_inverse_face_prompt_embs, gen_gradient_scaler
21
+ from adaface.arc2face_models import CLIPTextModelWrapper
22
+ import sys
23
+ sys.modules['ldm'] = sys.modules['adaface']
24
+
25
+ def reshape_tensor(x, num_heads):
26
+ bs, length, width = x.shape
27
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
28
+ x = x.view(bs, length, num_heads, -1)
29
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
30
+ x = x.transpose(1, 2)
31
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
32
+ x = x.reshape(bs, num_heads, length, -1)
33
+ return x
34
+
35
+ # FFN. Added a Dropout layer at the end, so that it can still load the old ckpt.
36
+ def FeedForward(dim, mult=4, p_dropout=0.1):
37
+ inner_dim = int(dim * mult)
38
+ return nn.Sequential(
39
+ nn.LayerNorm(dim),
40
+ nn.Linear(dim, inner_dim, bias=False),
41
+ nn.GELU(),
42
+ nn.Linear(inner_dim, dim, bias=False),
43
+ nn.Dropout(p_dropout),
44
+ )
45
+
46
+ # IP-Adapter FaceID class. Only used in knn-faces.py.
47
+ # From: https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter_faceid_separate.py
48
+ class IP_MLPProjModel(nn.Module):
49
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
50
+ super().__init__()
51
+
52
+ self.cross_attention_dim = cross_attention_dim
53
+ self.num_tokens = num_tokens
54
+
55
+ self.proj = nn.Sequential(
56
+ nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
57
+ nn.GELU(),
58
+ nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
59
+ )
60
+ self.norm = nn.LayerNorm(cross_attention_dim)
61
+
62
+ def forward(self, id_embeds):
63
+ x = self.proj(id_embeds)
64
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
65
+ x = self.norm(x)
66
+ return x
67
+
68
+ # group_dim: the tensor dimension that corresponds to the multiple groups.
69
+ class LearnedSoftAggregate(nn.Module):
70
+ def __init__(self, num_feat, group_dim, keepdim=False):
71
+ super(LearnedSoftAggregate, self).__init__()
72
+ self.group_dim = group_dim
73
+ # num_feat = 1: element-wise score function & softmax.
74
+ # num_feat > 1: the linear score function is applied to the last dim (features) of the input tensor.
75
+ self.num_feat = num_feat
76
+ self.feat2score = nn.Linear(num_feat, 1, bias=False)
77
+ self.keepdim = keepdim
78
+
79
+ def forward(self, x, score_basis=None):
80
+ # If there's only one mode, do nothing.
81
+ if x.shape[self.group_dim] == 1:
82
+ if self.keepdim:
83
+ return x
84
+ else:
85
+ return x.squeeze(self.group_dim)
86
+
87
+ # Assume the last dim of x is the feature dim.
88
+ if score_basis is None:
89
+ score_basis = x
90
+
91
+ if self.num_feat == 1:
92
+ mode_scores = self.feat2score(score_basis.unsqueeze(-1)).squeeze(-1)
93
+ else:
94
+ mode_scores = self.feat2score(score_basis)
95
+ attn_probs = mode_scores.softmax(dim=self.group_dim)
96
+ x_aggr = (x * attn_probs).sum(dim=self.group_dim, keepdim=self.keepdim)
97
+ return x_aggr
98
+
99
+ def LoRA_ExpandEmbs(input_dim, lora_rank, output_dim, num_modes,
100
+ num_output_vecs, elementwise_affine=True, p_dropout=0.1):
101
+ return nn.Sequential(
102
+ # Project to [BS, lora_rank * output_dim * num_modes].
103
+ # It takes a huge param size. 512 * 32 * 768 * 4 = 6,291,456.
104
+ nn.Linear(input_dim, lora_rank * output_dim * num_modes, bias=False),
105
+ # Reshape to [BS, lora_rank, output_dim].
106
+ Rearrange('b (m q d) -> b m q d', q=lora_rank, m=num_modes, d=output_dim),
107
+ nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
108
+ # Aggregate [BS, num_modes, loar_rank, output_dim] -> [BS, lora_rank, output_dim].
109
+ LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \
110
+ else Rearrange('b () q d -> b q d'),
111
+ nn.Dropout(p_dropout),
112
+ # Permute to [BS, output_dim, lora_rank].
113
+ Rearrange('b q d -> b d q'),
114
+ # Project to [BS, output_dim, num_output_vecs].
115
+ nn.Linear(lora_rank, num_output_vecs, bias=False),
116
+ # Permute to [BS, num_output_vecs, output_dim].
117
+ Rearrange('b d q -> b q d'),
118
+ nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
119
+ nn.Dropout(p_dropout),
120
+ )
121
+
122
+ def ExpandEmbs(input_dim, output_dim, expansion_ratio, elementwise_affine=True, p_dropout=0.1):
123
+ return nn.Sequential(
124
+ # Project to [BS, num_output_vecs * output_dim].
125
+ nn.Linear(input_dim, expansion_ratio * output_dim, bias=False),
126
+ # Reshape to [BS, num_output_vecs, output_dim].
127
+ Rearrange('b (e d) -> b e d', e=expansion_ratio, d=output_dim),
128
+ nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
129
+ nn.Dropout(p_dropout),
130
+ )
131
+
132
+ # Input: [BS, N, D].
133
+ def MultimodeProjection(input_dim, output_dim=-1, num_modes=4, elementwise_affine=True, p_dropout=0.1):
134
+ if output_dim == -1:
135
+ output_dim = input_dim
136
+
137
+ return nn.Sequential(
138
+ nn.Linear(input_dim, output_dim * num_modes, bias=False),
139
+ # Reshape to [BS, num_output_vecs, output_dim].
140
+ Rearrange('b n (m d) -> b n m d', m=num_modes, d=output_dim),
141
+ nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
142
+ # If num_modes == 1, then simply remove the mode dim. Otherwise, aggregate the modes.
143
+ LearnedSoftAggregate(num_feat=output_dim, group_dim=2, keepdim=False) if num_modes > 1 \
144
+ else Rearrange('b n () d -> b n d'),
145
+ nn.Dropout(p_dropout),
146
+ )
147
+
148
+ # Low-rank to high-rank transformation.
149
+ def Lora2Hira(lora_rank, hira_rank, output_dim, num_modes, elementwise_affine=True, p_dropout=0.1):
150
+ return nn.Sequential(
151
+ # Permute to [BS, output_dim, lora_rank].
152
+ Rearrange('b q d -> b d q'),
153
+ # Project to [BS, output_dim, hira_rank].
154
+ nn.Linear(lora_rank, hira_rank * num_modes, bias=False),
155
+ # Reshape and permute to [BS, num_modes, num_output_vecs, output_dim].
156
+ Rearrange('b d (m q) -> b m q d', m=num_modes, q=hira_rank),
157
+ nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
158
+ # Aggregate [BS, num_modes, hira_rank, output_dim] -> [BS, hira_rank, output_dim].
159
+ LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \
160
+ else Rearrange('b () q d -> b q d'),
161
+ nn.Dropout(p_dropout),
162
+ )
163
+
164
+ class PerceiverAttention(nn.Module):
165
+ def __init__(self, *, dim, dim_head=64, num_heads=8, elementwise_affine=True):
166
+ super().__init__()
167
+ self.scale = dim_head**-0.5
168
+ self.dim_head = dim_head
169
+ self.num_heads = num_heads
170
+ inner_dim = dim_head * num_heads
171
+
172
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine)
173
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine)
174
+
175
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
176
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
177
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
178
+
179
+ def forward(self, x, latent_queries):
180
+ """
181
+ Args:
182
+ x (torch.Tensor): image features
183
+ shape (b, n1, D)
184
+ latent (torch.Tensor): latent features
185
+ shape (b, n2, D)
186
+ """
187
+ x = self.norm1(x)
188
+ latent_queries = self.norm2(latent_queries)
189
+
190
+ b, l, _ = latent_queries.shape
191
+
192
+ q = self.to_q(latent_queries)
193
+ kv_input = torch.cat((x, latent_queries), dim=-2)
194
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
195
+
196
+ q = reshape_tensor(q, self.num_heads)
197
+ k = reshape_tensor(k, self.num_heads)
198
+ v = reshape_tensor(v, self.num_heads)
199
+
200
+ # attention
201
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
202
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
203
+ attn = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
204
+ out = attn @ v
205
+
206
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
207
+
208
+ return self.to_out(out)
209
+
210
+
211
+ class CrossAttention(nn.Module):
212
+ # output_dim is always the same as input_dim.
213
+ # num_q only matters when q_aware_to_v is True.
214
+ # If q_aware_to_v is False, query x in forward() is still usable.
215
+ def __init__(self, input_dim, num_heads=6, p_dropout=0.05,
216
+ identity_to_q=False, identity_to_k=False, identity_to_v=False, v_has_skip=True,
217
+ q_aware_to_v=True, num_q=416, v_repeat=4, q_aware_to_v_lora_rank=64,
218
+ identity_to_out=False, out_has_skip=False):
219
+ super().__init__()
220
+ dim_head = input_dim // num_heads
221
+ inner_dim = dim_head * num_heads
222
+
223
+ self.num_heads = num_heads
224
+ self.q_aware_to_v = q_aware_to_v
225
+ self.v_has_skip = v_has_skip
226
+ self.to_q = nn.Sequential(
227
+ nn.Linear(input_dim, inner_dim, bias=False),
228
+ nn.LayerNorm(inner_dim, elementwise_affine=True)
229
+ ) if not identity_to_q else nn.Identity()
230
+ self.to_k = nn.Sequential(
231
+ nn.Linear(input_dim, inner_dim, bias=False),
232
+ nn.LayerNorm(inner_dim, elementwise_affine=True)
233
+ ) if not identity_to_k else nn.Identity()
234
+
235
+ self.v_repeat = v_repeat
236
+ self.num_q_group = num_q_group = num_q // v_repeat # 416 / 4 = 104.
237
+
238
+ # If q_aware_to_v is True, then self.to_v consists of num_q projections of input_dim to inner_dim.
239
+ # Otherwise, self.to_v consists of a single projection of input_dim to inner_dim.
240
+ if q_aware_to_v:
241
+ # all_q_mid: 104 * 64 = 6656.
242
+ all_q_mid = num_q_group * q_aware_to_v_lora_rank
243
+ self.to_v = nn.Sequential(
244
+ # number of params: 768 * 6656 = 5,111,808.
245
+ # Input: [BS, 16, 768]. Output: [BS, 16, 104*64] = [BS, 16, 6656].
246
+ # Each 768-dim vec is dispersed into 104 64-dim vecs.
247
+ nn.Linear(input_dim, all_q_mid, bias=False),
248
+ nn.LayerNorm(all_q_mid, elementwise_affine=True),
249
+ # Change the dim of the tensor to [BS, 6656, 16], as Conv1d transforms dim 1.
250
+ Rearrange('b n q -> b q n', q=all_q_mid),
251
+ # Each q_aware_to_v projection has its own linear layer.
252
+ # The total number of parameters will be 6656*768 = 5,111,808.
253
+ # Output: [BS, 104*768, 16]. Each 64 dim feature is expanded to 768 dim.
254
+ nn.Conv1d(
255
+ in_channels=all_q_mid,
256
+ out_channels=num_q_group * input_dim,
257
+ kernel_size=1,
258
+ groups=num_q_group,
259
+ bias=False,
260
+ ),
261
+ # Output: [BS, 104, 16, 768].
262
+ Rearrange('b (q d) n -> b q n d', q=num_q_group, d=input_dim),
263
+ nn.LayerNorm(input_dim, elementwise_affine=True),
264
+ )
265
+ else:
266
+ self.to_v = nn.Sequential(
267
+ nn.Linear(input_dim, inner_dim, bias=False),
268
+ nn.LayerNorm(inner_dim, elementwise_affine=True)
269
+ ) if not identity_to_v else nn.Identity()
270
+
271
+ if identity_to_out:
272
+ assert not out_has_skip, "identity_to_out=True, then out_has_skip has to be False."
273
+
274
+ if identity_to_out:
275
+ self.to_out = nn.Identity()
276
+ else:
277
+ self.to_out = nn.Sequential(
278
+ nn.Linear(input_dim, input_dim, bias=False),
279
+ nn.Dropout(p_dropout),
280
+ nn.LayerNorm(inner_dim, elementwise_affine=True)
281
+ )
282
+
283
+ self.out_has_skip = out_has_skip
284
+ self.attn_drop = nn.Dropout(p_dropout)
285
+
286
+ def forward(self, x, context=None, attn_mat=None, return_attn=False):
287
+ h = self.num_heads
288
+
289
+ if context is None:
290
+ context = x
291
+
292
+ if attn_mat is None:
293
+ # q: [BS, Q, D] -> [BS, Q, D].
294
+ q = self.to_q(x)
295
+ # k: [BS, L, D] -> [BS, L, D].
296
+ k = self.to_k(context)
297
+ # q: [6, 512, 128], k: [6, 17, 128].
298
+ q, k = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k))
299
+
300
+ if self.q_aware_to_v:
301
+ # context: [BS, L, D]. v: [BS, Q, L, D].
302
+ # There are effectively Q to_v projections.
303
+ v = self.to_v(context)
304
+ if self.v_has_skip:
305
+ v = v + context.unsqueeze(1)
306
+ else:
307
+ # v: [BS, L, D].
308
+ v = self.to_v(context)
309
+ if self.v_has_skip:
310
+ v = v + context
311
+
312
+ #print(v.shape)
313
+
314
+ if self.q_aware_to_v:
315
+ # v: [6, 64, 17, 128].
316
+ # v is query-specific, so there's an extra dim for the query.
317
+ v = rearrange(v, 'b q n (h d) -> (b h) q n d', h=h)
318
+ # Each v is for a query group with 512/64 = 8 queries.
319
+ # So each v is repeated 8 times to match the number of queries.
320
+ # v: [6, 64, 17, 128] -> [6, 512, 17, 128].
321
+ v = v.repeat(1, self.v_repeat, 1, 1)
322
+ else:
323
+ v = rearrange(v, 'b n (h d) -> (b h) n d', h=h)
324
+
325
+ if attn_mat is None:
326
+ scale = q.size(-1) ** -0.25
327
+ sim = einsum('b i d, b j d -> b i j', q * scale, k * scale)
328
+ # sim: [6, 64, 17]. 6: bs 1 * h 6.
329
+ # attention, what we cannot get enough of
330
+ # NOTE: the normalization is done across tokens, not across pixels.
331
+ # So for each pixel, the sum of attention scores across tokens is 1.
332
+ attn = sim.softmax(dim=-1)
333
+ attn = self.attn_drop(attn)
334
+ #print(attn.std())
335
+ else:
336
+ attn = attn_mat
337
+
338
+ if self.q_aware_to_v:
339
+ # attn: [6, 32, 17]. v: [6, 32, 17, 128]. 128: dim of each head. out: [6, 32, 128].
340
+ # out is combined with different attn weights and v for different queries.
341
+ out = einsum('b i j, b i j d -> b i d', attn, v)
342
+ else:
343
+ # v: [6, 17, 128]. out: [6, 32, 128].
344
+ out = einsum('b i j, b j d -> b i d', attn, v)
345
+
346
+ # [6, 32, 128] -> [1, 32, 768].
347
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
348
+
349
+ if self.out_has_skip:
350
+ out = self.to_out(out) + out
351
+ else:
352
+ out = self.to_out(out)
353
+
354
+ if return_attn:
355
+ return out, attn
356
+ else:
357
+ return out
358
+
359
+ class SubjBasisGenerator(nn.Module):
360
+ def __init__(
361
+ self,
362
+ # number of cross-attention heads. Half of the number of heads 12 of OpenAI clip-vit-large-patch14:
363
+ # https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
364
+ num_heads=6,
365
+ num_id_vecs={ 'subj': 77, 'bg': 257 }, # number of identity vectors. 18: 16 face tokens + 2 extra tokens. 257: 257 CLIP tokens.
366
+ num_out_embs_per_layer=4, # num_out_embs. subj: 16. bg: 4.
367
+ num_out_layers=16, # number of layers of output embeddings.
368
+ image_embedding_dim=768, # CLIP image feature dimension, as per config.json above.
369
+ # DINO vits16 has 6 attention heads:
370
+ # https://huggingface.co/facebook/dino-vits16/blob/main/config.json
371
+ dino_embedding_dim=384, # DINO object feature dimension for objects.
372
+ output_dim=768, # CLIP text embedding input dimension.
373
+ placeholder_is_bg: bool = False, # Whether the placeholder is for the image background.
374
+ prompt2token_proj_grad_scale: float = 0.4, # Gradient scale for prompt2token_proj.
375
+ zs_extra_words_scale: float = 0.5, # Scale for extra words in the prompt2token_proj.
376
+ learnable_hidden_state_weights_scheme: str = 'per-layer', # none, per-layer.
377
+ bg_prompt_translator_has_to_out_proj: bool = False, # Whether the prompt_trans_layers have a to_out projection.
378
+ ):
379
+ super().__init__()
380
+
381
+ self.placeholder_is_bg = placeholder_is_bg
382
+ self.num_out_layers = num_out_layers
383
+ self.num_out_embs_per_layer = num_out_embs_per_layer
384
+ # subj: 64, bg: 32.
385
+ self.num_out_embs = num_out_layers * num_out_embs_per_layer
386
+ self.output_dim = output_dim
387
+ # num_id_vecs should be the number of core ID embs, 16.
388
+ # However, in such case, pos_embs is not used. So it doesn't matter if it's wrongly set.
389
+ self.num_id_vecs = num_id_vecs['bg'] if placeholder_is_bg else num_id_vecs['subj']
390
+ self.pos_embs = nn.Parameter(torch.randn(1, self.num_id_vecs, output_dim))
391
+ self.pos_embs_ln = nn.LayerNorm(output_dim)
392
+ self.zs_extra_words_scale = zs_extra_words_scale
393
+ self.output_scale = output_dim ** -0.5
394
+ self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
395
+
396
+ if not self.placeholder_is_bg:
397
+ # [1, 384] -> [1, 16, 768].
398
+ # TODO: use CLIPTextModelWrapper as obj_proj_in.
399
+ self.obj_proj_in = ExpandEmbs(dino_embedding_dim, output_dim, expansion_ratio=self.num_id_vecs)
400
+
401
+ # self.prompt2token_proj: [1, 16, 768] -> [1, 77, 768] (with paddings).
402
+ # If self.placeholder_is_bg: prompt2token_proj is set to None.
403
+ self.prompt2token_proj = CLIPTextModelWrapper.from_pretrained('openai/clip-vit-large-patch14')
404
+ self.prompt2token_proj_grad_scale = prompt2token_proj_grad_scale
405
+ self.prompt2token_proj_grad_scaler = gen_gradient_scaler(prompt2token_proj_grad_scale)
406
+ print(f"Subj prompt2token_proj initialized with grad scale of {prompt2token_proj_grad_scale}.")
407
+ # Freeze prompt2token_proj if prompt2token_proj_grad_scale is 0.
408
+ # Set requires_grad to False for all parameters in prompt2token_proj, to save memory taken by the optimizer.
409
+ if prompt2token_proj_grad_scale == 0:
410
+ self.freeze_prompt2token_proj()
411
+
412
+ self.prompt2token_proj_attention_multiplier = -1
413
+ self.initialize_hidden_state_layer_weights(learnable_hidden_state_weights_scheme, 'cpu')
414
+ self.pad_embeddings = None
415
+ self.bg_proj_in = None
416
+ else:
417
+ # For background placeholders, face and object embeddings are not used as they are foreground.
418
+ self.obj_proj_in = None
419
+ self.prompt2token_proj = None
420
+ print("Bg prompt2token_proj is set to None.")
421
+
422
+ self.bg_proj_in = nn.Sequential(
423
+ nn.Linear(image_embedding_dim, output_dim, bias=False),
424
+ nn.LayerNorm(output_dim),
425
+ )
426
+
427
+ self.latent_queries = nn.Parameter(torch.randn(1, self.num_out_embs, output_dim))
428
+ self.latent_queries_ln = nn.LayerNorm(output_dim)
429
+
430
+ self.bg_prompt_translator_has_to_out_proj = bg_prompt_translator_has_to_out_proj
431
+ identity_to_v = False
432
+ v_has_skip = not identity_to_v # True
433
+ identity_to_out = not bg_prompt_translator_has_to_out_proj # True
434
+ out_has_skip = not identity_to_out # False
435
+ # prompt_translator has a to_v projection with skip connection, and doesn't have a to_out projection.
436
+ # dim=768, num_heads=6.
437
+ self.prompt_translator = \
438
+ CrossAttention(input_dim=output_dim, num_heads=num_heads, p_dropout=0.05,
439
+ identity_to_q=False, identity_to_k=False, identity_to_v=identity_to_v,
440
+ q_aware_to_v=False, v_has_skip=v_has_skip,
441
+ num_q=0, # When not q_aware_to_v, num_q is not referenced.
442
+ identity_to_out=identity_to_out,
443
+ out_has_skip=out_has_skip)
444
+ '''
445
+ prompt_translator: CLIPEncoder
446
+ # https://github.com/huggingface/transformers/blob/1872bde7fc6a5d6796bd742bc2dc38eaf8069c5d/src/transformers/models/clip/modeling_clip.py#L566
447
+ # CLIPEncoder.layers: 12 layers of CLIPEncoderLayer, each being
448
+ (0): CLIPEncoderLayer(
449
+ (self_attn): CLIPAttention(
450
+ (k_proj): Linear(in_features=768, out_features=768, bias=True)
451
+ (v_proj): Linear(in_features=768, out_features=768, bias=True)
452
+ (q_proj): Linear(in_features=768, out_features=768, bias=True)
453
+ (out_proj): Linear(in_features=768, out_features=768, bias=True)
454
+ )
455
+ (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
456
+ (mlp): CLIPMLP(
457
+ (activation_fn): QuickGELUActivation()
458
+ (fc1): Linear(in_features=768, out_features=3072, bias=True)
459
+ (fc2): Linear(in_features=3072, out_features=768, bias=True)
460
+ )
461
+ (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
462
+ )
463
+ '''
464
+
465
+ print(repr(self))
466
+
467
+ # raw_id_embs: ArcFace embeddings for faces (not used since we have arc2face_id_embs),
468
+ # or DINO embeddings for objects.
469
+ # arc2face_id_embs: [BS, 16, 768], the core identity embeddings generated by Arc2Face.
470
+ def forward(self, arc2face_id_embs, clip_features=None, raw_id_embs=None, out_id_embs_scale=1.0,
471
+ is_face=True, is_training=False, adaface_prompt_embs_inf_type='full_half_pad'):
472
+
473
+ if not self.placeholder_is_bg:
474
+ BS = arc2face_id_embs.shape[0]
475
+ else:
476
+ # If bg, then arc2face_id_embs is set to None, but clip_features is not None.
477
+ BS = clip_features.shape[0]
478
+
479
+ adaface_prompt_embs = None
480
+ if not hasattr(self, 'clip_tokenizer'):
481
+ self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
482
+
483
+ # No need to use raw_id_embs if placeholder_is_bg.
484
+ if not self.placeholder_is_bg:
485
+ if is_face:
486
+ assert arc2face_id_embs is not None
487
+ # arc2face_embs has been projected to the (modified) prompt embedding space
488
+ # by arc2face_forward_face_embs. This prompt embedding space is modified because Arc2Face finetuned
489
+ # the text encoder and the U-Net.
490
+ # in embedding_manager: [BS, 16, 768] -> [BS, 77, 768].
491
+ # arc2face_id_embs is part of arc2face_embs: [BS, 77, 768] -> [BS, 16, 768].
492
+ # adaface_prompt_embs is projected to the prompt embedding spaces. This is the
493
+ # original U-Net prompt embedding space.
494
+
495
+ # hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]]
496
+ hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights)
497
+ # return_emb_types: a list of strings, each string is among
498
+ # ['full', 'core', 'full_pad', 'full_half_pad', 'full_zeroed_extra', 'b_core_e'].
499
+ # Using b_core_e is more computationally efficient than using full_zeroed_extra.
500
+ # But there is an unknow BUG that causes crash when using b_core_e.
501
+ if is_training:
502
+ return_emb_types = ['full_pad', 'core']
503
+ else:
504
+ # adaface_prompt_embs_inf_type: default is full_half_pad, same as training.
505
+ return_emb_types = [adaface_prompt_embs_inf_type, 'core']
506
+
507
+ if self.pad_embeddings is None:
508
+ self.generate_pad_embeddings()
509
+ else:
510
+ self.pad_embeddings = self.pad_embeddings.to(arc2face_id_embs.device)
511
+
512
+ with torch.set_grad_enabled(self.training and self.prompt2token_proj_grad_scale != 0):
513
+ # If list_extra_words is not None, then core_id_embs: [BS, 18, 768], three leading words, the 16 identity tokens
514
+ # and (at most) two extra words in full_prompt_embs, without BOS and EOS.
515
+ # If list_extra_words is None, then core_id_embs: [BS, 16, 768], the 16 identity tokens in full_prompt_embs.
516
+ # hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]]
517
+ # zs_extra_words_scale is only effective when list_extra_words is not None.
518
+ # adaface_prompt_embs: [BS, 77, 768], core_id_embs: [BS, 16, 768].
519
+ adaface_prompt_embs, core_id_embs = \
520
+ arc2face_inverse_face_prompt_embs(self.clip_tokenizer,
521
+ self.prompt2token_proj,
522
+ arc2face_id_embs,
523
+ list_extra_words=None,
524
+ return_emb_types=return_emb_types,
525
+ pad_embeddings=self.pad_embeddings,
526
+ hidden_state_layer_weights=hidden_state_layer_weights,
527
+ input_max_length=77, zs_extra_words_scale=self.zs_extra_words_scale)
528
+ # Reduce the update rate to prompt2token_proj.
529
+ adaface_prompt_embs = self.prompt2token_proj_grad_scaler(adaface_prompt_embs)
530
+ core_id_embs = self.prompt2token_proj_grad_scaler(core_id_embs)
531
+ elif raw_id_embs is not None:
532
+ # id_embs: [BS, 384] -> [BS, 18, 768].
533
+ # obj_proj_in is expected to project the DINO object features to
534
+ # the token embedding space. So no need to use prompt2token_proj.
535
+ id_embs = self.obj_proj_in(raw_id_embs)
536
+ else:
537
+ breakpoint()
538
+ else:
539
+ # Otherwise, context is the ad-hoc CLIP image features.
540
+ # id_embs: [BS, 257, 768].
541
+ id_embs = self.bg_proj_in(clip_features)
542
+
543
+ if self.placeholder_is_bg:
544
+ id_embs = id_embs + self.pos_embs_ln(self.pos_embs)
545
+ latent_queries = self.latent_queries_ln(self.latent_queries).repeat(BS, 1, 1)
546
+ # If bg, we don't have to use a specific attn layer for each 4-vec set. Instead, one attn layer can generate 257 embs,
547
+ # and we take the first 16*4=64.
548
+ # Output of prompt_translator is exactly num_out_embs == 64 tokens. id_embs_out: [BS, 64, 768].
549
+ # prompt_translator: better named as bg_prompt_translator. It maps the bg features
550
+ # to bg prompt embeddings.
551
+ with torch.set_grad_enabled(self.training):
552
+ id_embs_out = self.prompt_translator(latent_queries, id_embs)
553
+ # [BS, 64, 768] -> [BS, 16, 4, 768]
554
+ id_embs_out = id_embs_out.reshape(BS, self.num_out_layers, -1, self.output_dim)
555
+ adaface_subj_embs = id_embs_out * self.output_scale # * 0.036
556
+ else:
557
+ # adaface_subj_embs: [BS, 16, 768] -> [BS, 1, 16, 768] -> [BS, 16, 16, 768]
558
+ adaface_subj_embs = core_id_embs.unsqueeze(1).repeat(1, self.num_out_layers, 1, 1)
559
+
560
+ # If out_id_embs_scale < 1, adaface_subj_embs is a mix of adaface_subj_embs and pad_embeddings.
561
+ if out_id_embs_scale != 1:
562
+ # pad_embeddings: [77, 768] -> [16, 768] -> [1, 1, 16, 768].
563
+ pad_embeddings = self.pad_embeddings[4:4+self.num_out_embs_per_layer].unsqueeze(0).unsqueeze(0)
564
+ adaface_subj_embs = adaface_subj_embs * out_id_embs_scale \
565
+ + pad_embeddings * (1 - out_id_embs_scale)
566
+
567
+ return adaface_subj_embs, adaface_prompt_embs
568
+
569
+ def initialize_hidden_state_layer_weights(self, learnable_hidden_state_weights_scheme, device):
570
+ if learnable_hidden_state_weights_scheme == 'none':
571
+ self.hidden_state_layer_weights = None
572
+ # A grad scaler with alpha =1 is nn.Identity(), which outputs None given None as input.
573
+ self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(1)
574
+ print("hidden_state_layer_weights is set to None.")
575
+
576
+ elif learnable_hidden_state_weights_scheme == 'per-layer':
577
+ # Learnable weights of the last 3 layers, initialized to putting more focus on the last layer.
578
+ # 'per-layer': Different weights for different layers, but the same for different channels.
579
+ # hidden_state_layer_weights: [3, 1].
580
+ self.hidden_state_layer_weights = nn.Parameter(torch.tensor([[1.0], [2.0], [4.0]], device=device),
581
+ requires_grad=True)
582
+ self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(5)
583
+ print("hidden_state_layer_weights initialized as per-layer [1, 2, 4], with grad scaler 5.")
584
+ else:
585
+ breakpoint()
586
+
587
+ def generate_pad_embeddings(self):
588
+ # clip_embeddings: CLIPTextEmbeddings instance. pad_embeddings is generated after
589
+ # prompt2token_proj is loaded from the finetuned weight. It seems such pad embeddings perform
590
+ # slightly better than the original pad embeddings.
591
+ clip_embeddings = self.prompt2token_proj.text_model.embeddings
592
+ # clip_embeddings() and clip_embeddings.token_embedding() differ in that
593
+ # clip_embeddings() adds positional embeddings, while clip_embeddings.token_embedding() doesn't.
594
+ # Adding positional embeddings seems to help somewhat.
595
+ # pad_tokens: pad_token_id 49407 repeated 77 times.
596
+ # pad_token_id is the EOS token. But BOS is 49406.
597
+ pad_tokens = torch.tensor([self.clip_tokenizer.pad_token_id]).to(clip_embeddings.token_embedding.weight.device).repeat(77)
598
+ # pad_embeddings: [77, 768].
599
+ pad_embeddings = clip_embeddings(pad_tokens)[0]
600
+ # We don't allow face recon to influence the pad embeddings.
601
+ # Otherwise, face identity will leak into the pad embeddings.
602
+ self.pad_embeddings = pad_embeddings.detach()
603
+
604
+ def extend_prompt2token_proj_attention(self, begin_layer_idx=-1, end_layer_idx=-1, multiplier=2, noise_std=0.1):
605
+ if multiplier > 1:
606
+ num_extended_layers = self.prompt2token_proj.extend_clip_attention_MKV_multiplier(begin_layer_idx, end_layer_idx, multiplier, noise_std)
607
+ self.prompt2token_proj_attention_multiplier = multiplier
608
+ print(f"{num_extended_layers} layers in prompt2token_proj_attention are x{multiplier}")
609
+
610
+ def freeze_prompt2token_proj(self):
611
+ # If bg, then prompt2token_proj is set to None. Therefore no need to freeze it.
612
+ # Then we don't have to check whether it's for subj or bg.
613
+ if self.prompt2token_proj is not None:
614
+ frozen_param_names = []
615
+ for param_name, param in self.prompt2token_proj.named_parameters():
616
+ if param.requires_grad:
617
+ param.requires_grad = False
618
+ frozen_param_names.append(param_name)
619
+ # If param is already frozen, then no need to freeze it again.
620
+ print(f"{len(frozen_param_names)} params in Subj prompt2token_proj is frozen.")
621
+ #print(f"Frozen parameters:\n{frozen_param_names}")
622
+
623
+ def __repr__(self):
624
+ type_sig = 'subj' if not self.placeholder_is_bg else 'bg'
625
+ # Fix compatability with the previous version.
626
+ if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'):
627
+ self.bg_prompt_translator_has_to_out_proj = False
628
+ if not hasattr(self, 'num_out_embs'):
629
+ self.num_out_embs = -1
630
+ return f"{type_sig} SubjBasisGenerator: num_out_embs={self.num_out_embs}, " \
631
+ f"bg_prompt_translator_has_to_out_proj={self.bg_prompt_translator_has_to_out_proj}"
632
+
633
+ @dataclass
634
+ class BaseModelOutputWithPooling2(ModelOutput):
635
+ """
636
+ Base class for model's outputs that also contains a pooling of the last hidden states.
637
+
638
+ Args:
639
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
640
+ Sequence of hidden-states at the output of the last layer of the model.
641
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
642
+ Last layer hidden-state of the first token of the sequence (classification token) after further processing
643
+ through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
644
+ the classification token after processing through a linear layer and a tanh activation function. The linear
645
+ layer weights are trained from the next sentence prediction (classification) objective during pretraining.
646
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
647
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
648
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
649
+
650
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
651
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
652
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
653
+ sequence_length)`.
654
+
655
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
656
+ heads.
657
+ """
658
+
659
+ last_hidden_state: torch.FloatTensor = None
660
+ pooler_output: torch.FloatTensor = None
661
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
662
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
663
+ attn_mask: Optional[torch.FloatTensor] = None
664
+
665
+ # Revised from CLIPVisionTransformer to support attention mask.
666
+ # self: a CLIPVisionTransformer instance.
667
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L821
668
+ # pixel_values: preprocessed B*C*H*W images. [BS, 3, 224, 224]
669
+ # attn_mask: B*H*W attention mask.
670
+ def CLIPVisionTransformer_forward(self, pixel_values = None, attn_mask=None,
671
+ output_attentions = None,
672
+ output_hidden_states = None, return_dict = None):
673
+
674
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
675
+ output_hidden_states = (
676
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
677
+ )
678
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
679
+
680
+ if pixel_values is None:
681
+ raise ValueError("You have to specify pixel_values")
682
+
683
+ # Visual tokens are flattended in embeddings().
684
+ # self.embeddings: CLIPVisionEmbeddings.
685
+ # hidden_states: [BS, 257, 1280]. 257: 16*16 (patch_embeds) + 1 (class_embeds).
686
+ # 16*16 is output from Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14), bias=False).
687
+ hidden_states = self.embeddings(pixel_values)
688
+ hidden_states = self.pre_layrnorm(hidden_states)
689
+
690
+ if attn_mask is not None:
691
+ # feat_edge_size: 16.
692
+ feat_edge_size = np.sqrt(hidden_states.shape[1] - 1).astype(int)
693
+ # attn_mask: [BS, 512, 512] -> [BS, 1, 16, 16].
694
+ attn_mask = F.interpolate(attn_mask.unsqueeze(1), size=(feat_edge_size, feat_edge_size), mode='nearest')
695
+ # Flatten the mask: [BS, 1, 16, 16] => [BS, 1, 256].
696
+ attn_mask = attn_mask.flatten(2)
697
+ # Prepend 1 to the mask: [BS, 1, 256] => [BS, 1, 257].
698
+ # This 1 corresponds to class_embeds, which is always attended to.
699
+ attn_mask = torch.cat([torch.ones_like(attn_mask[:, :, :1]), attn_mask], dim=-1)
700
+ attn_mask_pairs = torch.matmul(attn_mask.transpose(-1, -2), attn_mask).unsqueeze(1)
701
+ else:
702
+ attn_mask_pairs = None
703
+
704
+ # encoder: CLIPEncoder.
705
+ encoder_outputs = self.encoder(
706
+ inputs_embeds=hidden_states,
707
+ # New feature: (***The official documentation is wrong***)
708
+ # attention_mask (`torch.Tensor` of shape `(batch_size, 1, sequence_length, sequence_length)`, *optional*):
709
+ # Mask to avoid performing attention on pairs of token. Mask values selected in `[0, 1]`:
710
+ # - 1 for pairs that are **not masked**,
711
+ # - 0 for pairs that are **masked**.
712
+ # attention_mask is eventually used by CLIPEncoderLayer:
713
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L370
714
+ attention_mask=attn_mask_pairs,
715
+ output_attentions=output_attentions, # False
716
+ output_hidden_states=output_hidden_states, # True
717
+ return_dict=return_dict, # True
718
+ )
719
+
720
+ # last_hidden_state: [BS, 257, 1280]
721
+ last_hidden_state = encoder_outputs[0]
722
+ pooled_output = last_hidden_state[:, 0, :]
723
+ pooled_output = self.post_layernorm(pooled_output)
724
+
725
+ # return_dict is True.
726
+ if not return_dict:
727
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
728
+
729
+ return BaseModelOutputWithPooling2(
730
+ last_hidden_state=last_hidden_state,
731
+ pooler_output=pooled_output,
732
+ hidden_states=encoder_outputs.hidden_states,
733
+ attentions=encoder_outputs.attentions,
734
+ # Newly added: return resized flattened attention mask.
735
+ # [BS, 1, 257] -> [BS, 257, 1]
736
+ attn_mask=attn_mask.permute(0, 2, 1) if attn_mask is not None else None
737
+ )
738
+
739
+
740
+ class CLIPVisionModelWithMask(CLIPVisionModel):
741
+ def __init__(self, config):
742
+ super().__init__(config)
743
+ # Replace vision_model.forward() with the new one that supports mask.
744
+ self.vision_model.forward = CLIPVisionTransformer_forward.__get__(self.vision_model)
745
+
746
+ def forward(self, pixel_values = None, attn_mask = None, output_attentions = None,
747
+ output_hidden_states = None, return_dict = None):
748
+
749
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
750
+
751
+ return self.vision_model(
752
+ pixel_values=pixel_values,
753
+ attn_mask=attn_mask,
754
+ output_attentions=output_attentions,
755
+ output_hidden_states=output_hidden_states,
756
+ return_dict=return_dict,
757
+ )
758
+
adaface/util.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from PIL import Image
6
+ import cv2
7
+
8
+ # add_noise_to_tensor() adds a fixed amount of noise to the tensor.
9
+ def add_noise_to_tensor(ts, noise_std, noise_std_is_relative=True, keep_norm=False,
10
+ std_dim=-1, norm_dim=-1):
11
+ if noise_std_is_relative:
12
+ ts_std_mean = ts.std(dim=std_dim).mean().detach()
13
+ noise_std *= ts_std_mean
14
+
15
+ noise = torch.randn_like(ts) * noise_std
16
+ if keep_norm:
17
+ orig_norm = ts.norm(dim=norm_dim, keepdim=True)
18
+ ts = ts + noise
19
+ new_norm = ts.norm(dim=norm_dim, keepdim=True).detach()
20
+ ts = ts * orig_norm / (new_norm + 1e-8)
21
+ else:
22
+ ts = ts + noise
23
+
24
+ return ts
25
+
26
+
27
+ # Revised from RevGrad, by removing the grad negation.
28
+ class ScaleGrad(torch.autograd.Function):
29
+ @staticmethod
30
+ def forward(ctx, input_, alpha_, debug=False):
31
+ ctx.save_for_backward(alpha_, debug)
32
+ output = input_
33
+ if debug:
34
+ print(f"input: {input_.abs().mean().item()}")
35
+ return output
36
+
37
+ @staticmethod
38
+ def backward(ctx, grad_output): # pragma: no cover
39
+ # saved_tensors returns a tuple of tensors.
40
+ alpha_, debug = ctx.saved_tensors
41
+ if ctx.needs_input_grad[0]:
42
+ grad_output2 = grad_output * alpha_
43
+ if debug:
44
+ print(f"grad_output2: {grad_output2.abs().mean().item()}")
45
+ else:
46
+ grad_output2 = None
47
+ return grad_output2, None, None
48
+
49
+ class GradientScaler(nn.Module):
50
+ def __init__(self, alpha=1., debug=False, *args, **kwargs):
51
+ """
52
+ A gradient scaling layer.
53
+ This layer has no parameters, and simply scales the gradient in the backward pass.
54
+ """
55
+ super().__init__(*args, **kwargs)
56
+
57
+ self._alpha = torch.tensor(alpha, requires_grad=False)
58
+ self._debug = torch.tensor(debug, requires_grad=False)
59
+
60
+ def forward(self, input_):
61
+ _debug = self._debug if hasattr(self, '_debug') else False
62
+ return ScaleGrad.apply(input_, self._alpha.to(input_.device), _debug)
63
+
64
+ def gen_gradient_scaler(alpha, debug=False):
65
+ if alpha == 1:
66
+ return nn.Identity()
67
+ if alpha > 0:
68
+ return GradientScaler(alpha, debug=debug)
69
+ else:
70
+ assert alpha == 0
71
+ # Don't use lambda function here, otherwise the object can't be pickled.
72
+ return torch.detach
73
+
74
+ #@torch.autocast(device_type="cuda")
75
+ # In AdaFaceWrapper, input_max_length is 22.
76
+ def arc2face_forward_face_embs(tokenizer, arc2face_text_encoder, face_embs,
77
+ input_max_length=77, return_full_and_core_embs=True):
78
+
79
+ '''
80
+ arc2face_text_encoder: arc2face_models.py CLIPTextModelWrapper instance.
81
+ face_embs: (N, 512) normalized ArcFace embeddings.
82
+ return_full_and_core_embs: Return both the full prompt embeddings and the core embeddings.
83
+ If False, return only the core embeddings.
84
+
85
+ '''
86
+
87
+ # arcface_token_id: 1014
88
+ arcface_token_id = tokenizer.encode("id", add_special_tokens=False)[0]
89
+
90
+ # This step should be quite fast, and there's no need to cache the input_ids.
91
+ input_ids = tokenizer(
92
+ "photo of a id person",
93
+ truncation=True,
94
+ padding="max_length",
95
+ max_length=input_max_length, #tokenizer.model_max_length,
96
+ return_tensors="pt",
97
+ ).input_ids.to(face_embs.device)
98
+ # input_ids: [1, 77] or [3, 77] (during training).
99
+ input_ids = input_ids.repeat(len(face_embs), 1)
100
+ face_embs_dtype = face_embs.dtype
101
+ face_embs = face_embs.to(arc2face_text_encoder.dtype)
102
+ # face_embs_padded: [1, 512] -> [1, 768].
103
+ face_embs_padded = F.pad(face_embs, (0, arc2face_text_encoder.config.hidden_size - face_embs.shape[-1]), "constant", 0)
104
+ # arc2face_text_encoder(input_ids=input_ids, ...) is called twice. The first is only to get the token embeddings (the shallowest mapping).
105
+ # The second call does the ordinary CLIP text encoding pass.
106
+ token_embs = arc2face_text_encoder(input_ids=input_ids, return_token_embs=True)
107
+ token_embs[input_ids==arcface_token_id] = face_embs_padded
108
+
109
+ prompt_embeds = arc2face_text_encoder(
110
+ input_ids=input_ids,
111
+ input_token_embs=token_embs,
112
+ return_token_embs=False
113
+ )[0]
114
+
115
+ # Restore the original dtype of prompt_embeds: float16 -> float32.
116
+ prompt_embeds = prompt_embeds.to(face_embs_dtype)
117
+
118
+ if return_full_and_core_embs:
119
+ # token 4: 'id' in "photo of a id person".
120
+ # 4:20 are the most important 16 embeddings that contain the subject's identity.
121
+ # [N, 77, 768] -> [N, 16, 768]
122
+ return prompt_embeds, prompt_embeds[:, 4:20]
123
+ else:
124
+ # [N, 16, 768]
125
+ return prompt_embeds[:, 4:20]
126
+
127
+ def get_b_core_e_embeddings(prompt_embeds, length=22):
128
+ b_core_e_embs = torch.cat([ prompt_embeds[:, :length], prompt_embeds[:, [-1]] ], dim=1)
129
+ return b_core_e_embs
130
+
131
+ # return_emb_types: a list of strings, each string is among ['full', 'core', 'full_zeroed_extra', 'b_core_e'].
132
+ def arc2face_inverse_face_prompt_embs(clip_tokenizer, inverse_text_encoder, face_prompt_embs, list_extra_words,
133
+ return_emb_types, pad_embeddings, hidden_state_layer_weights=None,
134
+ input_max_length=77, zs_extra_words_scale=0.5):
135
+
136
+ '''
137
+ inverse_text_encoder: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**.
138
+ inverse_text_encoder is NOT the original arc2face text encoder, but retrained to do inverse mapping.
139
+ face_prompt_embs: (BS, 16, 768). Only the core embeddings, no paddings.
140
+ list_extra_words: [s_1, ..., s_BS], each s_i is a list of extra words to be added to the prompt.
141
+ return_full_and_core_embs: Return both the full prompt embeddings and the core embeddings.
142
+ If False, return only the core embeddings.
143
+ '''
144
+
145
+ if list_extra_words is not None:
146
+ if len(list_extra_words) != len(face_prompt_embs):
147
+ if len(face_prompt_embs) > 1:
148
+ print("Warn: list_extra_words has different length as face_prompt_embs.")
149
+ if len(list_extra_words) == 1:
150
+ list_extra_words = list_extra_words * len(face_prompt_embs)
151
+ else:
152
+ breakpoint()
153
+ else:
154
+ # len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in do_mix_prompt_distillation.
155
+ # But list_extra_words always corresponds to the actual batch size. So we only take the first element.
156
+ list_extra_words = list_extra_words[:1]
157
+
158
+ for extra_words in list_extra_words:
159
+ assert len(extra_words.split()) <= 2, "Each extra_words string should consist of at most 2 words."
160
+ # 16 ", " are placeholders for face_prompt_embs.
161
+ prompt_templates = [ "photo of a " + ", " * 16 + list_extra_words[i] for i in range(len(list_extra_words)) ]
162
+ else:
163
+ # 16 ", " are placeholders for face_prompt_embs.
164
+ # No extra words are added to the prompt.
165
+ prompt_templates = [ "photo of a " + ", " * 16 for _ in range(len(face_prompt_embs)) ]
166
+
167
+ # This step should be quite fast, and there's no need to cache the input_ids.
168
+ # input_ids: [BS, 77].
169
+ input_ids = clip_tokenizer(
170
+ prompt_templates,
171
+ truncation=True,
172
+ padding="max_length",
173
+ max_length=input_max_length,
174
+ return_tensors="pt",
175
+ ).input_ids.to(face_prompt_embs.device)
176
+
177
+ face_prompt_embs_dtype = face_prompt_embs.dtype
178
+ face_prompt_embs = face_prompt_embs.to(inverse_text_encoder.dtype)
179
+
180
+ # token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping).
181
+ token_embs = inverse_text_encoder(input_ids=input_ids, return_token_embs=True)
182
+ # token 4: first ", " in the template prompt.
183
+ # Replace embeddings of 16 placeholder ", " with face_prompt_embs.
184
+ token_embs[:, 4:20] = face_prompt_embs
185
+
186
+ # This call does the ordinary CLIP text encoding pass.
187
+ prompt_embeds = inverse_text_encoder(
188
+ input_ids=input_ids,
189
+ input_token_embs=token_embs,
190
+ hidden_state_layer_weights=hidden_state_layer_weights,
191
+ return_token_embs=False
192
+ )[0]
193
+
194
+ # Restore the original dtype of prompt_embeds: float16 -> float32.
195
+ prompt_embeds = prompt_embeds.to(face_prompt_embs_dtype)
196
+ # token 4: first ", " in the template prompt.
197
+ # 4:20 are the most important 16 embeddings that contain the subject's identity.
198
+ # 20:22 are embeddings of the (at most) two extra words.
199
+ # [N, 77, 768] -> [N, 16, 768]
200
+ core_prompt_embs = prompt_embeds[:, 4:20]
201
+ if list_extra_words is not None:
202
+ # [N, 16, 768] -> [N, 18, 768]
203
+ extra_words_embs = prompt_embeds[:, 20:22] * zs_extra_words_scale
204
+ core_prompt_embs = torch.cat([core_prompt_embs, extra_words_embs], dim=1)
205
+
206
+ return_prompts = []
207
+ for emb_type in return_emb_types:
208
+ if emb_type == 'full':
209
+ return_prompts.append(prompt_embeds)
210
+ elif emb_type == 'full_half_pad':
211
+ prompt_embeds2 = prompt_embeds.clone()
212
+ PADS = prompt_embeds2.shape[1] - 23
213
+ if PADS >= 2:
214
+ # Fill half of the remaining embeddings with pad embeddings.
215
+ prompt_embeds2[:, 22:22+PADS//2] = pad_embeddings[22:22+PADS//2]
216
+ return_prompts.append(prompt_embeds2)
217
+ elif emb_type == 'full_pad':
218
+ prompt_embeds2 = prompt_embeds.clone()
219
+ # Fill the 22nd to the second last embeddings with pad embeddings.
220
+ prompt_embeds2[:, 22:-1] = pad_embeddings[22:-1]
221
+ return_prompts.append(prompt_embeds2)
222
+ elif emb_type == 'core':
223
+ return_prompts.append(core_prompt_embs)
224
+ elif emb_type == 'full_zeroed_extra':
225
+ prompt_embeds2 = prompt_embeds.clone()
226
+ # Only add two pad embeddings. The remaining embeddings are set to 0.
227
+ # Make the positional embeddings align with the actual positions.
228
+ prompt_embeds2[:, 22:24] = pad_embeddings[22:24]
229
+ prompt_embeds2[:, 24:-1] = 0
230
+ return_prompts.append(prompt_embeds2)
231
+ elif emb_type == 'b_core_e':
232
+ # The first 22 embeddings, plus the last EOS embedding.
233
+ b_core_e_embs = get_b_core_e_embeddings(prompt_embeds, length=22)
234
+ return_prompts.append(b_core_e_embs)
235
+ else:
236
+ breakpoint()
237
+
238
+ return return_prompts
239
+
240
+ # if pre_face_embs is None, generate random face embeddings [BS, 512].
241
+ # image_folder is passed only for logging purpose. image_paths contains the paths of the images.
242
+ def get_arc2face_id_prompt_embs(face_app, clip_tokenizer, arc2face_text_encoder,
243
+ extract_faceid_embeds, pre_face_embs,
244
+ image_folder, image_paths, images_np,
245
+ id_batch_size, device,
246
+ input_max_length=77, noise_level=0.0,
247
+ return_core_id_embs=False,
248
+ gen_neg_prompt=False, verbose=False):
249
+ if extract_faceid_embeds:
250
+ image_count = 0
251
+ faceid_embeds = []
252
+ if image_paths is not None:
253
+ images_np = []
254
+ for image_path in image_paths:
255
+ image_np = np.array(Image.open(image_path))
256
+ images_np.append(image_np)
257
+
258
+ for i, image_np in enumerate(images_np):
259
+ image_obj = Image.fromarray(image_np).resize((512, 512), Image.NEAREST)
260
+ # Remove alpha channel if it exists.
261
+ if image_obj.mode == 'RGBA':
262
+ image_obj = image_obj.convert('RGB')
263
+ # This seems NOT a bug. The input image should be in BGR format, as per
264
+ # https://github.com/deepinsight/insightface/issues/524
265
+ image_np = cv2.cvtColor(np.array(image_obj), cv2.COLOR_RGB2BGR)
266
+ image_np = np.array(image_obj)
267
+
268
+ face_infos = face_app.get(image_np)
269
+ if verbose and image_paths is not None:
270
+ print(image_paths[i], len(face_infos))
271
+ # Assume all images belong to the same subject. Therefore, we can skip the images with no face detected.
272
+ if len(face_infos) == 0:
273
+ continue
274
+ # only use the maximum face
275
+ face_info = sorted(face_infos, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1]
276
+ # Each faceid_embed: [1, 512]
277
+ faceid_embeds.append(torch.from_numpy(face_info.normed_embedding).unsqueeze(0))
278
+ image_count += 1
279
+
280
+ if verbose:
281
+ if image_folder is not None:
282
+ print(f"Extracted ID embeddings from {image_count} images in {image_folder}")
283
+ else:
284
+ print(f"Extracted ID embeddings from {image_count} images")
285
+
286
+ if len(faceid_embeds) == 0:
287
+ print("No face detected")
288
+ breakpoint()
289
+
290
+ # faceid_embeds: [10, 512]
291
+ faceid_embeds = torch.cat(faceid_embeds, dim=0)
292
+ # faceid_embeds: [10, 512] -> [1, 512].
293
+ # and the resulted prompt embeddings are the same.
294
+ faceid_embeds = faceid_embeds.mean(dim=0, keepdim=True).to(torch.float16).to(device)
295
+ else:
296
+ # Random face embeddings. faceid_embeds: [BS, 512].
297
+ if pre_face_embs is None:
298
+ faceid_embeds = torch.randn(id_batch_size, 512)
299
+ else:
300
+ faceid_embeds = pre_face_embs
301
+ if pre_face_embs.shape[0] == 1:
302
+ faceid_embeds = faceid_embeds.repeat(id_batch_size, 1)
303
+
304
+ faceid_embeds = faceid_embeds.to(torch.float16).to(device)
305
+
306
+ if noise_level > 0:
307
+ # If id_batch_size > 1, after adding noises, the id_batch_size embeddings will be different.
308
+ faceid_embeds = add_noise_to_tensor(faceid_embeds, noise_level, noise_std_is_relative=True, keep_norm=True)
309
+
310
+ faceid_embeds = F.normalize(faceid_embeds, p=2, dim=-1)
311
+
312
+ # arc2face_pos_prompt_emb, arc2face_neg_prompt_emb: [BS, 77, 768]
313
+ with torch.no_grad():
314
+ arc2face_pos_prompt_emb, arc2face_pos_core_prompt_emb = \
315
+ arc2face_forward_face_embs(clip_tokenizer, arc2face_text_encoder,
316
+ faceid_embeds, input_max_length=input_max_length,
317
+ return_full_and_core_embs=True)
318
+ if return_core_id_embs:
319
+ arc2face_pos_prompt_emb = arc2face_pos_core_prompt_emb
320
+ # If extract_faceid_embeds, we assume all images are from the same subject, and the batch dim of faceid_embeds is 1.
321
+ # So we need to repeat faceid_embeds.
322
+ if extract_faceid_embeds:
323
+ faceid_embeds = faceid_embeds.repeat(id_batch_size, 1)
324
+ arc2face_pos_prompt_emb = arc2face_pos_prompt_emb.repeat(id_batch_size, 1, 1)
325
+
326
+ if gen_neg_prompt:
327
+ with torch.no_grad():
328
+ arc2face_neg_prompt_emb, arc2face_neg_core_prompt_emb = \
329
+ arc2face_forward_face_embs(clip_tokenizer, arc2face_text_encoder,
330
+ torch.zeros_like(faceid_embeds),
331
+ input_max_length=input_max_length,
332
+ return_full_and_core_embs=True)
333
+ if return_core_id_embs:
334
+ arc2face_neg_prompt_emb = arc2face_neg_core_prompt_emb
335
+
336
+ #if extract_faceid_embeds:
337
+ # arc2face_neg_prompt_emb = arc2face_neg_prompt_emb.repeat(id_batch_size, 1, 1)
338
+ return faceid_embeds, arc2face_pos_prompt_emb, arc2face_neg_prompt_emb
339
+ else:
340
+ return faceid_embeds, arc2face_pos_prompt_emb
341
+
animatediff/models/attention.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+ from diffusers.models.attention import FeedForward, AdaLayerNorm,Attention
15
+
16
+ from einops import rearrange, repeat
17
+ import pdb
18
+
19
+ from diffusers.models.attention_processor import AttnProcessor,AttnProcessor2_0
20
+ @dataclass
21
+ class Transformer3DModelOutput(BaseOutput):
22
+ sample: torch.FloatTensor
23
+ from diffusers.utils import logging
24
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
+
26
+ if is_xformers_available():
27
+ import xformers
28
+ import xformers.ops
29
+ else:
30
+ xformers = None
31
+
32
+
33
+ class Transformer3DModel(ModelMixin, ConfigMixin):
34
+ @register_to_config
35
+ def __init__(
36
+ self,
37
+ num_attention_heads: int = 16,
38
+ attention_head_dim: int = 88,
39
+ in_channels: Optional[int] = None,
40
+ num_layers: int = 1,
41
+ dropout: float = 0.0,
42
+ norm_num_groups: int = 32,
43
+ cross_attention_dim: Optional[int] = None,
44
+ attention_bias: bool = False,
45
+ activation_fn: str = "geglu",
46
+ num_embeds_ada_norm: Optional[int] = None,
47
+ use_linear_projection: bool = False,
48
+ only_cross_attention: bool = False,
49
+ upcast_attention: bool = False,
50
+ unet_use_cross_frame_attention=None,
51
+ unet_use_temporal_attention=None,
52
+ processor: Optional["AttnProcessor"] = None,
53
+ ):
54
+ super().__init__()
55
+ self.use_linear_projection = use_linear_projection
56
+ self.num_attention_heads = num_attention_heads
57
+ self.attention_head_dim = attention_head_dim
58
+ inner_dim = num_attention_heads * attention_head_dim
59
+
60
+ # Define input layers
61
+ self.in_channels = in_channels
62
+
63
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
64
+ if use_linear_projection:
65
+ self.proj_in = nn.Linear(in_channels, inner_dim)
66
+ else:
67
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
68
+
69
+ # Define transformers blocks
70
+ self.transformer_blocks = nn.ModuleList(
71
+ [
72
+ BasicTransformerBlock(
73
+ inner_dim,
74
+ num_attention_heads,
75
+ attention_head_dim,
76
+ dropout=dropout,
77
+ cross_attention_dim=cross_attention_dim,
78
+ activation_fn=activation_fn,
79
+ num_embeds_ada_norm=num_embeds_ada_norm,
80
+ attention_bias=attention_bias,
81
+ only_cross_attention=only_cross_attention,
82
+ upcast_attention=upcast_attention,
83
+
84
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
85
+ unet_use_temporal_attention=unet_use_temporal_attention,
86
+ )
87
+ for d in range(num_layers)
88
+ ]
89
+ )
90
+
91
+ # 4. Define output layers
92
+ if use_linear_projection:
93
+ self.proj_out = nn.Linear(in_channels, inner_dim)
94
+ else:
95
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
96
+ # if processor is None:
97
+ # processor = (
98
+ # AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
99
+ # )
100
+ # self.set_processor(processor)
101
+ # def set_processor(self, processor: "AttnProcessor") -> None:
102
+ # r"""
103
+ # Set the attention processor to use.
104
+
105
+ # Args:
106
+ # processor (`AttnProcessor`):
107
+ # The attention processor to use.
108
+ # """
109
+ # # if current processor is in `self._modules` and if passed `processor` is not, we need to
110
+ # # pop `processor` from `self._modules`
111
+ # if (
112
+ # hasattr(self, "processor")
113
+ # and isinstance(self.processor, torch.nn.Module)
114
+ # and not isinstance(processor, torch.nn.Module)
115
+ # ):
116
+ # logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
117
+ # self._modules.pop("processor")
118
+
119
+ # self.processor = processor
120
+
121
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
122
+ # Input
123
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
124
+ video_length = hidden_states.shape[2]
125
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
126
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
127
+
128
+ batch, channel, height, weight = hidden_states.shape
129
+ residual = hidden_states
130
+
131
+ hidden_states = self.norm(hidden_states)
132
+ if not self.use_linear_projection:
133
+ hidden_states = self.proj_in(hidden_states)
134
+ inner_dim = hidden_states.shape[1]
135
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
136
+ else:
137
+ inner_dim = hidden_states.shape[1]
138
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
139
+ hidden_states = self.proj_in(hidden_states)
140
+
141
+ # Blocks
142
+ for block in self.transformer_blocks:
143
+ hidden_states = block(
144
+ hidden_states,
145
+ encoder_hidden_states=encoder_hidden_states,
146
+ timestep=timestep,
147
+ video_length=video_length
148
+ )
149
+
150
+ # Output
151
+ if not self.use_linear_projection:
152
+ hidden_states = (
153
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
154
+ )
155
+ hidden_states = self.proj_out(hidden_states)
156
+ else:
157
+ hidden_states = self.proj_out(hidden_states)
158
+ hidden_states = (
159
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
160
+ )
161
+
162
+ output = hidden_states + residual
163
+
164
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
165
+ if not return_dict:
166
+ return (output,)
167
+
168
+ return Transformer3DModelOutput(sample=output)
169
+
170
+
171
+ class BasicTransformerBlock(nn.Module):
172
+ def __init__(
173
+ self,
174
+ dim: int,
175
+ num_attention_heads: int,
176
+ attention_head_dim: int,
177
+ dropout=0.0,
178
+ cross_attention_dim: Optional[int] = None,
179
+ activation_fn: str = "geglu",
180
+ num_embeds_ada_norm: Optional[int] = None,
181
+ attention_bias: bool = False,
182
+ only_cross_attention: bool = False,
183
+ upcast_attention: bool = False,
184
+
185
+ unet_use_cross_frame_attention = None,
186
+ unet_use_temporal_attention = None,
187
+ ):
188
+ super().__init__()
189
+ self.only_cross_attention = only_cross_attention
190
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
191
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
192
+ self.unet_use_temporal_attention = unet_use_temporal_attention
193
+
194
+ # SC-Attn
195
+ assert unet_use_cross_frame_attention is not None
196
+ if unet_use_cross_frame_attention:
197
+ self.attn1 = SparseCausalAttention2D(
198
+ query_dim=dim,
199
+ heads=num_attention_heads,
200
+ dim_head=attention_head_dim,
201
+ dropout=dropout,
202
+ bias=attention_bias,
203
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
204
+ upcast_attention=upcast_attention,
205
+ )
206
+ else:
207
+ #self-attention
208
+ self.attn1 = Attention(
209
+ query_dim=dim,
210
+ heads=num_attention_heads,
211
+ dim_head=attention_head_dim,
212
+ dropout=dropout,
213
+ bias=attention_bias,
214
+ upcast_attention=upcast_attention,
215
+ cross_attention_dim=None,
216
+ )
217
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
218
+
219
+ # Cross-Attn
220
+ if cross_attention_dim is not None:
221
+ self.attn2 = Attention(
222
+ query_dim=dim,
223
+ cross_attention_dim=cross_attention_dim,
224
+ heads=num_attention_heads,
225
+ dim_head=attention_head_dim,
226
+ dropout=dropout,
227
+ bias=attention_bias,
228
+ upcast_attention=upcast_attention,
229
+ )
230
+ else:
231
+ self.attn2 = None
232
+
233
+ if cross_attention_dim is not None:
234
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
235
+ else:
236
+ self.norm2 = None
237
+
238
+ # Feed-forward
239
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
240
+ self.norm3 = nn.LayerNorm(dim)
241
+
242
+ # Temp-Attn
243
+ assert unet_use_temporal_attention is not None
244
+ if unet_use_temporal_attention:
245
+ self.attn_temp = Attention(
246
+ query_dim=dim,
247
+ heads=num_attention_heads,
248
+ dim_head=attention_head_dim,
249
+ dropout=dropout,
250
+ bias=attention_bias,
251
+ upcast_attention=upcast_attention,
252
+ )
253
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
254
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
255
+
256
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool,attention_op = None):
257
+ if not is_xformers_available():
258
+ print("Here is how to install it")
259
+ raise ModuleNotFoundError(
260
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
261
+ " xformers",
262
+ name="xformers",
263
+ )
264
+ elif not torch.cuda.is_available():
265
+ raise ValueError(
266
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
267
+ " available for GPU "
268
+ )
269
+ else:
270
+ try:
271
+ # Make sure we can run the memory efficient attention
272
+ _ = xformers.ops.memory_efficient_attention(
273
+ torch.randn((1, 2, 40), device="cuda"),
274
+ torch.randn((1, 2, 40), device="cuda"),
275
+ torch.randn((1, 2, 40), device="cuda"),
276
+ )
277
+ except Exception as e:
278
+ raise e
279
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
280
+ if self.attn2 is not None:
281
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
282
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
283
+
284
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
285
+ # SparseCausal-Attention
286
+ norm_hidden_states = (
287
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
288
+ )
289
+
290
+ # if self.only_cross_attention:
291
+ # hidden_states = (
292
+ # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
293
+ # )
294
+ # else:
295
+ # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
296
+
297
+ # pdb.set_trace()
298
+ if self.unet_use_cross_frame_attention:
299
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
300
+ else:
301
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
302
+
303
+ if self.attn2 is not None:
304
+ # Cross-Attention
305
+ norm_hidden_states = (
306
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
307
+ )
308
+ hidden_states = (
309
+ self.attn2(
310
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
311
+ )
312
+ + hidden_states
313
+ )
314
+ # Feed-forward
315
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
316
+
317
+ # Temporal-Attention
318
+ if self.unet_use_temporal_attention:
319
+ d = hidden_states.shape[1]
320
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
321
+ norm_hidden_states = (
322
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
323
+ )
324
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
325
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
326
+
327
+ return hidden_states
animatediff/models/attention_bkp.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+ from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
15
+
16
+ from einops import rearrange, repeat
17
+ import pdb
18
+
19
+ from diffusers.models.attention_processor import AttnProcessor,AttnProcessor2_0
20
+ @dataclass
21
+ class Transformer3DModelOutput(BaseOutput):
22
+ sample: torch.FloatTensor
23
+ from diffusers.utils import logging
24
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
+
26
+ if is_xformers_available():
27
+ import xformers
28
+ import xformers.ops
29
+ else:
30
+ xformers = None
31
+
32
+
33
+ class Transformer3DModel(ModelMixin, ConfigMixin):
34
+ @register_to_config
35
+ def __init__(
36
+ self,
37
+ num_attention_heads: int = 16,
38
+ attention_head_dim: int = 88,
39
+ in_channels: Optional[int] = None,
40
+ num_layers: int = 1,
41
+ dropout: float = 0.0,
42
+ norm_num_groups: int = 32,
43
+ cross_attention_dim: Optional[int] = None,
44
+ attention_bias: bool = False,
45
+ activation_fn: str = "geglu",
46
+ num_embeds_ada_norm: Optional[int] = None,
47
+ use_linear_projection: bool = False,
48
+ only_cross_attention: bool = False,
49
+ upcast_attention: bool = False,
50
+ unet_use_cross_frame_attention=None,
51
+ unet_use_temporal_attention=None,
52
+ processor: Optional["AttnProcessor"] = None,
53
+ ):
54
+ super().__init__()
55
+ self.use_linear_projection = use_linear_projection
56
+ self.num_attention_heads = num_attention_heads
57
+ self.attention_head_dim = attention_head_dim
58
+ inner_dim = num_attention_heads * attention_head_dim
59
+
60
+ # Define input layers
61
+ self.in_channels = in_channels
62
+
63
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
64
+ if use_linear_projection:
65
+ self.proj_in = nn.Linear(in_channels, inner_dim)
66
+ else:
67
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
68
+
69
+ # Define transformers blocks
70
+ self.transformer_blocks = nn.ModuleList(
71
+ [
72
+ BasicTransformerBlock(
73
+ inner_dim,
74
+ num_attention_heads,
75
+ attention_head_dim,
76
+ dropout=dropout,
77
+ cross_attention_dim=cross_attention_dim,
78
+ activation_fn=activation_fn,
79
+ num_embeds_ada_norm=num_embeds_ada_norm,
80
+ attention_bias=attention_bias,
81
+ only_cross_attention=only_cross_attention,
82
+ upcast_attention=upcast_attention,
83
+
84
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
85
+ unet_use_temporal_attention=unet_use_temporal_attention,
86
+ )
87
+ for d in range(num_layers)
88
+ ]
89
+ )
90
+
91
+ # 4. Define output layers
92
+ if use_linear_projection:
93
+ self.proj_out = nn.Linear(in_channels, inner_dim)
94
+ else:
95
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
96
+ # if processor is None:
97
+ # processor = (
98
+ # AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
99
+ # )
100
+ # self.set_processor(processor)
101
+ def set_processor(self, processor: "AttnProcessor") -> None:
102
+ r"""
103
+ Set the attention processor to use.
104
+
105
+ Args:
106
+ processor (`AttnProcessor`):
107
+ The attention processor to use.
108
+ """
109
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
110
+ # pop `processor` from `self._modules`
111
+ if (
112
+ hasattr(self, "processor")
113
+ and isinstance(self.processor, torch.nn.Module)
114
+ and not isinstance(processor, torch.nn.Module)
115
+ ):
116
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
117
+ self._modules.pop("processor")
118
+
119
+ self.processor = processor
120
+
121
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
122
+ # Input
123
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
124
+ video_length = hidden_states.shape[2]
125
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
126
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
127
+
128
+ batch, channel, height, weight = hidden_states.shape
129
+ residual = hidden_states
130
+
131
+ hidden_states = self.norm(hidden_states)
132
+ if not self.use_linear_projection:
133
+ hidden_states = self.proj_in(hidden_states)
134
+ inner_dim = hidden_states.shape[1]
135
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
136
+ else:
137
+ inner_dim = hidden_states.shape[1]
138
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
139
+ hidden_states = self.proj_in(hidden_states)
140
+
141
+ # Blocks
142
+ for block in self.transformer_blocks:
143
+ hidden_states = block(
144
+ hidden_states,
145
+ encoder_hidden_states=encoder_hidden_states,
146
+ timestep=timestep,
147
+ video_length=video_length
148
+ )
149
+
150
+ # Output
151
+ if not self.use_linear_projection:
152
+ hidden_states = (
153
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
154
+ )
155
+ hidden_states = self.proj_out(hidden_states)
156
+ else:
157
+ hidden_states = self.proj_out(hidden_states)
158
+ hidden_states = (
159
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
160
+ )
161
+
162
+ output = hidden_states + residual
163
+
164
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
165
+ if not return_dict:
166
+ return (output,)
167
+
168
+ return Transformer3DModelOutput(sample=output)
169
+
170
+
171
+ class BasicTransformerBlock(nn.Module):
172
+ def __init__(
173
+ self,
174
+ dim: int,
175
+ num_attention_heads: int,
176
+ attention_head_dim: int,
177
+ dropout=0.0,
178
+ cross_attention_dim: Optional[int] = None,
179
+ activation_fn: str = "geglu",
180
+ num_embeds_ada_norm: Optional[int] = None,
181
+ attention_bias: bool = False,
182
+ only_cross_attention: bool = False,
183
+ upcast_attention: bool = False,
184
+
185
+ unet_use_cross_frame_attention = None,
186
+ unet_use_temporal_attention = None,
187
+ ):
188
+ super().__init__()
189
+ self.only_cross_attention = only_cross_attention
190
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
191
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
192
+ self.unet_use_temporal_attention = unet_use_temporal_attention
193
+
194
+ # SC-Attn
195
+ assert unet_use_cross_frame_attention is not None
196
+ if unet_use_cross_frame_attention:
197
+ self.attn1 = SparseCausalAttention2D(
198
+ query_dim=dim,
199
+ heads=num_attention_heads,
200
+ dim_head=attention_head_dim,
201
+ dropout=dropout,
202
+ bias=attention_bias,
203
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
204
+ upcast_attention=upcast_attention,
205
+ )
206
+ else:
207
+ self.attn1 = CrossAttention(
208
+ query_dim=dim,
209
+ heads=num_attention_heads,
210
+ dim_head=attention_head_dim,
211
+ dropout=dropout,
212
+ bias=attention_bias,
213
+ upcast_attention=upcast_attention,
214
+ )
215
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
216
+
217
+ # Cross-Attn
218
+ if cross_attention_dim is not None:
219
+ self.attn2 = CrossAttention(
220
+ query_dim=dim,
221
+ cross_attention_dim=cross_attention_dim,
222
+ heads=num_attention_heads,
223
+ dim_head=attention_head_dim,
224
+ dropout=dropout,
225
+ bias=attention_bias,
226
+ upcast_attention=upcast_attention,
227
+ )
228
+ else:
229
+ self.attn2 = None
230
+
231
+ if cross_attention_dim is not None:
232
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
233
+ else:
234
+ self.norm2 = None
235
+
236
+ # Feed-forward
237
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
238
+ self.norm3 = nn.LayerNorm(dim)
239
+
240
+ # Temp-Attn
241
+ assert unet_use_temporal_attention is not None
242
+ if unet_use_temporal_attention:
243
+ self.attn_temp = CrossAttention(
244
+ query_dim=dim,
245
+ heads=num_attention_heads,
246
+ dim_head=attention_head_dim,
247
+ dropout=dropout,
248
+ bias=attention_bias,
249
+ upcast_attention=upcast_attention,
250
+ )
251
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
252
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
253
+
254
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool,attention_op = None):
255
+ if not is_xformers_available():
256
+ print("Here is how to install it")
257
+ raise ModuleNotFoundError(
258
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
259
+ " xformers",
260
+ name="xformers",
261
+ )
262
+ elif not torch.cuda.is_available():
263
+ raise ValueError(
264
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
265
+ " available for GPU "
266
+ )
267
+ else:
268
+ try:
269
+ # Make sure we can run the memory efficient attention
270
+ _ = xformers.ops.memory_efficient_attention(
271
+ torch.randn((1, 2, 40), device="cuda"),
272
+ torch.randn((1, 2, 40), device="cuda"),
273
+ torch.randn((1, 2, 40), device="cuda"),
274
+ )
275
+ except Exception as e:
276
+ raise e
277
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
278
+ if self.attn2 is not None:
279
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
280
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
281
+
282
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
283
+ # SparseCausal-Attention
284
+ norm_hidden_states = (
285
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
286
+ )
287
+
288
+ # if self.only_cross_attention:
289
+ # hidden_states = (
290
+ # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
291
+ # )
292
+ # else:
293
+ # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
294
+
295
+ # pdb.set_trace()
296
+ if self.unet_use_cross_frame_attention:
297
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
298
+ else:
299
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
300
+
301
+ if self.attn2 is not None:
302
+ # Cross-Attention
303
+ norm_hidden_states = (
304
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
305
+ )
306
+ hidden_states = (
307
+ self.attn2(
308
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
309
+ )
310
+ + hidden_states
311
+ )
312
+
313
+ # Feed-forward
314
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
315
+
316
+ # Temporal-Attention
317
+ if self.unet_use_temporal_attention:
318
+ d = hidden_states.shape[1]
319
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
320
+ norm_hidden_states = (
321
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
322
+ )
323
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
324
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
325
+
326
+ return hidden_states
animatediff/models/motion_module.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ import torchvision
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+ from diffusers.models.attention import FeedForward,Attention
15
+
16
+ from einops import rearrange, repeat
17
+ import math
18
+
19
+
20
+ def zero_module(module):
21
+ # Zero out the parameters of a module and return it.
22
+ for p in module.parameters():
23
+ p.detach().zero_()
24
+ return module
25
+
26
+
27
+ @dataclass
28
+ class TemporalTransformer3DModelOutput(BaseOutput):
29
+ sample: torch.FloatTensor
30
+
31
+
32
+ if is_xformers_available():
33
+ import xformers
34
+ import xformers.ops
35
+ else:
36
+ xformers = None
37
+
38
+
39
+ def get_motion_module(
40
+ in_channels,
41
+ motion_module_type: str,
42
+ motion_module_kwargs: dict
43
+ ):
44
+ if motion_module_type == "Vanilla":
45
+ return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
46
+ else:
47
+ raise ValueError
48
+
49
+
50
+ class VanillaTemporalModule(nn.Module):
51
+ def __init__(
52
+ self,
53
+ in_channels,
54
+ num_attention_heads = 8,
55
+ num_transformer_block = 2,
56
+ attention_block_types =( "Temporal_Self", "Temporal_Self" ),
57
+ cross_frame_attention_mode = None,
58
+ temporal_position_encoding = False,
59
+ temporal_position_encoding_max_len = 24,
60
+ temporal_attention_dim_div = 1,
61
+ zero_initialize = True,
62
+ ):
63
+ super().__init__()
64
+
65
+ self.temporal_transformer = TemporalTransformer3DModel(
66
+ in_channels=in_channels,
67
+ num_attention_heads=num_attention_heads,
68
+ attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
69
+ num_layers=num_transformer_block,
70
+ attention_block_types=attention_block_types,
71
+ cross_frame_attention_mode=cross_frame_attention_mode,
72
+ temporal_position_encoding=temporal_position_encoding,
73
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
74
+ )
75
+
76
+ if zero_initialize:
77
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
78
+
79
+ def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
80
+ hidden_states = input_tensor
81
+ hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
82
+
83
+ output = hidden_states
84
+ return output
85
+
86
+
87
+ class TemporalTransformer3DModel(nn.Module):
88
+ def __init__(
89
+ self,
90
+ in_channels,
91
+ num_attention_heads,
92
+ attention_head_dim,
93
+
94
+ num_layers,
95
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
96
+ dropout = 0.0,
97
+ norm_num_groups = 32,
98
+ cross_attention_dim = 768,
99
+ activation_fn = "geglu",
100
+ attention_bias = False,
101
+ upcast_attention = False,
102
+
103
+ cross_frame_attention_mode = None,
104
+ temporal_position_encoding = False,
105
+ temporal_position_encoding_max_len = 24,
106
+ ):
107
+ super().__init__()
108
+
109
+ inner_dim = num_attention_heads * attention_head_dim
110
+
111
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
112
+ self.proj_in = nn.Linear(in_channels, inner_dim)
113
+
114
+ self.transformer_blocks = nn.ModuleList(
115
+ [
116
+ TemporalTransformerBlock(
117
+ dim=inner_dim,
118
+ num_attention_heads=num_attention_heads,
119
+ attention_head_dim=attention_head_dim,
120
+ attention_block_types=attention_block_types,
121
+ dropout=dropout,
122
+ norm_num_groups=norm_num_groups,
123
+ cross_attention_dim=cross_attention_dim,
124
+ activation_fn=activation_fn,
125
+ attention_bias=attention_bias,
126
+ upcast_attention=upcast_attention,
127
+ cross_frame_attention_mode=cross_frame_attention_mode,
128
+ temporal_position_encoding=temporal_position_encoding,
129
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
130
+ )
131
+ for d in range(num_layers)
132
+ ]
133
+ )
134
+ self.proj_out = nn.Linear(inner_dim, in_channels)
135
+
136
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
137
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
138
+ video_length = hidden_states.shape[2]
139
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
140
+
141
+ batch, channel, height, weight = hidden_states.shape
142
+ residual = hidden_states
143
+
144
+ hidden_states = self.norm(hidden_states)
145
+ inner_dim = hidden_states.shape[1]
146
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
147
+ hidden_states = self.proj_in(hidden_states)
148
+
149
+ # Transformer Blocks
150
+ for block in self.transformer_blocks:
151
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
152
+
153
+ # output
154
+ hidden_states = self.proj_out(hidden_states)
155
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
156
+
157
+ output = hidden_states + residual
158
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
159
+
160
+ return output
161
+
162
+
163
+ class TemporalTransformerBlock(nn.Module):
164
+ def __init__(
165
+ self,
166
+ dim,
167
+ num_attention_heads,
168
+ attention_head_dim,
169
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
170
+ dropout = 0.0,
171
+ norm_num_groups = 32,
172
+ cross_attention_dim = 768,
173
+ activation_fn = "geglu",
174
+ attention_bias = False,
175
+ upcast_attention = False,
176
+ cross_frame_attention_mode = None,
177
+ temporal_position_encoding = False,
178
+ temporal_position_encoding_max_len = 24,
179
+ ):
180
+ super().__init__()
181
+
182
+ attention_blocks = []
183
+ norms = []
184
+
185
+ for block_name in attention_block_types:
186
+ attention_blocks.append(
187
+ VersatileAttention(
188
+ attention_mode=block_name.split("_")[0],
189
+ cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
190
+
191
+ query_dim=dim,
192
+ heads=num_attention_heads,
193
+ dim_head=attention_head_dim,
194
+ dropout=dropout,
195
+ bias=attention_bias,
196
+ upcast_attention=upcast_attention,
197
+
198
+ cross_frame_attention_mode=cross_frame_attention_mode,
199
+ temporal_position_encoding=temporal_position_encoding,
200
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
201
+ )
202
+ )
203
+ norms.append(nn.LayerNorm(dim))
204
+
205
+ self.attention_blocks = nn.ModuleList(attention_blocks)
206
+ self.norms = nn.ModuleList(norms)
207
+
208
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
209
+ self.ff_norm = nn.LayerNorm(dim)
210
+
211
+
212
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
213
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
214
+ norm_hidden_states = norm(hidden_states)
215
+ hidden_states = attention_block(
216
+ norm_hidden_states,
217
+ encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
218
+ video_length=video_length,
219
+ ) + hidden_states
220
+
221
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
222
+
223
+ output = hidden_states
224
+ return output
225
+
226
+
227
+ class PositionalEncoding(nn.Module):
228
+ def __init__(
229
+ self,
230
+ d_model,
231
+ dropout = 0.,
232
+ max_len = 24
233
+ ):
234
+ super().__init__()
235
+ self.dropout = nn.Dropout(p=dropout)
236
+ position = torch.arange(max_len).unsqueeze(1)
237
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
238
+ pe = torch.zeros(1, max_len, d_model)
239
+ pe[0, :, 0::2] = torch.sin(position * div_term)
240
+ pe[0, :, 1::2] = torch.cos(position * div_term)
241
+ self.register_buffer('pe', pe)
242
+
243
+ def forward(self, x):
244
+ x = x + self.pe[:, :x.size(1)]
245
+ return self.dropout(x)
246
+
247
+ class CrossAttention(nn.Module):
248
+ r"""
249
+ A cross attention layer.
250
+
251
+ Parameters:
252
+ query_dim (`int`): The number of channels in the query.
253
+ cross_attention_dim (`int`, *optional*):
254
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
255
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
256
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
257
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
258
+ bias (`bool`, *optional*, defaults to False):
259
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
260
+ """
261
+
262
+ def __init__(
263
+ self,
264
+ query_dim: int,
265
+ cross_attention_dim: Optional[int] = None,
266
+ heads: int = 8,
267
+ dim_head: int = 64,
268
+ dropout: float = 0.0,
269
+ bias=False,
270
+ upcast_attention: bool = False,
271
+ upcast_softmax: bool = False,
272
+ added_kv_proj_dim: Optional[int] = None,
273
+ norm_num_groups: Optional[int] = None,
274
+ ):
275
+ super().__init__()
276
+ inner_dim = dim_head * heads
277
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
278
+ self.upcast_attention = upcast_attention
279
+ self.upcast_softmax = upcast_softmax
280
+
281
+ self.scale = dim_head**-0.5
282
+
283
+ self.heads = heads
284
+ # for slice_size > 0 the attention score computation
285
+ # is split across the batch axis to save memory
286
+ # You can set slice_size with `set_attention_slice`
287
+ self.sliceable_head_dim = heads
288
+ self._slice_size = None
289
+ self._use_memory_efficient_attention_xformers = False
290
+ self.added_kv_proj_dim = added_kv_proj_dim
291
+
292
+ if norm_num_groups is not None:
293
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
294
+ else:
295
+ self.group_norm = None
296
+
297
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
298
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
299
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
300
+
301
+ if self.added_kv_proj_dim is not None:
302
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
303
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
304
+
305
+ self.to_out = nn.ModuleList([])
306
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
307
+ self.to_out.append(nn.Dropout(dropout))
308
+
309
+ def reshape_heads_to_batch_dim(self, tensor):
310
+ batch_size, seq_len, dim = tensor.shape
311
+ head_size = self.heads
312
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
313
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
314
+ return tensor
315
+
316
+ def reshape_batch_dim_to_heads(self, tensor):
317
+ batch_size, seq_len, dim = tensor.shape
318
+ head_size = self.heads
319
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
320
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
321
+ return tensor
322
+
323
+ def set_attention_slice(self, slice_size):
324
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
325
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
326
+
327
+ self._slice_size = slice_size
328
+
329
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
330
+ batch_size, sequence_length, _ = hidden_states.shape
331
+
332
+ encoder_hidden_states = encoder_hidden_states
333
+
334
+ if self.group_norm is not None:
335
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
336
+
337
+ query = self.to_q(hidden_states)
338
+ dim = query.shape[-1]
339
+ query = self.reshape_heads_to_batch_dim(query)
340
+
341
+ if self.added_kv_proj_dim is not None:
342
+ key = self.to_k(hidden_states)
343
+ value = self.to_v(hidden_states)
344
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
345
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
346
+
347
+ key = self.reshape_heads_to_batch_dim(key)
348
+ value = self.reshape_heads_to_batch_dim(value)
349
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
350
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
351
+
352
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
353
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
354
+ else:
355
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
356
+ key = self.to_k(encoder_hidden_states)
357
+ value = self.to_v(encoder_hidden_states)
358
+
359
+ key = self.reshape_heads_to_batch_dim(key)
360
+ value = self.reshape_heads_to_batch_dim(value)
361
+
362
+ if attention_mask is not None:
363
+ if attention_mask.shape[-1] != query.shape[1]:
364
+ target_length = query.shape[1]
365
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
366
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
367
+
368
+ # attention, what we cannot get enough of
369
+ if self._use_memory_efficient_attention_xformers:
370
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
371
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
372
+ hidden_states = hidden_states.to(query.dtype)
373
+ else:
374
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
375
+ hidden_states = self._attention(query, key, value, attention_mask)
376
+ else:
377
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
378
+
379
+ # linear proj
380
+ hidden_states = self.to_out[0](hidden_states)
381
+
382
+ # dropout
383
+ hidden_states = self.to_out[1](hidden_states)
384
+ return hidden_states
385
+
386
+ def _attention(self, query, key, value, attention_mask=None):
387
+ if self.upcast_attention:
388
+ query = query.float()
389
+ key = key.float()
390
+
391
+ attention_scores = torch.baddbmm(
392
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
393
+ query,
394
+ key.transpose(-1, -2),
395
+ beta=0,
396
+ alpha=self.scale,
397
+ )
398
+
399
+ if attention_mask is not None:
400
+ attention_scores = attention_scores + attention_mask
401
+
402
+ if self.upcast_softmax:
403
+ attention_scores = attention_scores.float()
404
+
405
+ attention_probs = attention_scores.softmax(dim=-1)
406
+
407
+ # cast back to the original dtype
408
+ attention_probs = attention_probs.to(value.dtype)
409
+
410
+ # compute attention output
411
+ hidden_states = torch.bmm(attention_probs, value)
412
+
413
+ # reshape hidden_states
414
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
415
+ return hidden_states
416
+
417
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
418
+ batch_size_attention = query.shape[0]
419
+ hidden_states = torch.zeros(
420
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
421
+ )
422
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
423
+ for i in range(hidden_states.shape[0] // slice_size):
424
+ start_idx = i * slice_size
425
+ end_idx = (i + 1) * slice_size
426
+
427
+ query_slice = query[start_idx:end_idx]
428
+ key_slice = key[start_idx:end_idx]
429
+
430
+ if self.upcast_attention:
431
+ query_slice = query_slice.float()
432
+ key_slice = key_slice.float()
433
+
434
+ attn_slice = torch.baddbmm(
435
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
436
+ query_slice,
437
+ key_slice.transpose(-1, -2),
438
+ beta=0,
439
+ alpha=self.scale,
440
+ )
441
+
442
+ if attention_mask is not None:
443
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
444
+
445
+ if self.upcast_softmax:
446
+ attn_slice = attn_slice.float()
447
+
448
+ attn_slice = attn_slice.softmax(dim=-1)
449
+
450
+ # cast back to the original dtype
451
+ attn_slice = attn_slice.to(value.dtype)
452
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
453
+
454
+ hidden_states[start_idx:end_idx] = attn_slice
455
+
456
+ # reshape hidden_states
457
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
458
+ return hidden_states
459
+
460
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
461
+ # TODO attention_mask
462
+ query = query.contiguous()
463
+ key = key.contiguous()
464
+ value = value.contiguous()
465
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
466
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
467
+ return hidden_states
468
+
469
+ class VersatileAttention(CrossAttention):
470
+ def __init__(
471
+ self,
472
+ attention_mode = None,
473
+ cross_frame_attention_mode = None,
474
+ temporal_position_encoding = False,
475
+ temporal_position_encoding_max_len = 24,
476
+ *args, **kwargs
477
+ ):
478
+ super().__init__(*args, **kwargs)
479
+ assert attention_mode == "Temporal"
480
+
481
+ self.attention_mode = attention_mode
482
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
483
+
484
+ self.pos_encoder = PositionalEncoding(
485
+ kwargs["query_dim"],
486
+ dropout=0.,
487
+ max_len=temporal_position_encoding_max_len
488
+ ) if (temporal_position_encoding and attention_mode == "Temporal") else None
489
+
490
+ def extra_repr(self):
491
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
492
+
493
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
494
+ batch_size, sequence_length, _ = hidden_states.shape
495
+
496
+ if self.attention_mode == "Temporal":
497
+ d = hidden_states.shape[1]
498
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
499
+
500
+ if self.pos_encoder is not None:
501
+ hidden_states = self.pos_encoder(hidden_states)
502
+
503
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
504
+ else:
505
+ raise NotImplementedError
506
+
507
+ encoder_hidden_states = encoder_hidden_states
508
+
509
+ if self.group_norm is not None:
510
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
511
+
512
+ query = self.to_q(hidden_states)
513
+ dim = query.shape[-1]
514
+ query = self.reshape_heads_to_batch_dim(query)
515
+
516
+ if self.added_kv_proj_dim is not None:
517
+ raise NotImplementedError
518
+
519
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
520
+ key = self.to_k(encoder_hidden_states)
521
+ value = self.to_v(encoder_hidden_states)
522
+
523
+ key = self.reshape_heads_to_batch_dim(key)
524
+ value = self.reshape_heads_to_batch_dim(value)
525
+
526
+ if attention_mask is not None:
527
+ if attention_mask.shape[-1] != query.shape[1]:
528
+ target_length = query.shape[1]
529
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
530
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
531
+
532
+ # attention, what we cannot get enough of
533
+ if self._use_memory_efficient_attention_xformers:
534
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
535
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
536
+ hidden_states = hidden_states.to(query.dtype)
537
+ else:
538
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
539
+ hidden_states = self._attention(query, key, value, attention_mask)
540
+ else:
541
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
542
+
543
+ # linear proj
544
+ hidden_states = self.to_out[0](hidden_states)
545
+
546
+ # dropout
547
+ hidden_states = self.to_out[1](hidden_states)
548
+
549
+ if self.attention_mode == "Temporal":
550
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
551
+
552
+ return hidden_states
animatediff/models/motion_module_bkp.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ import torchvision
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers import ModelMixin
12
+ from diffusers.utils import BaseOutput
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+ from diffusers.models.attention import CrossAttention, FeedForward
15
+
16
+ from einops import rearrange, repeat
17
+ import math
18
+
19
+
20
+ def zero_module(module):
21
+ # Zero out the parameters of a module and return it.
22
+ for p in module.parameters():
23
+ p.detach().zero_()
24
+ return module
25
+
26
+
27
+ @dataclass
28
+ class TemporalTransformer3DModelOutput(BaseOutput):
29
+ sample: torch.FloatTensor
30
+
31
+
32
+ if is_xformers_available():
33
+ import xformers
34
+ import xformers.ops
35
+ else:
36
+ xformers = None
37
+
38
+
39
+ def get_motion_module(
40
+ in_channels,
41
+ motion_module_type: str,
42
+ motion_module_kwargs: dict
43
+ ):
44
+ if motion_module_type == "Vanilla":
45
+ return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
46
+ else:
47
+ raise ValueError
48
+
49
+
50
+ class VanillaTemporalModule(nn.Module):
51
+ def __init__(
52
+ self,
53
+ in_channels,
54
+ num_attention_heads = 8,
55
+ num_transformer_block = 2,
56
+ attention_block_types =( "Temporal_Self", "Temporal_Self" ),
57
+ cross_frame_attention_mode = None,
58
+ temporal_position_encoding = False,
59
+ temporal_position_encoding_max_len = 24,
60
+ temporal_attention_dim_div = 1,
61
+ zero_initialize = True,
62
+ ):
63
+ super().__init__()
64
+
65
+ self.temporal_transformer = TemporalTransformer3DModel(
66
+ in_channels=in_channels,
67
+ num_attention_heads=num_attention_heads,
68
+ attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
69
+ num_layers=num_transformer_block,
70
+ attention_block_types=attention_block_types,
71
+ cross_frame_attention_mode=cross_frame_attention_mode,
72
+ temporal_position_encoding=temporal_position_encoding,
73
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
74
+ )
75
+
76
+ if zero_initialize:
77
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
78
+
79
+ def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
80
+ hidden_states = input_tensor
81
+ hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
82
+
83
+ output = hidden_states
84
+ return output
85
+
86
+
87
+ class TemporalTransformer3DModel(nn.Module):
88
+ def __init__(
89
+ self,
90
+ in_channels,
91
+ num_attention_heads,
92
+ attention_head_dim,
93
+
94
+ num_layers,
95
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
96
+ dropout = 0.0,
97
+ norm_num_groups = 32,
98
+ cross_attention_dim = 768,
99
+ activation_fn = "geglu",
100
+ attention_bias = False,
101
+ upcast_attention = False,
102
+
103
+ cross_frame_attention_mode = None,
104
+ temporal_position_encoding = False,
105
+ temporal_position_encoding_max_len = 24,
106
+ ):
107
+ super().__init__()
108
+
109
+ inner_dim = num_attention_heads * attention_head_dim
110
+
111
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
112
+ self.proj_in = nn.Linear(in_channels, inner_dim)
113
+
114
+ self.transformer_blocks = nn.ModuleList(
115
+ [
116
+ TemporalTransformerBlock(
117
+ dim=inner_dim,
118
+ num_attention_heads=num_attention_heads,
119
+ attention_head_dim=attention_head_dim,
120
+ attention_block_types=attention_block_types,
121
+ dropout=dropout,
122
+ norm_num_groups=norm_num_groups,
123
+ cross_attention_dim=cross_attention_dim,
124
+ activation_fn=activation_fn,
125
+ attention_bias=attention_bias,
126
+ upcast_attention=upcast_attention,
127
+ cross_frame_attention_mode=cross_frame_attention_mode,
128
+ temporal_position_encoding=temporal_position_encoding,
129
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
130
+ )
131
+ for d in range(num_layers)
132
+ ]
133
+ )
134
+ self.proj_out = nn.Linear(inner_dim, in_channels)
135
+
136
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
137
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
138
+ video_length = hidden_states.shape[2]
139
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
140
+
141
+ batch, channel, height, weight = hidden_states.shape
142
+ residual = hidden_states
143
+
144
+ hidden_states = self.norm(hidden_states)
145
+ inner_dim = hidden_states.shape[1]
146
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
147
+ hidden_states = self.proj_in(hidden_states)
148
+
149
+ # Transformer Blocks
150
+ for block in self.transformer_blocks:
151
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
152
+
153
+ # output
154
+ hidden_states = self.proj_out(hidden_states)
155
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
156
+
157
+ output = hidden_states + residual
158
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
159
+
160
+ return output
161
+
162
+
163
+ class TemporalTransformerBlock(nn.Module):
164
+ def __init__(
165
+ self,
166
+ dim,
167
+ num_attention_heads,
168
+ attention_head_dim,
169
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
170
+ dropout = 0.0,
171
+ norm_num_groups = 32,
172
+ cross_attention_dim = 768,
173
+ activation_fn = "geglu",
174
+ attention_bias = False,
175
+ upcast_attention = False,
176
+ cross_frame_attention_mode = None,
177
+ temporal_position_encoding = False,
178
+ temporal_position_encoding_max_len = 24,
179
+ ):
180
+ super().__init__()
181
+
182
+ attention_blocks = []
183
+ norms = []
184
+
185
+ for block_name in attention_block_types:
186
+ attention_blocks.append(
187
+ VersatileAttention(
188
+ attention_mode=block_name.split("_")[0],
189
+ cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
190
+
191
+ query_dim=dim,
192
+ heads=num_attention_heads,
193
+ dim_head=attention_head_dim,
194
+ dropout=dropout,
195
+ bias=attention_bias,
196
+ upcast_attention=upcast_attention,
197
+
198
+ cross_frame_attention_mode=cross_frame_attention_mode,
199
+ temporal_position_encoding=temporal_position_encoding,
200
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
201
+ )
202
+ )
203
+ norms.append(nn.LayerNorm(dim))
204
+
205
+ self.attention_blocks = nn.ModuleList(attention_blocks)
206
+ self.norms = nn.ModuleList(norms)
207
+
208
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
209
+ self.ff_norm = nn.LayerNorm(dim)
210
+
211
+
212
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
213
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
214
+ norm_hidden_states = norm(hidden_states)
215
+ hidden_states = attention_block(
216
+ norm_hidden_states,
217
+ encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
218
+ video_length=video_length,
219
+ ) + hidden_states
220
+
221
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
222
+
223
+ output = hidden_states
224
+ return output
225
+
226
+
227
+ class PositionalEncoding(nn.Module):
228
+ def __init__(
229
+ self,
230
+ d_model,
231
+ dropout = 0.,
232
+ max_len = 24
233
+ ):
234
+ super().__init__()
235
+ self.dropout = nn.Dropout(p=dropout)
236
+ position = torch.arange(max_len).unsqueeze(1)
237
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
238
+ pe = torch.zeros(1, max_len, d_model)
239
+ pe[0, :, 0::2] = torch.sin(position * div_term)
240
+ pe[0, :, 1::2] = torch.cos(position * div_term)
241
+ self.register_buffer('pe', pe)
242
+
243
+ def forward(self, x):
244
+ x = x + self.pe[:, :x.size(1)]
245
+ return self.dropout(x)
246
+
247
+
248
+ class VersatileAttention(CrossAttention):
249
+ def __init__(
250
+ self,
251
+ attention_mode = None,
252
+ cross_frame_attention_mode = None,
253
+ temporal_position_encoding = False,
254
+ temporal_position_encoding_max_len = 24,
255
+ *args, **kwargs
256
+ ):
257
+ super().__init__(*args, **kwargs)
258
+ assert attention_mode == "Temporal"
259
+
260
+ self.attention_mode = attention_mode
261
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
262
+
263
+ self.pos_encoder = PositionalEncoding(
264
+ kwargs["query_dim"],
265
+ dropout=0.,
266
+ max_len=temporal_position_encoding_max_len
267
+ ) if (temporal_position_encoding and attention_mode == "Temporal") else None
268
+
269
+ def extra_repr(self):
270
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
271
+
272
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
273
+ batch_size, sequence_length, _ = hidden_states.shape
274
+
275
+ if self.attention_mode == "Temporal":
276
+ d = hidden_states.shape[1]
277
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
278
+
279
+ if self.pos_encoder is not None:
280
+ hidden_states = self.pos_encoder(hidden_states)
281
+
282
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
283
+ else:
284
+ raise NotImplementedError
285
+
286
+ encoder_hidden_states = encoder_hidden_states
287
+
288
+ if self.group_norm is not None:
289
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
290
+
291
+ query = self.to_q(hidden_states)
292
+ dim = query.shape[-1]
293
+ query = self.reshape_heads_to_batch_dim(query)
294
+
295
+ if self.added_kv_proj_dim is not None:
296
+ raise NotImplementedError
297
+
298
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
299
+ key = self.to_k(encoder_hidden_states)
300
+ value = self.to_v(encoder_hidden_states)
301
+
302
+ key = self.reshape_heads_to_batch_dim(key)
303
+ value = self.reshape_heads_to_batch_dim(value)
304
+
305
+ if attention_mask is not None:
306
+ if attention_mask.shape[-1] != query.shape[1]:
307
+ target_length = query.shape[1]
308
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
309
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
310
+
311
+ # attention, what we cannot get enough of
312
+ if self._use_memory_efficient_attention_xformers:
313
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
314
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
315
+ hidden_states = hidden_states.to(query.dtype)
316
+ else:
317
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
318
+ hidden_states = self._attention(query, key, value, attention_mask)
319
+ else:
320
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
321
+
322
+ # linear proj
323
+ hidden_states = self.to_out[0](hidden_states)
324
+
325
+ # dropout
326
+ hidden_states = self.to_out[1](hidden_states)
327
+
328
+ if self.attention_mode == "Temporal":
329
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
330
+
331
+ return hidden_states
animatediff/models/resnet.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from einops import rearrange
8
+
9
+
10
+ class InflatedConv3d(nn.Conv2d):
11
+ def forward(self, x):
12
+ video_length = x.shape[2]
13
+
14
+ x = rearrange(x, "b c f h w -> (b f) c h w")
15
+ x = super().forward(x)
16
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17
+
18
+ return x
19
+
20
+
21
+ class InflatedGroupNorm(nn.GroupNorm):
22
+ def forward(self, x):
23
+ video_length = x.shape[2]
24
+
25
+ x = rearrange(x, "b c f h w -> (b f) c h w")
26
+ x = super().forward(x)
27
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
28
+
29
+ return x
30
+
31
+
32
+ class Upsample3D(nn.Module):
33
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
34
+ super().__init__()
35
+ self.channels = channels
36
+ self.out_channels = out_channels or channels
37
+ self.use_conv = use_conv
38
+ self.use_conv_transpose = use_conv_transpose
39
+ self.name = name
40
+
41
+ conv = None
42
+ if use_conv_transpose:
43
+ raise NotImplementedError
44
+ elif use_conv:
45
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
46
+
47
+ def forward(self, hidden_states, output_size=None):
48
+ assert hidden_states.shape[1] == self.channels
49
+
50
+ if self.use_conv_transpose:
51
+ raise NotImplementedError
52
+
53
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
54
+ dtype = hidden_states.dtype
55
+ if dtype == torch.bfloat16:
56
+ hidden_states = hidden_states.to(torch.float32)
57
+
58
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
59
+ if hidden_states.shape[0] >= 64:
60
+ hidden_states = hidden_states.contiguous()
61
+
62
+ # if `output_size` is passed we force the interpolation output
63
+ # size and do not make use of `scale_factor=2`
64
+ if output_size is None:
65
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
66
+ else:
67
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
68
+
69
+ # If the input is bfloat16, we cast back to bfloat16
70
+ if dtype == torch.bfloat16:
71
+ hidden_states = hidden_states.to(dtype)
72
+
73
+ # if self.use_conv:
74
+ # if self.name == "conv":
75
+ # hidden_states = self.conv(hidden_states)
76
+ # else:
77
+ # hidden_states = self.Conv2d_0(hidden_states)
78
+ hidden_states = self.conv(hidden_states)
79
+
80
+ return hidden_states
81
+
82
+
83
+ class Downsample3D(nn.Module):
84
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
85
+ super().__init__()
86
+ self.channels = channels
87
+ self.out_channels = out_channels or channels
88
+ self.use_conv = use_conv
89
+ self.padding = padding
90
+ stride = 2
91
+ self.name = name
92
+
93
+ if use_conv:
94
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ def forward(self, hidden_states):
99
+ assert hidden_states.shape[1] == self.channels
100
+ if self.use_conv and self.padding == 0:
101
+ raise NotImplementedError
102
+
103
+ assert hidden_states.shape[1] == self.channels
104
+ hidden_states = self.conv(hidden_states)
105
+
106
+ return hidden_states
107
+
108
+
109
+ class ResnetBlock3D(nn.Module):
110
+ def __init__(
111
+ self,
112
+ *,
113
+ in_channels,
114
+ out_channels=None,
115
+ conv_shortcut=False,
116
+ dropout=0.0,
117
+ temb_channels=512,
118
+ groups=32,
119
+ groups_out=None,
120
+ pre_norm=True,
121
+ eps=1e-6,
122
+ non_linearity="swish",
123
+ time_embedding_norm="default",
124
+ output_scale_factor=1.0,
125
+ use_in_shortcut=None,
126
+ use_inflated_groupnorm=False,
127
+ ):
128
+ super().__init__()
129
+ self.pre_norm = pre_norm
130
+ self.pre_norm = True
131
+ self.in_channels = in_channels
132
+ out_channels = in_channels if out_channels is None else out_channels
133
+ self.out_channels = out_channels
134
+ self.use_conv_shortcut = conv_shortcut
135
+ self.time_embedding_norm = time_embedding_norm
136
+ self.output_scale_factor = output_scale_factor
137
+
138
+ if groups_out is None:
139
+ groups_out = groups
140
+
141
+ assert use_inflated_groupnorm != None
142
+ if use_inflated_groupnorm:
143
+ self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
144
+ else:
145
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
146
+
147
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
148
+
149
+ if temb_channels is not None:
150
+ if self.time_embedding_norm == "default":
151
+ time_emb_proj_out_channels = out_channels
152
+ elif self.time_embedding_norm == "scale_shift":
153
+ time_emb_proj_out_channels = out_channels * 2
154
+ else:
155
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
156
+
157
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
158
+ else:
159
+ self.time_emb_proj = None
160
+
161
+ if use_inflated_groupnorm:
162
+ self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
163
+ else:
164
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
165
+
166
+ self.dropout = torch.nn.Dropout(dropout)
167
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
168
+
169
+ if non_linearity == "swish":
170
+ self.nonlinearity = lambda x: F.silu(x)
171
+ elif non_linearity == "mish":
172
+ self.nonlinearity = Mish()
173
+ elif non_linearity == "silu":
174
+ self.nonlinearity = nn.SiLU()
175
+
176
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
177
+
178
+ self.conv_shortcut = None
179
+ if self.use_in_shortcut:
180
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
181
+
182
+ def forward(self, input_tensor, temb):
183
+ hidden_states = input_tensor
184
+
185
+ hidden_states = self.norm1(hidden_states)
186
+ hidden_states = self.nonlinearity(hidden_states)
187
+
188
+ hidden_states = self.conv1(hidden_states)
189
+
190
+ if temb is not None:
191
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
192
+
193
+ if temb is not None and self.time_embedding_norm == "default":
194
+ hidden_states = hidden_states + temb
195
+
196
+ hidden_states = self.norm2(hidden_states)
197
+
198
+ if temb is not None and self.time_embedding_norm == "scale_shift":
199
+ scale, shift = torch.chunk(temb, 2, dim=1)
200
+ hidden_states = hidden_states * (1 + scale) + shift
201
+
202
+ hidden_states = self.nonlinearity(hidden_states)
203
+
204
+ hidden_states = self.dropout(hidden_states)
205
+ hidden_states = self.conv2(hidden_states)
206
+
207
+ if self.conv_shortcut is not None:
208
+ input_tensor = self.conv_shortcut(input_tensor)
209
+
210
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
211
+
212
+ return output_tensor
213
+
214
+
215
+ class Mish(torch.nn.Module):
216
+ def forward(self, hidden_states):
217
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
animatediff/models/sparse_controlnet.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # Changes were made to this source code by Yuwei Guo.
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ from torch import nn
21
+ from torch.nn import functional as F
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.utils import BaseOutput, logging
25
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
26
+ # from diffusers.modeling_utils import ModelMixin
27
+ from diffusers import ModelMixin
28
+
29
+
30
+ from .unet_blocks import (
31
+ CrossAttnDownBlock3D,
32
+ DownBlock3D,
33
+ UNetMidBlock3DCrossAttn,
34
+ get_down_block,
35
+ )
36
+ from einops import repeat, rearrange
37
+ from .resnet import InflatedConv3d
38
+
39
+ from diffusers import UNet2DConditionModel
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+
44
+ @dataclass
45
+ class SparseControlNetOutput(BaseOutput):
46
+ down_block_res_samples: Tuple[torch.Tensor]
47
+ mid_block_res_sample: torch.Tensor
48
+
49
+
50
+ class SparseControlNetConditioningEmbedding(nn.Module):
51
+ def __init__(
52
+ self,
53
+ conditioning_embedding_channels: int,
54
+ conditioning_channels: int = 3,
55
+ block_out_channels: Tuple[int] = (16, 32, 96, 256),
56
+ ):
57
+ super().__init__()
58
+
59
+ self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
60
+
61
+ self.blocks = nn.ModuleList([])
62
+
63
+ for i in range(len(block_out_channels) - 1):
64
+ channel_in = block_out_channels[i]
65
+ channel_out = block_out_channels[i + 1]
66
+ self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1))
67
+ self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
68
+
69
+ self.conv_out = zero_module(
70
+ InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
71
+ )
72
+
73
+ def forward(self, conditioning):
74
+ embedding = self.conv_in(conditioning)
75
+ embedding = F.silu(embedding)
76
+
77
+ for block in self.blocks:
78
+ embedding = block(embedding)
79
+ embedding = F.silu(embedding)
80
+
81
+ embedding = self.conv_out(embedding)
82
+
83
+ return embedding
84
+
85
+
86
+ class SparseControlNetModel(ModelMixin, ConfigMixin):
87
+ _supports_gradient_checkpointing = True
88
+
89
+ @register_to_config
90
+ def __init__(
91
+ self,
92
+ in_channels: int = 4,
93
+ conditioning_channels: int = 3,
94
+ flip_sin_to_cos: bool = True,
95
+ freq_shift: int = 0,
96
+ down_block_types: Tuple[str] = (
97
+ "CrossAttnDownBlock2D",
98
+ "CrossAttnDownBlock2D",
99
+ "CrossAttnDownBlock2D",
100
+ "DownBlock2D",
101
+ ),
102
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
103
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
104
+ layers_per_block: int = 2,
105
+ downsample_padding: int = 1,
106
+ mid_block_scale_factor: float = 1,
107
+ act_fn: str = "silu",
108
+ norm_num_groups: Optional[int] = 32,
109
+ norm_eps: float = 1e-5,
110
+ cross_attention_dim: int = 1280,
111
+ attention_head_dim: Union[int, Tuple[int]] = 8,
112
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
113
+ use_linear_projection: bool = False,
114
+ class_embed_type: Optional[str] = None,
115
+ num_class_embeds: Optional[int] = None,
116
+ upcast_attention: bool = False,
117
+ resnet_time_scale_shift: str = "default",
118
+ projection_class_embeddings_input_dim: Optional[int] = None,
119
+ controlnet_conditioning_channel_order: str = "rgb",
120
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
121
+ global_pool_conditions: bool = False,
122
+
123
+ use_motion_module = True,
124
+ motion_module_resolutions = ( 1,2,4,8 ),
125
+ motion_module_mid_block = False,
126
+ motion_module_type = "Vanilla",
127
+ motion_module_kwargs = {
128
+ "num_attention_heads": 8,
129
+ "num_transformer_block": 1,
130
+ "attention_block_types": ["Temporal_Self"],
131
+ "temporal_position_encoding": True,
132
+ "temporal_position_encoding_max_len": 32,
133
+ "temporal_attention_dim_div": 1,
134
+ "causal_temporal_attention": False,
135
+ },
136
+
137
+ concate_conditioning_mask: bool = True,
138
+ use_simplified_condition_embedding: bool = False,
139
+
140
+ set_noisy_sample_input_to_zero: bool = False,
141
+ ):
142
+ super().__init__()
143
+
144
+ # If `num_attention_heads` is not defined (which is the case for most models)
145
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
146
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
147
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
148
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
149
+ # which is why we correct for the naming here.
150
+ num_attention_heads = num_attention_heads or attention_head_dim
151
+
152
+ # Check inputs
153
+ if len(block_out_channels) != len(down_block_types):
154
+ raise ValueError(
155
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
156
+ )
157
+
158
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
159
+ raise ValueError(
160
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
161
+ )
162
+
163
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
164
+ raise ValueError(
165
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
166
+ )
167
+
168
+ # input
169
+ self.set_noisy_sample_input_to_zero = set_noisy_sample_input_to_zero
170
+
171
+ conv_in_kernel = 3
172
+ conv_in_padding = (conv_in_kernel - 1) // 2
173
+ self.conv_in = InflatedConv3d(
174
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
175
+ )
176
+
177
+ if concate_conditioning_mask:
178
+ conditioning_channels = conditioning_channels + 1
179
+ self.concate_conditioning_mask = concate_conditioning_mask
180
+
181
+ # control net conditioning embedding
182
+ if use_simplified_condition_embedding:
183
+ self.controlnet_cond_embedding = zero_module(
184
+ InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
185
+ )
186
+ else:
187
+ self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding(
188
+ conditioning_embedding_channels=block_out_channels[0],
189
+ block_out_channels=conditioning_embedding_out_channels,
190
+ conditioning_channels=conditioning_channels,
191
+ )
192
+ self.use_simplified_condition_embedding = use_simplified_condition_embedding
193
+
194
+ # time
195
+ time_embed_dim = block_out_channels[0] * 4
196
+
197
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
198
+ timestep_input_dim = block_out_channels[0]
199
+
200
+ self.time_embedding = TimestepEmbedding(
201
+ timestep_input_dim,
202
+ time_embed_dim,
203
+ act_fn=act_fn,
204
+ )
205
+
206
+ # class embedding
207
+ if class_embed_type is None and num_class_embeds is not None:
208
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
209
+ elif class_embed_type == "timestep":
210
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
211
+ elif class_embed_type == "identity":
212
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
213
+ elif class_embed_type == "projection":
214
+ if projection_class_embeddings_input_dim is None:
215
+ raise ValueError(
216
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
217
+ )
218
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
219
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
220
+ # 2. it projects from an arbitrary input dimension.
221
+ #
222
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
223
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
224
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
225
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
226
+ else:
227
+ self.class_embedding = None
228
+
229
+
230
+ self.down_blocks = nn.ModuleList([])
231
+ self.controlnet_down_blocks = nn.ModuleList([])
232
+
233
+ if isinstance(only_cross_attention, bool):
234
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
235
+
236
+ if isinstance(attention_head_dim, int):
237
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
238
+
239
+ if isinstance(num_attention_heads, int):
240
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
241
+
242
+ # down
243
+ output_channel = block_out_channels[0]
244
+
245
+ controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
246
+ controlnet_block = zero_module(controlnet_block)
247
+ self.controlnet_down_blocks.append(controlnet_block)
248
+
249
+ for i, down_block_type in enumerate(down_block_types):
250
+ res = 2 ** i
251
+ input_channel = output_channel
252
+ output_channel = block_out_channels[i]
253
+ is_final_block = i == len(block_out_channels) - 1
254
+
255
+ down_block = get_down_block(
256
+ down_block_type,
257
+ num_layers=layers_per_block,
258
+ in_channels=input_channel,
259
+ out_channels=output_channel,
260
+ temb_channels=time_embed_dim,
261
+ add_downsample=not is_final_block,
262
+ resnet_eps=norm_eps,
263
+ resnet_act_fn=act_fn,
264
+ resnet_groups=norm_num_groups,
265
+ cross_attention_dim=cross_attention_dim,
266
+ attn_num_head_channels=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
267
+ downsample_padding=downsample_padding,
268
+ use_linear_projection=use_linear_projection,
269
+ only_cross_attention=only_cross_attention[i],
270
+ upcast_attention=upcast_attention,
271
+ resnet_time_scale_shift=resnet_time_scale_shift,
272
+
273
+ use_inflated_groupnorm=True,
274
+
275
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
276
+ motion_module_type=motion_module_type,
277
+ motion_module_kwargs=motion_module_kwargs,
278
+ )
279
+ self.down_blocks.append(down_block)
280
+
281
+ for _ in range(layers_per_block):
282
+ controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
283
+ controlnet_block = zero_module(controlnet_block)
284
+ self.controlnet_down_blocks.append(controlnet_block)
285
+
286
+ if not is_final_block:
287
+ controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
288
+ controlnet_block = zero_module(controlnet_block)
289
+ self.controlnet_down_blocks.append(controlnet_block)
290
+
291
+ # mid
292
+ mid_block_channel = block_out_channels[-1]
293
+
294
+ controlnet_block = InflatedConv3d(mid_block_channel, mid_block_channel, kernel_size=1)
295
+ controlnet_block = zero_module(controlnet_block)
296
+ self.controlnet_mid_block = controlnet_block
297
+
298
+ self.mid_block = UNetMidBlock3DCrossAttn(
299
+ in_channels=mid_block_channel,
300
+ temb_channels=time_embed_dim,
301
+ resnet_eps=norm_eps,
302
+ resnet_act_fn=act_fn,
303
+ output_scale_factor=mid_block_scale_factor,
304
+ resnet_time_scale_shift=resnet_time_scale_shift,
305
+ cross_attention_dim=cross_attention_dim,
306
+ attn_num_head_channels=num_attention_heads[-1],
307
+ resnet_groups=norm_num_groups,
308
+ use_linear_projection=use_linear_projection,
309
+ upcast_attention=upcast_attention,
310
+
311
+ use_inflated_groupnorm=True,
312
+ use_motion_module=use_motion_module and motion_module_mid_block,
313
+ motion_module_type=motion_module_type,
314
+ motion_module_kwargs=motion_module_kwargs,
315
+ )
316
+
317
+ @classmethod
318
+ def from_unet(
319
+ cls,
320
+ unet: UNet2DConditionModel,
321
+ controlnet_conditioning_channel_order: str = "rgb",
322
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
323
+ load_weights_from_unet: bool = True,
324
+
325
+ controlnet_additional_kwargs: dict = {},
326
+ ):
327
+ controlnet = cls(
328
+ in_channels=unet.config.in_channels,
329
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
330
+ freq_shift=unet.config.freq_shift,
331
+ down_block_types=unet.config.down_block_types,
332
+ only_cross_attention=unet.config.only_cross_attention,
333
+ block_out_channels=unet.config.block_out_channels,
334
+ layers_per_block=unet.config.layers_per_block,
335
+ downsample_padding=unet.config.downsample_padding,
336
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
337
+ act_fn=unet.config.act_fn,
338
+ norm_num_groups=unet.config.norm_num_groups,
339
+ norm_eps=unet.config.norm_eps,
340
+ cross_attention_dim=unet.config.cross_attention_dim,
341
+ attention_head_dim=unet.config.attention_head_dim,
342
+ num_attention_heads=unet.config.num_attention_heads,
343
+ use_linear_projection=unet.config.use_linear_projection,
344
+ class_embed_type=unet.config.class_embed_type,
345
+ num_class_embeds=unet.config.num_class_embeds,
346
+ upcast_attention=unet.config.upcast_attention,
347
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
348
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
349
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
350
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
351
+
352
+ **controlnet_additional_kwargs,
353
+ )
354
+
355
+ if load_weights_from_unet:
356
+ m, u = controlnet.conv_in.load_state_dict(cls.image_layer_filter(unet.conv_in.state_dict()), strict=False)
357
+ assert len(u) == 0
358
+ m, u = controlnet.time_proj.load_state_dict(cls.image_layer_filter(unet.time_proj.state_dict()), strict=False)
359
+ assert len(u) == 0
360
+ m, u = controlnet.time_embedding.load_state_dict(cls.image_layer_filter(unet.time_embedding.state_dict()), strict=False)
361
+ assert len(u) == 0
362
+
363
+ if controlnet.class_embedding:
364
+ m, u = controlnet.class_embedding.load_state_dict(cls.image_layer_filter(unet.class_embedding.state_dict()), strict=False)
365
+ assert len(u) == 0
366
+ m, u = controlnet.down_blocks.load_state_dict(cls.image_layer_filter(unet.down_blocks.state_dict()), strict=False)
367
+ assert len(u) == 0
368
+ m, u = controlnet.mid_block.load_state_dict(cls.image_layer_filter(unet.mid_block.state_dict()), strict=False)
369
+ assert len(u) == 0
370
+
371
+ return controlnet
372
+
373
+ @staticmethod
374
+ def image_layer_filter(state_dict):
375
+ new_state_dict = {}
376
+ for name, param in state_dict.items():
377
+ if "motion_modules." in name or "lora" in name: continue
378
+ new_state_dict[name] = param
379
+ return new_state_dict
380
+
381
+ # Copied from diffusers.models.UNet2DConditionModel.set_attention_slice
382
+ def set_attention_slice(self, slice_size):
383
+ r"""
384
+ Enable sliced attention computation.
385
+
386
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
387
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
388
+
389
+ Args:
390
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
391
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
392
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
393
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
394
+ must be a multiple of `slice_size`.
395
+ """
396
+ sliceable_head_dims = []
397
+
398
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
399
+ if hasattr(module, "set_attention_slice"):
400
+ sliceable_head_dims.append(module.sliceable_head_dim)
401
+
402
+ for child in module.children():
403
+ fn_recursive_retrieve_sliceable_dims(child)
404
+
405
+ # retrieve number of attention layers
406
+ for module in self.children():
407
+ fn_recursive_retrieve_sliceable_dims(module)
408
+
409
+ num_sliceable_layers = len(sliceable_head_dims)
410
+
411
+ if slice_size == "auto":
412
+ # half the attention head size is usually a good trade-off between
413
+ # speed and memory
414
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
415
+ elif slice_size == "max":
416
+ # make smallest slice possible
417
+ slice_size = num_sliceable_layers * [1]
418
+
419
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
420
+
421
+ if len(slice_size) != len(sliceable_head_dims):
422
+ raise ValueError(
423
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
424
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
425
+ )
426
+
427
+ for i in range(len(slice_size)):
428
+ size = slice_size[i]
429
+ dim = sliceable_head_dims[i]
430
+ if size is not None and size > dim:
431
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
432
+
433
+ # Recursively walk through all the children.
434
+ # Any children which exposes the set_attention_slice method
435
+ # gets the message
436
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
437
+ if hasattr(module, "set_attention_slice"):
438
+ module.set_attention_slice(slice_size.pop())
439
+
440
+ for child in module.children():
441
+ fn_recursive_set_attention_slice(child, slice_size)
442
+
443
+ reversed_slice_size = list(reversed(slice_size))
444
+ for module in self.children():
445
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
446
+
447
+ def _set_gradient_checkpointing(self, module, value=False):
448
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
449
+ module.gradient_checkpointing = value
450
+
451
+ def forward(
452
+ self,
453
+ sample: torch.FloatTensor,
454
+ timestep: Union[torch.Tensor, float, int],
455
+ encoder_hidden_states: torch.Tensor,
456
+
457
+ controlnet_cond: torch.FloatTensor,
458
+ conditioning_mask: Optional[torch.FloatTensor] = None,
459
+
460
+ conditioning_scale: float = 1.0,
461
+ class_labels: Optional[torch.Tensor] = None,
462
+ attention_mask: Optional[torch.Tensor] = None,
463
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
464
+ guess_mode: bool = False,
465
+ return_dict: bool = True,
466
+ ) -> Union[SparseControlNetOutput, Tuple]:
467
+
468
+ # set input noise to zero
469
+ if self.set_noisy_sample_input_to_zero:
470
+ sample = torch.zeros_like(sample).to(sample.device)
471
+
472
+ # prepare attention_mask
473
+ if attention_mask is not None:
474
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
475
+ attention_mask = attention_mask.unsqueeze(1)
476
+
477
+ # 1. time
478
+ timesteps = timestep
479
+ if not torch.is_tensor(timesteps):
480
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
481
+ # This would be a good case for the `match` statement (Python 3.10+)
482
+ is_mps = sample.device.type == "mps"
483
+ if isinstance(timestep, float):
484
+ dtype = torch.float32 if is_mps else torch.float64
485
+ else:
486
+ dtype = torch.int32 if is_mps else torch.int64
487
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
488
+ elif len(timesteps.shape) == 0:
489
+ timesteps = timesteps[None].to(sample.device)
490
+
491
+ timesteps = timesteps.repeat(sample.shape[0] // timesteps.shape[0])
492
+ encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0] // encoder_hidden_states.shape[0], 1, 1)
493
+
494
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
495
+ timesteps = timesteps.expand(sample.shape[0])
496
+
497
+ t_emb = self.time_proj(timesteps)
498
+
499
+ # timesteps does not contain any weights and will always return f32 tensors
500
+ # but time_embedding might actually be running in fp16. so we need to cast here.
501
+ # there might be better ways to encapsulate this.
502
+ t_emb = t_emb.to(dtype=self.dtype)
503
+ emb = self.time_embedding(t_emb)
504
+
505
+ if self.class_embedding is not None:
506
+ if class_labels is None:
507
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
508
+
509
+ if self.config.class_embed_type == "timestep":
510
+ class_labels = self.time_proj(class_labels)
511
+
512
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
513
+ emb = emb + class_emb
514
+
515
+ # 2. pre-process
516
+ sample = self.conv_in(sample)
517
+ if self.concate_conditioning_mask:
518
+ controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1)
519
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
520
+
521
+ sample = sample + controlnet_cond
522
+
523
+ # 3. down
524
+ down_block_res_samples = (sample,)
525
+ for downsample_block in self.down_blocks:
526
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
527
+ sample, res_samples = downsample_block(
528
+ hidden_states=sample,
529
+ temb=emb,
530
+ encoder_hidden_states=encoder_hidden_states,
531
+ attention_mask=attention_mask,
532
+ # cross_attention_kwargs=cross_attention_kwargs,
533
+ )
534
+ else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
535
+
536
+ down_block_res_samples += res_samples
537
+
538
+ # 4. mid
539
+ if self.mid_block is not None:
540
+ sample = self.mid_block(
541
+ sample,
542
+ emb,
543
+ encoder_hidden_states=encoder_hidden_states,
544
+ attention_mask=attention_mask,
545
+ # cross_attention_kwargs=cross_attention_kwargs,
546
+ )
547
+
548
+ # 5. controlnet blocks
549
+ controlnet_down_block_res_samples = ()
550
+
551
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
552
+ down_block_res_sample = controlnet_block(down_block_res_sample)
553
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
554
+
555
+ down_block_res_samples = controlnet_down_block_res_samples
556
+
557
+ mid_block_res_sample = self.controlnet_mid_block(sample)
558
+
559
+ # 6. scaling
560
+ if guess_mode and not self.config.global_pool_conditions:
561
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
562
+
563
+ scales = scales * conditioning_scale
564
+ down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
565
+ mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
566
+ else:
567
+ down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
568
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
569
+
570
+ if self.config.global_pool_conditions:
571
+ down_block_res_samples = [
572
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
573
+ ]
574
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
575
+
576
+ if not return_dict:
577
+ return (down_block_res_samples, mid_block_res_sample)
578
+
579
+ return SparseControlNetOutput(
580
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
581
+ )
582
+
583
+
584
+ def zero_module(module):
585
+ for p in module.parameters():
586
+ nn.init.zeros_(p)
587
+ return module
animatediff/models/unet.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union,Dict
5
+
6
+ import os
7
+ import json
8
+ import pdb
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.utils.checkpoint
13
+
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers import ModelMixin
16
+ from diffusers.utils import BaseOutput, logging
17
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
18
+ from .unet_blocks import (
19
+ CrossAttnDownBlock3D,
20
+ CrossAttnUpBlock3D,
21
+ DownBlock3D,
22
+ UNetMidBlock3DCrossAttn,
23
+ UpBlock3D,
24
+ get_down_block,
25
+ get_up_block,
26
+ )
27
+ from .resnet import InflatedConv3d, InflatedGroupNorm
28
+ from diffusers.models.attention_processor import (
29
+ ADDED_KV_ATTENTION_PROCESSORS,
30
+ CROSS_ATTENTION_PROCESSORS,
31
+ Attention,
32
+ AttentionProcessor,
33
+ AttnAddedKVProcessor,
34
+ AttnProcessor,
35
+ )
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ @dataclass
41
+ class UNet3DConditionOutput(BaseOutput):
42
+ sample: torch.FloatTensor
43
+
44
+
45
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
46
+ _supports_gradient_checkpointing = True
47
+
48
+ @register_to_config
49
+ def __init__(
50
+ self,
51
+ sample_size: Optional[int] = None,
52
+ in_channels: int = 4,
53
+ out_channels: int = 4,
54
+ center_input_sample: bool = False,
55
+ flip_sin_to_cos: bool = True,
56
+ freq_shift: int = 0,
57
+ down_block_types: Tuple[str] = (
58
+ "CrossAttnDownBlock3D",
59
+ "CrossAttnDownBlock3D",
60
+ "CrossAttnDownBlock3D",
61
+ "DownBlock3D",
62
+ ),
63
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
64
+ up_block_types: Tuple[str] = (
65
+ "UpBlock3D",
66
+ "CrossAttnUpBlock3D",
67
+ "CrossAttnUpBlock3D",
68
+ "CrossAttnUpBlock3D"
69
+ ),
70
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
71
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
72
+ layers_per_block: int = 2,
73
+ downsample_padding: int = 1,
74
+ mid_block_scale_factor: float = 1,
75
+ act_fn: str = "silu",
76
+ norm_num_groups: int = 32,
77
+ norm_eps: float = 1e-5,
78
+ cross_attention_dim: int = 1280,
79
+ attention_head_dim: Union[int, Tuple[int]] = 8,
80
+ dual_cross_attention: bool = False,
81
+ use_linear_projection: bool = False,
82
+ class_embed_type: Optional[str] = None,
83
+ num_class_embeds: Optional[int] = None,
84
+ upcast_attention: bool = False,
85
+ resnet_time_scale_shift: str = "default",
86
+
87
+ use_inflated_groupnorm=False,
88
+
89
+ # Additional
90
+ use_motion_module = False,
91
+ motion_module_resolutions = ( 1,2,4,8 ),
92
+ motion_module_mid_block = False,
93
+ motion_module_decoder_only = False,
94
+ motion_module_type = None,
95
+ motion_module_kwargs = {},
96
+ unet_use_cross_frame_attention = False,
97
+ unet_use_temporal_attention = False,
98
+ ):
99
+ super().__init__()
100
+
101
+ self.sample_size = sample_size
102
+ time_embed_dim = block_out_channels[0] * 4
103
+
104
+ # input
105
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
106
+
107
+ # time
108
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
109
+ timestep_input_dim = block_out_channels[0]
110
+
111
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
112
+
113
+ # class embedding
114
+ if class_embed_type is None and num_class_embeds is not None:
115
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
116
+ elif class_embed_type == "timestep":
117
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
118
+ elif class_embed_type == "identity":
119
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
120
+ else:
121
+ self.class_embedding = None
122
+
123
+ self.down_blocks = nn.ModuleList([])
124
+ self.mid_block = None
125
+ self.up_blocks = nn.ModuleList([])
126
+
127
+ if isinstance(only_cross_attention, bool):
128
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
129
+
130
+ if isinstance(attention_head_dim, int):
131
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
132
+
133
+ # down
134
+ output_channel = block_out_channels[0]
135
+ for i, down_block_type in enumerate(down_block_types):
136
+ res = 2 ** i
137
+ input_channel = output_channel
138
+ output_channel = block_out_channels[i]
139
+ is_final_block = i == len(block_out_channels) - 1
140
+
141
+ down_block = get_down_block(
142
+ down_block_type,
143
+ num_layers=layers_per_block,
144
+ in_channels=input_channel,
145
+ out_channels=output_channel,
146
+ temb_channels=time_embed_dim,
147
+ add_downsample=not is_final_block,
148
+ resnet_eps=norm_eps,
149
+ resnet_act_fn=act_fn,
150
+ resnet_groups=norm_num_groups,
151
+ cross_attention_dim=cross_attention_dim,
152
+ attn_num_head_channels=attention_head_dim[i],
153
+ downsample_padding=downsample_padding,
154
+ dual_cross_attention=dual_cross_attention,
155
+ use_linear_projection=use_linear_projection,
156
+ only_cross_attention=only_cross_attention[i],
157
+ upcast_attention=upcast_attention,
158
+ resnet_time_scale_shift=resnet_time_scale_shift,
159
+
160
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
161
+ unet_use_temporal_attention=unet_use_temporal_attention,
162
+ use_inflated_groupnorm=use_inflated_groupnorm,
163
+
164
+ use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
165
+ motion_module_type=motion_module_type,
166
+ motion_module_kwargs=motion_module_kwargs,
167
+ )
168
+ self.down_blocks.append(down_block)
169
+
170
+ # mid
171
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
172
+ self.mid_block = UNetMidBlock3DCrossAttn(
173
+ in_channels=block_out_channels[-1],
174
+ temb_channels=time_embed_dim,
175
+ resnet_eps=norm_eps,
176
+ resnet_act_fn=act_fn,
177
+ output_scale_factor=mid_block_scale_factor,
178
+ resnet_time_scale_shift=resnet_time_scale_shift,
179
+ cross_attention_dim=cross_attention_dim,
180
+ attn_num_head_channels=attention_head_dim[-1],
181
+ resnet_groups=norm_num_groups,
182
+ dual_cross_attention=dual_cross_attention,
183
+ use_linear_projection=use_linear_projection,
184
+ upcast_attention=upcast_attention,
185
+
186
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
187
+ unet_use_temporal_attention=unet_use_temporal_attention,
188
+ use_inflated_groupnorm=use_inflated_groupnorm,
189
+
190
+ use_motion_module=use_motion_module and motion_module_mid_block,
191
+ motion_module_type=motion_module_type,
192
+ motion_module_kwargs=motion_module_kwargs,
193
+ )
194
+ else:
195
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
196
+
197
+ # count how many layers upsample the videos
198
+ self.num_upsamplers = 0
199
+
200
+ # up
201
+ reversed_block_out_channels = list(reversed(block_out_channels))
202
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
203
+ only_cross_attention = list(reversed(only_cross_attention))
204
+ output_channel = reversed_block_out_channels[0]
205
+ for i, up_block_type in enumerate(up_block_types):
206
+ res = 2 ** (3 - i)
207
+ is_final_block = i == len(block_out_channels) - 1
208
+
209
+ prev_output_channel = output_channel
210
+ output_channel = reversed_block_out_channels[i]
211
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
212
+
213
+ # add upsample block for all BUT final layer
214
+ if not is_final_block:
215
+ add_upsample = True
216
+ self.num_upsamplers += 1
217
+ else:
218
+ add_upsample = False
219
+
220
+ up_block = get_up_block(
221
+ up_block_type,
222
+ num_layers=layers_per_block + 1,
223
+ in_channels=input_channel,
224
+ out_channels=output_channel,
225
+ prev_output_channel=prev_output_channel,
226
+ temb_channels=time_embed_dim,
227
+ add_upsample=add_upsample,
228
+ resnet_eps=norm_eps,
229
+ resnet_act_fn=act_fn,
230
+ resnet_groups=norm_num_groups,
231
+ cross_attention_dim=cross_attention_dim,
232
+ attn_num_head_channels=reversed_attention_head_dim[i],
233
+ dual_cross_attention=dual_cross_attention,
234
+ use_linear_projection=use_linear_projection,
235
+ only_cross_attention=only_cross_attention[i],
236
+ upcast_attention=upcast_attention,
237
+ resnet_time_scale_shift=resnet_time_scale_shift,
238
+
239
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
240
+ unet_use_temporal_attention=unet_use_temporal_attention,
241
+ use_inflated_groupnorm=use_inflated_groupnorm,
242
+
243
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
244
+ motion_module_type=motion_module_type,
245
+ motion_module_kwargs=motion_module_kwargs,
246
+ )
247
+ self.up_blocks.append(up_block)
248
+ prev_output_channel = output_channel
249
+
250
+ # out
251
+ if use_inflated_groupnorm:
252
+ self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
253
+ else:
254
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
255
+ self.conv_act = nn.SiLU()
256
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
257
+
258
+ def set_attention_slice(self, slice_size):
259
+ r"""
260
+ Enable sliced attention computation.
261
+
262
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
263
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
264
+
265
+ Args:
266
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
267
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
268
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
269
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
270
+ must be a multiple of `slice_size`.
271
+ """
272
+ sliceable_head_dims = []
273
+
274
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
275
+ if hasattr(module, "set_attention_slice"):
276
+ sliceable_head_dims.append(module.sliceable_head_dim)
277
+
278
+ for child in module.children():
279
+ fn_recursive_retrieve_slicable_dims(child)
280
+
281
+ # retrieve number of attention layers
282
+ for module in self.children():
283
+ fn_recursive_retrieve_slicable_dims(module)
284
+
285
+ num_slicable_layers = len(sliceable_head_dims)
286
+
287
+ if slice_size == "auto":
288
+ # half the attention head size is usually a good trade-off between
289
+ # speed and memory
290
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
291
+ elif slice_size == "max":
292
+ # make smallest slice possible
293
+ slice_size = num_slicable_layers * [1]
294
+
295
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
296
+
297
+ if len(slice_size) != len(sliceable_head_dims):
298
+ raise ValueError(
299
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
300
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
301
+ )
302
+
303
+ for i in range(len(slice_size)):
304
+ size = slice_size[i]
305
+ dim = sliceable_head_dims[i]
306
+ if size is not None and size > dim:
307
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
308
+
309
+ # Recursively walk through all the children.
310
+ # Any children which exposes the set_attention_slice method
311
+ # gets the message
312
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
313
+ if hasattr(module, "set_attention_slice"):
314
+ module.set_attention_slice(slice_size.pop())
315
+
316
+ for child in module.children():
317
+ fn_recursive_set_attention_slice(child, slice_size)
318
+
319
+ reversed_slice_size = list(reversed(slice_size))
320
+ for module in self.children():
321
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
322
+
323
+ def _set_gradient_checkpointing(self, module, value=False):
324
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
325
+ module.gradient_checkpointing = value
326
+
327
+ def forward(
328
+ self,
329
+ sample: torch.FloatTensor,
330
+ timestep: Union[torch.Tensor, float, int],
331
+ encoder_hidden_states: torch.Tensor,
332
+ class_labels: Optional[torch.Tensor] = None,
333
+ attention_mask: Optional[torch.Tensor] = None,
334
+
335
+ # support controlnet
336
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
337
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
338
+
339
+ return_dict: bool = True,
340
+ ) -> Union[UNet3DConditionOutput, Tuple]:
341
+ r"""
342
+ Args:
343
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
344
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
345
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
346
+ return_dict (`bool`, *optional*, defaults to `True`):
347
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
348
+
349
+ Returns:
350
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
351
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
352
+ returning a tuple, the first element is the sample tensor.
353
+ """
354
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
355
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
356
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
357
+ # on the fly if necessary.
358
+ default_overall_up_factor = 2**self.num_upsamplers
359
+
360
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
361
+ forward_upsample_size = False
362
+ upsample_size = None
363
+
364
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
365
+ logger.info("Forward upsample size to force interpolation output size.")
366
+ forward_upsample_size = True
367
+
368
+ # prepare attention_mask
369
+ if attention_mask is not None:
370
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
371
+ attention_mask = attention_mask.unsqueeze(1)
372
+
373
+ # center input if necessary
374
+ if self.config.center_input_sample:
375
+ sample = 2 * sample - 1.0
376
+
377
+ # time
378
+ timesteps = timestep
379
+ if not torch.is_tensor(timesteps):
380
+ # This would be a good case for the `match` statement (Python 3.10+)
381
+ is_mps = sample.device.type == "mps"
382
+ if isinstance(timestep, float):
383
+ dtype = torch.float32 if is_mps else torch.float64
384
+ else:
385
+ dtype = torch.int32 if is_mps else torch.int64
386
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
387
+ elif len(timesteps.shape) == 0:
388
+ timesteps = timesteps[None].to(sample.device)
389
+
390
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
391
+ timesteps = timesteps.expand(sample.shape[0])
392
+
393
+ t_emb = self.time_proj(timesteps)
394
+
395
+ # timesteps does not contain any weights and will always return f32 tensors
396
+ # but time_embedding might actually be running in fp16. so we need to cast here.
397
+ # there might be better ways to encapsulate this.
398
+ t_emb = t_emb.to(dtype=self.dtype)
399
+ emb = self.time_embedding(t_emb)
400
+
401
+ if self.class_embedding is not None:
402
+ if class_labels is None:
403
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
404
+
405
+ if self.config.class_embed_type == "timestep":
406
+ class_labels = self.time_proj(class_labels)
407
+
408
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
409
+ emb = emb + class_emb
410
+
411
+ # pre-process
412
+ sample = self.conv_in(sample)
413
+
414
+ # down
415
+ down_block_res_samples = (sample,)
416
+ for downsample_block in self.down_blocks:
417
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
418
+ sample, res_samples = downsample_block(
419
+ hidden_states=sample,
420
+ temb=emb,
421
+ encoder_hidden_states=encoder_hidden_states,
422
+ attention_mask=attention_mask,
423
+ )
424
+ else:
425
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
426
+
427
+ down_block_res_samples += res_samples
428
+
429
+ # support controlnet
430
+ down_block_res_samples = list(down_block_res_samples)
431
+ if down_block_additional_residuals is not None:
432
+ for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
433
+ if down_block_additional_residual.dim() == 4: # boardcast
434
+ down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
435
+ down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
436
+
437
+ # mid
438
+ sample = self.mid_block(
439
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
440
+ )
441
+
442
+ # support controlnet
443
+ if mid_block_additional_residual is not None:
444
+ if mid_block_additional_residual.dim() == 4: # boardcast
445
+ mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
446
+ sample = sample + mid_block_additional_residual
447
+
448
+ # up
449
+ for i, upsample_block in enumerate(self.up_blocks):
450
+ is_final_block = i == len(self.up_blocks) - 1
451
+
452
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
453
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
454
+
455
+ # if we have not reached the final block and need to forward the
456
+ # upsample size, we do it here
457
+ if not is_final_block and forward_upsample_size:
458
+ upsample_size = down_block_res_samples[-1].shape[2:]
459
+
460
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
461
+ sample = upsample_block(
462
+ hidden_states=sample,
463
+ temb=emb,
464
+ res_hidden_states_tuple=res_samples,
465
+ encoder_hidden_states=encoder_hidden_states,
466
+ upsample_size=upsample_size,
467
+ attention_mask=attention_mask,
468
+ )
469
+ else:
470
+ sample = upsample_block(
471
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
472
+ )
473
+
474
+ # post-process
475
+ sample = self.conv_norm_out(sample)
476
+ sample = self.conv_act(sample)
477
+ sample = self.conv_out(sample)
478
+
479
+ if not return_dict:
480
+ return (sample,)
481
+
482
+ return UNet3DConditionOutput(sample=sample)
483
+
484
+ @classmethod
485
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
486
+ if subfolder is not None:
487
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
488
+ print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...")
489
+
490
+ config_file = os.path.join(pretrained_model_path, 'config.json')
491
+ if not os.path.isfile(config_file):
492
+ raise RuntimeError(f"{config_file} does not exist")
493
+ with open(config_file, "r") as f:
494
+ config = json.load(f)
495
+ config["_class_name"] = cls.__name__
496
+ config["down_block_types"] = [
497
+ "CrossAttnDownBlock3D",
498
+ "CrossAttnDownBlock3D",
499
+ "CrossAttnDownBlock3D",
500
+ "DownBlock3D"
501
+ ]
502
+ config["up_block_types"] = [
503
+ "UpBlock3D",
504
+ "CrossAttnUpBlock3D",
505
+ "CrossAttnUpBlock3D",
506
+ "CrossAttnUpBlock3D"
507
+ ]
508
+ # config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
509
+ from diffusers.utils import WEIGHTS_NAME
510
+ model = cls.from_config(config, **unet_additional_kwargs)
511
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
512
+ # from safetensors import safe_open
513
+ # state_dict={}
514
+ # # model_file = "/ssd1/hexuanhua/AnimateDiff/outputs/training-2024-02-17T10-07-50/checkpoints/checkpoint.ckpt"
515
+ # with safe_open("/home/zjy/data/hexuanhua/huggingface_model/hub/models--SG161222--Realistic_Vision_V4.0_noVAE/snapshots/1bd8c538b40236e642a1427ed154a50ef5bdd3df/unet/diffusion_pytorch_model.safetensors", framework="pt", device="cpu") as f:
516
+ # for key in f.keys():
517
+ # state_dict[key] = f.get_tensor(key)
518
+
519
+ if not os.path.isfile(model_file):
520
+ raise RuntimeError(f"{model_file} does not exist")
521
+ state_dict = torch.load(model_file, map_location="cpu")
522
+
523
+ m, u = model.load_state_dict(state_dict, strict=False)
524
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
525
+
526
+ params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()]
527
+ print(f"### Motion Module Parameters: {sum(params) / 1e6} M")
528
+
529
+ return model
530
+ @property
531
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
532
+ r"""
533
+ Returns:
534
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
535
+ indexed by its weight name.
536
+ """
537
+ # set recursively
538
+ processors = {}
539
+
540
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
541
+ if hasattr(module, "get_processor"):
542
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
543
+
544
+ for sub_name, child in module.named_children():
545
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
546
+
547
+ return processors
548
+
549
+ for name, module in self.named_children():
550
+ fn_recursive_add_processors(name, module, processors)
551
+
552
+ return processors
553
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
554
+ r"""
555
+ Sets the attention processor to use to compute attention.
556
+
557
+ Parameters:
558
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
559
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
560
+ for **all** `Attention` layers.
561
+
562
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
563
+ processor. This is strongly recommended when setting trainable attention processors.
564
+
565
+ """
566
+ count = len(self.attn_processors.keys())
567
+
568
+ if isinstance(processor, dict) and len(processor) != count:
569
+ raise ValueError(
570
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
571
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
572
+ )
573
+
574
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
575
+ if hasattr(module, "set_processor"):
576
+ if not isinstance(processor, dict):
577
+ module.set_processor(processor)
578
+ else:
579
+ module.set_processor(processor.pop(f"{name}.processor"))
580
+
581
+ for sub_name, child in module.named_children():
582
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
583
+
584
+ for name, module in self.named_children():
585
+ fn_recursive_attn_processor(name, module, processor)
586
+
587
+ def set_default_attn_processor(self):
588
+ """
589
+ Disables custom attention processors and sets the default attention implementation.
590
+ """
591
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
592
+ processor = AttnAddedKVProcessor()
593
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
594
+ processor = AttnProcessor()
595
+ else:
596
+ raise ValueError(
597
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
598
+ )
599
+
600
+ self.set_attn_processor(processor)
animatediff/models/unet_blocks.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .attention import Transformer3DModel
7
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
8
+ from .motion_module import get_motion_module
9
+
10
+ import pdb
11
+
12
+ def get_down_block(
13
+ down_block_type,
14
+ num_layers,
15
+ in_channels,
16
+ out_channels,
17
+ temb_channels,
18
+ add_downsample,
19
+ resnet_eps,
20
+ resnet_act_fn,
21
+ attn_num_head_channels,
22
+ resnet_groups=None,
23
+ cross_attention_dim=None,
24
+ downsample_padding=None,
25
+ dual_cross_attention=False,
26
+ use_linear_projection=False,
27
+ only_cross_attention=False,
28
+ upcast_attention=False,
29
+ resnet_time_scale_shift="default",
30
+
31
+ unet_use_cross_frame_attention=False,
32
+ unet_use_temporal_attention=False,
33
+ use_inflated_groupnorm=False,
34
+
35
+ use_motion_module=None,
36
+
37
+ motion_module_type=None,
38
+ motion_module_kwargs=None,
39
+ ):
40
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
41
+ if down_block_type == "DownBlock3D":
42
+ return DownBlock3D(
43
+ num_layers=num_layers,
44
+ in_channels=in_channels,
45
+ out_channels=out_channels,
46
+ temb_channels=temb_channels,
47
+ add_downsample=add_downsample,
48
+ resnet_eps=resnet_eps,
49
+ resnet_act_fn=resnet_act_fn,
50
+ resnet_groups=resnet_groups,
51
+ downsample_padding=downsample_padding,
52
+ resnet_time_scale_shift=resnet_time_scale_shift,
53
+
54
+ use_inflated_groupnorm=use_inflated_groupnorm,
55
+
56
+ use_motion_module=use_motion_module,
57
+ motion_module_type=motion_module_type,
58
+ motion_module_kwargs=motion_module_kwargs,
59
+ )
60
+ elif down_block_type == "CrossAttnDownBlock3D":
61
+ if cross_attention_dim is None:
62
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
63
+ return CrossAttnDownBlock3D(
64
+ num_layers=num_layers,
65
+ in_channels=in_channels,
66
+ out_channels=out_channels,
67
+ temb_channels=temb_channels,
68
+ add_downsample=add_downsample,
69
+ resnet_eps=resnet_eps,
70
+ resnet_act_fn=resnet_act_fn,
71
+ resnet_groups=resnet_groups,
72
+ downsample_padding=downsample_padding,
73
+ cross_attention_dim=cross_attention_dim,
74
+ attn_num_head_channels=attn_num_head_channels,
75
+ dual_cross_attention=dual_cross_attention,
76
+ use_linear_projection=use_linear_projection,
77
+ only_cross_attention=only_cross_attention,
78
+ upcast_attention=upcast_attention,
79
+ resnet_time_scale_shift=resnet_time_scale_shift,
80
+
81
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
82
+ unet_use_temporal_attention=unet_use_temporal_attention,
83
+ use_inflated_groupnorm=use_inflated_groupnorm,
84
+
85
+ use_motion_module=use_motion_module,
86
+ motion_module_type=motion_module_type,
87
+ motion_module_kwargs=motion_module_kwargs,
88
+ )
89
+ raise ValueError(f"{down_block_type} does not exist.")
90
+
91
+
92
+ def get_up_block(
93
+ up_block_type,
94
+ num_layers,
95
+ in_channels,
96
+ out_channels,
97
+ prev_output_channel,
98
+ temb_channels,
99
+ add_upsample,
100
+ resnet_eps,
101
+ resnet_act_fn,
102
+ attn_num_head_channels,
103
+ resnet_groups=None,
104
+ cross_attention_dim=None,
105
+ dual_cross_attention=False,
106
+ use_linear_projection=False,
107
+ only_cross_attention=False,
108
+ upcast_attention=False,
109
+ resnet_time_scale_shift="default",
110
+
111
+ unet_use_cross_frame_attention=False,
112
+ unet_use_temporal_attention=False,
113
+ use_inflated_groupnorm=False,
114
+
115
+ use_motion_module=None,
116
+ motion_module_type=None,
117
+ motion_module_kwargs=None,
118
+ ):
119
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
120
+ if up_block_type == "UpBlock3D":
121
+ return UpBlock3D(
122
+ num_layers=num_layers,
123
+ in_channels=in_channels,
124
+ out_channels=out_channels,
125
+ prev_output_channel=prev_output_channel,
126
+ temb_channels=temb_channels,
127
+ add_upsample=add_upsample,
128
+ resnet_eps=resnet_eps,
129
+ resnet_act_fn=resnet_act_fn,
130
+ resnet_groups=resnet_groups,
131
+ resnet_time_scale_shift=resnet_time_scale_shift,
132
+
133
+ use_inflated_groupnorm=use_inflated_groupnorm,
134
+
135
+ use_motion_module=use_motion_module,
136
+ motion_module_type=motion_module_type,
137
+ motion_module_kwargs=motion_module_kwargs,
138
+ )
139
+ elif up_block_type == "CrossAttnUpBlock3D":
140
+ if cross_attention_dim is None:
141
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
142
+ return CrossAttnUpBlock3D(
143
+ num_layers=num_layers,
144
+ in_channels=in_channels,
145
+ out_channels=out_channels,
146
+ prev_output_channel=prev_output_channel,
147
+ temb_channels=temb_channels,
148
+ add_upsample=add_upsample,
149
+ resnet_eps=resnet_eps,
150
+ resnet_act_fn=resnet_act_fn,
151
+ resnet_groups=resnet_groups,
152
+ cross_attention_dim=cross_attention_dim,
153
+ attn_num_head_channels=attn_num_head_channels,
154
+ dual_cross_attention=dual_cross_attention,
155
+ use_linear_projection=use_linear_projection,
156
+ only_cross_attention=only_cross_attention,
157
+ upcast_attention=upcast_attention,
158
+ resnet_time_scale_shift=resnet_time_scale_shift,
159
+
160
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
161
+ unet_use_temporal_attention=unet_use_temporal_attention,
162
+ use_inflated_groupnorm=use_inflated_groupnorm,
163
+
164
+ use_motion_module=use_motion_module,
165
+ motion_module_type=motion_module_type,
166
+ motion_module_kwargs=motion_module_kwargs,
167
+ )
168
+ raise ValueError(f"{up_block_type} does not exist.")
169
+
170
+
171
+ class UNetMidBlock3DCrossAttn(nn.Module):
172
+ def __init__(
173
+ self,
174
+ in_channels: int,
175
+ temb_channels: int,
176
+ dropout: float = 0.0,
177
+ num_layers: int = 1,
178
+ resnet_eps: float = 1e-6,
179
+ resnet_time_scale_shift: str = "default",
180
+ resnet_act_fn: str = "swish",
181
+ resnet_groups: int = 32,
182
+ resnet_pre_norm: bool = True,
183
+ attn_num_head_channels=1,
184
+ output_scale_factor=1.0,
185
+ cross_attention_dim=1280,
186
+ dual_cross_attention=False,
187
+ use_linear_projection=False,
188
+ upcast_attention=False,
189
+
190
+ unet_use_cross_frame_attention=False,
191
+ unet_use_temporal_attention=False,
192
+ use_inflated_groupnorm=False,
193
+
194
+ use_motion_module=None,
195
+
196
+ motion_module_type=None,
197
+ motion_module_kwargs=None,
198
+ ):
199
+ super().__init__()
200
+
201
+ self.has_cross_attention = True
202
+ self.attn_num_head_channels = attn_num_head_channels
203
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
204
+
205
+ # there is always at least one resnet
206
+ resnets = [
207
+ ResnetBlock3D(
208
+ in_channels=in_channels,
209
+ out_channels=in_channels,
210
+ temb_channels=temb_channels,
211
+ eps=resnet_eps,
212
+ groups=resnet_groups,
213
+ dropout=dropout,
214
+ time_embedding_norm=resnet_time_scale_shift,
215
+ non_linearity=resnet_act_fn,
216
+ output_scale_factor=output_scale_factor,
217
+ pre_norm=resnet_pre_norm,
218
+
219
+ use_inflated_groupnorm=use_inflated_groupnorm,
220
+ )
221
+ ]
222
+ attentions = []
223
+ motion_modules = []
224
+
225
+ for _ in range(num_layers):
226
+ if dual_cross_attention:
227
+ raise NotImplementedError
228
+ attentions.append(
229
+ Transformer3DModel(
230
+ attn_num_head_channels,
231
+ in_channels // attn_num_head_channels,
232
+ in_channels=in_channels,
233
+ num_layers=1,
234
+ cross_attention_dim=cross_attention_dim,
235
+ norm_num_groups=resnet_groups,
236
+ use_linear_projection=use_linear_projection,
237
+ upcast_attention=upcast_attention,
238
+
239
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
240
+ unet_use_temporal_attention=unet_use_temporal_attention,
241
+ )
242
+ )
243
+ motion_modules.append(
244
+ get_motion_module(
245
+ in_channels=in_channels,
246
+ motion_module_type=motion_module_type,
247
+ motion_module_kwargs=motion_module_kwargs,
248
+ ) if use_motion_module else None
249
+ )
250
+ resnets.append(
251
+ ResnetBlock3D(
252
+ in_channels=in_channels,
253
+ out_channels=in_channels,
254
+ temb_channels=temb_channels,
255
+ eps=resnet_eps,
256
+ groups=resnet_groups,
257
+ dropout=dropout,
258
+ time_embedding_norm=resnet_time_scale_shift,
259
+ non_linearity=resnet_act_fn,
260
+ output_scale_factor=output_scale_factor,
261
+ pre_norm=resnet_pre_norm,
262
+
263
+ use_inflated_groupnorm=use_inflated_groupnorm,
264
+ )
265
+ )
266
+
267
+ self.attentions = nn.ModuleList(attentions)
268
+ self.resnets = nn.ModuleList(resnets)
269
+ self.motion_modules = nn.ModuleList(motion_modules)
270
+
271
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
272
+ hidden_states = self.resnets[0](hidden_states, temb)
273
+ for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
274
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
275
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
276
+ hidden_states = resnet(hidden_states, temb)
277
+
278
+ return hidden_states
279
+
280
+
281
+ class CrossAttnDownBlock3D(nn.Module):
282
+ def __init__(
283
+ self,
284
+ in_channels: int,
285
+ out_channels: int,
286
+ temb_channels: int,
287
+ dropout: float = 0.0,
288
+ num_layers: int = 1,
289
+ resnet_eps: float = 1e-6,
290
+ resnet_time_scale_shift: str = "default",
291
+ resnet_act_fn: str = "swish",
292
+ resnet_groups: int = 32,
293
+ resnet_pre_norm: bool = True,
294
+ attn_num_head_channels=1,
295
+ cross_attention_dim=1280,
296
+ output_scale_factor=1.0,
297
+ downsample_padding=1,
298
+ add_downsample=True,
299
+ dual_cross_attention=False,
300
+ use_linear_projection=False,
301
+ only_cross_attention=False,
302
+ upcast_attention=False,
303
+
304
+ unet_use_cross_frame_attention=False,
305
+ unet_use_temporal_attention=False,
306
+ use_inflated_groupnorm=False,
307
+
308
+ use_motion_module=None,
309
+
310
+ motion_module_type=None,
311
+ motion_module_kwargs=None,
312
+ ):
313
+ super().__init__()
314
+ resnets = []
315
+ attentions = []
316
+ motion_modules = []
317
+
318
+ self.has_cross_attention = True
319
+ self.attn_num_head_channels = attn_num_head_channels
320
+
321
+ for i in range(num_layers):
322
+ in_channels = in_channels if i == 0 else out_channels
323
+ resnets.append(
324
+ ResnetBlock3D(
325
+ in_channels=in_channels,
326
+ out_channels=out_channels,
327
+ temb_channels=temb_channels,
328
+ eps=resnet_eps,
329
+ groups=resnet_groups,
330
+ dropout=dropout,
331
+ time_embedding_norm=resnet_time_scale_shift,
332
+ non_linearity=resnet_act_fn,
333
+ output_scale_factor=output_scale_factor,
334
+ pre_norm=resnet_pre_norm,
335
+
336
+ use_inflated_groupnorm=use_inflated_groupnorm,
337
+ )
338
+ )
339
+ if dual_cross_attention:
340
+ raise NotImplementedError
341
+ attentions.append(
342
+ Transformer3DModel(
343
+ attn_num_head_channels,
344
+ out_channels // attn_num_head_channels,
345
+ in_channels=out_channels,
346
+ num_layers=1,
347
+ cross_attention_dim=cross_attention_dim,
348
+ norm_num_groups=resnet_groups,
349
+ use_linear_projection=use_linear_projection,
350
+ only_cross_attention=only_cross_attention,
351
+ upcast_attention=upcast_attention,
352
+
353
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
354
+ unet_use_temporal_attention=unet_use_temporal_attention,
355
+ )
356
+ )
357
+ motion_modules.append(
358
+ get_motion_module(
359
+ in_channels=out_channels,
360
+ motion_module_type=motion_module_type,
361
+ motion_module_kwargs=motion_module_kwargs,
362
+ ) if use_motion_module else None
363
+ )
364
+
365
+ self.attentions = nn.ModuleList(attentions)
366
+ self.resnets = nn.ModuleList(resnets)
367
+ self.motion_modules = nn.ModuleList(motion_modules)
368
+
369
+ if add_downsample:
370
+ self.downsamplers = nn.ModuleList(
371
+ [
372
+ Downsample3D(
373
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
374
+ )
375
+ ]
376
+ )
377
+ else:
378
+ self.downsamplers = None
379
+
380
+ self.gradient_checkpointing = False
381
+
382
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
383
+ output_states = ()
384
+
385
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
386
+ if self.training and self.gradient_checkpointing:
387
+
388
+ def create_custom_forward(module, return_dict=None):
389
+ def custom_forward(*inputs):
390
+ if return_dict is not None:
391
+ return module(*inputs, return_dict=return_dict)
392
+ else:
393
+ return module(*inputs)
394
+
395
+ return custom_forward
396
+
397
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
398
+ hidden_states = torch.utils.checkpoint.checkpoint(
399
+ create_custom_forward(attn, return_dict=False),
400
+ hidden_states,
401
+ encoder_hidden_states,
402
+ )[0]
403
+ if motion_module is not None:
404
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
405
+
406
+ else:
407
+ hidden_states = resnet(hidden_states, temb)
408
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
409
+
410
+ # add motion module
411
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
412
+
413
+ output_states += (hidden_states,)
414
+
415
+ if self.downsamplers is not None:
416
+ for downsampler in self.downsamplers:
417
+ hidden_states = downsampler(hidden_states)
418
+
419
+ output_states += (hidden_states,)
420
+
421
+ return hidden_states, output_states
422
+
423
+
424
+ class DownBlock3D(nn.Module):
425
+ def __init__(
426
+ self,
427
+ in_channels: int,
428
+ out_channels: int,
429
+ temb_channels: int,
430
+ dropout: float = 0.0,
431
+ num_layers: int = 1,
432
+ resnet_eps: float = 1e-6,
433
+ resnet_time_scale_shift: str = "default",
434
+ resnet_act_fn: str = "swish",
435
+ resnet_groups: int = 32,
436
+ resnet_pre_norm: bool = True,
437
+ output_scale_factor=1.0,
438
+ add_downsample=True,
439
+ downsample_padding=1,
440
+
441
+ use_inflated_groupnorm=False,
442
+
443
+ use_motion_module=None,
444
+ motion_module_type=None,
445
+ motion_module_kwargs=None,
446
+ ):
447
+ super().__init__()
448
+ resnets = []
449
+ motion_modules = []
450
+
451
+ for i in range(num_layers):
452
+ in_channels = in_channels if i == 0 else out_channels
453
+ resnets.append(
454
+ ResnetBlock3D(
455
+ in_channels=in_channels,
456
+ out_channels=out_channels,
457
+ temb_channels=temb_channels,
458
+ eps=resnet_eps,
459
+ groups=resnet_groups,
460
+ dropout=dropout,
461
+ time_embedding_norm=resnet_time_scale_shift,
462
+ non_linearity=resnet_act_fn,
463
+ output_scale_factor=output_scale_factor,
464
+ pre_norm=resnet_pre_norm,
465
+
466
+ use_inflated_groupnorm=use_inflated_groupnorm,
467
+ )
468
+ )
469
+ motion_modules.append(
470
+ get_motion_module(
471
+ in_channels=out_channels,
472
+ motion_module_type=motion_module_type,
473
+ motion_module_kwargs=motion_module_kwargs,
474
+ ) if use_motion_module else None
475
+ )
476
+
477
+ self.resnets = nn.ModuleList(resnets)
478
+ self.motion_modules = nn.ModuleList(motion_modules)
479
+
480
+ if add_downsample:
481
+ self.downsamplers = nn.ModuleList(
482
+ [
483
+ Downsample3D(
484
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
485
+ )
486
+ ]
487
+ )
488
+ else:
489
+ self.downsamplers = None
490
+
491
+ self.gradient_checkpointing = False
492
+
493
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
494
+ output_states = ()
495
+
496
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
497
+ if self.training and self.gradient_checkpointing:
498
+ def create_custom_forward(module):
499
+ def custom_forward(*inputs):
500
+ return module(*inputs)
501
+
502
+ return custom_forward
503
+
504
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
505
+ if motion_module is not None:
506
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
507
+ else:
508
+ hidden_states = resnet(hidden_states, temb)
509
+
510
+ # add motion module
511
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
512
+
513
+ output_states += (hidden_states,)
514
+
515
+ if self.downsamplers is not None:
516
+ for downsampler in self.downsamplers:
517
+ hidden_states = downsampler(hidden_states)
518
+
519
+ output_states += (hidden_states,)
520
+
521
+ return hidden_states, output_states
522
+
523
+
524
+ class CrossAttnUpBlock3D(nn.Module):
525
+ def __init__(
526
+ self,
527
+ in_channels: int,
528
+ out_channels: int,
529
+ prev_output_channel: int,
530
+ temb_channels: int,
531
+ dropout: float = 0.0,
532
+ num_layers: int = 1,
533
+ resnet_eps: float = 1e-6,
534
+ resnet_time_scale_shift: str = "default",
535
+ resnet_act_fn: str = "swish",
536
+ resnet_groups: int = 32,
537
+ resnet_pre_norm: bool = True,
538
+ attn_num_head_channels=1,
539
+ cross_attention_dim=1280,
540
+ output_scale_factor=1.0,
541
+ add_upsample=True,
542
+ dual_cross_attention=False,
543
+ use_linear_projection=False,
544
+ only_cross_attention=False,
545
+ upcast_attention=False,
546
+
547
+ unet_use_cross_frame_attention=False,
548
+ unet_use_temporal_attention=False,
549
+ use_inflated_groupnorm=False,
550
+
551
+ use_motion_module=None,
552
+
553
+ motion_module_type=None,
554
+ motion_module_kwargs=None,
555
+ ):
556
+ super().__init__()
557
+ resnets = []
558
+ attentions = []
559
+ motion_modules = []
560
+
561
+ self.has_cross_attention = True
562
+ self.attn_num_head_channels = attn_num_head_channels
563
+
564
+ for i in range(num_layers):
565
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
566
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
567
+
568
+ resnets.append(
569
+ ResnetBlock3D(
570
+ in_channels=resnet_in_channels + res_skip_channels,
571
+ out_channels=out_channels,
572
+ temb_channels=temb_channels,
573
+ eps=resnet_eps,
574
+ groups=resnet_groups,
575
+ dropout=dropout,
576
+ time_embedding_norm=resnet_time_scale_shift,
577
+ non_linearity=resnet_act_fn,
578
+ output_scale_factor=output_scale_factor,
579
+ pre_norm=resnet_pre_norm,
580
+
581
+ use_inflated_groupnorm=use_inflated_groupnorm,
582
+ )
583
+ )
584
+ if dual_cross_attention:
585
+ raise NotImplementedError
586
+ attentions.append(
587
+ Transformer3DModel(
588
+ attn_num_head_channels,
589
+ out_channels // attn_num_head_channels,
590
+ in_channels=out_channels,
591
+ num_layers=1,
592
+ cross_attention_dim=cross_attention_dim,
593
+ norm_num_groups=resnet_groups,
594
+ use_linear_projection=use_linear_projection,
595
+ only_cross_attention=only_cross_attention,
596
+ upcast_attention=upcast_attention,
597
+
598
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
599
+ unet_use_temporal_attention=unet_use_temporal_attention,
600
+ )
601
+ )
602
+ motion_modules.append(
603
+ get_motion_module(
604
+ in_channels=out_channels,
605
+ motion_module_type=motion_module_type,
606
+ motion_module_kwargs=motion_module_kwargs,
607
+ ) if use_motion_module else None
608
+ )
609
+
610
+ self.attentions = nn.ModuleList(attentions)
611
+ self.resnets = nn.ModuleList(resnets)
612
+ self.motion_modules = nn.ModuleList(motion_modules)
613
+
614
+ if add_upsample:
615
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
616
+ else:
617
+ self.upsamplers = None
618
+
619
+ self.gradient_checkpointing = False
620
+
621
+ def forward(
622
+ self,
623
+ hidden_states,
624
+ res_hidden_states_tuple,
625
+ temb=None,
626
+ encoder_hidden_states=None,
627
+ upsample_size=None,
628
+ attention_mask=None,
629
+ ):
630
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
631
+ # pop res hidden states
632
+ res_hidden_states = res_hidden_states_tuple[-1]
633
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
634
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
635
+
636
+ if self.training and self.gradient_checkpointing:
637
+
638
+ def create_custom_forward(module, return_dict=None):
639
+ def custom_forward(*inputs):
640
+ if return_dict is not None:
641
+ return module(*inputs, return_dict=return_dict)
642
+ else:
643
+ return module(*inputs)
644
+
645
+ return custom_forward
646
+
647
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
648
+ hidden_states = torch.utils.checkpoint.checkpoint(
649
+ create_custom_forward(attn, return_dict=False),
650
+ hidden_states,
651
+ encoder_hidden_states,
652
+ )[0]
653
+ if motion_module is not None:
654
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
655
+
656
+ else:
657
+ hidden_states = resnet(hidden_states, temb)
658
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
659
+
660
+ # add motion module
661
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
662
+
663
+ if self.upsamplers is not None:
664
+ for upsampler in self.upsamplers:
665
+ hidden_states = upsampler(hidden_states, upsample_size)
666
+
667
+ return hidden_states
668
+
669
+
670
+ class UpBlock3D(nn.Module):
671
+ def __init__(
672
+ self,
673
+ in_channels: int,
674
+ prev_output_channel: int,
675
+ out_channels: int,
676
+ temb_channels: int,
677
+ dropout: float = 0.0,
678
+ num_layers: int = 1,
679
+ resnet_eps: float = 1e-6,
680
+ resnet_time_scale_shift: str = "default",
681
+ resnet_act_fn: str = "swish",
682
+ resnet_groups: int = 32,
683
+ resnet_pre_norm: bool = True,
684
+ output_scale_factor=1.0,
685
+ add_upsample=True,
686
+
687
+ use_inflated_groupnorm=False,
688
+
689
+ use_motion_module=None,
690
+ motion_module_type=None,
691
+ motion_module_kwargs=None,
692
+ ):
693
+ super().__init__()
694
+ resnets = []
695
+ motion_modules = []
696
+
697
+ for i in range(num_layers):
698
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
699
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
700
+
701
+ resnets.append(
702
+ ResnetBlock3D(
703
+ in_channels=resnet_in_channels + res_skip_channels,
704
+ out_channels=out_channels,
705
+ temb_channels=temb_channels,
706
+ eps=resnet_eps,
707
+ groups=resnet_groups,
708
+ dropout=dropout,
709
+ time_embedding_norm=resnet_time_scale_shift,
710
+ non_linearity=resnet_act_fn,
711
+ output_scale_factor=output_scale_factor,
712
+ pre_norm=resnet_pre_norm,
713
+
714
+ use_inflated_groupnorm=use_inflated_groupnorm,
715
+ )
716
+ )
717
+ motion_modules.append(
718
+ get_motion_module(
719
+ in_channels=out_channels,
720
+ motion_module_type=motion_module_type,
721
+ motion_module_kwargs=motion_module_kwargs,
722
+ ) if use_motion_module else None
723
+ )
724
+
725
+ self.resnets = nn.ModuleList(resnets)
726
+ self.motion_modules = nn.ModuleList(motion_modules)
727
+
728
+ if add_upsample:
729
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
730
+ else:
731
+ self.upsamplers = None
732
+
733
+ self.gradient_checkpointing = False
734
+
735
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
736
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
737
+ # pop res hidden states
738
+ res_hidden_states = res_hidden_states_tuple[-1]
739
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
740
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
741
+
742
+ if self.training and self.gradient_checkpointing:
743
+ def create_custom_forward(module):
744
+ def custom_forward(*inputs):
745
+ return module(*inputs)
746
+
747
+ return custom_forward
748
+
749
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
750
+ if motion_module is not None:
751
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
752
+ else:
753
+ hidden_states = resnet(hidden_states, temb)
754
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
755
+
756
+ if self.upsamplers is not None:
757
+ for upsampler in self.upsamplers:
758
+ hidden_states = upsampler(hidden_states, upsample_size)
759
+
760
+ return hidden_states
animatediff/pipelines/pipeline_animation.py ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2
+
3
+ import inspect
4
+ from typing import Callable, List, Optional, Union
5
+ from dataclasses import dataclass
6
+
7
+ import PIL.Image
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+ from diffusers.utils import is_accelerate_available
13
+ from packaging import version
14
+ from transformers import CLIPTextModel, CLIPTokenizer
15
+
16
+ from diffusers.configuration_utils import FrozenDict
17
+ from diffusers.models import AutoencoderKL
18
+ from diffusers import DiffusionPipeline
19
+ from diffusers.schedulers import (
20
+ DDIMScheduler,
21
+ DPMSolverMultistepScheduler,
22
+ EulerAncestralDiscreteScheduler,
23
+ EulerDiscreteScheduler,
24
+ LMSDiscreteScheduler,
25
+ PNDMScheduler,
26
+ )
27
+ from diffusers.utils import deprecate, logging, BaseOutput
28
+
29
+ from einops import rearrange
30
+
31
+ from ..models.unet import UNet3DConditionModel
32
+ from ..models.sparse_controlnet import SparseControlNetModel
33
+ import pdb
34
+ import PIL
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+ # image: either PIL.Image.Image or torch.Tensor.
39
+ def preprocess_image(image, h=512, w=512):
40
+ if isinstance(image, torch.Tensor):
41
+ return image
42
+ elif isinstance(image, PIL.Image.Image):
43
+ # image: [1, 512, 512, 3]
44
+ image = np.array(image.resize((w, h), resample=PIL.Image.LANCZOS))[None, :]
45
+ image = image.astype(np.float16) * 2 / 255.0 - 1.0
46
+ # image: [1, 3, 512, 512]
47
+ image = image.transpose(0, 3, 1, 2)
48
+ image = torch.from_numpy(image)
49
+ else:
50
+ breakpoint()
51
+ return image
52
+
53
+ @dataclass
54
+ class AnimationPipelineOutput(BaseOutput):
55
+ videos: Union[torch.Tensor, np.ndarray]
56
+
57
+
58
+ class AnimationPipeline(DiffusionPipeline):
59
+ _optional_components = []
60
+
61
+ def __init__(
62
+ self,
63
+ vae: AutoencoderKL,
64
+ text_encoder: CLIPTextModel,
65
+ tokenizer: CLIPTokenizer,
66
+ unet: UNet3DConditionModel,
67
+ scheduler: Union[
68
+ DDIMScheduler,
69
+ PNDMScheduler,
70
+ LMSDiscreteScheduler,
71
+ EulerDiscreteScheduler,
72
+ EulerAncestralDiscreteScheduler,
73
+ DPMSolverMultistepScheduler,
74
+ ],
75
+ controlnet: Union[SparseControlNetModel, None] = None,
76
+ torch_dtype=torch.float32,
77
+ ):
78
+ super().__init__()
79
+
80
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
81
+ deprecation_message = (
82
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
83
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
84
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
85
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
86
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
87
+ " file"
88
+ )
89
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
90
+ new_config = dict(scheduler.config)
91
+ new_config["steps_offset"] = 1
92
+ scheduler._internal_dict = FrozenDict(new_config)
93
+
94
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
95
+ deprecation_message = (
96
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
97
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
98
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
99
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
100
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
101
+ )
102
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
103
+ new_config = dict(scheduler.config)
104
+ new_config["clip_sample"] = False
105
+ scheduler._internal_dict = FrozenDict(new_config)
106
+
107
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
108
+ version.parse(unet.config._diffusers_version).base_version
109
+ ) < version.parse("0.9.0.dev0")
110
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
111
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
112
+ deprecation_message = (
113
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
114
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
115
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
116
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
117
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
118
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
119
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
120
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
121
+ " the `unet/config.json` file"
122
+ )
123
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
124
+ new_config = dict(unet.config)
125
+ new_config["sample_size"] = 64
126
+ unet._internal_dict = FrozenDict(new_config)
127
+ self.torch_dtype=torch_dtype
128
+ self.register_modules(
129
+ vae=vae.to(self.torch_dtype),
130
+ text_encoder=text_encoder.to(self.torch_dtype),
131
+ tokenizer=tokenizer,
132
+ unet=unet.to(self.torch_dtype),
133
+ scheduler=scheduler,
134
+ # controlnet=controlnet.to(self.torch_dtype),
135
+ )
136
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
137
+ if controlnet!=None: self.controlnet=controlnet.to(self.torch_dtype)
138
+ def enable_vae_slicing(self):
139
+ self.vae.enable_slicing()
140
+
141
+ def disable_vae_slicing(self):
142
+ self.vae.disable_slicing()
143
+
144
+ def enable_sequential_cpu_offload(self, gpu_id=0):
145
+ if is_accelerate_available():
146
+ from accelerate import cpu_offload
147
+ else:
148
+ raise ImportError("Please install accelerate via `pip install accelerate`")
149
+
150
+ device = torch.device(f"cuda:{gpu_id}")
151
+
152
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
153
+ if cpu_offloaded_model is not None:
154
+ cpu_offload(cpu_offloaded_model, device)
155
+
156
+
157
+ @property
158
+ def _execution_device(self):
159
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
160
+ return self.device
161
+ for module in self.unet.modules():
162
+ if (
163
+ hasattr(module, "_hf_hook")
164
+ and hasattr(module._hf_hook, "execution_device")
165
+ and module._hf_hook.execution_device is not None
166
+ ):
167
+ return torch.device(module._hf_hook.execution_device)
168
+ return self.device
169
+
170
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
171
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
172
+
173
+ text_inputs = self.tokenizer(
174
+ prompt,
175
+ padding="max_length",
176
+ max_length=self.tokenizer.model_max_length,
177
+ truncation=True,
178
+ return_tensors="pt",
179
+ )
180
+ text_input_ids = text_inputs.input_ids
181
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
182
+
183
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
184
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
185
+ logger.warning(
186
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
187
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
188
+ )
189
+
190
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
191
+ attention_mask = text_inputs.attention_mask.to(device)
192
+ else:
193
+ attention_mask = None
194
+
195
+ text_embeddings = self.text_encoder(
196
+ text_input_ids.to(device),
197
+ attention_mask=attention_mask,
198
+ )
199
+ text_embeddings = text_embeddings[0]
200
+
201
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
202
+ bs_embed, seq_len, _ = text_embeddings.shape
203
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
204
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
205
+
206
+ # get unconditional embeddings for classifier free guidance
207
+ if do_classifier_free_guidance:
208
+ uncond_tokens: List[str]
209
+ if negative_prompt is None:
210
+ uncond_tokens = [""] * batch_size
211
+ elif type(prompt) is not type(negative_prompt):
212
+ raise TypeError(
213
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
214
+ f" {type(prompt)}."
215
+ )
216
+ elif isinstance(negative_prompt, str):
217
+ uncond_tokens = [negative_prompt]
218
+ elif batch_size != len(negative_prompt):
219
+ raise ValueError(
220
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
221
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
222
+ " the batch size of `prompt`."
223
+ )
224
+ else:
225
+ uncond_tokens = negative_prompt
226
+
227
+ max_length = text_input_ids.shape[-1]
228
+ uncond_input = self.tokenizer(
229
+ uncond_tokens,
230
+ padding="max_length",
231
+ max_length=max_length,
232
+ truncation=True,
233
+ return_tensors="pt",
234
+ )
235
+
236
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
237
+ attention_mask = uncond_input.attention_mask.to(device)
238
+ else:
239
+ attention_mask = None
240
+
241
+ uncond_embeddings = self.text_encoder(
242
+ uncond_input.input_ids.to(device),
243
+ attention_mask=attention_mask,
244
+ )
245
+ uncond_embeddings = uncond_embeddings[0]
246
+
247
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
248
+ seq_len = uncond_embeddings.shape[1]
249
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
250
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
251
+
252
+ # For classifier free guidance, we need to do two forward passes.
253
+ # Here we concatenate the unconditional and text embeddings into a single batch
254
+ # to avoid doing two forward passes
255
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
256
+
257
+ return text_embeddings
258
+
259
+ def encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
260
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
261
+ # print(batch_size)
262
+ # exit()
263
+ text_inputs = self.tokenizer(
264
+ prompt,
265
+ padding="max_length",
266
+ max_length=self.tokenizer.model_max_length,
267
+ truncation=True,
268
+ return_tensors="pt",
269
+ )
270
+ text_input_ids = text_inputs.input_ids
271
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
272
+
273
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
274
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
275
+ logger.warning(
276
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
277
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
278
+ )
279
+
280
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
281
+ attention_mask = text_inputs.attention_mask.to(device)
282
+ else:
283
+ attention_mask = None
284
+
285
+ text_embeddings = self.text_encoder(
286
+ text_input_ids.to(device),
287
+ attention_mask=attention_mask,
288
+ )
289
+ text_embeddings = text_embeddings[0]
290
+
291
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
292
+ bs_embed, seq_len, _ = text_embeddings.shape
293
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
294
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
295
+
296
+ # get unconditional embeddings for classifier free guidance
297
+ if do_classifier_free_guidance:
298
+ uncond_tokens: List[str]
299
+ if negative_prompt is None:
300
+ uncond_tokens = [""] * batch_size
301
+ elif type(prompt) is not type(negative_prompt):
302
+ raise TypeError(
303
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
304
+ f" {type(prompt)}."
305
+ )
306
+ elif isinstance(negative_prompt, str):
307
+ uncond_tokens = [negative_prompt]
308
+ elif batch_size != len(negative_prompt):
309
+ raise ValueError(
310
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
311
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
312
+ " the batch size of `prompt`."
313
+ )
314
+ else:
315
+ uncond_tokens = negative_prompt
316
+ max_length = text_input_ids.shape[-1]
317
+ uncond_input = self.tokenizer(
318
+ uncond_tokens,
319
+ padding="max_length",
320
+ max_length=max_length,
321
+ truncation=True,
322
+ return_tensors="pt",
323
+ )
324
+
325
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
326
+ attention_mask = uncond_input.attention_mask.to(device)
327
+ else:
328
+ attention_mask = None
329
+
330
+ uncond_embeddings = self.text_encoder(
331
+ uncond_input.input_ids.to(device),
332
+ attention_mask=attention_mask,
333
+ )
334
+ uncond_embeddings = uncond_embeddings[0]
335
+
336
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
337
+ seq_len = uncond_embeddings.shape[1]
338
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
339
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
340
+
341
+ # For classifier free guidance, we need to do two forward passes.
342
+ # Here we concatenate the unconditional and text embeddings into a single batch
343
+ # to avoid doing two forward passes
344
+ # text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
345
+ # print("encode here!!!")
346
+ # print("shape of text_embeddings",text_embeddings.shape)
347
+ # print("shape of uncond_embeddings",uncond_embeddings.shape)
348
+ return text_embeddings,uncond_embeddings
349
+
350
+
351
+ def decode_latents(self, latents):
352
+ video_length = latents.shape[2]
353
+ latents = 1 / 0.18215 * latents
354
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
355
+ # video = self.vae.decode(latents).sample
356
+ video = []
357
+ for frame_idx in range(latents.shape[0]):
358
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
359
+ video = torch.cat(video)
360
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
361
+ video = (video / 2 + 0.5).clamp(0, 1)
362
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
363
+ video = video.cpu().float().numpy()
364
+ return video
365
+
366
+ def prepare_extra_step_kwargs(self, generator, eta):
367
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
368
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
369
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
370
+ # and should be between [0, 1]
371
+
372
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
373
+ extra_step_kwargs = {}
374
+ if accepts_eta:
375
+ extra_step_kwargs["eta"] = eta
376
+
377
+ # check if the scheduler accepts generator
378
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
379
+ if accepts_generator:
380
+ extra_step_kwargs["generator"] = generator
381
+ return extra_step_kwargs
382
+
383
+ def check_inputs(self, prompt, height, width, callback_steps,prompt_embedding):
384
+ if not isinstance(prompt, str) and not isinstance(prompt, list) and prompt_embedding==None:
385
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
386
+
387
+ if height % 8 != 0 or width % 8 != 0:
388
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
389
+
390
+ if (callback_steps is None) or (
391
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
392
+ ):
393
+ raise ValueError(
394
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
395
+ f" {type(callback_steps)}."
396
+ )
397
+
398
+ def prepare_latents(self, init_image, init_image_strength, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
399
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
400
+
401
+ if init_image is not None:
402
+ # init_image: either PIL.Image.Image or torch.Tensor.
403
+ image = preprocess_image(init_image, height, width)
404
+ image = image.to(device=device, dtype=dtype)
405
+ if isinstance(generator, list):
406
+ init_latents = [
407
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
408
+ ]
409
+ init_latents = torch.cat(init_latents, dim=0)
410
+ else:
411
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
412
+ else:
413
+ init_latents = None
414
+
415
+
416
+ if isinstance(generator, list) and len(generator) != batch_size:
417
+ raise ValueError(
418
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
419
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
420
+ )
421
+ if latents is None:
422
+ rand_device = "cpu" if device.type == "mps" else device
423
+
424
+ if isinstance(generator, list):
425
+ shape = shape
426
+ # shape = (1,) + shape[1:]
427
+ # ignore init latents for batch model
428
+ latents = [
429
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
430
+ for i in range(batch_size)
431
+ ]
432
+ latents = torch.cat(latents, dim=0).to(device)
433
+ else:
434
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
435
+ if init_latents is not None:
436
+ blend_frames = video_length // 2
437
+ init_image_strength, init_image_final_weight = init_image_strength
438
+ for i in range(video_length):
439
+ dist_to_end = (blend_frames - float(i)) / blend_frames
440
+ # When i > 0.9 * blend_frames, dist_to_end < 0.1. Then it will be changed to 0.05,
441
+ # so that the last half of the video still is still initialized with a little bit of init_latents.
442
+ dist_to_end = max(dist_to_end, init_image_final_weight)
443
+ # Changed from /30 to /100.
444
+ # gradully reduce init alpha along video frames (loosen restriction)
445
+ init_alpha = dist_to_end * init_image_strength / 100
446
+ latents[:, :, i, :, :] = init_latents * init_alpha + latents[:, :, i, :, :] * (1 - init_alpha)
447
+ else:
448
+ if latents.shape != shape:
449
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
450
+ latents = latents.to(device)
451
+
452
+ # scale the initial noise by the standard deviation required by the scheduler
453
+ if init_latents is None:
454
+ latents = latents * self.scheduler.init_noise_sigma
455
+ return latents
456
+
457
+ @torch.no_grad()
458
+ def __call__(
459
+ self,
460
+ prompt: Union[str, List[str]],
461
+ video_length: Optional[int],
462
+ init_image: Union[PIL.Image.Image, torch.Tensor],
463
+ init_image_strength: float = 1.0,
464
+ height: Optional[int] = None,
465
+ width: Optional[int] = None,
466
+ num_inference_steps: int = 50,
467
+ guidance_scale: float = 7.5,
468
+ negative_prompt: Optional[Union[str, List[str]]] = None,
469
+ num_videos_per_prompt: Optional[int] = 1,
470
+ eta: float = 0.0,
471
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
472
+ latents: Optional[torch.FloatTensor] = None,
473
+ output_type: Optional[str] = "tensor",
474
+ return_dict: bool = True,
475
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
476
+ callback_steps: Optional[int] = 1,
477
+ #support embeddings
478
+ prompt_embeds: Optional[torch.FloatTensor] = None,
479
+ negative_prompt_embeds:Optional[torch.FloatTensor] = None,
480
+ # support controlnet
481
+ controlnet_images: torch.FloatTensor = None,
482
+ controlnet_image_index: list = [0],
483
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
484
+ **kwargs,
485
+ ):
486
+ # Default height and width to unet
487
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
488
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
489
+
490
+ if isinstance(prompt_embeds, (list, tuple)):
491
+ prompt_embeds_begin, prompt_embeds_end, adaface_anneal_steps = prompt_embeds
492
+ prompt_embeds = prompt_embeds_begin
493
+ do_prompt_embeds_annealing = True
494
+ else:
495
+ do_prompt_embeds_annealing = False
496
+
497
+ # Check inputs. Raise error if not correct
498
+ self.check_inputs(prompt, height, width, callback_steps, prompt_embeds)
499
+
500
+ # Define call parameters
501
+ # batch_size = 1 if isinstance(prompt, str) else len(prompt)
502
+ batch_size = 1
503
+ if latents is not None:
504
+ batch_size = latents.shape[0]
505
+ if isinstance(prompt, list):
506
+ batch_size = len(prompt)
507
+
508
+ device = self._execution_device
509
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
510
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
511
+ # corresponds to doing no classifier free guidance.
512
+ do_classifier_free_guidance = guidance_scale > 1.0
513
+
514
+ # Encode input prompt
515
+ prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
516
+ if negative_prompt is not None:
517
+ negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
518
+ if prompt_embeds is None:
519
+ text_embeddings = self._encode_prompt(
520
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
521
+ )
522
+ # If do_prompt_embeds_annealing is True, prompt_embeds and text_embeddings will be assigned in the loop below,
523
+ # and this is just to avoid type error.
524
+ # Otherwise, text_embeddings won't be replaced.
525
+ else:
526
+ text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
527
+
528
+ # print(text_embeddings.shape)
529
+ # return
530
+ # Prepare timesteps
531
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
532
+ timesteps = self.scheduler.timesteps
533
+
534
+ # Prepare latent variables
535
+ num_channels_latents = self.unet.in_channels
536
+ latents = self.prepare_latents(
537
+ init_image,
538
+ init_image_strength,
539
+ batch_size * num_videos_per_prompt,
540
+ num_channels_latents,
541
+ video_length,
542
+ height,
543
+ width,
544
+ text_embeddings.dtype,
545
+ device,
546
+ generator,
547
+ latents,
548
+ ).to(self.torch_dtype)
549
+ latents_dtype = latents.dtype
550
+
551
+ # Prepare extra step kwargs.
552
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
553
+
554
+ # Denoising loop
555
+ # num_warmup_steps = 0. num_inference_steps: 30.
556
+ # [958, 925, 892, 859, 826, 793, 760, 727, 694, 661, 628, 595, 562, 529,
557
+ # 496, 463, 430, 397, 364, 331, 298, 265, 232, 199, 166, 133, 100, 67,
558
+ # 34, 1]
559
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
560
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
561
+ for i, t in enumerate(timesteps):
562
+ # expand the latents if we are doing classifier free guidance
563
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
564
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
565
+
566
+ down_block_additional_residuals = mid_block_additional_residual = None
567
+ if (getattr(self, "controlnet", None) != None) and (controlnet_images != None):
568
+ assert controlnet_images.dim() == 5
569
+
570
+ controlnet_noisy_latents = latent_model_input
571
+ controlnet_prompt_embeds = text_embeddings
572
+
573
+ controlnet_images = controlnet_images.to(latents.device)
574
+
575
+ controlnet_cond_shape = list(controlnet_images.shape)
576
+ controlnet_cond_shape[2] = video_length
577
+ controlnet_cond = torch.zeros(controlnet_cond_shape).to(latents.device).to(latents.dtype)
578
+
579
+ controlnet_conditioning_mask_shape = list(controlnet_cond.shape)
580
+ controlnet_conditioning_mask_shape[1] = 1
581
+ controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(latents.device).to(latents.dtype)
582
+
583
+ assert controlnet_images.shape[2] >= len(controlnet_image_index)
584
+ controlnet_cond[:,:,controlnet_image_index] = controlnet_images[:,:,:len(controlnet_image_index)]
585
+ controlnet_conditioning_mask[:,:,controlnet_image_index] = 1
586
+
587
+
588
+ down_block_additional_residuals, mid_block_additional_residual = self.controlnet(
589
+ controlnet_noisy_latents, t,
590
+ encoder_hidden_states=controlnet_prompt_embeds,
591
+ controlnet_cond=controlnet_cond,
592
+ conditioning_mask=controlnet_conditioning_mask,
593
+ conditioning_scale=controlnet_conditioning_scale,
594
+ guess_mode=False, return_dict=False,
595
+ )
596
+
597
+ if do_prompt_embeds_annealing:
598
+ # i: 0 to num_inference_steps. Anneal the first adaface_anneal_steps steps.
599
+ # If adaface_anneal_steps == 0, then anneal_factor is always 1.
600
+ anneal_factor = i / adaface_anneal_steps if i < adaface_anneal_steps else 1
601
+ prompt_embeds_annealed = prompt_embeds_begin + anneal_factor * (prompt_embeds_end - prompt_embeds_begin)
602
+ text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds_annealed])
603
+
604
+ # predict the noise residual
605
+ noise_pred = self.unet(
606
+ latent_model_input, t,
607
+ encoder_hidden_states=text_embeddings,
608
+ down_block_additional_residuals = down_block_additional_residuals,
609
+ mid_block_additional_residual = mid_block_additional_residual,
610
+ ).sample.to(dtype=latents_dtype)
611
+
612
+ # perform guidance
613
+ if do_classifier_free_guidance:
614
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
615
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
616
+
617
+ # compute the previous noisy sample x_t -> x_t-1
618
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
619
+
620
+ # call the callback, if provided
621
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
622
+ progress_bar.update()
623
+ if callback is not None and i % callback_steps == 0:
624
+ callback(i, t, latents)
625
+
626
+ # Post-processing
627
+ video = self.decode_latents(latents)
628
+
629
+ # Convert to tensor
630
+ if output_type == "tensor":
631
+ video = torch.from_numpy(video)
632
+
633
+ if not return_dict:
634
+ return video
635
+
636
+ return AnimationPipelineOutput(videos=video)
637
+ @torch.no_grad()
638
+ def video_edit(
639
+ self,
640
+ prompt: Union[str, List[str]],
641
+ video_length: Optional[int],
642
+ init_image: Union[PIL.Image.Image, torch.Tensor],
643
+ init_image_strength: float = 1.0,
644
+ height: Optional[int] = None,
645
+ width: Optional[int] = None,
646
+ num_inference_steps: int = 50,
647
+ guidance_scale: float = 7.5,
648
+ negative_prompt: Optional[Union[str, List[str]]] = None,
649
+ num_videos_per_prompt: Optional[int] = 1,
650
+ eta: float = 0.0,
651
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
652
+ latents: Optional[torch.FloatTensor] = None,
653
+ output_type: Optional[str] = "tensor",
654
+ return_dict: bool = True,
655
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
656
+ callback_steps: Optional[int] = 1,
657
+ #support embeddings
658
+ prompt_embeds: Optional[torch.FloatTensor] = None,
659
+ negative_prompt_embeds:Optional[torch.FloatTensor] = None,
660
+ # support controlnet
661
+ controlnet_images: torch.FloatTensor = None,
662
+ controlnet_image_index: list = [0],
663
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
664
+ **kwargs,
665
+ ):
666
+ # Default height and width to unet
667
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
668
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
669
+
670
+ # Check inputs. Raise error if not correct
671
+ self.check_inputs(prompt, height, width, callback_steps, prompt_embeds)
672
+
673
+ # Define call parameters
674
+ # batch_size = 1 if isinstance(prompt, str) else len(prompt)
675
+ batch_size = 1
676
+ if latents is not None:
677
+ batch_size = latents.shape[0]
678
+ if isinstance(prompt, list):
679
+ batch_size = len(prompt)
680
+
681
+ device = self._execution_device
682
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
683
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
684
+ # corresponds to doing no classifier free guidance.
685
+ do_classifier_free_guidance = guidance_scale > 1.0
686
+
687
+ # Encode input prompt
688
+ prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
689
+ if negative_prompt is not None:
690
+ negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
691
+ if prompt_embeds is None:
692
+ text_embeddings = self._encode_prompt(
693
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
694
+ )
695
+ else:
696
+ text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
697
+
698
+ # print(text_embeddings.shape)
699
+ # return
700
+ # Prepare timesteps
701
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
702
+ timesteps = self.scheduler.timesteps
703
+
704
+ # Prepare latent variables
705
+ num_channels_latents = self.unet.in_channels
706
+ latents = self.prepare_latents(
707
+ init_image,
708
+ init_image_strength,
709
+ batch_size * num_videos_per_prompt,
710
+ num_channels_latents,
711
+ video_length,
712
+ height,
713
+ width,
714
+ text_embeddings.dtype,
715
+ device,
716
+ generator,
717
+ latents,
718
+ ).to(self.torch_dtype)
719
+ latents_dtype = latents.dtype
720
+
721
+ # Prepare extra step kwargs.
722
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
723
+
724
+ # Denoising loop
725
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
726
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
727
+ for i, t in enumerate(timesteps):
728
+ # expand the latents if we are doing classifier free guidance
729
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
730
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
731
+
732
+ down_block_additional_residuals = mid_block_additional_residual = None
733
+ if (getattr(self, "controlnet", None) != None) and (controlnet_images != None):
734
+ assert controlnet_images.dim() == 5
735
+
736
+ controlnet_noisy_latents = latent_model_input
737
+ controlnet_prompt_embeds = text_embeddings
738
+
739
+ controlnet_images = controlnet_images.to(latents.device)
740
+
741
+ controlnet_cond_shape = list(controlnet_images.shape)
742
+ controlnet_cond_shape[2] = video_length
743
+ controlnet_cond = torch.zeros(controlnet_cond_shape).to(latents.device)
744
+
745
+ controlnet_conditioning_mask_shape = list(controlnet_cond.shape)
746
+ controlnet_conditioning_mask_shape[1] = 1
747
+ controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(latents.device)
748
+
749
+ assert controlnet_images.shape[2] >= len(controlnet_image_index)
750
+ controlnet_cond[:,:,controlnet_image_index] = controlnet_images[:,:,:len(controlnet_image_index)]
751
+ controlnet_conditioning_mask[:,:,controlnet_image_index] = 1
752
+
753
+ down_block_additional_residuals, mid_block_additional_residual = self.controlnet(
754
+ controlnet_noisy_latents, t,
755
+ encoder_hidden_states=controlnet_prompt_embeds,
756
+ controlnet_cond=controlnet_cond,
757
+ conditioning_mask=controlnet_conditioning_mask,
758
+ conditioning_scale=controlnet_conditioning_scale,
759
+ guess_mode=False, return_dict=False,
760
+ )
761
+ # predict the noise residual
762
+ noise_pred = self.unet(
763
+ latent_model_input, t,
764
+ encoder_hidden_states=text_embeddings,
765
+ down_block_additional_residuals = down_block_additional_residuals,
766
+ mid_block_additional_residual = mid_block_additional_residual,
767
+ ).sample.to(dtype=latents_dtype)
768
+
769
+ # perform guidance
770
+ if do_classifier_free_guidance:
771
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
772
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
773
+
774
+ # compute the previous noisy sample x_t -> x_t-1
775
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
776
+
777
+ # call the callback, if provided
778
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
779
+ progress_bar.update()
780
+ if callback is not None and i % callback_steps == 0:
781
+ callback(i, t, latents)
782
+
783
+ # Post-processing
784
+ video = self.decode_latents(latents)
785
+
786
+ # Convert to tensor
787
+ if output_type == "tensor":
788
+ video = torch.from_numpy(video)
789
+
790
+ if not return_dict:
791
+ return video
792
+
793
+ return AnimationPipelineOutput(videos=video)
animatediff/sd/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.npy filter=lfs diff=lfs merge=lfs -text
14
+ *.npz filter=lfs diff=lfs merge=lfs -text
15
+ *.onnx filter=lfs diff=lfs merge=lfs -text
16
+ *.ot filter=lfs diff=lfs merge=lfs -text
17
+ *.parquet filter=lfs diff=lfs merge=lfs -text
18
+ *.pb filter=lfs diff=lfs merge=lfs -text
19
+ *.pickle filter=lfs diff=lfs merge=lfs -text
20
+ *.pkl filter=lfs diff=lfs merge=lfs -text
21
+ *.pt filter=lfs diff=lfs merge=lfs -text
22
+ *.pth filter=lfs diff=lfs merge=lfs -text
23
+ *.rar filter=lfs diff=lfs merge=lfs -text
24
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
25
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
27
+ *.tflite filter=lfs diff=lfs merge=lfs -text
28
+ *.tgz filter=lfs diff=lfs merge=lfs -text
29
+ *.wasm filter=lfs diff=lfs merge=lfs -text
30
+ *.xz filter=lfs diff=lfs merge=lfs -text
31
+ *.zip filter=lfs diff=lfs merge=lfs -text
32
+ *.zst filter=lfs diff=lfs merge=lfs -text
33
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
34
+ v1-5-pruned-emaonly.ckpt filter=lfs diff=lfs merge=lfs -text
35
+ v1-5-pruned.ckpt filter=lfs diff=lfs merge=lfs -text
animatediff/sd/feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": 224,
3
+ "do_center_crop": true,
4
+ "do_convert_rgb": true,
5
+ "do_normalize": true,
6
+ "do_resize": true,
7
+ "feature_extractor_type": "CLIPFeatureExtractor",
8
+ "image_mean": [
9
+ 0.48145466,
10
+ 0.4578275,
11
+ 0.40821073
12
+ ],
13
+ "image_std": [
14
+ 0.26862954,
15
+ 0.26130258,
16
+ 0.27577711
17
+ ],
18
+ "resample": 3,
19
+ "size": 224
20
+ }
animatediff/sd/model_index.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "StableDiffusionPipeline",
3
+ "_diffusers_version": "0.6.0",
4
+ "feature_extractor": [
5
+ "transformers",
6
+ "CLIPImageProcessor"
7
+ ],
8
+ "safety_checker": [
9
+ "stable_diffusion",
10
+ "StableDiffusionSafetyChecker"
11
+ ],
12
+ "scheduler": [
13
+ "diffusers",
14
+ "PNDMScheduler"
15
+ ],
16
+ "text_encoder": [
17
+ "transformers",
18
+ "CLIPTextModel"
19
+ ],
20
+ "tokenizer": [
21
+ "transformers",
22
+ "CLIPTokenizer"
23
+ ],
24
+ "unet": [
25
+ "diffusers",
26
+ "UNet2DConditionModel"
27
+ ],
28
+ "vae": [
29
+ "diffusers",
30
+ "AutoencoderKL"
31
+ ]
32
+ }
animatediff/sd/safety_checker/config.json ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": "4bb648a606ef040e7685bde262611766a5fdd67b",
3
+ "_name_or_path": "CompVis/stable-diffusion-safety-checker",
4
+ "architectures": [
5
+ "StableDiffusionSafetyChecker"
6
+ ],
7
+ "initializer_factor": 1.0,
8
+ "logit_scale_init_value": 2.6592,
9
+ "model_type": "clip",
10
+ "projection_dim": 768,
11
+ "text_config": {
12
+ "_name_or_path": "",
13
+ "add_cross_attention": false,
14
+ "architectures": null,
15
+ "attention_dropout": 0.0,
16
+ "bad_words_ids": null,
17
+ "bos_token_id": 0,
18
+ "chunk_size_feed_forward": 0,
19
+ "cross_attention_hidden_size": null,
20
+ "decoder_start_token_id": null,
21
+ "diversity_penalty": 0.0,
22
+ "do_sample": false,
23
+ "dropout": 0.0,
24
+ "early_stopping": false,
25
+ "encoder_no_repeat_ngram_size": 0,
26
+ "eos_token_id": 2,
27
+ "exponential_decay_length_penalty": null,
28
+ "finetuning_task": null,
29
+ "forced_bos_token_id": null,
30
+ "forced_eos_token_id": null,
31
+ "hidden_act": "quick_gelu",
32
+ "hidden_size": 768,
33
+ "id2label": {
34
+ "0": "LABEL_0",
35
+ "1": "LABEL_1"
36
+ },
37
+ "initializer_factor": 1.0,
38
+ "initializer_range": 0.02,
39
+ "intermediate_size": 3072,
40
+ "is_decoder": false,
41
+ "is_encoder_decoder": false,
42
+ "label2id": {
43
+ "LABEL_0": 0,
44
+ "LABEL_1": 1
45
+ },
46
+ "layer_norm_eps": 1e-05,
47
+ "length_penalty": 1.0,
48
+ "max_length": 20,
49
+ "max_position_embeddings": 77,
50
+ "min_length": 0,
51
+ "model_type": "clip_text_model",
52
+ "no_repeat_ngram_size": 0,
53
+ "num_attention_heads": 12,
54
+ "num_beam_groups": 1,
55
+ "num_beams": 1,
56
+ "num_hidden_layers": 12,
57
+ "num_return_sequences": 1,
58
+ "output_attentions": false,
59
+ "output_hidden_states": false,
60
+ "output_scores": false,
61
+ "pad_token_id": 1,
62
+ "prefix": null,
63
+ "problem_type": null,
64
+ "pruned_heads": {},
65
+ "remove_invalid_values": false,
66
+ "repetition_penalty": 1.0,
67
+ "return_dict": true,
68
+ "return_dict_in_generate": false,
69
+ "sep_token_id": null,
70
+ "task_specific_params": null,
71
+ "temperature": 1.0,
72
+ "tf_legacy_loss": false,
73
+ "tie_encoder_decoder": false,
74
+ "tie_word_embeddings": true,
75
+ "tokenizer_class": null,
76
+ "top_k": 50,
77
+ "top_p": 1.0,
78
+ "torch_dtype": null,
79
+ "torchscript": false,
80
+ "transformers_version": "4.22.0.dev0",
81
+ "typical_p": 1.0,
82
+ "use_bfloat16": false,
83
+ "vocab_size": 49408
84
+ },
85
+ "text_config_dict": {
86
+ "hidden_size": 768,
87
+ "intermediate_size": 3072,
88
+ "num_attention_heads": 12,
89
+ "num_hidden_layers": 12
90
+ },
91
+ "torch_dtype": "float32",
92
+ "transformers_version": null,
93
+ "vision_config": {
94
+ "_name_or_path": "",
95
+ "add_cross_attention": false,
96
+ "architectures": null,
97
+ "attention_dropout": 0.0,
98
+ "bad_words_ids": null,
99
+ "bos_token_id": null,
100
+ "chunk_size_feed_forward": 0,
101
+ "cross_attention_hidden_size": null,
102
+ "decoder_start_token_id": null,
103
+ "diversity_penalty": 0.0,
104
+ "do_sample": false,
105
+ "dropout": 0.0,
106
+ "early_stopping": false,
107
+ "encoder_no_repeat_ngram_size": 0,
108
+ "eos_token_id": null,
109
+ "exponential_decay_length_penalty": null,
110
+ "finetuning_task": null,
111
+ "forced_bos_token_id": null,
112
+ "forced_eos_token_id": null,
113
+ "hidden_act": "quick_gelu",
114
+ "hidden_size": 1024,
115
+ "id2label": {
116
+ "0": "LABEL_0",
117
+ "1": "LABEL_1"
118
+ },
119
+ "image_size": 224,
120
+ "initializer_factor": 1.0,
121
+ "initializer_range": 0.02,
122
+ "intermediate_size": 4096,
123
+ "is_decoder": false,
124
+ "is_encoder_decoder": false,
125
+ "label2id": {
126
+ "LABEL_0": 0,
127
+ "LABEL_1": 1
128
+ },
129
+ "layer_norm_eps": 1e-05,
130
+ "length_penalty": 1.0,
131
+ "max_length": 20,
132
+ "min_length": 0,
133
+ "model_type": "clip_vision_model",
134
+ "no_repeat_ngram_size": 0,
135
+ "num_attention_heads": 16,
136
+ "num_beam_groups": 1,
137
+ "num_beams": 1,
138
+ "num_channels": 3,
139
+ "num_hidden_layers": 24,
140
+ "num_return_sequences": 1,
141
+ "output_attentions": false,
142
+ "output_hidden_states": false,
143
+ "output_scores": false,
144
+ "pad_token_id": null,
145
+ "patch_size": 14,
146
+ "prefix": null,
147
+ "problem_type": null,
148
+ "pruned_heads": {},
149
+ "remove_invalid_values": false,
150
+ "repetition_penalty": 1.0,
151
+ "return_dict": true,
152
+ "return_dict_in_generate": false,
153
+ "sep_token_id": null,
154
+ "task_specific_params": null,
155
+ "temperature": 1.0,
156
+ "tf_legacy_loss": false,
157
+ "tie_encoder_decoder": false,
158
+ "tie_word_embeddings": true,
159
+ "tokenizer_class": null,
160
+ "top_k": 50,
161
+ "top_p": 1.0,
162
+ "torch_dtype": null,
163
+ "torchscript": false,
164
+ "transformers_version": "4.22.0.dev0",
165
+ "typical_p": 1.0,
166
+ "use_bfloat16": false
167
+ },
168
+ "vision_config_dict": {
169
+ "hidden_size": 1024,
170
+ "intermediate_size": 4096,
171
+ "num_attention_heads": 16,
172
+ "num_hidden_layers": 24,
173
+ "patch_size": 14
174
+ }
175
+ }
animatediff/sd/safety_checker/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:193490b58ef62739077262e833bf091c66c29488058681ac25cf7df3d8190974
3
+ size 1216061799
animatediff/sd/scheduler/scheduler_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PNDMScheduler",
3
+ "_diffusers_version": "0.6.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "num_train_timesteps": 1000,
8
+ "set_alpha_to_one": false,
9
+ "skip_prk_steps": true,
10
+ "steps_offset": 1,
11
+ "trained_betas": null,
12
+ "clip_sample": false
13
+ }
animatediff/sd/text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "openai/clip-vit-large-patch14",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "quick_gelu",
11
+ "hidden_size": 768,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 768,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.22.0.dev0",
24
+ "vocab_size": 49408
25
+ }
animatediff/sd/text_encoder/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:770a47a9ffdcfda0b05506a7888ed714d06131d60267e6cf52765d61cf59fd67
3
+ size 492305335
animatediff/sd/tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
animatediff/sd/tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|endoftext|>",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
animatediff/sd/tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "do_lower_case": true,
12
+ "eos_token": {
13
+ "__type": "AddedToken",
14
+ "content": "<|endoftext|>",
15
+ "lstrip": false,
16
+ "normalized": true,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ },
20
+ "errors": "replace",
21
+ "model_max_length": 77,
22
+ "name_or_path": "openai/clip-vit-large-patch14",
23
+ "pad_token": "<|endoftext|>",
24
+ "special_tokens_map_file": "./special_tokens_map.json",
25
+ "tokenizer_class": "CLIPTokenizer",
26
+ "unk_token": {
27
+ "__type": "AddedToken",
28
+ "content": "<|endoftext|>",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false
33
+ }
34
+ }
animatediff/sd/tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
animatediff/sd/unet/config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.6.0",
4
+ "act_fn": "silu",
5
+ "attention_head_dim": 8,
6
+ "block_out_channels": [
7
+ 320,
8
+ 640,
9
+ 1280,
10
+ 1280
11
+ ],
12
+ "center_input_sample": false,
13
+ "cross_attention_dim": 768,
14
+ "down_block_types": [
15
+ "CrossAttnDownBlock2D",
16
+ "CrossAttnDownBlock2D",
17
+ "CrossAttnDownBlock2D",
18
+ "DownBlock2D"
19
+ ],
20
+ "downsample_padding": 1,
21
+ "flip_sin_to_cos": true,
22
+ "freq_shift": 0,
23
+ "in_channels": 4,
24
+ "layers_per_block": 2,
25
+ "mid_block_scale_factor": 1,
26
+ "norm_eps": 1e-05,
27
+ "norm_num_groups": 32,
28
+ "out_channels": 4,
29
+ "sample_size": 64,
30
+ "up_block_types": [
31
+ "UpBlock2D",
32
+ "CrossAttnUpBlock2D",
33
+ "CrossAttnUpBlock2D",
34
+ "CrossAttnUpBlock2D"
35
+ ]
36
+ }
animatediff/sd/unet/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7da0e21ba7ea50637bee26e81c220844defdf01aafca02b2c42ecdadb813de4
3
+ size 3438354725
animatediff/sd/v1-inference.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 10000 ]
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
animatediff/sd/vae/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.6.0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "in_channels": 3,
18
+ "latent_channels": 4,
19
+ "layers_per_block": 2,
20
+ "norm_num_groups": 32,
21
+ "out_channels": 3,
22
+ "sample_size": 512,
23
+ "up_block_types": [
24
+ "UpDecoderBlock2D",
25
+ "UpDecoderBlock2D",
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D"
28
+ ]
29
+ }
animatediff/sd/vae/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b134cded8eb78b184aefb8805b6b572f36fa77b255c483665dda931fa0130c5
3
+ size 334707217
animatediff/utils/convert_from_ckpt.py ADDED
@@ -0,0 +1,959 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Conversion script for the Stable Diffusion checkpoints."""
16
+
17
+ import re
18
+ from io import BytesIO
19
+ from typing import Optional
20
+
21
+ import requests
22
+ import torch
23
+ from transformers import (
24
+ AutoFeatureExtractor,
25
+ BertTokenizerFast,
26
+ CLIPImageProcessor,
27
+ CLIPTextModel,
28
+ CLIPTextModelWithProjection,
29
+ CLIPTokenizer,
30
+ CLIPVisionConfig,
31
+ CLIPVisionModelWithProjection,
32
+ )
33
+
34
+ from diffusers.models import (
35
+ AutoencoderKL,
36
+ PriorTransformer,
37
+ UNet2DConditionModel,
38
+ )
39
+ from diffusers.schedulers import (
40
+ DDIMScheduler,
41
+ DDPMScheduler,
42
+ DPMSolverMultistepScheduler,
43
+ EulerAncestralDiscreteScheduler,
44
+ EulerDiscreteScheduler,
45
+ HeunDiscreteScheduler,
46
+ LMSDiscreteScheduler,
47
+ PNDMScheduler,
48
+ UnCLIPScheduler,
49
+ )
50
+ from diffusers.utils.import_utils import BACKENDS_MAPPING
51
+
52
+
53
+ def shave_segments(path, n_shave_prefix_segments=1):
54
+ """
55
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
56
+ """
57
+ if n_shave_prefix_segments >= 0:
58
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
59
+ else:
60
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
61
+
62
+
63
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
64
+ """
65
+ Updates paths inside resnets to the new naming scheme (local renaming)
66
+ """
67
+ mapping = []
68
+ for old_item in old_list:
69
+ new_item = old_item.replace("in_layers.0", "norm1")
70
+ new_item = new_item.replace("in_layers.2", "conv1")
71
+
72
+ new_item = new_item.replace("out_layers.0", "norm2")
73
+ new_item = new_item.replace("out_layers.3", "conv2")
74
+
75
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
76
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
77
+
78
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
79
+
80
+ mapping.append({"old": old_item, "new": new_item})
81
+
82
+ return mapping
83
+
84
+
85
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
86
+ """
87
+ Updates paths inside resnets to the new naming scheme (local renaming)
88
+ """
89
+ mapping = []
90
+ for old_item in old_list:
91
+ new_item = old_item
92
+
93
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
94
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
95
+
96
+ mapping.append({"old": old_item, "new": new_item})
97
+
98
+ return mapping
99
+
100
+
101
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
102
+ """
103
+ Updates paths inside attentions to the new naming scheme (local renaming)
104
+ """
105
+ mapping = []
106
+ for old_item in old_list:
107
+ new_item = old_item
108
+
109
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
110
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
111
+
112
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
113
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
114
+
115
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
116
+
117
+ mapping.append({"old": old_item, "new": new_item})
118
+
119
+ return mapping
120
+
121
+
122
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
123
+ """
124
+ Updates paths inside attentions to the new naming scheme (local renaming)
125
+ """
126
+ mapping = []
127
+ for old_item in old_list:
128
+ new_item = old_item
129
+
130
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
131
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
132
+
133
+ new_item = new_item.replace("q.weight", "query.weight")
134
+ new_item = new_item.replace("q.bias", "query.bias")
135
+
136
+ new_item = new_item.replace("k.weight", "key.weight")
137
+ new_item = new_item.replace("k.bias", "key.bias")
138
+
139
+ new_item = new_item.replace("v.weight", "value.weight")
140
+ new_item = new_item.replace("v.bias", "value.bias")
141
+
142
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
143
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
144
+
145
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
146
+
147
+ mapping.append({"old": old_item, "new": new_item})
148
+
149
+ return mapping
150
+
151
+
152
+ def assign_to_checkpoint(
153
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
154
+ ):
155
+ """
156
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
157
+ attention layers, and takes into account additional replacements that may arise.
158
+
159
+ Assigns the weights to the new checkpoint.
160
+ """
161
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
162
+
163
+ # Splits the attention layers into three variables.
164
+ if attention_paths_to_split is not None:
165
+ for path, path_map in attention_paths_to_split.items():
166
+ old_tensor = old_checkpoint[path]
167
+ channels = old_tensor.shape[0] // 3
168
+
169
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
170
+
171
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
172
+
173
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
174
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
175
+
176
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
177
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
178
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
179
+
180
+ for path in paths:
181
+ new_path = path["new"]
182
+
183
+ # These have already been assigned
184
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
185
+ continue
186
+
187
+ # Global renaming happens here
188
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
189
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
190
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
191
+
192
+ if additional_replacements is not None:
193
+ for replacement in additional_replacements:
194
+ new_path = new_path.replace(replacement["old"], replacement["new"])
195
+
196
+ # proj_attn.weight has to be converted from conv 1D to linear
197
+ if "proj_attn.weight" in new_path:
198
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
199
+ else:
200
+ checkpoint[new_path] = old_checkpoint[path["old"]]
201
+
202
+
203
+ def conv_attn_to_linear(checkpoint):
204
+ keys = list(checkpoint.keys())
205
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
206
+ for key in keys:
207
+ if ".".join(key.split(".")[-2:]) in attn_keys:
208
+ if checkpoint[key].ndim > 2:
209
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
210
+ elif "proj_attn.weight" in key:
211
+ if checkpoint[key].ndim > 2:
212
+ checkpoint[key] = checkpoint[key][:, :, 0]
213
+
214
+
215
+ def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
216
+ """
217
+ Creates a config for the diffusers based on the config of the LDM model.
218
+ """
219
+ if controlnet:
220
+ unet_params = original_config.model.params.control_stage_config.params
221
+ else:
222
+ unet_params = original_config.model.params.unet_config.params
223
+
224
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
225
+
226
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
227
+
228
+ down_block_types = []
229
+ resolution = 1
230
+ for i in range(len(block_out_channels)):
231
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
232
+ down_block_types.append(block_type)
233
+ if i != len(block_out_channels) - 1:
234
+ resolution *= 2
235
+
236
+ up_block_types = []
237
+ for i in range(len(block_out_channels)):
238
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
239
+ up_block_types.append(block_type)
240
+ resolution //= 2
241
+
242
+ vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
243
+
244
+ head_dim = unet_params.num_heads if "num_heads" in unet_params else None
245
+ use_linear_projection = (
246
+ unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
247
+ )
248
+ if use_linear_projection:
249
+ # stable diffusion 2-base-512 and 2-768
250
+ if head_dim is None:
251
+ head_dim = [5, 10, 20, 20]
252
+
253
+ class_embed_type = None
254
+ projection_class_embeddings_input_dim = None
255
+
256
+ if "num_classes" in unet_params:
257
+ if unet_params.num_classes == "sequential":
258
+ class_embed_type = "projection"
259
+ assert "adm_in_channels" in unet_params
260
+ projection_class_embeddings_input_dim = unet_params.adm_in_channels
261
+ else:
262
+ raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
263
+
264
+ config = {
265
+ "sample_size": image_size // vae_scale_factor,
266
+ "in_channels": unet_params.in_channels,
267
+ "down_block_types": tuple(down_block_types),
268
+ "block_out_channels": tuple(block_out_channels),
269
+ "layers_per_block": unet_params.num_res_blocks,
270
+ "cross_attention_dim": unet_params.context_dim,
271
+ "attention_head_dim": head_dim,
272
+ "use_linear_projection": use_linear_projection,
273
+ "class_embed_type": class_embed_type,
274
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
275
+ }
276
+
277
+ if not controlnet:
278
+ config["out_channels"] = unet_params.out_channels
279
+ config["up_block_types"] = tuple(up_block_types)
280
+
281
+ return config
282
+
283
+
284
+ def create_vae_diffusers_config(original_config, image_size: int):
285
+ """
286
+ Creates a config for the diffusers based on the config of the LDM model.
287
+ """
288
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
289
+ _ = original_config.model.params.first_stage_config.params.embed_dim
290
+
291
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
292
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
293
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
294
+
295
+ config = {
296
+ "sample_size": image_size,
297
+ "in_channels": vae_params.in_channels,
298
+ "out_channels": vae_params.out_ch,
299
+ "down_block_types": tuple(down_block_types),
300
+ "up_block_types": tuple(up_block_types),
301
+ "block_out_channels": tuple(block_out_channels),
302
+ "latent_channels": vae_params.z_channels,
303
+ "layers_per_block": vae_params.num_res_blocks,
304
+ }
305
+ return config
306
+
307
+
308
+ def create_diffusers_schedular(original_config):
309
+ schedular = DDIMScheduler(
310
+ num_train_timesteps=original_config.model.params.timesteps,
311
+ beta_start=original_config.model.params.linear_start,
312
+ beta_end=original_config.model.params.linear_end,
313
+ beta_schedule="scaled_linear",
314
+ )
315
+ return schedular
316
+
317
+
318
+ def create_ldm_bert_config(original_config):
319
+ bert_params = original_config.model.parms.cond_stage_config.params
320
+ config = LDMBertConfig(
321
+ d_model=bert_params.n_embed,
322
+ encoder_layers=bert_params.n_layer,
323
+ encoder_ffn_dim=bert_params.n_embed * 4,
324
+ )
325
+ return config
326
+
327
+
328
+ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
329
+ """
330
+ Takes a state dict and a config, and returns a converted checkpoint.
331
+ """
332
+
333
+ # extract state_dict for UNet
334
+ unet_state_dict = {}
335
+ keys = list(checkpoint.keys())
336
+
337
+ if controlnet:
338
+ unet_key = "control_model."
339
+ else:
340
+ unet_key = "model.diffusion_model."
341
+
342
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
343
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
344
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
345
+ print(
346
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
347
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
348
+ )
349
+ for key in keys:
350
+ if key.startswith("model.diffusion_model"):
351
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
352
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
353
+ else:
354
+ if sum(k.startswith("model_ema") for k in keys) > 100:
355
+ print(
356
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
357
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
358
+ )
359
+
360
+ for key in keys:
361
+ if key.startswith(unet_key):
362
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
363
+
364
+ new_checkpoint = {}
365
+
366
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
367
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
368
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
369
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
370
+
371
+ if config["class_embed_type"] is None:
372
+ # No parameters to port
373
+ ...
374
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
375
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
376
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
377
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
378
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
379
+ else:
380
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
381
+
382
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
383
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
384
+
385
+ if not controlnet:
386
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
387
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
388
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
389
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
390
+
391
+ # Retrieves the keys for the input blocks only
392
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
393
+ input_blocks = {
394
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
395
+ for layer_id in range(num_input_blocks)
396
+ }
397
+
398
+ # Retrieves the keys for the middle blocks only
399
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
400
+ middle_blocks = {
401
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
402
+ for layer_id in range(num_middle_blocks)
403
+ }
404
+
405
+ # Retrieves the keys for the output blocks only
406
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
407
+ output_blocks = {
408
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
409
+ for layer_id in range(num_output_blocks)
410
+ }
411
+
412
+ for i in range(1, num_input_blocks):
413
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
414
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
415
+
416
+ resnets = [
417
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
418
+ ]
419
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
420
+
421
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
422
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
423
+ f"input_blocks.{i}.0.op.weight"
424
+ )
425
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
426
+ f"input_blocks.{i}.0.op.bias"
427
+ )
428
+
429
+ paths = renew_resnet_paths(resnets)
430
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
431
+ assign_to_checkpoint(
432
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
433
+ )
434
+
435
+ if len(attentions):
436
+ paths = renew_attention_paths(attentions)
437
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
438
+ assign_to_checkpoint(
439
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
440
+ )
441
+
442
+ resnet_0 = middle_blocks[0]
443
+ attentions = middle_blocks[1]
444
+ resnet_1 = middle_blocks[2]
445
+
446
+ resnet_0_paths = renew_resnet_paths(resnet_0)
447
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
448
+
449
+ resnet_1_paths = renew_resnet_paths(resnet_1)
450
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
451
+
452
+ attentions_paths = renew_attention_paths(attentions)
453
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
454
+ assign_to_checkpoint(
455
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
456
+ )
457
+
458
+ for i in range(num_output_blocks):
459
+ block_id = i // (config["layers_per_block"] + 1)
460
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
461
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
462
+ output_block_list = {}
463
+
464
+ for layer in output_block_layers:
465
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
466
+ if layer_id in output_block_list:
467
+ output_block_list[layer_id].append(layer_name)
468
+ else:
469
+ output_block_list[layer_id] = [layer_name]
470
+
471
+ if len(output_block_list) > 1:
472
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
473
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
474
+
475
+ resnet_0_paths = renew_resnet_paths(resnets)
476
+ paths = renew_resnet_paths(resnets)
477
+
478
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
479
+ assign_to_checkpoint(
480
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
481
+ )
482
+
483
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
484
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
485
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
486
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
487
+ f"output_blocks.{i}.{index}.conv.weight"
488
+ ]
489
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
490
+ f"output_blocks.{i}.{index}.conv.bias"
491
+ ]
492
+
493
+ # Clear attentions as they have been attributed above.
494
+ if len(attentions) == 2:
495
+ attentions = []
496
+
497
+ if len(attentions):
498
+ paths = renew_attention_paths(attentions)
499
+ meta_path = {
500
+ "old": f"output_blocks.{i}.1",
501
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
502
+ }
503
+ assign_to_checkpoint(
504
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
505
+ )
506
+ else:
507
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
508
+ for path in resnet_0_paths:
509
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
510
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
511
+
512
+ new_checkpoint[new_path] = unet_state_dict[old_path]
513
+
514
+ if controlnet:
515
+ # conditioning embedding
516
+
517
+ orig_index = 0
518
+
519
+ new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
520
+ f"input_hint_block.{orig_index}.weight"
521
+ )
522
+ new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
523
+ f"input_hint_block.{orig_index}.bias"
524
+ )
525
+
526
+ orig_index += 2
527
+
528
+ diffusers_index = 0
529
+
530
+ while diffusers_index < 6:
531
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
532
+ f"input_hint_block.{orig_index}.weight"
533
+ )
534
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
535
+ f"input_hint_block.{orig_index}.bias"
536
+ )
537
+ diffusers_index += 1
538
+ orig_index += 2
539
+
540
+ new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
541
+ f"input_hint_block.{orig_index}.weight"
542
+ )
543
+ new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
544
+ f"input_hint_block.{orig_index}.bias"
545
+ )
546
+
547
+ # down blocks
548
+ for i in range(num_input_blocks):
549
+ new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
550
+ new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
551
+
552
+ # mid block
553
+ new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
554
+ new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
555
+
556
+ return new_checkpoint
557
+
558
+
559
+ def convert_ldm_vae_checkpoint(checkpoint, config):
560
+ # extract state dict for VAE
561
+ vae_state_dict = {}
562
+ vae_key = "first_stage_model."
563
+ keys = list(checkpoint.keys())
564
+ for key in keys:
565
+ if key.startswith(vae_key):
566
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
567
+
568
+ new_checkpoint = {}
569
+
570
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
571
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
572
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
573
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
574
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
575
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
576
+
577
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
578
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
579
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
580
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
581
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
582
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
583
+
584
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
585
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
586
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
587
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
588
+
589
+ # Retrieves the keys for the encoder down blocks only
590
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
591
+ down_blocks = {
592
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
593
+ }
594
+
595
+ # Retrieves the keys for the decoder up blocks only
596
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
597
+ up_blocks = {
598
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
599
+ }
600
+
601
+ for i in range(num_down_blocks):
602
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
603
+
604
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
605
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
606
+ f"encoder.down.{i}.downsample.conv.weight"
607
+ )
608
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
609
+ f"encoder.down.{i}.downsample.conv.bias"
610
+ )
611
+
612
+ paths = renew_vae_resnet_paths(resnets)
613
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
614
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
615
+
616
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
617
+ num_mid_res_blocks = 2
618
+ for i in range(1, num_mid_res_blocks + 1):
619
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
620
+
621
+ paths = renew_vae_resnet_paths(resnets)
622
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
623
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
624
+
625
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
626
+ paths = renew_vae_attention_paths(mid_attentions)
627
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
628
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
629
+ conv_attn_to_linear(new_checkpoint)
630
+
631
+ for i in range(num_up_blocks):
632
+ block_id = num_up_blocks - 1 - i
633
+ resnets = [
634
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
635
+ ]
636
+
637
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
638
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
639
+ f"decoder.up.{block_id}.upsample.conv.weight"
640
+ ]
641
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
642
+ f"decoder.up.{block_id}.upsample.conv.bias"
643
+ ]
644
+
645
+ paths = renew_vae_resnet_paths(resnets)
646
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
647
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
648
+
649
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
650
+ num_mid_res_blocks = 2
651
+ for i in range(1, num_mid_res_blocks + 1):
652
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
653
+
654
+ paths = renew_vae_resnet_paths(resnets)
655
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
656
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
657
+
658
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
659
+ paths = renew_vae_attention_paths(mid_attentions)
660
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
661
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
662
+ conv_attn_to_linear(new_checkpoint)
663
+ return new_checkpoint
664
+
665
+
666
+ def convert_ldm_bert_checkpoint(checkpoint, config):
667
+ def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
668
+ hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
669
+ hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
670
+ hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
671
+
672
+ hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
673
+ hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
674
+
675
+ def _copy_linear(hf_linear, pt_linear):
676
+ hf_linear.weight = pt_linear.weight
677
+ hf_linear.bias = pt_linear.bias
678
+
679
+ def _copy_layer(hf_layer, pt_layer):
680
+ # copy layer norms
681
+ _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
682
+ _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
683
+
684
+ # copy attn
685
+ _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
686
+
687
+ # copy MLP
688
+ pt_mlp = pt_layer[1][1]
689
+ _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
690
+ _copy_linear(hf_layer.fc2, pt_mlp.net[2])
691
+
692
+ def _copy_layers(hf_layers, pt_layers):
693
+ for i, hf_layer in enumerate(hf_layers):
694
+ if i != 0:
695
+ i += i
696
+ pt_layer = pt_layers[i : i + 2]
697
+ _copy_layer(hf_layer, pt_layer)
698
+
699
+ hf_model = LDMBertModel(config).eval()
700
+
701
+ # copy embeds
702
+ hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
703
+ hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
704
+
705
+ # copy layer norm
706
+ _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
707
+
708
+ # copy hidden layers
709
+ _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
710
+
711
+ _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
712
+
713
+ return hf_model
714
+
715
+
716
+ def convert_ldm_clip_checkpoint(checkpoint, dtype=torch.float16):
717
+ text_model = CLIPTextModel.from_pretrained("animatediff/sd/text_encoder", torch_dtype=dtype)
718
+ keys = list(checkpoint.keys())
719
+
720
+ text_model_dict = {}
721
+
722
+ for key in keys:
723
+ if key.startswith("cond_stage_model.transformer"):
724
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
725
+
726
+ text_model.load_state_dict(text_model_dict, strict=False)
727
+
728
+ return text_model
729
+
730
+
731
+ textenc_conversion_lst = [
732
+ ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
733
+ ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
734
+ ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
735
+ ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
736
+ ]
737
+ textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
738
+
739
+ textenc_transformer_conversion_lst = [
740
+ # (stable-diffusion, HF Diffusers)
741
+ ("resblocks.", "text_model.encoder.layers."),
742
+ ("ln_1", "layer_norm1"),
743
+ ("ln_2", "layer_norm2"),
744
+ (".c_fc.", ".fc1."),
745
+ (".c_proj.", ".fc2."),
746
+ (".attn", ".self_attn"),
747
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
748
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
749
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
750
+ ]
751
+ protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
752
+ textenc_pattern = re.compile("|".join(protected.keys()))
753
+
754
+
755
+ def convert_paint_by_example_checkpoint(checkpoint):
756
+ config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
757
+ model = PaintByExampleImageEncoder(config)
758
+
759
+ keys = list(checkpoint.keys())
760
+
761
+ text_model_dict = {}
762
+
763
+ for key in keys:
764
+ if key.startswith("cond_stage_model.transformer"):
765
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
766
+
767
+ # load clip vision
768
+ model.model.load_state_dict(text_model_dict)
769
+
770
+ # load mapper
771
+ keys_mapper = {
772
+ k[len("cond_stage_model.mapper.res") :]: v
773
+ for k, v in checkpoint.items()
774
+ if k.startswith("cond_stage_model.mapper")
775
+ }
776
+
777
+ MAPPING = {
778
+ "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
779
+ "attn.c_proj": ["attn1.to_out.0"],
780
+ "ln_1": ["norm1"],
781
+ "ln_2": ["norm3"],
782
+ "mlp.c_fc": ["ff.net.0.proj"],
783
+ "mlp.c_proj": ["ff.net.2"],
784
+ }
785
+
786
+ mapped_weights = {}
787
+ for key, value in keys_mapper.items():
788
+ prefix = key[: len("blocks.i")]
789
+ suffix = key.split(prefix)[-1].split(".")[-1]
790
+ name = key.split(prefix)[-1].split(suffix)[0][1:-1]
791
+ mapped_names = MAPPING[name]
792
+
793
+ num_splits = len(mapped_names)
794
+ for i, mapped_name in enumerate(mapped_names):
795
+ new_name = ".".join([prefix, mapped_name, suffix])
796
+ shape = value.shape[0] // num_splits
797
+ mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
798
+
799
+ model.mapper.load_state_dict(mapped_weights)
800
+
801
+ # load final layer norm
802
+ model.final_layer_norm.load_state_dict(
803
+ {
804
+ "bias": checkpoint["cond_stage_model.final_ln.bias"],
805
+ "weight": checkpoint["cond_stage_model.final_ln.weight"],
806
+ }
807
+ )
808
+
809
+ # load final proj
810
+ model.proj_out.load_state_dict(
811
+ {
812
+ "bias": checkpoint["proj_out.bias"],
813
+ "weight": checkpoint["proj_out.weight"],
814
+ }
815
+ )
816
+
817
+ # load uncond vector
818
+ model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
819
+ return model
820
+
821
+
822
+ def convert_open_clip_checkpoint(checkpoint):
823
+ text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
824
+
825
+ keys = list(checkpoint.keys())
826
+
827
+ text_model_dict = {}
828
+
829
+ if "cond_stage_model.model.text_projection" in checkpoint:
830
+ d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
831
+ else:
832
+ d_model = 1024
833
+
834
+ text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
835
+
836
+ for key in keys:
837
+ if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
838
+ continue
839
+ if key in textenc_conversion_map:
840
+ text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
841
+ if key.startswith("cond_stage_model.model.transformer."):
842
+ new_key = key[len("cond_stage_model.model.transformer.") :]
843
+ if new_key.endswith(".in_proj_weight"):
844
+ new_key = new_key[: -len(".in_proj_weight")]
845
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
846
+ text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
847
+ text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
848
+ text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
849
+ elif new_key.endswith(".in_proj_bias"):
850
+ new_key = new_key[: -len(".in_proj_bias")]
851
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
852
+ text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
853
+ text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
854
+ text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
855
+ else:
856
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
857
+
858
+ text_model_dict[new_key] = checkpoint[key]
859
+
860
+ text_model.load_state_dict(text_model_dict)
861
+
862
+ return text_model
863
+
864
+
865
+ def stable_unclip_image_encoder(original_config):
866
+ """
867
+ Returns the image processor and clip image encoder for the img2img unclip pipeline.
868
+
869
+ We currently know of two types of stable unclip models which separately use the clip and the openclip image
870
+ encoders.
871
+ """
872
+
873
+ image_embedder_config = original_config.model.params.embedder_config
874
+
875
+ sd_clip_image_embedder_class = image_embedder_config.target
876
+ sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
877
+
878
+ if sd_clip_image_embedder_class == "ClipImageEmbedder":
879
+ clip_model_name = image_embedder_config.params.model
880
+
881
+ if clip_model_name == "ViT-L/14":
882
+ feature_extractor = CLIPImageProcessor()
883
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
884
+ else:
885
+ raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
886
+
887
+ elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
888
+ feature_extractor = CLIPImageProcessor()
889
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
890
+ else:
891
+ raise NotImplementedError(
892
+ f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
893
+ )
894
+
895
+ return feature_extractor, image_encoder
896
+
897
+
898
+ def stable_unclip_image_noising_components(
899
+ original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
900
+ ):
901
+ """
902
+ Returns the noising components for the img2img and txt2img unclip pipelines.
903
+
904
+ Converts the stability noise augmentor into
905
+ 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
906
+ 2. a `DDPMScheduler` for holding the noise schedule
907
+
908
+ If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
909
+ """
910
+ noise_aug_config = original_config.model.params.noise_aug_config
911
+ noise_aug_class = noise_aug_config.target
912
+ noise_aug_class = noise_aug_class.split(".")[-1]
913
+
914
+ if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
915
+ noise_aug_config = noise_aug_config.params
916
+ embedding_dim = noise_aug_config.timestep_dim
917
+ max_noise_level = noise_aug_config.noise_schedule_config.timesteps
918
+ beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
919
+
920
+ image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
921
+ image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
922
+
923
+ if "clip_stats_path" in noise_aug_config:
924
+ if clip_stats_path is None:
925
+ raise ValueError("This stable unclip config requires a `clip_stats_path`")
926
+
927
+ clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
928
+ clip_mean = clip_mean[None, :]
929
+ clip_std = clip_std[None, :]
930
+
931
+ clip_stats_state_dict = {
932
+ "mean": clip_mean,
933
+ "std": clip_std,
934
+ }
935
+
936
+ image_normalizer.load_state_dict(clip_stats_state_dict)
937
+ else:
938
+ raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
939
+
940
+ return image_normalizer, image_noising_scheduler
941
+
942
+
943
+ def convert_controlnet_checkpoint(
944
+ checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
945
+ ):
946
+ ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
947
+ ctrlnet_config["upcast_attention"] = upcast_attention
948
+
949
+ ctrlnet_config.pop("sample_size")
950
+
951
+ controlnet_model = ControlNetModel(**ctrlnet_config)
952
+
953
+ converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
954
+ checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
955
+ )
956
+
957
+ controlnet_model.load_state_dict(converted_ctrl_checkpoint)
958
+
959
+ return controlnet_model
animatediff/utils/convert_lora_safetensor_to_diffusers.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ #
16
+ # Changes were made to this source code by Yuwei Guo.
17
+ """ Conversion script for the LoRA's safetensors checkpoints. """
18
+
19
+ import argparse
20
+
21
+ import torch
22
+ from safetensors.torch import load_file
23
+
24
+ from diffusers import StableDiffusionPipeline
25
+
26
+
27
+ def load_diffusers_lora(pipeline, state_dict, alpha=1.0):
28
+ # directly update weight in diffusers model
29
+ for key in state_dict:
30
+ # only process lora down key
31
+ if "up." in key: continue
32
+
33
+ up_key = key.replace(".down.", ".up.")
34
+ model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
35
+ model_key = model_key.replace("to_out.", "to_out.0.")
36
+ layer_infos = model_key.split(".")[:-1]
37
+
38
+ curr_layer = pipeline.unet
39
+ while len(layer_infos) > 0:
40
+ temp_name = layer_infos.pop(0)
41
+ curr_layer = curr_layer.__getattr__(temp_name)
42
+
43
+ weight_down = state_dict[key]
44
+ weight_up = state_dict[up_key]
45
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
46
+
47
+ return pipeline
48
+
49
+
50
+ def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
51
+ # load base model
52
+ # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
53
+
54
+ # load LoRA weight from .safetensors
55
+ # state_dict = load_file(checkpoint_path)
56
+
57
+ visited = []
58
+
59
+ # directly update weight in diffusers model
60
+ for key in state_dict:
61
+ # it is suggested to print out the key, it usually will be something like below
62
+ # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
63
+
64
+ # as we have set the alpha beforehand, so just skip
65
+ if ".alpha" in key or key in visited:
66
+ continue
67
+
68
+ if "text" in key:
69
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
70
+ curr_layer = pipeline.text_encoder
71
+ else:
72
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
73
+ curr_layer = pipeline.unet
74
+
75
+ # find the target layer
76
+ temp_name = layer_infos.pop(0)
77
+ while len(layer_infos) > -1:
78
+ try:
79
+ curr_layer = curr_layer.__getattr__(temp_name)
80
+ if len(layer_infos) > 0:
81
+ temp_name = layer_infos.pop(0)
82
+ elif len(layer_infos) == 0:
83
+ break
84
+ except Exception:
85
+ if len(temp_name) > 0:
86
+ temp_name += "_" + layer_infos.pop(0)
87
+ else:
88
+ temp_name = layer_infos.pop(0)
89
+
90
+ pair_keys = []
91
+ if "lora_down" in key:
92
+ pair_keys.append(key.replace("lora_down", "lora_up"))
93
+ pair_keys.append(key)
94
+ else:
95
+ pair_keys.append(key)
96
+ pair_keys.append(key.replace("lora_up", "lora_down"))
97
+
98
+ # update weight
99
+ if len(state_dict[pair_keys[0]].shape) == 4:
100
+ weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
101
+ weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
102
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
103
+ else:
104
+ weight_up = state_dict[pair_keys[0]].to(torch.float32)
105
+ weight_down = state_dict[pair_keys[1]].to(torch.float32)
106
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
107
+
108
+ # update visited list
109
+ for item in pair_keys:
110
+ visited.append(item)
111
+
112
+ return pipeline
113
+
114
+
115
+ if __name__ == "__main__":
116
+ parser = argparse.ArgumentParser()
117
+
118
+ parser.add_argument(
119
+ "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
120
+ )
121
+ parser.add_argument(
122
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
123
+ )
124
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
125
+ parser.add_argument(
126
+ "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
127
+ )
128
+ parser.add_argument(
129
+ "--lora_prefix_text_encoder",
130
+ default="lora_te",
131
+ type=str,
132
+ help="The prefix of text encoder weight in safetensors",
133
+ )
134
+ parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
135
+ parser.add_argument(
136
+ "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
137
+ )
138
+ parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
139
+
140
+ args = parser.parse_args()
141
+
142
+ base_model_path = args.base_model_path
143
+ checkpoint_path = args.checkpoint_path
144
+ dump_path = args.dump_path
145
+ lora_prefix_unet = args.lora_prefix_unet
146
+ lora_prefix_text_encoder = args.lora_prefix_text_encoder
147
+ alpha = args.alpha
148
+
149
+ pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
150
+
151
+ pipe = pipe.to(args.device)
152
+ pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
animatediff/utils/convert_original_stable_diffusion_to_diffusers.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Conversion script for the LDM checkpoints."""
16
+
17
+ import argparse
18
+ import importlib
19
+
20
+ import torch
21
+
22
+ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
23
+
24
+
25
+ if __name__ == "__main__":
26
+ parser = argparse.ArgumentParser()
27
+
28
+ parser.add_argument(
29
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
30
+ )
31
+ # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
32
+ parser.add_argument(
33
+ "--original_config_file",
34
+ default=None,
35
+ type=str,
36
+ help="The YAML config file corresponding to the original architecture.",
37
+ )
38
+ parser.add_argument(
39
+ "--config_files",
40
+ default=None,
41
+ type=str,
42
+ help="The YAML config file corresponding to the architecture.",
43
+ )
44
+ parser.add_argument(
45
+ "--num_in_channels",
46
+ default=None,
47
+ type=int,
48
+ help="The number of input channels. If `None` number of input channels will be automatically inferred.",
49
+ )
50
+ parser.add_argument(
51
+ "--scheduler_type",
52
+ default="pndm",
53
+ type=str,
54
+ help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']",
55
+ )
56
+ parser.add_argument(
57
+ "--pipeline_type",
58
+ default=None,
59
+ type=str,
60
+ help=(
61
+ "The pipeline type. One of 'FrozenOpenCLIPEmbedder', 'FrozenCLIPEmbedder', 'PaintByExample'"
62
+ ". If `None` pipeline will be automatically inferred."
63
+ ),
64
+ )
65
+ parser.add_argument(
66
+ "--image_size",
67
+ default=None,
68
+ type=int,
69
+ help=(
70
+ "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
71
+ " Base. Use 768 for Stable Diffusion v2."
72
+ ),
73
+ )
74
+ parser.add_argument(
75
+ "--prediction_type",
76
+ default=None,
77
+ type=str,
78
+ help=(
79
+ "The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable"
80
+ " Diffusion v2 Base. Use 'v_prediction' for Stable Diffusion v2."
81
+ ),
82
+ )
83
+ parser.add_argument(
84
+ "--extract_ema",
85
+ action="store_true",
86
+ help=(
87
+ "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
88
+ " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
89
+ " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
90
+ ),
91
+ )
92
+ parser.add_argument(
93
+ "--upcast_attention",
94
+ action="store_true",
95
+ help=(
96
+ "Whether the attention computation should always be upcasted. This is necessary when running stable"
97
+ " diffusion 2.1."
98
+ ),
99
+ )
100
+ parser.add_argument(
101
+ "--from_safetensors",
102
+ action="store_true",
103
+ help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
104
+ )
105
+ parser.add_argument(
106
+ "--to_safetensors",
107
+ action="store_true",
108
+ help="Whether to store pipeline in safetensors format or not.",
109
+ )
110
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
111
+ parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
112
+ parser.add_argument(
113
+ "--stable_unclip",
114
+ type=str,
115
+ default=None,
116
+ required=False,
117
+ help="Set if this is a stable unCLIP model. One of 'txt2img' or 'img2img'.",
118
+ )
119
+ parser.add_argument(
120
+ "--stable_unclip_prior",
121
+ type=str,
122
+ default=None,
123
+ required=False,
124
+ help="Set if this is a stable unCLIP txt2img model. Selects which prior to use. If `--stable_unclip` is set to `txt2img`, the karlo prior (https://huggingface.co/kakaobrain/karlo-v1-alpha/tree/main/prior) is selected by default.",
125
+ )
126
+ parser.add_argument(
127
+ "--clip_stats_path",
128
+ type=str,
129
+ help="Path to the clip stats file. Only required if the stable unclip model's config specifies `model.params.noise_aug_config.params.clip_stats_path`.",
130
+ required=False,
131
+ )
132
+ parser.add_argument(
133
+ "--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint."
134
+ )
135
+ parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
136
+ parser.add_argument(
137
+ "--vae_path",
138
+ type=str,
139
+ default=None,
140
+ required=False,
141
+ help="Set to a path, hub id to an already converted vae to not convert it again.",
142
+ )
143
+ parser.add_argument(
144
+ "--pipeline_class_name",
145
+ type=str,
146
+ default=None,
147
+ required=False,
148
+ help="Specify the pipeline class name",
149
+ )
150
+
151
+ args = parser.parse_args()
152
+
153
+ if args.pipeline_class_name is not None:
154
+ library = importlib.import_module("diffusers")
155
+ class_obj = getattr(library, args.pipeline_class_name)
156
+ pipeline_class = class_obj
157
+ else:
158
+ pipeline_class = None
159
+
160
+ pipe = download_from_original_stable_diffusion_ckpt(
161
+ checkpoint_path_or_dict=args.checkpoint_path,
162
+ original_config_file=args.original_config_file,
163
+ config_files=args.config_files,
164
+ image_size=args.image_size,
165
+ prediction_type=args.prediction_type,
166
+ model_type=args.pipeline_type,
167
+ extract_ema=args.extract_ema,
168
+ scheduler_type=args.scheduler_type,
169
+ num_in_channels=args.num_in_channels,
170
+ upcast_attention=args.upcast_attention,
171
+ from_safetensors=args.from_safetensors,
172
+ device=args.device,
173
+ stable_unclip=args.stable_unclip,
174
+ stable_unclip_prior=args.stable_unclip_prior,
175
+ clip_stats_path=args.clip_stats_path,
176
+ controlnet=args.controlnet,
177
+ vae_path=args.vae_path,
178
+ pipeline_class=pipeline_class,
179
+ )
180
+
181
+ if args.half:
182
+ pipe.to(dtype=torch.float16)
183
+
184
+ if args.controlnet:
185
+ # only save the controlnet model
186
+ pipe.controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
187
+ else:
188
+ pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
animatediff/utils/util.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ from typing import Union
5
+ import cv2
6
+
7
+ import torch
8
+ import torchvision
9
+ import torch.distributed as dist
10
+
11
+ from safetensors import safe_open
12
+ from tqdm import tqdm
13
+ from einops import rearrange
14
+ from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
15
+ from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora
16
+
17
+
18
+ def zero_rank_print(s):
19
+ if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
20
+ from typing import List
21
+ import PIL
22
+ def export_to_video(
23
+ video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
24
+ ) -> str:
25
+ # if output_video_path is None:
26
+ # output_video_path = tempfile.NamedTemporaryFile(suffix=".webm").name
27
+
28
+ if isinstance(video_frames[0], PIL.Image.Image):
29
+ video_frames = [np.array(frame) for frame in video_frames]
30
+
31
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
32
+ # fourcc = cv2.VideoWriter_fourcc(*'VP90')
33
+ h, w, c = video_frames[0].shape
34
+ video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h))
35
+ for i in range(len(video_frames)):
36
+ img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
37
+ video_writer.write(img)
38
+
39
+ return output_video_path
40
+
41
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=9):
42
+ videos = rearrange(videos, "b c t h w -> t b c h w")
43
+ outputs = []
44
+ for x in videos:
45
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
46
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
47
+ if rescale:
48
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
49
+ x = (x * 255).numpy().astype(np.uint8)
50
+ outputs.append(x)
51
+ os.makedirs(os.path.dirname(path), exist_ok=True)
52
+ # export_to_video(outputs, output_video_path=path, fps=fps)
53
+
54
+ imageio.mimsave(path, outputs, fps=fps)
55
+
56
+
57
+ # DDIM Inversion
58
+ @torch.no_grad()
59
+ def init_prompt(prompt, pipeline):
60
+ uncond_input = pipeline.tokenizer(
61
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
62
+ return_tensors="pt"
63
+ )
64
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
65
+ text_input = pipeline.tokenizer(
66
+ [prompt],
67
+ padding="max_length",
68
+ max_length=pipeline.tokenizer.model_max_length,
69
+ truncation=True,
70
+ return_tensors="pt",
71
+ )
72
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
73
+ context = torch.cat([uncond_embeddings, text_embeddings])
74
+
75
+ return context
76
+
77
+
78
+ def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
79
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
80
+ timestep, next_timestep = min(
81
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
82
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
83
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
84
+ beta_prod_t = 1 - alpha_prod_t
85
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
86
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
87
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
88
+ return next_sample
89
+
90
+
91
+ def get_noise_pred_single(latents, t, context, unet):
92
+ noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
93
+ return noise_pred
94
+
95
+
96
+ @torch.no_grad()
97
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
98
+ context = init_prompt(prompt, pipeline)
99
+ uncond_embeddings, cond_embeddings = context.chunk(2)
100
+ all_latent = [latent]
101
+ latent = latent.clone().detach()
102
+ for i in tqdm(range(num_inv_steps)):
103
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
104
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
105
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
106
+ all_latent.append(latent)
107
+ return all_latent
108
+
109
+
110
+ @torch.no_grad()
111
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
112
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
113
+ return ddim_latents
114
+
115
+ def load_weights(
116
+ animation_pipeline,
117
+ # motion module
118
+ motion_module_path = "",
119
+ motion_module_lora_configs = [],
120
+ # domain adapter
121
+ adapter_lora_path = "",
122
+ adapter_lora_scale = 1.0,
123
+ # image layers
124
+ dreambooth_model_path = "",
125
+ lora_model_path = "",
126
+ lora_alpha = 0.8,
127
+ ):
128
+ # motion module
129
+ unet_state_dict = {}
130
+ if motion_module_path != "":
131
+ print(f"load motion module from {motion_module_path}")
132
+ motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
133
+ motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
134
+ unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
135
+ unet_state_dict.pop("animatediff_config", "")
136
+
137
+ missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
138
+ print("motion_module missing:",len(missing))
139
+ print("motion_module unexpe:",len(unexpected))
140
+ assert len(unexpected) == 0
141
+ del unet_state_dict
142
+
143
+ # base model
144
+ # if dreambooth_model_path != "":
145
+ # print(f"load dreambooth model from {dreambooth_model_path}")
146
+ # # if dreambooth_model_path.endswith(".safetensors"):
147
+ # # dreambooth_state_dict = {}
148
+ # # with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f:
149
+ # # for key in f.keys():
150
+ # # dreambooth_state_dict[key] = f.get_tensor(key)
151
+ # # elif dreambooth_model_path.endswith(".ckpt"):
152
+ # # dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu")
153
+
154
+ # # # 1. vae
155
+ # # converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config)
156
+ # # animation_pipeline.vae.load_state_dict(converted_vae_checkpoint)
157
+ # # # 2. unet
158
+ # # converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config)
159
+ # # animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
160
+ # # # 3. text_model
161
+ # # animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict)
162
+ # # del dreambooth_state_dict
163
+ # dreambooth_state_dict = {}
164
+ # with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f:
165
+ # for key in f.keys():
166
+ # dreambooth_state_dict[key] = f.get_tensor(key)
167
+
168
+ # converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config)
169
+ # # print(vae)
170
+ # #vae ->to_q,to_k,to_v
171
+ # # print(converted_vae_checkpoint)
172
+ # convert_vae_keys = list(converted_vae_checkpoint.keys())
173
+ # for key in convert_vae_keys:
174
+ # if "encoder.mid_block.attentions" in key or "decoder.mid_block.attentions" in key:
175
+ # new_key = None
176
+ # if "key" in key:
177
+ # new_key = key.replace("key","to_k")
178
+ # elif "query" in key:
179
+ # new_key = key.replace("query","to_q")
180
+ # elif "value" in key:
181
+ # new_key = key.replace("value","to_v")
182
+ # elif "proj_attn" in key:
183
+ # new_key = key.replace("proj_attn","to_out.0")
184
+ # if new_key:
185
+ # converted_vae_checkpoint[new_key] = converted_vae_checkpoint.pop(key)
186
+
187
+ # animation_pipeline.vae.load_state_dict(converted_vae_checkpoint)
188
+
189
+ # converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config)
190
+ # animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
191
+
192
+ # animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict)
193
+ # del dreambooth_state_dict
194
+ # lora layers
195
+ if lora_model_path != "":
196
+ print(f"load lora model from {lora_model_path}")
197
+ assert lora_model_path.endswith(".safetensors")
198
+ lora_state_dict = {}
199
+ with safe_open(lora_model_path, framework="pt", device="cpu") as f:
200
+ for key in f.keys():
201
+ lora_state_dict[key] = f.get_tensor(key)
202
+
203
+ animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha)
204
+ del lora_state_dict
205
+
206
+ # domain adapter lora
207
+ if adapter_lora_path != "":
208
+ print(f"load domain lora from {adapter_lora_path}")
209
+ domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu")
210
+ domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict
211
+ domain_lora_state_dict.pop("animatediff_config", "")
212
+
213
+ animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale)
214
+
215
+ # motion module lora
216
+ for motion_module_lora_config in motion_module_lora_configs:
217
+ path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
218
+ print(f"load motion LoRA from {path}")
219
+ motion_lora_state_dict = torch.load(path, map_location="cpu")
220
+ motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
221
+ motion_lora_state_dict.pop("animatediff_config", "")
222
+
223
+ animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha)
224
+
225
+ return animation_pipeline
app.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ css = '''
4
+ .gradio-container {width: 85% !important}
5
+ '''
6
+ from animatediff.utils.util import save_videos_grid
7
+
8
+ import random
9
+ from infer import load_model
10
+ MAX_SEED=10000
11
+ import uuid
12
+ from insightface.app import FaceAnalysis
13
+ import os
14
+ import os
15
+ import cv2
16
+ from diffusers.utils import load_image
17
+ from insightface.utils import face_align
18
+ from PIL import Image
19
+ import torch
20
+ import argparse
21
+ # From command line read command adaface_ckpt_path
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument('--adaface_ckpt_path', type=str,
24
+ default='models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt')
25
+ # Don't use 'sd15' for base_model_type; it just generates messy videos.
26
+ parser.add_argument('--base_model_type', type=str, default='sar')
27
+ parser.add_argument('--adaface_base_model_type', type=str, default='sar')
28
+ parser.add_argument('--gpu', type=int, default=None)
29
+ parser.add_argument('--ip', type=str, default="0.0.0.0")
30
+ args = parser.parse_args()
31
+
32
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
33
+ if randomize_seed:
34
+ seed = random.randint(0, MAX_SEED)
35
+ return seed
36
+
37
+ # model = load_model()
38
+ # This FaceAnalysis uses a different model from what AdaFace uses, but it's fine.
39
+ # This is just to crop the face areas from the uploaded images.
40
+ app = FaceAnalysis(name="buffalo_l", root='models/insightface', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
41
+ app.prepare(ctx_id=0, det_size=(320, 320))
42
+ device = "cuda" if args.gpu is None else f"cuda:{args.gpu}"
43
+
44
+ id_animator, adaface = load_model(base_model_type=args.base_model_type,
45
+ adaface_base_model_type=args.adaface_base_model_type,
46
+ adaface_ckpt_path=args.adaface_ckpt_path,
47
+ device=device)
48
+ basedir = os.getcwd()
49
+ savedir = os.path.join(basedir,'samples')
50
+ os.makedirs(savedir, exist_ok=True)
51
+
52
+ #print(f"### Cleaning cached examples ...")
53
+ #os.system(f"rm -rf gradio_cached_examples/")
54
+
55
+ def swap_to_gallery(images):
56
+ # Update uploaded_files_gallery, show files, hide clear_button_column
57
+ # Or:
58
+ # Update uploaded_init_img_gallery, show init_img_files, hide init_clear_button_column
59
+ return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(value=images, visible=False)
60
+
61
+ def remove_back_to_files():
62
+ # Hide uploaded_files_gallery, show clear_button_column, hide files, reset init_img_selected_idx
63
+ # Or:
64
+ # Hide uploaded_init_img_gallery, hide init_clear_button_column, show init_img_files, reset init_img_selected_idx
65
+ return gr.update(visible=False), gr.update(visible=False), gr.update(value=None, visible=True), gr.update(value="0")
66
+
67
+ def get_clicked_image(data: gr.SelectData):
68
+ return data.index
69
+
70
+ @spaces.GPU
71
+ def gen_init_images(uploaded_image_paths, prompt, adaface_id_cfg_scale, out_image_count=3):
72
+ if uploaded_image_paths is None:
73
+ print("No image uploaded")
74
+ return None, None, None
75
+ # uploaded_image_paths is a list of tuples:
76
+ # [('/tmp/gradio/249981e66a7c665aaaf1c7eaeb24949af4366c88/jensen huang.jpg', None)]
77
+ # Extract the file paths.
78
+ uploaded_image_paths = [path[0] for path in uploaded_image_paths]
79
+ adaface.generate_adaface_embeddings(image_folder=None, image_paths=uploaded_image_paths,
80
+ out_id_embs_scale=adaface_id_cfg_scale, update_text_encoder=True)
81
+ # Generate two images each time for the user to select from.
82
+ noise = torch.randn(out_image_count, 3, 512, 512)
83
+ # samples: A list of PIL Image instances.
84
+ samples = adaface(noise, prompt, out_image_count=out_image_count, verbose=True)
85
+
86
+ face_paths = []
87
+ for sample in samples:
88
+ random_name = str(uuid.uuid4())
89
+ face_path = os.path.join(savedir, f"{random_name}.jpg")
90
+ face_paths.append(face_path)
91
+ sample.save(face_path)
92
+ print(f"Generated init image: {face_path}")
93
+
94
+ # Update uploaded_init_img_gallery, update and hide init_img_files, hide init_clear_button_column
95
+ return gr.update(value=face_paths, visible=True), gr.update(value=face_paths, visible=False), gr.update(visible=True)
96
+
97
+ @spaces.GPU
98
+ def generate_image(image_container, uploaded_image_paths, init_img_file_paths, init_img_selected_idx,
99
+ init_image_strength, init_image_final_weight,
100
+ prompt, negative_prompt, num_steps, video_length, guidance_scale, seed, attn_scale, image_embed_scale,
101
+ is_adaface_enabled, adaface_ckpt_path, adaface_id_cfg_scale, adaface_power_scale,
102
+ adaface_anneal_steps, progress=gr.Progress(track_tqdm=True)):
103
+
104
+ prompt = prompt + " 8k uhd, high quality"
105
+ if " shot" not in prompt:
106
+ prompt = prompt + ", medium shot"
107
+
108
+ prompt_img_lists=[]
109
+ for path in uploaded_image_paths:
110
+ img = cv2.imread(path)
111
+ faces = app.get(img)
112
+ face_roi = face_align.norm_crop(img, faces[0]['kps'], 112)
113
+ random_name = str(uuid.uuid4())
114
+ face_path = os.path.join(savedir, f"{random_name}.jpg")
115
+ cv2.imwrite(face_path, face_roi)
116
+ # prompt_img_lists is a list of PIL images.
117
+ prompt_img_lists.append(load_image(face_path).resize((224,224)))
118
+
119
+ if adaface is None or not is_adaface_enabled:
120
+ adaface_prompt_embeds = None
121
+ else:
122
+ if adaface_ckpt_path != args.adaface_ckpt_path:
123
+ # Reload the embedding manager
124
+ adaface.load_subj_basis_generator(adaface_ckpt_path)
125
+
126
+ adaface.generate_adaface_embeddings(image_folder=None, image_paths=uploaded_image_paths,
127
+ out_id_embs_scale=adaface_id_cfg_scale, update_text_encoder=True)
128
+ # adaface_prompt_embeds: [1, 77, 768].
129
+ adaface_prompt_embeds, _ = adaface.encode_prompt(prompt)
130
+
131
+ # init_img_file_paths is a list of image paths. If not chose, init_img_file_paths is None.
132
+ if init_img_file_paths is not None:
133
+ init_img_selected_idx = int(init_img_selected_idx)
134
+ init_img_file_path = init_img_file_paths[init_img_selected_idx]
135
+ init_image = cv2.imread(init_img_file_path)
136
+ init_image = cv2.resize(init_image, (512, 512))
137
+ init_image = Image.fromarray(cv2.cvtColor(init_image, cv2.COLOR_BGR2RGB))
138
+ print(f"init_image: {init_img_file_path}")
139
+ else:
140
+ init_image = None
141
+
142
+ sample = id_animator.generate(prompt_img_lists,
143
+ init_image = init_image,
144
+ init_image_strength = (init_image_strength, init_image_final_weight),
145
+ prompt = prompt,
146
+ negative_prompt = negative_prompt,
147
+ adaface_embeds = adaface_prompt_embeds,
148
+ # adaface_scale is not so useful, and when it's set >= 2, weird artifacts appear.
149
+ # Here it's limited to 0.7~1.3.
150
+ adaface_scale = adaface_power_scale,
151
+ num_inference_steps = num_steps,
152
+ adaface_anneal_steps = adaface_anneal_steps,
153
+ seed=seed,
154
+ guidance_scale = guidance_scale,
155
+ width = 512,
156
+ height = 512,
157
+ video_length = video_length,
158
+ attn_scale = attn_scale,
159
+ image_embed_scale = image_embed_scale,
160
+ )
161
+
162
+ save_sample_path = os.path.join(savedir, f"{random_name}.mp4")
163
+ save_videos_grid(sample, save_sample_path)
164
+ return save_sample_path
165
+
166
+ def validate(prompt):
167
+ if not prompt:
168
+ raise gr.Error("Prompt cannot be blank")
169
+
170
+ examples = [
171
+ [
172
+ "demo/ann.png",
173
+ ["demo/ann.png" ],
174
+ "A young girl with a passion for reading, curled up with a book in a cozy nook near a window",
175
+ "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck,",
176
+ 30,
177
+ 8, 8290,1,16
178
+ ],
179
+ [
180
+ "demo/lecun.png",
181
+ ["demo/lecun.png" ],
182
+ "Iron Man soars through the clouds, his repulsors blazing",
183
+ "worst quality, low quality, jpeg artifacts, ugly, duplicate, blurry, long neck",
184
+ 30,
185
+ 8, 4993,0.7,16
186
+ ],
187
+ [
188
+ "demo/mix.png",
189
+ ["demo/lecun.png","demo/ann.png"],
190
+ "A musician playing a guitar, fingers deftly moving across the strings, producing a soulful melody",
191
+ "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
192
+ 30,
193
+ 8, 1897,0.9,16
194
+ ],
195
+ [
196
+ "demo/zendaya.png",
197
+ ["demo/zendaya.png" ],
198
+ "A woman on a serene beach at sunset, the sky ablaze with hues of orange and purple.",
199
+ "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
200
+ 30,
201
+ 8, 5992,1,16
202
+ ],
203
+ [
204
+ "demo/qianlong.png",
205
+ ["demo/qianlong.png" ],
206
+ "A chef in a white apron, complete with a toqueblanche, garnishing a gourmet dish",
207
+ "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime), text, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, UnrealisticDream",
208
+ 30,
209
+ 8, 1844,0.8,16
210
+ ],
211
+ [
212
+ "demo/augustus.png",
213
+ ["demo/augustus.png" ],
214
+ "A man with dyed pink and purple hair, styledin a high ponytail",
215
+ "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
216
+ 30,
217
+ 8, 870,0.7,16
218
+ ]
219
+ ]
220
+
221
+ with gr.Blocks(css=css) as demo:
222
+ gr.Markdown(
223
+ """
224
+ # AdaFace-Animate: Zero-Shot Subject-Driven Video Generation for Humans
225
+ """
226
+ )
227
+ gr.Markdown(
228
+ """
229
+ ❗️❗️❗️**Tips:**
230
+ - You can upload one or more subject images for generating ID-specific video.
231
+ - Try different parameter combinations for the best generation quality.
232
+ """
233
+ )
234
+
235
+ with gr.Row():
236
+ with gr.Column():
237
+ files = gr.File(
238
+ label="Drag (Select) 1 or more photos of a person's face",
239
+ file_types=["image"],
240
+ file_count="multiple"
241
+ )
242
+ image_container = gr.Image(label="image container", sources="upload", type="numpy", height=256, visible=False)
243
+ uploaded_files_gallery = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=200)
244
+ with gr.Column(visible=False) as clear_button_column:
245
+ remove_and_reupload = gr.ClearButton(value="Remove and upload subject images", components=files, size="sm")
246
+
247
+ init_img_files = gr.File(
248
+ label="Drag (Select) 1 image for initialization",
249
+ file_types=["image"],
250
+ file_count="multiple"
251
+ )
252
+ init_img_container = gr.Image(label="init image container", sources="upload", type="numpy", height=256, visible=False)
253
+ # Although there's only one image, we still use columns=3, to scale down the image size.
254
+ # Otherwise it will occupy the full width, and the gallery won't show the whole image.
255
+ uploaded_init_img_gallery = gr.Gallery(label="Init image", visible=False, columns=3, rows=1, height=200)
256
+ # placeholder is just hint, not the real value. So we use "value='0'" instead of "placeholder='0'".
257
+ init_img_selected_idx = gr.Textbox(label="Selected init image index", value="0", visible=False)
258
+
259
+ init_image_strength = gr.Slider(
260
+ label="Init Image Strength",
261
+ minimum=0,
262
+ maximum=3,
263
+ step=0.25,
264
+ value=1.5,
265
+ )
266
+ init_image_final_weight = gr.Slider(
267
+ label="Final Weight of the Init Image",
268
+ minimum=0,
269
+ maximum=0.25,
270
+ step=0.025,
271
+ value=0.1,
272
+ )
273
+
274
+ with gr.Column(visible=False) as init_clear_button_column:
275
+ remove_init_and_reupload = gr.ClearButton(value="Remove and upload new init image", components=init_img_files, size="sm")
276
+ with gr.Column(visible=True) as init_gen_button_column:
277
+ gen_init = gr.Button(value="Generate 3 new init images")
278
+
279
+ prompt = gr.Textbox(label="Prompt",
280
+ # info="Try something like 'a photo of a man/woman img', 'img' is the trigger word.",
281
+ placeholder="Iron Man soars through the clouds, his repulsors blazing.")
282
+
283
+ image_embed_scale = gr.Slider(
284
+ label="Image Embedding Scale",
285
+ minimum=0,
286
+ maximum=2,
287
+ step=0.1,
288
+ value=0.8,
289
+ )
290
+ attn_scale = gr.Slider(
291
+ label="Attention Processor Scale",
292
+ minimum=0,
293
+ maximum=2,
294
+ step=0.1,
295
+ value=0.8,
296
+ )
297
+ adaface_id_cfg_scale = gr.Slider(
298
+ label="AdaFace Embedding ID CFG Scale",
299
+ minimum=0.5,
300
+ maximum=6,
301
+ step=0.25,
302
+ value=1.5,
303
+ )
304
+
305
+ submit = gr.Button("Generate Video")
306
+
307
+ with gr.Accordion(open=False, label="Advanced Options"):
308
+ video_length = gr.Slider(
309
+ label="video_length",
310
+ minimum=16,
311
+ maximum=21,
312
+ step=1,
313
+ value=16,
314
+ )
315
+ is_adaface_enabled = gr.Checkbox(label="Enable AdaFace", value=True)
316
+ adaface_ckpt_path = gr.Textbox(
317
+ label="AdaFace ckpt Path",
318
+ placeholder=args.adaface_ckpt_path,
319
+ value=args.adaface_ckpt_path,
320
+ )
321
+
322
+ adaface_power_scale = gr.Slider(
323
+ label="AdaFace Embedding Power Scale",
324
+ minimum=0.7,
325
+ maximum=1.3,
326
+ step=0.1,
327
+ value=1,
328
+ )
329
+
330
+ # adaface_anneal_steps is no longer necessary, but we keep it here for future use.
331
+ adaface_anneal_steps = gr.Slider(
332
+ label="AdaFace Anneal Steps",
333
+ minimum=0,
334
+ maximum=2,
335
+ step=1,
336
+ value=0,
337
+ visible=False,
338
+ )
339
+
340
+ negative_prompt = gr.Textbox(
341
+ label="Negative Prompt",
342
+ placeholder="low quality",
343
+ value="face portrait, (deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime), text, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, bare breasts, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, long neck, UnrealisticDream",
344
+ )
345
+ num_steps = gr.Slider(
346
+ label="Number of sample steps",
347
+ minimum=25,
348
+ maximum=100,
349
+ step=1,
350
+ value=40,
351
+ )
352
+ guidance_scale = gr.Slider(
353
+ label="Guidance scale",
354
+ minimum=1.0,
355
+ maximum=10.0,
356
+ step=0.5,
357
+ value=4,
358
+ )
359
+ seed = gr.Slider(
360
+ label="Seed",
361
+ minimum=0,
362
+ maximum=MAX_SEED,
363
+ step=1,
364
+ value=985,
365
+ )
366
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
367
+ with gr.Column():
368
+ result_video = gr.Video(label="Generated Animation", interactive=False)
369
+
370
+ files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files_gallery, clear_button_column, files])
371
+ remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files_gallery, clear_button_column, files, init_img_selected_idx])
372
+
373
+ init_img_files.upload(fn=swap_to_gallery, inputs=init_img_files, outputs=[uploaded_init_img_gallery, init_clear_button_column, init_img_files])
374
+ remove_init_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_init_img_gallery, init_clear_button_column,
375
+ init_img_files, init_img_selected_idx])
376
+ gen_init.click(fn=gen_init_images, inputs=[uploaded_files_gallery, prompt, adaface_id_cfg_scale],
377
+ outputs=[uploaded_init_img_gallery, init_img_files, init_clear_button_column])
378
+ uploaded_init_img_gallery.select(fn=get_clicked_image, inputs=None, outputs=init_img_selected_idx)
379
+
380
+ submit.click(fn=validate,
381
+ inputs=[prompt],outputs=None).success(
382
+ fn=randomize_seed_fn,
383
+ inputs=[seed, randomize_seed],
384
+ outputs=seed,
385
+ queue=False,
386
+ api_name=False,
387
+ ).then(
388
+ fn=generate_image,
389
+ inputs=[image_container, files, init_img_files, init_img_selected_idx, init_image_strength, init_image_final_weight,
390
+ prompt, negative_prompt, num_steps, video_length, guidance_scale,
391
+ seed, attn_scale, image_embed_scale,
392
+ is_adaface_enabled, adaface_ckpt_path, adaface_id_cfg_scale, adaface_power_scale, adaface_anneal_steps],
393
+ outputs=[result_video]
394
+ )
395
+ gr.Examples( fn=generate_image, examples=[], #examples,
396
+ inputs=[image_container, files, init_img_files, init_img_selected_idx, init_image_strength, init_image_final_weight,
397
+ prompt, negative_prompt, num_steps, video_length, guidance_scale,
398
+ seed, attn_scale, image_embed_scale,
399
+ is_adaface_enabled, adaface_ckpt_path, adaface_id_cfg_scale, adaface_power_scale, adaface_anneal_steps],
400
+ outputs=[result_video], cache_examples=True )
401
+
402
+ demo.launch(share=True, server_name=args.ip, ssl_verify=False)
assets/alita/alita armor orig.mp4 ADDED
Binary file (241 kB). View file
 
assets/alita/alita armor.mp4 ADDED
Binary file (207 kB). View file
 
assets/alita/alita beach orig.mp4 ADDED
Binary file (127 kB). View file
 
assets/alita/alita beach.mp4 ADDED
Binary file (137 kB). View file
 
assets/alita/alita cooking orig.mp4 ADDED
Binary file (225 kB). View file
 
assets/alita/alita cooking.mp4 ADDED
Binary file (172 kB). View file
 
assets/alita/alita dancing orig.mp4 ADDED
Binary file (173 kB). View file
 
assets/alita/alita dancing.mp4 ADDED
Binary file (255 kB). View file
 
assets/alita/alita iron man orig.mp4 ADDED
Binary file (166 kB). View file