Spaces:
Runtime error
Runtime error
float32 unet
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +7 -0
- Dockerfile +15 -0
- README.md +5 -7
- README2.md +241 -0
- adaface/adaface-infer.py +131 -0
- adaface/adaface-translate.py +208 -0
- adaface/adaface_wrapper.py +286 -0
- adaface/arc2face_models.py +303 -0
- adaface/subj_basis_generator.py +758 -0
- adaface/util.py +341 -0
- animatediff/models/attention.py +327 -0
- animatediff/models/attention_bkp.py +326 -0
- animatediff/models/motion_module.py +552 -0
- animatediff/models/motion_module_bkp.py +331 -0
- animatediff/models/resnet.py +217 -0
- animatediff/models/sparse_controlnet.py +587 -0
- animatediff/models/unet.py +600 -0
- animatediff/models/unet_blocks.py +760 -0
- animatediff/pipelines/pipeline_animation.py +793 -0
- animatediff/sd/.gitattributes +35 -0
- animatediff/sd/feature_extractor/preprocessor_config.json +20 -0
- animatediff/sd/model_index.json +32 -0
- animatediff/sd/safety_checker/config.json +175 -0
- animatediff/sd/safety_checker/pytorch_model.bin +3 -0
- animatediff/sd/scheduler/scheduler_config.json +13 -0
- animatediff/sd/text_encoder/config.json +25 -0
- animatediff/sd/text_encoder/pytorch_model.bin +3 -0
- animatediff/sd/tokenizer/merges.txt +0 -0
- animatediff/sd/tokenizer/special_tokens_map.json +24 -0
- animatediff/sd/tokenizer/tokenizer_config.json +34 -0
- animatediff/sd/tokenizer/vocab.json +0 -0
- animatediff/sd/unet/config.json +36 -0
- animatediff/sd/unet/diffusion_pytorch_model.bin +3 -0
- animatediff/sd/v1-inference.yaml +70 -0
- animatediff/sd/vae/config.json +29 -0
- animatediff/sd/vae/diffusion_pytorch_model.bin +3 -0
- animatediff/utils/convert_from_ckpt.py +959 -0
- animatediff/utils/convert_lora_safetensor_to_diffusers.py +152 -0
- animatediff/utils/convert_original_stable_diffusion_to_diffusers.py +188 -0
- animatediff/utils/util.py +225 -0
- app.py +412 -0
- assets/alita/alita armor orig.mp4 +0 -0
- assets/alita/alita armor.mp4 +0 -0
- assets/alita/alita beach orig.mp4 +0 -0
- assets/alita/alita beach.mp4 +0 -0
- assets/alita/alita cooking orig.mp4 +0 -0
- assets/alita/alita cooking.mp4 +0 -0
- assets/alita/alita dancing orig.mp4 +0 -0
- assets/alita/alita dancing.mp4 +0 -0
- assets/alita/alita iron man orig.mp4 +0 -0
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/*
|
2 |
+
__pycache__/
|
3 |
+
*.pyc
|
4 |
+
gradio_cached_examples/*
|
5 |
+
gradio_cached_examples/
|
6 |
+
samples/*
|
7 |
+
samples/
|
Dockerfile
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10
|
2 |
+
ENV PYTHONUNBUFFERED=1
|
3 |
+
|
4 |
+
RUN RUN apt-get update && \
|
5 |
+
apt-get install -y \
|
6 |
+
bash \
|
7 |
+
git git-lfs \
|
8 |
+
wget curl procps \
|
9 |
+
htop vim nano && \
|
10 |
+
rm -rf /var/lib/apt/lists/*
|
11 |
+
|
12 |
+
WORKDIR /app
|
13 |
+
COPY --link --chown=1000 ./ /app
|
14 |
+
|
15 |
+
CMD ["python", "app.py"]
|
README.md
CHANGED
@@ -1,12 +1,10 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: yellow
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
-
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: AdaFace-Animate
|
3 |
+
emoji: 🎨
|
4 |
colorFrom: yellow
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.36.1
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
+
---
|
|
|
|
README2.md
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# AdaFace-Animate
|
2 |
+
|
3 |
+
This folder contains the preliminary implementation of **AdaFace-Animate**.
|
4 |
+
It is a zero-shot subject-guided animation generator conditioned with human subject images, by combining AnimateDiff, ID-Animator and AdaFace. The ID-Animator provides AnimateDiff with rough subject characteristics, and AdaFace provides refined and more authentic subject facial details.
|
5 |
+
|
6 |
+
Please refer to our NeurIPS 2024 submission for more details about AdaFace:
|
7 |
+
|
8 |
+
**AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization**
|
9 |
+
</br>
|
10 |
+
|
11 |
+
[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow)](https://huggingface.co/spaces/adaface-neurips/adaface-animate)
|
12 |
+
|
13 |
+
This pipeline uses 4 pretrained models: [Stable Diffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5), [AnimateDiff v3](https://github.com/guoyww/animatediff), [ID-Animator](https://github.com/ID-Animator/ID-Animator) and [AdaFace](https://huggingface.co/adaface-neurips/adaface).
|
14 |
+
|
15 |
+
AnimateDiff uses a SD-1.5 type checkpoint, referred to as a "DreamBooth" model. The DreamBooth model we use is an average of three SD-1.5 models named as "SAR": the original [Stable Diffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors), [AbsoluteReality V1.8.1](https://civitai.com/models/81458?modelVersionId=132760), and [RealisticVision V4.0](https://civitai.com/models/4201?modelVersionId=114367). In our experiments, this average model performs better than any of the individual models.
|
16 |
+
|
17 |
+
## Procedures of Generation
|
18 |
+
We find that using an initial image helps stablize the animation sequence and improve the quality. When generating each example video, an initial image is first generated by AdaFace with the same prompt as used to generate the video. This image is blended with multiple frames of random noises with weights decreasing with $t$. The multi-frame blended noises are converted to a 1-second animation with AnimateDiff, conditioned by both AdaFace and ID-Animator embeddings.
|
19 |
+
|
20 |
+
## Gallery
|
21 |
+
[Gallery](./assets/) contains 100 subject videos generated by us. They belong to 10 celebrities, each with 10 different prompts. The (shortened) prompts are: "Armor Suit", "Iron Man Costume", "Superman Costume", "Wielding a Lightsaber", "Walking on the beach", "Cooking", "Dancing", "Playing Guitar", "Reading", and "Running".
|
22 |
+
|
23 |
+
Some example videos are shown below. The full set of videos can be found in [Gallery](./assets/).
|
24 |
+
|
25 |
+
(Hint: use the horizontal scroll bar at the bottom of the table to view the full table)
|
26 |
+
|
27 |
+
<table class="center" style="table-layout: fixed; width: 100%; overflow-x: auto;">
|
28 |
+
<tr style="line-height: 1">
|
29 |
+
<td width="25%" style="text-align: center">Input (Celebrities)</td>
|
30 |
+
<td width="25%" style="text-align: center">Animation 1: Playing Guitar</td>
|
31 |
+
<td width="25%" style="text-align: center">Animation 2: Cooking</td>
|
32 |
+
<td width="25%" style="text-align: center">Animation 3: Dancing</td>
|
33 |
+
</tr>
|
34 |
+
<tr>
|
35 |
+
<td style="text-align: center"><img src="assets/jennifer-lawrence/jennifer lawrence.jpg" style="width:100%"></td>
|
36 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/ea12a906-8637-4b32-97ba-c439990fec0a" type="video/mp4"></video></td>
|
37 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/83a08691-4f4e-4898-b4ae-be5dfcd1fb85" type="video/mp4"></video></td>
|
38 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/1e957f80-376b-4ca7-81ca-fa63f19a1c5a" type="video/mp4"></video></td>
|
39 |
+
</tr>
|
40 |
+
<tr>
|
41 |
+
<td style="text-align: center"><img src="assets/yann-lecun/yann lecun.png" style="width:100%"></td>
|
42 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/0af3f6dc-d3d9-486c-a083-ab77a8397d80" type="video/mp4"></video></td>
|
43 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/54f3745f-abf6-4608-93c5-d8e103d05dc7" type="video/mp4"></video></td>
|
44 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/273ecced-a796-4e59-a43a-217db7fb4681" type="video/mp4"></video></td>
|
45 |
+
</tr>
|
46 |
+
<tr>
|
47 |
+
<td style="text-align: center"><img src="assets/gakki/gakki.png" style="width:100%"></td>
|
48 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/28056aeb-5ce4-42bc-a593-877ba49834b9" type="video/mp4"></video></td>
|
49 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/68ad643c-8a2b-43a8-9c7b-10c36c4912d4" type="video/mp4"></video></td>
|
50 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/93f3891d-19c5-40fb-af21-a0b2e03d0d7f" type="video/mp4"></video></td>
|
51 |
+
</tr>
|
52 |
+
|
53 |
+
</table>
|
54 |
+
|
55 |
+
To illustrate the wide range of applications of our method, we animated 8 internet memes. 4 of them are shown in the table below. The full gallery can be found in [memes](./assets/memes/).
|
56 |
+
|
57 |
+
<table class="center">
|
58 |
+
<tr style="line-height: 1">
|
59 |
+
<td width=25% style="text-align: center">Input (Memes)</td>
|
60 |
+
<td width=25% style="text-align: center">Animation</td>
|
61 |
+
<td width=25% style="text-align: center">Input</td>
|
62 |
+
<td width=25% style="text-align: center">Animation</td>
|
63 |
+
</tr>
|
64 |
+
<tr>
|
65 |
+
<td style="text-align: center">Yao Ming Laugh</td><td></td><td style="text-align: center">Girl Burning House</td></td><td>
|
66 |
+
</tr>
|
67 |
+
<tr>
|
68 |
+
<td><img src="assets/memes/yao ming laugh.jpg" style="width:100%"></td>
|
69 |
+
<td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/984a751f-ed2b-4ce3-aef8-41056ac111cf" type="video/mp4"></video></td>
|
70 |
+
<td><img src="assets/memes/girl burning house.jpg" style="width:100%"></td>
|
71 |
+
<td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/11c83ae1-dece-4798-baf5-e608ab8709e3" type="video/mp4"></video></td>
|
72 |
+
</tr>
|
73 |
+
<tr>
|
74 |
+
<td style="text-align: center">Girl with a Pearl Earring</td><td></td><td style="text-align: center">Great Gatsby</td></td><td>
|
75 |
+
</tr>
|
76 |
+
<tr>
|
77 |
+
<td><img src="assets/memes/girl with a pearl earring.jpg" style="width:100%"></td>
|
78 |
+
<td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/3b773486-b87e-4331-9e5d-ec8d54e11394" type="video/mp4"></video></td>
|
79 |
+
<td><img src="assets/memes/great gatsby.jpg" style="width:100%"></td>
|
80 |
+
<td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/db941805-8a6c-4596-ba3a-247b54baa5ef" type="video/mp4"></video></td>
|
81 |
+
</tr>
|
82 |
+
</table>
|
83 |
+
|
84 |
+
## Comparison with ID-Animator, with AdaFace Initial Images
|
85 |
+
To compare with the baseline method "ID-Animator", for each video, we disable AdaFace, and generate the corresponding video with ID-Animator, using otherwise identical settings: the same subject image(s) and initial image, and the same random seed and prompt. The table below compares some of these videos side-by-side with the AdaFace-Animate videos. The full set of ID-Animator videos can be found in each subject folder in [Gallery](./assets/), named as "* orig.mp4".
|
86 |
+
|
87 |
+
**NOTE** Since ID-Animator videos utilize initial images generated by **AdaFace**, this gives ID-Animator an advantage over the original ID-Animator.
|
88 |
+
|
89 |
+
(Hint: use the horizontal scroll bar at the bottom of the table to view the full table)
|
90 |
+
|
91 |
+
<table class="center" style="table-layout: fixed; width: 100%; overflow-x: auto;">
|
92 |
+
<tr style="line-height: 1">
|
93 |
+
<td width="14%" style="text-align: center; white-space: normal; word-wrap: break-word;">Initial Image: Playing Guitar</td>
|
94 |
+
<td width="18%" style="text-align: center; white-space: normal; word-wrap: break-word;">ID-Animator: Playing Guitar</td>
|
95 |
+
<td width="18%" style="text-align: center; white-space: normal; word-wrap: break-word;">AdaFace-Animate: Playing Guitar</td>
|
96 |
+
<td width="14%" style="text-align: center; white-space: normal; word-wrap: break-word;">Initial Image Dancing</td>
|
97 |
+
<td width="18%" style="text-align: center; white-space: normal; word-wrap: break-word;">ID-Animator: Dancing</td>
|
98 |
+
<td width="18%" style="text-align: center; white-space: normal; word-wrap: break-word;">AdaFace-Animate: Dancing</td>
|
99 |
+
</tr>
|
100 |
+
<tr>
|
101 |
+
<td style="text-align: center"><img src="assets/jennifer-lawrence/init images/jennifer lawrence playing guitar.jpg" style="width:100%"></td>
|
102 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/f5d0f2c6-f4bd-4517-bfa1-021db1577895" type="video/mp4"></video></td>
|
103 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/ea12a906-8637-4b32-97ba-c439990fec0a" type="video/mp4"></video></td>
|
104 |
+
<td style="text-align: center"><img src="assets/jennifer-lawrence/init images/jennifer lawrence dancing.jpg" style="width:100%"></td>
|
105 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/421f5b81-e1a7-459a-869a-f7f6dc51a74e" type="video/mp4"></video></td>
|
106 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/1e957f80-376b-4ca7-81ca-fa63f19a1c5a"
|
107 |
+
type="video/mp4"></video></td>
|
108 |
+
</tr>
|
109 |
+
<tr>
|
110 |
+
<td style="text-align: center"><img src="assets/yann-lecun/init images/yann lecun playing guitar.jpg" style="width:100%"></td>
|
111 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/3bbfc15b-4205-4052-b5cc-c4f8d6d17027" type="video/mp4"></video></td>
|
112 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/0af3f6dc-d3d9-486c-a083-ab77a8397d80" type="video/mp4"></video></td>
|
113 |
+
<td style="text-align: center"><img src="assets/yann-lecun/init images/yann lecun dancing.jpg" style="width:100%"></td>
|
114 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/75e191f4-87e2-486c-90e7-c9e21a1bf494" type="video/mp4"></video></td>
|
115 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/273ecced-a796-4e59-a43a-217db7fb4681"
|
116 |
+
type="video/mp4"></video></td>
|
117 |
+
</tr>
|
118 |
+
<tr>
|
119 |
+
<td style="text-align: center"><img src="assets/gakki/init images/gakki playing guitar.jpg" style="width:100%"></td>
|
120 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/6a5579ce-23e3-4603-8917-00a16d6a3682" type="video/mp4"></video></td>
|
121 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/28056aeb-5ce4-42bc-a593-877ba49834b9" type="video/mp4"></video></td>
|
122 |
+
<td style="text-align: center"><img src="assets/gakki/init images/gakki dancing.jpg" style="width:100%"></td>
|
123 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/28082e58-a0ed-4492-8c51-cb563f92baeb" type="video/mp4"></video></td>
|
124 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/93f3891d-19c5-40fb-af21-a0b2e03d0d7f" type="video/mp4"></video></td>
|
125 |
+
</tr>
|
126 |
+
|
127 |
+
</table>
|
128 |
+
|
129 |
+
The table below compares the animated internet memes. The initial image for each video is the meme image itself. For "Yao Ming laughing" and "Great Gatsby", 2~3 extra portrait photos of the subject are included as the subject images to enhance the facial fidelity. For other memes, the subject image is only the meme image. The full set of ID-Animator meme videos can be found in [memes](./assets/memes/), named as "* orig.mp4".
|
130 |
+
|
131 |
+
<table class="center" style="width: 60%;">
|
132 |
+
<tr style="line-height: 1">
|
133 |
+
<td width=20% style="text-align: center">Input (Memes)</td>
|
134 |
+
<td width=20% style="text-align: center">ID-Animator</td>
|
135 |
+
<td width=20% style="text-align: center">AdaFace-Animate</td>
|
136 |
+
</tr>
|
137 |
+
<tr>
|
138 |
+
<td><img src="assets/memes/yao ming laugh.jpg" style="width:100%"></td>
|
139 |
+
<td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/9daf814c-ae8a-476d-9c32-fa9ef6be16d9" type="video/mp4"></video></td>
|
140 |
+
<td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/984a751f-ed2b-4ce3-aef8-41056ac111cf" type="video/mp4"></video></td>
|
141 |
+
<tr>
|
142 |
+
<td><img src="assets/memes/girl with a pearl earring.jpg" style="width:100%"></td>
|
143 |
+
<td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/05ed29d5-4eaa-4a0a-bee2-bc77e5649f58" type="video/mp4"></video></td>
|
144 |
+
<td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/3b773486-b87e-4331-9e5d-ec8d54e11394" type="video/mp4"></video></td>
|
145 |
+
</tr>
|
146 |
+
</table>
|
147 |
+
|
148 |
+
We can see that the subjects in AdaFace-Animate videos have more authentic facial features and better preserve the facial expressions, while the subjects in ID-Animator videos are less authentic and faithful to the original images.
|
149 |
+
|
150 |
+
## Comparison with ID-Animator, without AdaFace Initial Images
|
151 |
+
To exclude the effects of AdaFace, we generate a subset of videos with AdaFace-Animate / ID-Animator *without initial images*. These videos were generated under the same settings as above, except not using initial images. The table below shows a selection of the videos. The complete set of such videos can be found in [no-init](./assets/no-init/). It can be seen that without the help of AdaFace initial images, the compositionality, or the overall layout deteriorates on some prompts. In particular, some background objects are suppressed by over-expressed facial features. Moreover, the performance discrepancy between AdaFace-Animate and ID-Animator becomes more pronounced.
|
152 |
+
|
153 |
+
(Hint: use the horizontal scroll bar at the bottom of the table to view the full table)
|
154 |
+
|
155 |
+
<table class="center" style="table-layout: fixed; width: 100%; overflow-x: auto;">
|
156 |
+
<tr style="line-height: 1">
|
157 |
+
<td width="20%" style="text-align: center; white-space: normal; word-wrap: break-word;">Input (Celebrities)</td>
|
158 |
+
<td width="20%" style="text-align: center; white-space: normal; word-wrap: break-word;">ID-Animator: Playing Guitar</td>
|
159 |
+
<td width="20%" style="text-align: center; white-space: normal; word-wrap: break-word;">AdaFace-Animate: Playing Guitar</td>
|
160 |
+
<td width="20%" style="text-align: center; white-space: normal; word-wrap: break-word;">ID-Animator: Dancing</td>
|
161 |
+
<td width="20%" style="text-align: center; white-space: normal; word-wrap: break-word;">AdaFace-Animate: Dancing</td>
|
162 |
+
</tr>
|
163 |
+
<tr>
|
164 |
+
<td style="text-align: center"><img src="assets/jennifer-lawrence/jennifer lawrence.jpg" style="width:100%"></td>
|
165 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/2c3fa70b-4a38-48d1-aead-cd94976f6beb" type="video/mp4"></video></td>
|
166 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/f658f9e6-c3b6-4c4a-920c-00a89b98d97a" type="video/mp4"></video></td>
|
167 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/2de5cb38-f62c-4e9d-90ad-9bbb72d1ba7a" type="video/mp4"></video></td>
|
168 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/3b39e66d-c696-4022-81ff-6afae8147981" type="video/mp4"></video></td>
|
169 |
+
</tr>
|
170 |
+
<tr>
|
171 |
+
<td style="text-align: center"><img src="assets/yann-lecun/yann lecun.png" style="width:100%"></td>
|
172 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/7f7f8cd0-7ca3-47b4-a44d-8c6b399bdbc4" type="video/mp4"></video></td>
|
173 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/eb173058-2314-470a-8cf4-3702036022ad" type="video/mp4"></video></td>
|
174 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/cd5a9687-bae0-47fd-b82c-febc0d343ac2" type="video/mp4"></video></td>
|
175 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/e08c778c-5e87-40f6-a7a1-328f5d0d016f" type="video/mp4"></video></td>
|
176 |
+
</tr>
|
177 |
+
<tr>
|
178 |
+
<td style="text-align: center"><img src="assets/gakki/gakki.png" style="width:100%"></td>
|
179 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/0370714b-d10c-422d-adee-76f6221aa1be" type="video/mp4"></video></td>
|
180 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/79cd95d2-95ea-4854-816e-2caf0cbebf94" type="video/mp4"></video></td>
|
181 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/60fa6b2a-6e1a-48c0-a777-4b62504ff679" type="video/mp4"></video></td>
|
182 |
+
<td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/c72836ee-9d7a-4525-a48a-7017b60f83f3" type="video/mp4"></video></td>
|
183 |
+
</tr>
|
184 |
+
|
185 |
+
</table>
|
186 |
+
|
187 |
+
## Installation
|
188 |
+
|
189 |
+
### Manually Download Model Checkpoints
|
190 |
+
- Download Stable Diffusion V1.5 into ``animatediff/sd``:
|
191 |
+
|
192 |
+
``git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 animatediff/sd``
|
193 |
+
- Download AnimateDiff motion module into ``models/v3_sd15_mm.ckpt``: https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_mm.ckpt
|
194 |
+
- Download Animatediff adapter into ``models/v3_adapter_sd_v15.ckpt``: https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_adapter.ckpt
|
195 |
+
- Download ID-Animator checkpoint into ``models/animator.ckpt`` from: https://huggingface.co/spaces/ID-Animator/ID-Animator/blob/main/animator.ckpt
|
196 |
+
- Download CLIP Image encoder into ``models/image_encoder/`` from: https://huggingface.co/spaces/ID-Animator/ID-Animator/tree/main/image_encoder
|
197 |
+
- Download AdaFace checkpoint into ``models/adaface/`` from: https://huggingface.co/adaface-neurips/adaface/tree/main/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt
|
198 |
+
|
199 |
+
### Prepare the SAR Model
|
200 |
+
|
201 |
+
Manually download the three `.safetensors` models: the original [Stable Diffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors), [AbsoluteReality V1.8.1](https://civitai.com/models/81458?modelVersionId=132760), and [RealisticVision V4.0](https://civitai.com/models/4201?modelVersionId=114367). Save them to `models/sar`.
|
202 |
+
|
203 |
+
Run the following command to generate an average of the three models:
|
204 |
+
```
|
205 |
+
python3 scripts/avg_models.py --input models/sar/absolutereality_v181.safetensors models/sar/realisticVisionV40_v40VAE.safetensors models/sar/v1-5-pruned.safetensors --output models/sar/sar.safetensors
|
206 |
+
```
|
207 |
+
|
208 |
+
\[Optional Improvement\]
|
209 |
+
1. You can replace the VAE of the SAR model with the [MSE-840000 finetuned VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse-original/tree/main) for slightly better video details:
|
210 |
+
```
|
211 |
+
python3 scripts/repl_vae.py --base_ckpt models/sar/sar.safetensors --vae_ckpt models/sar/vae-ft-mse-840000-ema-pruned.ckpt --out_ckpt models/sar/sar-vae.safetensors
|
212 |
+
mv models/sar/sar-vae.safetensors models/sar/sar.safetensors
|
213 |
+
```
|
214 |
+
|
215 |
+
2. You can replace the text encoder of the SAR model with the text encoder of [DreamShaper V8](https://civitai.com/models/4384?modelVersionId=252914) for slightly more authentic facial features:
|
216 |
+
```
|
217 |
+
python3 scripts/repl_textencoder.py --base_ckpt models/sar/sar.safetensors --te_ckpt models/sar/dreamshaper_8.safetensors --out_ckpt models/sar/sar2.safetensors
|
218 |
+
mv models/sar/sar2.safetensors models/sar/sar.safetensors
|
219 |
+
```
|
220 |
+
### Inference
|
221 |
+
|
222 |
+
Run the demo inference scripts:
|
223 |
+
```
|
224 |
+
python3 app.py
|
225 |
+
```
|
226 |
+
Then connect to the Gradio interface at `local-ip-address:7860` or `https://*.gradio.live` shown in the terminal.
|
227 |
+
|
228 |
+
#### Use of Initial Image
|
229 |
+
The use of an initial image is optional. It usually helps stabilize the animation sequence and improve the quality.
|
230 |
+
|
231 |
+
You can generate 3 initial images in one go by clicking "Generate 3 new init images". The images will be based on the same prompt as the video generation. You can also use different prompts for the initial images and the video generation. Select the desired initial image by clicking on the image, and then click "Generate Video". If none of the initial images are good enough, you can generate again by clicking "Generate 3 new init images" again.
|
232 |
+
|
233 |
+
### Common Issues
|
234 |
+
1. **Defocus**. This is the biggest possible issue. When the subject is far from the camera, the model may not be able to generate a clear face and control the subject's facial details. In this situation, consider to increase the weights of "Image Embedding Scale", "Attention Processor Scale" and "AdaFace Embedding ID CFG Scale". You can also add a prefix "face portrait of" to the prompt to help the model focus on the face.
|
235 |
+
2. **Motion Degeneration**. When the subject is too close to the camera, the model may not be able to generate correct motions and poses, and only generate the face. In this situation, consider to decrease the weights of "Image Embedding Scale", "Attention Processor Scale" and "AdaFace Embedding ID CFG Scale". You can also adjust the prompt slightly to let it focus on the whole body.
|
236 |
+
3. **Lesser Facial Characteristics**. If the subject's facial characteristics is not so distinctive, you can increase the weights of "AdaFace Embedding ID CFG Scale".
|
237 |
+
4. **Unstable Motions**. If the generated video has unstable motions, this is probably due to the limitations of AnimateDiff. Nonetheless, you can make it more stable by using a carefully selected initial image, and optionally increase the "Init Image Strength" and "Final Weight of the Init Image". Note that when "Final Weight of the Init Image" is larger, the motion in the generated video will be less dynamic.
|
238 |
+
|
239 |
+
|
240 |
+
## Disclaimer
|
241 |
+
This project is intended for academic purposes only. We do not accept responsibility for user-generated content. Users are solely responsible for their own actions. The contributors to this project are not legally affiliated with, nor are they liable for, the actions of users. Please use this generative model responsibly, in accordance with ethical and legal standards.
|
adaface/adaface-infer.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from adaface.adaface_wrapper import AdaFaceWrapper
|
2 |
+
import torch
|
3 |
+
#import torch.nn.functional as F
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import os, argparse, glob, re
|
7 |
+
|
8 |
+
def save_images(images, num_images_per_row, subject_name, prompt, noise_level, save_dir = "samples-ada"):
|
9 |
+
if num_images_per_row > len(images):
|
10 |
+
num_images_per_row = len(images)
|
11 |
+
|
12 |
+
os.makedirs(save_dir, exist_ok=True)
|
13 |
+
|
14 |
+
num_columns = int(np.ceil(len(images) / num_images_per_row))
|
15 |
+
# Save 4 images as a grid image in save_dir
|
16 |
+
grid_image = Image.new('RGB', (512 * num_images_per_row, 512 * num_columns))
|
17 |
+
for i, image in enumerate(images):
|
18 |
+
image = image.resize((512, 512))
|
19 |
+
grid_image.paste(image, (512 * (i % num_images_per_row), 512 * (i // num_images_per_row)))
|
20 |
+
|
21 |
+
prompt_sig = prompt.replace(" ", "_").replace(",", "_")
|
22 |
+
grid_filepath = os.path.join(save_dir, f"{subject_name}-{prompt_sig}-noise{noise_level:.02f}.png")
|
23 |
+
if os.path.exists(grid_filepath):
|
24 |
+
grid_count = 2
|
25 |
+
grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-noise{noise_level:.02f}-{grid_count}.jpg')
|
26 |
+
while os.path.exists(grid_filepath):
|
27 |
+
grid_count += 1
|
28 |
+
grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-noise{noise_level:.02f}-{grid_count}.jpg')
|
29 |
+
|
30 |
+
grid_image.save(grid_filepath)
|
31 |
+
print(f"Saved to {grid_filepath}")
|
32 |
+
|
33 |
+
def seed_everything(seed):
|
34 |
+
np.random.seed(seed)
|
35 |
+
torch.manual_seed(seed)
|
36 |
+
torch.cuda.manual_seed_all(seed)
|
37 |
+
torch.backends.cudnn.deterministic = True
|
38 |
+
torch.backends.cudnn.benchmark = False
|
39 |
+
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
40 |
+
|
41 |
+
def parse_args():
|
42 |
+
parser = argparse.ArgumentParser()
|
43 |
+
parser.add_argument("--base_model_path", type=str, default='runwayml/stable-diffusion-v1-5',
|
44 |
+
help="Type of checkpoints to use (default: SD 1.5)")
|
45 |
+
parser.add_argument("--embman_ckpt", type=str, required=True,
|
46 |
+
help="Path to the checkpoint of the embedding manager")
|
47 |
+
parser.add_argument("--subject", type=str, required=True)
|
48 |
+
parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use")
|
49 |
+
parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate")
|
50 |
+
parser.add_argument("--prompt", type=str, default="a woman z in superman costume")
|
51 |
+
parser.add_argument("--noise", dest='noise_level', type=float, default=0)
|
52 |
+
parser.add_argument("--randface", action="store_true")
|
53 |
+
parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
|
54 |
+
help="Guidance scale for the diffusion model")
|
55 |
+
parser.add_argument("--id_cfg_scale", type=float, default=1,
|
56 |
+
help="CFG scale when generating the identity embeddings")
|
57 |
+
|
58 |
+
parser.add_argument("--subject_string",
|
59 |
+
type=str, default="z",
|
60 |
+
help="Subject placeholder string used in prompts to denote the concept.")
|
61 |
+
parser.add_argument("--num_vectors", type=int, default=16,
|
62 |
+
help="Number of vectors used to represent the subject.")
|
63 |
+
parser.add_argument("--num_images_per_row", type=int, default=4,
|
64 |
+
help="Number of images to display in a row in the output grid image.")
|
65 |
+
parser.add_argument("--num_inference_steps", type=int, default=50,
|
66 |
+
help="Number of DDIM inference steps")
|
67 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
|
68 |
+
parser.add_argument("--seed", type=int, default=42,
|
69 |
+
help="the seed (for reproducible sampling). Set to -1 to disable.")
|
70 |
+
args = parser.parse_args()
|
71 |
+
|
72 |
+
return args
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
args = parse_args()
|
76 |
+
if args.seed != -1:
|
77 |
+
seed_everything(args.seed)
|
78 |
+
|
79 |
+
if re.match(r"^\d+$", args.device):
|
80 |
+
args.device = f"cuda:{args.device}"
|
81 |
+
print(f"Using device {args.device}")
|
82 |
+
|
83 |
+
adaface = AdaFaceWrapper("text2img", args.base_model_path, args.embman_ckpt, args.device,
|
84 |
+
args.subject_string, args.num_vectors, args.num_inference_steps)
|
85 |
+
|
86 |
+
if not args.randface:
|
87 |
+
image_folder = args.subject
|
88 |
+
if image_folder.endswith("/"):
|
89 |
+
image_folder = image_folder[:-1]
|
90 |
+
|
91 |
+
if os.path.isfile(image_folder):
|
92 |
+
# Get the second to the last part of the path
|
93 |
+
subject_name = os.path.basename(os.path.dirname(image_folder))
|
94 |
+
image_paths = [image_folder]
|
95 |
+
|
96 |
+
else:
|
97 |
+
subject_name = os.path.basename(image_folder)
|
98 |
+
image_types = ["*.jpg", "*.png", "*.jpeg"]
|
99 |
+
alltype_image_paths = []
|
100 |
+
for image_type in image_types:
|
101 |
+
# glob returns the full path.
|
102 |
+
image_paths = glob.glob(os.path.join(image_folder, image_type))
|
103 |
+
if len(image_paths) > 0:
|
104 |
+
alltype_image_paths.extend(image_paths)
|
105 |
+
|
106 |
+
# Filter out images of "*_mask.png"
|
107 |
+
alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path]
|
108 |
+
|
109 |
+
# image_paths contain at most args.example_image_count full image paths.
|
110 |
+
if args.example_image_count > 0:
|
111 |
+
image_paths = alltype_image_paths[:args.example_image_count]
|
112 |
+
else:
|
113 |
+
image_paths = alltype_image_paths
|
114 |
+
else:
|
115 |
+
subject_name = None
|
116 |
+
image_paths = None
|
117 |
+
image_folder = None
|
118 |
+
|
119 |
+
subject_name = "randface-" + str(torch.seed()) if args.randface else subject_name
|
120 |
+
rand_face_embs = torch.randn(1, 512)
|
121 |
+
|
122 |
+
pre_face_embs = rand_face_embs if args.randface else None
|
123 |
+
noise = torch.randn(args.out_image_count, 4, 64, 64).cuda()
|
124 |
+
# args.noise_level: the *relative* std of the noise added to the face embeddings.
|
125 |
+
# A noise level of 0.08 could change gender, but 0.06 is usually safe.
|
126 |
+
# adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call).
|
127 |
+
adaface_subj_embs = adaface.generate_adaface_embeddings(image_paths, image_folder, pre_face_embs, args.randface,
|
128 |
+
out_id_embs_scale=args.id_cfg_scale, noise_level=args.noise_level,
|
129 |
+
update_text_encoder=True)
|
130 |
+
images = adaface(noise, args.prompt, args.guidance_scale, args.out_image_count, verbose=True)
|
131 |
+
save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.noise_level)
|
adaface/adaface-translate.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from adaface.adaface_wrapper import AdaFaceWrapper
|
2 |
+
import torch
|
3 |
+
#import torch.nn.functional as F
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import os, argparse, glob, re, shutil
|
7 |
+
|
8 |
+
def str2bool(v):
|
9 |
+
if isinstance(v, bool):
|
10 |
+
return v
|
11 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
12 |
+
return True
|
13 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
14 |
+
return False
|
15 |
+
else:
|
16 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
17 |
+
|
18 |
+
def seed_everything(seed):
|
19 |
+
np.random.seed(seed)
|
20 |
+
torch.manual_seed(seed)
|
21 |
+
torch.cuda.manual_seed_all(seed)
|
22 |
+
torch.backends.cudnn.deterministic = True
|
23 |
+
torch.backends.cudnn.benchmark = False
|
24 |
+
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
25 |
+
|
26 |
+
def parse_args():
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
parser.add_argument("--base_model_path", type=str, default='models/realisticvision/realisticVisionV40_v40VAE.safetensors',
|
29 |
+
help="Path to the UNet checkpoint (default: RealisticVision 4.0)")
|
30 |
+
parser.add_argument("--embman_ckpt", type=str, required=True,
|
31 |
+
help="Path to the checkpoint of the embedding manager")
|
32 |
+
parser.add_argument("--in_folder", type=str, required=True, help="Path to the folder containing input images")
|
33 |
+
# If True, the input folder contains images of mixed subjects.
|
34 |
+
# If False, the input folder contains multiple subfolders, each of which contains images of the same subject.
|
35 |
+
parser.add_argument("--is_mix_subj_folder", type=str2bool, const=True, default=False, nargs="?",
|
36 |
+
help="Whether the input folder contains images of mixed subjects")
|
37 |
+
parser.add_argument("--max_images_per_subject", type=int, default=5, help="Number of example images used per subject")
|
38 |
+
parser.add_argument("--trans_subject_count", type=int, default=-1, help="Number of example images to be translated")
|
39 |
+
parser.add_argument("--out_folder", type=str, required=True, help="Path to the folder saving output images")
|
40 |
+
parser.add_argument("--out_count_per_input_image", type=int, default=1, help="Number of output images to generate per input image")
|
41 |
+
parser.add_argument("--copy_masks", action="store_true", help="Copy the mask images to the output folder")
|
42 |
+
parser.add_argument("--noise", dest='noise_level', type=float, default=0)
|
43 |
+
parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
|
44 |
+
help="Guidance scale for the diffusion model")
|
45 |
+
parser.add_argument("--ref_img_strength", type=float, default=0.8,
|
46 |
+
help="Strength of the reference image in the output image.")
|
47 |
+
parser.add_argument("--subject_string",
|
48 |
+
type=str, default="z",
|
49 |
+
help="Subject placeholder string used in prompts to denote the concept.")
|
50 |
+
parser.add_argument("--num_vectors", type=int, default=16,
|
51 |
+
help="Number of vectors used to represent the subject.")
|
52 |
+
parser.add_argument("--prompt", type=str, default="a person z")
|
53 |
+
parser.add_argument("--num_images_per_row", type=int, default=4,
|
54 |
+
help="Number of images to display in a row in the output grid image.")
|
55 |
+
parser.add_argument("--num_inference_steps", type=int, default=50,
|
56 |
+
help="Number of DDIM inference steps")
|
57 |
+
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use. If num_gpus > 1, use accelerate for distributed execution.")
|
58 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
|
59 |
+
parser.add_argument("--seed", type=int, default=42,
|
60 |
+
help="the seed (for reproducible sampling). Set to -1 to disable.")
|
61 |
+
args = parser.parse_args()
|
62 |
+
|
63 |
+
return args
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
args = parse_args()
|
67 |
+
if args.seed != -1:
|
68 |
+
seed_everything(args.seed)
|
69 |
+
|
70 |
+
# screen -dm -L -Logfile trans_rv4-2.txt accelerate launch --multi_gpu --num_processes=2 scripts/adaface-translate.py
|
71 |
+
# --embman_ckpt logs/subjects-celebrity2024-05-16T17-22-46_zero3-ada/checkpoints/embeddings_gs-30000.pt
|
72 |
+
# --base_model_path models/realisticvision/realisticVisionV40_v40VAE.safetensors --in_folder /data/shaohua/VGGface2_HQ_masks/
|
73 |
+
# --is_mix_subj_folder 0 --out_folder /data/shaohua/VGGface2_HQ_masks_rv4a --copy_masks --num_gpus 2
|
74 |
+
if args.num_gpus > 1:
|
75 |
+
from accelerate import PartialState
|
76 |
+
distributed_state = PartialState()
|
77 |
+
args.device = distributed_state.device
|
78 |
+
process_index = distributed_state.process_index
|
79 |
+
elif re.match(r"^\d+$", args.device):
|
80 |
+
args.device = f"cuda:{args.device}"
|
81 |
+
distributed_state = None
|
82 |
+
process_index = 0
|
83 |
+
|
84 |
+
adaface = AdaFaceWrapper("img2img", args.base_model_path, args.embman_ckpt, args.device,
|
85 |
+
args.subject_string, args.num_vectors, args.num_inference_steps)
|
86 |
+
|
87 |
+
in_folder = args.in_folder
|
88 |
+
if os.path.isfile(in_folder):
|
89 |
+
subject_folders = [ os.path.dirname(in_folder) ]
|
90 |
+
images_by_subject = [[in_folder]]
|
91 |
+
else:
|
92 |
+
if not args.is_mix_subj_folder:
|
93 |
+
in_folders = [in_folder]
|
94 |
+
else:
|
95 |
+
in_folders = [ os.path.join(in_folder, subfolder) for subfolder in sorted(os.listdir(in_folder)) ]
|
96 |
+
|
97 |
+
images_by_subject = []
|
98 |
+
subject_folders = []
|
99 |
+
for in_folder in in_folders:
|
100 |
+
image_types = ["*.jpg", "*.png", "*.jpeg"]
|
101 |
+
alltype_image_paths = []
|
102 |
+
for image_type in image_types:
|
103 |
+
# glob returns the full path.
|
104 |
+
image_paths = glob.glob(os.path.join(in_folder, image_type))
|
105 |
+
if len(image_paths) > 0:
|
106 |
+
alltype_image_paths.extend(image_paths)
|
107 |
+
|
108 |
+
# Filter out images of "*_mask.png"
|
109 |
+
alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path]
|
110 |
+
alltype_image_paths = sorted(alltype_image_paths)
|
111 |
+
|
112 |
+
if not args.is_mix_subj_folder:
|
113 |
+
# image_paths contain at most args.max_images_per_subject full image paths.
|
114 |
+
if args.max_images_per_subject > 0:
|
115 |
+
image_paths = alltype_image_paths[:args.max_images_per_subject]
|
116 |
+
else:
|
117 |
+
image_paths = alltype_image_paths
|
118 |
+
|
119 |
+
images_by_subject.append(image_paths)
|
120 |
+
subject_folders.append(in_folder)
|
121 |
+
else:
|
122 |
+
# Each image in the folder is treated as an individual subject.
|
123 |
+
images_by_subject.extend([[image_path] for image_path in alltype_image_paths])
|
124 |
+
subject_folders.extend([in_folder] * len(alltype_image_paths))
|
125 |
+
|
126 |
+
if args.trans_subject_count > 0 and len(subject_folders) >= args.trans_subject_count:
|
127 |
+
break
|
128 |
+
|
129 |
+
if args.trans_subject_count > 0:
|
130 |
+
images_by_subject = images_by_subject[:args.trans_subject_count]
|
131 |
+
subject_folders = subject_folders[:args.trans_subject_count]
|
132 |
+
|
133 |
+
out_image_count = 0
|
134 |
+
out_mask_count = 0
|
135 |
+
if not args.out_folder.endswith("/"):
|
136 |
+
args.out_folder += "/"
|
137 |
+
|
138 |
+
if args.num_gpus > 1:
|
139 |
+
# Split the subjects across the GPUs.
|
140 |
+
subject_folders = subject_folders[process_index::args.num_gpus]
|
141 |
+
images_by_subject = images_by_subject[process_index::args.num_gpus]
|
142 |
+
#subject_folders, images_by_subject = distributed_state.split_between_processes(zip(subject_folders, images_by_subject))
|
143 |
+
|
144 |
+
for (subject_folder, image_paths) in zip(subject_folders, images_by_subject):
|
145 |
+
# If is_mix_subj_folder, then image_paths only contains 1 image, and we use the file name as the signature of the image.
|
146 |
+
# Otherwise, we use the folder name as the signature of the images.
|
147 |
+
images_sig = subject_folder if not args.is_mix_subj_folder else os.path.basename(image_paths[0])
|
148 |
+
|
149 |
+
print(f"Translating {images_sig}...")
|
150 |
+
with torch.no_grad():
|
151 |
+
adaface_subj_embs = adaface.generate_adaface_embeddings(image_paths, subject_folder, None, False,
|
152 |
+
out_id_embs_scale=1, noise_level=args.noise_level,
|
153 |
+
update_text_encoder=True)
|
154 |
+
|
155 |
+
# Replace the first occurrence of "in_folder" with "out_folder" in the path of the subject_folder.
|
156 |
+
subject_out_folder = subject_folder.replace(args.in_folder, args.out_folder, 1)
|
157 |
+
if not os.path.exists(subject_out_folder):
|
158 |
+
os.makedirs(subject_out_folder)
|
159 |
+
print(f"Output images will be saved to {subject_out_folder}")
|
160 |
+
|
161 |
+
in_images = []
|
162 |
+
for image_path in image_paths:
|
163 |
+
image = Image.open(image_path).convert("RGB").resize((512, 512))
|
164 |
+
# [512, 512, 3] -> [3, 512, 512].
|
165 |
+
image = np.array(image).transpose(2, 0, 1)
|
166 |
+
# Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU.
|
167 |
+
image = torch.tensor(image).unsqueeze(0).float().cuda()
|
168 |
+
in_images.append(image)
|
169 |
+
|
170 |
+
# Put all input images of the subject into a batch. This assumes max_images_per_subject is small.
|
171 |
+
# NOTE: For simplicity, we do not check overly large batch sizes.
|
172 |
+
in_images = torch.cat(in_images, dim=0)
|
173 |
+
# in_images: [5, 3, 512, 512].
|
174 |
+
# Normalize the pixel values to [0, 1].
|
175 |
+
in_images = in_images / 255.0
|
176 |
+
num_out_images = len(in_images) * args.out_count_per_input_image
|
177 |
+
|
178 |
+
with torch.no_grad():
|
179 |
+
# args.noise_level: the *relative* std of the noise added to the face embeddings.
|
180 |
+
# A noise level of 0.08 could change gender, but 0.06 is usually safe.
|
181 |
+
# The returned adaface_subj_embs are already incorporated in the text encoder, and not used explicitly.
|
182 |
+
# NOTE: We assume out_count_per_input_image == 1, so that the output images are of the same number as the input images.
|
183 |
+
out_images = adaface(in_images, args.prompt, args.guidance_scale, num_out_images, ref_img_strength=args.ref_img_strength)
|
184 |
+
|
185 |
+
for img_i, img in enumerate(out_images):
|
186 |
+
# out_images: subj_1, subj_2, ..., subj_n, subj_1, subj_2, ..., subj_n, ...
|
187 |
+
subj_i = img_i % len(in_images)
|
188 |
+
copy_i = img_i // len(in_images)
|
189 |
+
image_filename_stem, image_fileext = os.path.splitext(os.path.basename(image_paths[subj_i]))
|
190 |
+
if copy_i == 0:
|
191 |
+
img.save(os.path.join(subject_out_folder, f"{image_filename_stem}{image_fileext}"))
|
192 |
+
else:
|
193 |
+
img.save(os.path.join(subject_out_folder, f"{image_filename_stem}_{copy_i}{image_fileext}"))
|
194 |
+
|
195 |
+
if args.copy_masks:
|
196 |
+
mask_path = image_paths[subj_i].replace(image_fileext, "_mask.png")
|
197 |
+
if os.path.exists(mask_path):
|
198 |
+
if copy_i == 0:
|
199 |
+
shutil.copy(mask_path, subject_out_folder)
|
200 |
+
else:
|
201 |
+
mask_filename_stem = image_filename_stem
|
202 |
+
shutil.copy(mask_path, os.path.join(subject_out_folder, f"{mask_filename_stem}_{copy_i}_mask.png"))
|
203 |
+
|
204 |
+
out_mask_count += 1
|
205 |
+
|
206 |
+
out_image_count += len(out_images)
|
207 |
+
|
208 |
+
print(f"{out_image_count} output images and {out_mask_count} masks saved to {args.out_folder}")
|
adaface/adaface_wrapper.py
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import CLIPTextModel
|
4 |
+
from diffusers import (
|
5 |
+
StableDiffusionPipeline,
|
6 |
+
StableDiffusionImg2ImgPipeline,
|
7 |
+
UNet2DConditionModel,
|
8 |
+
DDIMScheduler,
|
9 |
+
AutoencoderKL,
|
10 |
+
)
|
11 |
+
from insightface.app import FaceAnalysis
|
12 |
+
from adaface.arc2face_models import CLIPTextModelWrapper
|
13 |
+
from adaface.util import get_arc2face_id_prompt_embs
|
14 |
+
import re, os
|
15 |
+
|
16 |
+
class AdaFaceWrapper(nn.Module):
|
17 |
+
def __init__(self, pipeline_name, base_model_path, adaface_ckpt_path, device,
|
18 |
+
subject_string='z', num_vectors=16,
|
19 |
+
num_inference_steps=50, negative_prompt=None,
|
20 |
+
use_840k_vae=False, use_ds_text_encoder=False, is_training=False):
|
21 |
+
'''
|
22 |
+
pipeline_name: "text2img" or "img2img" or None. If None, the unet and vae are
|
23 |
+
removed from the pipeline to release RAM.
|
24 |
+
'''
|
25 |
+
super().__init__()
|
26 |
+
self.pipeline_name = pipeline_name
|
27 |
+
self.base_model_path = base_model_path
|
28 |
+
self.adaface_ckpt_path = adaface_ckpt_path
|
29 |
+
self.use_840k_vae = use_840k_vae
|
30 |
+
self.use_ds_text_encoder = use_ds_text_encoder
|
31 |
+
self.subject_string = subject_string
|
32 |
+
self.num_vectors = num_vectors
|
33 |
+
self.num_inference_steps = num_inference_steps
|
34 |
+
self.device = device
|
35 |
+
self.is_training = is_training
|
36 |
+
self.initialize_pipeline()
|
37 |
+
self.extend_tokenizer_and_text_encoder()
|
38 |
+
if negative_prompt is None:
|
39 |
+
self.negative_prompt = \
|
40 |
+
"flaws in the eyes, flaws in the face, lowres, non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, " \
|
41 |
+
"mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, deformed eyeballs, cross-eyed, blurry, " \
|
42 |
+
"mutation, duplicate, out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, " \
|
43 |
+
"nude, naked, nsfw, topless, bare breasts"
|
44 |
+
else:
|
45 |
+
self.negative_prompt = negative_prompt
|
46 |
+
|
47 |
+
def load_subj_basis_generator(self, adaface_ckpt_path):
|
48 |
+
ckpt = torch.load(adaface_ckpt_path, map_location='cpu')
|
49 |
+
string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
|
50 |
+
if self.subject_string not in string_to_subj_basis_generator_dict:
|
51 |
+
print(f"Subject '{self.subject_string}' not found in the embedding manager.")
|
52 |
+
breakpoint()
|
53 |
+
|
54 |
+
self.subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string]
|
55 |
+
# In the original ckpt, num_out_layers is 16 for layerwise embeddings.
|
56 |
+
# But we don't do layerwise embeddings here, so we set it to 1.
|
57 |
+
self.subj_basis_generator.num_out_layers = 1
|
58 |
+
print(f"Loaded subject basis generator for '{self.subject_string}'.")
|
59 |
+
print(repr(self.subj_basis_generator))
|
60 |
+
self.subj_basis_generator.to(self.device)
|
61 |
+
if self.is_training:
|
62 |
+
self.subj_basis_generator.train()
|
63 |
+
else:
|
64 |
+
self.subj_basis_generator.eval()
|
65 |
+
|
66 |
+
def initialize_pipeline(self):
|
67 |
+
self.load_subj_basis_generator(self.adaface_ckpt_path)
|
68 |
+
# arc2face_text_encoder maps the face analysis embedding to 16 face embeddings
|
69 |
+
# in the UNet image space.
|
70 |
+
arc2face_text_encoder = CLIPTextModelWrapper.from_pretrained(
|
71 |
+
'models/arc2face', subfolder="encoder", torch_dtype=torch.float16
|
72 |
+
)
|
73 |
+
self.arc2face_text_encoder = arc2face_text_encoder.to(self.device)
|
74 |
+
|
75 |
+
if self.use_840k_vae:
|
76 |
+
# The 840000-step vae model is slightly better in face details than the original vae model.
|
77 |
+
# https://huggingface.co/stabilityai/sd-vae-ft-mse-original
|
78 |
+
vae = AutoencoderKL.from_single_file("models/diffusers/sd-vae-ft-mse-original/vae-ft-mse-840000-ema-pruned.ckpt", torch_dtype=torch.float16)
|
79 |
+
else:
|
80 |
+
vae = None
|
81 |
+
|
82 |
+
if self.use_ds_text_encoder:
|
83 |
+
# The dreamshaper v7 finetuned text encoder follows the prompt slightly better than the original text encoder.
|
84 |
+
# https://huggingface.co/Lykon/DreamShaper/tree/main/text_encoder
|
85 |
+
text_encoder = CLIPTextModel.from_pretrained("models/ds_text_encoder", torch_dtype=torch.float16)
|
86 |
+
else:
|
87 |
+
text_encoder = None
|
88 |
+
|
89 |
+
remove_unet = False
|
90 |
+
|
91 |
+
if self.pipeline_name == "img2img":
|
92 |
+
PipelineClass = StableDiffusionImg2ImgPipeline
|
93 |
+
elif self.pipeline_name == "text2img":
|
94 |
+
PipelineClass = StableDiffusionPipeline
|
95 |
+
# pipeline_name is None means only use this instance to generate adaface embeddings, not to generate images.
|
96 |
+
elif self.pipeline_name is None:
|
97 |
+
PipelineClass = StableDiffusionPipeline
|
98 |
+
remove_unet = True
|
99 |
+
else:
|
100 |
+
raise ValueError(f"Unknown pipeline name: {self.pipeline_name}")
|
101 |
+
|
102 |
+
if os.path.isfile(self.base_model_path):
|
103 |
+
pipeline = PipelineClass.from_single_file(
|
104 |
+
self.base_model_path,
|
105 |
+
torch_dtype=torch.float16
|
106 |
+
)
|
107 |
+
else:
|
108 |
+
pipeline = PipelineClass.from_pretrained(
|
109 |
+
self.base_model_path,
|
110 |
+
torch_dtype=torch.float16,
|
111 |
+
safety_checker=None
|
112 |
+
)
|
113 |
+
print(f"Loaded pipeline from {self.base_model_path}.")
|
114 |
+
|
115 |
+
if self.use_840k_vae:
|
116 |
+
pipeline.vae = vae
|
117 |
+
print("Replaced the VAE with the 840k-step VAE.")
|
118 |
+
|
119 |
+
if self.use_ds_text_encoder:
|
120 |
+
pipeline.text_encoder = text_encoder
|
121 |
+
print("Replaced the text encoder with the DreamShaper text encoder.")
|
122 |
+
|
123 |
+
if remove_unet:
|
124 |
+
# Remove unet and vae to release RAM. Only keep tokenizer and text_encoder.
|
125 |
+
pipeline.unet = None
|
126 |
+
pipeline.vae = None
|
127 |
+
print("Removed UNet and VAE from the pipeline.")
|
128 |
+
|
129 |
+
noise_scheduler = DDIMScheduler(
|
130 |
+
num_train_timesteps=1000,
|
131 |
+
beta_start=0.00085,
|
132 |
+
beta_end=0.012,
|
133 |
+
beta_schedule="scaled_linear",
|
134 |
+
clip_sample=False,
|
135 |
+
set_alpha_to_one=False,
|
136 |
+
steps_offset=1,
|
137 |
+
)
|
138 |
+
|
139 |
+
pipeline.scheduler = noise_scheduler
|
140 |
+
self.pipeline = pipeline.to(self.device)
|
141 |
+
# FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
|
142 |
+
# Note there's a second "model" in the path.
|
143 |
+
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
144 |
+
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
145 |
+
# Patch the missing tokenizer in the subj_basis_generator.
|
146 |
+
if not hasattr(self.subj_basis_generator, 'clip_tokenizer'):
|
147 |
+
self.subj_basis_generator.clip_tokenizer = self.pipeline.tokenizer
|
148 |
+
print("Patched the missing tokenizer in the subj_basis_generator.")
|
149 |
+
|
150 |
+
def extend_tokenizer_and_text_encoder(self):
|
151 |
+
if self.num_vectors < 1:
|
152 |
+
raise ValueError(f"num_vectors has to be larger or equal to 1, but is {self.num_vectors}")
|
153 |
+
|
154 |
+
tokenizer = self.pipeline.tokenizer
|
155 |
+
# Add z0, z1, z2, ..., z15.
|
156 |
+
self.placeholder_tokens = []
|
157 |
+
for i in range(0, self.num_vectors):
|
158 |
+
self.placeholder_tokens.append(f"{self.subject_string}_{i}")
|
159 |
+
|
160 |
+
self.placeholder_tokens_str = " ".join(self.placeholder_tokens)
|
161 |
+
|
162 |
+
# Add the new tokens to the tokenizer.
|
163 |
+
num_added_tokens = tokenizer.add_tokens(self.placeholder_tokens)
|
164 |
+
if num_added_tokens != self.num_vectors:
|
165 |
+
raise ValueError(
|
166 |
+
f"The tokenizer already contains the token {self.subject_string}. Please pass a different"
|
167 |
+
" `subject_string` that is not already in the tokenizer.")
|
168 |
+
|
169 |
+
print(f"Added {num_added_tokens} tokens ({self.placeholder_tokens_str}) to the tokenizer.")
|
170 |
+
|
171 |
+
# placeholder_token_ids: [49408, ..., 49423].
|
172 |
+
self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.placeholder_tokens)
|
173 |
+
# print(self.placeholder_token_ids)
|
174 |
+
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
175 |
+
old_weight = self.pipeline.text_encoder.get_input_embeddings().weight
|
176 |
+
self.pipeline.text_encoder.resize_token_embeddings(len(tokenizer))
|
177 |
+
new_weight = self.pipeline.text_encoder.get_input_embeddings().weight
|
178 |
+
print(f"Resized text encoder token embeddings from {old_weight.shape} to {new_weight.shape} on {new_weight.device}.")
|
179 |
+
|
180 |
+
# Extend pipeline.text_encoder with the adaface subject emeddings.
|
181 |
+
# subj_embs: [16, 768].
|
182 |
+
def update_text_encoder_subj_embs(self, subj_embs):
|
183 |
+
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
184 |
+
token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data
|
185 |
+
with torch.no_grad():
|
186 |
+
for i, token_id in enumerate(self.placeholder_token_ids):
|
187 |
+
token_embeds[token_id] = subj_embs[i]
|
188 |
+
print(f"Updated {len(self.placeholder_token_ids)} tokens ({self.placeholder_tokens_str}) in the text encoder.")
|
189 |
+
|
190 |
+
def update_prompt(self, prompt):
|
191 |
+
# If the placeholder tokens are already in the prompt, then return the prompt as is.
|
192 |
+
if self.placeholder_tokens_str in prompt:
|
193 |
+
return prompt
|
194 |
+
|
195 |
+
# If the subject string 'z' is not in the prompt, then simply prepend the placeholder tokens to the prompt.
|
196 |
+
if re.search(r'\b' + self.subject_string + r'\b', prompt) is None:
|
197 |
+
print(f"Subject string '{self.subject_string}' not found in the prompt. Adding it.")
|
198 |
+
comp_prompt = self.placeholder_tokens_str + " " + prompt
|
199 |
+
else:
|
200 |
+
# Replace the subject string 'z' with the placeholder tokens.
|
201 |
+
comp_prompt = re.sub(r'\b' + self.subject_string + r'\b', self.placeholder_tokens_str, prompt)
|
202 |
+
return comp_prompt
|
203 |
+
|
204 |
+
# image_paths: a list of image paths. image_folder: the parent folder name.
|
205 |
+
def generate_adaface_embeddings(self, image_paths, image_folder=None,
|
206 |
+
pre_face_embs=None, gen_rand_face=False,
|
207 |
+
out_id_embs_scale=1., noise_level=0, update_text_encoder=True):
|
208 |
+
# faceid_embeds is a batch of extracted face analysis embeddings (BS * 512 = id_batch_size * 512).
|
209 |
+
# If extract_faceid_embeds is True, faceid_embeds is *the same* embedding repeated by id_batch_size times.
|
210 |
+
# Otherwise, faceid_embeds is a batch of random embeddings, each instance is different.
|
211 |
+
# The same applies to id_prompt_emb.
|
212 |
+
# faceid_embeds is in the face analysis embeddings. id_prompt_emb is in the image prompt space.
|
213 |
+
# Here id_batch_size = 1, so
|
214 |
+
# faceid_embeds: [1, 512]. NOT used later.
|
215 |
+
# id_prompt_emb: [1, 16, 768].
|
216 |
+
# NOTE: Since return_core_id_embs is True, id_prompt_emb is only the 16 core ID embeddings.
|
217 |
+
# arc2face prompt template: "photo of a id person"
|
218 |
+
# ID embeddings start from "id person ...". So there are 3 template tokens before the 16 ID embeddings.
|
219 |
+
faceid_embeds, id_prompt_emb \
|
220 |
+
= get_arc2face_id_prompt_embs(self.face_app, self.pipeline.tokenizer, self.arc2face_text_encoder,
|
221 |
+
extract_faceid_embeds=not gen_rand_face,
|
222 |
+
pre_face_embs=pre_face_embs,
|
223 |
+
# image_folder is passed only for logging purpose.
|
224 |
+
# image_paths contains the paths of the images.
|
225 |
+
image_folder=image_folder, image_paths=image_paths,
|
226 |
+
images_np=None,
|
227 |
+
id_batch_size=1,
|
228 |
+
device=self.device,
|
229 |
+
# input_max_length == 22: only keep the first 22 tokens,
|
230 |
+
# including 3 template tokens and 16 ID tokens, and BOS and EOS tokens.
|
231 |
+
# The results are indistinguishable from input_max_length=77.
|
232 |
+
input_max_length=22,
|
233 |
+
noise_level=noise_level,
|
234 |
+
return_core_id_embs=True,
|
235 |
+
gen_neg_prompt=False,
|
236 |
+
verbose=True)
|
237 |
+
|
238 |
+
# adaface_subj_embs: [1, 1, 16, 768].
|
239 |
+
# adaface_prompt_embs: [1, 77, 768] (not used).
|
240 |
+
adaface_subj_embs, adaface_prompt_embs = \
|
241 |
+
self.subj_basis_generator(id_prompt_emb, None, None,
|
242 |
+
out_id_embs_scale=out_id_embs_scale,
|
243 |
+
is_face=True, is_training=False,
|
244 |
+
adaface_prompt_embs_inf_type='full_half_pad')
|
245 |
+
# adaface_subj_embs: [16, 768]
|
246 |
+
adaface_subj_embs = adaface_subj_embs.squeeze()
|
247 |
+
if update_text_encoder:
|
248 |
+
self.update_text_encoder_subj_embs(adaface_subj_embs)
|
249 |
+
return adaface_subj_embs
|
250 |
+
|
251 |
+
def encode_prompt(self, prompt, device="cuda", verbose=False):
|
252 |
+
prompt = self.update_prompt(prompt)
|
253 |
+
if verbose:
|
254 |
+
print(f"Prompt: {prompt}")
|
255 |
+
|
256 |
+
# For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device).
|
257 |
+
# So we manually move it to GPU here.
|
258 |
+
self.pipeline.text_encoder.to(device)
|
259 |
+
# prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
|
260 |
+
prompt_embeds_, negative_prompt_embeds_ = \
|
261 |
+
self.pipeline.encode_prompt(prompt, device=device, num_images_per_prompt=1,
|
262 |
+
do_classifier_free_guidance=True, negative_prompt=self.negative_prompt)
|
263 |
+
return prompt_embeds_, negative_prompt_embeds_
|
264 |
+
|
265 |
+
# ref_img_strength is used only in the img2img pipeline.
|
266 |
+
def forward(self, noise, prompt, guidance_scale=4.0, out_image_count=4, ref_img_strength=0.8, verbose=False):
|
267 |
+
# prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
|
268 |
+
prompt_embeds_, negative_prompt_embeds_ = self.encode_prompt(prompt, device=self.device, verbose=verbose)
|
269 |
+
|
270 |
+
# Repeat the prompt embeddings for all images in the batch.
|
271 |
+
prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1)
|
272 |
+
negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1)
|
273 |
+
noise = noise.to(self.device).to(torch.float16)
|
274 |
+
|
275 |
+
# noise: [BS, 4, 64, 64]
|
276 |
+
# When the pipeline is text2img, strength is ignored.
|
277 |
+
images = self.pipeline(image=noise,
|
278 |
+
prompt_embeds=prompt_embeds_,
|
279 |
+
negative_prompt_embeds=negative_prompt_embeds_,
|
280 |
+
num_inference_steps=self.num_inference_steps,
|
281 |
+
guidance_scale=guidance_scale,
|
282 |
+
num_images_per_prompt=1,
|
283 |
+
strength=ref_img_strength).images
|
284 |
+
# images: [BS, 3, 512, 512]
|
285 |
+
return images
|
286 |
+
|
adaface/arc2face_models.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import CLIPTextModel
|
4 |
+
from transformers.models.clip.modeling_clip import CLIPAttention
|
5 |
+
from typing import Any, Callable, Dict, Optional, Tuple, Union, List
|
6 |
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
7 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
8 |
+
# from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
|
9 |
+
_make_causal_mask = AttentionMaskConverter._make_causal_mask
|
10 |
+
_expand_mask = AttentionMaskConverter._expand_mask
|
11 |
+
|
12 |
+
from adaface.util import add_noise_to_tensor
|
13 |
+
|
14 |
+
# Extend CLIPAttention by using multiple k_proj and v_proj in each head.
|
15 |
+
# To avoid too much increase of computation, we don't extend q_proj.
|
16 |
+
class CLIPAttentionMKV(nn.Module):
|
17 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
18 |
+
|
19 |
+
def __init__(self, config, multiplier=2):
|
20 |
+
super().__init__()
|
21 |
+
self.config = config
|
22 |
+
self.embed_dim = config.hidden_size
|
23 |
+
self.num_heads = config.num_attention_heads
|
24 |
+
self.head_dim = self.embed_dim // self.num_heads
|
25 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
26 |
+
raise ValueError(
|
27 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
28 |
+
f" {self.num_heads})."
|
29 |
+
)
|
30 |
+
self.scale = self.head_dim**-0.5
|
31 |
+
self.dropout = config.attention_dropout
|
32 |
+
self.multiplier = multiplier
|
33 |
+
|
34 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim * self.multiplier)
|
35 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim * self.multiplier)
|
36 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
37 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
38 |
+
|
39 |
+
# The (approximately) repeated token features are repeated along the last dim in tensor
|
40 |
+
# (multiplier * num_heads * head_dim), and then reshaped to (bsz, -1, num_heads, head_dim).
|
41 |
+
# Therefore, the "multiplier" dim is tucked into the seq_len dim, which looks like
|
42 |
+
# [token1_emb, token1_emb, token2_emb, token2_emb, ..., tokenN_emb, tokenN_emb].
|
43 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
44 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
45 |
+
|
46 |
+
def extend_weights(self, clip_attn_layer, layer_idx, multiplier, noise_std=0.1,
|
47 |
+
noise_std_is_relative=True, keep_norm=False, verbose=False):
|
48 |
+
self.multiplier *= multiplier
|
49 |
+
# q_proj and out_proj are the same as the original CLIPAttention.
|
50 |
+
self.q_proj.weight.data = clip_attn_layer.q_proj.weight.data.clone()
|
51 |
+
self.q_proj.bias.data = clip_attn_layer.q_proj.bias.data.clone()
|
52 |
+
self.out_proj.weight.data = clip_attn_layer.out_proj.weight.data.clone()
|
53 |
+
self.out_proj.bias.data = clip_attn_layer.out_proj.bias.data.clone()
|
54 |
+
|
55 |
+
# bias doesn't need noise perturbation, as after the weights are noised,
|
56 |
+
# different copies of the weight/bias will receive different gradients,
|
57 |
+
# making the bias terms diverge and identifiable after training.
|
58 |
+
self.v_proj.bias.data = clip_attn_layer.v_proj.bias.data.repeat(multiplier)
|
59 |
+
self.k_proj.bias.data = clip_attn_layer.k_proj.bias.data.repeat(multiplier)
|
60 |
+
|
61 |
+
self.v_proj.weight.data = clip_attn_layer.v_proj.weight.data.repeat(multiplier, 1)
|
62 |
+
self.k_proj.weight.data = clip_attn_layer.k_proj.weight.data.repeat(multiplier, 1)
|
63 |
+
|
64 |
+
if noise_std > 0:
|
65 |
+
ORIG_V_SHAPE = list(clip_attn_layer.v_proj.weight.shape)
|
66 |
+
ORIG_V_SHAPE_D0 = ORIG_V_SHAPE[0]
|
67 |
+
# Adding noise to the extra copies of the weights (keep the first copy unchanged).
|
68 |
+
self.v_proj.weight.data[ORIG_V_SHAPE_D0:] = \
|
69 |
+
add_noise_to_tensor(self.v_proj.weight.data[ORIG_V_SHAPE_D0:],
|
70 |
+
noise_std, noise_std_is_relative, keep_norm)
|
71 |
+
if verbose:
|
72 |
+
NEW_V_SHAPE = list(self.v_proj.weight.shape)
|
73 |
+
NOISED_V_SHAPE = list(self.v_proj.weight.data[ORIG_V_SHAPE_D0:].shape)
|
74 |
+
print(f"Layer {layer_idx}: {NOISED_V_SHAPE} in {NEW_V_SHAPE} of v_proj is added with {noise_std} noise")
|
75 |
+
|
76 |
+
ORIG_K_SHAPE = list(clip_attn_layer.k_proj.weight.shape)
|
77 |
+
ORIG_K_SHAPE_D0 = ORIG_K_SHAPE[0]
|
78 |
+
# Adding noise to the extra copies of the weights.
|
79 |
+
self.k_proj.weight.data[ORIG_K_SHAPE_D0:] = \
|
80 |
+
add_noise_to_tensor(self.k_proj.weight.data[ORIG_K_SHAPE_D0:],
|
81 |
+
noise_std, noise_std_is_relative, keep_norm)
|
82 |
+
if verbose:
|
83 |
+
NEW_K_SHAPE = list(self.k_proj.weight.shape)
|
84 |
+
NOISED_K_SHAPE = list(self.k_proj.weight.data[ORIG_K_SHAPE_D0:].shape)
|
85 |
+
print(f"Layer {layer_idx}: {NOISED_K_SHAPE} in {NEW_K_SHAPE} of k_proj is added with {noise_std} noise")
|
86 |
+
|
87 |
+
def forward(
|
88 |
+
self,
|
89 |
+
hidden_states: torch.Tensor,
|
90 |
+
attention_mask: Optional[torch.Tensor] = None,
|
91 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
92 |
+
output_attentions: Optional[bool] = False,
|
93 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
94 |
+
"""Input shape: Batch x Time x Channel"""
|
95 |
+
|
96 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
97 |
+
|
98 |
+
query_states = self.q_proj(hidden_states) * self.scale
|
99 |
+
# For key_states and value_states, the multiplier is absorbed into the seq_len (dim 1, shape specified as -1).
|
100 |
+
# [token0_head_emb, token0_head_emb, token1_head_emb, token1_head_emb, ..., tokenN-1_head_emb, tokenN-1_head_emb].
|
101 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
102 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
103 |
+
|
104 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
105 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
106 |
+
key_states = key_states.view(*proj_shape)
|
107 |
+
value_states = value_states.view(*proj_shape)
|
108 |
+
|
109 |
+
src_len = key_states.size(1)
|
110 |
+
# src_len0 is the original src_len without the multiplier.
|
111 |
+
src_len0 = src_len // self.multiplier
|
112 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
113 |
+
|
114 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
115 |
+
raise ValueError(
|
116 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
117 |
+
f" {attn_weights.size()}"
|
118 |
+
)
|
119 |
+
|
120 |
+
# apply the causal_attention_mask first
|
121 |
+
if causal_attention_mask is not None:
|
122 |
+
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len0):
|
123 |
+
raise ValueError(
|
124 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len0)}, but is"
|
125 |
+
f" {causal_attention_mask.size()}"
|
126 |
+
)
|
127 |
+
# The last dim of attn_weights corresponds to [token0, token0, token1, token1, ..., tokenN-1, tokenN-1].
|
128 |
+
# If reshaping it as (self.multiplier, src_len0), it will become
|
129 |
+
# [[token0, token0, token1, token1, ..., tokenN//2], [tokenN//2+1, tokenN//2+1, ..., tokenN-1, tokenN-1]],
|
130 |
+
# and the mask will be applied to wrong elements.
|
131 |
+
# If reshaping it as (src_len0, self.multiplier), it will become
|
132 |
+
# [[token0, token1, ..., tokenN-1], [token0, token1, ..., tokenN-1]], and then
|
133 |
+
# the mask at element i will mask all the multiplier elements at i, which is desired.
|
134 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len0, self.multiplier) + causal_attention_mask.unsqueeze(4)
|
135 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
136 |
+
|
137 |
+
if attention_mask is not None:
|
138 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len0):
|
139 |
+
raise ValueError(
|
140 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len0)}, but is {attention_mask.size()}"
|
141 |
+
)
|
142 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len0, self.multiplier) + attention_mask.unsqueeze(4)
|
143 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
144 |
+
|
145 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
146 |
+
|
147 |
+
if output_attentions:
|
148 |
+
# this operation is a bit awkward, but it's required to
|
149 |
+
# make sure that attn_weights keeps its gradient.
|
150 |
+
# In order to do so, attn_weights have to reshaped
|
151 |
+
# twice and have to be reused in the following
|
152 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
153 |
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
154 |
+
else:
|
155 |
+
attn_weights_reshaped = None
|
156 |
+
|
157 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
158 |
+
|
159 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
160 |
+
|
161 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
162 |
+
raise ValueError(
|
163 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
164 |
+
f" {attn_output.size()}"
|
165 |
+
)
|
166 |
+
|
167 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
168 |
+
attn_output = attn_output.transpose(1, 2)
|
169 |
+
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
170 |
+
|
171 |
+
attn_output = self.out_proj(attn_output)
|
172 |
+
|
173 |
+
return attn_output, attn_weights_reshaped
|
174 |
+
|
175 |
+
class CLIPTextModelWrapper(CLIPTextModel):
|
176 |
+
# Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
|
177 |
+
# Modified to accept precomputed token embeddings "input_token_embs" as input or calculate them from input_ids and return them.
|
178 |
+
def forward(
|
179 |
+
self,
|
180 |
+
input_ids: Optional[torch.Tensor] = None,
|
181 |
+
attention_mask: Optional[torch.Tensor] = None,
|
182 |
+
position_ids: Optional[torch.Tensor] = None,
|
183 |
+
output_attentions: Optional[bool] = None,
|
184 |
+
output_hidden_states: Optional[bool] = None,
|
185 |
+
return_dict: Optional[bool] = None,
|
186 |
+
input_token_embs: Optional[torch.Tensor] = None,
|
187 |
+
hidden_state_layer_weights: Optional[torch.Tensor] = None,
|
188 |
+
return_token_embs: Optional[bool] = False,
|
189 |
+
) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]:
|
190 |
+
|
191 |
+
if return_token_embs:
|
192 |
+
return self.text_model.embeddings.token_embedding(input_ids)
|
193 |
+
|
194 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
195 |
+
|
196 |
+
output_attentions = output_attentions if output_attentions is not None else self.text_model.config.output_attentions
|
197 |
+
output_hidden_states = (
|
198 |
+
output_hidden_states if output_hidden_states is not None else self.text_model.config.output_hidden_states
|
199 |
+
)
|
200 |
+
if hidden_state_layer_weights is not None:
|
201 |
+
output_hidden_states = True
|
202 |
+
return_dict = return_dict if return_dict is not None else self.text_model.config.use_return_dict
|
203 |
+
|
204 |
+
if input_ids is None:
|
205 |
+
raise ValueError("You have to specify input_ids")
|
206 |
+
|
207 |
+
input_shape = input_ids.size()
|
208 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
209 |
+
|
210 |
+
hidden_states = self.text_model.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=input_token_embs)
|
211 |
+
|
212 |
+
# CLIP's text model uses causal mask, prepare it here.
|
213 |
+
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
214 |
+
causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
|
215 |
+
# expand attention_mask
|
216 |
+
if attention_mask is not None:
|
217 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
218 |
+
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
219 |
+
|
220 |
+
encoder_outputs = self.text_model.encoder(
|
221 |
+
inputs_embeds=hidden_states,
|
222 |
+
attention_mask=attention_mask,
|
223 |
+
causal_attention_mask=causal_attention_mask,
|
224 |
+
output_attentions=output_attentions,
|
225 |
+
# output_hidden_states is False by default, and only True if hidden_state_layer_weights is provided.
|
226 |
+
output_hidden_states=output_hidden_states,
|
227 |
+
return_dict=return_dict,
|
228 |
+
)
|
229 |
+
|
230 |
+
# If output_hidden_states is True, then encoder_outputs[0] is last_hidden_state [1, 22, 768].
|
231 |
+
# encoder_outputs[1] is hidden_states, which is a tuple of 13 hidden states, each being [1, 22, 768].
|
232 |
+
# encoder_outputs[0] == encoder_outputs[1][12].
|
233 |
+
if hidden_state_layer_weights is None:
|
234 |
+
last_hidden_state = encoder_outputs[0]
|
235 |
+
else:
|
236 |
+
num_hidden_state_layers = len(hidden_state_layer_weights)
|
237 |
+
last_hidden_states = encoder_outputs[1][-num_hidden_state_layers:]
|
238 |
+
hidden_state_layer_weights = hidden_state_layer_weights.to(last_hidden_states[0].dtype)
|
239 |
+
# Normalize the weights of to sum to 1 across layers.
|
240 |
+
# hidden_state_layer_weights: [3, 1] or [3, 768].
|
241 |
+
hidden_state_layer_weights = hidden_state_layer_weights / hidden_state_layer_weights.sum(dim=0, keepdim=True)
|
242 |
+
# [3, 1/768] -> [3, 1, 1, 1/768]
|
243 |
+
hidden_state_layer_weights = hidden_state_layer_weights.unsqueeze(1).unsqueeze(1)
|
244 |
+
# A weighted sum of last_hidden_states.
|
245 |
+
# [3, 1, 22, 768] * [3, 1, 1, 1/768] -> [3, 1, 22, 768] -> [1, 22, 768]
|
246 |
+
last_hidden_state = (torch.stack(last_hidden_states, dim=0) * hidden_state_layer_weights).sum(dim=0)
|
247 |
+
|
248 |
+
last_hidden_state = self.text_model.final_layer_norm(last_hidden_state)
|
249 |
+
|
250 |
+
# self.text_model.eos_token_id == 2 is True.
|
251 |
+
if self.text_model.eos_token_id == 2:
|
252 |
+
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
253 |
+
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
254 |
+
# ------------------------------------------------------------
|
255 |
+
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
256 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
257 |
+
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
258 |
+
pooled_output = last_hidden_state[
|
259 |
+
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
260 |
+
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
261 |
+
]
|
262 |
+
else:
|
263 |
+
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
|
264 |
+
pooled_output = last_hidden_state[
|
265 |
+
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
266 |
+
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
|
267 |
+
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.text_model.eos_token_id)
|
268 |
+
.int()
|
269 |
+
.argmax(dim=-1),
|
270 |
+
]
|
271 |
+
|
272 |
+
if not return_dict:
|
273 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
274 |
+
|
275 |
+
return BaseModelOutputWithPooling(
|
276 |
+
last_hidden_state=last_hidden_state,
|
277 |
+
pooler_output=pooled_output,
|
278 |
+
hidden_states=encoder_outputs.hidden_states,
|
279 |
+
attentions=encoder_outputs.attentions,
|
280 |
+
)
|
281 |
+
|
282 |
+
# Applied to layers [begin_layer_idx, end_layer_idx) in the encoder.
|
283 |
+
# The layer indexed by end_layer_idx is not included.
|
284 |
+
# If both layer indices are -1, then apply to all layers (0-11).
|
285 |
+
def extend_clip_attention_MKV_multiplier(self, begin_layer_idx=-1, end_layer_idx=-1, multiplier=2, noise_std=0.1):
|
286 |
+
num_extended_layers = 0
|
287 |
+
|
288 |
+
for layer_idx, layer in enumerate(self.text_model.encoder.layers):
|
289 |
+
if begin_layer_idx >= 0 and layer_idx < begin_layer_idx:
|
290 |
+
continue
|
291 |
+
if end_layer_idx >= 0 and layer_idx >= end_layer_idx:
|
292 |
+
break
|
293 |
+
# This shouldn't happen, unless self_attn has already been extended as CLIPAttentionMKV.
|
294 |
+
if not isinstance(layer.self_attn, (CLIPAttention, CLIPAttentionMKV)):
|
295 |
+
breakpoint()
|
296 |
+
old_attn_layer = layer.self_attn
|
297 |
+
if not isinstance(old_attn_layer, CLIPAttentionMKV):
|
298 |
+
layer.self_attn = CLIPAttentionMKV(old_attn_layer.config, 1)
|
299 |
+
layer.self_attn.extend_weights(old_attn_layer, layer_idx, multiplier, noise_std, verbose=True)
|
300 |
+
num_extended_layers += 1
|
301 |
+
|
302 |
+
return num_extended_layers
|
303 |
+
|
adaface/subj_basis_generator.py
ADDED
@@ -0,0 +1,758 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Borrowed from ip-adapter resampler.py.
|
2 |
+
# https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
|
3 |
+
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
4 |
+
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from einops import rearrange
|
12 |
+
from einops.layers.torch import Rearrange
|
13 |
+
from transformers import CLIPVisionModel, CLIPTokenizer
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
from torch import einsum
|
17 |
+
from dataclasses import dataclass
|
18 |
+
from typing import Optional, Tuple
|
19 |
+
from transformers.utils import ModelOutput
|
20 |
+
from adaface.util import arc2face_inverse_face_prompt_embs, gen_gradient_scaler
|
21 |
+
from adaface.arc2face_models import CLIPTextModelWrapper
|
22 |
+
import sys
|
23 |
+
sys.modules['ldm'] = sys.modules['adaface']
|
24 |
+
|
25 |
+
def reshape_tensor(x, num_heads):
|
26 |
+
bs, length, width = x.shape
|
27 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
28 |
+
x = x.view(bs, length, num_heads, -1)
|
29 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
30 |
+
x = x.transpose(1, 2)
|
31 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
32 |
+
x = x.reshape(bs, num_heads, length, -1)
|
33 |
+
return x
|
34 |
+
|
35 |
+
# FFN. Added a Dropout layer at the end, so that it can still load the old ckpt.
|
36 |
+
def FeedForward(dim, mult=4, p_dropout=0.1):
|
37 |
+
inner_dim = int(dim * mult)
|
38 |
+
return nn.Sequential(
|
39 |
+
nn.LayerNorm(dim),
|
40 |
+
nn.Linear(dim, inner_dim, bias=False),
|
41 |
+
nn.GELU(),
|
42 |
+
nn.Linear(inner_dim, dim, bias=False),
|
43 |
+
nn.Dropout(p_dropout),
|
44 |
+
)
|
45 |
+
|
46 |
+
# IP-Adapter FaceID class. Only used in knn-faces.py.
|
47 |
+
# From: https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter_faceid_separate.py
|
48 |
+
class IP_MLPProjModel(nn.Module):
|
49 |
+
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
50 |
+
super().__init__()
|
51 |
+
|
52 |
+
self.cross_attention_dim = cross_attention_dim
|
53 |
+
self.num_tokens = num_tokens
|
54 |
+
|
55 |
+
self.proj = nn.Sequential(
|
56 |
+
nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
57 |
+
nn.GELU(),
|
58 |
+
nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
59 |
+
)
|
60 |
+
self.norm = nn.LayerNorm(cross_attention_dim)
|
61 |
+
|
62 |
+
def forward(self, id_embeds):
|
63 |
+
x = self.proj(id_embeds)
|
64 |
+
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
65 |
+
x = self.norm(x)
|
66 |
+
return x
|
67 |
+
|
68 |
+
# group_dim: the tensor dimension that corresponds to the multiple groups.
|
69 |
+
class LearnedSoftAggregate(nn.Module):
|
70 |
+
def __init__(self, num_feat, group_dim, keepdim=False):
|
71 |
+
super(LearnedSoftAggregate, self).__init__()
|
72 |
+
self.group_dim = group_dim
|
73 |
+
# num_feat = 1: element-wise score function & softmax.
|
74 |
+
# num_feat > 1: the linear score function is applied to the last dim (features) of the input tensor.
|
75 |
+
self.num_feat = num_feat
|
76 |
+
self.feat2score = nn.Linear(num_feat, 1, bias=False)
|
77 |
+
self.keepdim = keepdim
|
78 |
+
|
79 |
+
def forward(self, x, score_basis=None):
|
80 |
+
# If there's only one mode, do nothing.
|
81 |
+
if x.shape[self.group_dim] == 1:
|
82 |
+
if self.keepdim:
|
83 |
+
return x
|
84 |
+
else:
|
85 |
+
return x.squeeze(self.group_dim)
|
86 |
+
|
87 |
+
# Assume the last dim of x is the feature dim.
|
88 |
+
if score_basis is None:
|
89 |
+
score_basis = x
|
90 |
+
|
91 |
+
if self.num_feat == 1:
|
92 |
+
mode_scores = self.feat2score(score_basis.unsqueeze(-1)).squeeze(-1)
|
93 |
+
else:
|
94 |
+
mode_scores = self.feat2score(score_basis)
|
95 |
+
attn_probs = mode_scores.softmax(dim=self.group_dim)
|
96 |
+
x_aggr = (x * attn_probs).sum(dim=self.group_dim, keepdim=self.keepdim)
|
97 |
+
return x_aggr
|
98 |
+
|
99 |
+
def LoRA_ExpandEmbs(input_dim, lora_rank, output_dim, num_modes,
|
100 |
+
num_output_vecs, elementwise_affine=True, p_dropout=0.1):
|
101 |
+
return nn.Sequential(
|
102 |
+
# Project to [BS, lora_rank * output_dim * num_modes].
|
103 |
+
# It takes a huge param size. 512 * 32 * 768 * 4 = 6,291,456.
|
104 |
+
nn.Linear(input_dim, lora_rank * output_dim * num_modes, bias=False),
|
105 |
+
# Reshape to [BS, lora_rank, output_dim].
|
106 |
+
Rearrange('b (m q d) -> b m q d', q=lora_rank, m=num_modes, d=output_dim),
|
107 |
+
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
108 |
+
# Aggregate [BS, num_modes, loar_rank, output_dim] -> [BS, lora_rank, output_dim].
|
109 |
+
LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \
|
110 |
+
else Rearrange('b () q d -> b q d'),
|
111 |
+
nn.Dropout(p_dropout),
|
112 |
+
# Permute to [BS, output_dim, lora_rank].
|
113 |
+
Rearrange('b q d -> b d q'),
|
114 |
+
# Project to [BS, output_dim, num_output_vecs].
|
115 |
+
nn.Linear(lora_rank, num_output_vecs, bias=False),
|
116 |
+
# Permute to [BS, num_output_vecs, output_dim].
|
117 |
+
Rearrange('b d q -> b q d'),
|
118 |
+
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
119 |
+
nn.Dropout(p_dropout),
|
120 |
+
)
|
121 |
+
|
122 |
+
def ExpandEmbs(input_dim, output_dim, expansion_ratio, elementwise_affine=True, p_dropout=0.1):
|
123 |
+
return nn.Sequential(
|
124 |
+
# Project to [BS, num_output_vecs * output_dim].
|
125 |
+
nn.Linear(input_dim, expansion_ratio * output_dim, bias=False),
|
126 |
+
# Reshape to [BS, num_output_vecs, output_dim].
|
127 |
+
Rearrange('b (e d) -> b e d', e=expansion_ratio, d=output_dim),
|
128 |
+
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
129 |
+
nn.Dropout(p_dropout),
|
130 |
+
)
|
131 |
+
|
132 |
+
# Input: [BS, N, D].
|
133 |
+
def MultimodeProjection(input_dim, output_dim=-1, num_modes=4, elementwise_affine=True, p_dropout=0.1):
|
134 |
+
if output_dim == -1:
|
135 |
+
output_dim = input_dim
|
136 |
+
|
137 |
+
return nn.Sequential(
|
138 |
+
nn.Linear(input_dim, output_dim * num_modes, bias=False),
|
139 |
+
# Reshape to [BS, num_output_vecs, output_dim].
|
140 |
+
Rearrange('b n (m d) -> b n m d', m=num_modes, d=output_dim),
|
141 |
+
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
142 |
+
# If num_modes == 1, then simply remove the mode dim. Otherwise, aggregate the modes.
|
143 |
+
LearnedSoftAggregate(num_feat=output_dim, group_dim=2, keepdim=False) if num_modes > 1 \
|
144 |
+
else Rearrange('b n () d -> b n d'),
|
145 |
+
nn.Dropout(p_dropout),
|
146 |
+
)
|
147 |
+
|
148 |
+
# Low-rank to high-rank transformation.
|
149 |
+
def Lora2Hira(lora_rank, hira_rank, output_dim, num_modes, elementwise_affine=True, p_dropout=0.1):
|
150 |
+
return nn.Sequential(
|
151 |
+
# Permute to [BS, output_dim, lora_rank].
|
152 |
+
Rearrange('b q d -> b d q'),
|
153 |
+
# Project to [BS, output_dim, hira_rank].
|
154 |
+
nn.Linear(lora_rank, hira_rank * num_modes, bias=False),
|
155 |
+
# Reshape and permute to [BS, num_modes, num_output_vecs, output_dim].
|
156 |
+
Rearrange('b d (m q) -> b m q d', m=num_modes, q=hira_rank),
|
157 |
+
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
158 |
+
# Aggregate [BS, num_modes, hira_rank, output_dim] -> [BS, hira_rank, output_dim].
|
159 |
+
LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \
|
160 |
+
else Rearrange('b () q d -> b q d'),
|
161 |
+
nn.Dropout(p_dropout),
|
162 |
+
)
|
163 |
+
|
164 |
+
class PerceiverAttention(nn.Module):
|
165 |
+
def __init__(self, *, dim, dim_head=64, num_heads=8, elementwise_affine=True):
|
166 |
+
super().__init__()
|
167 |
+
self.scale = dim_head**-0.5
|
168 |
+
self.dim_head = dim_head
|
169 |
+
self.num_heads = num_heads
|
170 |
+
inner_dim = dim_head * num_heads
|
171 |
+
|
172 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine)
|
173 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine)
|
174 |
+
|
175 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
176 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
177 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
178 |
+
|
179 |
+
def forward(self, x, latent_queries):
|
180 |
+
"""
|
181 |
+
Args:
|
182 |
+
x (torch.Tensor): image features
|
183 |
+
shape (b, n1, D)
|
184 |
+
latent (torch.Tensor): latent features
|
185 |
+
shape (b, n2, D)
|
186 |
+
"""
|
187 |
+
x = self.norm1(x)
|
188 |
+
latent_queries = self.norm2(latent_queries)
|
189 |
+
|
190 |
+
b, l, _ = latent_queries.shape
|
191 |
+
|
192 |
+
q = self.to_q(latent_queries)
|
193 |
+
kv_input = torch.cat((x, latent_queries), dim=-2)
|
194 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
195 |
+
|
196 |
+
q = reshape_tensor(q, self.num_heads)
|
197 |
+
k = reshape_tensor(k, self.num_heads)
|
198 |
+
v = reshape_tensor(v, self.num_heads)
|
199 |
+
|
200 |
+
# attention
|
201 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
202 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
203 |
+
attn = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
204 |
+
out = attn @ v
|
205 |
+
|
206 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
207 |
+
|
208 |
+
return self.to_out(out)
|
209 |
+
|
210 |
+
|
211 |
+
class CrossAttention(nn.Module):
|
212 |
+
# output_dim is always the same as input_dim.
|
213 |
+
# num_q only matters when q_aware_to_v is True.
|
214 |
+
# If q_aware_to_v is False, query x in forward() is still usable.
|
215 |
+
def __init__(self, input_dim, num_heads=6, p_dropout=0.05,
|
216 |
+
identity_to_q=False, identity_to_k=False, identity_to_v=False, v_has_skip=True,
|
217 |
+
q_aware_to_v=True, num_q=416, v_repeat=4, q_aware_to_v_lora_rank=64,
|
218 |
+
identity_to_out=False, out_has_skip=False):
|
219 |
+
super().__init__()
|
220 |
+
dim_head = input_dim // num_heads
|
221 |
+
inner_dim = dim_head * num_heads
|
222 |
+
|
223 |
+
self.num_heads = num_heads
|
224 |
+
self.q_aware_to_v = q_aware_to_v
|
225 |
+
self.v_has_skip = v_has_skip
|
226 |
+
self.to_q = nn.Sequential(
|
227 |
+
nn.Linear(input_dim, inner_dim, bias=False),
|
228 |
+
nn.LayerNorm(inner_dim, elementwise_affine=True)
|
229 |
+
) if not identity_to_q else nn.Identity()
|
230 |
+
self.to_k = nn.Sequential(
|
231 |
+
nn.Linear(input_dim, inner_dim, bias=False),
|
232 |
+
nn.LayerNorm(inner_dim, elementwise_affine=True)
|
233 |
+
) if not identity_to_k else nn.Identity()
|
234 |
+
|
235 |
+
self.v_repeat = v_repeat
|
236 |
+
self.num_q_group = num_q_group = num_q // v_repeat # 416 / 4 = 104.
|
237 |
+
|
238 |
+
# If q_aware_to_v is True, then self.to_v consists of num_q projections of input_dim to inner_dim.
|
239 |
+
# Otherwise, self.to_v consists of a single projection of input_dim to inner_dim.
|
240 |
+
if q_aware_to_v:
|
241 |
+
# all_q_mid: 104 * 64 = 6656.
|
242 |
+
all_q_mid = num_q_group * q_aware_to_v_lora_rank
|
243 |
+
self.to_v = nn.Sequential(
|
244 |
+
# number of params: 768 * 6656 = 5,111,808.
|
245 |
+
# Input: [BS, 16, 768]. Output: [BS, 16, 104*64] = [BS, 16, 6656].
|
246 |
+
# Each 768-dim vec is dispersed into 104 64-dim vecs.
|
247 |
+
nn.Linear(input_dim, all_q_mid, bias=False),
|
248 |
+
nn.LayerNorm(all_q_mid, elementwise_affine=True),
|
249 |
+
# Change the dim of the tensor to [BS, 6656, 16], as Conv1d transforms dim 1.
|
250 |
+
Rearrange('b n q -> b q n', q=all_q_mid),
|
251 |
+
# Each q_aware_to_v projection has its own linear layer.
|
252 |
+
# The total number of parameters will be 6656*768 = 5,111,808.
|
253 |
+
# Output: [BS, 104*768, 16]. Each 64 dim feature is expanded to 768 dim.
|
254 |
+
nn.Conv1d(
|
255 |
+
in_channels=all_q_mid,
|
256 |
+
out_channels=num_q_group * input_dim,
|
257 |
+
kernel_size=1,
|
258 |
+
groups=num_q_group,
|
259 |
+
bias=False,
|
260 |
+
),
|
261 |
+
# Output: [BS, 104, 16, 768].
|
262 |
+
Rearrange('b (q d) n -> b q n d', q=num_q_group, d=input_dim),
|
263 |
+
nn.LayerNorm(input_dim, elementwise_affine=True),
|
264 |
+
)
|
265 |
+
else:
|
266 |
+
self.to_v = nn.Sequential(
|
267 |
+
nn.Linear(input_dim, inner_dim, bias=False),
|
268 |
+
nn.LayerNorm(inner_dim, elementwise_affine=True)
|
269 |
+
) if not identity_to_v else nn.Identity()
|
270 |
+
|
271 |
+
if identity_to_out:
|
272 |
+
assert not out_has_skip, "identity_to_out=True, then out_has_skip has to be False."
|
273 |
+
|
274 |
+
if identity_to_out:
|
275 |
+
self.to_out = nn.Identity()
|
276 |
+
else:
|
277 |
+
self.to_out = nn.Sequential(
|
278 |
+
nn.Linear(input_dim, input_dim, bias=False),
|
279 |
+
nn.Dropout(p_dropout),
|
280 |
+
nn.LayerNorm(inner_dim, elementwise_affine=True)
|
281 |
+
)
|
282 |
+
|
283 |
+
self.out_has_skip = out_has_skip
|
284 |
+
self.attn_drop = nn.Dropout(p_dropout)
|
285 |
+
|
286 |
+
def forward(self, x, context=None, attn_mat=None, return_attn=False):
|
287 |
+
h = self.num_heads
|
288 |
+
|
289 |
+
if context is None:
|
290 |
+
context = x
|
291 |
+
|
292 |
+
if attn_mat is None:
|
293 |
+
# q: [BS, Q, D] -> [BS, Q, D].
|
294 |
+
q = self.to_q(x)
|
295 |
+
# k: [BS, L, D] -> [BS, L, D].
|
296 |
+
k = self.to_k(context)
|
297 |
+
# q: [6, 512, 128], k: [6, 17, 128].
|
298 |
+
q, k = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k))
|
299 |
+
|
300 |
+
if self.q_aware_to_v:
|
301 |
+
# context: [BS, L, D]. v: [BS, Q, L, D].
|
302 |
+
# There are effectively Q to_v projections.
|
303 |
+
v = self.to_v(context)
|
304 |
+
if self.v_has_skip:
|
305 |
+
v = v + context.unsqueeze(1)
|
306 |
+
else:
|
307 |
+
# v: [BS, L, D].
|
308 |
+
v = self.to_v(context)
|
309 |
+
if self.v_has_skip:
|
310 |
+
v = v + context
|
311 |
+
|
312 |
+
#print(v.shape)
|
313 |
+
|
314 |
+
if self.q_aware_to_v:
|
315 |
+
# v: [6, 64, 17, 128].
|
316 |
+
# v is query-specific, so there's an extra dim for the query.
|
317 |
+
v = rearrange(v, 'b q n (h d) -> (b h) q n d', h=h)
|
318 |
+
# Each v is for a query group with 512/64 = 8 queries.
|
319 |
+
# So each v is repeated 8 times to match the number of queries.
|
320 |
+
# v: [6, 64, 17, 128] -> [6, 512, 17, 128].
|
321 |
+
v = v.repeat(1, self.v_repeat, 1, 1)
|
322 |
+
else:
|
323 |
+
v = rearrange(v, 'b n (h d) -> (b h) n d', h=h)
|
324 |
+
|
325 |
+
if attn_mat is None:
|
326 |
+
scale = q.size(-1) ** -0.25
|
327 |
+
sim = einsum('b i d, b j d -> b i j', q * scale, k * scale)
|
328 |
+
# sim: [6, 64, 17]. 6: bs 1 * h 6.
|
329 |
+
# attention, what we cannot get enough of
|
330 |
+
# NOTE: the normalization is done across tokens, not across pixels.
|
331 |
+
# So for each pixel, the sum of attention scores across tokens is 1.
|
332 |
+
attn = sim.softmax(dim=-1)
|
333 |
+
attn = self.attn_drop(attn)
|
334 |
+
#print(attn.std())
|
335 |
+
else:
|
336 |
+
attn = attn_mat
|
337 |
+
|
338 |
+
if self.q_aware_to_v:
|
339 |
+
# attn: [6, 32, 17]. v: [6, 32, 17, 128]. 128: dim of each head. out: [6, 32, 128].
|
340 |
+
# out is combined with different attn weights and v for different queries.
|
341 |
+
out = einsum('b i j, b i j d -> b i d', attn, v)
|
342 |
+
else:
|
343 |
+
# v: [6, 17, 128]. out: [6, 32, 128].
|
344 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
345 |
+
|
346 |
+
# [6, 32, 128] -> [1, 32, 768].
|
347 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
348 |
+
|
349 |
+
if self.out_has_skip:
|
350 |
+
out = self.to_out(out) + out
|
351 |
+
else:
|
352 |
+
out = self.to_out(out)
|
353 |
+
|
354 |
+
if return_attn:
|
355 |
+
return out, attn
|
356 |
+
else:
|
357 |
+
return out
|
358 |
+
|
359 |
+
class SubjBasisGenerator(nn.Module):
|
360 |
+
def __init__(
|
361 |
+
self,
|
362 |
+
# number of cross-attention heads. Half of the number of heads 12 of OpenAI clip-vit-large-patch14:
|
363 |
+
# https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
|
364 |
+
num_heads=6,
|
365 |
+
num_id_vecs={ 'subj': 77, 'bg': 257 }, # number of identity vectors. 18: 16 face tokens + 2 extra tokens. 257: 257 CLIP tokens.
|
366 |
+
num_out_embs_per_layer=4, # num_out_embs. subj: 16. bg: 4.
|
367 |
+
num_out_layers=16, # number of layers of output embeddings.
|
368 |
+
image_embedding_dim=768, # CLIP image feature dimension, as per config.json above.
|
369 |
+
# DINO vits16 has 6 attention heads:
|
370 |
+
# https://huggingface.co/facebook/dino-vits16/blob/main/config.json
|
371 |
+
dino_embedding_dim=384, # DINO object feature dimension for objects.
|
372 |
+
output_dim=768, # CLIP text embedding input dimension.
|
373 |
+
placeholder_is_bg: bool = False, # Whether the placeholder is for the image background.
|
374 |
+
prompt2token_proj_grad_scale: float = 0.4, # Gradient scale for prompt2token_proj.
|
375 |
+
zs_extra_words_scale: float = 0.5, # Scale for extra words in the prompt2token_proj.
|
376 |
+
learnable_hidden_state_weights_scheme: str = 'per-layer', # none, per-layer.
|
377 |
+
bg_prompt_translator_has_to_out_proj: bool = False, # Whether the prompt_trans_layers have a to_out projection.
|
378 |
+
):
|
379 |
+
super().__init__()
|
380 |
+
|
381 |
+
self.placeholder_is_bg = placeholder_is_bg
|
382 |
+
self.num_out_layers = num_out_layers
|
383 |
+
self.num_out_embs_per_layer = num_out_embs_per_layer
|
384 |
+
# subj: 64, bg: 32.
|
385 |
+
self.num_out_embs = num_out_layers * num_out_embs_per_layer
|
386 |
+
self.output_dim = output_dim
|
387 |
+
# num_id_vecs should be the number of core ID embs, 16.
|
388 |
+
# However, in such case, pos_embs is not used. So it doesn't matter if it's wrongly set.
|
389 |
+
self.num_id_vecs = num_id_vecs['bg'] if placeholder_is_bg else num_id_vecs['subj']
|
390 |
+
self.pos_embs = nn.Parameter(torch.randn(1, self.num_id_vecs, output_dim))
|
391 |
+
self.pos_embs_ln = nn.LayerNorm(output_dim)
|
392 |
+
self.zs_extra_words_scale = zs_extra_words_scale
|
393 |
+
self.output_scale = output_dim ** -0.5
|
394 |
+
self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
395 |
+
|
396 |
+
if not self.placeholder_is_bg:
|
397 |
+
# [1, 384] -> [1, 16, 768].
|
398 |
+
# TODO: use CLIPTextModelWrapper as obj_proj_in.
|
399 |
+
self.obj_proj_in = ExpandEmbs(dino_embedding_dim, output_dim, expansion_ratio=self.num_id_vecs)
|
400 |
+
|
401 |
+
# self.prompt2token_proj: [1, 16, 768] -> [1, 77, 768] (with paddings).
|
402 |
+
# If self.placeholder_is_bg: prompt2token_proj is set to None.
|
403 |
+
self.prompt2token_proj = CLIPTextModelWrapper.from_pretrained('openai/clip-vit-large-patch14')
|
404 |
+
self.prompt2token_proj_grad_scale = prompt2token_proj_grad_scale
|
405 |
+
self.prompt2token_proj_grad_scaler = gen_gradient_scaler(prompt2token_proj_grad_scale)
|
406 |
+
print(f"Subj prompt2token_proj initialized with grad scale of {prompt2token_proj_grad_scale}.")
|
407 |
+
# Freeze prompt2token_proj if prompt2token_proj_grad_scale is 0.
|
408 |
+
# Set requires_grad to False for all parameters in prompt2token_proj, to save memory taken by the optimizer.
|
409 |
+
if prompt2token_proj_grad_scale == 0:
|
410 |
+
self.freeze_prompt2token_proj()
|
411 |
+
|
412 |
+
self.prompt2token_proj_attention_multiplier = -1
|
413 |
+
self.initialize_hidden_state_layer_weights(learnable_hidden_state_weights_scheme, 'cpu')
|
414 |
+
self.pad_embeddings = None
|
415 |
+
self.bg_proj_in = None
|
416 |
+
else:
|
417 |
+
# For background placeholders, face and object embeddings are not used as they are foreground.
|
418 |
+
self.obj_proj_in = None
|
419 |
+
self.prompt2token_proj = None
|
420 |
+
print("Bg prompt2token_proj is set to None.")
|
421 |
+
|
422 |
+
self.bg_proj_in = nn.Sequential(
|
423 |
+
nn.Linear(image_embedding_dim, output_dim, bias=False),
|
424 |
+
nn.LayerNorm(output_dim),
|
425 |
+
)
|
426 |
+
|
427 |
+
self.latent_queries = nn.Parameter(torch.randn(1, self.num_out_embs, output_dim))
|
428 |
+
self.latent_queries_ln = nn.LayerNorm(output_dim)
|
429 |
+
|
430 |
+
self.bg_prompt_translator_has_to_out_proj = bg_prompt_translator_has_to_out_proj
|
431 |
+
identity_to_v = False
|
432 |
+
v_has_skip = not identity_to_v # True
|
433 |
+
identity_to_out = not bg_prompt_translator_has_to_out_proj # True
|
434 |
+
out_has_skip = not identity_to_out # False
|
435 |
+
# prompt_translator has a to_v projection with skip connection, and doesn't have a to_out projection.
|
436 |
+
# dim=768, num_heads=6.
|
437 |
+
self.prompt_translator = \
|
438 |
+
CrossAttention(input_dim=output_dim, num_heads=num_heads, p_dropout=0.05,
|
439 |
+
identity_to_q=False, identity_to_k=False, identity_to_v=identity_to_v,
|
440 |
+
q_aware_to_v=False, v_has_skip=v_has_skip,
|
441 |
+
num_q=0, # When not q_aware_to_v, num_q is not referenced.
|
442 |
+
identity_to_out=identity_to_out,
|
443 |
+
out_has_skip=out_has_skip)
|
444 |
+
'''
|
445 |
+
prompt_translator: CLIPEncoder
|
446 |
+
# https://github.com/huggingface/transformers/blob/1872bde7fc6a5d6796bd742bc2dc38eaf8069c5d/src/transformers/models/clip/modeling_clip.py#L566
|
447 |
+
# CLIPEncoder.layers: 12 layers of CLIPEncoderLayer, each being
|
448 |
+
(0): CLIPEncoderLayer(
|
449 |
+
(self_attn): CLIPAttention(
|
450 |
+
(k_proj): Linear(in_features=768, out_features=768, bias=True)
|
451 |
+
(v_proj): Linear(in_features=768, out_features=768, bias=True)
|
452 |
+
(q_proj): Linear(in_features=768, out_features=768, bias=True)
|
453 |
+
(out_proj): Linear(in_features=768, out_features=768, bias=True)
|
454 |
+
)
|
455 |
+
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
456 |
+
(mlp): CLIPMLP(
|
457 |
+
(activation_fn): QuickGELUActivation()
|
458 |
+
(fc1): Linear(in_features=768, out_features=3072, bias=True)
|
459 |
+
(fc2): Linear(in_features=3072, out_features=768, bias=True)
|
460 |
+
)
|
461 |
+
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
462 |
+
)
|
463 |
+
'''
|
464 |
+
|
465 |
+
print(repr(self))
|
466 |
+
|
467 |
+
# raw_id_embs: ArcFace embeddings for faces (not used since we have arc2face_id_embs),
|
468 |
+
# or DINO embeddings for objects.
|
469 |
+
# arc2face_id_embs: [BS, 16, 768], the core identity embeddings generated by Arc2Face.
|
470 |
+
def forward(self, arc2face_id_embs, clip_features=None, raw_id_embs=None, out_id_embs_scale=1.0,
|
471 |
+
is_face=True, is_training=False, adaface_prompt_embs_inf_type='full_half_pad'):
|
472 |
+
|
473 |
+
if not self.placeholder_is_bg:
|
474 |
+
BS = arc2face_id_embs.shape[0]
|
475 |
+
else:
|
476 |
+
# If bg, then arc2face_id_embs is set to None, but clip_features is not None.
|
477 |
+
BS = clip_features.shape[0]
|
478 |
+
|
479 |
+
adaface_prompt_embs = None
|
480 |
+
if not hasattr(self, 'clip_tokenizer'):
|
481 |
+
self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
482 |
+
|
483 |
+
# No need to use raw_id_embs if placeholder_is_bg.
|
484 |
+
if not self.placeholder_is_bg:
|
485 |
+
if is_face:
|
486 |
+
assert arc2face_id_embs is not None
|
487 |
+
# arc2face_embs has been projected to the (modified) prompt embedding space
|
488 |
+
# by arc2face_forward_face_embs. This prompt embedding space is modified because Arc2Face finetuned
|
489 |
+
# the text encoder and the U-Net.
|
490 |
+
# in embedding_manager: [BS, 16, 768] -> [BS, 77, 768].
|
491 |
+
# arc2face_id_embs is part of arc2face_embs: [BS, 77, 768] -> [BS, 16, 768].
|
492 |
+
# adaface_prompt_embs is projected to the prompt embedding spaces. This is the
|
493 |
+
# original U-Net prompt embedding space.
|
494 |
+
|
495 |
+
# hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]]
|
496 |
+
hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights)
|
497 |
+
# return_emb_types: a list of strings, each string is among
|
498 |
+
# ['full', 'core', 'full_pad', 'full_half_pad', 'full_zeroed_extra', 'b_core_e'].
|
499 |
+
# Using b_core_e is more computationally efficient than using full_zeroed_extra.
|
500 |
+
# But there is an unknow BUG that causes crash when using b_core_e.
|
501 |
+
if is_training:
|
502 |
+
return_emb_types = ['full_pad', 'core']
|
503 |
+
else:
|
504 |
+
# adaface_prompt_embs_inf_type: default is full_half_pad, same as training.
|
505 |
+
return_emb_types = [adaface_prompt_embs_inf_type, 'core']
|
506 |
+
|
507 |
+
if self.pad_embeddings is None:
|
508 |
+
self.generate_pad_embeddings()
|
509 |
+
else:
|
510 |
+
self.pad_embeddings = self.pad_embeddings.to(arc2face_id_embs.device)
|
511 |
+
|
512 |
+
with torch.set_grad_enabled(self.training and self.prompt2token_proj_grad_scale != 0):
|
513 |
+
# If list_extra_words is not None, then core_id_embs: [BS, 18, 768], three leading words, the 16 identity tokens
|
514 |
+
# and (at most) two extra words in full_prompt_embs, without BOS and EOS.
|
515 |
+
# If list_extra_words is None, then core_id_embs: [BS, 16, 768], the 16 identity tokens in full_prompt_embs.
|
516 |
+
# hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]]
|
517 |
+
# zs_extra_words_scale is only effective when list_extra_words is not None.
|
518 |
+
# adaface_prompt_embs: [BS, 77, 768], core_id_embs: [BS, 16, 768].
|
519 |
+
adaface_prompt_embs, core_id_embs = \
|
520 |
+
arc2face_inverse_face_prompt_embs(self.clip_tokenizer,
|
521 |
+
self.prompt2token_proj,
|
522 |
+
arc2face_id_embs,
|
523 |
+
list_extra_words=None,
|
524 |
+
return_emb_types=return_emb_types,
|
525 |
+
pad_embeddings=self.pad_embeddings,
|
526 |
+
hidden_state_layer_weights=hidden_state_layer_weights,
|
527 |
+
input_max_length=77, zs_extra_words_scale=self.zs_extra_words_scale)
|
528 |
+
# Reduce the update rate to prompt2token_proj.
|
529 |
+
adaface_prompt_embs = self.prompt2token_proj_grad_scaler(adaface_prompt_embs)
|
530 |
+
core_id_embs = self.prompt2token_proj_grad_scaler(core_id_embs)
|
531 |
+
elif raw_id_embs is not None:
|
532 |
+
# id_embs: [BS, 384] -> [BS, 18, 768].
|
533 |
+
# obj_proj_in is expected to project the DINO object features to
|
534 |
+
# the token embedding space. So no need to use prompt2token_proj.
|
535 |
+
id_embs = self.obj_proj_in(raw_id_embs)
|
536 |
+
else:
|
537 |
+
breakpoint()
|
538 |
+
else:
|
539 |
+
# Otherwise, context is the ad-hoc CLIP image features.
|
540 |
+
# id_embs: [BS, 257, 768].
|
541 |
+
id_embs = self.bg_proj_in(clip_features)
|
542 |
+
|
543 |
+
if self.placeholder_is_bg:
|
544 |
+
id_embs = id_embs + self.pos_embs_ln(self.pos_embs)
|
545 |
+
latent_queries = self.latent_queries_ln(self.latent_queries).repeat(BS, 1, 1)
|
546 |
+
# If bg, we don't have to use a specific attn layer for each 4-vec set. Instead, one attn layer can generate 257 embs,
|
547 |
+
# and we take the first 16*4=64.
|
548 |
+
# Output of prompt_translator is exactly num_out_embs == 64 tokens. id_embs_out: [BS, 64, 768].
|
549 |
+
# prompt_translator: better named as bg_prompt_translator. It maps the bg features
|
550 |
+
# to bg prompt embeddings.
|
551 |
+
with torch.set_grad_enabled(self.training):
|
552 |
+
id_embs_out = self.prompt_translator(latent_queries, id_embs)
|
553 |
+
# [BS, 64, 768] -> [BS, 16, 4, 768]
|
554 |
+
id_embs_out = id_embs_out.reshape(BS, self.num_out_layers, -1, self.output_dim)
|
555 |
+
adaface_subj_embs = id_embs_out * self.output_scale # * 0.036
|
556 |
+
else:
|
557 |
+
# adaface_subj_embs: [BS, 16, 768] -> [BS, 1, 16, 768] -> [BS, 16, 16, 768]
|
558 |
+
adaface_subj_embs = core_id_embs.unsqueeze(1).repeat(1, self.num_out_layers, 1, 1)
|
559 |
+
|
560 |
+
# If out_id_embs_scale < 1, adaface_subj_embs is a mix of adaface_subj_embs and pad_embeddings.
|
561 |
+
if out_id_embs_scale != 1:
|
562 |
+
# pad_embeddings: [77, 768] -> [16, 768] -> [1, 1, 16, 768].
|
563 |
+
pad_embeddings = self.pad_embeddings[4:4+self.num_out_embs_per_layer].unsqueeze(0).unsqueeze(0)
|
564 |
+
adaface_subj_embs = adaface_subj_embs * out_id_embs_scale \
|
565 |
+
+ pad_embeddings * (1 - out_id_embs_scale)
|
566 |
+
|
567 |
+
return adaface_subj_embs, adaface_prompt_embs
|
568 |
+
|
569 |
+
def initialize_hidden_state_layer_weights(self, learnable_hidden_state_weights_scheme, device):
|
570 |
+
if learnable_hidden_state_weights_scheme == 'none':
|
571 |
+
self.hidden_state_layer_weights = None
|
572 |
+
# A grad scaler with alpha =1 is nn.Identity(), which outputs None given None as input.
|
573 |
+
self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(1)
|
574 |
+
print("hidden_state_layer_weights is set to None.")
|
575 |
+
|
576 |
+
elif learnable_hidden_state_weights_scheme == 'per-layer':
|
577 |
+
# Learnable weights of the last 3 layers, initialized to putting more focus on the last layer.
|
578 |
+
# 'per-layer': Different weights for different layers, but the same for different channels.
|
579 |
+
# hidden_state_layer_weights: [3, 1].
|
580 |
+
self.hidden_state_layer_weights = nn.Parameter(torch.tensor([[1.0], [2.0], [4.0]], device=device),
|
581 |
+
requires_grad=True)
|
582 |
+
self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(5)
|
583 |
+
print("hidden_state_layer_weights initialized as per-layer [1, 2, 4], with grad scaler 5.")
|
584 |
+
else:
|
585 |
+
breakpoint()
|
586 |
+
|
587 |
+
def generate_pad_embeddings(self):
|
588 |
+
# clip_embeddings: CLIPTextEmbeddings instance. pad_embeddings is generated after
|
589 |
+
# prompt2token_proj is loaded from the finetuned weight. It seems such pad embeddings perform
|
590 |
+
# slightly better than the original pad embeddings.
|
591 |
+
clip_embeddings = self.prompt2token_proj.text_model.embeddings
|
592 |
+
# clip_embeddings() and clip_embeddings.token_embedding() differ in that
|
593 |
+
# clip_embeddings() adds positional embeddings, while clip_embeddings.token_embedding() doesn't.
|
594 |
+
# Adding positional embeddings seems to help somewhat.
|
595 |
+
# pad_tokens: pad_token_id 49407 repeated 77 times.
|
596 |
+
# pad_token_id is the EOS token. But BOS is 49406.
|
597 |
+
pad_tokens = torch.tensor([self.clip_tokenizer.pad_token_id]).to(clip_embeddings.token_embedding.weight.device).repeat(77)
|
598 |
+
# pad_embeddings: [77, 768].
|
599 |
+
pad_embeddings = clip_embeddings(pad_tokens)[0]
|
600 |
+
# We don't allow face recon to influence the pad embeddings.
|
601 |
+
# Otherwise, face identity will leak into the pad embeddings.
|
602 |
+
self.pad_embeddings = pad_embeddings.detach()
|
603 |
+
|
604 |
+
def extend_prompt2token_proj_attention(self, begin_layer_idx=-1, end_layer_idx=-1, multiplier=2, noise_std=0.1):
|
605 |
+
if multiplier > 1:
|
606 |
+
num_extended_layers = self.prompt2token_proj.extend_clip_attention_MKV_multiplier(begin_layer_idx, end_layer_idx, multiplier, noise_std)
|
607 |
+
self.prompt2token_proj_attention_multiplier = multiplier
|
608 |
+
print(f"{num_extended_layers} layers in prompt2token_proj_attention are x{multiplier}")
|
609 |
+
|
610 |
+
def freeze_prompt2token_proj(self):
|
611 |
+
# If bg, then prompt2token_proj is set to None. Therefore no need to freeze it.
|
612 |
+
# Then we don't have to check whether it's for subj or bg.
|
613 |
+
if self.prompt2token_proj is not None:
|
614 |
+
frozen_param_names = []
|
615 |
+
for param_name, param in self.prompt2token_proj.named_parameters():
|
616 |
+
if param.requires_grad:
|
617 |
+
param.requires_grad = False
|
618 |
+
frozen_param_names.append(param_name)
|
619 |
+
# If param is already frozen, then no need to freeze it again.
|
620 |
+
print(f"{len(frozen_param_names)} params in Subj prompt2token_proj is frozen.")
|
621 |
+
#print(f"Frozen parameters:\n{frozen_param_names}")
|
622 |
+
|
623 |
+
def __repr__(self):
|
624 |
+
type_sig = 'subj' if not self.placeholder_is_bg else 'bg'
|
625 |
+
# Fix compatability with the previous version.
|
626 |
+
if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'):
|
627 |
+
self.bg_prompt_translator_has_to_out_proj = False
|
628 |
+
if not hasattr(self, 'num_out_embs'):
|
629 |
+
self.num_out_embs = -1
|
630 |
+
return f"{type_sig} SubjBasisGenerator: num_out_embs={self.num_out_embs}, " \
|
631 |
+
f"bg_prompt_translator_has_to_out_proj={self.bg_prompt_translator_has_to_out_proj}"
|
632 |
+
|
633 |
+
@dataclass
|
634 |
+
class BaseModelOutputWithPooling2(ModelOutput):
|
635 |
+
"""
|
636 |
+
Base class for model's outputs that also contains a pooling of the last hidden states.
|
637 |
+
|
638 |
+
Args:
|
639 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
640 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
641 |
+
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
|
642 |
+
Last layer hidden-state of the first token of the sequence (classification token) after further processing
|
643 |
+
through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
|
644 |
+
the classification token after processing through a linear layer and a tanh activation function. The linear
|
645 |
+
layer weights are trained from the next sentence prediction (classification) objective during pretraining.
|
646 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
647 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
648 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
649 |
+
|
650 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
651 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
652 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
653 |
+
sequence_length)`.
|
654 |
+
|
655 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
656 |
+
heads.
|
657 |
+
"""
|
658 |
+
|
659 |
+
last_hidden_state: torch.FloatTensor = None
|
660 |
+
pooler_output: torch.FloatTensor = None
|
661 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
662 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
663 |
+
attn_mask: Optional[torch.FloatTensor] = None
|
664 |
+
|
665 |
+
# Revised from CLIPVisionTransformer to support attention mask.
|
666 |
+
# self: a CLIPVisionTransformer instance.
|
667 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L821
|
668 |
+
# pixel_values: preprocessed B*C*H*W images. [BS, 3, 224, 224]
|
669 |
+
# attn_mask: B*H*W attention mask.
|
670 |
+
def CLIPVisionTransformer_forward(self, pixel_values = None, attn_mask=None,
|
671 |
+
output_attentions = None,
|
672 |
+
output_hidden_states = None, return_dict = None):
|
673 |
+
|
674 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
675 |
+
output_hidden_states = (
|
676 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
677 |
+
)
|
678 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
679 |
+
|
680 |
+
if pixel_values is None:
|
681 |
+
raise ValueError("You have to specify pixel_values")
|
682 |
+
|
683 |
+
# Visual tokens are flattended in embeddings().
|
684 |
+
# self.embeddings: CLIPVisionEmbeddings.
|
685 |
+
# hidden_states: [BS, 257, 1280]. 257: 16*16 (patch_embeds) + 1 (class_embeds).
|
686 |
+
# 16*16 is output from Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14), bias=False).
|
687 |
+
hidden_states = self.embeddings(pixel_values)
|
688 |
+
hidden_states = self.pre_layrnorm(hidden_states)
|
689 |
+
|
690 |
+
if attn_mask is not None:
|
691 |
+
# feat_edge_size: 16.
|
692 |
+
feat_edge_size = np.sqrt(hidden_states.shape[1] - 1).astype(int)
|
693 |
+
# attn_mask: [BS, 512, 512] -> [BS, 1, 16, 16].
|
694 |
+
attn_mask = F.interpolate(attn_mask.unsqueeze(1), size=(feat_edge_size, feat_edge_size), mode='nearest')
|
695 |
+
# Flatten the mask: [BS, 1, 16, 16] => [BS, 1, 256].
|
696 |
+
attn_mask = attn_mask.flatten(2)
|
697 |
+
# Prepend 1 to the mask: [BS, 1, 256] => [BS, 1, 257].
|
698 |
+
# This 1 corresponds to class_embeds, which is always attended to.
|
699 |
+
attn_mask = torch.cat([torch.ones_like(attn_mask[:, :, :1]), attn_mask], dim=-1)
|
700 |
+
attn_mask_pairs = torch.matmul(attn_mask.transpose(-1, -2), attn_mask).unsqueeze(1)
|
701 |
+
else:
|
702 |
+
attn_mask_pairs = None
|
703 |
+
|
704 |
+
# encoder: CLIPEncoder.
|
705 |
+
encoder_outputs = self.encoder(
|
706 |
+
inputs_embeds=hidden_states,
|
707 |
+
# New feature: (***The official documentation is wrong***)
|
708 |
+
# attention_mask (`torch.Tensor` of shape `(batch_size, 1, sequence_length, sequence_length)`, *optional*):
|
709 |
+
# Mask to avoid performing attention on pairs of token. Mask values selected in `[0, 1]`:
|
710 |
+
# - 1 for pairs that are **not masked**,
|
711 |
+
# - 0 for pairs that are **masked**.
|
712 |
+
# attention_mask is eventually used by CLIPEncoderLayer:
|
713 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L370
|
714 |
+
attention_mask=attn_mask_pairs,
|
715 |
+
output_attentions=output_attentions, # False
|
716 |
+
output_hidden_states=output_hidden_states, # True
|
717 |
+
return_dict=return_dict, # True
|
718 |
+
)
|
719 |
+
|
720 |
+
# last_hidden_state: [BS, 257, 1280]
|
721 |
+
last_hidden_state = encoder_outputs[0]
|
722 |
+
pooled_output = last_hidden_state[:, 0, :]
|
723 |
+
pooled_output = self.post_layernorm(pooled_output)
|
724 |
+
|
725 |
+
# return_dict is True.
|
726 |
+
if not return_dict:
|
727 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
728 |
+
|
729 |
+
return BaseModelOutputWithPooling2(
|
730 |
+
last_hidden_state=last_hidden_state,
|
731 |
+
pooler_output=pooled_output,
|
732 |
+
hidden_states=encoder_outputs.hidden_states,
|
733 |
+
attentions=encoder_outputs.attentions,
|
734 |
+
# Newly added: return resized flattened attention mask.
|
735 |
+
# [BS, 1, 257] -> [BS, 257, 1]
|
736 |
+
attn_mask=attn_mask.permute(0, 2, 1) if attn_mask is not None else None
|
737 |
+
)
|
738 |
+
|
739 |
+
|
740 |
+
class CLIPVisionModelWithMask(CLIPVisionModel):
|
741 |
+
def __init__(self, config):
|
742 |
+
super().__init__(config)
|
743 |
+
# Replace vision_model.forward() with the new one that supports mask.
|
744 |
+
self.vision_model.forward = CLIPVisionTransformer_forward.__get__(self.vision_model)
|
745 |
+
|
746 |
+
def forward(self, pixel_values = None, attn_mask = None, output_attentions = None,
|
747 |
+
output_hidden_states = None, return_dict = None):
|
748 |
+
|
749 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
750 |
+
|
751 |
+
return self.vision_model(
|
752 |
+
pixel_values=pixel_values,
|
753 |
+
attn_mask=attn_mask,
|
754 |
+
output_attentions=output_attentions,
|
755 |
+
output_hidden_states=output_hidden_states,
|
756 |
+
return_dict=return_dict,
|
757 |
+
)
|
758 |
+
|
adaface/util.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
# add_noise_to_tensor() adds a fixed amount of noise to the tensor.
|
9 |
+
def add_noise_to_tensor(ts, noise_std, noise_std_is_relative=True, keep_norm=False,
|
10 |
+
std_dim=-1, norm_dim=-1):
|
11 |
+
if noise_std_is_relative:
|
12 |
+
ts_std_mean = ts.std(dim=std_dim).mean().detach()
|
13 |
+
noise_std *= ts_std_mean
|
14 |
+
|
15 |
+
noise = torch.randn_like(ts) * noise_std
|
16 |
+
if keep_norm:
|
17 |
+
orig_norm = ts.norm(dim=norm_dim, keepdim=True)
|
18 |
+
ts = ts + noise
|
19 |
+
new_norm = ts.norm(dim=norm_dim, keepdim=True).detach()
|
20 |
+
ts = ts * orig_norm / (new_norm + 1e-8)
|
21 |
+
else:
|
22 |
+
ts = ts + noise
|
23 |
+
|
24 |
+
return ts
|
25 |
+
|
26 |
+
|
27 |
+
# Revised from RevGrad, by removing the grad negation.
|
28 |
+
class ScaleGrad(torch.autograd.Function):
|
29 |
+
@staticmethod
|
30 |
+
def forward(ctx, input_, alpha_, debug=False):
|
31 |
+
ctx.save_for_backward(alpha_, debug)
|
32 |
+
output = input_
|
33 |
+
if debug:
|
34 |
+
print(f"input: {input_.abs().mean().item()}")
|
35 |
+
return output
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def backward(ctx, grad_output): # pragma: no cover
|
39 |
+
# saved_tensors returns a tuple of tensors.
|
40 |
+
alpha_, debug = ctx.saved_tensors
|
41 |
+
if ctx.needs_input_grad[0]:
|
42 |
+
grad_output2 = grad_output * alpha_
|
43 |
+
if debug:
|
44 |
+
print(f"grad_output2: {grad_output2.abs().mean().item()}")
|
45 |
+
else:
|
46 |
+
grad_output2 = None
|
47 |
+
return grad_output2, None, None
|
48 |
+
|
49 |
+
class GradientScaler(nn.Module):
|
50 |
+
def __init__(self, alpha=1., debug=False, *args, **kwargs):
|
51 |
+
"""
|
52 |
+
A gradient scaling layer.
|
53 |
+
This layer has no parameters, and simply scales the gradient in the backward pass.
|
54 |
+
"""
|
55 |
+
super().__init__(*args, **kwargs)
|
56 |
+
|
57 |
+
self._alpha = torch.tensor(alpha, requires_grad=False)
|
58 |
+
self._debug = torch.tensor(debug, requires_grad=False)
|
59 |
+
|
60 |
+
def forward(self, input_):
|
61 |
+
_debug = self._debug if hasattr(self, '_debug') else False
|
62 |
+
return ScaleGrad.apply(input_, self._alpha.to(input_.device), _debug)
|
63 |
+
|
64 |
+
def gen_gradient_scaler(alpha, debug=False):
|
65 |
+
if alpha == 1:
|
66 |
+
return nn.Identity()
|
67 |
+
if alpha > 0:
|
68 |
+
return GradientScaler(alpha, debug=debug)
|
69 |
+
else:
|
70 |
+
assert alpha == 0
|
71 |
+
# Don't use lambda function here, otherwise the object can't be pickled.
|
72 |
+
return torch.detach
|
73 |
+
|
74 |
+
#@torch.autocast(device_type="cuda")
|
75 |
+
# In AdaFaceWrapper, input_max_length is 22.
|
76 |
+
def arc2face_forward_face_embs(tokenizer, arc2face_text_encoder, face_embs,
|
77 |
+
input_max_length=77, return_full_and_core_embs=True):
|
78 |
+
|
79 |
+
'''
|
80 |
+
arc2face_text_encoder: arc2face_models.py CLIPTextModelWrapper instance.
|
81 |
+
face_embs: (N, 512) normalized ArcFace embeddings.
|
82 |
+
return_full_and_core_embs: Return both the full prompt embeddings and the core embeddings.
|
83 |
+
If False, return only the core embeddings.
|
84 |
+
|
85 |
+
'''
|
86 |
+
|
87 |
+
# arcface_token_id: 1014
|
88 |
+
arcface_token_id = tokenizer.encode("id", add_special_tokens=False)[0]
|
89 |
+
|
90 |
+
# This step should be quite fast, and there's no need to cache the input_ids.
|
91 |
+
input_ids = tokenizer(
|
92 |
+
"photo of a id person",
|
93 |
+
truncation=True,
|
94 |
+
padding="max_length",
|
95 |
+
max_length=input_max_length, #tokenizer.model_max_length,
|
96 |
+
return_tensors="pt",
|
97 |
+
).input_ids.to(face_embs.device)
|
98 |
+
# input_ids: [1, 77] or [3, 77] (during training).
|
99 |
+
input_ids = input_ids.repeat(len(face_embs), 1)
|
100 |
+
face_embs_dtype = face_embs.dtype
|
101 |
+
face_embs = face_embs.to(arc2face_text_encoder.dtype)
|
102 |
+
# face_embs_padded: [1, 512] -> [1, 768].
|
103 |
+
face_embs_padded = F.pad(face_embs, (0, arc2face_text_encoder.config.hidden_size - face_embs.shape[-1]), "constant", 0)
|
104 |
+
# arc2face_text_encoder(input_ids=input_ids, ...) is called twice. The first is only to get the token embeddings (the shallowest mapping).
|
105 |
+
# The second call does the ordinary CLIP text encoding pass.
|
106 |
+
token_embs = arc2face_text_encoder(input_ids=input_ids, return_token_embs=True)
|
107 |
+
token_embs[input_ids==arcface_token_id] = face_embs_padded
|
108 |
+
|
109 |
+
prompt_embeds = arc2face_text_encoder(
|
110 |
+
input_ids=input_ids,
|
111 |
+
input_token_embs=token_embs,
|
112 |
+
return_token_embs=False
|
113 |
+
)[0]
|
114 |
+
|
115 |
+
# Restore the original dtype of prompt_embeds: float16 -> float32.
|
116 |
+
prompt_embeds = prompt_embeds.to(face_embs_dtype)
|
117 |
+
|
118 |
+
if return_full_and_core_embs:
|
119 |
+
# token 4: 'id' in "photo of a id person".
|
120 |
+
# 4:20 are the most important 16 embeddings that contain the subject's identity.
|
121 |
+
# [N, 77, 768] -> [N, 16, 768]
|
122 |
+
return prompt_embeds, prompt_embeds[:, 4:20]
|
123 |
+
else:
|
124 |
+
# [N, 16, 768]
|
125 |
+
return prompt_embeds[:, 4:20]
|
126 |
+
|
127 |
+
def get_b_core_e_embeddings(prompt_embeds, length=22):
|
128 |
+
b_core_e_embs = torch.cat([ prompt_embeds[:, :length], prompt_embeds[:, [-1]] ], dim=1)
|
129 |
+
return b_core_e_embs
|
130 |
+
|
131 |
+
# return_emb_types: a list of strings, each string is among ['full', 'core', 'full_zeroed_extra', 'b_core_e'].
|
132 |
+
def arc2face_inverse_face_prompt_embs(clip_tokenizer, inverse_text_encoder, face_prompt_embs, list_extra_words,
|
133 |
+
return_emb_types, pad_embeddings, hidden_state_layer_weights=None,
|
134 |
+
input_max_length=77, zs_extra_words_scale=0.5):
|
135 |
+
|
136 |
+
'''
|
137 |
+
inverse_text_encoder: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**.
|
138 |
+
inverse_text_encoder is NOT the original arc2face text encoder, but retrained to do inverse mapping.
|
139 |
+
face_prompt_embs: (BS, 16, 768). Only the core embeddings, no paddings.
|
140 |
+
list_extra_words: [s_1, ..., s_BS], each s_i is a list of extra words to be added to the prompt.
|
141 |
+
return_full_and_core_embs: Return both the full prompt embeddings and the core embeddings.
|
142 |
+
If False, return only the core embeddings.
|
143 |
+
'''
|
144 |
+
|
145 |
+
if list_extra_words is not None:
|
146 |
+
if len(list_extra_words) != len(face_prompt_embs):
|
147 |
+
if len(face_prompt_embs) > 1:
|
148 |
+
print("Warn: list_extra_words has different length as face_prompt_embs.")
|
149 |
+
if len(list_extra_words) == 1:
|
150 |
+
list_extra_words = list_extra_words * len(face_prompt_embs)
|
151 |
+
else:
|
152 |
+
breakpoint()
|
153 |
+
else:
|
154 |
+
# len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in do_mix_prompt_distillation.
|
155 |
+
# But list_extra_words always corresponds to the actual batch size. So we only take the first element.
|
156 |
+
list_extra_words = list_extra_words[:1]
|
157 |
+
|
158 |
+
for extra_words in list_extra_words:
|
159 |
+
assert len(extra_words.split()) <= 2, "Each extra_words string should consist of at most 2 words."
|
160 |
+
# 16 ", " are placeholders for face_prompt_embs.
|
161 |
+
prompt_templates = [ "photo of a " + ", " * 16 + list_extra_words[i] for i in range(len(list_extra_words)) ]
|
162 |
+
else:
|
163 |
+
# 16 ", " are placeholders for face_prompt_embs.
|
164 |
+
# No extra words are added to the prompt.
|
165 |
+
prompt_templates = [ "photo of a " + ", " * 16 for _ in range(len(face_prompt_embs)) ]
|
166 |
+
|
167 |
+
# This step should be quite fast, and there's no need to cache the input_ids.
|
168 |
+
# input_ids: [BS, 77].
|
169 |
+
input_ids = clip_tokenizer(
|
170 |
+
prompt_templates,
|
171 |
+
truncation=True,
|
172 |
+
padding="max_length",
|
173 |
+
max_length=input_max_length,
|
174 |
+
return_tensors="pt",
|
175 |
+
).input_ids.to(face_prompt_embs.device)
|
176 |
+
|
177 |
+
face_prompt_embs_dtype = face_prompt_embs.dtype
|
178 |
+
face_prompt_embs = face_prompt_embs.to(inverse_text_encoder.dtype)
|
179 |
+
|
180 |
+
# token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping).
|
181 |
+
token_embs = inverse_text_encoder(input_ids=input_ids, return_token_embs=True)
|
182 |
+
# token 4: first ", " in the template prompt.
|
183 |
+
# Replace embeddings of 16 placeholder ", " with face_prompt_embs.
|
184 |
+
token_embs[:, 4:20] = face_prompt_embs
|
185 |
+
|
186 |
+
# This call does the ordinary CLIP text encoding pass.
|
187 |
+
prompt_embeds = inverse_text_encoder(
|
188 |
+
input_ids=input_ids,
|
189 |
+
input_token_embs=token_embs,
|
190 |
+
hidden_state_layer_weights=hidden_state_layer_weights,
|
191 |
+
return_token_embs=False
|
192 |
+
)[0]
|
193 |
+
|
194 |
+
# Restore the original dtype of prompt_embeds: float16 -> float32.
|
195 |
+
prompt_embeds = prompt_embeds.to(face_prompt_embs_dtype)
|
196 |
+
# token 4: first ", " in the template prompt.
|
197 |
+
# 4:20 are the most important 16 embeddings that contain the subject's identity.
|
198 |
+
# 20:22 are embeddings of the (at most) two extra words.
|
199 |
+
# [N, 77, 768] -> [N, 16, 768]
|
200 |
+
core_prompt_embs = prompt_embeds[:, 4:20]
|
201 |
+
if list_extra_words is not None:
|
202 |
+
# [N, 16, 768] -> [N, 18, 768]
|
203 |
+
extra_words_embs = prompt_embeds[:, 20:22] * zs_extra_words_scale
|
204 |
+
core_prompt_embs = torch.cat([core_prompt_embs, extra_words_embs], dim=1)
|
205 |
+
|
206 |
+
return_prompts = []
|
207 |
+
for emb_type in return_emb_types:
|
208 |
+
if emb_type == 'full':
|
209 |
+
return_prompts.append(prompt_embeds)
|
210 |
+
elif emb_type == 'full_half_pad':
|
211 |
+
prompt_embeds2 = prompt_embeds.clone()
|
212 |
+
PADS = prompt_embeds2.shape[1] - 23
|
213 |
+
if PADS >= 2:
|
214 |
+
# Fill half of the remaining embeddings with pad embeddings.
|
215 |
+
prompt_embeds2[:, 22:22+PADS//2] = pad_embeddings[22:22+PADS//2]
|
216 |
+
return_prompts.append(prompt_embeds2)
|
217 |
+
elif emb_type == 'full_pad':
|
218 |
+
prompt_embeds2 = prompt_embeds.clone()
|
219 |
+
# Fill the 22nd to the second last embeddings with pad embeddings.
|
220 |
+
prompt_embeds2[:, 22:-1] = pad_embeddings[22:-1]
|
221 |
+
return_prompts.append(prompt_embeds2)
|
222 |
+
elif emb_type == 'core':
|
223 |
+
return_prompts.append(core_prompt_embs)
|
224 |
+
elif emb_type == 'full_zeroed_extra':
|
225 |
+
prompt_embeds2 = prompt_embeds.clone()
|
226 |
+
# Only add two pad embeddings. The remaining embeddings are set to 0.
|
227 |
+
# Make the positional embeddings align with the actual positions.
|
228 |
+
prompt_embeds2[:, 22:24] = pad_embeddings[22:24]
|
229 |
+
prompt_embeds2[:, 24:-1] = 0
|
230 |
+
return_prompts.append(prompt_embeds2)
|
231 |
+
elif emb_type == 'b_core_e':
|
232 |
+
# The first 22 embeddings, plus the last EOS embedding.
|
233 |
+
b_core_e_embs = get_b_core_e_embeddings(prompt_embeds, length=22)
|
234 |
+
return_prompts.append(b_core_e_embs)
|
235 |
+
else:
|
236 |
+
breakpoint()
|
237 |
+
|
238 |
+
return return_prompts
|
239 |
+
|
240 |
+
# if pre_face_embs is None, generate random face embeddings [BS, 512].
|
241 |
+
# image_folder is passed only for logging purpose. image_paths contains the paths of the images.
|
242 |
+
def get_arc2face_id_prompt_embs(face_app, clip_tokenizer, arc2face_text_encoder,
|
243 |
+
extract_faceid_embeds, pre_face_embs,
|
244 |
+
image_folder, image_paths, images_np,
|
245 |
+
id_batch_size, device,
|
246 |
+
input_max_length=77, noise_level=0.0,
|
247 |
+
return_core_id_embs=False,
|
248 |
+
gen_neg_prompt=False, verbose=False):
|
249 |
+
if extract_faceid_embeds:
|
250 |
+
image_count = 0
|
251 |
+
faceid_embeds = []
|
252 |
+
if image_paths is not None:
|
253 |
+
images_np = []
|
254 |
+
for image_path in image_paths:
|
255 |
+
image_np = np.array(Image.open(image_path))
|
256 |
+
images_np.append(image_np)
|
257 |
+
|
258 |
+
for i, image_np in enumerate(images_np):
|
259 |
+
image_obj = Image.fromarray(image_np).resize((512, 512), Image.NEAREST)
|
260 |
+
# Remove alpha channel if it exists.
|
261 |
+
if image_obj.mode == 'RGBA':
|
262 |
+
image_obj = image_obj.convert('RGB')
|
263 |
+
# This seems NOT a bug. The input image should be in BGR format, as per
|
264 |
+
# https://github.com/deepinsight/insightface/issues/524
|
265 |
+
image_np = cv2.cvtColor(np.array(image_obj), cv2.COLOR_RGB2BGR)
|
266 |
+
image_np = np.array(image_obj)
|
267 |
+
|
268 |
+
face_infos = face_app.get(image_np)
|
269 |
+
if verbose and image_paths is not None:
|
270 |
+
print(image_paths[i], len(face_infos))
|
271 |
+
# Assume all images belong to the same subject. Therefore, we can skip the images with no face detected.
|
272 |
+
if len(face_infos) == 0:
|
273 |
+
continue
|
274 |
+
# only use the maximum face
|
275 |
+
face_info = sorted(face_infos, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1]
|
276 |
+
# Each faceid_embed: [1, 512]
|
277 |
+
faceid_embeds.append(torch.from_numpy(face_info.normed_embedding).unsqueeze(0))
|
278 |
+
image_count += 1
|
279 |
+
|
280 |
+
if verbose:
|
281 |
+
if image_folder is not None:
|
282 |
+
print(f"Extracted ID embeddings from {image_count} images in {image_folder}")
|
283 |
+
else:
|
284 |
+
print(f"Extracted ID embeddings from {image_count} images")
|
285 |
+
|
286 |
+
if len(faceid_embeds) == 0:
|
287 |
+
print("No face detected. Use a random face instead.")
|
288 |
+
faceid_embeds = torch.randn(id_batch_size, 512).to(device=device, dtype=torch.float16)
|
289 |
+
else:
|
290 |
+
# faceid_embeds: [10, 512]
|
291 |
+
faceid_embeds = torch.cat(faceid_embeds, dim=0)
|
292 |
+
# faceid_embeds: [10, 512] -> [1, 512].
|
293 |
+
# and the resulted prompt embeddings are the same.
|
294 |
+
faceid_embeds = faceid_embeds.mean(dim=0, keepdim=True).to(device=device, dtype=torch.float16)
|
295 |
+
else:
|
296 |
+
# Random face embeddings. faceid_embeds: [BS, 512].
|
297 |
+
if pre_face_embs is None:
|
298 |
+
faceid_embeds = torch.randn(id_batch_size, 512)
|
299 |
+
else:
|
300 |
+
faceid_embeds = pre_face_embs
|
301 |
+
if pre_face_embs.shape[0] == 1:
|
302 |
+
faceid_embeds = faceid_embeds.repeat(id_batch_size, 1)
|
303 |
+
|
304 |
+
faceid_embeds = faceid_embeds.to(device=device, dtype=torch.float16)
|
305 |
+
|
306 |
+
if noise_level > 0:
|
307 |
+
# If id_batch_size > 1, after adding noises, the id_batch_size embeddings will be different.
|
308 |
+
faceid_embeds = add_noise_to_tensor(faceid_embeds, noise_level, noise_std_is_relative=True, keep_norm=True)
|
309 |
+
|
310 |
+
faceid_embeds = F.normalize(faceid_embeds, p=2, dim=-1)
|
311 |
+
|
312 |
+
# arc2face_pos_prompt_emb, arc2face_neg_prompt_emb: [BS, 77, 768]
|
313 |
+
with torch.no_grad():
|
314 |
+
arc2face_pos_prompt_emb, arc2face_pos_core_prompt_emb = \
|
315 |
+
arc2face_forward_face_embs(clip_tokenizer, arc2face_text_encoder,
|
316 |
+
faceid_embeds, input_max_length=input_max_length,
|
317 |
+
return_full_and_core_embs=True)
|
318 |
+
if return_core_id_embs:
|
319 |
+
arc2face_pos_prompt_emb = arc2face_pos_core_prompt_emb
|
320 |
+
# If extract_faceid_embeds, we assume all images are from the same subject, and the batch dim of faceid_embeds is 1.
|
321 |
+
# So we need to repeat faceid_embeds.
|
322 |
+
if extract_faceid_embeds:
|
323 |
+
faceid_embeds = faceid_embeds.repeat(id_batch_size, 1)
|
324 |
+
arc2face_pos_prompt_emb = arc2face_pos_prompt_emb.repeat(id_batch_size, 1, 1)
|
325 |
+
|
326 |
+
if gen_neg_prompt:
|
327 |
+
with torch.no_grad():
|
328 |
+
arc2face_neg_prompt_emb, arc2face_neg_core_prompt_emb = \
|
329 |
+
arc2face_forward_face_embs(clip_tokenizer, arc2face_text_encoder,
|
330 |
+
torch.zeros_like(faceid_embeds),
|
331 |
+
input_max_length=input_max_length,
|
332 |
+
return_full_and_core_embs=True)
|
333 |
+
if return_core_id_embs:
|
334 |
+
arc2face_neg_prompt_emb = arc2face_neg_core_prompt_emb
|
335 |
+
|
336 |
+
#if extract_faceid_embeds:
|
337 |
+
# arc2face_neg_prompt_emb = arc2face_neg_prompt_emb.repeat(id_batch_size, 1, 1)
|
338 |
+
return faceid_embeds, arc2face_pos_prompt_emb, arc2face_neg_prompt_emb
|
339 |
+
else:
|
340 |
+
return faceid_embeds, arc2face_pos_prompt_emb
|
341 |
+
|
animatediff/models/attention.py
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
11 |
+
from diffusers import ModelMixin
|
12 |
+
from diffusers.utils import BaseOutput
|
13 |
+
from diffusers.utils.import_utils import is_xformers_available
|
14 |
+
from diffusers.models.attention import FeedForward, AdaLayerNorm,Attention
|
15 |
+
|
16 |
+
from einops import rearrange, repeat
|
17 |
+
import pdb
|
18 |
+
|
19 |
+
from diffusers.models.attention_processor import AttnProcessor,AttnProcessor2_0
|
20 |
+
@dataclass
|
21 |
+
class Transformer3DModelOutput(BaseOutput):
|
22 |
+
sample: torch.FloatTensor
|
23 |
+
from diffusers.utils import logging
|
24 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
25 |
+
|
26 |
+
if is_xformers_available():
|
27 |
+
import xformers
|
28 |
+
import xformers.ops
|
29 |
+
else:
|
30 |
+
xformers = None
|
31 |
+
|
32 |
+
|
33 |
+
class Transformer3DModel(ModelMixin, ConfigMixin):
|
34 |
+
@register_to_config
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
num_attention_heads: int = 16,
|
38 |
+
attention_head_dim: int = 88,
|
39 |
+
in_channels: Optional[int] = None,
|
40 |
+
num_layers: int = 1,
|
41 |
+
dropout: float = 0.0,
|
42 |
+
norm_num_groups: int = 32,
|
43 |
+
cross_attention_dim: Optional[int] = None,
|
44 |
+
attention_bias: bool = False,
|
45 |
+
activation_fn: str = "geglu",
|
46 |
+
num_embeds_ada_norm: Optional[int] = None,
|
47 |
+
use_linear_projection: bool = False,
|
48 |
+
only_cross_attention: bool = False,
|
49 |
+
upcast_attention: bool = False,
|
50 |
+
unet_use_cross_frame_attention=None,
|
51 |
+
unet_use_temporal_attention=None,
|
52 |
+
processor: Optional["AttnProcessor"] = None,
|
53 |
+
):
|
54 |
+
super().__init__()
|
55 |
+
self.use_linear_projection = use_linear_projection
|
56 |
+
self.num_attention_heads = num_attention_heads
|
57 |
+
self.attention_head_dim = attention_head_dim
|
58 |
+
inner_dim = num_attention_heads * attention_head_dim
|
59 |
+
|
60 |
+
# Define input layers
|
61 |
+
self.in_channels = in_channels
|
62 |
+
|
63 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
64 |
+
if use_linear_projection:
|
65 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
66 |
+
else:
|
67 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
68 |
+
|
69 |
+
# Define transformers blocks
|
70 |
+
self.transformer_blocks = nn.ModuleList(
|
71 |
+
[
|
72 |
+
BasicTransformerBlock(
|
73 |
+
inner_dim,
|
74 |
+
num_attention_heads,
|
75 |
+
attention_head_dim,
|
76 |
+
dropout=dropout,
|
77 |
+
cross_attention_dim=cross_attention_dim,
|
78 |
+
activation_fn=activation_fn,
|
79 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
80 |
+
attention_bias=attention_bias,
|
81 |
+
only_cross_attention=only_cross_attention,
|
82 |
+
upcast_attention=upcast_attention,
|
83 |
+
|
84 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
85 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
86 |
+
)
|
87 |
+
for d in range(num_layers)
|
88 |
+
]
|
89 |
+
)
|
90 |
+
|
91 |
+
# 4. Define output layers
|
92 |
+
if use_linear_projection:
|
93 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
94 |
+
else:
|
95 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
96 |
+
# if processor is None:
|
97 |
+
# processor = (
|
98 |
+
# AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
99 |
+
# )
|
100 |
+
# self.set_processor(processor)
|
101 |
+
# def set_processor(self, processor: "AttnProcessor") -> None:
|
102 |
+
# r"""
|
103 |
+
# Set the attention processor to use.
|
104 |
+
|
105 |
+
# Args:
|
106 |
+
# processor (`AttnProcessor`):
|
107 |
+
# The attention processor to use.
|
108 |
+
# """
|
109 |
+
# # if current processor is in `self._modules` and if passed `processor` is not, we need to
|
110 |
+
# # pop `processor` from `self._modules`
|
111 |
+
# if (
|
112 |
+
# hasattr(self, "processor")
|
113 |
+
# and isinstance(self.processor, torch.nn.Module)
|
114 |
+
# and not isinstance(processor, torch.nn.Module)
|
115 |
+
# ):
|
116 |
+
# logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
117 |
+
# self._modules.pop("processor")
|
118 |
+
|
119 |
+
# self.processor = processor
|
120 |
+
|
121 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
122 |
+
# Input
|
123 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
124 |
+
video_length = hidden_states.shape[2]
|
125 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
126 |
+
encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
|
127 |
+
|
128 |
+
batch, channel, height, weight = hidden_states.shape
|
129 |
+
residual = hidden_states
|
130 |
+
|
131 |
+
hidden_states = self.norm(hidden_states)
|
132 |
+
if not self.use_linear_projection:
|
133 |
+
hidden_states = self.proj_in(hidden_states)
|
134 |
+
inner_dim = hidden_states.shape[1]
|
135 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
136 |
+
else:
|
137 |
+
inner_dim = hidden_states.shape[1]
|
138 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
139 |
+
hidden_states = self.proj_in(hidden_states)
|
140 |
+
|
141 |
+
# Blocks
|
142 |
+
for block in self.transformer_blocks:
|
143 |
+
hidden_states = block(
|
144 |
+
hidden_states,
|
145 |
+
encoder_hidden_states=encoder_hidden_states,
|
146 |
+
timestep=timestep,
|
147 |
+
video_length=video_length
|
148 |
+
)
|
149 |
+
|
150 |
+
# Output
|
151 |
+
if not self.use_linear_projection:
|
152 |
+
hidden_states = (
|
153 |
+
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
154 |
+
)
|
155 |
+
hidden_states = self.proj_out(hidden_states)
|
156 |
+
else:
|
157 |
+
hidden_states = self.proj_out(hidden_states)
|
158 |
+
hidden_states = (
|
159 |
+
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
160 |
+
)
|
161 |
+
|
162 |
+
output = hidden_states + residual
|
163 |
+
|
164 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
165 |
+
if not return_dict:
|
166 |
+
return (output,)
|
167 |
+
|
168 |
+
return Transformer3DModelOutput(sample=output)
|
169 |
+
|
170 |
+
|
171 |
+
class BasicTransformerBlock(nn.Module):
|
172 |
+
def __init__(
|
173 |
+
self,
|
174 |
+
dim: int,
|
175 |
+
num_attention_heads: int,
|
176 |
+
attention_head_dim: int,
|
177 |
+
dropout=0.0,
|
178 |
+
cross_attention_dim: Optional[int] = None,
|
179 |
+
activation_fn: str = "geglu",
|
180 |
+
num_embeds_ada_norm: Optional[int] = None,
|
181 |
+
attention_bias: bool = False,
|
182 |
+
only_cross_attention: bool = False,
|
183 |
+
upcast_attention: bool = False,
|
184 |
+
|
185 |
+
unet_use_cross_frame_attention = None,
|
186 |
+
unet_use_temporal_attention = None,
|
187 |
+
):
|
188 |
+
super().__init__()
|
189 |
+
self.only_cross_attention = only_cross_attention
|
190 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
191 |
+
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
|
192 |
+
self.unet_use_temporal_attention = unet_use_temporal_attention
|
193 |
+
|
194 |
+
# SC-Attn
|
195 |
+
assert unet_use_cross_frame_attention is not None
|
196 |
+
if unet_use_cross_frame_attention:
|
197 |
+
self.attn1 = SparseCausalAttention2D(
|
198 |
+
query_dim=dim,
|
199 |
+
heads=num_attention_heads,
|
200 |
+
dim_head=attention_head_dim,
|
201 |
+
dropout=dropout,
|
202 |
+
bias=attention_bias,
|
203 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
204 |
+
upcast_attention=upcast_attention,
|
205 |
+
)
|
206 |
+
else:
|
207 |
+
#self-attention
|
208 |
+
self.attn1 = Attention(
|
209 |
+
query_dim=dim,
|
210 |
+
heads=num_attention_heads,
|
211 |
+
dim_head=attention_head_dim,
|
212 |
+
dropout=dropout,
|
213 |
+
bias=attention_bias,
|
214 |
+
upcast_attention=upcast_attention,
|
215 |
+
cross_attention_dim=None,
|
216 |
+
)
|
217 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
218 |
+
|
219 |
+
# Cross-Attn
|
220 |
+
if cross_attention_dim is not None:
|
221 |
+
self.attn2 = Attention(
|
222 |
+
query_dim=dim,
|
223 |
+
cross_attention_dim=cross_attention_dim,
|
224 |
+
heads=num_attention_heads,
|
225 |
+
dim_head=attention_head_dim,
|
226 |
+
dropout=dropout,
|
227 |
+
bias=attention_bias,
|
228 |
+
upcast_attention=upcast_attention,
|
229 |
+
)
|
230 |
+
else:
|
231 |
+
self.attn2 = None
|
232 |
+
|
233 |
+
if cross_attention_dim is not None:
|
234 |
+
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
235 |
+
else:
|
236 |
+
self.norm2 = None
|
237 |
+
|
238 |
+
# Feed-forward
|
239 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
240 |
+
self.norm3 = nn.LayerNorm(dim)
|
241 |
+
|
242 |
+
# Temp-Attn
|
243 |
+
assert unet_use_temporal_attention is not None
|
244 |
+
if unet_use_temporal_attention:
|
245 |
+
self.attn_temp = Attention(
|
246 |
+
query_dim=dim,
|
247 |
+
heads=num_attention_heads,
|
248 |
+
dim_head=attention_head_dim,
|
249 |
+
dropout=dropout,
|
250 |
+
bias=attention_bias,
|
251 |
+
upcast_attention=upcast_attention,
|
252 |
+
)
|
253 |
+
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
|
254 |
+
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
255 |
+
|
256 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool,attention_op = None):
|
257 |
+
if not is_xformers_available():
|
258 |
+
print("Here is how to install it")
|
259 |
+
raise ModuleNotFoundError(
|
260 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
261 |
+
" xformers",
|
262 |
+
name="xformers",
|
263 |
+
)
|
264 |
+
elif not torch.cuda.is_available():
|
265 |
+
raise ValueError(
|
266 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
267 |
+
" available for GPU "
|
268 |
+
)
|
269 |
+
else:
|
270 |
+
try:
|
271 |
+
# Make sure we can run the memory efficient attention
|
272 |
+
_ = xformers.ops.memory_efficient_attention(
|
273 |
+
torch.randn((1, 2, 40), device="cuda"),
|
274 |
+
torch.randn((1, 2, 40), device="cuda"),
|
275 |
+
torch.randn((1, 2, 40), device="cuda"),
|
276 |
+
)
|
277 |
+
except Exception as e:
|
278 |
+
raise e
|
279 |
+
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
280 |
+
if self.attn2 is not None:
|
281 |
+
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
282 |
+
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
283 |
+
|
284 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
|
285 |
+
# SparseCausal-Attention
|
286 |
+
norm_hidden_states = (
|
287 |
+
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
288 |
+
)
|
289 |
+
|
290 |
+
# if self.only_cross_attention:
|
291 |
+
# hidden_states = (
|
292 |
+
# self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
|
293 |
+
# )
|
294 |
+
# else:
|
295 |
+
# hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
|
296 |
+
|
297 |
+
# pdb.set_trace()
|
298 |
+
if self.unet_use_cross_frame_attention:
|
299 |
+
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
|
300 |
+
else:
|
301 |
+
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
|
302 |
+
|
303 |
+
if self.attn2 is not None:
|
304 |
+
# Cross-Attention
|
305 |
+
norm_hidden_states = (
|
306 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
307 |
+
)
|
308 |
+
hidden_states = (
|
309 |
+
self.attn2(
|
310 |
+
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
311 |
+
)
|
312 |
+
+ hidden_states
|
313 |
+
)
|
314 |
+
# Feed-forward
|
315 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
316 |
+
|
317 |
+
# Temporal-Attention
|
318 |
+
if self.unet_use_temporal_attention:
|
319 |
+
d = hidden_states.shape[1]
|
320 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
321 |
+
norm_hidden_states = (
|
322 |
+
self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
|
323 |
+
)
|
324 |
+
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
|
325 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
326 |
+
|
327 |
+
return hidden_states
|
animatediff/models/attention_bkp.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
11 |
+
from diffusers import ModelMixin
|
12 |
+
from diffusers.utils import BaseOutput
|
13 |
+
from diffusers.utils.import_utils import is_xformers_available
|
14 |
+
from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
|
15 |
+
|
16 |
+
from einops import rearrange, repeat
|
17 |
+
import pdb
|
18 |
+
|
19 |
+
from diffusers.models.attention_processor import AttnProcessor,AttnProcessor2_0
|
20 |
+
@dataclass
|
21 |
+
class Transformer3DModelOutput(BaseOutput):
|
22 |
+
sample: torch.FloatTensor
|
23 |
+
from diffusers.utils import logging
|
24 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
25 |
+
|
26 |
+
if is_xformers_available():
|
27 |
+
import xformers
|
28 |
+
import xformers.ops
|
29 |
+
else:
|
30 |
+
xformers = None
|
31 |
+
|
32 |
+
|
33 |
+
class Transformer3DModel(ModelMixin, ConfigMixin):
|
34 |
+
@register_to_config
|
35 |
+
def __init__(
|
36 |
+
self,
|
37 |
+
num_attention_heads: int = 16,
|
38 |
+
attention_head_dim: int = 88,
|
39 |
+
in_channels: Optional[int] = None,
|
40 |
+
num_layers: int = 1,
|
41 |
+
dropout: float = 0.0,
|
42 |
+
norm_num_groups: int = 32,
|
43 |
+
cross_attention_dim: Optional[int] = None,
|
44 |
+
attention_bias: bool = False,
|
45 |
+
activation_fn: str = "geglu",
|
46 |
+
num_embeds_ada_norm: Optional[int] = None,
|
47 |
+
use_linear_projection: bool = False,
|
48 |
+
only_cross_attention: bool = False,
|
49 |
+
upcast_attention: bool = False,
|
50 |
+
unet_use_cross_frame_attention=None,
|
51 |
+
unet_use_temporal_attention=None,
|
52 |
+
processor: Optional["AttnProcessor"] = None,
|
53 |
+
):
|
54 |
+
super().__init__()
|
55 |
+
self.use_linear_projection = use_linear_projection
|
56 |
+
self.num_attention_heads = num_attention_heads
|
57 |
+
self.attention_head_dim = attention_head_dim
|
58 |
+
inner_dim = num_attention_heads * attention_head_dim
|
59 |
+
|
60 |
+
# Define input layers
|
61 |
+
self.in_channels = in_channels
|
62 |
+
|
63 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
64 |
+
if use_linear_projection:
|
65 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
66 |
+
else:
|
67 |
+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
68 |
+
|
69 |
+
# Define transformers blocks
|
70 |
+
self.transformer_blocks = nn.ModuleList(
|
71 |
+
[
|
72 |
+
BasicTransformerBlock(
|
73 |
+
inner_dim,
|
74 |
+
num_attention_heads,
|
75 |
+
attention_head_dim,
|
76 |
+
dropout=dropout,
|
77 |
+
cross_attention_dim=cross_attention_dim,
|
78 |
+
activation_fn=activation_fn,
|
79 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
80 |
+
attention_bias=attention_bias,
|
81 |
+
only_cross_attention=only_cross_attention,
|
82 |
+
upcast_attention=upcast_attention,
|
83 |
+
|
84 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
85 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
86 |
+
)
|
87 |
+
for d in range(num_layers)
|
88 |
+
]
|
89 |
+
)
|
90 |
+
|
91 |
+
# 4. Define output layers
|
92 |
+
if use_linear_projection:
|
93 |
+
self.proj_out = nn.Linear(in_channels, inner_dim)
|
94 |
+
else:
|
95 |
+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
96 |
+
# if processor is None:
|
97 |
+
# processor = (
|
98 |
+
# AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
99 |
+
# )
|
100 |
+
# self.set_processor(processor)
|
101 |
+
def set_processor(self, processor: "AttnProcessor") -> None:
|
102 |
+
r"""
|
103 |
+
Set the attention processor to use.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
processor (`AttnProcessor`):
|
107 |
+
The attention processor to use.
|
108 |
+
"""
|
109 |
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
110 |
+
# pop `processor` from `self._modules`
|
111 |
+
if (
|
112 |
+
hasattr(self, "processor")
|
113 |
+
and isinstance(self.processor, torch.nn.Module)
|
114 |
+
and not isinstance(processor, torch.nn.Module)
|
115 |
+
):
|
116 |
+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
117 |
+
self._modules.pop("processor")
|
118 |
+
|
119 |
+
self.processor = processor
|
120 |
+
|
121 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
122 |
+
# Input
|
123 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
124 |
+
video_length = hidden_states.shape[2]
|
125 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
126 |
+
encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
|
127 |
+
|
128 |
+
batch, channel, height, weight = hidden_states.shape
|
129 |
+
residual = hidden_states
|
130 |
+
|
131 |
+
hidden_states = self.norm(hidden_states)
|
132 |
+
if not self.use_linear_projection:
|
133 |
+
hidden_states = self.proj_in(hidden_states)
|
134 |
+
inner_dim = hidden_states.shape[1]
|
135 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
136 |
+
else:
|
137 |
+
inner_dim = hidden_states.shape[1]
|
138 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
139 |
+
hidden_states = self.proj_in(hidden_states)
|
140 |
+
|
141 |
+
# Blocks
|
142 |
+
for block in self.transformer_blocks:
|
143 |
+
hidden_states = block(
|
144 |
+
hidden_states,
|
145 |
+
encoder_hidden_states=encoder_hidden_states,
|
146 |
+
timestep=timestep,
|
147 |
+
video_length=video_length
|
148 |
+
)
|
149 |
+
|
150 |
+
# Output
|
151 |
+
if not self.use_linear_projection:
|
152 |
+
hidden_states = (
|
153 |
+
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
154 |
+
)
|
155 |
+
hidden_states = self.proj_out(hidden_states)
|
156 |
+
else:
|
157 |
+
hidden_states = self.proj_out(hidden_states)
|
158 |
+
hidden_states = (
|
159 |
+
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
160 |
+
)
|
161 |
+
|
162 |
+
output = hidden_states + residual
|
163 |
+
|
164 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
165 |
+
if not return_dict:
|
166 |
+
return (output,)
|
167 |
+
|
168 |
+
return Transformer3DModelOutput(sample=output)
|
169 |
+
|
170 |
+
|
171 |
+
class BasicTransformerBlock(nn.Module):
|
172 |
+
def __init__(
|
173 |
+
self,
|
174 |
+
dim: int,
|
175 |
+
num_attention_heads: int,
|
176 |
+
attention_head_dim: int,
|
177 |
+
dropout=0.0,
|
178 |
+
cross_attention_dim: Optional[int] = None,
|
179 |
+
activation_fn: str = "geglu",
|
180 |
+
num_embeds_ada_norm: Optional[int] = None,
|
181 |
+
attention_bias: bool = False,
|
182 |
+
only_cross_attention: bool = False,
|
183 |
+
upcast_attention: bool = False,
|
184 |
+
|
185 |
+
unet_use_cross_frame_attention = None,
|
186 |
+
unet_use_temporal_attention = None,
|
187 |
+
):
|
188 |
+
super().__init__()
|
189 |
+
self.only_cross_attention = only_cross_attention
|
190 |
+
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
191 |
+
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
|
192 |
+
self.unet_use_temporal_attention = unet_use_temporal_attention
|
193 |
+
|
194 |
+
# SC-Attn
|
195 |
+
assert unet_use_cross_frame_attention is not None
|
196 |
+
if unet_use_cross_frame_attention:
|
197 |
+
self.attn1 = SparseCausalAttention2D(
|
198 |
+
query_dim=dim,
|
199 |
+
heads=num_attention_heads,
|
200 |
+
dim_head=attention_head_dim,
|
201 |
+
dropout=dropout,
|
202 |
+
bias=attention_bias,
|
203 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
204 |
+
upcast_attention=upcast_attention,
|
205 |
+
)
|
206 |
+
else:
|
207 |
+
self.attn1 = CrossAttention(
|
208 |
+
query_dim=dim,
|
209 |
+
heads=num_attention_heads,
|
210 |
+
dim_head=attention_head_dim,
|
211 |
+
dropout=dropout,
|
212 |
+
bias=attention_bias,
|
213 |
+
upcast_attention=upcast_attention,
|
214 |
+
)
|
215 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
216 |
+
|
217 |
+
# Cross-Attn
|
218 |
+
if cross_attention_dim is not None:
|
219 |
+
self.attn2 = CrossAttention(
|
220 |
+
query_dim=dim,
|
221 |
+
cross_attention_dim=cross_attention_dim,
|
222 |
+
heads=num_attention_heads,
|
223 |
+
dim_head=attention_head_dim,
|
224 |
+
dropout=dropout,
|
225 |
+
bias=attention_bias,
|
226 |
+
upcast_attention=upcast_attention,
|
227 |
+
)
|
228 |
+
else:
|
229 |
+
self.attn2 = None
|
230 |
+
|
231 |
+
if cross_attention_dim is not None:
|
232 |
+
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
233 |
+
else:
|
234 |
+
self.norm2 = None
|
235 |
+
|
236 |
+
# Feed-forward
|
237 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
238 |
+
self.norm3 = nn.LayerNorm(dim)
|
239 |
+
|
240 |
+
# Temp-Attn
|
241 |
+
assert unet_use_temporal_attention is not None
|
242 |
+
if unet_use_temporal_attention:
|
243 |
+
self.attn_temp = CrossAttention(
|
244 |
+
query_dim=dim,
|
245 |
+
heads=num_attention_heads,
|
246 |
+
dim_head=attention_head_dim,
|
247 |
+
dropout=dropout,
|
248 |
+
bias=attention_bias,
|
249 |
+
upcast_attention=upcast_attention,
|
250 |
+
)
|
251 |
+
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
|
252 |
+
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
253 |
+
|
254 |
+
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool,attention_op = None):
|
255 |
+
if not is_xformers_available():
|
256 |
+
print("Here is how to install it")
|
257 |
+
raise ModuleNotFoundError(
|
258 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
259 |
+
" xformers",
|
260 |
+
name="xformers",
|
261 |
+
)
|
262 |
+
elif not torch.cuda.is_available():
|
263 |
+
raise ValueError(
|
264 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
265 |
+
" available for GPU "
|
266 |
+
)
|
267 |
+
else:
|
268 |
+
try:
|
269 |
+
# Make sure we can run the memory efficient attention
|
270 |
+
_ = xformers.ops.memory_efficient_attention(
|
271 |
+
torch.randn((1, 2, 40), device="cuda"),
|
272 |
+
torch.randn((1, 2, 40), device="cuda"),
|
273 |
+
torch.randn((1, 2, 40), device="cuda"),
|
274 |
+
)
|
275 |
+
except Exception as e:
|
276 |
+
raise e
|
277 |
+
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
278 |
+
if self.attn2 is not None:
|
279 |
+
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
280 |
+
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
281 |
+
|
282 |
+
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
|
283 |
+
# SparseCausal-Attention
|
284 |
+
norm_hidden_states = (
|
285 |
+
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
286 |
+
)
|
287 |
+
|
288 |
+
# if self.only_cross_attention:
|
289 |
+
# hidden_states = (
|
290 |
+
# self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
|
291 |
+
# )
|
292 |
+
# else:
|
293 |
+
# hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
|
294 |
+
|
295 |
+
# pdb.set_trace()
|
296 |
+
if self.unet_use_cross_frame_attention:
|
297 |
+
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
|
298 |
+
else:
|
299 |
+
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
|
300 |
+
|
301 |
+
if self.attn2 is not None:
|
302 |
+
# Cross-Attention
|
303 |
+
norm_hidden_states = (
|
304 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
305 |
+
)
|
306 |
+
hidden_states = (
|
307 |
+
self.attn2(
|
308 |
+
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
309 |
+
)
|
310 |
+
+ hidden_states
|
311 |
+
)
|
312 |
+
|
313 |
+
# Feed-forward
|
314 |
+
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
315 |
+
|
316 |
+
# Temporal-Attention
|
317 |
+
if self.unet_use_temporal_attention:
|
318 |
+
d = hidden_states.shape[1]
|
319 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
320 |
+
norm_hidden_states = (
|
321 |
+
self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
|
322 |
+
)
|
323 |
+
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
|
324 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
325 |
+
|
326 |
+
return hidden_states
|
animatediff/models/motion_module.py
ADDED
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
import torchvision
|
9 |
+
|
10 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
11 |
+
from diffusers import ModelMixin
|
12 |
+
from diffusers.utils import BaseOutput
|
13 |
+
from diffusers.utils.import_utils import is_xformers_available
|
14 |
+
from diffusers.models.attention import FeedForward,Attention
|
15 |
+
|
16 |
+
from einops import rearrange, repeat
|
17 |
+
import math
|
18 |
+
|
19 |
+
|
20 |
+
def zero_module(module):
|
21 |
+
# Zero out the parameters of a module and return it.
|
22 |
+
for p in module.parameters():
|
23 |
+
p.detach().zero_()
|
24 |
+
return module
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class TemporalTransformer3DModelOutput(BaseOutput):
|
29 |
+
sample: torch.FloatTensor
|
30 |
+
|
31 |
+
|
32 |
+
if is_xformers_available():
|
33 |
+
import xformers
|
34 |
+
import xformers.ops
|
35 |
+
else:
|
36 |
+
xformers = None
|
37 |
+
|
38 |
+
|
39 |
+
def get_motion_module(
|
40 |
+
in_channels,
|
41 |
+
motion_module_type: str,
|
42 |
+
motion_module_kwargs: dict
|
43 |
+
):
|
44 |
+
if motion_module_type == "Vanilla":
|
45 |
+
return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
|
46 |
+
else:
|
47 |
+
raise ValueError
|
48 |
+
|
49 |
+
|
50 |
+
class VanillaTemporalModule(nn.Module):
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
in_channels,
|
54 |
+
num_attention_heads = 8,
|
55 |
+
num_transformer_block = 2,
|
56 |
+
attention_block_types =( "Temporal_Self", "Temporal_Self" ),
|
57 |
+
cross_frame_attention_mode = None,
|
58 |
+
temporal_position_encoding = False,
|
59 |
+
temporal_position_encoding_max_len = 24,
|
60 |
+
temporal_attention_dim_div = 1,
|
61 |
+
zero_initialize = True,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
self.temporal_transformer = TemporalTransformer3DModel(
|
66 |
+
in_channels=in_channels,
|
67 |
+
num_attention_heads=num_attention_heads,
|
68 |
+
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
|
69 |
+
num_layers=num_transformer_block,
|
70 |
+
attention_block_types=attention_block_types,
|
71 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
72 |
+
temporal_position_encoding=temporal_position_encoding,
|
73 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
74 |
+
)
|
75 |
+
|
76 |
+
if zero_initialize:
|
77 |
+
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
|
78 |
+
|
79 |
+
def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
|
80 |
+
hidden_states = input_tensor
|
81 |
+
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
|
82 |
+
|
83 |
+
output = hidden_states
|
84 |
+
return output
|
85 |
+
|
86 |
+
|
87 |
+
class TemporalTransformer3DModel(nn.Module):
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
in_channels,
|
91 |
+
num_attention_heads,
|
92 |
+
attention_head_dim,
|
93 |
+
|
94 |
+
num_layers,
|
95 |
+
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
|
96 |
+
dropout = 0.0,
|
97 |
+
norm_num_groups = 32,
|
98 |
+
cross_attention_dim = 768,
|
99 |
+
activation_fn = "geglu",
|
100 |
+
attention_bias = False,
|
101 |
+
upcast_attention = False,
|
102 |
+
|
103 |
+
cross_frame_attention_mode = None,
|
104 |
+
temporal_position_encoding = False,
|
105 |
+
temporal_position_encoding_max_len = 24,
|
106 |
+
):
|
107 |
+
super().__init__()
|
108 |
+
|
109 |
+
inner_dim = num_attention_heads * attention_head_dim
|
110 |
+
|
111 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
112 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
113 |
+
|
114 |
+
self.transformer_blocks = nn.ModuleList(
|
115 |
+
[
|
116 |
+
TemporalTransformerBlock(
|
117 |
+
dim=inner_dim,
|
118 |
+
num_attention_heads=num_attention_heads,
|
119 |
+
attention_head_dim=attention_head_dim,
|
120 |
+
attention_block_types=attention_block_types,
|
121 |
+
dropout=dropout,
|
122 |
+
norm_num_groups=norm_num_groups,
|
123 |
+
cross_attention_dim=cross_attention_dim,
|
124 |
+
activation_fn=activation_fn,
|
125 |
+
attention_bias=attention_bias,
|
126 |
+
upcast_attention=upcast_attention,
|
127 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
128 |
+
temporal_position_encoding=temporal_position_encoding,
|
129 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
130 |
+
)
|
131 |
+
for d in range(num_layers)
|
132 |
+
]
|
133 |
+
)
|
134 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
135 |
+
|
136 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
137 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
138 |
+
video_length = hidden_states.shape[2]
|
139 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
140 |
+
|
141 |
+
batch, channel, height, weight = hidden_states.shape
|
142 |
+
residual = hidden_states
|
143 |
+
|
144 |
+
hidden_states = self.norm(hidden_states)
|
145 |
+
inner_dim = hidden_states.shape[1]
|
146 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
147 |
+
hidden_states = self.proj_in(hidden_states)
|
148 |
+
|
149 |
+
# Transformer Blocks
|
150 |
+
for block in self.transformer_blocks:
|
151 |
+
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
|
152 |
+
|
153 |
+
# output
|
154 |
+
hidden_states = self.proj_out(hidden_states)
|
155 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
156 |
+
|
157 |
+
output = hidden_states + residual
|
158 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
159 |
+
|
160 |
+
return output
|
161 |
+
|
162 |
+
|
163 |
+
class TemporalTransformerBlock(nn.Module):
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
dim,
|
167 |
+
num_attention_heads,
|
168 |
+
attention_head_dim,
|
169 |
+
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
|
170 |
+
dropout = 0.0,
|
171 |
+
norm_num_groups = 32,
|
172 |
+
cross_attention_dim = 768,
|
173 |
+
activation_fn = "geglu",
|
174 |
+
attention_bias = False,
|
175 |
+
upcast_attention = False,
|
176 |
+
cross_frame_attention_mode = None,
|
177 |
+
temporal_position_encoding = False,
|
178 |
+
temporal_position_encoding_max_len = 24,
|
179 |
+
):
|
180 |
+
super().__init__()
|
181 |
+
|
182 |
+
attention_blocks = []
|
183 |
+
norms = []
|
184 |
+
|
185 |
+
for block_name in attention_block_types:
|
186 |
+
attention_blocks.append(
|
187 |
+
VersatileAttention(
|
188 |
+
attention_mode=block_name.split("_")[0],
|
189 |
+
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
|
190 |
+
|
191 |
+
query_dim=dim,
|
192 |
+
heads=num_attention_heads,
|
193 |
+
dim_head=attention_head_dim,
|
194 |
+
dropout=dropout,
|
195 |
+
bias=attention_bias,
|
196 |
+
upcast_attention=upcast_attention,
|
197 |
+
|
198 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
199 |
+
temporal_position_encoding=temporal_position_encoding,
|
200 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
201 |
+
)
|
202 |
+
)
|
203 |
+
norms.append(nn.LayerNorm(dim))
|
204 |
+
|
205 |
+
self.attention_blocks = nn.ModuleList(attention_blocks)
|
206 |
+
self.norms = nn.ModuleList(norms)
|
207 |
+
|
208 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
209 |
+
self.ff_norm = nn.LayerNorm(dim)
|
210 |
+
|
211 |
+
|
212 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
213 |
+
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
214 |
+
norm_hidden_states = norm(hidden_states)
|
215 |
+
hidden_states = attention_block(
|
216 |
+
norm_hidden_states,
|
217 |
+
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
|
218 |
+
video_length=video_length,
|
219 |
+
) + hidden_states
|
220 |
+
|
221 |
+
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
222 |
+
|
223 |
+
output = hidden_states
|
224 |
+
return output
|
225 |
+
|
226 |
+
|
227 |
+
class PositionalEncoding(nn.Module):
|
228 |
+
def __init__(
|
229 |
+
self,
|
230 |
+
d_model,
|
231 |
+
dropout = 0.,
|
232 |
+
max_len = 24
|
233 |
+
):
|
234 |
+
super().__init__()
|
235 |
+
self.dropout = nn.Dropout(p=dropout)
|
236 |
+
position = torch.arange(max_len).unsqueeze(1)
|
237 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
238 |
+
pe = torch.zeros(1, max_len, d_model)
|
239 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
240 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
241 |
+
self.register_buffer('pe', pe)
|
242 |
+
|
243 |
+
def forward(self, x):
|
244 |
+
x = x + self.pe[:, :x.size(1)]
|
245 |
+
return self.dropout(x)
|
246 |
+
|
247 |
+
class CrossAttention(nn.Module):
|
248 |
+
r"""
|
249 |
+
A cross attention layer.
|
250 |
+
|
251 |
+
Parameters:
|
252 |
+
query_dim (`int`): The number of channels in the query.
|
253 |
+
cross_attention_dim (`int`, *optional*):
|
254 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
255 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
256 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
257 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
258 |
+
bias (`bool`, *optional*, defaults to False):
|
259 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
260 |
+
"""
|
261 |
+
|
262 |
+
def __init__(
|
263 |
+
self,
|
264 |
+
query_dim: int,
|
265 |
+
cross_attention_dim: Optional[int] = None,
|
266 |
+
heads: int = 8,
|
267 |
+
dim_head: int = 64,
|
268 |
+
dropout: float = 0.0,
|
269 |
+
bias=False,
|
270 |
+
upcast_attention: bool = False,
|
271 |
+
upcast_softmax: bool = False,
|
272 |
+
added_kv_proj_dim: Optional[int] = None,
|
273 |
+
norm_num_groups: Optional[int] = None,
|
274 |
+
):
|
275 |
+
super().__init__()
|
276 |
+
inner_dim = dim_head * heads
|
277 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
278 |
+
self.upcast_attention = upcast_attention
|
279 |
+
self.upcast_softmax = upcast_softmax
|
280 |
+
|
281 |
+
self.scale = dim_head**-0.5
|
282 |
+
|
283 |
+
self.heads = heads
|
284 |
+
# for slice_size > 0 the attention score computation
|
285 |
+
# is split across the batch axis to save memory
|
286 |
+
# You can set slice_size with `set_attention_slice`
|
287 |
+
self.sliceable_head_dim = heads
|
288 |
+
self._slice_size = None
|
289 |
+
self._use_memory_efficient_attention_xformers = False
|
290 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
291 |
+
|
292 |
+
if norm_num_groups is not None:
|
293 |
+
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
|
294 |
+
else:
|
295 |
+
self.group_norm = None
|
296 |
+
|
297 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
298 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
299 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
300 |
+
|
301 |
+
if self.added_kv_proj_dim is not None:
|
302 |
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
303 |
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
304 |
+
|
305 |
+
self.to_out = nn.ModuleList([])
|
306 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
307 |
+
self.to_out.append(nn.Dropout(dropout))
|
308 |
+
|
309 |
+
def reshape_heads_to_batch_dim(self, tensor):
|
310 |
+
batch_size, seq_len, dim = tensor.shape
|
311 |
+
head_size = self.heads
|
312 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
313 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
314 |
+
return tensor
|
315 |
+
|
316 |
+
def reshape_batch_dim_to_heads(self, tensor):
|
317 |
+
batch_size, seq_len, dim = tensor.shape
|
318 |
+
head_size = self.heads
|
319 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
320 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
321 |
+
return tensor
|
322 |
+
|
323 |
+
def set_attention_slice(self, slice_size):
|
324 |
+
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
325 |
+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
326 |
+
|
327 |
+
self._slice_size = slice_size
|
328 |
+
|
329 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
330 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
331 |
+
|
332 |
+
encoder_hidden_states = encoder_hidden_states
|
333 |
+
|
334 |
+
if self.group_norm is not None:
|
335 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
336 |
+
|
337 |
+
query = self.to_q(hidden_states)
|
338 |
+
dim = query.shape[-1]
|
339 |
+
query = self.reshape_heads_to_batch_dim(query)
|
340 |
+
|
341 |
+
if self.added_kv_proj_dim is not None:
|
342 |
+
key = self.to_k(hidden_states)
|
343 |
+
value = self.to_v(hidden_states)
|
344 |
+
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
|
345 |
+
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
|
346 |
+
|
347 |
+
key = self.reshape_heads_to_batch_dim(key)
|
348 |
+
value = self.reshape_heads_to_batch_dim(value)
|
349 |
+
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
|
350 |
+
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
|
351 |
+
|
352 |
+
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
|
353 |
+
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
|
354 |
+
else:
|
355 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
356 |
+
key = self.to_k(encoder_hidden_states)
|
357 |
+
value = self.to_v(encoder_hidden_states)
|
358 |
+
|
359 |
+
key = self.reshape_heads_to_batch_dim(key)
|
360 |
+
value = self.reshape_heads_to_batch_dim(value)
|
361 |
+
|
362 |
+
if attention_mask is not None:
|
363 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
364 |
+
target_length = query.shape[1]
|
365 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
366 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
367 |
+
|
368 |
+
# attention, what we cannot get enough of
|
369 |
+
if self._use_memory_efficient_attention_xformers:
|
370 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
371 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
372 |
+
hidden_states = hidden_states.to(query.dtype)
|
373 |
+
else:
|
374 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
375 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
376 |
+
else:
|
377 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
378 |
+
|
379 |
+
# linear proj
|
380 |
+
hidden_states = self.to_out[0](hidden_states)
|
381 |
+
|
382 |
+
# dropout
|
383 |
+
hidden_states = self.to_out[1](hidden_states)
|
384 |
+
return hidden_states
|
385 |
+
|
386 |
+
def _attention(self, query, key, value, attention_mask=None):
|
387 |
+
if self.upcast_attention:
|
388 |
+
query = query.float()
|
389 |
+
key = key.float()
|
390 |
+
|
391 |
+
attention_scores = torch.baddbmm(
|
392 |
+
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
393 |
+
query,
|
394 |
+
key.transpose(-1, -2),
|
395 |
+
beta=0,
|
396 |
+
alpha=self.scale,
|
397 |
+
)
|
398 |
+
|
399 |
+
if attention_mask is not None:
|
400 |
+
attention_scores = attention_scores + attention_mask
|
401 |
+
|
402 |
+
if self.upcast_softmax:
|
403 |
+
attention_scores = attention_scores.float()
|
404 |
+
|
405 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
406 |
+
|
407 |
+
# cast back to the original dtype
|
408 |
+
attention_probs = attention_probs.to(value.dtype)
|
409 |
+
|
410 |
+
# compute attention output
|
411 |
+
hidden_states = torch.bmm(attention_probs, value)
|
412 |
+
|
413 |
+
# reshape hidden_states
|
414 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
415 |
+
return hidden_states
|
416 |
+
|
417 |
+
def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
|
418 |
+
batch_size_attention = query.shape[0]
|
419 |
+
hidden_states = torch.zeros(
|
420 |
+
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
421 |
+
)
|
422 |
+
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
|
423 |
+
for i in range(hidden_states.shape[0] // slice_size):
|
424 |
+
start_idx = i * slice_size
|
425 |
+
end_idx = (i + 1) * slice_size
|
426 |
+
|
427 |
+
query_slice = query[start_idx:end_idx]
|
428 |
+
key_slice = key[start_idx:end_idx]
|
429 |
+
|
430 |
+
if self.upcast_attention:
|
431 |
+
query_slice = query_slice.float()
|
432 |
+
key_slice = key_slice.float()
|
433 |
+
|
434 |
+
attn_slice = torch.baddbmm(
|
435 |
+
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
|
436 |
+
query_slice,
|
437 |
+
key_slice.transpose(-1, -2),
|
438 |
+
beta=0,
|
439 |
+
alpha=self.scale,
|
440 |
+
)
|
441 |
+
|
442 |
+
if attention_mask is not None:
|
443 |
+
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
|
444 |
+
|
445 |
+
if self.upcast_softmax:
|
446 |
+
attn_slice = attn_slice.float()
|
447 |
+
|
448 |
+
attn_slice = attn_slice.softmax(dim=-1)
|
449 |
+
|
450 |
+
# cast back to the original dtype
|
451 |
+
attn_slice = attn_slice.to(value.dtype)
|
452 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
453 |
+
|
454 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
455 |
+
|
456 |
+
# reshape hidden_states
|
457 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
458 |
+
return hidden_states
|
459 |
+
|
460 |
+
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
|
461 |
+
# TODO attention_mask
|
462 |
+
query = query.contiguous()
|
463 |
+
key = key.contiguous()
|
464 |
+
value = value.contiguous()
|
465 |
+
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
466 |
+
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
467 |
+
return hidden_states
|
468 |
+
|
469 |
+
class VersatileAttention(CrossAttention):
|
470 |
+
def __init__(
|
471 |
+
self,
|
472 |
+
attention_mode = None,
|
473 |
+
cross_frame_attention_mode = None,
|
474 |
+
temporal_position_encoding = False,
|
475 |
+
temporal_position_encoding_max_len = 24,
|
476 |
+
*args, **kwargs
|
477 |
+
):
|
478 |
+
super().__init__(*args, **kwargs)
|
479 |
+
assert attention_mode == "Temporal"
|
480 |
+
|
481 |
+
self.attention_mode = attention_mode
|
482 |
+
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
|
483 |
+
|
484 |
+
self.pos_encoder = PositionalEncoding(
|
485 |
+
kwargs["query_dim"],
|
486 |
+
dropout=0.,
|
487 |
+
max_len=temporal_position_encoding_max_len
|
488 |
+
) if (temporal_position_encoding and attention_mode == "Temporal") else None
|
489 |
+
|
490 |
+
def extra_repr(self):
|
491 |
+
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
|
492 |
+
|
493 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
494 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
495 |
+
|
496 |
+
if self.attention_mode == "Temporal":
|
497 |
+
d = hidden_states.shape[1]
|
498 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
499 |
+
|
500 |
+
if self.pos_encoder is not None:
|
501 |
+
hidden_states = self.pos_encoder(hidden_states)
|
502 |
+
|
503 |
+
encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
|
504 |
+
else:
|
505 |
+
raise NotImplementedError
|
506 |
+
|
507 |
+
encoder_hidden_states = encoder_hidden_states
|
508 |
+
|
509 |
+
if self.group_norm is not None:
|
510 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
511 |
+
|
512 |
+
query = self.to_q(hidden_states)
|
513 |
+
dim = query.shape[-1]
|
514 |
+
query = self.reshape_heads_to_batch_dim(query)
|
515 |
+
|
516 |
+
if self.added_kv_proj_dim is not None:
|
517 |
+
raise NotImplementedError
|
518 |
+
|
519 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
520 |
+
key = self.to_k(encoder_hidden_states)
|
521 |
+
value = self.to_v(encoder_hidden_states)
|
522 |
+
|
523 |
+
key = self.reshape_heads_to_batch_dim(key)
|
524 |
+
value = self.reshape_heads_to_batch_dim(value)
|
525 |
+
|
526 |
+
if attention_mask is not None:
|
527 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
528 |
+
target_length = query.shape[1]
|
529 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
530 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
531 |
+
|
532 |
+
# attention, what we cannot get enough of
|
533 |
+
if self._use_memory_efficient_attention_xformers:
|
534 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
535 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
536 |
+
hidden_states = hidden_states.to(query.dtype)
|
537 |
+
else:
|
538 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
539 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
540 |
+
else:
|
541 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
542 |
+
|
543 |
+
# linear proj
|
544 |
+
hidden_states = self.to_out[0](hidden_states)
|
545 |
+
|
546 |
+
# dropout
|
547 |
+
hidden_states = self.to_out[1](hidden_states)
|
548 |
+
|
549 |
+
if self.attention_mode == "Temporal":
|
550 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
551 |
+
|
552 |
+
return hidden_states
|
animatediff/models/motion_module_bkp.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
import torchvision
|
9 |
+
|
10 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
11 |
+
from diffusers import ModelMixin
|
12 |
+
from diffusers.utils import BaseOutput
|
13 |
+
from diffusers.utils.import_utils import is_xformers_available
|
14 |
+
from diffusers.models.attention import CrossAttention, FeedForward
|
15 |
+
|
16 |
+
from einops import rearrange, repeat
|
17 |
+
import math
|
18 |
+
|
19 |
+
|
20 |
+
def zero_module(module):
|
21 |
+
# Zero out the parameters of a module and return it.
|
22 |
+
for p in module.parameters():
|
23 |
+
p.detach().zero_()
|
24 |
+
return module
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class TemporalTransformer3DModelOutput(BaseOutput):
|
29 |
+
sample: torch.FloatTensor
|
30 |
+
|
31 |
+
|
32 |
+
if is_xformers_available():
|
33 |
+
import xformers
|
34 |
+
import xformers.ops
|
35 |
+
else:
|
36 |
+
xformers = None
|
37 |
+
|
38 |
+
|
39 |
+
def get_motion_module(
|
40 |
+
in_channels,
|
41 |
+
motion_module_type: str,
|
42 |
+
motion_module_kwargs: dict
|
43 |
+
):
|
44 |
+
if motion_module_type == "Vanilla":
|
45 |
+
return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
|
46 |
+
else:
|
47 |
+
raise ValueError
|
48 |
+
|
49 |
+
|
50 |
+
class VanillaTemporalModule(nn.Module):
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
in_channels,
|
54 |
+
num_attention_heads = 8,
|
55 |
+
num_transformer_block = 2,
|
56 |
+
attention_block_types =( "Temporal_Self", "Temporal_Self" ),
|
57 |
+
cross_frame_attention_mode = None,
|
58 |
+
temporal_position_encoding = False,
|
59 |
+
temporal_position_encoding_max_len = 24,
|
60 |
+
temporal_attention_dim_div = 1,
|
61 |
+
zero_initialize = True,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
self.temporal_transformer = TemporalTransformer3DModel(
|
66 |
+
in_channels=in_channels,
|
67 |
+
num_attention_heads=num_attention_heads,
|
68 |
+
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
|
69 |
+
num_layers=num_transformer_block,
|
70 |
+
attention_block_types=attention_block_types,
|
71 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
72 |
+
temporal_position_encoding=temporal_position_encoding,
|
73 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
74 |
+
)
|
75 |
+
|
76 |
+
if zero_initialize:
|
77 |
+
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
|
78 |
+
|
79 |
+
def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
|
80 |
+
hidden_states = input_tensor
|
81 |
+
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
|
82 |
+
|
83 |
+
output = hidden_states
|
84 |
+
return output
|
85 |
+
|
86 |
+
|
87 |
+
class TemporalTransformer3DModel(nn.Module):
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
in_channels,
|
91 |
+
num_attention_heads,
|
92 |
+
attention_head_dim,
|
93 |
+
|
94 |
+
num_layers,
|
95 |
+
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
|
96 |
+
dropout = 0.0,
|
97 |
+
norm_num_groups = 32,
|
98 |
+
cross_attention_dim = 768,
|
99 |
+
activation_fn = "geglu",
|
100 |
+
attention_bias = False,
|
101 |
+
upcast_attention = False,
|
102 |
+
|
103 |
+
cross_frame_attention_mode = None,
|
104 |
+
temporal_position_encoding = False,
|
105 |
+
temporal_position_encoding_max_len = 24,
|
106 |
+
):
|
107 |
+
super().__init__()
|
108 |
+
|
109 |
+
inner_dim = num_attention_heads * attention_head_dim
|
110 |
+
|
111 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
112 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
113 |
+
|
114 |
+
self.transformer_blocks = nn.ModuleList(
|
115 |
+
[
|
116 |
+
TemporalTransformerBlock(
|
117 |
+
dim=inner_dim,
|
118 |
+
num_attention_heads=num_attention_heads,
|
119 |
+
attention_head_dim=attention_head_dim,
|
120 |
+
attention_block_types=attention_block_types,
|
121 |
+
dropout=dropout,
|
122 |
+
norm_num_groups=norm_num_groups,
|
123 |
+
cross_attention_dim=cross_attention_dim,
|
124 |
+
activation_fn=activation_fn,
|
125 |
+
attention_bias=attention_bias,
|
126 |
+
upcast_attention=upcast_attention,
|
127 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
128 |
+
temporal_position_encoding=temporal_position_encoding,
|
129 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
130 |
+
)
|
131 |
+
for d in range(num_layers)
|
132 |
+
]
|
133 |
+
)
|
134 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
135 |
+
|
136 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
137 |
+
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
138 |
+
video_length = hidden_states.shape[2]
|
139 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
140 |
+
|
141 |
+
batch, channel, height, weight = hidden_states.shape
|
142 |
+
residual = hidden_states
|
143 |
+
|
144 |
+
hidden_states = self.norm(hidden_states)
|
145 |
+
inner_dim = hidden_states.shape[1]
|
146 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
147 |
+
hidden_states = self.proj_in(hidden_states)
|
148 |
+
|
149 |
+
# Transformer Blocks
|
150 |
+
for block in self.transformer_blocks:
|
151 |
+
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
|
152 |
+
|
153 |
+
# output
|
154 |
+
hidden_states = self.proj_out(hidden_states)
|
155 |
+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
156 |
+
|
157 |
+
output = hidden_states + residual
|
158 |
+
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
159 |
+
|
160 |
+
return output
|
161 |
+
|
162 |
+
|
163 |
+
class TemporalTransformerBlock(nn.Module):
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
dim,
|
167 |
+
num_attention_heads,
|
168 |
+
attention_head_dim,
|
169 |
+
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
|
170 |
+
dropout = 0.0,
|
171 |
+
norm_num_groups = 32,
|
172 |
+
cross_attention_dim = 768,
|
173 |
+
activation_fn = "geglu",
|
174 |
+
attention_bias = False,
|
175 |
+
upcast_attention = False,
|
176 |
+
cross_frame_attention_mode = None,
|
177 |
+
temporal_position_encoding = False,
|
178 |
+
temporal_position_encoding_max_len = 24,
|
179 |
+
):
|
180 |
+
super().__init__()
|
181 |
+
|
182 |
+
attention_blocks = []
|
183 |
+
norms = []
|
184 |
+
|
185 |
+
for block_name in attention_block_types:
|
186 |
+
attention_blocks.append(
|
187 |
+
VersatileAttention(
|
188 |
+
attention_mode=block_name.split("_")[0],
|
189 |
+
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
|
190 |
+
|
191 |
+
query_dim=dim,
|
192 |
+
heads=num_attention_heads,
|
193 |
+
dim_head=attention_head_dim,
|
194 |
+
dropout=dropout,
|
195 |
+
bias=attention_bias,
|
196 |
+
upcast_attention=upcast_attention,
|
197 |
+
|
198 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
199 |
+
temporal_position_encoding=temporal_position_encoding,
|
200 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
201 |
+
)
|
202 |
+
)
|
203 |
+
norms.append(nn.LayerNorm(dim))
|
204 |
+
|
205 |
+
self.attention_blocks = nn.ModuleList(attention_blocks)
|
206 |
+
self.norms = nn.ModuleList(norms)
|
207 |
+
|
208 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
209 |
+
self.ff_norm = nn.LayerNorm(dim)
|
210 |
+
|
211 |
+
|
212 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
213 |
+
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
214 |
+
norm_hidden_states = norm(hidden_states)
|
215 |
+
hidden_states = attention_block(
|
216 |
+
norm_hidden_states,
|
217 |
+
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
|
218 |
+
video_length=video_length,
|
219 |
+
) + hidden_states
|
220 |
+
|
221 |
+
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
222 |
+
|
223 |
+
output = hidden_states
|
224 |
+
return output
|
225 |
+
|
226 |
+
|
227 |
+
class PositionalEncoding(nn.Module):
|
228 |
+
def __init__(
|
229 |
+
self,
|
230 |
+
d_model,
|
231 |
+
dropout = 0.,
|
232 |
+
max_len = 24
|
233 |
+
):
|
234 |
+
super().__init__()
|
235 |
+
self.dropout = nn.Dropout(p=dropout)
|
236 |
+
position = torch.arange(max_len).unsqueeze(1)
|
237 |
+
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
238 |
+
pe = torch.zeros(1, max_len, d_model)
|
239 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
240 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
241 |
+
self.register_buffer('pe', pe)
|
242 |
+
|
243 |
+
def forward(self, x):
|
244 |
+
x = x + self.pe[:, :x.size(1)]
|
245 |
+
return self.dropout(x)
|
246 |
+
|
247 |
+
|
248 |
+
class VersatileAttention(CrossAttention):
|
249 |
+
def __init__(
|
250 |
+
self,
|
251 |
+
attention_mode = None,
|
252 |
+
cross_frame_attention_mode = None,
|
253 |
+
temporal_position_encoding = False,
|
254 |
+
temporal_position_encoding_max_len = 24,
|
255 |
+
*args, **kwargs
|
256 |
+
):
|
257 |
+
super().__init__(*args, **kwargs)
|
258 |
+
assert attention_mode == "Temporal"
|
259 |
+
|
260 |
+
self.attention_mode = attention_mode
|
261 |
+
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
|
262 |
+
|
263 |
+
self.pos_encoder = PositionalEncoding(
|
264 |
+
kwargs["query_dim"],
|
265 |
+
dropout=0.,
|
266 |
+
max_len=temporal_position_encoding_max_len
|
267 |
+
) if (temporal_position_encoding and attention_mode == "Temporal") else None
|
268 |
+
|
269 |
+
def extra_repr(self):
|
270 |
+
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
|
271 |
+
|
272 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
273 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
274 |
+
|
275 |
+
if self.attention_mode == "Temporal":
|
276 |
+
d = hidden_states.shape[1]
|
277 |
+
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
278 |
+
|
279 |
+
if self.pos_encoder is not None:
|
280 |
+
hidden_states = self.pos_encoder(hidden_states)
|
281 |
+
|
282 |
+
encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
|
283 |
+
else:
|
284 |
+
raise NotImplementedError
|
285 |
+
|
286 |
+
encoder_hidden_states = encoder_hidden_states
|
287 |
+
|
288 |
+
if self.group_norm is not None:
|
289 |
+
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
290 |
+
|
291 |
+
query = self.to_q(hidden_states)
|
292 |
+
dim = query.shape[-1]
|
293 |
+
query = self.reshape_heads_to_batch_dim(query)
|
294 |
+
|
295 |
+
if self.added_kv_proj_dim is not None:
|
296 |
+
raise NotImplementedError
|
297 |
+
|
298 |
+
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
299 |
+
key = self.to_k(encoder_hidden_states)
|
300 |
+
value = self.to_v(encoder_hidden_states)
|
301 |
+
|
302 |
+
key = self.reshape_heads_to_batch_dim(key)
|
303 |
+
value = self.reshape_heads_to_batch_dim(value)
|
304 |
+
|
305 |
+
if attention_mask is not None:
|
306 |
+
if attention_mask.shape[-1] != query.shape[1]:
|
307 |
+
target_length = query.shape[1]
|
308 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
309 |
+
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
310 |
+
|
311 |
+
# attention, what we cannot get enough of
|
312 |
+
if self._use_memory_efficient_attention_xformers:
|
313 |
+
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
314 |
+
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
315 |
+
hidden_states = hidden_states.to(query.dtype)
|
316 |
+
else:
|
317 |
+
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
318 |
+
hidden_states = self._attention(query, key, value, attention_mask)
|
319 |
+
else:
|
320 |
+
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
321 |
+
|
322 |
+
# linear proj
|
323 |
+
hidden_states = self.to_out[0](hidden_states)
|
324 |
+
|
325 |
+
# dropout
|
326 |
+
hidden_states = self.to_out[1](hidden_states)
|
327 |
+
|
328 |
+
if self.attention_mode == "Temporal":
|
329 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
330 |
+
|
331 |
+
return hidden_states
|
animatediff/models/resnet.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
|
10 |
+
class InflatedConv3d(nn.Conv2d):
|
11 |
+
def forward(self, x):
|
12 |
+
video_length = x.shape[2]
|
13 |
+
|
14 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
15 |
+
x = super().forward(x)
|
16 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
17 |
+
|
18 |
+
return x
|
19 |
+
|
20 |
+
|
21 |
+
class InflatedGroupNorm(nn.GroupNorm):
|
22 |
+
def forward(self, x):
|
23 |
+
video_length = x.shape[2]
|
24 |
+
|
25 |
+
x = rearrange(x, "b c f h w -> (b f) c h w")
|
26 |
+
x = super().forward(x)
|
27 |
+
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
28 |
+
|
29 |
+
return x
|
30 |
+
|
31 |
+
|
32 |
+
class Upsample3D(nn.Module):
|
33 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
34 |
+
super().__init__()
|
35 |
+
self.channels = channels
|
36 |
+
self.out_channels = out_channels or channels
|
37 |
+
self.use_conv = use_conv
|
38 |
+
self.use_conv_transpose = use_conv_transpose
|
39 |
+
self.name = name
|
40 |
+
|
41 |
+
conv = None
|
42 |
+
if use_conv_transpose:
|
43 |
+
raise NotImplementedError
|
44 |
+
elif use_conv:
|
45 |
+
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
|
46 |
+
|
47 |
+
def forward(self, hidden_states, output_size=None):
|
48 |
+
assert hidden_states.shape[1] == self.channels
|
49 |
+
|
50 |
+
if self.use_conv_transpose:
|
51 |
+
raise NotImplementedError
|
52 |
+
|
53 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
54 |
+
dtype = hidden_states.dtype
|
55 |
+
if dtype == torch.bfloat16:
|
56 |
+
hidden_states = hidden_states.to(torch.float32)
|
57 |
+
|
58 |
+
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
59 |
+
if hidden_states.shape[0] >= 64:
|
60 |
+
hidden_states = hidden_states.contiguous()
|
61 |
+
|
62 |
+
# if `output_size` is passed we force the interpolation output
|
63 |
+
# size and do not make use of `scale_factor=2`
|
64 |
+
if output_size is None:
|
65 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
|
66 |
+
else:
|
67 |
+
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
68 |
+
|
69 |
+
# If the input is bfloat16, we cast back to bfloat16
|
70 |
+
if dtype == torch.bfloat16:
|
71 |
+
hidden_states = hidden_states.to(dtype)
|
72 |
+
|
73 |
+
# if self.use_conv:
|
74 |
+
# if self.name == "conv":
|
75 |
+
# hidden_states = self.conv(hidden_states)
|
76 |
+
# else:
|
77 |
+
# hidden_states = self.Conv2d_0(hidden_states)
|
78 |
+
hidden_states = self.conv(hidden_states)
|
79 |
+
|
80 |
+
return hidden_states
|
81 |
+
|
82 |
+
|
83 |
+
class Downsample3D(nn.Module):
|
84 |
+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
85 |
+
super().__init__()
|
86 |
+
self.channels = channels
|
87 |
+
self.out_channels = out_channels or channels
|
88 |
+
self.use_conv = use_conv
|
89 |
+
self.padding = padding
|
90 |
+
stride = 2
|
91 |
+
self.name = name
|
92 |
+
|
93 |
+
if use_conv:
|
94 |
+
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
95 |
+
else:
|
96 |
+
raise NotImplementedError
|
97 |
+
|
98 |
+
def forward(self, hidden_states):
|
99 |
+
assert hidden_states.shape[1] == self.channels
|
100 |
+
if self.use_conv and self.padding == 0:
|
101 |
+
raise NotImplementedError
|
102 |
+
|
103 |
+
assert hidden_states.shape[1] == self.channels
|
104 |
+
hidden_states = self.conv(hidden_states)
|
105 |
+
|
106 |
+
return hidden_states
|
107 |
+
|
108 |
+
|
109 |
+
class ResnetBlock3D(nn.Module):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
*,
|
113 |
+
in_channels,
|
114 |
+
out_channels=None,
|
115 |
+
conv_shortcut=False,
|
116 |
+
dropout=0.0,
|
117 |
+
temb_channels=512,
|
118 |
+
groups=32,
|
119 |
+
groups_out=None,
|
120 |
+
pre_norm=True,
|
121 |
+
eps=1e-6,
|
122 |
+
non_linearity="swish",
|
123 |
+
time_embedding_norm="default",
|
124 |
+
output_scale_factor=1.0,
|
125 |
+
use_in_shortcut=None,
|
126 |
+
use_inflated_groupnorm=False,
|
127 |
+
):
|
128 |
+
super().__init__()
|
129 |
+
self.pre_norm = pre_norm
|
130 |
+
self.pre_norm = True
|
131 |
+
self.in_channels = in_channels
|
132 |
+
out_channels = in_channels if out_channels is None else out_channels
|
133 |
+
self.out_channels = out_channels
|
134 |
+
self.use_conv_shortcut = conv_shortcut
|
135 |
+
self.time_embedding_norm = time_embedding_norm
|
136 |
+
self.output_scale_factor = output_scale_factor
|
137 |
+
|
138 |
+
if groups_out is None:
|
139 |
+
groups_out = groups
|
140 |
+
|
141 |
+
assert use_inflated_groupnorm != None
|
142 |
+
if use_inflated_groupnorm:
|
143 |
+
self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
144 |
+
else:
|
145 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
146 |
+
|
147 |
+
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
148 |
+
|
149 |
+
if temb_channels is not None:
|
150 |
+
if self.time_embedding_norm == "default":
|
151 |
+
time_emb_proj_out_channels = out_channels
|
152 |
+
elif self.time_embedding_norm == "scale_shift":
|
153 |
+
time_emb_proj_out_channels = out_channels * 2
|
154 |
+
else:
|
155 |
+
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
156 |
+
|
157 |
+
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
|
158 |
+
else:
|
159 |
+
self.time_emb_proj = None
|
160 |
+
|
161 |
+
if use_inflated_groupnorm:
|
162 |
+
self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
163 |
+
else:
|
164 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
165 |
+
|
166 |
+
self.dropout = torch.nn.Dropout(dropout)
|
167 |
+
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
168 |
+
|
169 |
+
if non_linearity == "swish":
|
170 |
+
self.nonlinearity = lambda x: F.silu(x)
|
171 |
+
elif non_linearity == "mish":
|
172 |
+
self.nonlinearity = Mish()
|
173 |
+
elif non_linearity == "silu":
|
174 |
+
self.nonlinearity = nn.SiLU()
|
175 |
+
|
176 |
+
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
|
177 |
+
|
178 |
+
self.conv_shortcut = None
|
179 |
+
if self.use_in_shortcut:
|
180 |
+
self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
181 |
+
|
182 |
+
def forward(self, input_tensor, temb):
|
183 |
+
hidden_states = input_tensor
|
184 |
+
|
185 |
+
hidden_states = self.norm1(hidden_states)
|
186 |
+
hidden_states = self.nonlinearity(hidden_states)
|
187 |
+
|
188 |
+
hidden_states = self.conv1(hidden_states)
|
189 |
+
|
190 |
+
if temb is not None:
|
191 |
+
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
192 |
+
|
193 |
+
if temb is not None and self.time_embedding_norm == "default":
|
194 |
+
hidden_states = hidden_states + temb
|
195 |
+
|
196 |
+
hidden_states = self.norm2(hidden_states)
|
197 |
+
|
198 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
199 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
200 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
201 |
+
|
202 |
+
hidden_states = self.nonlinearity(hidden_states)
|
203 |
+
|
204 |
+
hidden_states = self.dropout(hidden_states)
|
205 |
+
hidden_states = self.conv2(hidden_states)
|
206 |
+
|
207 |
+
if self.conv_shortcut is not None:
|
208 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
209 |
+
|
210 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
211 |
+
|
212 |
+
return output_tensor
|
213 |
+
|
214 |
+
|
215 |
+
class Mish(torch.nn.Module):
|
216 |
+
def forward(self, hidden_states):
|
217 |
+
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
|
animatediff/models/sparse_controlnet.py
ADDED
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
#
|
15 |
+
# Changes were made to this source code by Yuwei Guo.
|
16 |
+
from dataclasses import dataclass
|
17 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from torch import nn
|
21 |
+
from torch.nn import functional as F
|
22 |
+
|
23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
24 |
+
from diffusers.utils import BaseOutput, logging
|
25 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
26 |
+
# from diffusers.modeling_utils import ModelMixin
|
27 |
+
from diffusers import ModelMixin
|
28 |
+
|
29 |
+
|
30 |
+
from .unet_blocks import (
|
31 |
+
CrossAttnDownBlock3D,
|
32 |
+
DownBlock3D,
|
33 |
+
UNetMidBlock3DCrossAttn,
|
34 |
+
get_down_block,
|
35 |
+
)
|
36 |
+
from einops import repeat, rearrange
|
37 |
+
from .resnet import InflatedConv3d
|
38 |
+
|
39 |
+
from diffusers import UNet2DConditionModel
|
40 |
+
|
41 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
42 |
+
|
43 |
+
|
44 |
+
@dataclass
|
45 |
+
class SparseControlNetOutput(BaseOutput):
|
46 |
+
down_block_res_samples: Tuple[torch.Tensor]
|
47 |
+
mid_block_res_sample: torch.Tensor
|
48 |
+
|
49 |
+
|
50 |
+
class SparseControlNetConditioningEmbedding(nn.Module):
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
conditioning_embedding_channels: int,
|
54 |
+
conditioning_channels: int = 3,
|
55 |
+
block_out_channels: Tuple[int] = (16, 32, 96, 256),
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
|
59 |
+
self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
60 |
+
|
61 |
+
self.blocks = nn.ModuleList([])
|
62 |
+
|
63 |
+
for i in range(len(block_out_channels) - 1):
|
64 |
+
channel_in = block_out_channels[i]
|
65 |
+
channel_out = block_out_channels[i + 1]
|
66 |
+
self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1))
|
67 |
+
self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
|
68 |
+
|
69 |
+
self.conv_out = zero_module(
|
70 |
+
InflatedConv3d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
|
71 |
+
)
|
72 |
+
|
73 |
+
def forward(self, conditioning):
|
74 |
+
embedding = self.conv_in(conditioning)
|
75 |
+
embedding = F.silu(embedding)
|
76 |
+
|
77 |
+
for block in self.blocks:
|
78 |
+
embedding = block(embedding)
|
79 |
+
embedding = F.silu(embedding)
|
80 |
+
|
81 |
+
embedding = self.conv_out(embedding)
|
82 |
+
|
83 |
+
return embedding
|
84 |
+
|
85 |
+
|
86 |
+
class SparseControlNetModel(ModelMixin, ConfigMixin):
|
87 |
+
_supports_gradient_checkpointing = True
|
88 |
+
|
89 |
+
@register_to_config
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
in_channels: int = 4,
|
93 |
+
conditioning_channels: int = 3,
|
94 |
+
flip_sin_to_cos: bool = True,
|
95 |
+
freq_shift: int = 0,
|
96 |
+
down_block_types: Tuple[str] = (
|
97 |
+
"CrossAttnDownBlock2D",
|
98 |
+
"CrossAttnDownBlock2D",
|
99 |
+
"CrossAttnDownBlock2D",
|
100 |
+
"DownBlock2D",
|
101 |
+
),
|
102 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
103 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
104 |
+
layers_per_block: int = 2,
|
105 |
+
downsample_padding: int = 1,
|
106 |
+
mid_block_scale_factor: float = 1,
|
107 |
+
act_fn: str = "silu",
|
108 |
+
norm_num_groups: Optional[int] = 32,
|
109 |
+
norm_eps: float = 1e-5,
|
110 |
+
cross_attention_dim: int = 1280,
|
111 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
112 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
113 |
+
use_linear_projection: bool = False,
|
114 |
+
class_embed_type: Optional[str] = None,
|
115 |
+
num_class_embeds: Optional[int] = None,
|
116 |
+
upcast_attention: bool = False,
|
117 |
+
resnet_time_scale_shift: str = "default",
|
118 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
119 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
120 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
121 |
+
global_pool_conditions: bool = False,
|
122 |
+
|
123 |
+
use_motion_module = True,
|
124 |
+
motion_module_resolutions = ( 1,2,4,8 ),
|
125 |
+
motion_module_mid_block = False,
|
126 |
+
motion_module_type = "Vanilla",
|
127 |
+
motion_module_kwargs = {
|
128 |
+
"num_attention_heads": 8,
|
129 |
+
"num_transformer_block": 1,
|
130 |
+
"attention_block_types": ["Temporal_Self"],
|
131 |
+
"temporal_position_encoding": True,
|
132 |
+
"temporal_position_encoding_max_len": 32,
|
133 |
+
"temporal_attention_dim_div": 1,
|
134 |
+
"causal_temporal_attention": False,
|
135 |
+
},
|
136 |
+
|
137 |
+
concate_conditioning_mask: bool = True,
|
138 |
+
use_simplified_condition_embedding: bool = False,
|
139 |
+
|
140 |
+
set_noisy_sample_input_to_zero: bool = False,
|
141 |
+
):
|
142 |
+
super().__init__()
|
143 |
+
|
144 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
145 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
146 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
147 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
148 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
149 |
+
# which is why we correct for the naming here.
|
150 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
151 |
+
|
152 |
+
# Check inputs
|
153 |
+
if len(block_out_channels) != len(down_block_types):
|
154 |
+
raise ValueError(
|
155 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
156 |
+
)
|
157 |
+
|
158 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
159 |
+
raise ValueError(
|
160 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
161 |
+
)
|
162 |
+
|
163 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
164 |
+
raise ValueError(
|
165 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
166 |
+
)
|
167 |
+
|
168 |
+
# input
|
169 |
+
self.set_noisy_sample_input_to_zero = set_noisy_sample_input_to_zero
|
170 |
+
|
171 |
+
conv_in_kernel = 3
|
172 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
173 |
+
self.conv_in = InflatedConv3d(
|
174 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
175 |
+
)
|
176 |
+
|
177 |
+
if concate_conditioning_mask:
|
178 |
+
conditioning_channels = conditioning_channels + 1
|
179 |
+
self.concate_conditioning_mask = concate_conditioning_mask
|
180 |
+
|
181 |
+
# control net conditioning embedding
|
182 |
+
if use_simplified_condition_embedding:
|
183 |
+
self.controlnet_cond_embedding = zero_module(
|
184 |
+
InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
|
185 |
+
)
|
186 |
+
else:
|
187 |
+
self.controlnet_cond_embedding = SparseControlNetConditioningEmbedding(
|
188 |
+
conditioning_embedding_channels=block_out_channels[0],
|
189 |
+
block_out_channels=conditioning_embedding_out_channels,
|
190 |
+
conditioning_channels=conditioning_channels,
|
191 |
+
)
|
192 |
+
self.use_simplified_condition_embedding = use_simplified_condition_embedding
|
193 |
+
|
194 |
+
# time
|
195 |
+
time_embed_dim = block_out_channels[0] * 4
|
196 |
+
|
197 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
198 |
+
timestep_input_dim = block_out_channels[0]
|
199 |
+
|
200 |
+
self.time_embedding = TimestepEmbedding(
|
201 |
+
timestep_input_dim,
|
202 |
+
time_embed_dim,
|
203 |
+
act_fn=act_fn,
|
204 |
+
)
|
205 |
+
|
206 |
+
# class embedding
|
207 |
+
if class_embed_type is None and num_class_embeds is not None:
|
208 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
209 |
+
elif class_embed_type == "timestep":
|
210 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
211 |
+
elif class_embed_type == "identity":
|
212 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
213 |
+
elif class_embed_type == "projection":
|
214 |
+
if projection_class_embeddings_input_dim is None:
|
215 |
+
raise ValueError(
|
216 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
217 |
+
)
|
218 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
219 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
220 |
+
# 2. it projects from an arbitrary input dimension.
|
221 |
+
#
|
222 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
223 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
224 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
225 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
226 |
+
else:
|
227 |
+
self.class_embedding = None
|
228 |
+
|
229 |
+
|
230 |
+
self.down_blocks = nn.ModuleList([])
|
231 |
+
self.controlnet_down_blocks = nn.ModuleList([])
|
232 |
+
|
233 |
+
if isinstance(only_cross_attention, bool):
|
234 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
235 |
+
|
236 |
+
if isinstance(attention_head_dim, int):
|
237 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
238 |
+
|
239 |
+
if isinstance(num_attention_heads, int):
|
240 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
241 |
+
|
242 |
+
# down
|
243 |
+
output_channel = block_out_channels[0]
|
244 |
+
|
245 |
+
controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
|
246 |
+
controlnet_block = zero_module(controlnet_block)
|
247 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
248 |
+
|
249 |
+
for i, down_block_type in enumerate(down_block_types):
|
250 |
+
res = 2 ** i
|
251 |
+
input_channel = output_channel
|
252 |
+
output_channel = block_out_channels[i]
|
253 |
+
is_final_block = i == len(block_out_channels) - 1
|
254 |
+
|
255 |
+
down_block = get_down_block(
|
256 |
+
down_block_type,
|
257 |
+
num_layers=layers_per_block,
|
258 |
+
in_channels=input_channel,
|
259 |
+
out_channels=output_channel,
|
260 |
+
temb_channels=time_embed_dim,
|
261 |
+
add_downsample=not is_final_block,
|
262 |
+
resnet_eps=norm_eps,
|
263 |
+
resnet_act_fn=act_fn,
|
264 |
+
resnet_groups=norm_num_groups,
|
265 |
+
cross_attention_dim=cross_attention_dim,
|
266 |
+
attn_num_head_channels=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
267 |
+
downsample_padding=downsample_padding,
|
268 |
+
use_linear_projection=use_linear_projection,
|
269 |
+
only_cross_attention=only_cross_attention[i],
|
270 |
+
upcast_attention=upcast_attention,
|
271 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
272 |
+
|
273 |
+
use_inflated_groupnorm=True,
|
274 |
+
|
275 |
+
use_motion_module=use_motion_module and (res in motion_module_resolutions),
|
276 |
+
motion_module_type=motion_module_type,
|
277 |
+
motion_module_kwargs=motion_module_kwargs,
|
278 |
+
)
|
279 |
+
self.down_blocks.append(down_block)
|
280 |
+
|
281 |
+
for _ in range(layers_per_block):
|
282 |
+
controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
|
283 |
+
controlnet_block = zero_module(controlnet_block)
|
284 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
285 |
+
|
286 |
+
if not is_final_block:
|
287 |
+
controlnet_block = InflatedConv3d(output_channel, output_channel, kernel_size=1)
|
288 |
+
controlnet_block = zero_module(controlnet_block)
|
289 |
+
self.controlnet_down_blocks.append(controlnet_block)
|
290 |
+
|
291 |
+
# mid
|
292 |
+
mid_block_channel = block_out_channels[-1]
|
293 |
+
|
294 |
+
controlnet_block = InflatedConv3d(mid_block_channel, mid_block_channel, kernel_size=1)
|
295 |
+
controlnet_block = zero_module(controlnet_block)
|
296 |
+
self.controlnet_mid_block = controlnet_block
|
297 |
+
|
298 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
299 |
+
in_channels=mid_block_channel,
|
300 |
+
temb_channels=time_embed_dim,
|
301 |
+
resnet_eps=norm_eps,
|
302 |
+
resnet_act_fn=act_fn,
|
303 |
+
output_scale_factor=mid_block_scale_factor,
|
304 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
305 |
+
cross_attention_dim=cross_attention_dim,
|
306 |
+
attn_num_head_channels=num_attention_heads[-1],
|
307 |
+
resnet_groups=norm_num_groups,
|
308 |
+
use_linear_projection=use_linear_projection,
|
309 |
+
upcast_attention=upcast_attention,
|
310 |
+
|
311 |
+
use_inflated_groupnorm=True,
|
312 |
+
use_motion_module=use_motion_module and motion_module_mid_block,
|
313 |
+
motion_module_type=motion_module_type,
|
314 |
+
motion_module_kwargs=motion_module_kwargs,
|
315 |
+
)
|
316 |
+
|
317 |
+
@classmethod
|
318 |
+
def from_unet(
|
319 |
+
cls,
|
320 |
+
unet: UNet2DConditionModel,
|
321 |
+
controlnet_conditioning_channel_order: str = "rgb",
|
322 |
+
conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
|
323 |
+
load_weights_from_unet: bool = True,
|
324 |
+
|
325 |
+
controlnet_additional_kwargs: dict = {},
|
326 |
+
):
|
327 |
+
controlnet = cls(
|
328 |
+
in_channels=unet.config.in_channels,
|
329 |
+
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
330 |
+
freq_shift=unet.config.freq_shift,
|
331 |
+
down_block_types=unet.config.down_block_types,
|
332 |
+
only_cross_attention=unet.config.only_cross_attention,
|
333 |
+
block_out_channels=unet.config.block_out_channels,
|
334 |
+
layers_per_block=unet.config.layers_per_block,
|
335 |
+
downsample_padding=unet.config.downsample_padding,
|
336 |
+
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
337 |
+
act_fn=unet.config.act_fn,
|
338 |
+
norm_num_groups=unet.config.norm_num_groups,
|
339 |
+
norm_eps=unet.config.norm_eps,
|
340 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
341 |
+
attention_head_dim=unet.config.attention_head_dim,
|
342 |
+
num_attention_heads=unet.config.num_attention_heads,
|
343 |
+
use_linear_projection=unet.config.use_linear_projection,
|
344 |
+
class_embed_type=unet.config.class_embed_type,
|
345 |
+
num_class_embeds=unet.config.num_class_embeds,
|
346 |
+
upcast_attention=unet.config.upcast_attention,
|
347 |
+
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
348 |
+
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
349 |
+
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
350 |
+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
351 |
+
|
352 |
+
**controlnet_additional_kwargs,
|
353 |
+
)
|
354 |
+
|
355 |
+
if load_weights_from_unet:
|
356 |
+
m, u = controlnet.conv_in.load_state_dict(cls.image_layer_filter(unet.conv_in.state_dict()), strict=False)
|
357 |
+
assert len(u) == 0
|
358 |
+
m, u = controlnet.time_proj.load_state_dict(cls.image_layer_filter(unet.time_proj.state_dict()), strict=False)
|
359 |
+
assert len(u) == 0
|
360 |
+
m, u = controlnet.time_embedding.load_state_dict(cls.image_layer_filter(unet.time_embedding.state_dict()), strict=False)
|
361 |
+
assert len(u) == 0
|
362 |
+
|
363 |
+
if controlnet.class_embedding:
|
364 |
+
m, u = controlnet.class_embedding.load_state_dict(cls.image_layer_filter(unet.class_embedding.state_dict()), strict=False)
|
365 |
+
assert len(u) == 0
|
366 |
+
m, u = controlnet.down_blocks.load_state_dict(cls.image_layer_filter(unet.down_blocks.state_dict()), strict=False)
|
367 |
+
assert len(u) == 0
|
368 |
+
m, u = controlnet.mid_block.load_state_dict(cls.image_layer_filter(unet.mid_block.state_dict()), strict=False)
|
369 |
+
assert len(u) == 0
|
370 |
+
|
371 |
+
return controlnet
|
372 |
+
|
373 |
+
@staticmethod
|
374 |
+
def image_layer_filter(state_dict):
|
375 |
+
new_state_dict = {}
|
376 |
+
for name, param in state_dict.items():
|
377 |
+
if "motion_modules." in name or "lora" in name: continue
|
378 |
+
new_state_dict[name] = param
|
379 |
+
return new_state_dict
|
380 |
+
|
381 |
+
# Copied from diffusers.models.UNet2DConditionModel.set_attention_slice
|
382 |
+
def set_attention_slice(self, slice_size):
|
383 |
+
r"""
|
384 |
+
Enable sliced attention computation.
|
385 |
+
|
386 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
387 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
388 |
+
|
389 |
+
Args:
|
390 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
391 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
392 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
393 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
394 |
+
must be a multiple of `slice_size`.
|
395 |
+
"""
|
396 |
+
sliceable_head_dims = []
|
397 |
+
|
398 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
399 |
+
if hasattr(module, "set_attention_slice"):
|
400 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
401 |
+
|
402 |
+
for child in module.children():
|
403 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
404 |
+
|
405 |
+
# retrieve number of attention layers
|
406 |
+
for module in self.children():
|
407 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
408 |
+
|
409 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
410 |
+
|
411 |
+
if slice_size == "auto":
|
412 |
+
# half the attention head size is usually a good trade-off between
|
413 |
+
# speed and memory
|
414 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
415 |
+
elif slice_size == "max":
|
416 |
+
# make smallest slice possible
|
417 |
+
slice_size = num_sliceable_layers * [1]
|
418 |
+
|
419 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
420 |
+
|
421 |
+
if len(slice_size) != len(sliceable_head_dims):
|
422 |
+
raise ValueError(
|
423 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
424 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
425 |
+
)
|
426 |
+
|
427 |
+
for i in range(len(slice_size)):
|
428 |
+
size = slice_size[i]
|
429 |
+
dim = sliceable_head_dims[i]
|
430 |
+
if size is not None and size > dim:
|
431 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
432 |
+
|
433 |
+
# Recursively walk through all the children.
|
434 |
+
# Any children which exposes the set_attention_slice method
|
435 |
+
# gets the message
|
436 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
437 |
+
if hasattr(module, "set_attention_slice"):
|
438 |
+
module.set_attention_slice(slice_size.pop())
|
439 |
+
|
440 |
+
for child in module.children():
|
441 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
442 |
+
|
443 |
+
reversed_slice_size = list(reversed(slice_size))
|
444 |
+
for module in self.children():
|
445 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
446 |
+
|
447 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
448 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
449 |
+
module.gradient_checkpointing = value
|
450 |
+
|
451 |
+
def forward(
|
452 |
+
self,
|
453 |
+
sample: torch.FloatTensor,
|
454 |
+
timestep: Union[torch.Tensor, float, int],
|
455 |
+
encoder_hidden_states: torch.Tensor,
|
456 |
+
|
457 |
+
controlnet_cond: torch.FloatTensor,
|
458 |
+
conditioning_mask: Optional[torch.FloatTensor] = None,
|
459 |
+
|
460 |
+
conditioning_scale: float = 1.0,
|
461 |
+
class_labels: Optional[torch.Tensor] = None,
|
462 |
+
attention_mask: Optional[torch.Tensor] = None,
|
463 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
464 |
+
guess_mode: bool = False,
|
465 |
+
return_dict: bool = True,
|
466 |
+
) -> Union[SparseControlNetOutput, Tuple]:
|
467 |
+
|
468 |
+
# set input noise to zero
|
469 |
+
if self.set_noisy_sample_input_to_zero:
|
470 |
+
sample = torch.zeros_like(sample).to(sample.device)
|
471 |
+
|
472 |
+
# prepare attention_mask
|
473 |
+
if attention_mask is not None:
|
474 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
475 |
+
attention_mask = attention_mask.unsqueeze(1)
|
476 |
+
|
477 |
+
# 1. time
|
478 |
+
timesteps = timestep
|
479 |
+
if not torch.is_tensor(timesteps):
|
480 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
481 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
482 |
+
is_mps = sample.device.type == "mps"
|
483 |
+
if isinstance(timestep, float):
|
484 |
+
dtype = torch.float32 if is_mps else torch.float64
|
485 |
+
else:
|
486 |
+
dtype = torch.int32 if is_mps else torch.int64
|
487 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
488 |
+
elif len(timesteps.shape) == 0:
|
489 |
+
timesteps = timesteps[None].to(sample.device)
|
490 |
+
|
491 |
+
timesteps = timesteps.repeat(sample.shape[0] // timesteps.shape[0])
|
492 |
+
encoder_hidden_states = encoder_hidden_states.repeat(sample.shape[0] // encoder_hidden_states.shape[0], 1, 1)
|
493 |
+
|
494 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
495 |
+
timesteps = timesteps.expand(sample.shape[0])
|
496 |
+
|
497 |
+
t_emb = self.time_proj(timesteps)
|
498 |
+
|
499 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
500 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
501 |
+
# there might be better ways to encapsulate this.
|
502 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
503 |
+
emb = self.time_embedding(t_emb)
|
504 |
+
|
505 |
+
if self.class_embedding is not None:
|
506 |
+
if class_labels is None:
|
507 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
508 |
+
|
509 |
+
if self.config.class_embed_type == "timestep":
|
510 |
+
class_labels = self.time_proj(class_labels)
|
511 |
+
|
512 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
513 |
+
emb = emb + class_emb
|
514 |
+
|
515 |
+
# 2. pre-process
|
516 |
+
sample = self.conv_in(sample)
|
517 |
+
if self.concate_conditioning_mask:
|
518 |
+
controlnet_cond = torch.cat([controlnet_cond, conditioning_mask], dim=1)
|
519 |
+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
520 |
+
|
521 |
+
sample = sample + controlnet_cond
|
522 |
+
|
523 |
+
# 3. down
|
524 |
+
down_block_res_samples = (sample,)
|
525 |
+
for downsample_block in self.down_blocks:
|
526 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
527 |
+
sample, res_samples = downsample_block(
|
528 |
+
hidden_states=sample,
|
529 |
+
temb=emb,
|
530 |
+
encoder_hidden_states=encoder_hidden_states,
|
531 |
+
attention_mask=attention_mask,
|
532 |
+
# cross_attention_kwargs=cross_attention_kwargs,
|
533 |
+
)
|
534 |
+
else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
535 |
+
|
536 |
+
down_block_res_samples += res_samples
|
537 |
+
|
538 |
+
# 4. mid
|
539 |
+
if self.mid_block is not None:
|
540 |
+
sample = self.mid_block(
|
541 |
+
sample,
|
542 |
+
emb,
|
543 |
+
encoder_hidden_states=encoder_hidden_states,
|
544 |
+
attention_mask=attention_mask,
|
545 |
+
# cross_attention_kwargs=cross_attention_kwargs,
|
546 |
+
)
|
547 |
+
|
548 |
+
# 5. controlnet blocks
|
549 |
+
controlnet_down_block_res_samples = ()
|
550 |
+
|
551 |
+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
552 |
+
down_block_res_sample = controlnet_block(down_block_res_sample)
|
553 |
+
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
554 |
+
|
555 |
+
down_block_res_samples = controlnet_down_block_res_samples
|
556 |
+
|
557 |
+
mid_block_res_sample = self.controlnet_mid_block(sample)
|
558 |
+
|
559 |
+
# 6. scaling
|
560 |
+
if guess_mode and not self.config.global_pool_conditions:
|
561 |
+
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
562 |
+
|
563 |
+
scales = scales * conditioning_scale
|
564 |
+
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
565 |
+
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
566 |
+
else:
|
567 |
+
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
568 |
+
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
569 |
+
|
570 |
+
if self.config.global_pool_conditions:
|
571 |
+
down_block_res_samples = [
|
572 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
573 |
+
]
|
574 |
+
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
575 |
+
|
576 |
+
if not return_dict:
|
577 |
+
return (down_block_res_samples, mid_block_res_sample)
|
578 |
+
|
579 |
+
return SparseControlNetOutput(
|
580 |
+
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
581 |
+
)
|
582 |
+
|
583 |
+
|
584 |
+
def zero_module(module):
|
585 |
+
for p in module.parameters():
|
586 |
+
nn.init.zeros_(p)
|
587 |
+
return module
|
animatediff/models/unet.py
ADDED
@@ -0,0 +1,600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import List, Optional, Tuple, Union,Dict
|
5 |
+
|
6 |
+
import os
|
7 |
+
import json
|
8 |
+
import pdb
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.utils.checkpoint
|
13 |
+
|
14 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
15 |
+
from diffusers import ModelMixin
|
16 |
+
from diffusers.utils import BaseOutput, logging
|
17 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
18 |
+
from .unet_blocks import (
|
19 |
+
CrossAttnDownBlock3D,
|
20 |
+
CrossAttnUpBlock3D,
|
21 |
+
DownBlock3D,
|
22 |
+
UNetMidBlock3DCrossAttn,
|
23 |
+
UpBlock3D,
|
24 |
+
get_down_block,
|
25 |
+
get_up_block,
|
26 |
+
)
|
27 |
+
from .resnet import InflatedConv3d, InflatedGroupNorm
|
28 |
+
from diffusers.models.attention_processor import (
|
29 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
30 |
+
CROSS_ATTENTION_PROCESSORS,
|
31 |
+
Attention,
|
32 |
+
AttentionProcessor,
|
33 |
+
AttnAddedKVProcessor,
|
34 |
+
AttnProcessor,
|
35 |
+
)
|
36 |
+
|
37 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class UNet3DConditionOutput(BaseOutput):
|
42 |
+
sample: torch.FloatTensor
|
43 |
+
|
44 |
+
|
45 |
+
class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
46 |
+
_supports_gradient_checkpointing = True
|
47 |
+
|
48 |
+
@register_to_config
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
sample_size: Optional[int] = None,
|
52 |
+
in_channels: int = 4,
|
53 |
+
out_channels: int = 4,
|
54 |
+
center_input_sample: bool = False,
|
55 |
+
flip_sin_to_cos: bool = True,
|
56 |
+
freq_shift: int = 0,
|
57 |
+
down_block_types: Tuple[str] = (
|
58 |
+
"CrossAttnDownBlock3D",
|
59 |
+
"CrossAttnDownBlock3D",
|
60 |
+
"CrossAttnDownBlock3D",
|
61 |
+
"DownBlock3D",
|
62 |
+
),
|
63 |
+
mid_block_type: str = "UNetMidBlock3DCrossAttn",
|
64 |
+
up_block_types: Tuple[str] = (
|
65 |
+
"UpBlock3D",
|
66 |
+
"CrossAttnUpBlock3D",
|
67 |
+
"CrossAttnUpBlock3D",
|
68 |
+
"CrossAttnUpBlock3D"
|
69 |
+
),
|
70 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
71 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
72 |
+
layers_per_block: int = 2,
|
73 |
+
downsample_padding: int = 1,
|
74 |
+
mid_block_scale_factor: float = 1,
|
75 |
+
act_fn: str = "silu",
|
76 |
+
norm_num_groups: int = 32,
|
77 |
+
norm_eps: float = 1e-5,
|
78 |
+
cross_attention_dim: int = 1280,
|
79 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
80 |
+
dual_cross_attention: bool = False,
|
81 |
+
use_linear_projection: bool = False,
|
82 |
+
class_embed_type: Optional[str] = None,
|
83 |
+
num_class_embeds: Optional[int] = None,
|
84 |
+
upcast_attention: bool = False,
|
85 |
+
resnet_time_scale_shift: str = "default",
|
86 |
+
|
87 |
+
use_inflated_groupnorm=False,
|
88 |
+
|
89 |
+
# Additional
|
90 |
+
use_motion_module = False,
|
91 |
+
motion_module_resolutions = ( 1,2,4,8 ),
|
92 |
+
motion_module_mid_block = False,
|
93 |
+
motion_module_decoder_only = False,
|
94 |
+
motion_module_type = None,
|
95 |
+
motion_module_kwargs = {},
|
96 |
+
unet_use_cross_frame_attention = False,
|
97 |
+
unet_use_temporal_attention = False,
|
98 |
+
):
|
99 |
+
super().__init__()
|
100 |
+
|
101 |
+
self.sample_size = sample_size
|
102 |
+
time_embed_dim = block_out_channels[0] * 4
|
103 |
+
|
104 |
+
# input
|
105 |
+
self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
106 |
+
|
107 |
+
# time
|
108 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
109 |
+
timestep_input_dim = block_out_channels[0]
|
110 |
+
|
111 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
112 |
+
|
113 |
+
# class embedding
|
114 |
+
if class_embed_type is None and num_class_embeds is not None:
|
115 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
116 |
+
elif class_embed_type == "timestep":
|
117 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
118 |
+
elif class_embed_type == "identity":
|
119 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
120 |
+
else:
|
121 |
+
self.class_embedding = None
|
122 |
+
|
123 |
+
self.down_blocks = nn.ModuleList([])
|
124 |
+
self.mid_block = None
|
125 |
+
self.up_blocks = nn.ModuleList([])
|
126 |
+
|
127 |
+
if isinstance(only_cross_attention, bool):
|
128 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
129 |
+
|
130 |
+
if isinstance(attention_head_dim, int):
|
131 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
132 |
+
|
133 |
+
# down
|
134 |
+
output_channel = block_out_channels[0]
|
135 |
+
for i, down_block_type in enumerate(down_block_types):
|
136 |
+
res = 2 ** i
|
137 |
+
input_channel = output_channel
|
138 |
+
output_channel = block_out_channels[i]
|
139 |
+
is_final_block = i == len(block_out_channels) - 1
|
140 |
+
|
141 |
+
down_block = get_down_block(
|
142 |
+
down_block_type,
|
143 |
+
num_layers=layers_per_block,
|
144 |
+
in_channels=input_channel,
|
145 |
+
out_channels=output_channel,
|
146 |
+
temb_channels=time_embed_dim,
|
147 |
+
add_downsample=not is_final_block,
|
148 |
+
resnet_eps=norm_eps,
|
149 |
+
resnet_act_fn=act_fn,
|
150 |
+
resnet_groups=norm_num_groups,
|
151 |
+
cross_attention_dim=cross_attention_dim,
|
152 |
+
attn_num_head_channels=attention_head_dim[i],
|
153 |
+
downsample_padding=downsample_padding,
|
154 |
+
dual_cross_attention=dual_cross_attention,
|
155 |
+
use_linear_projection=use_linear_projection,
|
156 |
+
only_cross_attention=only_cross_attention[i],
|
157 |
+
upcast_attention=upcast_attention,
|
158 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
159 |
+
|
160 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
161 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
162 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
163 |
+
|
164 |
+
use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
|
165 |
+
motion_module_type=motion_module_type,
|
166 |
+
motion_module_kwargs=motion_module_kwargs,
|
167 |
+
)
|
168 |
+
self.down_blocks.append(down_block)
|
169 |
+
|
170 |
+
# mid
|
171 |
+
if mid_block_type == "UNetMidBlock3DCrossAttn":
|
172 |
+
self.mid_block = UNetMidBlock3DCrossAttn(
|
173 |
+
in_channels=block_out_channels[-1],
|
174 |
+
temb_channels=time_embed_dim,
|
175 |
+
resnet_eps=norm_eps,
|
176 |
+
resnet_act_fn=act_fn,
|
177 |
+
output_scale_factor=mid_block_scale_factor,
|
178 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
179 |
+
cross_attention_dim=cross_attention_dim,
|
180 |
+
attn_num_head_channels=attention_head_dim[-1],
|
181 |
+
resnet_groups=norm_num_groups,
|
182 |
+
dual_cross_attention=dual_cross_attention,
|
183 |
+
use_linear_projection=use_linear_projection,
|
184 |
+
upcast_attention=upcast_attention,
|
185 |
+
|
186 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
187 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
188 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
189 |
+
|
190 |
+
use_motion_module=use_motion_module and motion_module_mid_block,
|
191 |
+
motion_module_type=motion_module_type,
|
192 |
+
motion_module_kwargs=motion_module_kwargs,
|
193 |
+
)
|
194 |
+
else:
|
195 |
+
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
196 |
+
|
197 |
+
# count how many layers upsample the videos
|
198 |
+
self.num_upsamplers = 0
|
199 |
+
|
200 |
+
# up
|
201 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
202 |
+
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
203 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
204 |
+
output_channel = reversed_block_out_channels[0]
|
205 |
+
for i, up_block_type in enumerate(up_block_types):
|
206 |
+
res = 2 ** (3 - i)
|
207 |
+
is_final_block = i == len(block_out_channels) - 1
|
208 |
+
|
209 |
+
prev_output_channel = output_channel
|
210 |
+
output_channel = reversed_block_out_channels[i]
|
211 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
212 |
+
|
213 |
+
# add upsample block for all BUT final layer
|
214 |
+
if not is_final_block:
|
215 |
+
add_upsample = True
|
216 |
+
self.num_upsamplers += 1
|
217 |
+
else:
|
218 |
+
add_upsample = False
|
219 |
+
|
220 |
+
up_block = get_up_block(
|
221 |
+
up_block_type,
|
222 |
+
num_layers=layers_per_block + 1,
|
223 |
+
in_channels=input_channel,
|
224 |
+
out_channels=output_channel,
|
225 |
+
prev_output_channel=prev_output_channel,
|
226 |
+
temb_channels=time_embed_dim,
|
227 |
+
add_upsample=add_upsample,
|
228 |
+
resnet_eps=norm_eps,
|
229 |
+
resnet_act_fn=act_fn,
|
230 |
+
resnet_groups=norm_num_groups,
|
231 |
+
cross_attention_dim=cross_attention_dim,
|
232 |
+
attn_num_head_channels=reversed_attention_head_dim[i],
|
233 |
+
dual_cross_attention=dual_cross_attention,
|
234 |
+
use_linear_projection=use_linear_projection,
|
235 |
+
only_cross_attention=only_cross_attention[i],
|
236 |
+
upcast_attention=upcast_attention,
|
237 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
238 |
+
|
239 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
240 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
241 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
242 |
+
|
243 |
+
use_motion_module=use_motion_module and (res in motion_module_resolutions),
|
244 |
+
motion_module_type=motion_module_type,
|
245 |
+
motion_module_kwargs=motion_module_kwargs,
|
246 |
+
)
|
247 |
+
self.up_blocks.append(up_block)
|
248 |
+
prev_output_channel = output_channel
|
249 |
+
|
250 |
+
# out
|
251 |
+
if use_inflated_groupnorm:
|
252 |
+
self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
253 |
+
else:
|
254 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
255 |
+
self.conv_act = nn.SiLU()
|
256 |
+
self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
257 |
+
|
258 |
+
def set_attention_slice(self, slice_size):
|
259 |
+
r"""
|
260 |
+
Enable sliced attention computation.
|
261 |
+
|
262 |
+
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
263 |
+
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
264 |
+
|
265 |
+
Args:
|
266 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
267 |
+
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
268 |
+
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
269 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
270 |
+
must be a multiple of `slice_size`.
|
271 |
+
"""
|
272 |
+
sliceable_head_dims = []
|
273 |
+
|
274 |
+
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
275 |
+
if hasattr(module, "set_attention_slice"):
|
276 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
277 |
+
|
278 |
+
for child in module.children():
|
279 |
+
fn_recursive_retrieve_slicable_dims(child)
|
280 |
+
|
281 |
+
# retrieve number of attention layers
|
282 |
+
for module in self.children():
|
283 |
+
fn_recursive_retrieve_slicable_dims(module)
|
284 |
+
|
285 |
+
num_slicable_layers = len(sliceable_head_dims)
|
286 |
+
|
287 |
+
if slice_size == "auto":
|
288 |
+
# half the attention head size is usually a good trade-off between
|
289 |
+
# speed and memory
|
290 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
291 |
+
elif slice_size == "max":
|
292 |
+
# make smallest slice possible
|
293 |
+
slice_size = num_slicable_layers * [1]
|
294 |
+
|
295 |
+
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
296 |
+
|
297 |
+
if len(slice_size) != len(sliceable_head_dims):
|
298 |
+
raise ValueError(
|
299 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
300 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
301 |
+
)
|
302 |
+
|
303 |
+
for i in range(len(slice_size)):
|
304 |
+
size = slice_size[i]
|
305 |
+
dim = sliceable_head_dims[i]
|
306 |
+
if size is not None and size > dim:
|
307 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
308 |
+
|
309 |
+
# Recursively walk through all the children.
|
310 |
+
# Any children which exposes the set_attention_slice method
|
311 |
+
# gets the message
|
312 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
313 |
+
if hasattr(module, "set_attention_slice"):
|
314 |
+
module.set_attention_slice(slice_size.pop())
|
315 |
+
|
316 |
+
for child in module.children():
|
317 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
318 |
+
|
319 |
+
reversed_slice_size = list(reversed(slice_size))
|
320 |
+
for module in self.children():
|
321 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
322 |
+
|
323 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
324 |
+
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
325 |
+
module.gradient_checkpointing = value
|
326 |
+
|
327 |
+
def forward(
|
328 |
+
self,
|
329 |
+
sample: torch.FloatTensor,
|
330 |
+
timestep: Union[torch.Tensor, float, int],
|
331 |
+
encoder_hidden_states: torch.Tensor,
|
332 |
+
class_labels: Optional[torch.Tensor] = None,
|
333 |
+
attention_mask: Optional[torch.Tensor] = None,
|
334 |
+
|
335 |
+
# support controlnet
|
336 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
337 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
338 |
+
|
339 |
+
return_dict: bool = True,
|
340 |
+
) -> Union[UNet3DConditionOutput, Tuple]:
|
341 |
+
r"""
|
342 |
+
Args:
|
343 |
+
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
344 |
+
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
345 |
+
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
346 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
347 |
+
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
348 |
+
|
349 |
+
Returns:
|
350 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
351 |
+
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
352 |
+
returning a tuple, the first element is the sample tensor.
|
353 |
+
"""
|
354 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
355 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
356 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
357 |
+
# on the fly if necessary.
|
358 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
359 |
+
|
360 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
361 |
+
forward_upsample_size = False
|
362 |
+
upsample_size = None
|
363 |
+
|
364 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
365 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
366 |
+
forward_upsample_size = True
|
367 |
+
|
368 |
+
# prepare attention_mask
|
369 |
+
if attention_mask is not None:
|
370 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
371 |
+
attention_mask = attention_mask.unsqueeze(1)
|
372 |
+
|
373 |
+
# center input if necessary
|
374 |
+
if self.config.center_input_sample:
|
375 |
+
sample = 2 * sample - 1.0
|
376 |
+
|
377 |
+
# time
|
378 |
+
timesteps = timestep
|
379 |
+
if not torch.is_tensor(timesteps):
|
380 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
381 |
+
is_mps = sample.device.type == "mps"
|
382 |
+
if isinstance(timestep, float):
|
383 |
+
dtype = torch.float32 if is_mps else torch.float64
|
384 |
+
else:
|
385 |
+
dtype = torch.int32 if is_mps else torch.int64
|
386 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
387 |
+
elif len(timesteps.shape) == 0:
|
388 |
+
timesteps = timesteps[None].to(sample.device)
|
389 |
+
|
390 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
391 |
+
timesteps = timesteps.expand(sample.shape[0])
|
392 |
+
|
393 |
+
t_emb = self.time_proj(timesteps)
|
394 |
+
|
395 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
396 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
397 |
+
# there might be better ways to encapsulate this.
|
398 |
+
t_emb = t_emb.to(dtype=self.dtype)
|
399 |
+
emb = self.time_embedding(t_emb)
|
400 |
+
|
401 |
+
if self.class_embedding is not None:
|
402 |
+
if class_labels is None:
|
403 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
404 |
+
|
405 |
+
if self.config.class_embed_type == "timestep":
|
406 |
+
class_labels = self.time_proj(class_labels)
|
407 |
+
|
408 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
409 |
+
emb = emb + class_emb
|
410 |
+
|
411 |
+
# pre-process
|
412 |
+
sample = self.conv_in(sample)
|
413 |
+
|
414 |
+
# down
|
415 |
+
down_block_res_samples = (sample,)
|
416 |
+
for downsample_block in self.down_blocks:
|
417 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
418 |
+
sample, res_samples = downsample_block(
|
419 |
+
hidden_states=sample,
|
420 |
+
temb=emb,
|
421 |
+
encoder_hidden_states=encoder_hidden_states,
|
422 |
+
attention_mask=attention_mask,
|
423 |
+
)
|
424 |
+
else:
|
425 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
|
426 |
+
|
427 |
+
down_block_res_samples += res_samples
|
428 |
+
|
429 |
+
# support controlnet
|
430 |
+
down_block_res_samples = list(down_block_res_samples)
|
431 |
+
if down_block_additional_residuals is not None:
|
432 |
+
for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
|
433 |
+
if down_block_additional_residual.dim() == 4: # boardcast
|
434 |
+
down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
|
435 |
+
down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
|
436 |
+
|
437 |
+
# mid
|
438 |
+
sample = self.mid_block(
|
439 |
+
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
440 |
+
)
|
441 |
+
|
442 |
+
# support controlnet
|
443 |
+
if mid_block_additional_residual is not None:
|
444 |
+
if mid_block_additional_residual.dim() == 4: # boardcast
|
445 |
+
mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
|
446 |
+
sample = sample + mid_block_additional_residual
|
447 |
+
|
448 |
+
# up
|
449 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
450 |
+
is_final_block = i == len(self.up_blocks) - 1
|
451 |
+
|
452 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
453 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
454 |
+
|
455 |
+
# if we have not reached the final block and need to forward the
|
456 |
+
# upsample size, we do it here
|
457 |
+
if not is_final_block and forward_upsample_size:
|
458 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
459 |
+
|
460 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
461 |
+
sample = upsample_block(
|
462 |
+
hidden_states=sample,
|
463 |
+
temb=emb,
|
464 |
+
res_hidden_states_tuple=res_samples,
|
465 |
+
encoder_hidden_states=encoder_hidden_states,
|
466 |
+
upsample_size=upsample_size,
|
467 |
+
attention_mask=attention_mask,
|
468 |
+
)
|
469 |
+
else:
|
470 |
+
sample = upsample_block(
|
471 |
+
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
|
472 |
+
)
|
473 |
+
|
474 |
+
# post-process
|
475 |
+
sample = self.conv_norm_out(sample)
|
476 |
+
sample = self.conv_act(sample)
|
477 |
+
sample = self.conv_out(sample)
|
478 |
+
|
479 |
+
if not return_dict:
|
480 |
+
return (sample,)
|
481 |
+
|
482 |
+
return UNet3DConditionOutput(sample=sample)
|
483 |
+
|
484 |
+
@classmethod
|
485 |
+
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
|
486 |
+
if subfolder is not None:
|
487 |
+
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
488 |
+
print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...")
|
489 |
+
|
490 |
+
config_file = os.path.join(pretrained_model_path, 'config.json')
|
491 |
+
if not os.path.isfile(config_file):
|
492 |
+
raise RuntimeError(f"{config_file} does not exist")
|
493 |
+
with open(config_file, "r") as f:
|
494 |
+
config = json.load(f)
|
495 |
+
config["_class_name"] = cls.__name__
|
496 |
+
config["down_block_types"] = [
|
497 |
+
"CrossAttnDownBlock3D",
|
498 |
+
"CrossAttnDownBlock3D",
|
499 |
+
"CrossAttnDownBlock3D",
|
500 |
+
"DownBlock3D"
|
501 |
+
]
|
502 |
+
config["up_block_types"] = [
|
503 |
+
"UpBlock3D",
|
504 |
+
"CrossAttnUpBlock3D",
|
505 |
+
"CrossAttnUpBlock3D",
|
506 |
+
"CrossAttnUpBlock3D"
|
507 |
+
]
|
508 |
+
# config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
|
509 |
+
from diffusers.utils import WEIGHTS_NAME
|
510 |
+
model = cls.from_config(config, **unet_additional_kwargs)
|
511 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
512 |
+
# from safetensors import safe_open
|
513 |
+
# state_dict={}
|
514 |
+
# # model_file = "/ssd1/hexuanhua/AnimateDiff/outputs/training-2024-02-17T10-07-50/checkpoints/checkpoint.ckpt"
|
515 |
+
# with safe_open("/home/zjy/data/hexuanhua/huggingface_model/hub/models--SG161222--Realistic_Vision_V4.0_noVAE/snapshots/1bd8c538b40236e642a1427ed154a50ef5bdd3df/unet/diffusion_pytorch_model.safetensors", framework="pt", device="cpu") as f:
|
516 |
+
# for key in f.keys():
|
517 |
+
# state_dict[key] = f.get_tensor(key)
|
518 |
+
|
519 |
+
if not os.path.isfile(model_file):
|
520 |
+
raise RuntimeError(f"{model_file} does not exist")
|
521 |
+
state_dict = torch.load(model_file, map_location="cpu")
|
522 |
+
|
523 |
+
m, u = model.load_state_dict(state_dict, strict=False)
|
524 |
+
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
525 |
+
|
526 |
+
params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()]
|
527 |
+
print(f"### Motion Module Parameters: {sum(params) / 1e6} M")
|
528 |
+
|
529 |
+
return model
|
530 |
+
@property
|
531 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
532 |
+
r"""
|
533 |
+
Returns:
|
534 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
535 |
+
indexed by its weight name.
|
536 |
+
"""
|
537 |
+
# set recursively
|
538 |
+
processors = {}
|
539 |
+
|
540 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
541 |
+
if hasattr(module, "get_processor"):
|
542 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
543 |
+
|
544 |
+
for sub_name, child in module.named_children():
|
545 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
546 |
+
|
547 |
+
return processors
|
548 |
+
|
549 |
+
for name, module in self.named_children():
|
550 |
+
fn_recursive_add_processors(name, module, processors)
|
551 |
+
|
552 |
+
return processors
|
553 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
554 |
+
r"""
|
555 |
+
Sets the attention processor to use to compute attention.
|
556 |
+
|
557 |
+
Parameters:
|
558 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
559 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
560 |
+
for **all** `Attention` layers.
|
561 |
+
|
562 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
563 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
564 |
+
|
565 |
+
"""
|
566 |
+
count = len(self.attn_processors.keys())
|
567 |
+
|
568 |
+
if isinstance(processor, dict) and len(processor) != count:
|
569 |
+
raise ValueError(
|
570 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
571 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
572 |
+
)
|
573 |
+
|
574 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
575 |
+
if hasattr(module, "set_processor"):
|
576 |
+
if not isinstance(processor, dict):
|
577 |
+
module.set_processor(processor)
|
578 |
+
else:
|
579 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
580 |
+
|
581 |
+
for sub_name, child in module.named_children():
|
582 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
583 |
+
|
584 |
+
for name, module in self.named_children():
|
585 |
+
fn_recursive_attn_processor(name, module, processor)
|
586 |
+
|
587 |
+
def set_default_attn_processor(self):
|
588 |
+
"""
|
589 |
+
Disables custom attention processors and sets the default attention implementation.
|
590 |
+
"""
|
591 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
592 |
+
processor = AttnAddedKVProcessor()
|
593 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
594 |
+
processor = AttnProcessor()
|
595 |
+
else:
|
596 |
+
raise ValueError(
|
597 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
598 |
+
)
|
599 |
+
|
600 |
+
self.set_attn_processor(processor)
|
animatediff/models/unet_blocks.py
ADDED
@@ -0,0 +1,760 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from .attention import Transformer3DModel
|
7 |
+
from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
|
8 |
+
from .motion_module import get_motion_module
|
9 |
+
|
10 |
+
import pdb
|
11 |
+
|
12 |
+
def get_down_block(
|
13 |
+
down_block_type,
|
14 |
+
num_layers,
|
15 |
+
in_channels,
|
16 |
+
out_channels,
|
17 |
+
temb_channels,
|
18 |
+
add_downsample,
|
19 |
+
resnet_eps,
|
20 |
+
resnet_act_fn,
|
21 |
+
attn_num_head_channels,
|
22 |
+
resnet_groups=None,
|
23 |
+
cross_attention_dim=None,
|
24 |
+
downsample_padding=None,
|
25 |
+
dual_cross_attention=False,
|
26 |
+
use_linear_projection=False,
|
27 |
+
only_cross_attention=False,
|
28 |
+
upcast_attention=False,
|
29 |
+
resnet_time_scale_shift="default",
|
30 |
+
|
31 |
+
unet_use_cross_frame_attention=False,
|
32 |
+
unet_use_temporal_attention=False,
|
33 |
+
use_inflated_groupnorm=False,
|
34 |
+
|
35 |
+
use_motion_module=None,
|
36 |
+
|
37 |
+
motion_module_type=None,
|
38 |
+
motion_module_kwargs=None,
|
39 |
+
):
|
40 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
41 |
+
if down_block_type == "DownBlock3D":
|
42 |
+
return DownBlock3D(
|
43 |
+
num_layers=num_layers,
|
44 |
+
in_channels=in_channels,
|
45 |
+
out_channels=out_channels,
|
46 |
+
temb_channels=temb_channels,
|
47 |
+
add_downsample=add_downsample,
|
48 |
+
resnet_eps=resnet_eps,
|
49 |
+
resnet_act_fn=resnet_act_fn,
|
50 |
+
resnet_groups=resnet_groups,
|
51 |
+
downsample_padding=downsample_padding,
|
52 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
53 |
+
|
54 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
55 |
+
|
56 |
+
use_motion_module=use_motion_module,
|
57 |
+
motion_module_type=motion_module_type,
|
58 |
+
motion_module_kwargs=motion_module_kwargs,
|
59 |
+
)
|
60 |
+
elif down_block_type == "CrossAttnDownBlock3D":
|
61 |
+
if cross_attention_dim is None:
|
62 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
|
63 |
+
return CrossAttnDownBlock3D(
|
64 |
+
num_layers=num_layers,
|
65 |
+
in_channels=in_channels,
|
66 |
+
out_channels=out_channels,
|
67 |
+
temb_channels=temb_channels,
|
68 |
+
add_downsample=add_downsample,
|
69 |
+
resnet_eps=resnet_eps,
|
70 |
+
resnet_act_fn=resnet_act_fn,
|
71 |
+
resnet_groups=resnet_groups,
|
72 |
+
downsample_padding=downsample_padding,
|
73 |
+
cross_attention_dim=cross_attention_dim,
|
74 |
+
attn_num_head_channels=attn_num_head_channels,
|
75 |
+
dual_cross_attention=dual_cross_attention,
|
76 |
+
use_linear_projection=use_linear_projection,
|
77 |
+
only_cross_attention=only_cross_attention,
|
78 |
+
upcast_attention=upcast_attention,
|
79 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
80 |
+
|
81 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
82 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
83 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
84 |
+
|
85 |
+
use_motion_module=use_motion_module,
|
86 |
+
motion_module_type=motion_module_type,
|
87 |
+
motion_module_kwargs=motion_module_kwargs,
|
88 |
+
)
|
89 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
90 |
+
|
91 |
+
|
92 |
+
def get_up_block(
|
93 |
+
up_block_type,
|
94 |
+
num_layers,
|
95 |
+
in_channels,
|
96 |
+
out_channels,
|
97 |
+
prev_output_channel,
|
98 |
+
temb_channels,
|
99 |
+
add_upsample,
|
100 |
+
resnet_eps,
|
101 |
+
resnet_act_fn,
|
102 |
+
attn_num_head_channels,
|
103 |
+
resnet_groups=None,
|
104 |
+
cross_attention_dim=None,
|
105 |
+
dual_cross_attention=False,
|
106 |
+
use_linear_projection=False,
|
107 |
+
only_cross_attention=False,
|
108 |
+
upcast_attention=False,
|
109 |
+
resnet_time_scale_shift="default",
|
110 |
+
|
111 |
+
unet_use_cross_frame_attention=False,
|
112 |
+
unet_use_temporal_attention=False,
|
113 |
+
use_inflated_groupnorm=False,
|
114 |
+
|
115 |
+
use_motion_module=None,
|
116 |
+
motion_module_type=None,
|
117 |
+
motion_module_kwargs=None,
|
118 |
+
):
|
119 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
120 |
+
if up_block_type == "UpBlock3D":
|
121 |
+
return UpBlock3D(
|
122 |
+
num_layers=num_layers,
|
123 |
+
in_channels=in_channels,
|
124 |
+
out_channels=out_channels,
|
125 |
+
prev_output_channel=prev_output_channel,
|
126 |
+
temb_channels=temb_channels,
|
127 |
+
add_upsample=add_upsample,
|
128 |
+
resnet_eps=resnet_eps,
|
129 |
+
resnet_act_fn=resnet_act_fn,
|
130 |
+
resnet_groups=resnet_groups,
|
131 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
132 |
+
|
133 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
134 |
+
|
135 |
+
use_motion_module=use_motion_module,
|
136 |
+
motion_module_type=motion_module_type,
|
137 |
+
motion_module_kwargs=motion_module_kwargs,
|
138 |
+
)
|
139 |
+
elif up_block_type == "CrossAttnUpBlock3D":
|
140 |
+
if cross_attention_dim is None:
|
141 |
+
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
|
142 |
+
return CrossAttnUpBlock3D(
|
143 |
+
num_layers=num_layers,
|
144 |
+
in_channels=in_channels,
|
145 |
+
out_channels=out_channels,
|
146 |
+
prev_output_channel=prev_output_channel,
|
147 |
+
temb_channels=temb_channels,
|
148 |
+
add_upsample=add_upsample,
|
149 |
+
resnet_eps=resnet_eps,
|
150 |
+
resnet_act_fn=resnet_act_fn,
|
151 |
+
resnet_groups=resnet_groups,
|
152 |
+
cross_attention_dim=cross_attention_dim,
|
153 |
+
attn_num_head_channels=attn_num_head_channels,
|
154 |
+
dual_cross_attention=dual_cross_attention,
|
155 |
+
use_linear_projection=use_linear_projection,
|
156 |
+
only_cross_attention=only_cross_attention,
|
157 |
+
upcast_attention=upcast_attention,
|
158 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
159 |
+
|
160 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
161 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
162 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
163 |
+
|
164 |
+
use_motion_module=use_motion_module,
|
165 |
+
motion_module_type=motion_module_type,
|
166 |
+
motion_module_kwargs=motion_module_kwargs,
|
167 |
+
)
|
168 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
169 |
+
|
170 |
+
|
171 |
+
class UNetMidBlock3DCrossAttn(nn.Module):
|
172 |
+
def __init__(
|
173 |
+
self,
|
174 |
+
in_channels: int,
|
175 |
+
temb_channels: int,
|
176 |
+
dropout: float = 0.0,
|
177 |
+
num_layers: int = 1,
|
178 |
+
resnet_eps: float = 1e-6,
|
179 |
+
resnet_time_scale_shift: str = "default",
|
180 |
+
resnet_act_fn: str = "swish",
|
181 |
+
resnet_groups: int = 32,
|
182 |
+
resnet_pre_norm: bool = True,
|
183 |
+
attn_num_head_channels=1,
|
184 |
+
output_scale_factor=1.0,
|
185 |
+
cross_attention_dim=1280,
|
186 |
+
dual_cross_attention=False,
|
187 |
+
use_linear_projection=False,
|
188 |
+
upcast_attention=False,
|
189 |
+
|
190 |
+
unet_use_cross_frame_attention=False,
|
191 |
+
unet_use_temporal_attention=False,
|
192 |
+
use_inflated_groupnorm=False,
|
193 |
+
|
194 |
+
use_motion_module=None,
|
195 |
+
|
196 |
+
motion_module_type=None,
|
197 |
+
motion_module_kwargs=None,
|
198 |
+
):
|
199 |
+
super().__init__()
|
200 |
+
|
201 |
+
self.has_cross_attention = True
|
202 |
+
self.attn_num_head_channels = attn_num_head_channels
|
203 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
204 |
+
|
205 |
+
# there is always at least one resnet
|
206 |
+
resnets = [
|
207 |
+
ResnetBlock3D(
|
208 |
+
in_channels=in_channels,
|
209 |
+
out_channels=in_channels,
|
210 |
+
temb_channels=temb_channels,
|
211 |
+
eps=resnet_eps,
|
212 |
+
groups=resnet_groups,
|
213 |
+
dropout=dropout,
|
214 |
+
time_embedding_norm=resnet_time_scale_shift,
|
215 |
+
non_linearity=resnet_act_fn,
|
216 |
+
output_scale_factor=output_scale_factor,
|
217 |
+
pre_norm=resnet_pre_norm,
|
218 |
+
|
219 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
220 |
+
)
|
221 |
+
]
|
222 |
+
attentions = []
|
223 |
+
motion_modules = []
|
224 |
+
|
225 |
+
for _ in range(num_layers):
|
226 |
+
if dual_cross_attention:
|
227 |
+
raise NotImplementedError
|
228 |
+
attentions.append(
|
229 |
+
Transformer3DModel(
|
230 |
+
attn_num_head_channels,
|
231 |
+
in_channels // attn_num_head_channels,
|
232 |
+
in_channels=in_channels,
|
233 |
+
num_layers=1,
|
234 |
+
cross_attention_dim=cross_attention_dim,
|
235 |
+
norm_num_groups=resnet_groups,
|
236 |
+
use_linear_projection=use_linear_projection,
|
237 |
+
upcast_attention=upcast_attention,
|
238 |
+
|
239 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
240 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
241 |
+
)
|
242 |
+
)
|
243 |
+
motion_modules.append(
|
244 |
+
get_motion_module(
|
245 |
+
in_channels=in_channels,
|
246 |
+
motion_module_type=motion_module_type,
|
247 |
+
motion_module_kwargs=motion_module_kwargs,
|
248 |
+
) if use_motion_module else None
|
249 |
+
)
|
250 |
+
resnets.append(
|
251 |
+
ResnetBlock3D(
|
252 |
+
in_channels=in_channels,
|
253 |
+
out_channels=in_channels,
|
254 |
+
temb_channels=temb_channels,
|
255 |
+
eps=resnet_eps,
|
256 |
+
groups=resnet_groups,
|
257 |
+
dropout=dropout,
|
258 |
+
time_embedding_norm=resnet_time_scale_shift,
|
259 |
+
non_linearity=resnet_act_fn,
|
260 |
+
output_scale_factor=output_scale_factor,
|
261 |
+
pre_norm=resnet_pre_norm,
|
262 |
+
|
263 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
264 |
+
)
|
265 |
+
)
|
266 |
+
|
267 |
+
self.attentions = nn.ModuleList(attentions)
|
268 |
+
self.resnets = nn.ModuleList(resnets)
|
269 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
270 |
+
|
271 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
272 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
273 |
+
for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
|
274 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
275 |
+
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
276 |
+
hidden_states = resnet(hidden_states, temb)
|
277 |
+
|
278 |
+
return hidden_states
|
279 |
+
|
280 |
+
|
281 |
+
class CrossAttnDownBlock3D(nn.Module):
|
282 |
+
def __init__(
|
283 |
+
self,
|
284 |
+
in_channels: int,
|
285 |
+
out_channels: int,
|
286 |
+
temb_channels: int,
|
287 |
+
dropout: float = 0.0,
|
288 |
+
num_layers: int = 1,
|
289 |
+
resnet_eps: float = 1e-6,
|
290 |
+
resnet_time_scale_shift: str = "default",
|
291 |
+
resnet_act_fn: str = "swish",
|
292 |
+
resnet_groups: int = 32,
|
293 |
+
resnet_pre_norm: bool = True,
|
294 |
+
attn_num_head_channels=1,
|
295 |
+
cross_attention_dim=1280,
|
296 |
+
output_scale_factor=1.0,
|
297 |
+
downsample_padding=1,
|
298 |
+
add_downsample=True,
|
299 |
+
dual_cross_attention=False,
|
300 |
+
use_linear_projection=False,
|
301 |
+
only_cross_attention=False,
|
302 |
+
upcast_attention=False,
|
303 |
+
|
304 |
+
unet_use_cross_frame_attention=False,
|
305 |
+
unet_use_temporal_attention=False,
|
306 |
+
use_inflated_groupnorm=False,
|
307 |
+
|
308 |
+
use_motion_module=None,
|
309 |
+
|
310 |
+
motion_module_type=None,
|
311 |
+
motion_module_kwargs=None,
|
312 |
+
):
|
313 |
+
super().__init__()
|
314 |
+
resnets = []
|
315 |
+
attentions = []
|
316 |
+
motion_modules = []
|
317 |
+
|
318 |
+
self.has_cross_attention = True
|
319 |
+
self.attn_num_head_channels = attn_num_head_channels
|
320 |
+
|
321 |
+
for i in range(num_layers):
|
322 |
+
in_channels = in_channels if i == 0 else out_channels
|
323 |
+
resnets.append(
|
324 |
+
ResnetBlock3D(
|
325 |
+
in_channels=in_channels,
|
326 |
+
out_channels=out_channels,
|
327 |
+
temb_channels=temb_channels,
|
328 |
+
eps=resnet_eps,
|
329 |
+
groups=resnet_groups,
|
330 |
+
dropout=dropout,
|
331 |
+
time_embedding_norm=resnet_time_scale_shift,
|
332 |
+
non_linearity=resnet_act_fn,
|
333 |
+
output_scale_factor=output_scale_factor,
|
334 |
+
pre_norm=resnet_pre_norm,
|
335 |
+
|
336 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
337 |
+
)
|
338 |
+
)
|
339 |
+
if dual_cross_attention:
|
340 |
+
raise NotImplementedError
|
341 |
+
attentions.append(
|
342 |
+
Transformer3DModel(
|
343 |
+
attn_num_head_channels,
|
344 |
+
out_channels // attn_num_head_channels,
|
345 |
+
in_channels=out_channels,
|
346 |
+
num_layers=1,
|
347 |
+
cross_attention_dim=cross_attention_dim,
|
348 |
+
norm_num_groups=resnet_groups,
|
349 |
+
use_linear_projection=use_linear_projection,
|
350 |
+
only_cross_attention=only_cross_attention,
|
351 |
+
upcast_attention=upcast_attention,
|
352 |
+
|
353 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
354 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
355 |
+
)
|
356 |
+
)
|
357 |
+
motion_modules.append(
|
358 |
+
get_motion_module(
|
359 |
+
in_channels=out_channels,
|
360 |
+
motion_module_type=motion_module_type,
|
361 |
+
motion_module_kwargs=motion_module_kwargs,
|
362 |
+
) if use_motion_module else None
|
363 |
+
)
|
364 |
+
|
365 |
+
self.attentions = nn.ModuleList(attentions)
|
366 |
+
self.resnets = nn.ModuleList(resnets)
|
367 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
368 |
+
|
369 |
+
if add_downsample:
|
370 |
+
self.downsamplers = nn.ModuleList(
|
371 |
+
[
|
372 |
+
Downsample3D(
|
373 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
374 |
+
)
|
375 |
+
]
|
376 |
+
)
|
377 |
+
else:
|
378 |
+
self.downsamplers = None
|
379 |
+
|
380 |
+
self.gradient_checkpointing = False
|
381 |
+
|
382 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
383 |
+
output_states = ()
|
384 |
+
|
385 |
+
for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
|
386 |
+
if self.training and self.gradient_checkpointing:
|
387 |
+
|
388 |
+
def create_custom_forward(module, return_dict=None):
|
389 |
+
def custom_forward(*inputs):
|
390 |
+
if return_dict is not None:
|
391 |
+
return module(*inputs, return_dict=return_dict)
|
392 |
+
else:
|
393 |
+
return module(*inputs)
|
394 |
+
|
395 |
+
return custom_forward
|
396 |
+
|
397 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
398 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
399 |
+
create_custom_forward(attn, return_dict=False),
|
400 |
+
hidden_states,
|
401 |
+
encoder_hidden_states,
|
402 |
+
)[0]
|
403 |
+
if motion_module is not None:
|
404 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
|
405 |
+
|
406 |
+
else:
|
407 |
+
hidden_states = resnet(hidden_states, temb)
|
408 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
409 |
+
|
410 |
+
# add motion module
|
411 |
+
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
412 |
+
|
413 |
+
output_states += (hidden_states,)
|
414 |
+
|
415 |
+
if self.downsamplers is not None:
|
416 |
+
for downsampler in self.downsamplers:
|
417 |
+
hidden_states = downsampler(hidden_states)
|
418 |
+
|
419 |
+
output_states += (hidden_states,)
|
420 |
+
|
421 |
+
return hidden_states, output_states
|
422 |
+
|
423 |
+
|
424 |
+
class DownBlock3D(nn.Module):
|
425 |
+
def __init__(
|
426 |
+
self,
|
427 |
+
in_channels: int,
|
428 |
+
out_channels: int,
|
429 |
+
temb_channels: int,
|
430 |
+
dropout: float = 0.0,
|
431 |
+
num_layers: int = 1,
|
432 |
+
resnet_eps: float = 1e-6,
|
433 |
+
resnet_time_scale_shift: str = "default",
|
434 |
+
resnet_act_fn: str = "swish",
|
435 |
+
resnet_groups: int = 32,
|
436 |
+
resnet_pre_norm: bool = True,
|
437 |
+
output_scale_factor=1.0,
|
438 |
+
add_downsample=True,
|
439 |
+
downsample_padding=1,
|
440 |
+
|
441 |
+
use_inflated_groupnorm=False,
|
442 |
+
|
443 |
+
use_motion_module=None,
|
444 |
+
motion_module_type=None,
|
445 |
+
motion_module_kwargs=None,
|
446 |
+
):
|
447 |
+
super().__init__()
|
448 |
+
resnets = []
|
449 |
+
motion_modules = []
|
450 |
+
|
451 |
+
for i in range(num_layers):
|
452 |
+
in_channels = in_channels if i == 0 else out_channels
|
453 |
+
resnets.append(
|
454 |
+
ResnetBlock3D(
|
455 |
+
in_channels=in_channels,
|
456 |
+
out_channels=out_channels,
|
457 |
+
temb_channels=temb_channels,
|
458 |
+
eps=resnet_eps,
|
459 |
+
groups=resnet_groups,
|
460 |
+
dropout=dropout,
|
461 |
+
time_embedding_norm=resnet_time_scale_shift,
|
462 |
+
non_linearity=resnet_act_fn,
|
463 |
+
output_scale_factor=output_scale_factor,
|
464 |
+
pre_norm=resnet_pre_norm,
|
465 |
+
|
466 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
467 |
+
)
|
468 |
+
)
|
469 |
+
motion_modules.append(
|
470 |
+
get_motion_module(
|
471 |
+
in_channels=out_channels,
|
472 |
+
motion_module_type=motion_module_type,
|
473 |
+
motion_module_kwargs=motion_module_kwargs,
|
474 |
+
) if use_motion_module else None
|
475 |
+
)
|
476 |
+
|
477 |
+
self.resnets = nn.ModuleList(resnets)
|
478 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
479 |
+
|
480 |
+
if add_downsample:
|
481 |
+
self.downsamplers = nn.ModuleList(
|
482 |
+
[
|
483 |
+
Downsample3D(
|
484 |
+
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
485 |
+
)
|
486 |
+
]
|
487 |
+
)
|
488 |
+
else:
|
489 |
+
self.downsamplers = None
|
490 |
+
|
491 |
+
self.gradient_checkpointing = False
|
492 |
+
|
493 |
+
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
494 |
+
output_states = ()
|
495 |
+
|
496 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
497 |
+
if self.training and self.gradient_checkpointing:
|
498 |
+
def create_custom_forward(module):
|
499 |
+
def custom_forward(*inputs):
|
500 |
+
return module(*inputs)
|
501 |
+
|
502 |
+
return custom_forward
|
503 |
+
|
504 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
505 |
+
if motion_module is not None:
|
506 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
|
507 |
+
else:
|
508 |
+
hidden_states = resnet(hidden_states, temb)
|
509 |
+
|
510 |
+
# add motion module
|
511 |
+
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
512 |
+
|
513 |
+
output_states += (hidden_states,)
|
514 |
+
|
515 |
+
if self.downsamplers is not None:
|
516 |
+
for downsampler in self.downsamplers:
|
517 |
+
hidden_states = downsampler(hidden_states)
|
518 |
+
|
519 |
+
output_states += (hidden_states,)
|
520 |
+
|
521 |
+
return hidden_states, output_states
|
522 |
+
|
523 |
+
|
524 |
+
class CrossAttnUpBlock3D(nn.Module):
|
525 |
+
def __init__(
|
526 |
+
self,
|
527 |
+
in_channels: int,
|
528 |
+
out_channels: int,
|
529 |
+
prev_output_channel: int,
|
530 |
+
temb_channels: int,
|
531 |
+
dropout: float = 0.0,
|
532 |
+
num_layers: int = 1,
|
533 |
+
resnet_eps: float = 1e-6,
|
534 |
+
resnet_time_scale_shift: str = "default",
|
535 |
+
resnet_act_fn: str = "swish",
|
536 |
+
resnet_groups: int = 32,
|
537 |
+
resnet_pre_norm: bool = True,
|
538 |
+
attn_num_head_channels=1,
|
539 |
+
cross_attention_dim=1280,
|
540 |
+
output_scale_factor=1.0,
|
541 |
+
add_upsample=True,
|
542 |
+
dual_cross_attention=False,
|
543 |
+
use_linear_projection=False,
|
544 |
+
only_cross_attention=False,
|
545 |
+
upcast_attention=False,
|
546 |
+
|
547 |
+
unet_use_cross_frame_attention=False,
|
548 |
+
unet_use_temporal_attention=False,
|
549 |
+
use_inflated_groupnorm=False,
|
550 |
+
|
551 |
+
use_motion_module=None,
|
552 |
+
|
553 |
+
motion_module_type=None,
|
554 |
+
motion_module_kwargs=None,
|
555 |
+
):
|
556 |
+
super().__init__()
|
557 |
+
resnets = []
|
558 |
+
attentions = []
|
559 |
+
motion_modules = []
|
560 |
+
|
561 |
+
self.has_cross_attention = True
|
562 |
+
self.attn_num_head_channels = attn_num_head_channels
|
563 |
+
|
564 |
+
for i in range(num_layers):
|
565 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
566 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
567 |
+
|
568 |
+
resnets.append(
|
569 |
+
ResnetBlock3D(
|
570 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
571 |
+
out_channels=out_channels,
|
572 |
+
temb_channels=temb_channels,
|
573 |
+
eps=resnet_eps,
|
574 |
+
groups=resnet_groups,
|
575 |
+
dropout=dropout,
|
576 |
+
time_embedding_norm=resnet_time_scale_shift,
|
577 |
+
non_linearity=resnet_act_fn,
|
578 |
+
output_scale_factor=output_scale_factor,
|
579 |
+
pre_norm=resnet_pre_norm,
|
580 |
+
|
581 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
582 |
+
)
|
583 |
+
)
|
584 |
+
if dual_cross_attention:
|
585 |
+
raise NotImplementedError
|
586 |
+
attentions.append(
|
587 |
+
Transformer3DModel(
|
588 |
+
attn_num_head_channels,
|
589 |
+
out_channels // attn_num_head_channels,
|
590 |
+
in_channels=out_channels,
|
591 |
+
num_layers=1,
|
592 |
+
cross_attention_dim=cross_attention_dim,
|
593 |
+
norm_num_groups=resnet_groups,
|
594 |
+
use_linear_projection=use_linear_projection,
|
595 |
+
only_cross_attention=only_cross_attention,
|
596 |
+
upcast_attention=upcast_attention,
|
597 |
+
|
598 |
+
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
599 |
+
unet_use_temporal_attention=unet_use_temporal_attention,
|
600 |
+
)
|
601 |
+
)
|
602 |
+
motion_modules.append(
|
603 |
+
get_motion_module(
|
604 |
+
in_channels=out_channels,
|
605 |
+
motion_module_type=motion_module_type,
|
606 |
+
motion_module_kwargs=motion_module_kwargs,
|
607 |
+
) if use_motion_module else None
|
608 |
+
)
|
609 |
+
|
610 |
+
self.attentions = nn.ModuleList(attentions)
|
611 |
+
self.resnets = nn.ModuleList(resnets)
|
612 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
613 |
+
|
614 |
+
if add_upsample:
|
615 |
+
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
616 |
+
else:
|
617 |
+
self.upsamplers = None
|
618 |
+
|
619 |
+
self.gradient_checkpointing = False
|
620 |
+
|
621 |
+
def forward(
|
622 |
+
self,
|
623 |
+
hidden_states,
|
624 |
+
res_hidden_states_tuple,
|
625 |
+
temb=None,
|
626 |
+
encoder_hidden_states=None,
|
627 |
+
upsample_size=None,
|
628 |
+
attention_mask=None,
|
629 |
+
):
|
630 |
+
for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
|
631 |
+
# pop res hidden states
|
632 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
633 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
634 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
635 |
+
|
636 |
+
if self.training and self.gradient_checkpointing:
|
637 |
+
|
638 |
+
def create_custom_forward(module, return_dict=None):
|
639 |
+
def custom_forward(*inputs):
|
640 |
+
if return_dict is not None:
|
641 |
+
return module(*inputs, return_dict=return_dict)
|
642 |
+
else:
|
643 |
+
return module(*inputs)
|
644 |
+
|
645 |
+
return custom_forward
|
646 |
+
|
647 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
648 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
649 |
+
create_custom_forward(attn, return_dict=False),
|
650 |
+
hidden_states,
|
651 |
+
encoder_hidden_states,
|
652 |
+
)[0]
|
653 |
+
if motion_module is not None:
|
654 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
|
655 |
+
|
656 |
+
else:
|
657 |
+
hidden_states = resnet(hidden_states, temb)
|
658 |
+
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
659 |
+
|
660 |
+
# add motion module
|
661 |
+
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
662 |
+
|
663 |
+
if self.upsamplers is not None:
|
664 |
+
for upsampler in self.upsamplers:
|
665 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
666 |
+
|
667 |
+
return hidden_states
|
668 |
+
|
669 |
+
|
670 |
+
class UpBlock3D(nn.Module):
|
671 |
+
def __init__(
|
672 |
+
self,
|
673 |
+
in_channels: int,
|
674 |
+
prev_output_channel: int,
|
675 |
+
out_channels: int,
|
676 |
+
temb_channels: int,
|
677 |
+
dropout: float = 0.0,
|
678 |
+
num_layers: int = 1,
|
679 |
+
resnet_eps: float = 1e-6,
|
680 |
+
resnet_time_scale_shift: str = "default",
|
681 |
+
resnet_act_fn: str = "swish",
|
682 |
+
resnet_groups: int = 32,
|
683 |
+
resnet_pre_norm: bool = True,
|
684 |
+
output_scale_factor=1.0,
|
685 |
+
add_upsample=True,
|
686 |
+
|
687 |
+
use_inflated_groupnorm=False,
|
688 |
+
|
689 |
+
use_motion_module=None,
|
690 |
+
motion_module_type=None,
|
691 |
+
motion_module_kwargs=None,
|
692 |
+
):
|
693 |
+
super().__init__()
|
694 |
+
resnets = []
|
695 |
+
motion_modules = []
|
696 |
+
|
697 |
+
for i in range(num_layers):
|
698 |
+
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
699 |
+
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
700 |
+
|
701 |
+
resnets.append(
|
702 |
+
ResnetBlock3D(
|
703 |
+
in_channels=resnet_in_channels + res_skip_channels,
|
704 |
+
out_channels=out_channels,
|
705 |
+
temb_channels=temb_channels,
|
706 |
+
eps=resnet_eps,
|
707 |
+
groups=resnet_groups,
|
708 |
+
dropout=dropout,
|
709 |
+
time_embedding_norm=resnet_time_scale_shift,
|
710 |
+
non_linearity=resnet_act_fn,
|
711 |
+
output_scale_factor=output_scale_factor,
|
712 |
+
pre_norm=resnet_pre_norm,
|
713 |
+
|
714 |
+
use_inflated_groupnorm=use_inflated_groupnorm,
|
715 |
+
)
|
716 |
+
)
|
717 |
+
motion_modules.append(
|
718 |
+
get_motion_module(
|
719 |
+
in_channels=out_channels,
|
720 |
+
motion_module_type=motion_module_type,
|
721 |
+
motion_module_kwargs=motion_module_kwargs,
|
722 |
+
) if use_motion_module else None
|
723 |
+
)
|
724 |
+
|
725 |
+
self.resnets = nn.ModuleList(resnets)
|
726 |
+
self.motion_modules = nn.ModuleList(motion_modules)
|
727 |
+
|
728 |
+
if add_upsample:
|
729 |
+
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
730 |
+
else:
|
731 |
+
self.upsamplers = None
|
732 |
+
|
733 |
+
self.gradient_checkpointing = False
|
734 |
+
|
735 |
+
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
|
736 |
+
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
737 |
+
# pop res hidden states
|
738 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
739 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
740 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
741 |
+
|
742 |
+
if self.training and self.gradient_checkpointing:
|
743 |
+
def create_custom_forward(module):
|
744 |
+
def custom_forward(*inputs):
|
745 |
+
return module(*inputs)
|
746 |
+
|
747 |
+
return custom_forward
|
748 |
+
|
749 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
750 |
+
if motion_module is not None:
|
751 |
+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
|
752 |
+
else:
|
753 |
+
hidden_states = resnet(hidden_states, temb)
|
754 |
+
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
755 |
+
|
756 |
+
if self.upsamplers is not None:
|
757 |
+
for upsampler in self.upsamplers:
|
758 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
759 |
+
|
760 |
+
return hidden_states
|
animatediff/pipelines/pipeline_animation.py
ADDED
@@ -0,0 +1,793 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
|
2 |
+
|
3 |
+
import inspect
|
4 |
+
from typing import Callable, List, Optional, Union
|
5 |
+
from dataclasses import dataclass
|
6 |
+
|
7 |
+
import PIL.Image
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from diffusers.utils import is_accelerate_available
|
13 |
+
from packaging import version
|
14 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
15 |
+
|
16 |
+
from diffusers.configuration_utils import FrozenDict
|
17 |
+
from diffusers.models import AutoencoderKL
|
18 |
+
from diffusers import DiffusionPipeline
|
19 |
+
from diffusers.schedulers import (
|
20 |
+
DDIMScheduler,
|
21 |
+
DPMSolverMultistepScheduler,
|
22 |
+
EulerAncestralDiscreteScheduler,
|
23 |
+
EulerDiscreteScheduler,
|
24 |
+
LMSDiscreteScheduler,
|
25 |
+
PNDMScheduler,
|
26 |
+
)
|
27 |
+
from diffusers.utils import deprecate, logging, BaseOutput
|
28 |
+
|
29 |
+
from einops import rearrange
|
30 |
+
|
31 |
+
from ..models.unet import UNet3DConditionModel
|
32 |
+
from ..models.sparse_controlnet import SparseControlNetModel
|
33 |
+
import pdb
|
34 |
+
import PIL
|
35 |
+
|
36 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
37 |
+
|
38 |
+
# image: either PIL.Image.Image or torch.Tensor.
|
39 |
+
def preprocess_image(image, h=512, w=512):
|
40 |
+
if isinstance(image, torch.Tensor):
|
41 |
+
return image
|
42 |
+
elif isinstance(image, PIL.Image.Image):
|
43 |
+
# image: [1, 512, 512, 3]
|
44 |
+
image = np.array(image.resize((w, h), resample=PIL.Image.LANCZOS))[None, :]
|
45 |
+
image = image.astype(np.float16) * 2 / 255.0 - 1.0
|
46 |
+
# image: [1, 3, 512, 512]
|
47 |
+
image = image.transpose(0, 3, 1, 2)
|
48 |
+
image = torch.from_numpy(image)
|
49 |
+
else:
|
50 |
+
breakpoint()
|
51 |
+
return image
|
52 |
+
|
53 |
+
@dataclass
|
54 |
+
class AnimationPipelineOutput(BaseOutput):
|
55 |
+
videos: Union[torch.Tensor, np.ndarray]
|
56 |
+
|
57 |
+
|
58 |
+
class AnimationPipeline(DiffusionPipeline):
|
59 |
+
_optional_components = []
|
60 |
+
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
vae: AutoencoderKL,
|
64 |
+
text_encoder: CLIPTextModel,
|
65 |
+
tokenizer: CLIPTokenizer,
|
66 |
+
unet: UNet3DConditionModel,
|
67 |
+
scheduler: Union[
|
68 |
+
DDIMScheduler,
|
69 |
+
PNDMScheduler,
|
70 |
+
LMSDiscreteScheduler,
|
71 |
+
EulerDiscreteScheduler,
|
72 |
+
EulerAncestralDiscreteScheduler,
|
73 |
+
DPMSolverMultistepScheduler,
|
74 |
+
],
|
75 |
+
controlnet: Union[SparseControlNetModel, None] = None,
|
76 |
+
torch_dtype=torch.float32,
|
77 |
+
):
|
78 |
+
super().__init__()
|
79 |
+
|
80 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
81 |
+
deprecation_message = (
|
82 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
83 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
84 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
85 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
86 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
87 |
+
" file"
|
88 |
+
)
|
89 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
90 |
+
new_config = dict(scheduler.config)
|
91 |
+
new_config["steps_offset"] = 1
|
92 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
93 |
+
|
94 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
95 |
+
deprecation_message = (
|
96 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
97 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
98 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
99 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
100 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
101 |
+
)
|
102 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
103 |
+
new_config = dict(scheduler.config)
|
104 |
+
new_config["clip_sample"] = False
|
105 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
106 |
+
|
107 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
108 |
+
version.parse(unet.config._diffusers_version).base_version
|
109 |
+
) < version.parse("0.9.0.dev0")
|
110 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
111 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
112 |
+
deprecation_message = (
|
113 |
+
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
114 |
+
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
115 |
+
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
116 |
+
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
117 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
118 |
+
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
119 |
+
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
120 |
+
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
121 |
+
" the `unet/config.json` file"
|
122 |
+
)
|
123 |
+
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
124 |
+
new_config = dict(unet.config)
|
125 |
+
new_config["sample_size"] = 64
|
126 |
+
unet._internal_dict = FrozenDict(new_config)
|
127 |
+
self.torch_dtype=torch_dtype
|
128 |
+
self.register_modules(
|
129 |
+
vae=vae.to(self.torch_dtype),
|
130 |
+
text_encoder=text_encoder.to(self.torch_dtype),
|
131 |
+
tokenizer=tokenizer,
|
132 |
+
unet=unet.to(self.torch_dtype),
|
133 |
+
scheduler=scheduler,
|
134 |
+
# controlnet=controlnet.to(self.torch_dtype),
|
135 |
+
)
|
136 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
137 |
+
if controlnet!=None: self.controlnet=controlnet.to(self.torch_dtype)
|
138 |
+
def enable_vae_slicing(self):
|
139 |
+
self.vae.enable_slicing()
|
140 |
+
|
141 |
+
def disable_vae_slicing(self):
|
142 |
+
self.vae.disable_slicing()
|
143 |
+
|
144 |
+
def enable_sequential_cpu_offload(self, gpu_id=0):
|
145 |
+
if is_accelerate_available():
|
146 |
+
from accelerate import cpu_offload
|
147 |
+
else:
|
148 |
+
raise ImportError("Please install accelerate via `pip install accelerate`")
|
149 |
+
|
150 |
+
device = torch.device(f"cuda:{gpu_id}")
|
151 |
+
|
152 |
+
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
153 |
+
if cpu_offloaded_model is not None:
|
154 |
+
cpu_offload(cpu_offloaded_model, device)
|
155 |
+
|
156 |
+
|
157 |
+
@property
|
158 |
+
def _execution_device(self):
|
159 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
160 |
+
return self.device
|
161 |
+
for module in self.unet.modules():
|
162 |
+
if (
|
163 |
+
hasattr(module, "_hf_hook")
|
164 |
+
and hasattr(module._hf_hook, "execution_device")
|
165 |
+
and module._hf_hook.execution_device is not None
|
166 |
+
):
|
167 |
+
return torch.device(module._hf_hook.execution_device)
|
168 |
+
return self.device
|
169 |
+
|
170 |
+
def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
|
171 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
172 |
+
|
173 |
+
text_inputs = self.tokenizer(
|
174 |
+
prompt,
|
175 |
+
padding="max_length",
|
176 |
+
max_length=self.tokenizer.model_max_length,
|
177 |
+
truncation=True,
|
178 |
+
return_tensors="pt",
|
179 |
+
)
|
180 |
+
text_input_ids = text_inputs.input_ids
|
181 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
182 |
+
|
183 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
184 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
185 |
+
logger.warning(
|
186 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
187 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
188 |
+
)
|
189 |
+
|
190 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
191 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
192 |
+
else:
|
193 |
+
attention_mask = None
|
194 |
+
|
195 |
+
text_embeddings = self.text_encoder(
|
196 |
+
text_input_ids.to(device),
|
197 |
+
attention_mask=attention_mask,
|
198 |
+
)
|
199 |
+
text_embeddings = text_embeddings[0]
|
200 |
+
|
201 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
202 |
+
bs_embed, seq_len, _ = text_embeddings.shape
|
203 |
+
text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
|
204 |
+
text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
205 |
+
|
206 |
+
# get unconditional embeddings for classifier free guidance
|
207 |
+
if do_classifier_free_guidance:
|
208 |
+
uncond_tokens: List[str]
|
209 |
+
if negative_prompt is None:
|
210 |
+
uncond_tokens = [""] * batch_size
|
211 |
+
elif type(prompt) is not type(negative_prompt):
|
212 |
+
raise TypeError(
|
213 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
214 |
+
f" {type(prompt)}."
|
215 |
+
)
|
216 |
+
elif isinstance(negative_prompt, str):
|
217 |
+
uncond_tokens = [negative_prompt]
|
218 |
+
elif batch_size != len(negative_prompt):
|
219 |
+
raise ValueError(
|
220 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
221 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
222 |
+
" the batch size of `prompt`."
|
223 |
+
)
|
224 |
+
else:
|
225 |
+
uncond_tokens = negative_prompt
|
226 |
+
|
227 |
+
max_length = text_input_ids.shape[-1]
|
228 |
+
uncond_input = self.tokenizer(
|
229 |
+
uncond_tokens,
|
230 |
+
padding="max_length",
|
231 |
+
max_length=max_length,
|
232 |
+
truncation=True,
|
233 |
+
return_tensors="pt",
|
234 |
+
)
|
235 |
+
|
236 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
237 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
238 |
+
else:
|
239 |
+
attention_mask = None
|
240 |
+
|
241 |
+
uncond_embeddings = self.text_encoder(
|
242 |
+
uncond_input.input_ids.to(device),
|
243 |
+
attention_mask=attention_mask,
|
244 |
+
)
|
245 |
+
uncond_embeddings = uncond_embeddings[0]
|
246 |
+
|
247 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
248 |
+
seq_len = uncond_embeddings.shape[1]
|
249 |
+
uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
|
250 |
+
uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
251 |
+
|
252 |
+
# For classifier free guidance, we need to do two forward passes.
|
253 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
254 |
+
# to avoid doing two forward passes
|
255 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
256 |
+
|
257 |
+
return text_embeddings
|
258 |
+
|
259 |
+
def encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
|
260 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
261 |
+
# print(batch_size)
|
262 |
+
# exit()
|
263 |
+
text_inputs = self.tokenizer(
|
264 |
+
prompt,
|
265 |
+
padding="max_length",
|
266 |
+
max_length=self.tokenizer.model_max_length,
|
267 |
+
truncation=True,
|
268 |
+
return_tensors="pt",
|
269 |
+
)
|
270 |
+
text_input_ids = text_inputs.input_ids
|
271 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
272 |
+
|
273 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
274 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
275 |
+
logger.warning(
|
276 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
277 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
278 |
+
)
|
279 |
+
|
280 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
281 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
282 |
+
else:
|
283 |
+
attention_mask = None
|
284 |
+
|
285 |
+
text_embeddings = self.text_encoder(
|
286 |
+
text_input_ids.to(device),
|
287 |
+
attention_mask=attention_mask,
|
288 |
+
)
|
289 |
+
text_embeddings = text_embeddings[0]
|
290 |
+
|
291 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
292 |
+
bs_embed, seq_len, _ = text_embeddings.shape
|
293 |
+
text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
|
294 |
+
text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
295 |
+
|
296 |
+
# get unconditional embeddings for classifier free guidance
|
297 |
+
if do_classifier_free_guidance:
|
298 |
+
uncond_tokens: List[str]
|
299 |
+
if negative_prompt is None:
|
300 |
+
uncond_tokens = [""] * batch_size
|
301 |
+
elif type(prompt) is not type(negative_prompt):
|
302 |
+
raise TypeError(
|
303 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
304 |
+
f" {type(prompt)}."
|
305 |
+
)
|
306 |
+
elif isinstance(negative_prompt, str):
|
307 |
+
uncond_tokens = [negative_prompt]
|
308 |
+
elif batch_size != len(negative_prompt):
|
309 |
+
raise ValueError(
|
310 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
311 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
312 |
+
" the batch size of `prompt`."
|
313 |
+
)
|
314 |
+
else:
|
315 |
+
uncond_tokens = negative_prompt
|
316 |
+
max_length = text_input_ids.shape[-1]
|
317 |
+
uncond_input = self.tokenizer(
|
318 |
+
uncond_tokens,
|
319 |
+
padding="max_length",
|
320 |
+
max_length=max_length,
|
321 |
+
truncation=True,
|
322 |
+
return_tensors="pt",
|
323 |
+
)
|
324 |
+
|
325 |
+
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
326 |
+
attention_mask = uncond_input.attention_mask.to(device)
|
327 |
+
else:
|
328 |
+
attention_mask = None
|
329 |
+
|
330 |
+
uncond_embeddings = self.text_encoder(
|
331 |
+
uncond_input.input_ids.to(device),
|
332 |
+
attention_mask=attention_mask,
|
333 |
+
)
|
334 |
+
uncond_embeddings = uncond_embeddings[0]
|
335 |
+
|
336 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
337 |
+
seq_len = uncond_embeddings.shape[1]
|
338 |
+
uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
|
339 |
+
uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
340 |
+
|
341 |
+
# For classifier free guidance, we need to do two forward passes.
|
342 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
343 |
+
# to avoid doing two forward passes
|
344 |
+
# text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
345 |
+
# print("encode here!!!")
|
346 |
+
# print("shape of text_embeddings",text_embeddings.shape)
|
347 |
+
# print("shape of uncond_embeddings",uncond_embeddings.shape)
|
348 |
+
return text_embeddings,uncond_embeddings
|
349 |
+
|
350 |
+
|
351 |
+
def decode_latents(self, latents):
|
352 |
+
video_length = latents.shape[2]
|
353 |
+
latents = 1 / 0.18215 * latents
|
354 |
+
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
355 |
+
# video = self.vae.decode(latents).sample
|
356 |
+
video = []
|
357 |
+
for frame_idx in range(latents.shape[0]):
|
358 |
+
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
|
359 |
+
video = torch.cat(video)
|
360 |
+
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
361 |
+
video = (video / 2 + 0.5).clamp(0, 1)
|
362 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
363 |
+
video = video.cpu().float().numpy()
|
364 |
+
return video
|
365 |
+
|
366 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
367 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
368 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
369 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
370 |
+
# and should be between [0, 1]
|
371 |
+
|
372 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
373 |
+
extra_step_kwargs = {}
|
374 |
+
if accepts_eta:
|
375 |
+
extra_step_kwargs["eta"] = eta
|
376 |
+
|
377 |
+
# check if the scheduler accepts generator
|
378 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
379 |
+
if accepts_generator:
|
380 |
+
extra_step_kwargs["generator"] = generator
|
381 |
+
return extra_step_kwargs
|
382 |
+
|
383 |
+
def check_inputs(self, prompt, height, width, callback_steps,prompt_embedding):
|
384 |
+
if not isinstance(prompt, str) and not isinstance(prompt, list) and prompt_embedding==None:
|
385 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
386 |
+
|
387 |
+
if height % 8 != 0 or width % 8 != 0:
|
388 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
389 |
+
|
390 |
+
if (callback_steps is None) or (
|
391 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
392 |
+
):
|
393 |
+
raise ValueError(
|
394 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
395 |
+
f" {type(callback_steps)}."
|
396 |
+
)
|
397 |
+
|
398 |
+
def prepare_latents(self, init_image, init_image_strength, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
|
399 |
+
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
400 |
+
|
401 |
+
if init_image is not None:
|
402 |
+
# init_image: either PIL.Image.Image or torch.Tensor.
|
403 |
+
image = preprocess_image(init_image, height, width)
|
404 |
+
image = image.to(device=device, dtype=dtype)
|
405 |
+
if isinstance(generator, list):
|
406 |
+
init_latents = [
|
407 |
+
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
408 |
+
]
|
409 |
+
init_latents = torch.cat(init_latents, dim=0)
|
410 |
+
else:
|
411 |
+
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
412 |
+
else:
|
413 |
+
init_latents = None
|
414 |
+
|
415 |
+
|
416 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
417 |
+
raise ValueError(
|
418 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
419 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
420 |
+
)
|
421 |
+
if latents is None:
|
422 |
+
rand_device = "cpu" if device.type == "mps" else device
|
423 |
+
|
424 |
+
if isinstance(generator, list):
|
425 |
+
shape = shape
|
426 |
+
# shape = (1,) + shape[1:]
|
427 |
+
# ignore init latents for batch model
|
428 |
+
latents = [
|
429 |
+
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
430 |
+
for i in range(batch_size)
|
431 |
+
]
|
432 |
+
latents = torch.cat(latents, dim=0).to(device)
|
433 |
+
else:
|
434 |
+
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
435 |
+
if init_latents is not None:
|
436 |
+
blend_frames = video_length // 2
|
437 |
+
init_image_strength, init_image_final_weight = init_image_strength
|
438 |
+
for i in range(video_length):
|
439 |
+
dist_to_end = (blend_frames - float(i)) / blend_frames
|
440 |
+
# When i > 0.9 * blend_frames, dist_to_end < 0.1. Then it will be changed to 0.05,
|
441 |
+
# so that the last half of the video still is still initialized with a little bit of init_latents.
|
442 |
+
dist_to_end = max(dist_to_end, init_image_final_weight)
|
443 |
+
# Changed from /30 to /100.
|
444 |
+
# gradully reduce init alpha along video frames (loosen restriction)
|
445 |
+
init_alpha = dist_to_end * init_image_strength / 100
|
446 |
+
latents[:, :, i, :, :] = init_latents * init_alpha + latents[:, :, i, :, :] * (1 - init_alpha)
|
447 |
+
else:
|
448 |
+
if latents.shape != shape:
|
449 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
450 |
+
latents = latents.to(device)
|
451 |
+
|
452 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
453 |
+
if init_latents is None:
|
454 |
+
latents = latents * self.scheduler.init_noise_sigma
|
455 |
+
return latents
|
456 |
+
|
457 |
+
@torch.no_grad()
|
458 |
+
def __call__(
|
459 |
+
self,
|
460 |
+
prompt: Union[str, List[str]],
|
461 |
+
video_length: Optional[int],
|
462 |
+
init_image: Union[PIL.Image.Image, torch.Tensor],
|
463 |
+
init_image_strength: float = 1.0,
|
464 |
+
height: Optional[int] = None,
|
465 |
+
width: Optional[int] = None,
|
466 |
+
num_inference_steps: int = 50,
|
467 |
+
guidance_scale: float = 7.5,
|
468 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
469 |
+
num_videos_per_prompt: Optional[int] = 1,
|
470 |
+
eta: float = 0.0,
|
471 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
472 |
+
latents: Optional[torch.FloatTensor] = None,
|
473 |
+
output_type: Optional[str] = "tensor",
|
474 |
+
return_dict: bool = True,
|
475 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
476 |
+
callback_steps: Optional[int] = 1,
|
477 |
+
#support embeddings
|
478 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
479 |
+
negative_prompt_embeds:Optional[torch.FloatTensor] = None,
|
480 |
+
# support controlnet
|
481 |
+
controlnet_images: torch.FloatTensor = None,
|
482 |
+
controlnet_image_index: list = [0],
|
483 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
484 |
+
**kwargs,
|
485 |
+
):
|
486 |
+
# Default height and width to unet
|
487 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
488 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
489 |
+
|
490 |
+
if isinstance(prompt_embeds, (list, tuple)):
|
491 |
+
prompt_embeds_begin, prompt_embeds_end, adaface_anneal_steps = prompt_embeds
|
492 |
+
prompt_embeds = prompt_embeds_begin
|
493 |
+
do_prompt_embeds_annealing = True
|
494 |
+
else:
|
495 |
+
do_prompt_embeds_annealing = False
|
496 |
+
|
497 |
+
# Check inputs. Raise error if not correct
|
498 |
+
self.check_inputs(prompt, height, width, callback_steps, prompt_embeds)
|
499 |
+
|
500 |
+
# Define call parameters
|
501 |
+
# batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
502 |
+
batch_size = 1
|
503 |
+
if latents is not None:
|
504 |
+
batch_size = latents.shape[0]
|
505 |
+
if isinstance(prompt, list):
|
506 |
+
batch_size = len(prompt)
|
507 |
+
|
508 |
+
device = self._execution_device
|
509 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
510 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
511 |
+
# corresponds to doing no classifier free guidance.
|
512 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
513 |
+
|
514 |
+
# Encode input prompt
|
515 |
+
prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
|
516 |
+
if negative_prompt is not None:
|
517 |
+
negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
|
518 |
+
if prompt_embeds is None:
|
519 |
+
text_embeddings = self._encode_prompt(
|
520 |
+
prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
|
521 |
+
)
|
522 |
+
# If do_prompt_embeds_annealing is True, prompt_embeds and text_embeddings will be assigned in the loop below,
|
523 |
+
# and this is just to avoid type error.
|
524 |
+
# Otherwise, text_embeddings won't be replaced.
|
525 |
+
else:
|
526 |
+
text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
|
527 |
+
|
528 |
+
# print(text_embeddings.shape)
|
529 |
+
# return
|
530 |
+
# Prepare timesteps
|
531 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
532 |
+
timesteps = self.scheduler.timesteps
|
533 |
+
|
534 |
+
# Prepare latent variables
|
535 |
+
num_channels_latents = self.unet.in_channels
|
536 |
+
latents = self.prepare_latents(
|
537 |
+
init_image,
|
538 |
+
init_image_strength,
|
539 |
+
batch_size * num_videos_per_prompt,
|
540 |
+
num_channels_latents,
|
541 |
+
video_length,
|
542 |
+
height,
|
543 |
+
width,
|
544 |
+
text_embeddings.dtype,
|
545 |
+
device,
|
546 |
+
generator,
|
547 |
+
latents,
|
548 |
+
).to(self.torch_dtype)
|
549 |
+
latents_dtype = latents.dtype
|
550 |
+
|
551 |
+
# Prepare extra step kwargs.
|
552 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
553 |
+
|
554 |
+
# Denoising loop
|
555 |
+
# num_warmup_steps = 0. num_inference_steps: 30.
|
556 |
+
# [958, 925, 892, 859, 826, 793, 760, 727, 694, 661, 628, 595, 562, 529,
|
557 |
+
# 496, 463, 430, 397, 364, 331, 298, 265, 232, 199, 166, 133, 100, 67,
|
558 |
+
# 34, 1]
|
559 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
560 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
561 |
+
for i, t in enumerate(timesteps):
|
562 |
+
# expand the latents if we are doing classifier free guidance
|
563 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
564 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
565 |
+
|
566 |
+
down_block_additional_residuals = mid_block_additional_residual = None
|
567 |
+
if (getattr(self, "controlnet", None) != None) and (controlnet_images != None):
|
568 |
+
assert controlnet_images.dim() == 5
|
569 |
+
|
570 |
+
controlnet_noisy_latents = latent_model_input
|
571 |
+
controlnet_prompt_embeds = text_embeddings
|
572 |
+
|
573 |
+
controlnet_images = controlnet_images.to(latents.device)
|
574 |
+
|
575 |
+
controlnet_cond_shape = list(controlnet_images.shape)
|
576 |
+
controlnet_cond_shape[2] = video_length
|
577 |
+
controlnet_cond = torch.zeros(controlnet_cond_shape).to(latents.device).to(latents.dtype)
|
578 |
+
|
579 |
+
controlnet_conditioning_mask_shape = list(controlnet_cond.shape)
|
580 |
+
controlnet_conditioning_mask_shape[1] = 1
|
581 |
+
controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(latents.device).to(latents.dtype)
|
582 |
+
|
583 |
+
assert controlnet_images.shape[2] >= len(controlnet_image_index)
|
584 |
+
controlnet_cond[:,:,controlnet_image_index] = controlnet_images[:,:,:len(controlnet_image_index)]
|
585 |
+
controlnet_conditioning_mask[:,:,controlnet_image_index] = 1
|
586 |
+
|
587 |
+
|
588 |
+
down_block_additional_residuals, mid_block_additional_residual = self.controlnet(
|
589 |
+
controlnet_noisy_latents, t,
|
590 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
591 |
+
controlnet_cond=controlnet_cond,
|
592 |
+
conditioning_mask=controlnet_conditioning_mask,
|
593 |
+
conditioning_scale=controlnet_conditioning_scale,
|
594 |
+
guess_mode=False, return_dict=False,
|
595 |
+
)
|
596 |
+
|
597 |
+
if do_prompt_embeds_annealing:
|
598 |
+
# i: 0 to num_inference_steps. Anneal the first adaface_anneal_steps steps.
|
599 |
+
# If adaface_anneal_steps == 0, then anneal_factor is always 1.
|
600 |
+
anneal_factor = i / adaface_anneal_steps if i < adaface_anneal_steps else 1
|
601 |
+
prompt_embeds_annealed = prompt_embeds_begin + anneal_factor * (prompt_embeds_end - prompt_embeds_begin)
|
602 |
+
text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds_annealed])
|
603 |
+
|
604 |
+
# predict the noise residual
|
605 |
+
noise_pred = self.unet(
|
606 |
+
latent_model_input, t,
|
607 |
+
encoder_hidden_states=text_embeddings,
|
608 |
+
down_block_additional_residuals = down_block_additional_residuals,
|
609 |
+
mid_block_additional_residual = mid_block_additional_residual,
|
610 |
+
).sample.to(dtype=latents_dtype)
|
611 |
+
|
612 |
+
# perform guidance
|
613 |
+
if do_classifier_free_guidance:
|
614 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
615 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
616 |
+
|
617 |
+
# compute the previous noisy sample x_t -> x_t-1
|
618 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
619 |
+
|
620 |
+
# call the callback, if provided
|
621 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
622 |
+
progress_bar.update()
|
623 |
+
if callback is not None and i % callback_steps == 0:
|
624 |
+
callback(i, t, latents)
|
625 |
+
|
626 |
+
# Post-processing
|
627 |
+
video = self.decode_latents(latents)
|
628 |
+
|
629 |
+
# Convert to tensor
|
630 |
+
if output_type == "tensor":
|
631 |
+
video = torch.from_numpy(video)
|
632 |
+
|
633 |
+
if not return_dict:
|
634 |
+
return video
|
635 |
+
|
636 |
+
return AnimationPipelineOutput(videos=video)
|
637 |
+
@torch.no_grad()
|
638 |
+
def video_edit(
|
639 |
+
self,
|
640 |
+
prompt: Union[str, List[str]],
|
641 |
+
video_length: Optional[int],
|
642 |
+
init_image: Union[PIL.Image.Image, torch.Tensor],
|
643 |
+
init_image_strength: float = 1.0,
|
644 |
+
height: Optional[int] = None,
|
645 |
+
width: Optional[int] = None,
|
646 |
+
num_inference_steps: int = 50,
|
647 |
+
guidance_scale: float = 7.5,
|
648 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
649 |
+
num_videos_per_prompt: Optional[int] = 1,
|
650 |
+
eta: float = 0.0,
|
651 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
652 |
+
latents: Optional[torch.FloatTensor] = None,
|
653 |
+
output_type: Optional[str] = "tensor",
|
654 |
+
return_dict: bool = True,
|
655 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
656 |
+
callback_steps: Optional[int] = 1,
|
657 |
+
#support embeddings
|
658 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
659 |
+
negative_prompt_embeds:Optional[torch.FloatTensor] = None,
|
660 |
+
# support controlnet
|
661 |
+
controlnet_images: torch.FloatTensor = None,
|
662 |
+
controlnet_image_index: list = [0],
|
663 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
664 |
+
**kwargs,
|
665 |
+
):
|
666 |
+
# Default height and width to unet
|
667 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
668 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
669 |
+
|
670 |
+
# Check inputs. Raise error if not correct
|
671 |
+
self.check_inputs(prompt, height, width, callback_steps, prompt_embeds)
|
672 |
+
|
673 |
+
# Define call parameters
|
674 |
+
# batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
675 |
+
batch_size = 1
|
676 |
+
if latents is not None:
|
677 |
+
batch_size = latents.shape[0]
|
678 |
+
if isinstance(prompt, list):
|
679 |
+
batch_size = len(prompt)
|
680 |
+
|
681 |
+
device = self._execution_device
|
682 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
683 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
684 |
+
# corresponds to doing no classifier free guidance.
|
685 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
686 |
+
|
687 |
+
# Encode input prompt
|
688 |
+
prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
|
689 |
+
if negative_prompt is not None:
|
690 |
+
negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
|
691 |
+
if prompt_embeds is None:
|
692 |
+
text_embeddings = self._encode_prompt(
|
693 |
+
prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
|
694 |
+
)
|
695 |
+
else:
|
696 |
+
text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
|
697 |
+
|
698 |
+
# print(text_embeddings.shape)
|
699 |
+
# return
|
700 |
+
# Prepare timesteps
|
701 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
702 |
+
timesteps = self.scheduler.timesteps
|
703 |
+
|
704 |
+
# Prepare latent variables
|
705 |
+
num_channels_latents = self.unet.in_channels
|
706 |
+
latents = self.prepare_latents(
|
707 |
+
init_image,
|
708 |
+
init_image_strength,
|
709 |
+
batch_size * num_videos_per_prompt,
|
710 |
+
num_channels_latents,
|
711 |
+
video_length,
|
712 |
+
height,
|
713 |
+
width,
|
714 |
+
text_embeddings.dtype,
|
715 |
+
device,
|
716 |
+
generator,
|
717 |
+
latents,
|
718 |
+
).to(self.torch_dtype)
|
719 |
+
latents_dtype = latents.dtype
|
720 |
+
|
721 |
+
# Prepare extra step kwargs.
|
722 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
723 |
+
|
724 |
+
# Denoising loop
|
725 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
726 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
727 |
+
for i, t in enumerate(timesteps):
|
728 |
+
# expand the latents if we are doing classifier free guidance
|
729 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
730 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
731 |
+
|
732 |
+
down_block_additional_residuals = mid_block_additional_residual = None
|
733 |
+
if (getattr(self, "controlnet", None) != None) and (controlnet_images != None):
|
734 |
+
assert controlnet_images.dim() == 5
|
735 |
+
|
736 |
+
controlnet_noisy_latents = latent_model_input
|
737 |
+
controlnet_prompt_embeds = text_embeddings
|
738 |
+
|
739 |
+
controlnet_images = controlnet_images.to(latents.device)
|
740 |
+
|
741 |
+
controlnet_cond_shape = list(controlnet_images.shape)
|
742 |
+
controlnet_cond_shape[2] = video_length
|
743 |
+
controlnet_cond = torch.zeros(controlnet_cond_shape).to(latents.device)
|
744 |
+
|
745 |
+
controlnet_conditioning_mask_shape = list(controlnet_cond.shape)
|
746 |
+
controlnet_conditioning_mask_shape[1] = 1
|
747 |
+
controlnet_conditioning_mask = torch.zeros(controlnet_conditioning_mask_shape).to(latents.device)
|
748 |
+
|
749 |
+
assert controlnet_images.shape[2] >= len(controlnet_image_index)
|
750 |
+
controlnet_cond[:,:,controlnet_image_index] = controlnet_images[:,:,:len(controlnet_image_index)]
|
751 |
+
controlnet_conditioning_mask[:,:,controlnet_image_index] = 1
|
752 |
+
|
753 |
+
down_block_additional_residuals, mid_block_additional_residual = self.controlnet(
|
754 |
+
controlnet_noisy_latents, t,
|
755 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
756 |
+
controlnet_cond=controlnet_cond,
|
757 |
+
conditioning_mask=controlnet_conditioning_mask,
|
758 |
+
conditioning_scale=controlnet_conditioning_scale,
|
759 |
+
guess_mode=False, return_dict=False,
|
760 |
+
)
|
761 |
+
# predict the noise residual
|
762 |
+
noise_pred = self.unet(
|
763 |
+
latent_model_input, t,
|
764 |
+
encoder_hidden_states=text_embeddings,
|
765 |
+
down_block_additional_residuals = down_block_additional_residuals,
|
766 |
+
mid_block_additional_residual = mid_block_additional_residual,
|
767 |
+
).sample.to(dtype=latents_dtype)
|
768 |
+
|
769 |
+
# perform guidance
|
770 |
+
if do_classifier_free_guidance:
|
771 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
772 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
773 |
+
|
774 |
+
# compute the previous noisy sample x_t -> x_t-1
|
775 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
776 |
+
|
777 |
+
# call the callback, if provided
|
778 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
779 |
+
progress_bar.update()
|
780 |
+
if callback is not None and i % callback_steps == 0:
|
781 |
+
callback(i, t, latents)
|
782 |
+
|
783 |
+
# Post-processing
|
784 |
+
video = self.decode_latents(latents)
|
785 |
+
|
786 |
+
# Convert to tensor
|
787 |
+
if output_type == "tensor":
|
788 |
+
video = torch.from_numpy(video)
|
789 |
+
|
790 |
+
if not return_dict:
|
791 |
+
return video
|
792 |
+
|
793 |
+
return AnimationPipelineOutput(videos=video)
|
animatediff/sd/.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
25 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
34 |
+
v1-5-pruned-emaonly.ckpt filter=lfs diff=lfs merge=lfs -text
|
35 |
+
v1-5-pruned.ckpt filter=lfs diff=lfs merge=lfs -text
|
animatediff/sd/feature_extractor/preprocessor_config.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"crop_size": 224,
|
3 |
+
"do_center_crop": true,
|
4 |
+
"do_convert_rgb": true,
|
5 |
+
"do_normalize": true,
|
6 |
+
"do_resize": true,
|
7 |
+
"feature_extractor_type": "CLIPFeatureExtractor",
|
8 |
+
"image_mean": [
|
9 |
+
0.48145466,
|
10 |
+
0.4578275,
|
11 |
+
0.40821073
|
12 |
+
],
|
13 |
+
"image_std": [
|
14 |
+
0.26862954,
|
15 |
+
0.26130258,
|
16 |
+
0.27577711
|
17 |
+
],
|
18 |
+
"resample": 3,
|
19 |
+
"size": 224
|
20 |
+
}
|
animatediff/sd/model_index.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "StableDiffusionPipeline",
|
3 |
+
"_diffusers_version": "0.6.0",
|
4 |
+
"feature_extractor": [
|
5 |
+
"transformers",
|
6 |
+
"CLIPImageProcessor"
|
7 |
+
],
|
8 |
+
"safety_checker": [
|
9 |
+
"stable_diffusion",
|
10 |
+
"StableDiffusionSafetyChecker"
|
11 |
+
],
|
12 |
+
"scheduler": [
|
13 |
+
"diffusers",
|
14 |
+
"PNDMScheduler"
|
15 |
+
],
|
16 |
+
"text_encoder": [
|
17 |
+
"transformers",
|
18 |
+
"CLIPTextModel"
|
19 |
+
],
|
20 |
+
"tokenizer": [
|
21 |
+
"transformers",
|
22 |
+
"CLIPTokenizer"
|
23 |
+
],
|
24 |
+
"unet": [
|
25 |
+
"diffusers",
|
26 |
+
"UNet2DConditionModel"
|
27 |
+
],
|
28 |
+
"vae": [
|
29 |
+
"diffusers",
|
30 |
+
"AutoencoderKL"
|
31 |
+
]
|
32 |
+
}
|
animatediff/sd/safety_checker/config.json
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_commit_hash": "4bb648a606ef040e7685bde262611766a5fdd67b",
|
3 |
+
"_name_or_path": "CompVis/stable-diffusion-safety-checker",
|
4 |
+
"architectures": [
|
5 |
+
"StableDiffusionSafetyChecker"
|
6 |
+
],
|
7 |
+
"initializer_factor": 1.0,
|
8 |
+
"logit_scale_init_value": 2.6592,
|
9 |
+
"model_type": "clip",
|
10 |
+
"projection_dim": 768,
|
11 |
+
"text_config": {
|
12 |
+
"_name_or_path": "",
|
13 |
+
"add_cross_attention": false,
|
14 |
+
"architectures": null,
|
15 |
+
"attention_dropout": 0.0,
|
16 |
+
"bad_words_ids": null,
|
17 |
+
"bos_token_id": 0,
|
18 |
+
"chunk_size_feed_forward": 0,
|
19 |
+
"cross_attention_hidden_size": null,
|
20 |
+
"decoder_start_token_id": null,
|
21 |
+
"diversity_penalty": 0.0,
|
22 |
+
"do_sample": false,
|
23 |
+
"dropout": 0.0,
|
24 |
+
"early_stopping": false,
|
25 |
+
"encoder_no_repeat_ngram_size": 0,
|
26 |
+
"eos_token_id": 2,
|
27 |
+
"exponential_decay_length_penalty": null,
|
28 |
+
"finetuning_task": null,
|
29 |
+
"forced_bos_token_id": null,
|
30 |
+
"forced_eos_token_id": null,
|
31 |
+
"hidden_act": "quick_gelu",
|
32 |
+
"hidden_size": 768,
|
33 |
+
"id2label": {
|
34 |
+
"0": "LABEL_0",
|
35 |
+
"1": "LABEL_1"
|
36 |
+
},
|
37 |
+
"initializer_factor": 1.0,
|
38 |
+
"initializer_range": 0.02,
|
39 |
+
"intermediate_size": 3072,
|
40 |
+
"is_decoder": false,
|
41 |
+
"is_encoder_decoder": false,
|
42 |
+
"label2id": {
|
43 |
+
"LABEL_0": 0,
|
44 |
+
"LABEL_1": 1
|
45 |
+
},
|
46 |
+
"layer_norm_eps": 1e-05,
|
47 |
+
"length_penalty": 1.0,
|
48 |
+
"max_length": 20,
|
49 |
+
"max_position_embeddings": 77,
|
50 |
+
"min_length": 0,
|
51 |
+
"model_type": "clip_text_model",
|
52 |
+
"no_repeat_ngram_size": 0,
|
53 |
+
"num_attention_heads": 12,
|
54 |
+
"num_beam_groups": 1,
|
55 |
+
"num_beams": 1,
|
56 |
+
"num_hidden_layers": 12,
|
57 |
+
"num_return_sequences": 1,
|
58 |
+
"output_attentions": false,
|
59 |
+
"output_hidden_states": false,
|
60 |
+
"output_scores": false,
|
61 |
+
"pad_token_id": 1,
|
62 |
+
"prefix": null,
|
63 |
+
"problem_type": null,
|
64 |
+
"pruned_heads": {},
|
65 |
+
"remove_invalid_values": false,
|
66 |
+
"repetition_penalty": 1.0,
|
67 |
+
"return_dict": true,
|
68 |
+
"return_dict_in_generate": false,
|
69 |
+
"sep_token_id": null,
|
70 |
+
"task_specific_params": null,
|
71 |
+
"temperature": 1.0,
|
72 |
+
"tf_legacy_loss": false,
|
73 |
+
"tie_encoder_decoder": false,
|
74 |
+
"tie_word_embeddings": true,
|
75 |
+
"tokenizer_class": null,
|
76 |
+
"top_k": 50,
|
77 |
+
"top_p": 1.0,
|
78 |
+
"torch_dtype": null,
|
79 |
+
"torchscript": false,
|
80 |
+
"transformers_version": "4.22.0.dev0",
|
81 |
+
"typical_p": 1.0,
|
82 |
+
"use_bfloat16": false,
|
83 |
+
"vocab_size": 49408
|
84 |
+
},
|
85 |
+
"text_config_dict": {
|
86 |
+
"hidden_size": 768,
|
87 |
+
"intermediate_size": 3072,
|
88 |
+
"num_attention_heads": 12,
|
89 |
+
"num_hidden_layers": 12
|
90 |
+
},
|
91 |
+
"torch_dtype": "float32",
|
92 |
+
"transformers_version": null,
|
93 |
+
"vision_config": {
|
94 |
+
"_name_or_path": "",
|
95 |
+
"add_cross_attention": false,
|
96 |
+
"architectures": null,
|
97 |
+
"attention_dropout": 0.0,
|
98 |
+
"bad_words_ids": null,
|
99 |
+
"bos_token_id": null,
|
100 |
+
"chunk_size_feed_forward": 0,
|
101 |
+
"cross_attention_hidden_size": null,
|
102 |
+
"decoder_start_token_id": null,
|
103 |
+
"diversity_penalty": 0.0,
|
104 |
+
"do_sample": false,
|
105 |
+
"dropout": 0.0,
|
106 |
+
"early_stopping": false,
|
107 |
+
"encoder_no_repeat_ngram_size": 0,
|
108 |
+
"eos_token_id": null,
|
109 |
+
"exponential_decay_length_penalty": null,
|
110 |
+
"finetuning_task": null,
|
111 |
+
"forced_bos_token_id": null,
|
112 |
+
"forced_eos_token_id": null,
|
113 |
+
"hidden_act": "quick_gelu",
|
114 |
+
"hidden_size": 1024,
|
115 |
+
"id2label": {
|
116 |
+
"0": "LABEL_0",
|
117 |
+
"1": "LABEL_1"
|
118 |
+
},
|
119 |
+
"image_size": 224,
|
120 |
+
"initializer_factor": 1.0,
|
121 |
+
"initializer_range": 0.02,
|
122 |
+
"intermediate_size": 4096,
|
123 |
+
"is_decoder": false,
|
124 |
+
"is_encoder_decoder": false,
|
125 |
+
"label2id": {
|
126 |
+
"LABEL_0": 0,
|
127 |
+
"LABEL_1": 1
|
128 |
+
},
|
129 |
+
"layer_norm_eps": 1e-05,
|
130 |
+
"length_penalty": 1.0,
|
131 |
+
"max_length": 20,
|
132 |
+
"min_length": 0,
|
133 |
+
"model_type": "clip_vision_model",
|
134 |
+
"no_repeat_ngram_size": 0,
|
135 |
+
"num_attention_heads": 16,
|
136 |
+
"num_beam_groups": 1,
|
137 |
+
"num_beams": 1,
|
138 |
+
"num_channels": 3,
|
139 |
+
"num_hidden_layers": 24,
|
140 |
+
"num_return_sequences": 1,
|
141 |
+
"output_attentions": false,
|
142 |
+
"output_hidden_states": false,
|
143 |
+
"output_scores": false,
|
144 |
+
"pad_token_id": null,
|
145 |
+
"patch_size": 14,
|
146 |
+
"prefix": null,
|
147 |
+
"problem_type": null,
|
148 |
+
"pruned_heads": {},
|
149 |
+
"remove_invalid_values": false,
|
150 |
+
"repetition_penalty": 1.0,
|
151 |
+
"return_dict": true,
|
152 |
+
"return_dict_in_generate": false,
|
153 |
+
"sep_token_id": null,
|
154 |
+
"task_specific_params": null,
|
155 |
+
"temperature": 1.0,
|
156 |
+
"tf_legacy_loss": false,
|
157 |
+
"tie_encoder_decoder": false,
|
158 |
+
"tie_word_embeddings": true,
|
159 |
+
"tokenizer_class": null,
|
160 |
+
"top_k": 50,
|
161 |
+
"top_p": 1.0,
|
162 |
+
"torch_dtype": null,
|
163 |
+
"torchscript": false,
|
164 |
+
"transformers_version": "4.22.0.dev0",
|
165 |
+
"typical_p": 1.0,
|
166 |
+
"use_bfloat16": false
|
167 |
+
},
|
168 |
+
"vision_config_dict": {
|
169 |
+
"hidden_size": 1024,
|
170 |
+
"intermediate_size": 4096,
|
171 |
+
"num_attention_heads": 16,
|
172 |
+
"num_hidden_layers": 24,
|
173 |
+
"patch_size": 14
|
174 |
+
}
|
175 |
+
}
|
animatediff/sd/safety_checker/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:193490b58ef62739077262e833bf091c66c29488058681ac25cf7df3d8190974
|
3 |
+
size 1216061799
|
animatediff/sd/scheduler/scheduler_config.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "PNDMScheduler",
|
3 |
+
"_diffusers_version": "0.6.0",
|
4 |
+
"beta_end": 0.012,
|
5 |
+
"beta_schedule": "scaled_linear",
|
6 |
+
"beta_start": 0.00085,
|
7 |
+
"num_train_timesteps": 1000,
|
8 |
+
"set_alpha_to_one": false,
|
9 |
+
"skip_prk_steps": true,
|
10 |
+
"steps_offset": 1,
|
11 |
+
"trained_betas": null,
|
12 |
+
"clip_sample": false
|
13 |
+
}
|
animatediff/sd/text_encoder/config.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "openai/clip-vit-large-patch14",
|
3 |
+
"architectures": [
|
4 |
+
"CLIPTextModel"
|
5 |
+
],
|
6 |
+
"attention_dropout": 0.0,
|
7 |
+
"bos_token_id": 0,
|
8 |
+
"dropout": 0.0,
|
9 |
+
"eos_token_id": 2,
|
10 |
+
"hidden_act": "quick_gelu",
|
11 |
+
"hidden_size": 768,
|
12 |
+
"initializer_factor": 1.0,
|
13 |
+
"initializer_range": 0.02,
|
14 |
+
"intermediate_size": 3072,
|
15 |
+
"layer_norm_eps": 1e-05,
|
16 |
+
"max_position_embeddings": 77,
|
17 |
+
"model_type": "clip_text_model",
|
18 |
+
"num_attention_heads": 12,
|
19 |
+
"num_hidden_layers": 12,
|
20 |
+
"pad_token_id": 1,
|
21 |
+
"projection_dim": 768,
|
22 |
+
"torch_dtype": "float32",
|
23 |
+
"transformers_version": "4.22.0.dev0",
|
24 |
+
"vocab_size": 49408
|
25 |
+
}
|
animatediff/sd/text_encoder/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:770a47a9ffdcfda0b05506a7888ed714d06131d60267e6cf52765d61cf59fd67
|
3 |
+
size 492305335
|
animatediff/sd/tokenizer/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
animatediff/sd/tokenizer/special_tokens_map.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<|startoftext|>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": true,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "<|endoftext|>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": true,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": "<|endoftext|>",
|
17 |
+
"unk_token": {
|
18 |
+
"content": "<|endoftext|>",
|
19 |
+
"lstrip": false,
|
20 |
+
"normalized": true,
|
21 |
+
"rstrip": false,
|
22 |
+
"single_word": false
|
23 |
+
}
|
24 |
+
}
|
animatediff/sd/tokenizer/tokenizer_config.json
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"bos_token": {
|
4 |
+
"__type": "AddedToken",
|
5 |
+
"content": "<|startoftext|>",
|
6 |
+
"lstrip": false,
|
7 |
+
"normalized": true,
|
8 |
+
"rstrip": false,
|
9 |
+
"single_word": false
|
10 |
+
},
|
11 |
+
"do_lower_case": true,
|
12 |
+
"eos_token": {
|
13 |
+
"__type": "AddedToken",
|
14 |
+
"content": "<|endoftext|>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": true,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false
|
19 |
+
},
|
20 |
+
"errors": "replace",
|
21 |
+
"model_max_length": 77,
|
22 |
+
"name_or_path": "openai/clip-vit-large-patch14",
|
23 |
+
"pad_token": "<|endoftext|>",
|
24 |
+
"special_tokens_map_file": "./special_tokens_map.json",
|
25 |
+
"tokenizer_class": "CLIPTokenizer",
|
26 |
+
"unk_token": {
|
27 |
+
"__type": "AddedToken",
|
28 |
+
"content": "<|endoftext|>",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": true,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false
|
33 |
+
}
|
34 |
+
}
|
animatediff/sd/tokenizer/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
animatediff/sd/unet/config.json
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "UNet2DConditionModel",
|
3 |
+
"_diffusers_version": "0.6.0",
|
4 |
+
"act_fn": "silu",
|
5 |
+
"attention_head_dim": 8,
|
6 |
+
"block_out_channels": [
|
7 |
+
320,
|
8 |
+
640,
|
9 |
+
1280,
|
10 |
+
1280
|
11 |
+
],
|
12 |
+
"center_input_sample": false,
|
13 |
+
"cross_attention_dim": 768,
|
14 |
+
"down_block_types": [
|
15 |
+
"CrossAttnDownBlock2D",
|
16 |
+
"CrossAttnDownBlock2D",
|
17 |
+
"CrossAttnDownBlock2D",
|
18 |
+
"DownBlock2D"
|
19 |
+
],
|
20 |
+
"downsample_padding": 1,
|
21 |
+
"flip_sin_to_cos": true,
|
22 |
+
"freq_shift": 0,
|
23 |
+
"in_channels": 4,
|
24 |
+
"layers_per_block": 2,
|
25 |
+
"mid_block_scale_factor": 1,
|
26 |
+
"norm_eps": 1e-05,
|
27 |
+
"norm_num_groups": 32,
|
28 |
+
"out_channels": 4,
|
29 |
+
"sample_size": 64,
|
30 |
+
"up_block_types": [
|
31 |
+
"UpBlock2D",
|
32 |
+
"CrossAttnUpBlock2D",
|
33 |
+
"CrossAttnUpBlock2D",
|
34 |
+
"CrossAttnUpBlock2D"
|
35 |
+
]
|
36 |
+
}
|
animatediff/sd/unet/diffusion_pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c7da0e21ba7ea50637bee26e81c220844defdf01aafca02b2c42ecdadb813de4
|
3 |
+
size 3438354725
|
animatediff/sd/v1-inference.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 1.0e-04
|
3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.0120
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: "jpg"
|
11 |
+
cond_stage_key: "txt"
|
12 |
+
image_size: 64
|
13 |
+
channels: 4
|
14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
15 |
+
conditioning_key: crossattn
|
16 |
+
monitor: val/loss_simple_ema
|
17 |
+
scale_factor: 0.18215
|
18 |
+
use_ema: False
|
19 |
+
|
20 |
+
scheduler_config: # 10000 warmup steps
|
21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
22 |
+
params:
|
23 |
+
warm_up_steps: [ 10000 ]
|
24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
25 |
+
f_start: [ 1.e-6 ]
|
26 |
+
f_max: [ 1. ]
|
27 |
+
f_min: [ 1. ]
|
28 |
+
|
29 |
+
unet_config:
|
30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
31 |
+
params:
|
32 |
+
image_size: 32 # unused
|
33 |
+
in_channels: 4
|
34 |
+
out_channels: 4
|
35 |
+
model_channels: 320
|
36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
37 |
+
num_res_blocks: 2
|
38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
39 |
+
num_heads: 8
|
40 |
+
use_spatial_transformer: True
|
41 |
+
transformer_depth: 1
|
42 |
+
context_dim: 768
|
43 |
+
use_checkpoint: True
|
44 |
+
legacy: False
|
45 |
+
|
46 |
+
first_stage_config:
|
47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
48 |
+
params:
|
49 |
+
embed_dim: 4
|
50 |
+
monitor: val/rec_loss
|
51 |
+
ddconfig:
|
52 |
+
double_z: true
|
53 |
+
z_channels: 4
|
54 |
+
resolution: 256
|
55 |
+
in_channels: 3
|
56 |
+
out_ch: 3
|
57 |
+
ch: 128
|
58 |
+
ch_mult:
|
59 |
+
- 1
|
60 |
+
- 2
|
61 |
+
- 4
|
62 |
+
- 4
|
63 |
+
num_res_blocks: 2
|
64 |
+
attn_resolutions: []
|
65 |
+
dropout: 0.0
|
66 |
+
lossconfig:
|
67 |
+
target: torch.nn.Identity
|
68 |
+
|
69 |
+
cond_stage_config:
|
70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
animatediff/sd/vae/config.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "AutoencoderKL",
|
3 |
+
"_diffusers_version": "0.6.0",
|
4 |
+
"act_fn": "silu",
|
5 |
+
"block_out_channels": [
|
6 |
+
128,
|
7 |
+
256,
|
8 |
+
512,
|
9 |
+
512
|
10 |
+
],
|
11 |
+
"down_block_types": [
|
12 |
+
"DownEncoderBlock2D",
|
13 |
+
"DownEncoderBlock2D",
|
14 |
+
"DownEncoderBlock2D",
|
15 |
+
"DownEncoderBlock2D"
|
16 |
+
],
|
17 |
+
"in_channels": 3,
|
18 |
+
"latent_channels": 4,
|
19 |
+
"layers_per_block": 2,
|
20 |
+
"norm_num_groups": 32,
|
21 |
+
"out_channels": 3,
|
22 |
+
"sample_size": 512,
|
23 |
+
"up_block_types": [
|
24 |
+
"UpDecoderBlock2D",
|
25 |
+
"UpDecoderBlock2D",
|
26 |
+
"UpDecoderBlock2D",
|
27 |
+
"UpDecoderBlock2D"
|
28 |
+
]
|
29 |
+
}
|
animatediff/sd/vae/diffusion_pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1b134cded8eb78b184aefb8805b6b572f36fa77b255c483665dda931fa0130c5
|
3 |
+
size 334707217
|
animatediff/utils/convert_from_ckpt.py
ADDED
@@ -0,0 +1,959 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Conversion script for the Stable Diffusion checkpoints."""
|
16 |
+
|
17 |
+
import re
|
18 |
+
from io import BytesIO
|
19 |
+
from typing import Optional
|
20 |
+
|
21 |
+
import requests
|
22 |
+
import torch
|
23 |
+
from transformers import (
|
24 |
+
AutoFeatureExtractor,
|
25 |
+
BertTokenizerFast,
|
26 |
+
CLIPImageProcessor,
|
27 |
+
CLIPTextModel,
|
28 |
+
CLIPTextModelWithProjection,
|
29 |
+
CLIPTokenizer,
|
30 |
+
CLIPVisionConfig,
|
31 |
+
CLIPVisionModelWithProjection,
|
32 |
+
)
|
33 |
+
|
34 |
+
from diffusers.models import (
|
35 |
+
AutoencoderKL,
|
36 |
+
PriorTransformer,
|
37 |
+
UNet2DConditionModel,
|
38 |
+
)
|
39 |
+
from diffusers.schedulers import (
|
40 |
+
DDIMScheduler,
|
41 |
+
DDPMScheduler,
|
42 |
+
DPMSolverMultistepScheduler,
|
43 |
+
EulerAncestralDiscreteScheduler,
|
44 |
+
EulerDiscreteScheduler,
|
45 |
+
HeunDiscreteScheduler,
|
46 |
+
LMSDiscreteScheduler,
|
47 |
+
PNDMScheduler,
|
48 |
+
UnCLIPScheduler,
|
49 |
+
)
|
50 |
+
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
51 |
+
|
52 |
+
|
53 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
54 |
+
"""
|
55 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
56 |
+
"""
|
57 |
+
if n_shave_prefix_segments >= 0:
|
58 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
59 |
+
else:
|
60 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
61 |
+
|
62 |
+
|
63 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
64 |
+
"""
|
65 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
66 |
+
"""
|
67 |
+
mapping = []
|
68 |
+
for old_item in old_list:
|
69 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
70 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
71 |
+
|
72 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
73 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
74 |
+
|
75 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
76 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
77 |
+
|
78 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
79 |
+
|
80 |
+
mapping.append({"old": old_item, "new": new_item})
|
81 |
+
|
82 |
+
return mapping
|
83 |
+
|
84 |
+
|
85 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
86 |
+
"""
|
87 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
88 |
+
"""
|
89 |
+
mapping = []
|
90 |
+
for old_item in old_list:
|
91 |
+
new_item = old_item
|
92 |
+
|
93 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
94 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
95 |
+
|
96 |
+
mapping.append({"old": old_item, "new": new_item})
|
97 |
+
|
98 |
+
return mapping
|
99 |
+
|
100 |
+
|
101 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
102 |
+
"""
|
103 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
104 |
+
"""
|
105 |
+
mapping = []
|
106 |
+
for old_item in old_list:
|
107 |
+
new_item = old_item
|
108 |
+
|
109 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
110 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
111 |
+
|
112 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
113 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
114 |
+
|
115 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
116 |
+
|
117 |
+
mapping.append({"old": old_item, "new": new_item})
|
118 |
+
|
119 |
+
return mapping
|
120 |
+
|
121 |
+
|
122 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
123 |
+
"""
|
124 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
125 |
+
"""
|
126 |
+
mapping = []
|
127 |
+
for old_item in old_list:
|
128 |
+
new_item = old_item
|
129 |
+
|
130 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
131 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
132 |
+
|
133 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
134 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
135 |
+
|
136 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
137 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
138 |
+
|
139 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
140 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
141 |
+
|
142 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
143 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
144 |
+
|
145 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
146 |
+
|
147 |
+
mapping.append({"old": old_item, "new": new_item})
|
148 |
+
|
149 |
+
return mapping
|
150 |
+
|
151 |
+
|
152 |
+
def assign_to_checkpoint(
|
153 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
154 |
+
):
|
155 |
+
"""
|
156 |
+
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
157 |
+
attention layers, and takes into account additional replacements that may arise.
|
158 |
+
|
159 |
+
Assigns the weights to the new checkpoint.
|
160 |
+
"""
|
161 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
162 |
+
|
163 |
+
# Splits the attention layers into three variables.
|
164 |
+
if attention_paths_to_split is not None:
|
165 |
+
for path, path_map in attention_paths_to_split.items():
|
166 |
+
old_tensor = old_checkpoint[path]
|
167 |
+
channels = old_tensor.shape[0] // 3
|
168 |
+
|
169 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
170 |
+
|
171 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
172 |
+
|
173 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
174 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
175 |
+
|
176 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
177 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
178 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
179 |
+
|
180 |
+
for path in paths:
|
181 |
+
new_path = path["new"]
|
182 |
+
|
183 |
+
# These have already been assigned
|
184 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
185 |
+
continue
|
186 |
+
|
187 |
+
# Global renaming happens here
|
188 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
189 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
190 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
191 |
+
|
192 |
+
if additional_replacements is not None:
|
193 |
+
for replacement in additional_replacements:
|
194 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
195 |
+
|
196 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
197 |
+
if "proj_attn.weight" in new_path:
|
198 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
199 |
+
else:
|
200 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
201 |
+
|
202 |
+
|
203 |
+
def conv_attn_to_linear(checkpoint):
|
204 |
+
keys = list(checkpoint.keys())
|
205 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
206 |
+
for key in keys:
|
207 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
208 |
+
if checkpoint[key].ndim > 2:
|
209 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
210 |
+
elif "proj_attn.weight" in key:
|
211 |
+
if checkpoint[key].ndim > 2:
|
212 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
213 |
+
|
214 |
+
|
215 |
+
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
|
216 |
+
"""
|
217 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
218 |
+
"""
|
219 |
+
if controlnet:
|
220 |
+
unet_params = original_config.model.params.control_stage_config.params
|
221 |
+
else:
|
222 |
+
unet_params = original_config.model.params.unet_config.params
|
223 |
+
|
224 |
+
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
225 |
+
|
226 |
+
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
227 |
+
|
228 |
+
down_block_types = []
|
229 |
+
resolution = 1
|
230 |
+
for i in range(len(block_out_channels)):
|
231 |
+
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
232 |
+
down_block_types.append(block_type)
|
233 |
+
if i != len(block_out_channels) - 1:
|
234 |
+
resolution *= 2
|
235 |
+
|
236 |
+
up_block_types = []
|
237 |
+
for i in range(len(block_out_channels)):
|
238 |
+
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
239 |
+
up_block_types.append(block_type)
|
240 |
+
resolution //= 2
|
241 |
+
|
242 |
+
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
243 |
+
|
244 |
+
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
245 |
+
use_linear_projection = (
|
246 |
+
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
|
247 |
+
)
|
248 |
+
if use_linear_projection:
|
249 |
+
# stable diffusion 2-base-512 and 2-768
|
250 |
+
if head_dim is None:
|
251 |
+
head_dim = [5, 10, 20, 20]
|
252 |
+
|
253 |
+
class_embed_type = None
|
254 |
+
projection_class_embeddings_input_dim = None
|
255 |
+
|
256 |
+
if "num_classes" in unet_params:
|
257 |
+
if unet_params.num_classes == "sequential":
|
258 |
+
class_embed_type = "projection"
|
259 |
+
assert "adm_in_channels" in unet_params
|
260 |
+
projection_class_embeddings_input_dim = unet_params.adm_in_channels
|
261 |
+
else:
|
262 |
+
raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
|
263 |
+
|
264 |
+
config = {
|
265 |
+
"sample_size": image_size // vae_scale_factor,
|
266 |
+
"in_channels": unet_params.in_channels,
|
267 |
+
"down_block_types": tuple(down_block_types),
|
268 |
+
"block_out_channels": tuple(block_out_channels),
|
269 |
+
"layers_per_block": unet_params.num_res_blocks,
|
270 |
+
"cross_attention_dim": unet_params.context_dim,
|
271 |
+
"attention_head_dim": head_dim,
|
272 |
+
"use_linear_projection": use_linear_projection,
|
273 |
+
"class_embed_type": class_embed_type,
|
274 |
+
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
275 |
+
}
|
276 |
+
|
277 |
+
if not controlnet:
|
278 |
+
config["out_channels"] = unet_params.out_channels
|
279 |
+
config["up_block_types"] = tuple(up_block_types)
|
280 |
+
|
281 |
+
return config
|
282 |
+
|
283 |
+
|
284 |
+
def create_vae_diffusers_config(original_config, image_size: int):
|
285 |
+
"""
|
286 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
287 |
+
"""
|
288 |
+
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
289 |
+
_ = original_config.model.params.first_stage_config.params.embed_dim
|
290 |
+
|
291 |
+
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
292 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
293 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
294 |
+
|
295 |
+
config = {
|
296 |
+
"sample_size": image_size,
|
297 |
+
"in_channels": vae_params.in_channels,
|
298 |
+
"out_channels": vae_params.out_ch,
|
299 |
+
"down_block_types": tuple(down_block_types),
|
300 |
+
"up_block_types": tuple(up_block_types),
|
301 |
+
"block_out_channels": tuple(block_out_channels),
|
302 |
+
"latent_channels": vae_params.z_channels,
|
303 |
+
"layers_per_block": vae_params.num_res_blocks,
|
304 |
+
}
|
305 |
+
return config
|
306 |
+
|
307 |
+
|
308 |
+
def create_diffusers_schedular(original_config):
|
309 |
+
schedular = DDIMScheduler(
|
310 |
+
num_train_timesteps=original_config.model.params.timesteps,
|
311 |
+
beta_start=original_config.model.params.linear_start,
|
312 |
+
beta_end=original_config.model.params.linear_end,
|
313 |
+
beta_schedule="scaled_linear",
|
314 |
+
)
|
315 |
+
return schedular
|
316 |
+
|
317 |
+
|
318 |
+
def create_ldm_bert_config(original_config):
|
319 |
+
bert_params = original_config.model.parms.cond_stage_config.params
|
320 |
+
config = LDMBertConfig(
|
321 |
+
d_model=bert_params.n_embed,
|
322 |
+
encoder_layers=bert_params.n_layer,
|
323 |
+
encoder_ffn_dim=bert_params.n_embed * 4,
|
324 |
+
)
|
325 |
+
return config
|
326 |
+
|
327 |
+
|
328 |
+
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
|
329 |
+
"""
|
330 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
331 |
+
"""
|
332 |
+
|
333 |
+
# extract state_dict for UNet
|
334 |
+
unet_state_dict = {}
|
335 |
+
keys = list(checkpoint.keys())
|
336 |
+
|
337 |
+
if controlnet:
|
338 |
+
unet_key = "control_model."
|
339 |
+
else:
|
340 |
+
unet_key = "model.diffusion_model."
|
341 |
+
|
342 |
+
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
343 |
+
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
344 |
+
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
345 |
+
print(
|
346 |
+
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
347 |
+
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
348 |
+
)
|
349 |
+
for key in keys:
|
350 |
+
if key.startswith("model.diffusion_model"):
|
351 |
+
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
352 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
353 |
+
else:
|
354 |
+
if sum(k.startswith("model_ema") for k in keys) > 100:
|
355 |
+
print(
|
356 |
+
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
357 |
+
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
358 |
+
)
|
359 |
+
|
360 |
+
for key in keys:
|
361 |
+
if key.startswith(unet_key):
|
362 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
363 |
+
|
364 |
+
new_checkpoint = {}
|
365 |
+
|
366 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
367 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
368 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
369 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
370 |
+
|
371 |
+
if config["class_embed_type"] is None:
|
372 |
+
# No parameters to port
|
373 |
+
...
|
374 |
+
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
|
375 |
+
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
376 |
+
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
377 |
+
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
378 |
+
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
379 |
+
else:
|
380 |
+
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
|
381 |
+
|
382 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
383 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
384 |
+
|
385 |
+
if not controlnet:
|
386 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
387 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
388 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
389 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
390 |
+
|
391 |
+
# Retrieves the keys for the input blocks only
|
392 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
393 |
+
input_blocks = {
|
394 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
395 |
+
for layer_id in range(num_input_blocks)
|
396 |
+
}
|
397 |
+
|
398 |
+
# Retrieves the keys for the middle blocks only
|
399 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
400 |
+
middle_blocks = {
|
401 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
402 |
+
for layer_id in range(num_middle_blocks)
|
403 |
+
}
|
404 |
+
|
405 |
+
# Retrieves the keys for the output blocks only
|
406 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
407 |
+
output_blocks = {
|
408 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
409 |
+
for layer_id in range(num_output_blocks)
|
410 |
+
}
|
411 |
+
|
412 |
+
for i in range(1, num_input_blocks):
|
413 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
414 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
415 |
+
|
416 |
+
resnets = [
|
417 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
418 |
+
]
|
419 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
420 |
+
|
421 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
422 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
423 |
+
f"input_blocks.{i}.0.op.weight"
|
424 |
+
)
|
425 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
426 |
+
f"input_blocks.{i}.0.op.bias"
|
427 |
+
)
|
428 |
+
|
429 |
+
paths = renew_resnet_paths(resnets)
|
430 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
431 |
+
assign_to_checkpoint(
|
432 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
433 |
+
)
|
434 |
+
|
435 |
+
if len(attentions):
|
436 |
+
paths = renew_attention_paths(attentions)
|
437 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
438 |
+
assign_to_checkpoint(
|
439 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
440 |
+
)
|
441 |
+
|
442 |
+
resnet_0 = middle_blocks[0]
|
443 |
+
attentions = middle_blocks[1]
|
444 |
+
resnet_1 = middle_blocks[2]
|
445 |
+
|
446 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
447 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
448 |
+
|
449 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
450 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
451 |
+
|
452 |
+
attentions_paths = renew_attention_paths(attentions)
|
453 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
454 |
+
assign_to_checkpoint(
|
455 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
456 |
+
)
|
457 |
+
|
458 |
+
for i in range(num_output_blocks):
|
459 |
+
block_id = i // (config["layers_per_block"] + 1)
|
460 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
461 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
462 |
+
output_block_list = {}
|
463 |
+
|
464 |
+
for layer in output_block_layers:
|
465 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
466 |
+
if layer_id in output_block_list:
|
467 |
+
output_block_list[layer_id].append(layer_name)
|
468 |
+
else:
|
469 |
+
output_block_list[layer_id] = [layer_name]
|
470 |
+
|
471 |
+
if len(output_block_list) > 1:
|
472 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
473 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
474 |
+
|
475 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
476 |
+
paths = renew_resnet_paths(resnets)
|
477 |
+
|
478 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
479 |
+
assign_to_checkpoint(
|
480 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
481 |
+
)
|
482 |
+
|
483 |
+
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
484 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
485 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
486 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
487 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
488 |
+
]
|
489 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
490 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
491 |
+
]
|
492 |
+
|
493 |
+
# Clear attentions as they have been attributed above.
|
494 |
+
if len(attentions) == 2:
|
495 |
+
attentions = []
|
496 |
+
|
497 |
+
if len(attentions):
|
498 |
+
paths = renew_attention_paths(attentions)
|
499 |
+
meta_path = {
|
500 |
+
"old": f"output_blocks.{i}.1",
|
501 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
502 |
+
}
|
503 |
+
assign_to_checkpoint(
|
504 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
505 |
+
)
|
506 |
+
else:
|
507 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
508 |
+
for path in resnet_0_paths:
|
509 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
510 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
511 |
+
|
512 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
513 |
+
|
514 |
+
if controlnet:
|
515 |
+
# conditioning embedding
|
516 |
+
|
517 |
+
orig_index = 0
|
518 |
+
|
519 |
+
new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
|
520 |
+
f"input_hint_block.{orig_index}.weight"
|
521 |
+
)
|
522 |
+
new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
|
523 |
+
f"input_hint_block.{orig_index}.bias"
|
524 |
+
)
|
525 |
+
|
526 |
+
orig_index += 2
|
527 |
+
|
528 |
+
diffusers_index = 0
|
529 |
+
|
530 |
+
while diffusers_index < 6:
|
531 |
+
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
|
532 |
+
f"input_hint_block.{orig_index}.weight"
|
533 |
+
)
|
534 |
+
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
|
535 |
+
f"input_hint_block.{orig_index}.bias"
|
536 |
+
)
|
537 |
+
diffusers_index += 1
|
538 |
+
orig_index += 2
|
539 |
+
|
540 |
+
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
|
541 |
+
f"input_hint_block.{orig_index}.weight"
|
542 |
+
)
|
543 |
+
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
|
544 |
+
f"input_hint_block.{orig_index}.bias"
|
545 |
+
)
|
546 |
+
|
547 |
+
# down blocks
|
548 |
+
for i in range(num_input_blocks):
|
549 |
+
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
|
550 |
+
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
|
551 |
+
|
552 |
+
# mid block
|
553 |
+
new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
|
554 |
+
new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
|
555 |
+
|
556 |
+
return new_checkpoint
|
557 |
+
|
558 |
+
|
559 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
560 |
+
# extract state dict for VAE
|
561 |
+
vae_state_dict = {}
|
562 |
+
vae_key = "first_stage_model."
|
563 |
+
keys = list(checkpoint.keys())
|
564 |
+
for key in keys:
|
565 |
+
if key.startswith(vae_key):
|
566 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
567 |
+
|
568 |
+
new_checkpoint = {}
|
569 |
+
|
570 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
571 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
572 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
573 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
574 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
575 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
576 |
+
|
577 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
578 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
579 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
580 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
581 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
582 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
583 |
+
|
584 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
585 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
586 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
587 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
588 |
+
|
589 |
+
# Retrieves the keys for the encoder down blocks only
|
590 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
591 |
+
down_blocks = {
|
592 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
593 |
+
}
|
594 |
+
|
595 |
+
# Retrieves the keys for the decoder up blocks only
|
596 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
597 |
+
up_blocks = {
|
598 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
599 |
+
}
|
600 |
+
|
601 |
+
for i in range(num_down_blocks):
|
602 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
603 |
+
|
604 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
605 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
606 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
607 |
+
)
|
608 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
609 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
610 |
+
)
|
611 |
+
|
612 |
+
paths = renew_vae_resnet_paths(resnets)
|
613 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
614 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
615 |
+
|
616 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
617 |
+
num_mid_res_blocks = 2
|
618 |
+
for i in range(1, num_mid_res_blocks + 1):
|
619 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
620 |
+
|
621 |
+
paths = renew_vae_resnet_paths(resnets)
|
622 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
623 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
624 |
+
|
625 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
626 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
627 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
628 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
629 |
+
conv_attn_to_linear(new_checkpoint)
|
630 |
+
|
631 |
+
for i in range(num_up_blocks):
|
632 |
+
block_id = num_up_blocks - 1 - i
|
633 |
+
resnets = [
|
634 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
635 |
+
]
|
636 |
+
|
637 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
638 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
639 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
640 |
+
]
|
641 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
642 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
643 |
+
]
|
644 |
+
|
645 |
+
paths = renew_vae_resnet_paths(resnets)
|
646 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
647 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
648 |
+
|
649 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
650 |
+
num_mid_res_blocks = 2
|
651 |
+
for i in range(1, num_mid_res_blocks + 1):
|
652 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
653 |
+
|
654 |
+
paths = renew_vae_resnet_paths(resnets)
|
655 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
656 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
657 |
+
|
658 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
659 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
660 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
661 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
662 |
+
conv_attn_to_linear(new_checkpoint)
|
663 |
+
return new_checkpoint
|
664 |
+
|
665 |
+
|
666 |
+
def convert_ldm_bert_checkpoint(checkpoint, config):
|
667 |
+
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
668 |
+
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
|
669 |
+
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
|
670 |
+
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
|
671 |
+
|
672 |
+
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
|
673 |
+
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
|
674 |
+
|
675 |
+
def _copy_linear(hf_linear, pt_linear):
|
676 |
+
hf_linear.weight = pt_linear.weight
|
677 |
+
hf_linear.bias = pt_linear.bias
|
678 |
+
|
679 |
+
def _copy_layer(hf_layer, pt_layer):
|
680 |
+
# copy layer norms
|
681 |
+
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
|
682 |
+
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
|
683 |
+
|
684 |
+
# copy attn
|
685 |
+
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
|
686 |
+
|
687 |
+
# copy MLP
|
688 |
+
pt_mlp = pt_layer[1][1]
|
689 |
+
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
|
690 |
+
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
|
691 |
+
|
692 |
+
def _copy_layers(hf_layers, pt_layers):
|
693 |
+
for i, hf_layer in enumerate(hf_layers):
|
694 |
+
if i != 0:
|
695 |
+
i += i
|
696 |
+
pt_layer = pt_layers[i : i + 2]
|
697 |
+
_copy_layer(hf_layer, pt_layer)
|
698 |
+
|
699 |
+
hf_model = LDMBertModel(config).eval()
|
700 |
+
|
701 |
+
# copy embeds
|
702 |
+
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
|
703 |
+
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
|
704 |
+
|
705 |
+
# copy layer norm
|
706 |
+
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
|
707 |
+
|
708 |
+
# copy hidden layers
|
709 |
+
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
|
710 |
+
|
711 |
+
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
|
712 |
+
|
713 |
+
return hf_model
|
714 |
+
|
715 |
+
|
716 |
+
def convert_ldm_clip_checkpoint(checkpoint, dtype=torch.float16):
|
717 |
+
text_model = CLIPTextModel.from_pretrained("animatediff/sd/text_encoder", torch_dtype=dtype)
|
718 |
+
keys = list(checkpoint.keys())
|
719 |
+
|
720 |
+
text_model_dict = {}
|
721 |
+
|
722 |
+
for key in keys:
|
723 |
+
if key.startswith("cond_stage_model.transformer"):
|
724 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
725 |
+
|
726 |
+
text_model.load_state_dict(text_model_dict, strict=False)
|
727 |
+
|
728 |
+
return text_model
|
729 |
+
|
730 |
+
|
731 |
+
textenc_conversion_lst = [
|
732 |
+
("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
733 |
+
("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
734 |
+
("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
|
735 |
+
("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
|
736 |
+
]
|
737 |
+
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
|
738 |
+
|
739 |
+
textenc_transformer_conversion_lst = [
|
740 |
+
# (stable-diffusion, HF Diffusers)
|
741 |
+
("resblocks.", "text_model.encoder.layers."),
|
742 |
+
("ln_1", "layer_norm1"),
|
743 |
+
("ln_2", "layer_norm2"),
|
744 |
+
(".c_fc.", ".fc1."),
|
745 |
+
(".c_proj.", ".fc2."),
|
746 |
+
(".attn", ".self_attn"),
|
747 |
+
("ln_final.", "transformer.text_model.final_layer_norm."),
|
748 |
+
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
|
749 |
+
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
|
750 |
+
]
|
751 |
+
protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
|
752 |
+
textenc_pattern = re.compile("|".join(protected.keys()))
|
753 |
+
|
754 |
+
|
755 |
+
def convert_paint_by_example_checkpoint(checkpoint):
|
756 |
+
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
|
757 |
+
model = PaintByExampleImageEncoder(config)
|
758 |
+
|
759 |
+
keys = list(checkpoint.keys())
|
760 |
+
|
761 |
+
text_model_dict = {}
|
762 |
+
|
763 |
+
for key in keys:
|
764 |
+
if key.startswith("cond_stage_model.transformer"):
|
765 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
766 |
+
|
767 |
+
# load clip vision
|
768 |
+
model.model.load_state_dict(text_model_dict)
|
769 |
+
|
770 |
+
# load mapper
|
771 |
+
keys_mapper = {
|
772 |
+
k[len("cond_stage_model.mapper.res") :]: v
|
773 |
+
for k, v in checkpoint.items()
|
774 |
+
if k.startswith("cond_stage_model.mapper")
|
775 |
+
}
|
776 |
+
|
777 |
+
MAPPING = {
|
778 |
+
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
|
779 |
+
"attn.c_proj": ["attn1.to_out.0"],
|
780 |
+
"ln_1": ["norm1"],
|
781 |
+
"ln_2": ["norm3"],
|
782 |
+
"mlp.c_fc": ["ff.net.0.proj"],
|
783 |
+
"mlp.c_proj": ["ff.net.2"],
|
784 |
+
}
|
785 |
+
|
786 |
+
mapped_weights = {}
|
787 |
+
for key, value in keys_mapper.items():
|
788 |
+
prefix = key[: len("blocks.i")]
|
789 |
+
suffix = key.split(prefix)[-1].split(".")[-1]
|
790 |
+
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
|
791 |
+
mapped_names = MAPPING[name]
|
792 |
+
|
793 |
+
num_splits = len(mapped_names)
|
794 |
+
for i, mapped_name in enumerate(mapped_names):
|
795 |
+
new_name = ".".join([prefix, mapped_name, suffix])
|
796 |
+
shape = value.shape[0] // num_splits
|
797 |
+
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
|
798 |
+
|
799 |
+
model.mapper.load_state_dict(mapped_weights)
|
800 |
+
|
801 |
+
# load final layer norm
|
802 |
+
model.final_layer_norm.load_state_dict(
|
803 |
+
{
|
804 |
+
"bias": checkpoint["cond_stage_model.final_ln.bias"],
|
805 |
+
"weight": checkpoint["cond_stage_model.final_ln.weight"],
|
806 |
+
}
|
807 |
+
)
|
808 |
+
|
809 |
+
# load final proj
|
810 |
+
model.proj_out.load_state_dict(
|
811 |
+
{
|
812 |
+
"bias": checkpoint["proj_out.bias"],
|
813 |
+
"weight": checkpoint["proj_out.weight"],
|
814 |
+
}
|
815 |
+
)
|
816 |
+
|
817 |
+
# load uncond vector
|
818 |
+
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
|
819 |
+
return model
|
820 |
+
|
821 |
+
|
822 |
+
def convert_open_clip_checkpoint(checkpoint):
|
823 |
+
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
824 |
+
|
825 |
+
keys = list(checkpoint.keys())
|
826 |
+
|
827 |
+
text_model_dict = {}
|
828 |
+
|
829 |
+
if "cond_stage_model.model.text_projection" in checkpoint:
|
830 |
+
d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
|
831 |
+
else:
|
832 |
+
d_model = 1024
|
833 |
+
|
834 |
+
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
835 |
+
|
836 |
+
for key in keys:
|
837 |
+
if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
|
838 |
+
continue
|
839 |
+
if key in textenc_conversion_map:
|
840 |
+
text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
|
841 |
+
if key.startswith("cond_stage_model.model.transformer."):
|
842 |
+
new_key = key[len("cond_stage_model.model.transformer.") :]
|
843 |
+
if new_key.endswith(".in_proj_weight"):
|
844 |
+
new_key = new_key[: -len(".in_proj_weight")]
|
845 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
846 |
+
text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
847 |
+
text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
|
848 |
+
text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
|
849 |
+
elif new_key.endswith(".in_proj_bias"):
|
850 |
+
new_key = new_key[: -len(".in_proj_bias")]
|
851 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
852 |
+
text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
|
853 |
+
text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
|
854 |
+
text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
|
855 |
+
else:
|
856 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
857 |
+
|
858 |
+
text_model_dict[new_key] = checkpoint[key]
|
859 |
+
|
860 |
+
text_model.load_state_dict(text_model_dict)
|
861 |
+
|
862 |
+
return text_model
|
863 |
+
|
864 |
+
|
865 |
+
def stable_unclip_image_encoder(original_config):
|
866 |
+
"""
|
867 |
+
Returns the image processor and clip image encoder for the img2img unclip pipeline.
|
868 |
+
|
869 |
+
We currently know of two types of stable unclip models which separately use the clip and the openclip image
|
870 |
+
encoders.
|
871 |
+
"""
|
872 |
+
|
873 |
+
image_embedder_config = original_config.model.params.embedder_config
|
874 |
+
|
875 |
+
sd_clip_image_embedder_class = image_embedder_config.target
|
876 |
+
sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
|
877 |
+
|
878 |
+
if sd_clip_image_embedder_class == "ClipImageEmbedder":
|
879 |
+
clip_model_name = image_embedder_config.params.model
|
880 |
+
|
881 |
+
if clip_model_name == "ViT-L/14":
|
882 |
+
feature_extractor = CLIPImageProcessor()
|
883 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
884 |
+
else:
|
885 |
+
raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
|
886 |
+
|
887 |
+
elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
|
888 |
+
feature_extractor = CLIPImageProcessor()
|
889 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
890 |
+
else:
|
891 |
+
raise NotImplementedError(
|
892 |
+
f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
|
893 |
+
)
|
894 |
+
|
895 |
+
return feature_extractor, image_encoder
|
896 |
+
|
897 |
+
|
898 |
+
def stable_unclip_image_noising_components(
|
899 |
+
original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
|
900 |
+
):
|
901 |
+
"""
|
902 |
+
Returns the noising components for the img2img and txt2img unclip pipelines.
|
903 |
+
|
904 |
+
Converts the stability noise augmentor into
|
905 |
+
1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
|
906 |
+
2. a `DDPMScheduler` for holding the noise schedule
|
907 |
+
|
908 |
+
If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
|
909 |
+
"""
|
910 |
+
noise_aug_config = original_config.model.params.noise_aug_config
|
911 |
+
noise_aug_class = noise_aug_config.target
|
912 |
+
noise_aug_class = noise_aug_class.split(".")[-1]
|
913 |
+
|
914 |
+
if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
|
915 |
+
noise_aug_config = noise_aug_config.params
|
916 |
+
embedding_dim = noise_aug_config.timestep_dim
|
917 |
+
max_noise_level = noise_aug_config.noise_schedule_config.timesteps
|
918 |
+
beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
|
919 |
+
|
920 |
+
image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
|
921 |
+
image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
|
922 |
+
|
923 |
+
if "clip_stats_path" in noise_aug_config:
|
924 |
+
if clip_stats_path is None:
|
925 |
+
raise ValueError("This stable unclip config requires a `clip_stats_path`")
|
926 |
+
|
927 |
+
clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
|
928 |
+
clip_mean = clip_mean[None, :]
|
929 |
+
clip_std = clip_std[None, :]
|
930 |
+
|
931 |
+
clip_stats_state_dict = {
|
932 |
+
"mean": clip_mean,
|
933 |
+
"std": clip_std,
|
934 |
+
}
|
935 |
+
|
936 |
+
image_normalizer.load_state_dict(clip_stats_state_dict)
|
937 |
+
else:
|
938 |
+
raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
|
939 |
+
|
940 |
+
return image_normalizer, image_noising_scheduler
|
941 |
+
|
942 |
+
|
943 |
+
def convert_controlnet_checkpoint(
|
944 |
+
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
945 |
+
):
|
946 |
+
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
947 |
+
ctrlnet_config["upcast_attention"] = upcast_attention
|
948 |
+
|
949 |
+
ctrlnet_config.pop("sample_size")
|
950 |
+
|
951 |
+
controlnet_model = ControlNetModel(**ctrlnet_config)
|
952 |
+
|
953 |
+
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
|
954 |
+
checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
|
955 |
+
)
|
956 |
+
|
957 |
+
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
|
958 |
+
|
959 |
+
return controlnet_model
|
animatediff/utils/convert_lora_safetensor_to_diffusers.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
#
|
16 |
+
# Changes were made to this source code by Yuwei Guo.
|
17 |
+
""" Conversion script for the LoRA's safetensors checkpoints. """
|
18 |
+
|
19 |
+
import argparse
|
20 |
+
|
21 |
+
import torch
|
22 |
+
from safetensors.torch import load_file
|
23 |
+
|
24 |
+
from diffusers import StableDiffusionPipeline
|
25 |
+
|
26 |
+
|
27 |
+
def load_diffusers_lora(pipeline, state_dict, alpha=1.0):
|
28 |
+
# directly update weight in diffusers model
|
29 |
+
for key in state_dict:
|
30 |
+
# only process lora down key
|
31 |
+
if "up." in key: continue
|
32 |
+
|
33 |
+
up_key = key.replace(".down.", ".up.")
|
34 |
+
model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
|
35 |
+
model_key = model_key.replace("to_out.", "to_out.0.")
|
36 |
+
layer_infos = model_key.split(".")[:-1]
|
37 |
+
|
38 |
+
curr_layer = pipeline.unet
|
39 |
+
while len(layer_infos) > 0:
|
40 |
+
temp_name = layer_infos.pop(0)
|
41 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
42 |
+
|
43 |
+
weight_down = state_dict[key]
|
44 |
+
weight_up = state_dict[up_key]
|
45 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
|
46 |
+
|
47 |
+
return pipeline
|
48 |
+
|
49 |
+
|
50 |
+
def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
|
51 |
+
# load base model
|
52 |
+
# pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
|
53 |
+
|
54 |
+
# load LoRA weight from .safetensors
|
55 |
+
# state_dict = load_file(checkpoint_path)
|
56 |
+
|
57 |
+
visited = []
|
58 |
+
|
59 |
+
# directly update weight in diffusers model
|
60 |
+
for key in state_dict:
|
61 |
+
# it is suggested to print out the key, it usually will be something like below
|
62 |
+
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
63 |
+
|
64 |
+
# as we have set the alpha beforehand, so just skip
|
65 |
+
if ".alpha" in key or key in visited:
|
66 |
+
continue
|
67 |
+
|
68 |
+
if "text" in key:
|
69 |
+
layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
70 |
+
curr_layer = pipeline.text_encoder
|
71 |
+
else:
|
72 |
+
layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
|
73 |
+
curr_layer = pipeline.unet
|
74 |
+
|
75 |
+
# find the target layer
|
76 |
+
temp_name = layer_infos.pop(0)
|
77 |
+
while len(layer_infos) > -1:
|
78 |
+
try:
|
79 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
80 |
+
if len(layer_infos) > 0:
|
81 |
+
temp_name = layer_infos.pop(0)
|
82 |
+
elif len(layer_infos) == 0:
|
83 |
+
break
|
84 |
+
except Exception:
|
85 |
+
if len(temp_name) > 0:
|
86 |
+
temp_name += "_" + layer_infos.pop(0)
|
87 |
+
else:
|
88 |
+
temp_name = layer_infos.pop(0)
|
89 |
+
|
90 |
+
pair_keys = []
|
91 |
+
if "lora_down" in key:
|
92 |
+
pair_keys.append(key.replace("lora_down", "lora_up"))
|
93 |
+
pair_keys.append(key)
|
94 |
+
else:
|
95 |
+
pair_keys.append(key)
|
96 |
+
pair_keys.append(key.replace("lora_up", "lora_down"))
|
97 |
+
|
98 |
+
# update weight
|
99 |
+
if len(state_dict[pair_keys[0]].shape) == 4:
|
100 |
+
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
|
101 |
+
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
|
102 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
|
103 |
+
else:
|
104 |
+
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
105 |
+
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
106 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
|
107 |
+
|
108 |
+
# update visited list
|
109 |
+
for item in pair_keys:
|
110 |
+
visited.append(item)
|
111 |
+
|
112 |
+
return pipeline
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
parser = argparse.ArgumentParser()
|
117 |
+
|
118 |
+
parser.add_argument(
|
119 |
+
"--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
|
120 |
+
)
|
121 |
+
parser.add_argument(
|
122 |
+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
123 |
+
)
|
124 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
125 |
+
parser.add_argument(
|
126 |
+
"--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
|
127 |
+
)
|
128 |
+
parser.add_argument(
|
129 |
+
"--lora_prefix_text_encoder",
|
130 |
+
default="lora_te",
|
131 |
+
type=str,
|
132 |
+
help="The prefix of text encoder weight in safetensors",
|
133 |
+
)
|
134 |
+
parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
|
135 |
+
parser.add_argument(
|
136 |
+
"--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
|
137 |
+
)
|
138 |
+
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
139 |
+
|
140 |
+
args = parser.parse_args()
|
141 |
+
|
142 |
+
base_model_path = args.base_model_path
|
143 |
+
checkpoint_path = args.checkpoint_path
|
144 |
+
dump_path = args.dump_path
|
145 |
+
lora_prefix_unet = args.lora_prefix_unet
|
146 |
+
lora_prefix_text_encoder = args.lora_prefix_text_encoder
|
147 |
+
alpha = args.alpha
|
148 |
+
|
149 |
+
pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
|
150 |
+
|
151 |
+
pipe = pipe.to(args.device)
|
152 |
+
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
animatediff/utils/convert_original_stable_diffusion_to_diffusers.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Conversion script for the LDM checkpoints."""
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import importlib
|
19 |
+
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
23 |
+
|
24 |
+
|
25 |
+
if __name__ == "__main__":
|
26 |
+
parser = argparse.ArgumentParser()
|
27 |
+
|
28 |
+
parser.add_argument(
|
29 |
+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
30 |
+
)
|
31 |
+
# !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
|
32 |
+
parser.add_argument(
|
33 |
+
"--original_config_file",
|
34 |
+
default=None,
|
35 |
+
type=str,
|
36 |
+
help="The YAML config file corresponding to the original architecture.",
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--config_files",
|
40 |
+
default=None,
|
41 |
+
type=str,
|
42 |
+
help="The YAML config file corresponding to the architecture.",
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--num_in_channels",
|
46 |
+
default=None,
|
47 |
+
type=int,
|
48 |
+
help="The number of input channels. If `None` number of input channels will be automatically inferred.",
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--scheduler_type",
|
52 |
+
default="pndm",
|
53 |
+
type=str,
|
54 |
+
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']",
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--pipeline_type",
|
58 |
+
default=None,
|
59 |
+
type=str,
|
60 |
+
help=(
|
61 |
+
"The pipeline type. One of 'FrozenOpenCLIPEmbedder', 'FrozenCLIPEmbedder', 'PaintByExample'"
|
62 |
+
". If `None` pipeline will be automatically inferred."
|
63 |
+
),
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"--image_size",
|
67 |
+
default=None,
|
68 |
+
type=int,
|
69 |
+
help=(
|
70 |
+
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
|
71 |
+
" Base. Use 768 for Stable Diffusion v2."
|
72 |
+
),
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--prediction_type",
|
76 |
+
default=None,
|
77 |
+
type=str,
|
78 |
+
help=(
|
79 |
+
"The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable"
|
80 |
+
" Diffusion v2 Base. Use 'v_prediction' for Stable Diffusion v2."
|
81 |
+
),
|
82 |
+
)
|
83 |
+
parser.add_argument(
|
84 |
+
"--extract_ema",
|
85 |
+
action="store_true",
|
86 |
+
help=(
|
87 |
+
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
|
88 |
+
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
|
89 |
+
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
|
90 |
+
),
|
91 |
+
)
|
92 |
+
parser.add_argument(
|
93 |
+
"--upcast_attention",
|
94 |
+
action="store_true",
|
95 |
+
help=(
|
96 |
+
"Whether the attention computation should always be upcasted. This is necessary when running stable"
|
97 |
+
" diffusion 2.1."
|
98 |
+
),
|
99 |
+
)
|
100 |
+
parser.add_argument(
|
101 |
+
"--from_safetensors",
|
102 |
+
action="store_true",
|
103 |
+
help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
|
104 |
+
)
|
105 |
+
parser.add_argument(
|
106 |
+
"--to_safetensors",
|
107 |
+
action="store_true",
|
108 |
+
help="Whether to store pipeline in safetensors format or not.",
|
109 |
+
)
|
110 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
111 |
+
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
112 |
+
parser.add_argument(
|
113 |
+
"--stable_unclip",
|
114 |
+
type=str,
|
115 |
+
default=None,
|
116 |
+
required=False,
|
117 |
+
help="Set if this is a stable unCLIP model. One of 'txt2img' or 'img2img'.",
|
118 |
+
)
|
119 |
+
parser.add_argument(
|
120 |
+
"--stable_unclip_prior",
|
121 |
+
type=str,
|
122 |
+
default=None,
|
123 |
+
required=False,
|
124 |
+
help="Set if this is a stable unCLIP txt2img model. Selects which prior to use. If `--stable_unclip` is set to `txt2img`, the karlo prior (https://huggingface.co/kakaobrain/karlo-v1-alpha/tree/main/prior) is selected by default.",
|
125 |
+
)
|
126 |
+
parser.add_argument(
|
127 |
+
"--clip_stats_path",
|
128 |
+
type=str,
|
129 |
+
help="Path to the clip stats file. Only required if the stable unclip model's config specifies `model.params.noise_aug_config.params.clip_stats_path`.",
|
130 |
+
required=False,
|
131 |
+
)
|
132 |
+
parser.add_argument(
|
133 |
+
"--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint."
|
134 |
+
)
|
135 |
+
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
|
136 |
+
parser.add_argument(
|
137 |
+
"--vae_path",
|
138 |
+
type=str,
|
139 |
+
default=None,
|
140 |
+
required=False,
|
141 |
+
help="Set to a path, hub id to an already converted vae to not convert it again.",
|
142 |
+
)
|
143 |
+
parser.add_argument(
|
144 |
+
"--pipeline_class_name",
|
145 |
+
type=str,
|
146 |
+
default=None,
|
147 |
+
required=False,
|
148 |
+
help="Specify the pipeline class name",
|
149 |
+
)
|
150 |
+
|
151 |
+
args = parser.parse_args()
|
152 |
+
|
153 |
+
if args.pipeline_class_name is not None:
|
154 |
+
library = importlib.import_module("diffusers")
|
155 |
+
class_obj = getattr(library, args.pipeline_class_name)
|
156 |
+
pipeline_class = class_obj
|
157 |
+
else:
|
158 |
+
pipeline_class = None
|
159 |
+
|
160 |
+
pipe = download_from_original_stable_diffusion_ckpt(
|
161 |
+
checkpoint_path_or_dict=args.checkpoint_path,
|
162 |
+
original_config_file=args.original_config_file,
|
163 |
+
config_files=args.config_files,
|
164 |
+
image_size=args.image_size,
|
165 |
+
prediction_type=args.prediction_type,
|
166 |
+
model_type=args.pipeline_type,
|
167 |
+
extract_ema=args.extract_ema,
|
168 |
+
scheduler_type=args.scheduler_type,
|
169 |
+
num_in_channels=args.num_in_channels,
|
170 |
+
upcast_attention=args.upcast_attention,
|
171 |
+
from_safetensors=args.from_safetensors,
|
172 |
+
device=args.device,
|
173 |
+
stable_unclip=args.stable_unclip,
|
174 |
+
stable_unclip_prior=args.stable_unclip_prior,
|
175 |
+
clip_stats_path=args.clip_stats_path,
|
176 |
+
controlnet=args.controlnet,
|
177 |
+
vae_path=args.vae_path,
|
178 |
+
pipeline_class=pipeline_class,
|
179 |
+
)
|
180 |
+
|
181 |
+
if args.half:
|
182 |
+
pipe.to(dtype=torch.float16)
|
183 |
+
|
184 |
+
if args.controlnet:
|
185 |
+
# only save the controlnet model
|
186 |
+
pipe.controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
187 |
+
else:
|
188 |
+
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
animatediff/utils/util.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import imageio
|
3 |
+
import numpy as np
|
4 |
+
from typing import Union
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torchvision
|
9 |
+
import torch.distributed as dist
|
10 |
+
|
11 |
+
from safetensors import safe_open
|
12 |
+
from tqdm import tqdm
|
13 |
+
from einops import rearrange
|
14 |
+
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
|
15 |
+
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora
|
16 |
+
|
17 |
+
|
18 |
+
def zero_rank_print(s):
|
19 |
+
if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
|
20 |
+
from typing import List
|
21 |
+
import PIL
|
22 |
+
def export_to_video(
|
23 |
+
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
|
24 |
+
) -> str:
|
25 |
+
# if output_video_path is None:
|
26 |
+
# output_video_path = tempfile.NamedTemporaryFile(suffix=".webm").name
|
27 |
+
|
28 |
+
if isinstance(video_frames[0], PIL.Image.Image):
|
29 |
+
video_frames = [np.array(frame) for frame in video_frames]
|
30 |
+
|
31 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
32 |
+
# fourcc = cv2.VideoWriter_fourcc(*'VP90')
|
33 |
+
h, w, c = video_frames[0].shape
|
34 |
+
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps=fps, frameSize=(w, h))
|
35 |
+
for i in range(len(video_frames)):
|
36 |
+
img = cv2.cvtColor(video_frames[i], cv2.COLOR_RGB2BGR)
|
37 |
+
video_writer.write(img)
|
38 |
+
|
39 |
+
return output_video_path
|
40 |
+
|
41 |
+
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=9):
|
42 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
43 |
+
outputs = []
|
44 |
+
for x in videos:
|
45 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
46 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
47 |
+
if rescale:
|
48 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
49 |
+
x = (x * 255).numpy().astype(np.uint8)
|
50 |
+
outputs.append(x)
|
51 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
52 |
+
# export_to_video(outputs, output_video_path=path, fps=fps)
|
53 |
+
|
54 |
+
imageio.mimsave(path, outputs, fps=fps)
|
55 |
+
|
56 |
+
|
57 |
+
# DDIM Inversion
|
58 |
+
@torch.no_grad()
|
59 |
+
def init_prompt(prompt, pipeline):
|
60 |
+
uncond_input = pipeline.tokenizer(
|
61 |
+
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
|
62 |
+
return_tensors="pt"
|
63 |
+
)
|
64 |
+
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
|
65 |
+
text_input = pipeline.tokenizer(
|
66 |
+
[prompt],
|
67 |
+
padding="max_length",
|
68 |
+
max_length=pipeline.tokenizer.model_max_length,
|
69 |
+
truncation=True,
|
70 |
+
return_tensors="pt",
|
71 |
+
)
|
72 |
+
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
|
73 |
+
context = torch.cat([uncond_embeddings, text_embeddings])
|
74 |
+
|
75 |
+
return context
|
76 |
+
|
77 |
+
|
78 |
+
def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
|
79 |
+
sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
|
80 |
+
timestep, next_timestep = min(
|
81 |
+
timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
|
82 |
+
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
|
83 |
+
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
|
84 |
+
beta_prod_t = 1 - alpha_prod_t
|
85 |
+
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
86 |
+
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
|
87 |
+
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
|
88 |
+
return next_sample
|
89 |
+
|
90 |
+
|
91 |
+
def get_noise_pred_single(latents, t, context, unet):
|
92 |
+
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
|
93 |
+
return noise_pred
|
94 |
+
|
95 |
+
|
96 |
+
@torch.no_grad()
|
97 |
+
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
|
98 |
+
context = init_prompt(prompt, pipeline)
|
99 |
+
uncond_embeddings, cond_embeddings = context.chunk(2)
|
100 |
+
all_latent = [latent]
|
101 |
+
latent = latent.clone().detach()
|
102 |
+
for i in tqdm(range(num_inv_steps)):
|
103 |
+
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
|
104 |
+
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
|
105 |
+
latent = next_step(noise_pred, t, latent, ddim_scheduler)
|
106 |
+
all_latent.append(latent)
|
107 |
+
return all_latent
|
108 |
+
|
109 |
+
|
110 |
+
@torch.no_grad()
|
111 |
+
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
|
112 |
+
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
|
113 |
+
return ddim_latents
|
114 |
+
|
115 |
+
def load_weights(
|
116 |
+
animation_pipeline,
|
117 |
+
# motion module
|
118 |
+
motion_module_path = "",
|
119 |
+
motion_module_lora_configs = [],
|
120 |
+
# domain adapter
|
121 |
+
adapter_lora_path = "",
|
122 |
+
adapter_lora_scale = 1.0,
|
123 |
+
# image layers
|
124 |
+
dreambooth_model_path = "",
|
125 |
+
lora_model_path = "",
|
126 |
+
lora_alpha = 0.8,
|
127 |
+
):
|
128 |
+
# motion module
|
129 |
+
unet_state_dict = {}
|
130 |
+
if motion_module_path != "":
|
131 |
+
print(f"load motion module from {motion_module_path}")
|
132 |
+
motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
|
133 |
+
motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
|
134 |
+
unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
|
135 |
+
unet_state_dict.pop("animatediff_config", "")
|
136 |
+
|
137 |
+
missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
|
138 |
+
print("motion_module missing:",len(missing))
|
139 |
+
print("motion_module unexpe:",len(unexpected))
|
140 |
+
assert len(unexpected) == 0
|
141 |
+
del unet_state_dict
|
142 |
+
|
143 |
+
# base model
|
144 |
+
# if dreambooth_model_path != "":
|
145 |
+
# print(f"load dreambooth model from {dreambooth_model_path}")
|
146 |
+
# # if dreambooth_model_path.endswith(".safetensors"):
|
147 |
+
# # dreambooth_state_dict = {}
|
148 |
+
# # with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f:
|
149 |
+
# # for key in f.keys():
|
150 |
+
# # dreambooth_state_dict[key] = f.get_tensor(key)
|
151 |
+
# # elif dreambooth_model_path.endswith(".ckpt"):
|
152 |
+
# # dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu")
|
153 |
+
|
154 |
+
# # # 1. vae
|
155 |
+
# # converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config)
|
156 |
+
# # animation_pipeline.vae.load_state_dict(converted_vae_checkpoint)
|
157 |
+
# # # 2. unet
|
158 |
+
# # converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config)
|
159 |
+
# # animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
160 |
+
# # # 3. text_model
|
161 |
+
# # animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict)
|
162 |
+
# # del dreambooth_state_dict
|
163 |
+
# dreambooth_state_dict = {}
|
164 |
+
# with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f:
|
165 |
+
# for key in f.keys():
|
166 |
+
# dreambooth_state_dict[key] = f.get_tensor(key)
|
167 |
+
|
168 |
+
# converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config)
|
169 |
+
# # print(vae)
|
170 |
+
# #vae ->to_q,to_k,to_v
|
171 |
+
# # print(converted_vae_checkpoint)
|
172 |
+
# convert_vae_keys = list(converted_vae_checkpoint.keys())
|
173 |
+
# for key in convert_vae_keys:
|
174 |
+
# if "encoder.mid_block.attentions" in key or "decoder.mid_block.attentions" in key:
|
175 |
+
# new_key = None
|
176 |
+
# if "key" in key:
|
177 |
+
# new_key = key.replace("key","to_k")
|
178 |
+
# elif "query" in key:
|
179 |
+
# new_key = key.replace("query","to_q")
|
180 |
+
# elif "value" in key:
|
181 |
+
# new_key = key.replace("value","to_v")
|
182 |
+
# elif "proj_attn" in key:
|
183 |
+
# new_key = key.replace("proj_attn","to_out.0")
|
184 |
+
# if new_key:
|
185 |
+
# converted_vae_checkpoint[new_key] = converted_vae_checkpoint.pop(key)
|
186 |
+
|
187 |
+
# animation_pipeline.vae.load_state_dict(converted_vae_checkpoint)
|
188 |
+
|
189 |
+
# converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config)
|
190 |
+
# animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
191 |
+
|
192 |
+
# animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict)
|
193 |
+
# del dreambooth_state_dict
|
194 |
+
# lora layers
|
195 |
+
if lora_model_path != "":
|
196 |
+
print(f"load lora model from {lora_model_path}")
|
197 |
+
assert lora_model_path.endswith(".safetensors")
|
198 |
+
lora_state_dict = {}
|
199 |
+
with safe_open(lora_model_path, framework="pt", device="cpu") as f:
|
200 |
+
for key in f.keys():
|
201 |
+
lora_state_dict[key] = f.get_tensor(key)
|
202 |
+
|
203 |
+
animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha)
|
204 |
+
del lora_state_dict
|
205 |
+
|
206 |
+
# domain adapter lora
|
207 |
+
if adapter_lora_path != "":
|
208 |
+
print(f"load domain lora from {adapter_lora_path}")
|
209 |
+
domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu")
|
210 |
+
domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict
|
211 |
+
domain_lora_state_dict.pop("animatediff_config", "")
|
212 |
+
|
213 |
+
animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale)
|
214 |
+
|
215 |
+
# motion module lora
|
216 |
+
for motion_module_lora_config in motion_module_lora_configs:
|
217 |
+
path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
|
218 |
+
print(f"load motion LoRA from {path}")
|
219 |
+
motion_lora_state_dict = torch.load(path, map_location="cpu")
|
220 |
+
motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
|
221 |
+
motion_lora_state_dict.pop("animatediff_config", "")
|
222 |
+
|
223 |
+
animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha)
|
224 |
+
|
225 |
+
return animation_pipeline
|
app.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import spaces
|
3 |
+
css = '''
|
4 |
+
.gradio-container {width: 85% !important}
|
5 |
+
'''
|
6 |
+
from animatediff.utils.util import save_videos_grid
|
7 |
+
|
8 |
+
import random
|
9 |
+
from infer import load_model
|
10 |
+
MAX_SEED=10000
|
11 |
+
import uuid
|
12 |
+
from insightface.app import FaceAnalysis
|
13 |
+
import os
|
14 |
+
import os
|
15 |
+
import cv2
|
16 |
+
from diffusers.utils import load_image
|
17 |
+
from insightface.utils import face_align
|
18 |
+
from PIL import Image
|
19 |
+
import torch
|
20 |
+
import argparse
|
21 |
+
# From command line read command adaface_ckpt_path
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
parser.add_argument('--adaface_ckpt_path', type=str,
|
24 |
+
default='models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt')
|
25 |
+
# Don't use 'sd15' for base_model_type; it just generates messy videos.
|
26 |
+
parser.add_argument('--base_model_type', type=str, default='sar')
|
27 |
+
parser.add_argument('--adaface_base_model_type', type=str, default='sar')
|
28 |
+
parser.add_argument('--gpu', type=int, default=None)
|
29 |
+
parser.add_argument('--ip', type=str, default="0.0.0.0")
|
30 |
+
args = parser.parse_args()
|
31 |
+
|
32 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
33 |
+
if randomize_seed:
|
34 |
+
seed = random.randint(0, MAX_SEED)
|
35 |
+
return seed
|
36 |
+
|
37 |
+
# model = load_model()
|
38 |
+
# This FaceAnalysis uses a different model from what AdaFace uses, but it's fine.
|
39 |
+
# This is just to crop the face areas from the uploaded images.
|
40 |
+
app = FaceAnalysis(name="buffalo_l", root='models/insightface', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
41 |
+
app.prepare(ctx_id=0, det_size=(320, 320))
|
42 |
+
device = "cuda" if args.gpu is None else f"cuda:{args.gpu}"
|
43 |
+
|
44 |
+
id_animator, adaface = load_model(base_model_type=args.base_model_type,
|
45 |
+
adaface_base_model_type=args.adaface_base_model_type,
|
46 |
+
adaface_ckpt_path=args.adaface_ckpt_path,
|
47 |
+
device=device)
|
48 |
+
basedir = os.getcwd()
|
49 |
+
savedir = os.path.join(basedir,'samples')
|
50 |
+
os.makedirs(savedir, exist_ok=True)
|
51 |
+
|
52 |
+
#print(f"### Cleaning cached examples ...")
|
53 |
+
#os.system(f"rm -rf gradio_cached_examples/")
|
54 |
+
|
55 |
+
def swap_to_gallery(images):
|
56 |
+
# Update uploaded_files_gallery, show files, hide clear_button_column
|
57 |
+
# Or:
|
58 |
+
# Update uploaded_init_img_gallery, show init_img_files, hide init_clear_button_column
|
59 |
+
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(value=images, visible=False)
|
60 |
+
|
61 |
+
def remove_back_to_files():
|
62 |
+
# Hide uploaded_files_gallery, show clear_button_column, hide files, reset init_img_selected_idx
|
63 |
+
# Or:
|
64 |
+
# Hide uploaded_init_img_gallery, hide init_clear_button_column, show init_img_files, reset init_img_selected_idx
|
65 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(value=None, visible=True), gr.update(value="0")
|
66 |
+
|
67 |
+
def get_clicked_image(data: gr.SelectData):
|
68 |
+
return data.index
|
69 |
+
|
70 |
+
@spaces.GPU
|
71 |
+
def gen_init_images(uploaded_image_paths, prompt, adaface_id_cfg_scale, out_image_count=3):
|
72 |
+
if uploaded_image_paths is None:
|
73 |
+
print("No image uploaded")
|
74 |
+
return None, None, None
|
75 |
+
# uploaded_image_paths is a list of tuples:
|
76 |
+
# [('/tmp/gradio/249981e66a7c665aaaf1c7eaeb24949af4366c88/jensen huang.jpg', None)]
|
77 |
+
# Extract the file paths.
|
78 |
+
uploaded_image_paths = [path[0] for path in uploaded_image_paths]
|
79 |
+
adaface.generate_adaface_embeddings(image_folder=None, image_paths=uploaded_image_paths,
|
80 |
+
out_id_embs_scale=adaface_id_cfg_scale, update_text_encoder=True)
|
81 |
+
# Generate two images each time for the user to select from.
|
82 |
+
noise = torch.randn(out_image_count, 3, 512, 512)
|
83 |
+
# samples: A list of PIL Image instances.
|
84 |
+
samples = adaface(noise, prompt, out_image_count=out_image_count, verbose=True)
|
85 |
+
|
86 |
+
face_paths = []
|
87 |
+
for sample in samples:
|
88 |
+
random_name = str(uuid.uuid4())
|
89 |
+
face_path = os.path.join(savedir, f"{random_name}.jpg")
|
90 |
+
face_paths.append(face_path)
|
91 |
+
sample.save(face_path)
|
92 |
+
print(f"Generated init image: {face_path}")
|
93 |
+
|
94 |
+
# Update uploaded_init_img_gallery, update and hide init_img_files, hide init_clear_button_column
|
95 |
+
return gr.update(value=face_paths, visible=True), gr.update(value=face_paths, visible=False), gr.update(visible=True)
|
96 |
+
|
97 |
+
@spaces.GPU(duration=90)
|
98 |
+
def generate_image(image_container, uploaded_image_paths, init_img_file_paths, init_img_selected_idx,
|
99 |
+
init_image_strength, init_image_final_weight,
|
100 |
+
prompt, negative_prompt, num_steps, video_length, guidance_scale, seed, attn_scale, image_embed_scale,
|
101 |
+
is_adaface_enabled, adaface_ckpt_path, adaface_id_cfg_scale, adaface_power_scale,
|
102 |
+
adaface_anneal_steps, progress=gr.Progress(track_tqdm=True)):
|
103 |
+
|
104 |
+
prompt = prompt + " 8k uhd, high quality"
|
105 |
+
if " shot" not in prompt:
|
106 |
+
prompt = prompt + ", medium shot"
|
107 |
+
|
108 |
+
prompt_img_lists=[]
|
109 |
+
for path in uploaded_image_paths:
|
110 |
+
img = cv2.imread(path)
|
111 |
+
faces = app.get(img)
|
112 |
+
face_roi = face_align.norm_crop(img, faces[0]['kps'], 112)
|
113 |
+
random_name = str(uuid.uuid4())
|
114 |
+
face_path = os.path.join(savedir, f"{random_name}.jpg")
|
115 |
+
cv2.imwrite(face_path, face_roi)
|
116 |
+
# prompt_img_lists is a list of PIL images.
|
117 |
+
prompt_img_lists.append(load_image(face_path).resize((224,224)))
|
118 |
+
|
119 |
+
if adaface is None or not is_adaface_enabled:
|
120 |
+
adaface_prompt_embeds = None
|
121 |
+
else:
|
122 |
+
if adaface_ckpt_path != args.adaface_ckpt_path:
|
123 |
+
# Reload the embedding manager
|
124 |
+
adaface.load_subj_basis_generator(adaface_ckpt_path)
|
125 |
+
|
126 |
+
adaface.generate_adaface_embeddings(image_folder=None, image_paths=uploaded_image_paths,
|
127 |
+
out_id_embs_scale=adaface_id_cfg_scale, update_text_encoder=True)
|
128 |
+
# adaface_prompt_embeds: [1, 77, 768].
|
129 |
+
adaface_prompt_embeds, _ = adaface.encode_prompt(prompt)
|
130 |
+
|
131 |
+
# init_img_file_paths is a list of image paths. If not chose, init_img_file_paths is None.
|
132 |
+
if init_img_file_paths is not None:
|
133 |
+
init_img_selected_idx = int(init_img_selected_idx)
|
134 |
+
init_img_file_path = init_img_file_paths[init_img_selected_idx]
|
135 |
+
init_image = cv2.imread(init_img_file_path)
|
136 |
+
init_image = cv2.resize(init_image, (512, 512))
|
137 |
+
init_image = Image.fromarray(cv2.cvtColor(init_image, cv2.COLOR_BGR2RGB))
|
138 |
+
print(f"init_image: {init_img_file_path}")
|
139 |
+
else:
|
140 |
+
init_image = None
|
141 |
+
|
142 |
+
sample = id_animator.generate(prompt_img_lists,
|
143 |
+
init_image = init_image,
|
144 |
+
init_image_strength = (init_image_strength, init_image_final_weight),
|
145 |
+
prompt = prompt,
|
146 |
+
negative_prompt = negative_prompt,
|
147 |
+
adaface_embeds = adaface_prompt_embeds,
|
148 |
+
# adaface_scale is not so useful, and when it's set >= 2, weird artifacts appear.
|
149 |
+
# Here it's limited to 0.7~1.3.
|
150 |
+
adaface_scale = adaface_power_scale,
|
151 |
+
num_inference_steps = num_steps,
|
152 |
+
adaface_anneal_steps = adaface_anneal_steps,
|
153 |
+
seed=seed,
|
154 |
+
guidance_scale = guidance_scale,
|
155 |
+
width = 512,
|
156 |
+
height = 512,
|
157 |
+
video_length = video_length,
|
158 |
+
attn_scale = attn_scale,
|
159 |
+
image_embed_scale = image_embed_scale,
|
160 |
+
)
|
161 |
+
|
162 |
+
save_sample_path = os.path.join(savedir, f"{random_name}.mp4")
|
163 |
+
save_videos_grid(sample, save_sample_path)
|
164 |
+
return save_sample_path
|
165 |
+
|
166 |
+
def validate(prompt):
|
167 |
+
if not prompt:
|
168 |
+
raise gr.Error("Prompt cannot be blank")
|
169 |
+
|
170 |
+
examples = [
|
171 |
+
[
|
172 |
+
"demo/ann.png",
|
173 |
+
["demo/ann.png" ],
|
174 |
+
"A young girl with a passion for reading, curled up with a book in a cozy nook near a window",
|
175 |
+
"semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck,",
|
176 |
+
30,
|
177 |
+
8, 8290,1,16
|
178 |
+
],
|
179 |
+
[
|
180 |
+
"demo/lecun.png",
|
181 |
+
["demo/lecun.png" ],
|
182 |
+
"Iron Man soars through the clouds, his repulsors blazing",
|
183 |
+
"worst quality, low quality, jpeg artifacts, ugly, duplicate, blurry, long neck",
|
184 |
+
30,
|
185 |
+
8, 4993,0.7,16
|
186 |
+
],
|
187 |
+
[
|
188 |
+
"demo/mix.png",
|
189 |
+
["demo/lecun.png","demo/ann.png"],
|
190 |
+
"A musician playing a guitar, fingers deftly moving across the strings, producing a soulful melody",
|
191 |
+
"semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
|
192 |
+
30,
|
193 |
+
8, 1897,0.9,16
|
194 |
+
],
|
195 |
+
[
|
196 |
+
"demo/zendaya.png",
|
197 |
+
["demo/zendaya.png" ],
|
198 |
+
"A woman on a serene beach at sunset, the sky ablaze with hues of orange and purple.",
|
199 |
+
"semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
|
200 |
+
30,
|
201 |
+
8, 5992,1,16
|
202 |
+
],
|
203 |
+
[
|
204 |
+
"demo/qianlong.png",
|
205 |
+
["demo/qianlong.png" ],
|
206 |
+
"A chef in a white apron, complete with a toqueblanche, garnishing a gourmet dish",
|
207 |
+
"(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime), text, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, UnrealisticDream",
|
208 |
+
30,
|
209 |
+
8, 1844,0.8,16
|
210 |
+
],
|
211 |
+
[
|
212 |
+
"demo/augustus.png",
|
213 |
+
["demo/augustus.png" ],
|
214 |
+
"A man with dyed pink and purple hair, styledin a high ponytail",
|
215 |
+
"semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
|
216 |
+
30,
|
217 |
+
8, 870,0.7,16
|
218 |
+
]
|
219 |
+
]
|
220 |
+
|
221 |
+
with gr.Blocks(css=css) as demo:
|
222 |
+
gr.Markdown(
|
223 |
+
"""
|
224 |
+
# AdaFace-Animate: Zero-Shot Subject-Driven Video Generation for Humans
|
225 |
+
"""
|
226 |
+
)
|
227 |
+
gr.Markdown(
|
228 |
+
"""
|
229 |
+
❗️❗️❗️**Tips:**
|
230 |
+
- You can upload one or more subject images for generating ID-specific video.
|
231 |
+
- Try different parameter combinations for the best generation quality.
|
232 |
+
- Technical explanations and demo videos: [Readme](https://huggingface.co/spaces/adaface-neurips/adaface-animate/blob/main/README2.md).
|
233 |
+
"""
|
234 |
+
)
|
235 |
+
|
236 |
+
with gr.Row():
|
237 |
+
with gr.Column():
|
238 |
+
files = gr.File(
|
239 |
+
label="Drag / Select 1 or more photos of a person's face",
|
240 |
+
file_types=["image"],
|
241 |
+
file_count="multiple"
|
242 |
+
)
|
243 |
+
image_container = gr.Image(label="image container", sources="upload", type="numpy", height=256, visible=False)
|
244 |
+
uploaded_files_gallery = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=200)
|
245 |
+
with gr.Column(visible=False) as clear_button_column:
|
246 |
+
remove_and_reupload = gr.ClearButton(value="Remove and upload subject images", components=files, size="sm")
|
247 |
+
|
248 |
+
init_img_files = gr.File(
|
249 |
+
label="[Optional] Select 1 image for initialization, or generate 3 images with the button below and select 1",
|
250 |
+
file_types=["image"],
|
251 |
+
file_count="multiple"
|
252 |
+
)
|
253 |
+
init_img_container = gr.Image(label="init image container", sources="upload", type="numpy", height=256, visible=False)
|
254 |
+
# Although there's only one image, we still use columns=3, to scale down the image size.
|
255 |
+
# Otherwise it will occupy the full width, and the gallery won't show the whole image.
|
256 |
+
uploaded_init_img_gallery = gr.Gallery(label="Init image", visible=False, columns=3, rows=1, height=200)
|
257 |
+
# placeholder is just hint, not the real value. So we use "value='0'" instead of "placeholder='0'".
|
258 |
+
init_img_selected_idx = gr.Textbox(label="Selected init image index", value="0", visible=False)
|
259 |
+
|
260 |
+
init_image_strength = gr.Slider(
|
261 |
+
label="Init Image Strength",
|
262 |
+
info="How much the init image should influence each frame. 0: no influence (scenes are more dynamic), 3: strongest influence (scenes are more static).",
|
263 |
+
minimum=0,
|
264 |
+
maximum=3,
|
265 |
+
step=0.25,
|
266 |
+
value=1.5,
|
267 |
+
)
|
268 |
+
init_image_final_weight = gr.Slider(
|
269 |
+
label="Final Weight of the Init Image",
|
270 |
+
info="How much the init image should influence the end of the video",
|
271 |
+
minimum=0,
|
272 |
+
maximum=0.25,
|
273 |
+
step=0.025,
|
274 |
+
value=0.1,
|
275 |
+
)
|
276 |
+
|
277 |
+
with gr.Column(visible=False) as init_clear_button_column:
|
278 |
+
remove_init_and_reupload = gr.ClearButton(value="Remove and upload new init image", components=init_img_files, size="sm")
|
279 |
+
with gr.Column(visible=True) as init_gen_button_column:
|
280 |
+
gen_init = gr.Button(value="Generate 3 new init images")
|
281 |
+
|
282 |
+
prompt = gr.Textbox(label="Prompt",
|
283 |
+
info="Try something like 'man/woman walking on the beach'",
|
284 |
+
placeholder="woman playing guitar on a boat, ocean waves")
|
285 |
+
|
286 |
+
image_embed_scale = gr.Slider(
|
287 |
+
label="Image Embedding Scale",
|
288 |
+
info="The scale of the ID-Animator image embedding (influencing coarse facial features and poses)",
|
289 |
+
minimum=0,
|
290 |
+
maximum=2,
|
291 |
+
step=0.1,
|
292 |
+
value=0.8,
|
293 |
+
)
|
294 |
+
attn_scale = gr.Slider(
|
295 |
+
label="Attention Processor Scale",
|
296 |
+
info="The scale of the ID embeddings on the attention (the higher, the more focus on the face, less on the background)" ,
|
297 |
+
minimum=0,
|
298 |
+
maximum=2,
|
299 |
+
step=0.1,
|
300 |
+
value=0.8,
|
301 |
+
)
|
302 |
+
adaface_id_cfg_scale = gr.Slider(
|
303 |
+
label="AdaFace Embedding ID CFG Scale",
|
304 |
+
info="The scale of the AdaFace ID embeddings (influencing fine facial features and details)",
|
305 |
+
minimum=0.5,
|
306 |
+
maximum=6,
|
307 |
+
step=0.25,
|
308 |
+
value=1.5,
|
309 |
+
)
|
310 |
+
|
311 |
+
submit = gr.Button("Generate Video")
|
312 |
+
|
313 |
+
with gr.Accordion(open=False, label="Advanced Options"):
|
314 |
+
video_length = gr.Slider(
|
315 |
+
label="video_length",
|
316 |
+
info="Do not change, otherwise the video will be messy",
|
317 |
+
minimum=16,
|
318 |
+
maximum=21,
|
319 |
+
step=1,
|
320 |
+
value=16,
|
321 |
+
)
|
322 |
+
is_adaface_enabled = gr.Checkbox(label="Enable AdaFace",
|
323 |
+
info="Enable AdaFace for better face details. If unchecked, it falls back to ID-Animator (https://huggingface.co/spaces/ID-Animator/ID-Animator).",
|
324 |
+
value=True)
|
325 |
+
adaface_ckpt_path = gr.Textbox(
|
326 |
+
label="AdaFace ckpt Path",
|
327 |
+
placeholder=args.adaface_ckpt_path,
|
328 |
+
value=args.adaface_ckpt_path,
|
329 |
+
)
|
330 |
+
|
331 |
+
adaface_power_scale = gr.Slider(
|
332 |
+
label="AdaFace Embedding Power Scale",
|
333 |
+
info="Increase this scale slightly only if the face is defocused or the face details are not clear",
|
334 |
+
minimum=0.7,
|
335 |
+
maximum=1.3,
|
336 |
+
step=0.1,
|
337 |
+
value=1,
|
338 |
+
)
|
339 |
+
|
340 |
+
# adaface_anneal_steps is no longer necessary, but we keep it here for future use.
|
341 |
+
adaface_anneal_steps = gr.Slider(
|
342 |
+
label="AdaFace Anneal Steps",
|
343 |
+
minimum=0,
|
344 |
+
maximum=2,
|
345 |
+
step=1,
|
346 |
+
value=0,
|
347 |
+
visible=False,
|
348 |
+
)
|
349 |
+
|
350 |
+
negative_prompt = gr.Textbox(
|
351 |
+
label="Negative Prompt",
|
352 |
+
placeholder="low quality",
|
353 |
+
value="face portrait, (deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime), text, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, bare breasts, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, long neck, UnrealisticDream",
|
354 |
+
)
|
355 |
+
num_steps = gr.Slider(
|
356 |
+
label="Number of sample steps",
|
357 |
+
minimum=25,
|
358 |
+
maximum=100,
|
359 |
+
step=1,
|
360 |
+
value=40,
|
361 |
+
)
|
362 |
+
guidance_scale = gr.Slider(
|
363 |
+
label="Guidance scale",
|
364 |
+
minimum=1.0,
|
365 |
+
maximum=10.0,
|
366 |
+
step=0.5,
|
367 |
+
value=4,
|
368 |
+
)
|
369 |
+
seed = gr.Slider(
|
370 |
+
label="Seed",
|
371 |
+
minimum=0,
|
372 |
+
maximum=MAX_SEED,
|
373 |
+
step=1,
|
374 |
+
value=985,
|
375 |
+
)
|
376 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
|
377 |
+
with gr.Column():
|
378 |
+
result_video = gr.Video(label="Generated Animation", interactive=False)
|
379 |
+
|
380 |
+
files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files_gallery, clear_button_column, files])
|
381 |
+
remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files_gallery, clear_button_column, files, init_img_selected_idx])
|
382 |
+
|
383 |
+
init_img_files.upload(fn=swap_to_gallery, inputs=init_img_files, outputs=[uploaded_init_img_gallery, init_clear_button_column, init_img_files])
|
384 |
+
remove_init_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_init_img_gallery, init_clear_button_column,
|
385 |
+
init_img_files, init_img_selected_idx])
|
386 |
+
gen_init.click(fn=gen_init_images, inputs=[uploaded_files_gallery, prompt, adaface_id_cfg_scale],
|
387 |
+
outputs=[uploaded_init_img_gallery, init_img_files, init_clear_button_column])
|
388 |
+
uploaded_init_img_gallery.select(fn=get_clicked_image, inputs=None, outputs=init_img_selected_idx)
|
389 |
+
|
390 |
+
submit.click(fn=validate,
|
391 |
+
inputs=[prompt],outputs=None).success(
|
392 |
+
fn=randomize_seed_fn,
|
393 |
+
inputs=[seed, randomize_seed],
|
394 |
+
outputs=seed,
|
395 |
+
queue=False,
|
396 |
+
api_name=False,
|
397 |
+
).then(
|
398 |
+
fn=generate_image,
|
399 |
+
inputs=[image_container, files, init_img_files, init_img_selected_idx, init_image_strength, init_image_final_weight,
|
400 |
+
prompt, negative_prompt, num_steps, video_length, guidance_scale,
|
401 |
+
seed, attn_scale, image_embed_scale,
|
402 |
+
is_adaface_enabled, adaface_ckpt_path, adaface_id_cfg_scale, adaface_power_scale, adaface_anneal_steps],
|
403 |
+
outputs=[result_video]
|
404 |
+
)
|
405 |
+
gr.Examples( fn=generate_image, examples=[], #examples,
|
406 |
+
inputs=[image_container, files, init_img_files, init_img_selected_idx, init_image_strength, init_image_final_weight,
|
407 |
+
prompt, negative_prompt, num_steps, video_length, guidance_scale,
|
408 |
+
seed, attn_scale, image_embed_scale,
|
409 |
+
is_adaface_enabled, adaface_ckpt_path, adaface_id_cfg_scale, adaface_power_scale, adaface_anneal_steps],
|
410 |
+
outputs=[result_video], cache_examples=True )
|
411 |
+
|
412 |
+
demo.launch(share=True, server_name=args.ip, ssl_verify=False)
|
assets/alita/alita armor orig.mp4
ADDED
Binary file (241 kB). View file
|
|
assets/alita/alita armor.mp4
ADDED
Binary file (207 kB). View file
|
|
assets/alita/alita beach orig.mp4
ADDED
Binary file (127 kB). View file
|
|
assets/alita/alita beach.mp4
ADDED
Binary file (137 kB). View file
|
|
assets/alita/alita cooking orig.mp4
ADDED
Binary file (225 kB). View file
|
|
assets/alita/alita cooking.mp4
ADDED
Binary file (172 kB). View file
|
|
assets/alita/alita dancing orig.mp4
ADDED
Binary file (173 kB). View file
|
|
assets/alita/alita dancing.mp4
ADDED
Binary file (255 kB). View file
|
|
assets/alita/alita iron man orig.mp4
ADDED
Binary file (166 kB). View file
|
|