adaface-neurips commited on
Commit
02cc20b
·
1 Parent(s): 550b1c1
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +7 -0
  2. Dockerfile +15 -0
  3. README.md +6 -8
  4. README2.md +241 -0
  5. adaface/adaface-infer.py +131 -0
  6. adaface/adaface-translate.py +208 -0
  7. adaface/adaface_wrapper.py +286 -0
  8. adaface/arc2face_models.py +303 -0
  9. adaface/subj_basis_generator.py +758 -0
  10. adaface/util.py +341 -0
  11. animatediff/models/attention.py +327 -0
  12. animatediff/models/attention_bkp.py +326 -0
  13. animatediff/models/motion_module.py +552 -0
  14. animatediff/models/motion_module_bkp.py +331 -0
  15. animatediff/models/resnet.py +217 -0
  16. animatediff/models/sparse_controlnet.py +587 -0
  17. animatediff/models/unet.py +600 -0
  18. animatediff/models/unet_blocks.py +760 -0
  19. animatediff/pipelines/pipeline_animation.py +793 -0
  20. animatediff/sd/.gitattributes +35 -0
  21. animatediff/sd/feature_extractor/preprocessor_config.json +20 -0
  22. animatediff/sd/model_index.json +32 -0
  23. animatediff/sd/safety_checker/config.json +175 -0
  24. animatediff/sd/safety_checker/pytorch_model.bin +3 -0
  25. animatediff/sd/scheduler/scheduler_config.json +13 -0
  26. animatediff/sd/text_encoder/config.json +25 -0
  27. animatediff/sd/text_encoder/pytorch_model.bin +3 -0
  28. animatediff/sd/tokenizer/merges.txt +0 -0
  29. animatediff/sd/tokenizer/special_tokens_map.json +24 -0
  30. animatediff/sd/tokenizer/tokenizer_config.json +34 -0
  31. animatediff/sd/tokenizer/vocab.json +0 -0
  32. animatediff/sd/unet/config.json +36 -0
  33. animatediff/sd/unet/diffusion_pytorch_model.bin +3 -0
  34. animatediff/sd/v1-inference.yaml +70 -0
  35. animatediff/sd/vae/config.json +29 -0
  36. animatediff/sd/vae/diffusion_pytorch_model.bin +3 -0
  37. animatediff/utils/convert_from_ckpt.py +959 -0
  38. animatediff/utils/convert_lora_safetensor_to_diffusers.py +152 -0
  39. animatediff/utils/convert_original_stable_diffusion_to_diffusers.py +188 -0
  40. animatediff/utils/util.py +225 -0
  41. app.py +402 -0
  42. assets/alita/alita armor orig.mp4 +0 -0
  43. assets/alita/alita armor.mp4 +0 -0
  44. assets/alita/alita beach orig.mp4 +0 -0
  45. assets/alita/alita beach.mp4 +0 -0
  46. assets/alita/alita cooking orig.mp4 +0 -0
  47. assets/alita/alita cooking.mp4 +0 -0
  48. assets/alita/alita dancing orig.mp4 +0 -0
  49. assets/alita/alita dancing.mp4 +0 -0
  50. assets/alita/alita iron man orig.mp4 +0 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/*
2
+ __pycache__/
3
+ *.pyc
4
+ gradio_cached_examples/*
5
+ gradio_cached_examples/
6
+ samples/*
7
+ samples/
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8-slim
2
+ ENV PYTHONUNBUFFERED=1
3
+
4
+ RUN RUN apt-get update && \
5
+ apt-get install -y \
6
+ bash \
7
+ git git-lfs \
8
+ wget curl procps \
9
+ htop vim nano && \
10
+ rm -rf /var/lib/apt/lists/*
11
+
12
+ WORKDIR /app
13
+ COPY --link --chown=1000 ./ /app
14
+
15
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,12 +1,10 @@
1
  ---
2
- title: Adaface Animate
3
- emoji: 🌖
4
- colorFrom: gray
5
  colorTo: green
6
  sdk: gradio
7
- sdk_version: 4.36.1
8
  app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: AdaFace-Animate
3
+ emoji: 🎨
4
+ colorFrom: yellow
5
  colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.27.0
8
  app_file: app.py
9
+ pinned: true
10
+ ---
 
 
README2.md ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AdaFace-Animate
2
+
3
+ This folder contains the preliminary implementation of **AdaFace-Animate**.
4
+ It is a zero-shot subject-guided animation generator conditioned with human subject images, by combining AnimateDiff, ID-Animator and AdaFace. The ID-Animator provides AnimateDiff with rough subject characteristics, and AdaFace provides refined and more authentic subject facial details.
5
+
6
+ Please refer to our NeurIPS 2024 submission for more details about AdaFace:
7
+
8
+ **AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization**
9
+ </br>
10
+
11
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97 Hugging Face-Spaces-yellow)](https://huggingface.co/spaces/adaface-neurips/adaface-animate)
12
+
13
+ This pipeline uses 4 pretrained models: [Stable Diffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5), [AnimateDiff v3](https://github.com/guoyww/animatediff), [ID-Animator](https://github.com/ID-Animator/ID-Animator) and [AdaFace](https://huggingface.co/adaface-neurips/adaface).
14
+
15
+ AnimateDiff uses a SD-1.5 type checkpoint, referred to as a "DreamBooth" model. The DreamBooth model we use is an average of three SD-1.5 models named as "SAR": the original [Stable Diffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors), [AbsoluteReality V1.8.1](https://civitai.com/models/81458?modelVersionId=132760), and [RealisticVision V4.0](https://civitai.com/models/4201?modelVersionId=114367). In our experiments, this average model performs better than any of the individual models.
16
+
17
+ ## Procedures of Generation
18
+ We find that using an initial image helps stablize the animation sequence and improve the quality. When generating each example video, an initial image is first generated by AdaFace with the same prompt as used to generate the video. This image is blended with multiple frames of random noises with weights decreasing with $t$. The multi-frame blended noises are converted to a 1-second animation with AnimateDiff, conditioned by both AdaFace and ID-Animator embeddings.
19
+
20
+ ## Gallery
21
+ [Gallery](./assets/) contains 100 subject videos generated by us. They belong to 10 celebrities, each with 10 different prompts. The (shortened) prompts are: "Armor Suit", "Iron Man Costume", "Superman Costume", "Wielding a Lightsaber", "Walking on the beach", "Cooking", "Dancing", "Playing Guitar", "Reading", and "Running".
22
+
23
+ Some example videos are shown below. The full set of videos can be found in [Gallery](./assets/).
24
+
25
+ (Hint: use the horizontal scroll bar at the bottom of the table to view the full table)
26
+
27
+ <table class="center" style="table-layout: fixed; width: 100%; overflow-x: auto;">
28
+ <tr style="line-height: 1">
29
+ <td width="25%" style="text-align: center">Input (Celebrities)</td>
30
+ <td width="25%" style="text-align: center">Animation 1: Playing Guitar</td>
31
+ <td width="25%" style="text-align: center">Animation 2: Cooking</td>
32
+ <td width="25%" style="text-align: center">Animation 3: Dancing</td>
33
+ </tr>
34
+ <tr>
35
+ <td style="text-align: center"><img src="assets/jennifer-lawrence/jennifer lawrence.jpg" style="width:100%"></td>
36
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/ea12a906-8637-4b32-97ba-c439990fec0a" type="video/mp4"></video></td>
37
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/83a08691-4f4e-4898-b4ae-be5dfcd1fb85" type="video/mp4"></video></td>
38
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/1e957f80-376b-4ca7-81ca-fa63f19a1c5a" type="video/mp4"></video></td>
39
+ </tr>
40
+ <tr>
41
+ <td style="text-align: center"><img src="assets/yann-lecun/yann lecun.png" style="width:100%"></td>
42
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/0af3f6dc-d3d9-486c-a083-ab77a8397d80" type="video/mp4"></video></td>
43
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/54f3745f-abf6-4608-93c5-d8e103d05dc7" type="video/mp4"></video></td>
44
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/273ecced-a796-4e59-a43a-217db7fb4681" type="video/mp4"></video></td>
45
+ </tr>
46
+ <tr>
47
+ <td style="text-align: center"><img src="assets/gakki/gakki.png" style="width:100%"></td>
48
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/28056aeb-5ce4-42bc-a593-877ba49834b9" type="video/mp4"></video></td>
49
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/68ad643c-8a2b-43a8-9c7b-10c36c4912d4" type="video/mp4"></video></td>
50
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/93f3891d-19c5-40fb-af21-a0b2e03d0d7f" type="video/mp4"></video></td>
51
+ </tr>
52
+
53
+ </table>
54
+
55
+ To illustrate the wide range of applications of our method, we animated 8 internet memes. 4 of them are shown in the table below. The full gallery can be found in [memes](./assets/memes/).
56
+
57
+ <table class="center">
58
+ <tr style="line-height: 1">
59
+ <td width=25% style="text-align: center">Input (Memes)</td>
60
+ <td width=25% style="text-align: center">Animation</td>
61
+ <td width=25% style="text-align: center">Input</td>
62
+ <td width=25% style="text-align: center">Animation</td>
63
+ </tr>
64
+ <tr>
65
+ <td style="text-align: center">Yao Ming Laugh</td><td></td><td style="text-align: center">Girl Burning House</td></td><td>
66
+ </tr>
67
+ <tr>
68
+ <td><img src="assets/memes/yao ming laugh.jpg" style="width:100%"></td>
69
+ <td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/984a751f-ed2b-4ce3-aef8-41056ac111cf" type="video/mp4"></video></td>
70
+ <td><img src="assets/memes/girl burning house.jpg" style="width:100%"></td>
71
+ <td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/11c83ae1-dece-4798-baf5-e608ab8709e3" type="video/mp4"></video></td>
72
+ </tr>
73
+ <tr>
74
+ <td style="text-align: center">Girl with a Pearl Earring</td><td></td><td style="text-align: center">Great Gatsby</td></td><td>
75
+ </tr>
76
+ <tr>
77
+ <td><img src="assets/memes/girl with a pearl earring.jpg" style="width:100%"></td>
78
+ <td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/3b773486-b87e-4331-9e5d-ec8d54e11394" type="video/mp4"></video></td>
79
+ <td><img src="assets/memes/great gatsby.jpg" style="width:100%"></td>
80
+ <td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/db941805-8a6c-4596-ba3a-247b54baa5ef" type="video/mp4"></video></td>
81
+ </tr>
82
+ </table>
83
+
84
+ ## Comparison with ID-Animator, with AdaFace Initial Images
85
+ To compare with the baseline method "ID-Animator", for each video, we disable AdaFace, and generate the corresponding video with ID-Animator, using otherwise identical settings: the same subject image(s) and initial image, and the same random seed and prompt. The table below compares some of these videos side-by-side with the AdaFace-Animate videos. The full set of ID-Animator videos can be found in each subject folder in [Gallery](./assets/), named as "* orig.mp4".
86
+
87
+ **NOTE** Since ID-Animator videos utilize initial images generated by **AdaFace**, this gives ID-Animator an advantage over the original ID-Animator.
88
+
89
+ (Hint: use the horizontal scroll bar at the bottom of the table to view the full table)
90
+
91
+ <table class="center" style="table-layout: fixed; width: 100%; overflow-x: auto;">
92
+ <tr style="line-height: 1">
93
+ <td width="14%" style="text-align: center; white-space: normal; word-wrap: break-word;">Initial Image: Playing Guitar</td>
94
+ <td width="18%" style="text-align: center; white-space: normal; word-wrap: break-word;">ID-Animator: Playing Guitar</td>
95
+ <td width="18%" style="text-align: center; white-space: normal; word-wrap: break-word;">AdaFace-Animate: Playing Guitar</td>
96
+ <td width="14%" style="text-align: center; white-space: normal; word-wrap: break-word;">Initial Image Dancing</td>
97
+ <td width="18%" style="text-align: center; white-space: normal; word-wrap: break-word;">ID-Animator: Dancing</td>
98
+ <td width="18%" style="text-align: center; white-space: normal; word-wrap: break-word;">AdaFace-Animate: Dancing</td>
99
+ </tr>
100
+ <tr>
101
+ <td style="text-align: center"><img src="assets/jennifer-lawrence/init images/jennifer lawrence playing guitar.jpg" style="width:100%"></td>
102
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/f5d0f2c6-f4bd-4517-bfa1-021db1577895" type="video/mp4"></video></td>
103
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/ea12a906-8637-4b32-97ba-c439990fec0a" type="video/mp4"></video></td>
104
+ <td style="text-align: center"><img src="assets/jennifer-lawrence/init images/jennifer lawrence dancing.jpg" style="width:100%"></td>
105
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/421f5b81-e1a7-459a-869a-f7f6dc51a74e" type="video/mp4"></video></td>
106
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/1e957f80-376b-4ca7-81ca-fa63f19a1c5a"
107
+ type="video/mp4"></video></td>
108
+ </tr>
109
+ <tr>
110
+ <td style="text-align: center"><img src="assets/yann-lecun/init images/yann lecun playing guitar.jpg" style="width:100%"></td>
111
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/3bbfc15b-4205-4052-b5cc-c4f8d6d17027" type="video/mp4"></video></td>
112
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/0af3f6dc-d3d9-486c-a083-ab77a8397d80" type="video/mp4"></video></td>
113
+ <td style="text-align: center"><img src="assets/yann-lecun/init images/yann lecun dancing.jpg" style="width:100%"></td>
114
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/75e191f4-87e2-486c-90e7-c9e21a1bf494" type="video/mp4"></video></td>
115
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/273ecced-a796-4e59-a43a-217db7fb4681"
116
+ type="video/mp4"></video></td>
117
+ </tr>
118
+ <tr>
119
+ <td style="text-align: center"><img src="assets/gakki/init images/gakki playing guitar.jpg" style="width:100%"></td>
120
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/6a5579ce-23e3-4603-8917-00a16d6a3682" type="video/mp4"></video></td>
121
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/28056aeb-5ce4-42bc-a593-877ba49834b9" type="video/mp4"></video></td>
122
+ <td style="text-align: center"><img src="assets/gakki/init images/gakki dancing.jpg" style="width:100%"></td>
123
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/28082e58-a0ed-4492-8c51-cb563f92baeb" type="video/mp4"></video></td>
124
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/93f3891d-19c5-40fb-af21-a0b2e03d0d7f" type="video/mp4"></video></td>
125
+ </tr>
126
+
127
+ </table>
128
+
129
+ The table below compares the animated internet memes. The initial image for each video is the meme image itself. For "Yao Ming laughing" and "Great Gatsby", 2~3 extra portrait photos of the subject are included as the subject images to enhance the facial fidelity. For other memes, the subject image is only the meme image. The full set of ID-Animator meme videos can be found in [memes](./assets/memes/), named as "* orig.mp4".
130
+
131
+ <table class="center" style="width: 60%;">
132
+ <tr style="line-height: 1">
133
+ <td width=20% style="text-align: center">Input (Memes)</td>
134
+ <td width=20% style="text-align: center">ID-Animator</td>
135
+ <td width=20% style="text-align: center">AdaFace-Animate</td>
136
+ </tr>
137
+ <tr>
138
+ <td><img src="assets/memes/yao ming laugh.jpg" style="width:100%"></td>
139
+ <td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/9daf814c-ae8a-476d-9c32-fa9ef6be16d9" type="video/mp4"></video></td>
140
+ <td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/984a751f-ed2b-4ce3-aef8-41056ac111cf" type="video/mp4"></video></td>
141
+ <tr>
142
+ <td><img src="assets/memes/girl with a pearl earring.jpg" style="width:100%"></td>
143
+ <td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/05ed29d5-4eaa-4a0a-bee2-bc77e5649f58" type="video/mp4"></video></td>
144
+ <td><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/3b773486-b87e-4331-9e5d-ec8d54e11394" type="video/mp4"></video></td>
145
+ </tr>
146
+ </table>
147
+
148
+ We can see that the subjects in AdaFace-Animate videos have more authentic facial features and better preserve the facial expressions, while the subjects in ID-Animator videos are less authentic and faithful to the original images.
149
+
150
+ ## Comparison with ID-Animator, without AdaFace Initial Images
151
+ To exclude the effects of AdaFace, we generate a subset of videos with AdaFace-Animate / ID-Animator *without initial images*. These videos were generated under the same settings as above, except not using initial images. The table below shows a selection of the videos. The complete set of such videos can be found in [no-init](./assets/no-init/). It can be seen that without the help of AdaFace initial images, the compositionality, or the overall layout deteriorates on some prompts. In particular, some background objects are suppressed by over-expressed facial features. Moreover, the performance discrepancy between AdaFace-Animate and ID-Animator becomes more pronounced.
152
+
153
+ (Hint: use the horizontal scroll bar at the bottom of the table to view the full table)
154
+
155
+ <table class="center" style="table-layout: fixed; width: 100%; overflow-x: auto;">
156
+ <tr style="line-height: 1">
157
+ <td width="20%" style="text-align: center; white-space: normal; word-wrap: break-word;">Input (Celebrities)</td>
158
+ <td width="20%" style="text-align: center; white-space: normal; word-wrap: break-word;">ID-Animator: Playing Guitar</td>
159
+ <td width="20%" style="text-align: center; white-space: normal; word-wrap: break-word;">AdaFace-Animate: Playing Guitar</td>
160
+ <td width="20%" style="text-align: center; white-space: normal; word-wrap: break-word;">ID-Animator: Dancing</td>
161
+ <td width="20%" style="text-align: center; white-space: normal; word-wrap: break-word;">AdaFace-Animate: Dancing</td>
162
+ </tr>
163
+ <tr>
164
+ <td style="text-align: center"><img src="assets/jennifer-lawrence/jennifer lawrence.jpg" style="width:100%"></td>
165
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/2c3fa70b-4a38-48d1-aead-cd94976f6beb" type="video/mp4"></video></td>
166
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/f658f9e6-c3b6-4c4a-920c-00a89b98d97a" type="video/mp4"></video></td>
167
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/2de5cb38-f62c-4e9d-90ad-9bbb72d1ba7a" type="video/mp4"></video></td>
168
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/3b39e66d-c696-4022-81ff-6afae8147981" type="video/mp4"></video></td>
169
+ </tr>
170
+ <tr>
171
+ <td style="text-align: center"><img src="assets/yann-lecun/yann lecun.png" style="width:100%"></td>
172
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/7f7f8cd0-7ca3-47b4-a44d-8c6b399bdbc4" type="video/mp4"></video></td>
173
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/eb173058-2314-470a-8cf4-3702036022ad" type="video/mp4"></video></td>
174
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/cd5a9687-bae0-47fd-b82c-febc0d343ac2" type="video/mp4"></video></td>
175
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/e08c778c-5e87-40f6-a7a1-328f5d0d016f" type="video/mp4"></video></td>
176
+ </tr>
177
+ <tr>
178
+ <td style="text-align: center"><img src="assets/gakki/gakki.png" style="width:100%"></td>
179
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/0370714b-d10c-422d-adee-76f6221aa1be" type="video/mp4"></video></td>
180
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/79cd95d2-95ea-4854-816e-2caf0cbebf94" type="video/mp4"></video></td>
181
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/60fa6b2a-6e1a-48c0-a777-4b62504ff679" type="video/mp4"></video></td>
182
+ <td style="text-align: center"><video width="100%" controls src="https://github.com/siberianlynx/video-demos/assets/77731289/c72836ee-9d7a-4525-a48a-7017b60f83f3" type="video/mp4"></video></td>
183
+ </tr>
184
+
185
+ </table>
186
+
187
+ ## Installation
188
+
189
+ ### Manually Download Model Checkpoints
190
+ - Download Stable Diffusion V1.5 into ``animatediff/sd``:
191
+
192
+ ``git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 animatediff/sd``
193
+ - Download AnimateDiff motion module into ``models/v3_sd15_mm.ckpt``: https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_mm.ckpt
194
+ - Download Animatediff adapter into ``models/v3_adapter_sd_v15.ckpt``: https://huggingface.co/guoyww/animatediff/blob/main/v3_sd15_adapter.ckpt
195
+ - Download ID-Animator checkpoint into ``models/animator.ckpt`` from: https://huggingface.co/spaces/ID-Animator/ID-Animator/blob/main/animator.ckpt
196
+ - Download CLIP Image encoder into ``models/image_encoder/`` from: https://huggingface.co/spaces/ID-Animator/ID-Animator/tree/main/image_encoder
197
+ - Download AdaFace checkpoint into ``models/adaface/`` from: https://huggingface.co/adaface-neurips/adaface/tree/main/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt
198
+
199
+ ### Prepare the SAR Model
200
+
201
+ Manually download the three `.safetensors` models: the original [Stable Diffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors), [AbsoluteReality V1.8.1](https://civitai.com/models/81458?modelVersionId=132760), and [RealisticVision V4.0](https://civitai.com/models/4201?modelVersionId=114367). Save them to `models/sar`.
202
+
203
+ Run the following command to generate an average of the three models:
204
+ ```
205
+ python3 scripts/avg_models.py --input models/sar/absolutereality_v181.safetensors models/sar/realisticVisionV40_v40VAE.safetensors models/sar/v1-5-pruned.safetensors --output models/sar/sar.safetensors
206
+ ```
207
+
208
+ \[Optional Improvement\]
209
+ 1. You can replace the VAE of the SAR model with the [MSE-840000 finetuned VAE](https://huggingface.co/stabilityai/sd-vae-ft-mse-original/tree/main) for slightly better video details:
210
+ ```
211
+ python3 scripts/repl_vae.py --base_ckpt models/sar/sar.safetensors --vae_ckpt models/sar/vae-ft-mse-840000-ema-pruned.ckpt --out_ckpt models/sar/sar-vae.safetensors
212
+ mv models/sar/sar-vae.safetensors models/sar/sar.safetensors
213
+ ```
214
+
215
+ 2. You can replace the text encoder of the SAR model with the text encoder of [DreamShaper V8](https://civitai.com/models/4384?modelVersionId=252914) for slightly more authentic facial features:
216
+ ```
217
+ python3 scripts/repl_textencoder.py --base_ckpt models/sar/sar.safetensors --te_ckpt models/sar/dreamshaper_8.safetensors --out_ckpt models/sar/sar2.safetensors
218
+ mv models/sar/sar2.safetensors models/sar/sar.safetensors
219
+ ```
220
+ ### Inference
221
+
222
+ Run the demo inference scripts:
223
+ ```
224
+ python3 app.py
225
+ ```
226
+ Then connect to the Gradio interface at `local-ip-address:7860` or `https://*.gradio.live` shown in the terminal.
227
+
228
+ #### Use of Initial Image
229
+ The use of an initial image is optional. It usually helps stabilize the animation sequence and improve the quality.
230
+
231
+ You can generate 3 initial images in one go by clicking "Generate 3 new init images". The images will be based on the same prompt as the video generation. You can also use different prompts for the initial images and the video generation. Select the desired initial image by clicking on the image, and then click "Generate Video". If none of the initial images are good enough, you can generate again by clicking "Generate 3 new init images" again.
232
+
233
+ ### Common Issues
234
+ 1. **Defocus**. This is the biggest possible issue. When the subject is far from the camera, the model may not be able to generate a clear face and control the subject's facial details. In this situation, consider to increase the weights of "Image Embedding Scale", "Attention Processor Scale" and "AdaFace Embedding ID CFG Scale". You can also add a prefix "face portrait of" to the prompt to help the model focus on the face.
235
+ 2. **Motion Degeneration**. When the subject is too close to the camera, the model may not be able to generate correct motions and poses, and only generate the face. In this situation, consider to decrease the weights of "Image Embedding Scale", "Attention Processor Scale" and "AdaFace Embedding ID CFG Scale". You can also adjust the prompt slightly to let it focus on the whole body.
236
+ 3. **Lesser Facial Characteristics**. If the subject's facial characteristics is not so distinctive, you can increase the weights of "AdaFace Embedding ID CFG Scale".
237
+ 4. **Unstable Motions**. If the generated video has unstable motions, this is probably due to the limitations of AnimateDiff. Nonetheless, you can make it more stable by using a carefully selected initial image, and optionally increase the "Init Image Strength" and "Final Weight of the Init Image". Note that when "Final Weight of the Init Image" is larger, the motion in the generated video will be less dynamic.
238
+
239
+
240
+ ## Disclaimer
241
+ This project is intended for academic purposes only. We do not accept responsibility for user-generated content. Users are solely responsible for their own actions. The contributors to this project are not legally affiliated with, nor are they liable for, the actions of users. Please use this generative model responsibly, in accordance with ethical and legal standards.
adaface/adaface-infer.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from adaface.adaface_wrapper import AdaFaceWrapper
2
+ import torch
3
+ #import torch.nn.functional as F
4
+ from PIL import Image
5
+ import numpy as np
6
+ import os, argparse, glob, re
7
+
8
+ def save_images(images, num_images_per_row, subject_name, prompt, noise_level, save_dir = "samples-ada"):
9
+ if num_images_per_row > len(images):
10
+ num_images_per_row = len(images)
11
+
12
+ os.makedirs(save_dir, exist_ok=True)
13
+
14
+ num_columns = int(np.ceil(len(images) / num_images_per_row))
15
+ # Save 4 images as a grid image in save_dir
16
+ grid_image = Image.new('RGB', (512 * num_images_per_row, 512 * num_columns))
17
+ for i, image in enumerate(images):
18
+ image = image.resize((512, 512))
19
+ grid_image.paste(image, (512 * (i % num_images_per_row), 512 * (i // num_images_per_row)))
20
+
21
+ prompt_sig = prompt.replace(" ", "_").replace(",", "_")
22
+ grid_filepath = os.path.join(save_dir, f"{subject_name}-{prompt_sig}-noise{noise_level:.02f}.png")
23
+ if os.path.exists(grid_filepath):
24
+ grid_count = 2
25
+ grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-noise{noise_level:.02f}-{grid_count}.jpg')
26
+ while os.path.exists(grid_filepath):
27
+ grid_count += 1
28
+ grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-noise{noise_level:.02f}-{grid_count}.jpg')
29
+
30
+ grid_image.save(grid_filepath)
31
+ print(f"Saved to {grid_filepath}")
32
+
33
+ def seed_everything(seed):
34
+ np.random.seed(seed)
35
+ torch.manual_seed(seed)
36
+ torch.cuda.manual_seed_all(seed)
37
+ torch.backends.cudnn.deterministic = True
38
+ torch.backends.cudnn.benchmark = False
39
+ os.environ["PL_GLOBAL_SEED"] = str(seed)
40
+
41
+ def parse_args():
42
+ parser = argparse.ArgumentParser()
43
+ parser.add_argument("--base_model_path", type=str, default='runwayml/stable-diffusion-v1-5',
44
+ help="Type of checkpoints to use (default: SD 1.5)")
45
+ parser.add_argument("--embman_ckpt", type=str, required=True,
46
+ help="Path to the checkpoint of the embedding manager")
47
+ parser.add_argument("--subject", type=str, required=True)
48
+ parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use")
49
+ parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate")
50
+ parser.add_argument("--prompt", type=str, default="a woman z in superman costume")
51
+ parser.add_argument("--noise", dest='noise_level', type=float, default=0)
52
+ parser.add_argument("--randface", action="store_true")
53
+ parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
54
+ help="Guidance scale for the diffusion model")
55
+ parser.add_argument("--id_cfg_scale", type=float, default=1,
56
+ help="CFG scale when generating the identity embeddings")
57
+
58
+ parser.add_argument("--subject_string",
59
+ type=str, default="z",
60
+ help="Subject placeholder string used in prompts to denote the concept.")
61
+ parser.add_argument("--num_vectors", type=int, default=16,
62
+ help="Number of vectors used to represent the subject.")
63
+ parser.add_argument("--num_images_per_row", type=int, default=4,
64
+ help="Number of images to display in a row in the output grid image.")
65
+ parser.add_argument("--num_inference_steps", type=int, default=50,
66
+ help="Number of DDIM inference steps")
67
+ parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
68
+ parser.add_argument("--seed", type=int, default=42,
69
+ help="the seed (for reproducible sampling). Set to -1 to disable.")
70
+ args = parser.parse_args()
71
+
72
+ return args
73
+
74
+ if __name__ == "__main__":
75
+ args = parse_args()
76
+ if args.seed != -1:
77
+ seed_everything(args.seed)
78
+
79
+ if re.match(r"^\d+$", args.device):
80
+ args.device = f"cuda:{args.device}"
81
+ print(f"Using device {args.device}")
82
+
83
+ adaface = AdaFaceWrapper("text2img", args.base_model_path, args.embman_ckpt, args.device,
84
+ args.subject_string, args.num_vectors, args.num_inference_steps)
85
+
86
+ if not args.randface:
87
+ image_folder = args.subject
88
+ if image_folder.endswith("/"):
89
+ image_folder = image_folder[:-1]
90
+
91
+ if os.path.isfile(image_folder):
92
+ # Get the second to the last part of the path
93
+ subject_name = os.path.basename(os.path.dirname(image_folder))
94
+ image_paths = [image_folder]
95
+
96
+ else:
97
+ subject_name = os.path.basename(image_folder)
98
+ image_types = ["*.jpg", "*.png", "*.jpeg"]
99
+ alltype_image_paths = []
100
+ for image_type in image_types:
101
+ # glob returns the full path.
102
+ image_paths = glob.glob(os.path.join(image_folder, image_type))
103
+ if len(image_paths) > 0:
104
+ alltype_image_paths.extend(image_paths)
105
+
106
+ # Filter out images of "*_mask.png"
107
+ alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path]
108
+
109
+ # image_paths contain at most args.example_image_count full image paths.
110
+ if args.example_image_count > 0:
111
+ image_paths = alltype_image_paths[:args.example_image_count]
112
+ else:
113
+ image_paths = alltype_image_paths
114
+ else:
115
+ subject_name = None
116
+ image_paths = None
117
+ image_folder = None
118
+
119
+ subject_name = "randface-" + str(torch.seed()) if args.randface else subject_name
120
+ rand_face_embs = torch.randn(1, 512)
121
+
122
+ pre_face_embs = rand_face_embs if args.randface else None
123
+ noise = torch.randn(args.out_image_count, 4, 64, 64).cuda()
124
+ # args.noise_level: the *relative* std of the noise added to the face embeddings.
125
+ # A noise level of 0.08 could change gender, but 0.06 is usually safe.
126
+ # adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call).
127
+ adaface_subj_embs = adaface.generate_adaface_embeddings(image_paths, image_folder, pre_face_embs, args.randface,
128
+ out_id_embs_scale=args.id_cfg_scale, noise_level=args.noise_level,
129
+ update_text_encoder=True)
130
+ images = adaface(noise, args.prompt, args.guidance_scale, args.out_image_count, verbose=True)
131
+ save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.noise_level)
adaface/adaface-translate.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from adaface.adaface_wrapper import AdaFaceWrapper
2
+ import torch
3
+ #import torch.nn.functional as F
4
+ from PIL import Image
5
+ import numpy as np
6
+ import os, argparse, glob, re, shutil
7
+
8
+ def str2bool(v):
9
+ if isinstance(v, bool):
10
+ return v
11
+ if v.lower() in ("yes", "true", "t", "y", "1"):
12
+ return True
13
+ elif v.lower() in ("no", "false", "f", "n", "0"):
14
+ return False
15
+ else:
16
+ raise argparse.ArgumentTypeError("Boolean value expected.")
17
+
18
+ def seed_everything(seed):
19
+ np.random.seed(seed)
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed_all(seed)
22
+ torch.backends.cudnn.deterministic = True
23
+ torch.backends.cudnn.benchmark = False
24
+ os.environ["PL_GLOBAL_SEED"] = str(seed)
25
+
26
+ def parse_args():
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument("--base_model_path", type=str, default='models/realisticvision/realisticVisionV40_v40VAE.safetensors',
29
+ help="Path to the UNet checkpoint (default: RealisticVision 4.0)")
30
+ parser.add_argument("--embman_ckpt", type=str, required=True,
31
+ help="Path to the checkpoint of the embedding manager")
32
+ parser.add_argument("--in_folder", type=str, required=True, help="Path to the folder containing input images")
33
+ # If True, the input folder contains images of mixed subjects.
34
+ # If False, the input folder contains multiple subfolders, each of which contains images of the same subject.
35
+ parser.add_argument("--is_mix_subj_folder", type=str2bool, const=True, default=False, nargs="?",
36
+ help="Whether the input folder contains images of mixed subjects")
37
+ parser.add_argument("--max_images_per_subject", type=int, default=5, help="Number of example images used per subject")
38
+ parser.add_argument("--trans_subject_count", type=int, default=-1, help="Number of example images to be translated")
39
+ parser.add_argument("--out_folder", type=str, required=True, help="Path to the folder saving output images")
40
+ parser.add_argument("--out_count_per_input_image", type=int, default=1, help="Number of output images to generate per input image")
41
+ parser.add_argument("--copy_masks", action="store_true", help="Copy the mask images to the output folder")
42
+ parser.add_argument("--noise", dest='noise_level', type=float, default=0)
43
+ parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
44
+ help="Guidance scale for the diffusion model")
45
+ parser.add_argument("--ref_img_strength", type=float, default=0.8,
46
+ help="Strength of the reference image in the output image.")
47
+ parser.add_argument("--subject_string",
48
+ type=str, default="z",
49
+ help="Subject placeholder string used in prompts to denote the concept.")
50
+ parser.add_argument("--num_vectors", type=int, default=16,
51
+ help="Number of vectors used to represent the subject.")
52
+ parser.add_argument("--prompt", type=str, default="a person z")
53
+ parser.add_argument("--num_images_per_row", type=int, default=4,
54
+ help="Number of images to display in a row in the output grid image.")
55
+ parser.add_argument("--num_inference_steps", type=int, default=50,
56
+ help="Number of DDIM inference steps")
57
+ parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use. If num_gpus > 1, use accelerate for distributed execution.")
58
+ parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
59
+ parser.add_argument("--seed", type=int, default=42,
60
+ help="the seed (for reproducible sampling). Set to -1 to disable.")
61
+ args = parser.parse_args()
62
+
63
+ return args
64
+
65
+ if __name__ == "__main__":
66
+ args = parse_args()
67
+ if args.seed != -1:
68
+ seed_everything(args.seed)
69
+
70
+ # screen -dm -L -Logfile trans_rv4-2.txt accelerate launch --multi_gpu --num_processes=2 scripts/adaface-translate.py
71
+ # --embman_ckpt logs/subjects-celebrity2024-05-16T17-22-46_zero3-ada/checkpoints/embeddings_gs-30000.pt
72
+ # --base_model_path models/realisticvision/realisticVisionV40_v40VAE.safetensors --in_folder /data/shaohua/VGGface2_HQ_masks/
73
+ # --is_mix_subj_folder 0 --out_folder /data/shaohua/VGGface2_HQ_masks_rv4a --copy_masks --num_gpus 2
74
+ if args.num_gpus > 1:
75
+ from accelerate import PartialState
76
+ distributed_state = PartialState()
77
+ args.device = distributed_state.device
78
+ process_index = distributed_state.process_index
79
+ elif re.match(r"^\d+$", args.device):
80
+ args.device = f"cuda:{args.device}"
81
+ distributed_state = None
82
+ process_index = 0
83
+
84
+ adaface = AdaFaceWrapper("img2img", args.base_model_path, args.embman_ckpt, args.device,
85
+ args.subject_string, args.num_vectors, args.num_inference_steps)
86
+
87
+ in_folder = args.in_folder
88
+ if os.path.isfile(in_folder):
89
+ subject_folders = [ os.path.dirname(in_folder) ]
90
+ images_by_subject = [[in_folder]]
91
+ else:
92
+ if not args.is_mix_subj_folder:
93
+ in_folders = [in_folder]
94
+ else:
95
+ in_folders = [ os.path.join(in_folder, subfolder) for subfolder in sorted(os.listdir(in_folder)) ]
96
+
97
+ images_by_subject = []
98
+ subject_folders = []
99
+ for in_folder in in_folders:
100
+ image_types = ["*.jpg", "*.png", "*.jpeg"]
101
+ alltype_image_paths = []
102
+ for image_type in image_types:
103
+ # glob returns the full path.
104
+ image_paths = glob.glob(os.path.join(in_folder, image_type))
105
+ if len(image_paths) > 0:
106
+ alltype_image_paths.extend(image_paths)
107
+
108
+ # Filter out images of "*_mask.png"
109
+ alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path]
110
+ alltype_image_paths = sorted(alltype_image_paths)
111
+
112
+ if not args.is_mix_subj_folder:
113
+ # image_paths contain at most args.max_images_per_subject full image paths.
114
+ if args.max_images_per_subject > 0:
115
+ image_paths = alltype_image_paths[:args.max_images_per_subject]
116
+ else:
117
+ image_paths = alltype_image_paths
118
+
119
+ images_by_subject.append(image_paths)
120
+ subject_folders.append(in_folder)
121
+ else:
122
+ # Each image in the folder is treated as an individual subject.
123
+ images_by_subject.extend([[image_path] for image_path in alltype_image_paths])
124
+ subject_folders.extend([in_folder] * len(alltype_image_paths))
125
+
126
+ if args.trans_subject_count > 0 and len(subject_folders) >= args.trans_subject_count:
127
+ break
128
+
129
+ if args.trans_subject_count > 0:
130
+ images_by_subject = images_by_subject[:args.trans_subject_count]
131
+ subject_folders = subject_folders[:args.trans_subject_count]
132
+
133
+ out_image_count = 0
134
+ out_mask_count = 0
135
+ if not args.out_folder.endswith("/"):
136
+ args.out_folder += "/"
137
+
138
+ if args.num_gpus > 1:
139
+ # Split the subjects across the GPUs.
140
+ subject_folders = subject_folders[process_index::args.num_gpus]
141
+ images_by_subject = images_by_subject[process_index::args.num_gpus]
142
+ #subject_folders, images_by_subject = distributed_state.split_between_processes(zip(subject_folders, images_by_subject))
143
+
144
+ for (subject_folder, image_paths) in zip(subject_folders, images_by_subject):
145
+ # If is_mix_subj_folder, then image_paths only contains 1 image, and we use the file name as the signature of the image.
146
+ # Otherwise, we use the folder name as the signature of the images.
147
+ images_sig = subject_folder if not args.is_mix_subj_folder else os.path.basename(image_paths[0])
148
+
149
+ print(f"Translating {images_sig}...")
150
+ with torch.no_grad():
151
+ adaface_subj_embs = adaface.generate_adaface_embeddings(image_paths, subject_folder, None, False,
152
+ out_id_embs_scale=1, noise_level=args.noise_level,
153
+ update_text_encoder=True)
154
+
155
+ # Replace the first occurrence of "in_folder" with "out_folder" in the path of the subject_folder.
156
+ subject_out_folder = subject_folder.replace(args.in_folder, args.out_folder, 1)
157
+ if not os.path.exists(subject_out_folder):
158
+ os.makedirs(subject_out_folder)
159
+ print(f"Output images will be saved to {subject_out_folder}")
160
+
161
+ in_images = []
162
+ for image_path in image_paths:
163
+ image = Image.open(image_path).convert("RGB").resize((512, 512))
164
+ # [512, 512, 3] -> [3, 512, 512].
165
+ image = np.array(image).transpose(2, 0, 1)
166
+ # Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU.
167
+ image = torch.tensor(image).unsqueeze(0).float().cuda()
168
+ in_images.append(image)
169
+
170
+ # Put all input images of the subject into a batch. This assumes max_images_per_subject is small.
171
+ # NOTE: For simplicity, we do not check overly large batch sizes.
172
+ in_images = torch.cat(in_images, dim=0)
173
+ # in_images: [5, 3, 512, 512].
174
+ # Normalize the pixel values to [0, 1].
175
+ in_images = in_images / 255.0
176
+ num_out_images = len(in_images) * args.out_count_per_input_image
177
+
178
+ with torch.no_grad():
179
+ # args.noise_level: the *relative* std of the noise added to the face embeddings.
180
+ # A noise level of 0.08 could change gender, but 0.06 is usually safe.
181
+ # The returned adaface_subj_embs are already incorporated in the text encoder, and not used explicitly.
182
+ # NOTE: We assume out_count_per_input_image == 1, so that the output images are of the same number as the input images.
183
+ out_images = adaface(in_images, args.prompt, args.guidance_scale, num_out_images, ref_img_strength=args.ref_img_strength)
184
+
185
+ for img_i, img in enumerate(out_images):
186
+ # out_images: subj_1, subj_2, ..., subj_n, subj_1, subj_2, ..., subj_n, ...
187
+ subj_i = img_i % len(in_images)
188
+ copy_i = img_i // len(in_images)
189
+ image_filename_stem, image_fileext = os.path.splitext(os.path.basename(image_paths[subj_i]))
190
+ if copy_i == 0:
191
+ img.save(os.path.join(subject_out_folder, f"{image_filename_stem}{image_fileext}"))
192
+ else:
193
+ img.save(os.path.join(subject_out_folder, f"{image_filename_stem}_{copy_i}{image_fileext}"))
194
+
195
+ if args.copy_masks:
196
+ mask_path = image_paths[subj_i].replace(image_fileext, "_mask.png")
197
+ if os.path.exists(mask_path):
198
+ if copy_i == 0:
199
+ shutil.copy(mask_path, subject_out_folder)
200
+ else:
201
+ mask_filename_stem = image_filename_stem
202
+ shutil.copy(mask_path, os.path.join(subject_out_folder, f"{mask_filename_stem}_{copy_i}_mask.png"))
203
+
204
+ out_mask_count += 1
205
+
206
+ out_image_count += len(out_images)
207
+
208
+ print(f"{out_image_count} output images and {out_mask_count} masks saved to {args.out_folder}")
adaface/adaface_wrapper.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPTextModel
4
+ from diffusers import (
5
+ StableDiffusionPipeline,
6
+ StableDiffusionImg2ImgPipeline,
7
+ UNet2DConditionModel,
8
+ DDIMScheduler,
9
+ AutoencoderKL,
10
+ )
11
+ from insightface.app import FaceAnalysis
12
+ from adaface.arc2face_models import CLIPTextModelWrapper
13
+ from adaface.util import get_arc2face_id_prompt_embs
14
+ import re, os
15
+
16
+ class AdaFaceWrapper(nn.Module):
17
+ def __init__(self, pipeline_name, base_model_path, adaface_ckpt_path, device,
18
+ subject_string='z', num_vectors=16,
19
+ num_inference_steps=50, negative_prompt=None,
20
+ use_840k_vae=False, use_ds_text_encoder=False, is_training=False):
21
+ '''
22
+ pipeline_name: "text2img" or "img2img" or None. If None, the unet and vae are
23
+ removed from the pipeline to release RAM.
24
+ '''
25
+ super().__init__()
26
+ self.pipeline_name = pipeline_name
27
+ self.base_model_path = base_model_path
28
+ self.adaface_ckpt_path = adaface_ckpt_path
29
+ self.use_840k_vae = use_840k_vae
30
+ self.use_ds_text_encoder = use_ds_text_encoder
31
+ self.subject_string = subject_string
32
+ self.num_vectors = num_vectors
33
+ self.num_inference_steps = num_inference_steps
34
+ self.device = device
35
+ self.is_training = is_training
36
+ self.initialize_pipeline()
37
+ self.extend_tokenizer_and_text_encoder()
38
+ if negative_prompt is None:
39
+ self.negative_prompt = \
40
+ "flaws in the eyes, flaws in the face, lowres, non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, " \
41
+ "mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, deformed eyeballs, cross-eyed, blurry, " \
42
+ "mutation, duplicate, out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, " \
43
+ "nude, naked, nsfw, topless, bare breasts"
44
+ else:
45
+ self.negative_prompt = negative_prompt
46
+
47
+ def load_subj_basis_generator(self, adaface_ckpt_path):
48
+ ckpt = torch.load(adaface_ckpt_path, map_location='cpu')
49
+ string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
50
+ if self.subject_string not in string_to_subj_basis_generator_dict:
51
+ print(f"Subject '{self.subject_string}' not found in the embedding manager.")
52
+ breakpoint()
53
+
54
+ self.subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string]
55
+ # In the original ckpt, num_out_layers is 16 for layerwise embeddings.
56
+ # But we don't do layerwise embeddings here, so we set it to 1.
57
+ self.subj_basis_generator.num_out_layers = 1
58
+ print(f"Loaded subject basis generator for '{self.subject_string}'.")
59
+ print(repr(self.subj_basis_generator))
60
+ self.subj_basis_generator.to(self.device)
61
+ if self.is_training:
62
+ self.subj_basis_generator.train()
63
+ else:
64
+ self.subj_basis_generator.eval()
65
+
66
+ def initialize_pipeline(self):
67
+ self.load_subj_basis_generator(self.adaface_ckpt_path)
68
+ # arc2face_text_encoder maps the face analysis embedding to 16 face embeddings
69
+ # in the UNet image space.
70
+ arc2face_text_encoder = CLIPTextModelWrapper.from_pretrained(
71
+ 'models/arc2face', subfolder="encoder", torch_dtype=torch.float16
72
+ )
73
+ self.arc2face_text_encoder = arc2face_text_encoder.to(self.device)
74
+
75
+ if self.use_840k_vae:
76
+ # The 840000-step vae model is slightly better in face details than the original vae model.
77
+ # https://huggingface.co/stabilityai/sd-vae-ft-mse-original
78
+ vae = AutoencoderKL.from_single_file("models/diffusers/sd-vae-ft-mse-original/vae-ft-mse-840000-ema-pruned.ckpt", torch_dtype=torch.float16)
79
+ else:
80
+ vae = None
81
+
82
+ if self.use_ds_text_encoder:
83
+ # The dreamshaper v7 finetuned text encoder follows the prompt slightly better than the original text encoder.
84
+ # https://huggingface.co/Lykon/DreamShaper/tree/main/text_encoder
85
+ text_encoder = CLIPTextModel.from_pretrained("models/ds_text_encoder", torch_dtype=torch.float16)
86
+ else:
87
+ text_encoder = None
88
+
89
+ remove_unet = False
90
+
91
+ if self.pipeline_name == "img2img":
92
+ PipelineClass = StableDiffusionImg2ImgPipeline
93
+ elif self.pipeline_name == "text2img":
94
+ PipelineClass = StableDiffusionPipeline
95
+ # pipeline_name is None means only use this instance to generate adaface embeddings, not to generate images.
96
+ elif self.pipeline_name is None:
97
+ PipelineClass = StableDiffusionPipeline
98
+ remove_unet = True
99
+ else:
100
+ raise ValueError(f"Unknown pipeline name: {self.pipeline_name}")
101
+
102
+ if os.path.isfile(self.base_model_path):
103
+ pipeline = PipelineClass.from_single_file(
104
+ self.base_model_path,
105
+ torch_dtype=torch.float16
106
+ )
107
+ else:
108
+ pipeline = PipelineClass.from_pretrained(
109
+ self.base_model_path,
110
+ torch_dtype=torch.float16,
111
+ safety_checker=None
112
+ )
113
+ print(f"Loaded pipeline from {self.base_model_path}.")
114
+
115
+ if self.use_840k_vae:
116
+ pipeline.vae = vae
117
+ print("Replaced the VAE with the 840k-step VAE.")
118
+
119
+ if self.use_ds_text_encoder:
120
+ pipeline.text_encoder = text_encoder
121
+ print("Replaced the text encoder with the DreamShaper text encoder.")
122
+
123
+ if remove_unet:
124
+ # Remove unet and vae to release RAM. Only keep tokenizer and text_encoder.
125
+ pipeline.unet = None
126
+ pipeline.vae = None
127
+ print("Removed UNet and VAE from the pipeline.")
128
+
129
+ noise_scheduler = DDIMScheduler(
130
+ num_train_timesteps=1000,
131
+ beta_start=0.00085,
132
+ beta_end=0.012,
133
+ beta_schedule="scaled_linear",
134
+ clip_sample=False,
135
+ set_alpha_to_one=False,
136
+ steps_offset=1,
137
+ )
138
+
139
+ pipeline.scheduler = noise_scheduler
140
+ self.pipeline = pipeline.to(self.device)
141
+ # FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
142
+ # Note there's a second "model" in the path.
143
+ self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
144
+ self.face_app.prepare(ctx_id=0, det_size=(512, 512))
145
+ # Patch the missing tokenizer in the subj_basis_generator.
146
+ if not hasattr(self.subj_basis_generator, 'clip_tokenizer'):
147
+ self.subj_basis_generator.clip_tokenizer = self.pipeline.tokenizer
148
+ print("Patched the missing tokenizer in the subj_basis_generator.")
149
+
150
+ def extend_tokenizer_and_text_encoder(self):
151
+ if self.num_vectors < 1:
152
+ raise ValueError(f"num_vectors has to be larger or equal to 1, but is {self.num_vectors}")
153
+
154
+ tokenizer = self.pipeline.tokenizer
155
+ # Add z0, z1, z2, ..., z15.
156
+ self.placeholder_tokens = []
157
+ for i in range(0, self.num_vectors):
158
+ self.placeholder_tokens.append(f"{self.subject_string}_{i}")
159
+
160
+ self.placeholder_tokens_str = " ".join(self.placeholder_tokens)
161
+
162
+ # Add the new tokens to the tokenizer.
163
+ num_added_tokens = tokenizer.add_tokens(self.placeholder_tokens)
164
+ if num_added_tokens != self.num_vectors:
165
+ raise ValueError(
166
+ f"The tokenizer already contains the token {self.subject_string}. Please pass a different"
167
+ " `subject_string` that is not already in the tokenizer.")
168
+
169
+ print(f"Added {num_added_tokens} tokens ({self.placeholder_tokens_str}) to the tokenizer.")
170
+
171
+ # placeholder_token_ids: [49408, ..., 49423].
172
+ self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.placeholder_tokens)
173
+ # print(self.placeholder_token_ids)
174
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
175
+ old_weight = self.pipeline.text_encoder.get_input_embeddings().weight
176
+ self.pipeline.text_encoder.resize_token_embeddings(len(tokenizer))
177
+ new_weight = self.pipeline.text_encoder.get_input_embeddings().weight
178
+ print(f"Resized text encoder token embeddings from {old_weight.shape} to {new_weight.shape} on {new_weight.device}.")
179
+
180
+ # Extend pipeline.text_encoder with the adaface subject emeddings.
181
+ # subj_embs: [16, 768].
182
+ def update_text_encoder_subj_embs(self, subj_embs):
183
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
184
+ token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data
185
+ with torch.no_grad():
186
+ for i, token_id in enumerate(self.placeholder_token_ids):
187
+ token_embeds[token_id] = subj_embs[i]
188
+ print(f"Updated {len(self.placeholder_token_ids)} tokens ({self.placeholder_tokens_str}) in the text encoder.")
189
+
190
+ def update_prompt(self, prompt):
191
+ # If the placeholder tokens are already in the prompt, then return the prompt as is.
192
+ if self.placeholder_tokens_str in prompt:
193
+ return prompt
194
+
195
+ # If the subject string 'z' is not in the prompt, then simply prepend the placeholder tokens to the prompt.
196
+ if re.search(r'\b' + self.subject_string + r'\b', prompt) is None:
197
+ print(f"Subject string '{self.subject_string}' not found in the prompt. Adding it.")
198
+ comp_prompt = self.placeholder_tokens_str + " " + prompt
199
+ else:
200
+ # Replace the subject string 'z' with the placeholder tokens.
201
+ comp_prompt = re.sub(r'\b' + self.subject_string + r'\b', self.placeholder_tokens_str, prompt)
202
+ return comp_prompt
203
+
204
+ # image_paths: a list of image paths. image_folder: the parent folder name.
205
+ def generate_adaface_embeddings(self, image_paths, image_folder=None,
206
+ pre_face_embs=None, gen_rand_face=False,
207
+ out_id_embs_scale=1., noise_level=0, update_text_encoder=True):
208
+ # faceid_embeds is a batch of extracted face analysis embeddings (BS * 512 = id_batch_size * 512).
209
+ # If extract_faceid_embeds is True, faceid_embeds is *the same* embedding repeated by id_batch_size times.
210
+ # Otherwise, faceid_embeds is a batch of random embeddings, each instance is different.
211
+ # The same applies to id_prompt_emb.
212
+ # faceid_embeds is in the face analysis embeddings. id_prompt_emb is in the image prompt space.
213
+ # Here id_batch_size = 1, so
214
+ # faceid_embeds: [1, 512]. NOT used later.
215
+ # id_prompt_emb: [1, 16, 768].
216
+ # NOTE: Since return_core_id_embs is True, id_prompt_emb is only the 16 core ID embeddings.
217
+ # arc2face prompt template: "photo of a id person"
218
+ # ID embeddings start from "id person ...". So there are 3 template tokens before the 16 ID embeddings.
219
+ faceid_embeds, id_prompt_emb \
220
+ = get_arc2face_id_prompt_embs(self.face_app, self.pipeline.tokenizer, self.arc2face_text_encoder,
221
+ extract_faceid_embeds=not gen_rand_face,
222
+ pre_face_embs=pre_face_embs,
223
+ # image_folder is passed only for logging purpose.
224
+ # image_paths contains the paths of the images.
225
+ image_folder=image_folder, image_paths=image_paths,
226
+ images_np=None,
227
+ id_batch_size=1,
228
+ device=self.device,
229
+ # input_max_length == 22: only keep the first 22 tokens,
230
+ # including 3 template tokens and 16 ID tokens, and BOS and EOS tokens.
231
+ # The results are indistinguishable from input_max_length=77.
232
+ input_max_length=22,
233
+ noise_level=noise_level,
234
+ return_core_id_embs=True,
235
+ gen_neg_prompt=False,
236
+ verbose=True)
237
+
238
+ # adaface_subj_embs: [1, 1, 16, 768].
239
+ # adaface_prompt_embs: [1, 77, 768] (not used).
240
+ adaface_subj_embs, adaface_prompt_embs = \
241
+ self.subj_basis_generator(id_prompt_emb, None, None,
242
+ out_id_embs_scale=out_id_embs_scale,
243
+ is_face=True, is_training=False,
244
+ adaface_prompt_embs_inf_type='full_half_pad')
245
+ # adaface_subj_embs: [16, 768]
246
+ adaface_subj_embs = adaface_subj_embs.squeeze()
247
+ if update_text_encoder:
248
+ self.update_text_encoder_subj_embs(adaface_subj_embs)
249
+ return adaface_subj_embs
250
+
251
+ def encode_prompt(self, prompt, device="cuda", verbose=False):
252
+ prompt = self.update_prompt(prompt)
253
+ if verbose:
254
+ print(f"Prompt: {prompt}")
255
+
256
+ # For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device).
257
+ # So we manually move it to GPU here.
258
+ self.pipeline.text_encoder.to(device)
259
+ # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
260
+ prompt_embeds_, negative_prompt_embeds_ = \
261
+ self.pipeline.encode_prompt(prompt, device=device, num_images_per_prompt=1,
262
+ do_classifier_free_guidance=True, negative_prompt=self.negative_prompt)
263
+ return prompt_embeds_, negative_prompt_embeds_
264
+
265
+ # ref_img_strength is used only in the img2img pipeline.
266
+ def forward(self, noise, prompt, guidance_scale=4.0, out_image_count=4, ref_img_strength=0.8, verbose=False):
267
+ # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
268
+ prompt_embeds_, negative_prompt_embeds_ = self.encode_prompt(prompt, device=self.device, verbose=verbose)
269
+
270
+ # Repeat the prompt embeddings for all images in the batch.
271
+ prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1)
272
+ negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1)
273
+ noise = noise.to(self.device).to(torch.float16)
274
+
275
+ # noise: [BS, 4, 64, 64]
276
+ # When the pipeline is text2img, strength is ignored.
277
+ images = self.pipeline(image=noise,
278
+ prompt_embeds=prompt_embeds_,
279
+ negative_prompt_embeds=negative_prompt_embeds_,
280
+ num_inference_steps=self.num_inference_steps,
281
+ guidance_scale=guidance_scale,
282
+ num_images_per_prompt=1,
283
+ strength=ref_img_strength).images
284
+ # images: [BS, 3, 512, 512]
285
+ return images
286
+
adaface/arc2face_models.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPTextModel
4
+ from transformers.models.clip.modeling_clip import CLIPAttention
5
+ from typing import Any, Callable, Dict, Optional, Tuple, Union, List
6
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
7
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
8
+ # from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
9
+ _make_causal_mask = AttentionMaskConverter._make_causal_mask
10
+ _expand_mask = AttentionMaskConverter._expand_mask
11
+
12
+ from adaface.util import add_noise_to_tensor
13
+
14
+ # Extend CLIPAttention by using multiple k_proj and v_proj in each head.
15
+ # To avoid too much increase of computation, we don't extend q_proj.
16
+ class CLIPAttentionMKV(nn.Module):
17
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
18
+
19
+ def __init__(self, config, multiplier=2):
20
+ super().__init__()
21
+ self.config = config
22
+ self.embed_dim = config.hidden_size
23
+ self.num_heads = config.num_attention_heads
24
+ self.head_dim = self.embed_dim // self.num_heads
25
+ if self.head_dim * self.num_heads != self.embed_dim:
26
+ raise ValueError(
27
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
28
+ f" {self.num_heads})."
29
+ )
30
+ self.scale = self.head_dim**-0.5
31
+ self.dropout = config.attention_dropout
32
+ self.multiplier = multiplier
33
+
34
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim * self.multiplier)
35
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim * self.multiplier)
36
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
37
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
38
+
39
+ # The (approximately) repeated token features are repeated along the last dim in tensor
40
+ # (multiplier * num_heads * head_dim), and then reshaped to (bsz, -1, num_heads, head_dim).
41
+ # Therefore, the "multiplier" dim is tucked into the seq_len dim, which looks like
42
+ # [token1_emb, token1_emb, token2_emb, token2_emb, ..., tokenN_emb, tokenN_emb].
43
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
44
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
45
+
46
+ def extend_weights(self, clip_attn_layer, layer_idx, multiplier, noise_std=0.1,
47
+ noise_std_is_relative=True, keep_norm=False, verbose=False):
48
+ self.multiplier *= multiplier
49
+ # q_proj and out_proj are the same as the original CLIPAttention.
50
+ self.q_proj.weight.data = clip_attn_layer.q_proj.weight.data.clone()
51
+ self.q_proj.bias.data = clip_attn_layer.q_proj.bias.data.clone()
52
+ self.out_proj.weight.data = clip_attn_layer.out_proj.weight.data.clone()
53
+ self.out_proj.bias.data = clip_attn_layer.out_proj.bias.data.clone()
54
+
55
+ # bias doesn't need noise perturbation, as after the weights are noised,
56
+ # different copies of the weight/bias will receive different gradients,
57
+ # making the bias terms diverge and identifiable after training.
58
+ self.v_proj.bias.data = clip_attn_layer.v_proj.bias.data.repeat(multiplier)
59
+ self.k_proj.bias.data = clip_attn_layer.k_proj.bias.data.repeat(multiplier)
60
+
61
+ self.v_proj.weight.data = clip_attn_layer.v_proj.weight.data.repeat(multiplier, 1)
62
+ self.k_proj.weight.data = clip_attn_layer.k_proj.weight.data.repeat(multiplier, 1)
63
+
64
+ if noise_std > 0:
65
+ ORIG_V_SHAPE = list(clip_attn_layer.v_proj.weight.shape)
66
+ ORIG_V_SHAPE_D0 = ORIG_V_SHAPE[0]
67
+ # Adding noise to the extra copies of the weights (keep the first copy unchanged).
68
+ self.v_proj.weight.data[ORIG_V_SHAPE_D0:] = \
69
+ add_noise_to_tensor(self.v_proj.weight.data[ORIG_V_SHAPE_D0:],
70
+ noise_std, noise_std_is_relative, keep_norm)
71
+ if verbose:
72
+ NEW_V_SHAPE = list(self.v_proj.weight.shape)
73
+ NOISED_V_SHAPE = list(self.v_proj.weight.data[ORIG_V_SHAPE_D0:].shape)
74
+ print(f"Layer {layer_idx}: {NOISED_V_SHAPE} in {NEW_V_SHAPE} of v_proj is added with {noise_std} noise")
75
+
76
+ ORIG_K_SHAPE = list(clip_attn_layer.k_proj.weight.shape)
77
+ ORIG_K_SHAPE_D0 = ORIG_K_SHAPE[0]
78
+ # Adding noise to the extra copies of the weights.
79
+ self.k_proj.weight.data[ORIG_K_SHAPE_D0:] = \
80
+ add_noise_to_tensor(self.k_proj.weight.data[ORIG_K_SHAPE_D0:],
81
+ noise_std, noise_std_is_relative, keep_norm)
82
+ if verbose:
83
+ NEW_K_SHAPE = list(self.k_proj.weight.shape)
84
+ NOISED_K_SHAPE = list(self.k_proj.weight.data[ORIG_K_SHAPE_D0:].shape)
85
+ print(f"Layer {layer_idx}: {NOISED_K_SHAPE} in {NEW_K_SHAPE} of k_proj is added with {noise_std} noise")
86
+
87
+ def forward(
88
+ self,
89
+ hidden_states: torch.Tensor,
90
+ attention_mask: Optional[torch.Tensor] = None,
91
+ causal_attention_mask: Optional[torch.Tensor] = None,
92
+ output_attentions: Optional[bool] = False,
93
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
94
+ """Input shape: Batch x Time x Channel"""
95
+
96
+ bsz, tgt_len, embed_dim = hidden_states.size()
97
+
98
+ query_states = self.q_proj(hidden_states) * self.scale
99
+ # For key_states and value_states, the multiplier is absorbed into the seq_len (dim 1, shape specified as -1).
100
+ # [token0_head_emb, token0_head_emb, token1_head_emb, token1_head_emb, ..., tokenN-1_head_emb, tokenN-1_head_emb].
101
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
102
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
103
+
104
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
105
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
106
+ key_states = key_states.view(*proj_shape)
107
+ value_states = value_states.view(*proj_shape)
108
+
109
+ src_len = key_states.size(1)
110
+ # src_len0 is the original src_len without the multiplier.
111
+ src_len0 = src_len // self.multiplier
112
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
113
+
114
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
115
+ raise ValueError(
116
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
117
+ f" {attn_weights.size()}"
118
+ )
119
+
120
+ # apply the causal_attention_mask first
121
+ if causal_attention_mask is not None:
122
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len0):
123
+ raise ValueError(
124
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len0)}, but is"
125
+ f" {causal_attention_mask.size()}"
126
+ )
127
+ # The last dim of attn_weights corresponds to [token0, token0, token1, token1, ..., tokenN-1, tokenN-1].
128
+ # If reshaping it as (self.multiplier, src_len0), it will become
129
+ # [[token0, token0, token1, token1, ..., tokenN//2], [tokenN//2+1, tokenN//2+1, ..., tokenN-1, tokenN-1]],
130
+ # and the mask will be applied to wrong elements.
131
+ # If reshaping it as (src_len0, self.multiplier), it will become
132
+ # [[token0, token1, ..., tokenN-1], [token0, token1, ..., tokenN-1]], and then
133
+ # the mask at element i will mask all the multiplier elements at i, which is desired.
134
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len0, self.multiplier) + causal_attention_mask.unsqueeze(4)
135
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
136
+
137
+ if attention_mask is not None:
138
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len0):
139
+ raise ValueError(
140
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len0)}, but is {attention_mask.size()}"
141
+ )
142
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len0, self.multiplier) + attention_mask.unsqueeze(4)
143
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
144
+
145
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
146
+
147
+ if output_attentions:
148
+ # this operation is a bit awkward, but it's required to
149
+ # make sure that attn_weights keeps its gradient.
150
+ # In order to do so, attn_weights have to reshaped
151
+ # twice and have to be reused in the following
152
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
153
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
154
+ else:
155
+ attn_weights_reshaped = None
156
+
157
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
158
+
159
+ attn_output = torch.bmm(attn_probs, value_states)
160
+
161
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
162
+ raise ValueError(
163
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
164
+ f" {attn_output.size()}"
165
+ )
166
+
167
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
168
+ attn_output = attn_output.transpose(1, 2)
169
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
170
+
171
+ attn_output = self.out_proj(attn_output)
172
+
173
+ return attn_output, attn_weights_reshaped
174
+
175
+ class CLIPTextModelWrapper(CLIPTextModel):
176
+ # Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
177
+ # Modified to accept precomputed token embeddings "input_token_embs" as input or calculate them from input_ids and return them.
178
+ def forward(
179
+ self,
180
+ input_ids: Optional[torch.Tensor] = None,
181
+ attention_mask: Optional[torch.Tensor] = None,
182
+ position_ids: Optional[torch.Tensor] = None,
183
+ output_attentions: Optional[bool] = None,
184
+ output_hidden_states: Optional[bool] = None,
185
+ return_dict: Optional[bool] = None,
186
+ input_token_embs: Optional[torch.Tensor] = None,
187
+ hidden_state_layer_weights: Optional[torch.Tensor] = None,
188
+ return_token_embs: Optional[bool] = False,
189
+ ) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]:
190
+
191
+ if return_token_embs:
192
+ return self.text_model.embeddings.token_embedding(input_ids)
193
+
194
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
195
+
196
+ output_attentions = output_attentions if output_attentions is not None else self.text_model.config.output_attentions
197
+ output_hidden_states = (
198
+ output_hidden_states if output_hidden_states is not None else self.text_model.config.output_hidden_states
199
+ )
200
+ if hidden_state_layer_weights is not None:
201
+ output_hidden_states = True
202
+ return_dict = return_dict if return_dict is not None else self.text_model.config.use_return_dict
203
+
204
+ if input_ids is None:
205
+ raise ValueError("You have to specify input_ids")
206
+
207
+ input_shape = input_ids.size()
208
+ input_ids = input_ids.view(-1, input_shape[-1])
209
+
210
+ hidden_states = self.text_model.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=input_token_embs)
211
+
212
+ # CLIP's text model uses causal mask, prepare it here.
213
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
214
+ causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
215
+ # expand attention_mask
216
+ if attention_mask is not None:
217
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
218
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
219
+
220
+ encoder_outputs = self.text_model.encoder(
221
+ inputs_embeds=hidden_states,
222
+ attention_mask=attention_mask,
223
+ causal_attention_mask=causal_attention_mask,
224
+ output_attentions=output_attentions,
225
+ # output_hidden_states is False by default, and only True if hidden_state_layer_weights is provided.
226
+ output_hidden_states=output_hidden_states,
227
+ return_dict=return_dict,
228
+ )
229
+
230
+ # If output_hidden_states is True, then encoder_outputs[0] is last_hidden_state [1, 22, 768].
231
+ # encoder_outputs[1] is hidden_states, which is a tuple of 13 hidden states, each being [1, 22, 768].
232
+ # encoder_outputs[0] == encoder_outputs[1][12].
233
+ if hidden_state_layer_weights is None:
234
+ last_hidden_state = encoder_outputs[0]
235
+ else:
236
+ num_hidden_state_layers = len(hidden_state_layer_weights)
237
+ last_hidden_states = encoder_outputs[1][-num_hidden_state_layers:]
238
+ hidden_state_layer_weights = hidden_state_layer_weights.to(last_hidden_states[0].dtype)
239
+ # Normalize the weights of to sum to 1 across layers.
240
+ # hidden_state_layer_weights: [3, 1] or [3, 768].
241
+ hidden_state_layer_weights = hidden_state_layer_weights / hidden_state_layer_weights.sum(dim=0, keepdim=True)
242
+ # [3, 1/768] -> [3, 1, 1, 1/768]
243
+ hidden_state_layer_weights = hidden_state_layer_weights.unsqueeze(1).unsqueeze(1)
244
+ # A weighted sum of last_hidden_states.
245
+ # [3, 1, 22, 768] * [3, 1, 1, 1/768] -> [3, 1, 22, 768] -> [1, 22, 768]
246
+ last_hidden_state = (torch.stack(last_hidden_states, dim=0) * hidden_state_layer_weights).sum(dim=0)
247
+
248
+ last_hidden_state = self.text_model.final_layer_norm(last_hidden_state)
249
+
250
+ # self.text_model.eos_token_id == 2 is True.
251
+ if self.text_model.eos_token_id == 2:
252
+ # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
253
+ # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
254
+ # ------------------------------------------------------------
255
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
256
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
257
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
258
+ pooled_output = last_hidden_state[
259
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
260
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
261
+ ]
262
+ else:
263
+ # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
264
+ pooled_output = last_hidden_state[
265
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
266
+ # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
267
+ (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.text_model.eos_token_id)
268
+ .int()
269
+ .argmax(dim=-1),
270
+ ]
271
+
272
+ if not return_dict:
273
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
274
+
275
+ return BaseModelOutputWithPooling(
276
+ last_hidden_state=last_hidden_state,
277
+ pooler_output=pooled_output,
278
+ hidden_states=encoder_outputs.hidden_states,
279
+ attentions=encoder_outputs.attentions,
280
+ )
281
+
282
+ # Applied to layers [begin_layer_idx, end_layer_idx) in the encoder.
283
+ # The layer indexed by end_layer_idx is not included.
284
+ # If both layer indices are -1, then apply to all layers (0-11).
285
+ def extend_clip_attention_MKV_multiplier(self, begin_layer_idx=-1, end_layer_idx=-1, multiplier=2, noise_std=0.1):
286
+ num_extended_layers = 0
287
+
288
+ for layer_idx, layer in enumerate(self.text_model.encoder.layers):
289
+ if begin_layer_idx >= 0 and layer_idx < begin_layer_idx:
290
+ continue
291
+ if end_layer_idx >= 0 and layer_idx >= end_layer_idx:
292
+ break
293
+ # This shouldn't happen, unless self_attn has already been extended as CLIPAttentionMKV.
294
+ if not isinstance(layer.self_attn, (CLIPAttention, CLIPAttentionMKV)):
295
+ breakpoint()
296
+ old_attn_layer = layer.self_attn
297
+ if not isinstance(old_attn_layer, CLIPAttentionMKV):
298
+ layer.self_attn = CLIPAttentionMKV(old_attn_layer.config, 1)
299
+ layer.self_attn.extend_weights(old_attn_layer, layer_idx, multiplier, noise_std, verbose=True)
300
+ num_extended_layers += 1
301
+
302
+ return num_extended_layers
303
+
adaface/subj_basis_generator.py ADDED
@@ -0,0 +1,758 @@