alexnasa commited on
Commit
257f706
·
verified ·
1 Parent(s): 001b61a

Upload 69 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. INSTALL.md +55 -0
  3. LICENSE.txt +201 -0
  4. Makefile +5 -0
  5. README.md +12 -12
  6. app.py +546 -0
  7. examples/desi.mp4 +3 -0
  8. examples/desi.png +3 -0
  9. examples/man.png +3 -0
  10. examples/paul.mp4 +3 -0
  11. generate.py +236 -0
  12. pyproject.toml +66 -0
  13. requirements.txt +31 -0
  14. wan/__init__.py +7 -0
  15. wan/animate.py +653 -0
  16. wan/configs/__init__.py +50 -0
  17. wan/configs/shared_config.py +20 -0
  18. wan/configs/wan_animate_14B.py +40 -0
  19. wan/configs/wan_i2v_A14B.py +37 -0
  20. wan/configs/wan_s2v_14B.py +59 -0
  21. wan/configs/wan_t2v_A14B.py +37 -0
  22. wan/configs/wan_ti2v_5B.py +36 -0
  23. wan/distributed/__init__.py +1 -0
  24. wan/distributed/fsdp.py +45 -0
  25. wan/distributed/sequence_parallel.py +176 -0
  26. wan/distributed/ulysses.py +47 -0
  27. wan/distributed/util.py +51 -0
  28. wan/image2video.py +431 -0
  29. wan/modules/__init__.py +19 -0
  30. wan/modules/animate/__init__.py +4 -0
  31. wan/modules/animate/animate_utils.py +143 -0
  32. wan/modules/animate/clip.py +542 -0
  33. wan/modules/animate/face_blocks.py +383 -0
  34. wan/modules/animate/model_animate.py +500 -0
  35. wan/modules/animate/motion_encoder.py +307 -0
  36. wan/modules/animate/preprocess/UserGuider.md +70 -0
  37. wan/modules/animate/preprocess/__init__.py +3 -0
  38. wan/modules/animate/preprocess/human_visualization.py +1357 -0
  39. wan/modules/animate/preprocess/pose2d.py +430 -0
  40. wan/modules/animate/preprocess/pose2d_utils.py +1159 -0
  41. wan/modules/animate/preprocess/preprocess_data.py +121 -0
  42. wan/modules/animate/preprocess/process_pipepline.py +354 -0
  43. wan/modules/animate/preprocess/retarget_pose.py +847 -0
  44. wan/modules/animate/preprocess/sam_utils.py +155 -0
  45. wan/modules/animate/preprocess/utils.py +226 -0
  46. wan/modules/animate/preprocess/video_predictor.py +157 -0
  47. wan/modules/animate/xlm_roberta.py +170 -0
  48. wan/modules/attention.py +256 -0
  49. wan/modules/model.py +546 -0
  50. wan/modules/s2v/__init__.py +5 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/desi.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ examples/desi.png filter=lfs diff=lfs merge=lfs -text
38
+ examples/man.png filter=lfs diff=lfs merge=lfs -text
39
+ examples/paul.mp4 filter=lfs diff=lfs merge=lfs -text
INSTALL.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Installation Guide
2
+
3
+ ## Install with pip
4
+
5
+ ```bash
6
+ pip install .
7
+ pip install .[dev] # Installe aussi les outils de dev
8
+ ```
9
+
10
+ ## Install with Poetry
11
+
12
+ Ensure you have [Poetry](https://python-poetry.org/docs/#installation) installed on your system.
13
+
14
+ To install all dependencies:
15
+
16
+ ```bash
17
+ poetry install
18
+ ```
19
+
20
+ ### Handling `flash-attn` Installation Issues
21
+
22
+ If `flash-attn` fails due to **PEP 517 build issues**, you can try one of the following fixes.
23
+
24
+ #### No-Build-Isolation Installation (Recommended)
25
+ ```bash
26
+ poetry run pip install --upgrade pip setuptools wheel
27
+ poetry run pip install flash-attn --no-build-isolation
28
+ poetry install
29
+ ```
30
+
31
+ #### Install from Git (Alternative)
32
+ ```bash
33
+ poetry run pip install git+https://github.com/Dao-AILab/flash-attention.git
34
+ ```
35
+
36
+ ---
37
+
38
+ ### Running the Model
39
+
40
+ Once the installation is complete, you can run **Wan2.2** using:
41
+
42
+ ```bash
43
+ poetry run python generate.py --task t2v-A14B --size '1280*720' --ckpt_dir ./Wan2.2-T2V-A14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
44
+ ```
45
+
46
+ #### Test
47
+ ```bash
48
+ bash tests/test.sh
49
+ ```
50
+
51
+ #### Format
52
+ ```bash
53
+ black .
54
+ isort .
55
+ ```
LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
Makefile ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .PHONY: format
2
+
3
+ format:
4
+ isort generate.py wan
5
+ yapf -i -r *.py generate.py wan
README.md CHANGED
@@ -1,12 +1,12 @@
1
- ---
2
- title: Wan2.2 Animate ZEROGPU
3
- emoji: 😻
4
- colorFrom: indigo
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Wan2.2 Animate [Local]
3
+ emoji: 🔥
4
+ colorFrom: pink
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 5.49.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from huggingface_hub import snapshot_download, hf_hub_download
3
+ import os
4
+ import subprocess
5
+ import importlib, site
6
+ from PIL import Image
7
+ import uuid
8
+ import shutil
9
+ import time
10
+ import cv2
11
+ from generate import generate, load_model
12
+ import json
13
+
14
+ # Re-discover all .pth/.egg-link files
15
+ for sitedir in site.getsitepackages():
16
+ site.addsitedir(sitedir)
17
+
18
+ # Clear caches so importlib will pick up new modules
19
+ importlib.invalidate_caches()
20
+
21
+ def sh(cmd): subprocess.check_call(cmd, shell=True)
22
+
23
+ try:
24
+ print("Attempting to download and build sam2...")
25
+
26
+ print("download sam")
27
+ sam_dir = snapshot_download(repo_id="alexnasa/sam2")
28
+
29
+ @spaces.GPU(duration=450)
30
+ def install_sam():
31
+ os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
32
+ sh(f"cd {sam_dir} && python setup.py build_ext --inplace && pip install -e .")
33
+
34
+ print("install sam")
35
+ install_sam()
36
+
37
+ # tell Python to re-scan site-packages now that the egg-link exists
38
+ import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
39
+
40
+ flash_attention_installed = True
41
+ print("sam2 installed successfully.")
42
+
43
+ except Exception as e:
44
+ print(f"⚠️ Could not install sam2: {e}")
45
+ print("Continuing without sam2...")
46
+
47
+ import torch
48
+ print(f"Torch version: {torch.__version__}")
49
+
50
+ os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/processed_results"
51
+
52
+ import gradio as gr
53
+
54
+
55
+ snapshot_download(repo_id="Wan-AI/Wan2.2-Animate-14B", local_dir="./Wan2.2-Animate-14B")
56
+ wan_animate = load_model(True)
57
+
58
+
59
+ rc_mapping = {
60
+ "Video → Ref Image" : False,
61
+ "Video ← Ref Image" : True
62
+ }
63
+
64
+
65
+ def preprocess_video(input_video_path, session_id=None):
66
+
67
+ if session_id is None:
68
+ session_id = uuid.uuid4().hex
69
+
70
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
71
+ os.makedirs(output_dir, exist_ok=True)
72
+
73
+ process_video_path = os.path.join(output_dir, 'input_video.mp4')
74
+
75
+ convert_video_to_30fps_and_clip(input_video_path, process_video_path, crop_width=720, crop_height=1280)
76
+
77
+ return process_video_path
78
+
79
+ def extract_audio_from_video_ffmpeg(video_path, output_wav_path, sample_rate=None):
80
+ """
81
+ Extracts the audio track from a video file and saves it as a WAV file.
82
+
83
+ Args:
84
+ video_path (str): Path to the input video file.
85
+ output_wav_path (str): Path to save the extracted WAV file.
86
+ sample_rate (int, optional): Output sample rate (e.g., 16000).
87
+ If None, keep the original.
88
+ """
89
+ cmd = [
90
+ 'ffmpeg',
91
+ '-i', video_path, # Input video
92
+ '-vn', # Disable video
93
+ '-acodec', 'pcm_s16le', # 16-bit PCM (WAV format)
94
+ '-ac', '1', # Mono channel (use '2' for stereo)
95
+ '-y', # Overwrite output
96
+ '-loglevel', 'error' # Cleaner output
97
+ ]
98
+
99
+ # Only add the sample rate option if explicitly specified
100
+ if sample_rate is not None:
101
+ cmd.extend(['-ar', str(sample_rate)])
102
+
103
+ cmd.append(output_wav_path)
104
+
105
+ try:
106
+ subprocess.run(cmd, check=True, capture_output=True, text=True)
107
+ except subprocess.CalledProcessError as e:
108
+ raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}")
109
+
110
+
111
+ def combine_video_and_audio_ffmpeg(video_path, audio_path, output_video_path):
112
+ """
113
+ Combines a silent MP4 video with a WAV audio file into a single MP4 with sound.
114
+
115
+ Args:
116
+ video_path (str): Path to the silent video file.
117
+ audio_path (str): Path to the WAV audio file.
118
+ output_video_path (str): Path to save the output MP4 with audio.
119
+ """
120
+ cmd = [
121
+ 'ffmpeg',
122
+ '-i', video_path, # Input video
123
+ '-i', audio_path, # Input audio
124
+ '-c:v', 'copy', # Copy video without re-encoding
125
+ '-c:a', 'aac', # Encode audio as AAC (MP4-compatible)
126
+ '-shortest', # Stop when the shortest stream ends
127
+ '-y', # Overwrite output
128
+ '-loglevel', 'error',
129
+ output_video_path
130
+ ]
131
+
132
+ try:
133
+ subprocess.run(cmd, check=True, capture_output=True, text=True)
134
+ except subprocess.CalledProcessError as e:
135
+ raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}")
136
+
137
+
138
+ def convert_video_to_30fps_and_clip(
139
+ input_video_path,
140
+ output_video_path,
141
+ duration_s=2,
142
+ target_fps=30,
143
+ crop_width=None,
144
+ crop_height=None
145
+ ):
146
+ # Get input video dimensions using ffprobe
147
+ if crop_width and crop_height:
148
+ probe_cmd = [
149
+ 'ffprobe', '-v', 'error', '-select_streams', 'v:0',
150
+ '-show_entries', 'stream=width,height',
151
+ '-of', 'json', input_video_path
152
+ ]
153
+ probe_result = subprocess.run(probe_cmd, capture_output=True, text=True, check=True)
154
+ video_info = json.loads(probe_result.stdout)
155
+ w = video_info['streams'][0]['width']
156
+ h = video_info['streams'][0]['height']
157
+
158
+ # Clamp crop size to not exceed actual dimensions
159
+ crop_width = min(crop_width, w)
160
+ crop_height = min(crop_height, h)
161
+
162
+ # Center crop offsets
163
+ crop_x = max((w - crop_width) // 2, 0)
164
+ crop_y = max((h - crop_height) // 2, 0)
165
+ crop_filter = f"crop={crop_width}:{crop_height}:{crop_x}:{crop_y}"
166
+ else:
167
+ crop_filter = None
168
+
169
+ cmd = [
170
+ 'ffmpeg',
171
+ '-i', input_video_path,
172
+ '-r', str(target_fps),
173
+ '-t', str(duration_s),
174
+ ]
175
+
176
+ if crop_filter:
177
+ cmd += ['-vf', crop_filter]
178
+
179
+ cmd += [
180
+ '-c:v', 'libx264',
181
+ '-c:a', 'aac',
182
+ '-strict', 'experimental',
183
+ '-y',
184
+ '-loglevel', 'error',
185
+ output_video_path
186
+ ]
187
+
188
+ try:
189
+ subprocess.run(cmd, check=True, capture_output=True, text=True)
190
+ except subprocess.CalledProcessError as e:
191
+ raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}")
192
+
193
+ def get_frames_count(video_file):
194
+
195
+ # Get video information
196
+ cap = cv2.VideoCapture(video_file)
197
+ if not cap.isOpened():
198
+ error_msg = "Cannot open video file"
199
+ gr.Warning(error_msg)
200
+
201
+ orig_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
202
+ orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
203
+ orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
204
+
205
+ cap.release()
206
+
207
+ return orig_frame_count
208
+
209
+ def calculate_time_required(input_video, rc_bool):
210
+
211
+ frames_count = get_frames_count(input_video)
212
+
213
+ chunks = frames_count // 77 + 1
214
+
215
+
216
+ if rc_bool:
217
+ pose2d_tracking_duration_s = 75
218
+ iteration_per_step_s = 13
219
+ else:
220
+ pose2d_tracking_duration_s = 50
221
+ iteration_per_step_s = 12
222
+
223
+ time_required = pose2d_tracking_duration_s + iteration_per_step_s * 20 * chunks
224
+ print(f'for frames_count:{frames_count} doing {chunks} chunks the time_required is {time_required}')
225
+ return time_required
226
+
227
+ def update_time_required(input_video, rc_str):
228
+
229
+ if input_video is None:
230
+ return gr.update(value="⌚ Zero GPU Required: --")
231
+
232
+ rc_bool = rc_mapping[rc_str]
233
+
234
+ duration_s = calculate_time_required(input_video, rc_bool)
235
+ duration_m = duration_s / 60
236
+
237
+ return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)")
238
+
239
+ def get_duration(input_video, edited_frame, rc_bool, session_id, progress):
240
+
241
+ return calculate_time_required(input_video, rc_bool)
242
+
243
+
244
+ @spaces.GPU(duration=get_duration)
245
+ def _animate(input_video, edited_frame, rc_bool, session_id = None, progress=gr.Progress(track_tqdm=True),):
246
+
247
+ if session_id is None:
248
+ session_id = uuid.uuid4().hex
249
+
250
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
251
+ os.makedirs(output_dir, exist_ok=True)
252
+
253
+ preprocess_dir = os.path.join(output_dir, "preprocess_dir")
254
+ os.makedirs(preprocess_dir, exist_ok=True)
255
+
256
+ output_video_path = os.path.join(output_dir, 'result.mp4')
257
+
258
+ # --- Measure preprocess time ---
259
+ start_preprocess = time.time()
260
+
261
+ # w = 720
262
+ # h = 480
263
+
264
+ # w = 720
265
+ # h = 1280
266
+
267
+ w = 480
268
+ h = 832
269
+
270
+ # w = 480
271
+ # h = 720
272
+
273
+ tag_string = "retarget_flag"
274
+
275
+ if rc_bool:
276
+ tag_string = "replace_flag"
277
+
278
+ sh("python ./wan/modules/animate/preprocess/preprocess_data.py "
279
+ "--ckpt_path ./Wan2.2-Animate-14B/process_checkpoint "
280
+ f"--video_path {input_video} "
281
+ f"--refer_path {edited_frame} "
282
+ f"--save_path {preprocess_dir} "
283
+ f"--resolution_area {w} {h} --{tag_string} "
284
+ )
285
+
286
+ preprocess_time = time.time() - start_preprocess
287
+ print(f"Preprocess took {preprocess_time:.2f} seconds")
288
+
289
+ # --- Measure generate time ---
290
+ start_generate = time.time()
291
+
292
+ generate(wan_animate, preprocess_dir, output_video_path, rc_bool)
293
+
294
+ generate_time = time.time() - start_generate
295
+ print(f"Generate took {generate_time:.2f} seconds")
296
+
297
+ # --- Optional total time ---
298
+ total_time = preprocess_time + generate_time
299
+ print(f"Total time: {total_time:.2f} seconds")
300
+
301
+ return output_video_path
302
+
303
+ def animate_scene(input_video, edited_frame, rc_str, session_id = None, progress=gr.Progress(track_tqdm=True),):
304
+
305
+ if not input_video:
306
+ raise gr.Error("Please provide an video")
307
+
308
+ if not edited_frame:
309
+ raise gr.Error("Please provide an image")
310
+
311
+ if session_id is None:
312
+ session_id = uuid.uuid4().hex
313
+
314
+ rc_bool = rc_mapping[rc_str]
315
+
316
+ output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
317
+ os.makedirs(output_dir, exist_ok=True)
318
+
319
+ input_audio_path = os.path.join(output_dir, 'input_audio.wav')
320
+
321
+ extract_audio_from_video_ffmpeg(input_video, input_audio_path)
322
+
323
+ output_video_path = _animate(input_video, edited_frame, rc_bool, session_id, progress)
324
+
325
+ final_video_path = os.path.join(output_dir, 'final_result.mp4')
326
+
327
+ preprocess_dir = os.path.join(output_dir, "preprocess_dir")
328
+ pose_video = os.path.join(preprocess_dir, 'src_pose.mp4')
329
+
330
+ if rc_bool:
331
+ mask_video = os.path.join(preprocess_dir, 'src_mask.mp4')
332
+ bg_video = os.path.join(preprocess_dir, 'src_bg.mp4')
333
+ face_video = os.path.join(preprocess_dir, 'src_face.mp4')
334
+ else:
335
+ mask_video = os.path.join(preprocess_dir, 'src_pose.mp4')
336
+ bg_video = os.path.join(preprocess_dir, 'src_pose.mp4')
337
+ face_video = os.path.join(preprocess_dir, 'src_pose.mp4')
338
+
339
+ combine_video_and_audio_ffmpeg(output_video_path, input_audio_path, final_video_path)
340
+
341
+ return final_video_path, pose_video, bg_video, mask_video, face_video
342
+
343
+ css = """
344
+ #col-container {
345
+ margin: 0 auto;
346
+ max-width: 1600px;
347
+ }
348
+
349
+ #step-column {
350
+ padding: 20px;
351
+ border-radius: 8px;
352
+ box-shadow: var(--card-shadow);
353
+ margin: 10px;
354
+ }
355
+
356
+ #col-showcase {
357
+ margin: 0 auto;
358
+ max-width: 1100px;
359
+ }
360
+
361
+ .button-gradient {
362
+ background: linear-gradient(45deg, rgb(255, 65, 108), rgb(255, 75, 43), rgb(255, 155, 0), rgb(255, 65, 108)) 0% 0% / 400% 400%;
363
+ border: none;
364
+ padding: 14px 28px;
365
+ font-size: 16px;
366
+ font-weight: bold;
367
+ color: white;
368
+ border-radius: 10px;
369
+ cursor: pointer;
370
+ transition: 0.3s ease-in-out;
371
+ animation: 2s linear 0s infinite normal none running gradientAnimation;
372
+ box-shadow: rgba(255, 65, 108, 0.6) 0px 4px 10px;
373
+ }
374
+
375
+ .toggle-container {
376
+ display: inline-flex;
377
+ background-color: #ffd6ff; /* light pink background */
378
+ border-radius: 9999px;
379
+ padding: 4px;
380
+ position: relative;
381
+ width: fit-content;
382
+ font-family: sans-serif;
383
+ }
384
+
385
+ .toggle-container input[type="radio"] {
386
+ display: none;
387
+ }
388
+
389
+ .toggle-container label {
390
+ position: relative;
391
+ z-index: 2;
392
+ flex: 1;
393
+ text-align: center;
394
+ font-weight: 700;
395
+ color: #4b2ab5; /* dark purple text for unselected */
396
+ padding: 6px 22px;
397
+ border-radius: 9999px;
398
+ cursor: pointer;
399
+ transition: color 0.25s ease;
400
+ }
401
+
402
+ /* Moving highlight */
403
+ .toggle-highlight {
404
+ position: absolute;
405
+ top: 4px;
406
+ left: 4px;
407
+ width: calc(50% - 4px);
408
+ height: calc(100% - 8px);
409
+ background-color: #4b2ab5; /* dark purple background */
410
+ border-radius: 9999px;
411
+ transition: transform 0.25s ease;
412
+ z-index: 1;
413
+ }
414
+
415
+ /* When "True" is checked */
416
+ #true:checked ~ label[for="true"] {
417
+ color: #ffd6ff; /* light pink text */
418
+ }
419
+
420
+ /* When "False" is checked */
421
+ #false:checked ~ label[for="false"] {
422
+ color: #ffd6ff; /* light pink text */
423
+ }
424
+
425
+ /* Move highlight to right side when False is checked */
426
+ #false:checked ~ .toggle-highlight {
427
+ transform: translateX(100%);
428
+ }
429
+ """
430
+ def start_session(request: gr.Request):
431
+
432
+ return request.session_hash
433
+
434
+ def cleanup(request: gr.Request):
435
+
436
+ sid = request.session_hash
437
+
438
+ if sid:
439
+ d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
440
+ shutil.rmtree(d1, ignore_errors=True)
441
+
442
+ with gr.Blocks(css=css, title="Wan 2.2 Animate --replace", theme=gr.themes.Ocean()) as demo:
443
+
444
+ session_state = gr.State()
445
+ demo.load(start_session, outputs=[session_state])
446
+
447
+ with gr.Column(elem_id="col-container"):
448
+ with gr.Row():
449
+ gr.HTML(
450
+ """
451
+ <div style="text-align: center;">
452
+ <p style="font-size:16px; display: inline; margin: 0;">
453
+ <strong>Wan2.2-Animate-14B </strong>
454
+ </p>
455
+ <a href="https://huggingface.co/Wan-AI/Wan2.2-Animate-14B" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
456
+ [Model]
457
+ </a>
458
+ <div style="text-align: center;">
459
+ <p style="font-size:16px; display: inline; margin: 0;">
460
+ HF Space By:
461
+ </p>
462
+ <a href="https://huggingface.co/alexnasa" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
463
+ <img src="https://img.shields.io/badge/🤗-Follow Me-yellow.svg">
464
+ </a>
465
+ </div>
466
+ """
467
+ )
468
+ with gr.Row():
469
+ with gr.Column(elem_id="step-column"):
470
+ gr.HTML("""
471
+ <div>
472
+ <span style="font-size: 24px;">1. Upload a Video</span><br>
473
+ </div>
474
+ """)
475
+ input_video = gr.Video(label="Input Video", height=512)
476
+
477
+
478
+ with gr.Column(elem_id="step-column"):
479
+ gr.HTML("""
480
+ <div>
481
+ <span style="font-size: 24px;">2. Upload a Ref Image</span><br>
482
+ </div>
483
+ """)
484
+ edited_frame = gr.Image(label="Ref Image", type="filepath", height=512)
485
+ gr.HTML("""
486
+ <div>
487
+ <span style="font-size: 24px;">3. Choose Mode</span><br>
488
+ </div>
489
+ """)
490
+ replace_character_string = gr.Radio(
491
+ ["Video → Ref Image", "Video ← Ref Image"], value="Video → Ref Image", show_label=False
492
+ )
493
+
494
+ with gr.Column(elem_id="step-column"):
495
+ gr.HTML("""
496
+ <div>
497
+ <span style="font-size: 24px;">4. Wan Animate it!</span><br>
498
+ </div>
499
+ """)
500
+ output_video = gr.Video(label="Edited Video", height=512)
501
+
502
+ time_required = gr.Text(value="⌚ Zero GPU Required: --", show_label=False)
503
+ action_button = gr.Button("Wan Animate 🦆", variant='primary', elem_classes="button-gradient")
504
+
505
+ with gr.Accordion("Preprocessed Data", open=False, visible=False):
506
+ pose_video = gr.Video(label="Pose Video", height=512)
507
+ bg_video = gr.Video(label="Background Video", height=512)
508
+ face_video = gr.Video(label="Face Video", height=512)
509
+ mask_video = gr.Video(label="Mask Video", height=512)
510
+
511
+ with gr.Row():
512
+ with gr.Column(elem_id="col-showcase"):
513
+
514
+ gr.Examples(
515
+ examples=[
516
+
517
+ [
518
+ "./examples/desi.mp4",
519
+ "./examples/desi.png",
520
+ "Video ← Ref Image"
521
+ ],
522
+
523
+ [
524
+ "./examples/paul.mp4",
525
+ "./examples/man.png",
526
+ "Video → Ref Image"
527
+ ],
528
+
529
+
530
+ ],
531
+ inputs=[input_video, edited_frame, replace_character_string],
532
+ outputs=[output_video, pose_video, bg_video, mask_video, face_video],
533
+ fn=animate_scene,
534
+ cache_examples=True,
535
+ )
536
+
537
+ action_button.click(fn=animate_scene, inputs=[input_video, edited_frame, replace_character_string, session_state], outputs=[output_video, pose_video, bg_video, mask_video, face_video])
538
+
539
+ input_video.upload(preprocess_video, inputs=[input_video, session_state], outputs=[input_video]).then(update_time_required, inputs=[input_video, replace_character_string], outputs=[time_required])
540
+ replace_character_string.change(update_time_required, inputs=[input_video, replace_character_string], outputs=[time_required])
541
+
542
+ if __name__ == "__main__":
543
+ demo.queue()
544
+ demo.unload(cleanup)
545
+ demo.launch(ssr_mode=False, share=True)
546
+
examples/desi.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e02e84151e5625fb3863ebdf65dfab06940afac5fbd471db3b46a4ebd84b248d
3
+ size 551595
examples/desi.png ADDED

Git LFS Details

  • SHA256: 3f1a6ac41049380ddb43dcfb9efe1a0b6c561c4bb4132332fe07a82df263df66
  • Pointer size: 131 Bytes
  • Size of remote file: 477 kB
examples/man.png ADDED

Git LFS Details

  • SHA256: 6dc2c61f01a0290a8478fe3b494cf69ca054b2502b00b0be8c68a42ac544d5b5
  • Pointer size: 132 Bytes
  • Size of remote file: 2.5 MB
examples/paul.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb065c2d24bff8a49955389f94c05c80d39638410dad8082f7e0eb7f2dc5c672
3
+ size 1029922
generate.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import argparse
3
+ import logging
4
+ import os
5
+ import sys
6
+ import warnings
7
+ from datetime import datetime
8
+
9
+ warnings.filterwarnings('ignore')
10
+
11
+ import random
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+ from PIL import Image
16
+
17
+ import wan
18
+ from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
19
+ from wan.distributed.util import init_distributed_group
20
+ from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
21
+ from wan.utils.utils import merge_video_audio, save_video, str2bool
22
+
23
+
24
+ EXAMPLE_PROMPT = {
25
+ "t2v-A14B": {
26
+ "prompt":
27
+ "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
28
+ },
29
+ "i2v-A14B": {
30
+ "prompt":
31
+ "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
32
+ "image":
33
+ "examples/i2v_input.JPG",
34
+ },
35
+ "ti2v-5B": {
36
+ "prompt":
37
+ "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
38
+ },
39
+ "animate-14B": {
40
+ "prompt": "视频中的人在做动作",
41
+ "video": "",
42
+ "pose": "",
43
+ "mask": "",
44
+ },
45
+ "s2v-14B": {
46
+ "prompt":
47
+ "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
48
+ "image":
49
+ "examples/i2v_input.JPG",
50
+ "audio":
51
+ "examples/talk.wav",
52
+ "tts_prompt_audio":
53
+ "examples/zero_shot_prompt.wav",
54
+ "tts_prompt_text":
55
+ "希望你以后能够做的比我还好呦。",
56
+ "tts_text":
57
+ "收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。"
58
+ },
59
+ }
60
+
61
+
62
+ def _validate_args(args):
63
+ # Basic check
64
+ assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
65
+ assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
66
+ assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
67
+
68
+ if args.prompt is None:
69
+ args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
70
+ if args.image is None and "image" in EXAMPLE_PROMPT[args.task]:
71
+ args.image = EXAMPLE_PROMPT[args.task]["image"]
72
+ if args.audio is None and args.enable_tts is False and "audio" in EXAMPLE_PROMPT[args.task]:
73
+ args.audio = EXAMPLE_PROMPT[args.task]["audio"]
74
+ if (args.tts_prompt_audio is None or args.tts_text is None) and args.enable_tts is True and "audio" in EXAMPLE_PROMPT[args.task]:
75
+ args.tts_prompt_audio = EXAMPLE_PROMPT[args.task]["tts_prompt_audio"]
76
+ args.tts_prompt_text = EXAMPLE_PROMPT[args.task]["tts_prompt_text"]
77
+ args.tts_text = EXAMPLE_PROMPT[args.task]["tts_text"]
78
+
79
+ if args.task == "i2v-A14B":
80
+ assert args.image is not None, "Please specify the image path for i2v."
81
+
82
+ cfg = WAN_CONFIGS[args.task]
83
+
84
+ if args.sample_steps is None:
85
+ args.sample_steps = cfg.sample_steps
86
+
87
+ if args.sample_shift is None:
88
+ args.sample_shift = cfg.sample_shift
89
+
90
+ if args.sample_guide_scale is None:
91
+ args.sample_guide_scale = cfg.sample_guide_scale
92
+
93
+ if args.frame_num is None:
94
+ args.frame_num = cfg.frame_num
95
+
96
+ args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
97
+ 0, sys.maxsize)
98
+ # Size check
99
+ if not 's2v' in args.task:
100
+ assert args.size in SUPPORTED_SIZES[
101
+ args.
102
+ task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
103
+
104
+
105
+ class _Args:
106
+ pass
107
+
108
+ def _parse_args():
109
+ args = _Args()
110
+
111
+ # core generation options
112
+ args.task = "animate-14B"
113
+ # args.size = "1280*720"
114
+ args.size = "720*1280"
115
+ args.frame_num = None
116
+ args.ckpt_dir = "./Wan2.2-Animate-14B/"
117
+ args.offload_model = True
118
+ args.ulysses_size = 1
119
+ args.t5_fsdp = False
120
+ args.t5_cpu = False
121
+ args.dit_fsdp = False
122
+ args.prompt = None
123
+ args.use_prompt_extend = False
124
+ args.prompt_extend_method = "local_qwen" # ["dashscope", "local_qwen"]
125
+ args.prompt_extend_model = None
126
+ args.prompt_extend_target_lang = "zh" # ["zh", "en"]
127
+ args.base_seed = 0
128
+ args.image = None
129
+ args.sample_solver = "unipc" # ['unipc', 'dpm++']
130
+ args.sample_steps = None
131
+ args.sample_shift = None
132
+ args.sample_guide_scale = None
133
+ args.convert_model_dtype = False
134
+
135
+ # animate
136
+ args.refert_num = 1
137
+
138
+ # s2v-only
139
+ args.num_clip = None
140
+ args.audio = None
141
+ args.enable_tts = False
142
+ args.tts_prompt_audio = None
143
+ args.tts_prompt_text = None
144
+ args.tts_text = None
145
+ args.pose_video = None
146
+ args.start_from_ref = False
147
+ args.infer_frames = 80
148
+
149
+ _validate_args(args)
150
+ return args
151
+
152
+
153
+
154
+ def _init_logging(rank):
155
+ # logging
156
+ if rank == 0:
157
+ # set format
158
+ logging.basicConfig(
159
+ level=logging.INFO,
160
+ format="[%(asctime)s] %(levelname)s: %(message)s",
161
+ handlers=[logging.StreamHandler(stream=sys.stdout)])
162
+ else:
163
+ logging.basicConfig(level=logging.ERROR)
164
+
165
+ def load_model(use_relighting_lora = False):
166
+
167
+ cfg = WAN_CONFIGS["animate-14B"]
168
+
169
+ return wan.WanAnimate(
170
+ config=cfg,
171
+ checkpoint_dir="./Wan2.2-Animate-14B/",
172
+ device_id=0,
173
+ rank=0,
174
+ t5_fsdp=False,
175
+ dit_fsdp=False,
176
+ use_sp=False,
177
+ t5_cpu=False,
178
+ convert_model_dtype=False,
179
+ use_relighting_lora=use_relighting_lora
180
+ )
181
+
182
+ def generate(wan_animate, preprocess_dir, save_file, replace_flag = False):
183
+ args = _parse_args()
184
+ rank = int(os.getenv("RANK", 0))
185
+ world_size = int(os.getenv("WORLD_SIZE", 1))
186
+ local_rank = int(os.getenv("LOCAL_RANK", 0))
187
+ device = local_rank
188
+ _init_logging(rank)
189
+
190
+ cfg = WAN_CONFIGS[args.task]
191
+
192
+ logging.info(f"Input prompt: {args.prompt}")
193
+ img = None
194
+ if args.image is not None:
195
+ img = Image.open(args.image).convert("RGB")
196
+ logging.info(f"Input image: {args.image}")
197
+
198
+ print(f'rank:{rank}')
199
+
200
+
201
+
202
+ logging.info(f"Generating video ...")
203
+ video = wan_animate.generate(
204
+ src_root_path=preprocess_dir,
205
+ replace_flag=replace_flag,
206
+ refert_num = args.refert_num,
207
+ clip_len=args.frame_num,
208
+ shift=args.sample_shift,
209
+ sample_solver=args.sample_solver,
210
+ sampling_steps=args.sample_steps,
211
+ guide_scale=args.sample_guide_scale,
212
+ seed=args.base_seed,
213
+ offload_model=args.offload_model)
214
+ if rank == 0:
215
+
216
+ save_video(
217
+ tensor=video[None],
218
+ save_file=save_file,
219
+ fps=cfg.sample_fps,
220
+ nrow=1,
221
+ normalize=True,
222
+ value_range=(-1, 1))
223
+ # if "s2v" in args.task:
224
+ # if args.enable_tts is False:
225
+ # merge_video_audio(video_path=args.save_file, audio_path=args.audio)
226
+ # else:
227
+ # merge_video_audio(video_path=args.save_file, audio_path="tts.wav")
228
+ del video
229
+
230
+ torch.cuda.synchronize()
231
+ if dist.is_initialized():
232
+ dist.barrier()
233
+ dist.destroy_process_group()
234
+
235
+ logging.info("Finished.")
236
+
pyproject.toml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "wan"
7
+ version = "2.2.0"
8
+ description = "Wan: Open and Advanced Large-Scale Video Generative Models"
9
+ authors = [
10
+ { name = "Wan Team", email = "wan.ai@alibabacloud.com" }
11
+ ]
12
+ license = { file = "LICENSE.txt" }
13
+ readme = "README.md"
14
+ requires-python = ">=3.10,<4.0"
15
+ dependencies = [
16
+ "torch>=2.4.0",
17
+ "torchvision>=0.19.0",
18
+ "opencv-python>=4.9.0.80",
19
+ "diffusers>=0.31.0",
20
+ "transformers>=4.49.0",
21
+ "tokenizers>=0.20.3",
22
+ "accelerate>=1.1.1",
23
+ "tqdm",
24
+ "imageio",
25
+ "easydict",
26
+ "ftfy",
27
+ "dashscope",
28
+ "imageio-ffmpeg",
29
+ "flash_attn",
30
+ "numpy>=1.23.5,<2"
31
+ ]
32
+
33
+ [project.optional-dependencies]
34
+ dev = [
35
+ "pytest",
36
+ "black",
37
+ "flake8",
38
+ "isort",
39
+ "mypy",
40
+ "huggingface-hub[cli]"
41
+ ]
42
+
43
+ [project.urls]
44
+ homepage = "https://wanxai.com"
45
+ documentation = "https://github.com/Wan-Video/Wan2.2"
46
+ repository = "https://github.com/Wan-Video/Wan2.2"
47
+ huggingface = "https://huggingface.co/Wan-AI/"
48
+ modelscope = "https://modelscope.cn/organization/Wan-AI"
49
+ discord = "https://discord.gg/p5XbdQV7"
50
+
51
+ [tool.setuptools]
52
+ packages = ["wan"]
53
+
54
+ [tool.setuptools.package-data]
55
+ "wan" = ["**/*.py"]
56
+
57
+ [tool.black]
58
+ line-length = 88
59
+
60
+ [tool.isort]
61
+ profile = "black"
62
+
63
+ [tool.mypy]
64
+ strict = true
65
+
66
+
requirements.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.8.0
2
+ decord
3
+ peft
4
+ pandas
5
+ matplotlib
6
+ loguru
7
+ sentencepiece
8
+ dashscope
9
+ ftfy
10
+ diffusers
11
+ opencv-python
12
+ moviepy
13
+ torchvision
14
+ torchaudio
15
+ transformers
16
+ tokenizers
17
+ accelerate
18
+ tqdm
19
+ imageio[ffmpeg]
20
+ easydict
21
+ imageio-ffmpeg
22
+ numpy>=1.23.5,<2
23
+ hydra-core
24
+ iopath
25
+ pytest
26
+ pillow
27
+ fvcore
28
+ librosa
29
+ flash-attn
30
+ onnxruntime-gpu
31
+ flash-attn-3 @ https://huggingface.co/alexnasa/flash-attn-3/resolve/main/128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
wan/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from . import configs, distributed, modules
3
+ from .image2video import WanI2V
4
+ from .speech2video import WanS2V
5
+ from .text2video import WanT2V
6
+ from .textimage2video import WanTI2V
7
+ from .animate import WanAnimate
wan/animate.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import logging
3
+ import math
4
+ import os
5
+ import cv2
6
+ import types
7
+ from copy import deepcopy
8
+ from functools import partial
9
+ from einops import rearrange
10
+ import numpy as np
11
+ import torch
12
+
13
+ import torch.distributed as dist
14
+ from peft import set_peft_model_state_dict
15
+ from decord import VideoReader
16
+ from tqdm import tqdm
17
+ import torch.nn.functional as F
18
+ from .distributed.fsdp import shard_model
19
+ from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
20
+ from .distributed.util import get_world_size
21
+
22
+ from .modules.animate import WanAnimateModel
23
+ from .modules.animate import CLIPModel
24
+ from .modules.t5 import T5EncoderModel
25
+ from .modules.vae2_1 import Wan2_1_VAE
26
+ from .modules.animate.animate_utils import TensorList, get_loraconfig
27
+ from .utils.fm_solvers import (
28
+ FlowDPMSolverMultistepScheduler,
29
+ get_sampling_sigmas,
30
+ retrieve_timesteps,
31
+ )
32
+ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
33
+
34
+
35
+
36
+ class WanAnimate:
37
+
38
+ def __init__(
39
+ self,
40
+ config,
41
+ checkpoint_dir,
42
+ device_id=0,
43
+ rank=0,
44
+ t5_fsdp=False,
45
+ dit_fsdp=False,
46
+ use_sp=False,
47
+ t5_cpu=False,
48
+ init_on_cpu=True,
49
+ convert_model_dtype=False,
50
+ use_relighting_lora=False
51
+ ):
52
+ r"""
53
+ Initializes the generation model components.
54
+
55
+ Args:
56
+ config (EasyDict):
57
+ Object containing model parameters initialized from config.py
58
+ checkpoint_dir (`str`):
59
+ Path to directory containing model checkpoints
60
+ device_id (`int`, *optional*, defaults to 0):
61
+ Id of target GPU device
62
+ rank (`int`, *optional*, defaults to 0):
63
+ Process rank for distributed training
64
+ t5_fsdp (`bool`, *optional*, defaults to False):
65
+ Enable FSDP sharding for T5 model
66
+ dit_fsdp (`bool`, *optional*, defaults to False):
67
+ Enable FSDP sharding for DiT model
68
+ use_sp (`bool`, *optional*, defaults to False):
69
+ Enable distribution strategy of sequence parallel.
70
+ t5_cpu (`bool`, *optional*, defaults to False):
71
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
72
+ init_on_cpu (`bool`, *optional*, defaults to True):
73
+ Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
74
+ convert_model_dtype (`bool`, *optional*, defaults to False):
75
+ Convert DiT model parameters dtype to 'config.param_dtype'.
76
+ Only works without FSDP.
77
+ use_relighting_lora (`bool`, *optional*, defaults to False):
78
+ Whether to use relighting lora for character replacement.
79
+ """
80
+ self.device = torch.device(f"cuda:{device_id}")
81
+ self.config = config
82
+ self.rank = rank
83
+ self.t5_cpu = t5_cpu
84
+ self.init_on_cpu = init_on_cpu
85
+
86
+ self.num_train_timesteps = config.num_train_timesteps
87
+ self.param_dtype = config.param_dtype
88
+
89
+ if t5_fsdp or dit_fsdp or use_sp:
90
+ self.init_on_cpu = False
91
+
92
+ shard_fn = partial(shard_model, device_id=device_id)
93
+ self.text_encoder = T5EncoderModel(
94
+ text_len=config.text_len,
95
+ dtype=config.t5_dtype,
96
+ device=torch.device('cpu'),
97
+ checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
98
+ tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
99
+ shard_fn=shard_fn if t5_fsdp else None,
100
+ )
101
+
102
+ self.clip = CLIPModel(
103
+ dtype=torch.float16,
104
+ device=self.device,
105
+ checkpoint_path=os.path.join(checkpoint_dir,
106
+ config.clip_checkpoint),
107
+ tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
108
+
109
+ self.vae = Wan2_1_VAE(
110
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
111
+ device=self.device)
112
+
113
+ logging.info(f"Creating WanAnimate from {checkpoint_dir}")
114
+
115
+ if not dit_fsdp:
116
+ self.noise_model = WanAnimateModel.from_pretrained(
117
+ checkpoint_dir,
118
+ torch_dtype=self.param_dtype,
119
+ device_map=self.device)
120
+ else:
121
+ self.noise_model = WanAnimateModel.from_pretrained(
122
+ checkpoint_dir, torch_dtype=self.param_dtype)
123
+
124
+ self.noise_model = self._configure_model(
125
+ model=self.noise_model,
126
+ use_sp=use_sp,
127
+ dit_fsdp=dit_fsdp,
128
+ shard_fn=shard_fn,
129
+ convert_model_dtype=convert_model_dtype,
130
+ use_lora=use_relighting_lora,
131
+ checkpoint_dir=checkpoint_dir,
132
+ config=config
133
+ )
134
+
135
+ if use_sp:
136
+ self.sp_size = get_world_size()
137
+ else:
138
+ self.sp_size = 1
139
+
140
+ self.sample_neg_prompt = config.sample_neg_prompt
141
+ self.sample_prompt = config.prompt
142
+
143
+
144
+ def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
145
+ convert_model_dtype, use_lora, checkpoint_dir, config):
146
+ """
147
+ Configures a model object. This includes setting evaluation modes,
148
+ applying distributed parallel strategy, and handling device placement.
149
+
150
+ Args:
151
+ model (torch.nn.Module):
152
+ The model instance to configure.
153
+ use_sp (`bool`):
154
+ Enable distribution strategy of sequence parallel.
155
+ dit_fsdp (`bool`):
156
+ Enable FSDP sharding for DiT model.
157
+ shard_fn (callable):
158
+ The function to apply FSDP sharding.
159
+ convert_model_dtype (`bool`):
160
+ Convert DiT model parameters dtype to 'config.param_dtype'.
161
+ Only works without FSDP.
162
+
163
+ Returns:
164
+ torch.nn.Module:
165
+ The configured model.
166
+ """
167
+ model.eval().requires_grad_(False)
168
+
169
+ if use_sp:
170
+ for block in model.blocks:
171
+ block.self_attn.forward = types.MethodType(
172
+ sp_attn_forward, block.self_attn)
173
+
174
+ model.use_context_parallel = True
175
+
176
+ if dist.is_initialized():
177
+ dist.barrier()
178
+
179
+ if use_lora:
180
+ logging.info("Loading Relighting Lora. ")
181
+ lora_config = get_loraconfig(
182
+ transformer=model,
183
+ rank=128,
184
+ alpha=128
185
+ )
186
+ model.add_adapter(lora_config)
187
+ lora_path = os.path.join(checkpoint_dir, config.lora_checkpoint)
188
+ peft_state_dict = torch.load(lora_path)["state_dict"]
189
+ set_peft_model_state_dict(model, peft_state_dict)
190
+
191
+ if dit_fsdp:
192
+ model = shard_fn(model, use_lora=use_lora)
193
+ else:
194
+ if convert_model_dtype:
195
+ model.to(self.param_dtype)
196
+ if not self.init_on_cpu:
197
+ model.to(self.device)
198
+
199
+ return model
200
+
201
+ def inputs_padding(self, array, target_len):
202
+ idx = 0
203
+ flip = False
204
+ target_array = []
205
+ while len(target_array) < target_len:
206
+ target_array.append(deepcopy(array[idx]))
207
+ if flip:
208
+ idx -= 1
209
+ else:
210
+ idx += 1
211
+ if idx == 0 or idx == len(array) - 1:
212
+ flip = not flip
213
+ return target_array[:target_len]
214
+
215
+ def get_valid_len(self, real_len, clip_len=81, overlap=1):
216
+ real_clip_len = clip_len - overlap
217
+ last_clip_num = (real_len - overlap) % real_clip_len
218
+ if last_clip_num == 0:
219
+ extra = 0
220
+ else:
221
+ extra = real_clip_len - last_clip_num
222
+ target_len = real_len + extra
223
+ return target_len
224
+
225
+
226
+ def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
227
+ if mask_pixel_values is None:
228
+ msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
229
+ else:
230
+ msk = mask_pixel_values.clone()
231
+ msk[:, :mask_len] = 1
232
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
233
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
234
+ msk = msk.transpose(1, 2)[0]
235
+ return msk
236
+
237
+ def padding_resize(self, img_ori, height=512, width=512, padding_color=(0, 0, 0), interpolation=cv2.INTER_LINEAR):
238
+ ori_height = img_ori.shape[0]
239
+ ori_width = img_ori.shape[1]
240
+ channel = img_ori.shape[2]
241
+
242
+ img_pad = np.zeros((height, width, channel))
243
+ if channel == 1:
244
+ img_pad[:, :, 0] = padding_color[0]
245
+ else:
246
+ img_pad[:, :, 0] = padding_color[0]
247
+ img_pad[:, :, 1] = padding_color[1]
248
+ img_pad[:, :, 2] = padding_color[2]
249
+
250
+ if (ori_height / ori_width) > (height / width):
251
+ new_width = int(height / ori_height * ori_width)
252
+ img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation)
253
+ padding = int((width - new_width) / 2)
254
+ if len(img.shape) == 2:
255
+ img = img[:, :, np.newaxis]
256
+ img_pad[:, padding: padding + new_width, :] = img
257
+ else:
258
+ new_height = int(width / ori_width * ori_height)
259
+ img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation)
260
+ padding = int((height - new_height) / 2)
261
+ if len(img.shape) == 2:
262
+ img = img[:, :, np.newaxis]
263
+ img_pad[padding: padding + new_height, :, :] = img
264
+
265
+ img_pad = np.uint8(img_pad)
266
+
267
+ return img_pad
268
+
269
+ def prepare_source(self, src_pose_path, src_face_path, src_ref_path):
270
+ pose_video_reader = VideoReader(src_pose_path)
271
+ pose_len = len(pose_video_reader)
272
+ pose_idxs = list(range(pose_len))
273
+ cond_images = pose_video_reader.get_batch(pose_idxs).asnumpy()
274
+
275
+ face_video_reader = VideoReader(src_face_path)
276
+ face_len = len(face_video_reader)
277
+ face_idxs = list(range(face_len))
278
+ face_images = face_video_reader.get_batch(face_idxs).asnumpy()
279
+ height, width = cond_images[0].shape[:2]
280
+ refer_images = cv2.imread(src_ref_path)[..., ::-1]
281
+ refer_images = self.padding_resize(refer_images, height=height, width=width)
282
+ return cond_images, face_images, refer_images
283
+
284
+ def prepare_source_for_replace(self, src_bg_path, src_mask_path):
285
+ bg_video_reader = VideoReader(src_bg_path)
286
+ bg_len = len(bg_video_reader)
287
+ bg_idxs = list(range(bg_len))
288
+ bg_images = bg_video_reader.get_batch(bg_idxs).asnumpy()
289
+
290
+ mask_video_reader = VideoReader(src_mask_path)
291
+ mask_len = len(mask_video_reader)
292
+ mask_idxs = list(range(mask_len))
293
+ mask_images = mask_video_reader.get_batch(mask_idxs).asnumpy()
294
+ mask_images = mask_images[:, :, :, 0] / 255
295
+ return bg_images, mask_images
296
+
297
+ def generate(
298
+ self,
299
+ src_root_path,
300
+ replace_flag=False,
301
+ clip_len=77,
302
+ refert_num=1,
303
+ shift=5.0,
304
+ sample_solver='dpm++',
305
+ sampling_steps=20,
306
+ guide_scale=1,
307
+ input_prompt="",
308
+ n_prompt="",
309
+ seed=-1,
310
+ offload_model=True,
311
+ ):
312
+ r"""
313
+ Generates video frames from input image using diffusion process.
314
+
315
+ Args:
316
+ src_root_path ('str'):
317
+ Process output path
318
+ replace_flag (`bool`, *optional*, defaults to False):
319
+ Whether to use character replace.
320
+ clip_len (`int`, *optional*, defaults to 77):
321
+ How many frames to generate per clips. The number should be 4n+1
322
+ refert_num (`int`, *optional*, defaults to 1):
323
+ How many frames used for temporal guidance. Recommended to be 1 or 5.
324
+ shift (`float`, *optional*, defaults to 5.0):
325
+ Noise schedule shift parameter.
326
+ sample_solver (`str`, *optional*, defaults to 'dpm++'):
327
+ Solver used to sample the video.
328
+ sampling_steps (`int`, *optional*, defaults to 20):
329
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
330
+ guide_scale (`float` or tuple[`float`], *optional*, defaults 1.0):
331
+ Classifier-free guidance scale. We only use it for expression control.
332
+ In most cases, it's not necessary and faster generation can be achieved without it.
333
+ When expression adjustments are needed, you may consider using this feature.
334
+ input_prompt (`str`):
335
+ Text prompt for content generation. We don't recommend custom prompts (although they work)
336
+ n_prompt (`str`, *optional*, defaults to ""):
337
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
338
+ seed (`int`, *optional*, defaults to -1):
339
+ Random seed for noise generation. If -1, use random seed
340
+ offload_model (`bool`, *optional*, defaults to True):
341
+ If True, offloads models to CPU during generation to save VRAM
342
+
343
+ Returns:
344
+ torch.Tensor:
345
+ Generated video frames tensor. Dimensions: (C, N, H, W) where:
346
+ - C: Color channels (3 for RGB)
347
+ - N: Number of frames
348
+ - H: Frame height
349
+ - W: Frame width
350
+ """
351
+ assert refert_num == 1 or refert_num == 5, "refert_num should be 1 or 5."
352
+
353
+ seed_g = torch.Generator(device=self.device)
354
+ seed_g.manual_seed(seed)
355
+
356
+
357
+ if n_prompt == "":
358
+ n_prompt = self.sample_neg_prompt
359
+
360
+ if input_prompt == "":
361
+ input_prompt = self.sample_prompt
362
+
363
+ src_pose_path = os.path.join(src_root_path, "src_pose.mp4")
364
+ src_face_path = os.path.join(src_root_path, "src_face.mp4")
365
+ src_ref_path = os.path.join(src_root_path, "src_ref.png")
366
+
367
+ cond_images, face_images, refer_images = self.prepare_source(src_pose_path=src_pose_path, src_face_path=src_face_path, src_ref_path=src_ref_path)
368
+
369
+ if not self.t5_cpu:
370
+ self.text_encoder.model.to(self.device)
371
+ context = self.text_encoder([input_prompt], self.device)
372
+ context_null = self.text_encoder([n_prompt], self.device)
373
+ if offload_model:
374
+ self.text_encoder.model.cpu()
375
+ else:
376
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
377
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
378
+ context = [t.to(self.device) for t in context]
379
+ context_null = [t.to(self.device) for t in context_null]
380
+
381
+ real_frame_len = len(cond_images)
382
+ target_len = self.get_valid_len(real_frame_len, clip_len, overlap=refert_num)
383
+ logging.info('real frames: {} target frames: {}'.format(real_frame_len, target_len))
384
+ cond_images = self.inputs_padding(cond_images, target_len)
385
+ face_images = self.inputs_padding(face_images, target_len)
386
+
387
+ if replace_flag:
388
+ src_bg_path = os.path.join(src_root_path, "src_bg.mp4")
389
+ src_mask_path = os.path.join(src_root_path, "src_mask.mp4")
390
+ bg_images, mask_images = self.prepare_source_for_replace(src_bg_path, src_mask_path)
391
+ bg_images = self.inputs_padding(bg_images, target_len)
392
+ mask_images = self.inputs_padding(mask_images, target_len)
393
+ self.noise_model.disable_adapters()
394
+ else:
395
+ self.noise_model.disable_adapters()
396
+
397
+
398
+ height, width = refer_images.shape[:2]
399
+ start = 0
400
+ end = clip_len
401
+ all_out_frames = []
402
+ while True:
403
+ if start + refert_num >= len(cond_images):
404
+ break
405
+
406
+ if start == 0:
407
+ mask_reft_len = 0
408
+ else:
409
+ mask_reft_len = refert_num
410
+
411
+ batch = {
412
+ "conditioning_pixel_values": torch.zeros(1, 3, clip_len, height, width),
413
+ "bg_pixel_values": torch.zeros(1, 3, clip_len, height, width),
414
+ "mask_pixel_values": torch.zeros(1, 1, clip_len, height, width),
415
+ "face_pixel_values": torch.zeros(1, 3, clip_len, 512, 512),
416
+ "refer_pixel_values": torch.zeros(1, 3, height, width),
417
+ "refer_t_pixel_values": torch.zeros(refert_num, 3, height, width)
418
+ }
419
+
420
+ batch["conditioning_pixel_values"] = rearrange(
421
+ torch.tensor(np.stack(cond_images[start:end]) / 127.5 - 1),
422
+ "t h w c -> 1 c t h w",
423
+ )
424
+ batch["face_pixel_values"] = rearrange(
425
+ torch.tensor(np.stack(face_images[start:end]) / 127.5 - 1),
426
+ "t h w c -> 1 c t h w",
427
+ )
428
+
429
+ batch["refer_pixel_values"] = rearrange(
430
+ torch.tensor(refer_images / 127.5 - 1), "h w c -> 1 c h w"
431
+ )
432
+
433
+ if start > 0:
434
+ batch["refer_t_pixel_values"] = rearrange(
435
+ out_frames[0, :, -refert_num:].clone().detach(),
436
+ "c t h w -> t c h w",
437
+ )
438
+
439
+ batch["refer_t_pixel_values"] = rearrange(batch["refer_t_pixel_values"],
440
+ "t c h w -> 1 c t h w",
441
+ )
442
+
443
+ if replace_flag:
444
+ batch["bg_pixel_values"] = rearrange(
445
+ torch.tensor(np.stack(bg_images[start:end]) / 127.5 - 1),
446
+ "t h w c -> 1 c t h w",
447
+ )
448
+
449
+ batch["mask_pixel_values"] = rearrange(
450
+ torch.tensor(np.stack(mask_images[start:end])[:, :, :, None]),
451
+ "t h w c -> 1 t c h w",
452
+ )
453
+
454
+
455
+ for key, value in batch.items():
456
+ if isinstance(value, torch.Tensor):
457
+ batch[key] = value.to(device=self.device, dtype=torch.bfloat16)
458
+
459
+ ref_pixel_values = batch["refer_pixel_values"]
460
+ refer_t_pixel_values = batch["refer_t_pixel_values"]
461
+ conditioning_pixel_values = batch["conditioning_pixel_values"]
462
+ face_pixel_values = batch["face_pixel_values"]
463
+
464
+ B, _, H, W = ref_pixel_values.shape
465
+ T = clip_len
466
+ lat_h = H // 8
467
+ lat_w = W // 8
468
+ lat_t = T // 4 + 1
469
+ target_shape = [lat_t + 1, lat_h, lat_w]
470
+ noise = [
471
+ torch.randn(
472
+ 16,
473
+ target_shape[0],
474
+ target_shape[1],
475
+ target_shape[2],
476
+ dtype=torch.float32,
477
+ device=self.device,
478
+ generator=seed_g,
479
+ )
480
+ ]
481
+
482
+ max_seq_len = int(math.ceil(np.prod(target_shape) // 4 / self.sp_size)) * self.sp_size
483
+ if max_seq_len % self.sp_size != 0:
484
+ raise ValueError(f"max_seq_len {max_seq_len} is not divisible by sp_size {self.sp_size}")
485
+
486
+ with (
487
+ torch.autocast(device_type=str(self.device), dtype=torch.bfloat16, enabled=True),
488
+ torch.no_grad()
489
+ ):
490
+ if sample_solver == 'unipc':
491
+ sample_scheduler = FlowUniPCMultistepScheduler(
492
+ num_train_timesteps=self.num_train_timesteps,
493
+ shift=1,
494
+ use_dynamic_shifting=False)
495
+ sample_scheduler.set_timesteps(
496
+ sampling_steps, device=self.device, shift=shift)
497
+ timesteps = sample_scheduler.timesteps
498
+ elif sample_solver == 'dpm++':
499
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
500
+ num_train_timesteps=self.num_train_timesteps,
501
+ shift=1,
502
+ use_dynamic_shifting=False)
503
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
504
+ timesteps, _ = retrieve_timesteps(
505
+ sample_scheduler,
506
+ device=self.device,
507
+ sigmas=sampling_sigmas)
508
+ else:
509
+ raise NotImplementedError("Unsupported solver.")
510
+
511
+ latents = noise
512
+
513
+ pose_latents_no_ref = self.vae.encode(conditioning_pixel_values.to(torch.bfloat16))
514
+ pose_latents_no_ref = torch.stack(pose_latents_no_ref)
515
+ pose_latents = torch.cat([pose_latents_no_ref], dim=2)
516
+
517
+ ref_pixel_values = rearrange(ref_pixel_values, "t c h w -> 1 c t h w")
518
+ ref_latents = self.vae.encode(ref_pixel_values.to(torch.bfloat16))
519
+ ref_latents = torch.stack(ref_latents)
520
+
521
+ mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=self.device)
522
+ y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=self.device)
523
+
524
+ img = ref_pixel_values[0, :, 0]
525
+ clip_context = self.clip.visual([img[:, None, :, :]]).to(dtype=torch.bfloat16, device=self.device)
526
+
527
+ if mask_reft_len > 0:
528
+ if replace_flag:
529
+ bg_pixel_values = batch["bg_pixel_values"]
530
+ y_reft = self.vae.encode(
531
+ [
532
+ torch.concat([refer_t_pixel_values[0, :, :mask_reft_len], bg_pixel_values[0, :, mask_reft_len:]], dim=1).to(self.device)
533
+ ]
534
+ )[0]
535
+ mask_pixel_values = 1 - batch["mask_pixel_values"]
536
+ mask_pixel_values = rearrange(mask_pixel_values, "b t c h w -> (b t) c h w")
537
+ mask_pixel_values = F.interpolate(mask_pixel_values, size=(H//8, W//8), mode='nearest')
538
+ mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0]
539
+ msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, mask_pixel_values=mask_pixel_values, device=self.device)
540
+ else:
541
+ y_reft = self.vae.encode(
542
+ [
543
+ torch.concat(
544
+ [
545
+ torch.nn.functional.interpolate(refer_t_pixel_values[0, :, :mask_reft_len].cpu(),
546
+ size=(H, W), mode="bicubic"),
547
+ torch.zeros(3, T - mask_reft_len, H, W),
548
+ ],
549
+ dim=1,
550
+ ).to(self.device)
551
+ ]
552
+ )[0]
553
+ msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, device=self.device)
554
+ else:
555
+ if replace_flag:
556
+ bg_pixel_values = batch["bg_pixel_values"]
557
+ mask_pixel_values = 1 - batch["mask_pixel_values"]
558
+ mask_pixel_values = rearrange(mask_pixel_values, "b t c h w -> (b t) c h w")
559
+ mask_pixel_values = F.interpolate(mask_pixel_values, size=(H//8, W//8), mode='nearest')
560
+ mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0]
561
+ y_reft = self.vae.encode(
562
+ [
563
+ torch.concat(
564
+ [
565
+ bg_pixel_values[0],
566
+ ],
567
+ dim=1,
568
+ ).to(self.device)
569
+ ]
570
+ )[0]
571
+ msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, mask_pixel_values=mask_pixel_values, device=self.device)
572
+ else:
573
+ y_reft = self.vae.encode(
574
+ [
575
+ torch.concat(
576
+ [
577
+ torch.zeros(3, T - mask_reft_len, H, W),
578
+ ],
579
+ dim=1,
580
+ ).to(self.device)
581
+ ]
582
+ )[0]
583
+ msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, device=self.device)
584
+
585
+ y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=self.device)
586
+ y = torch.concat([y_ref, y_reft], dim=1)
587
+
588
+ arg_c = {
589
+ "context": context,
590
+ "seq_len": max_seq_len,
591
+ "clip_fea": clip_context.to(dtype=torch.bfloat16, device=self.device),
592
+ "y": [y],
593
+ "pose_latents": pose_latents,
594
+ "face_pixel_values": face_pixel_values,
595
+ }
596
+
597
+ if guide_scale > 1:
598
+ face_pixel_values_uncond = face_pixel_values * 0 - 1
599
+ arg_null = {
600
+ "context": context_null,
601
+ "seq_len": max_seq_len,
602
+ "clip_fea": clip_context.to(dtype=torch.bfloat16, device=self.device),
603
+ "y": [y],
604
+ "pose_latents": pose_latents,
605
+ "face_pixel_values": face_pixel_values_uncond,
606
+ }
607
+
608
+ for i, t in enumerate(tqdm(timesteps)):
609
+ latent_model_input = latents
610
+ timestep = [t]
611
+
612
+ timestep = torch.stack(timestep)
613
+
614
+ noise_pred_cond = TensorList(
615
+ self.noise_model(TensorList(latent_model_input), t=timestep, **arg_c)
616
+ )
617
+
618
+ if guide_scale > 1:
619
+ noise_pred_uncond = TensorList(
620
+ self.noise_model(
621
+ TensorList(latent_model_input), t=timestep, **arg_null
622
+ )
623
+ )
624
+ noise_pred = noise_pred_uncond + guide_scale * (
625
+ noise_pred_cond - noise_pred_uncond
626
+ )
627
+ else:
628
+ noise_pred = noise_pred_cond
629
+
630
+ temp_x0 = sample_scheduler.step(
631
+ noise_pred[0].unsqueeze(0),
632
+ t,
633
+ latents[0].unsqueeze(0),
634
+ return_dict=False,
635
+ generator=seed_g,
636
+ )[0]
637
+ latents[0] = temp_x0.squeeze(0)
638
+
639
+ x0 = latents
640
+
641
+ x0 = [x.to(dtype=torch.float32) for x in x0]
642
+ out_frames = torch.stack(self.vae.decode([x0[0][:, 1:]]))
643
+
644
+ if start != 0:
645
+ out_frames = out_frames[:, :, refert_num:]
646
+
647
+ all_out_frames.append(out_frames.cpu())
648
+
649
+ start += clip_len - refert_num
650
+ end += clip_len - refert_num
651
+
652
+ videos = torch.cat(all_out_frames, dim=2)[:, :, :real_frame_len]
653
+ return videos[0] if self.rank == 0 else None
wan/configs/__init__.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import copy
3
+ import os
4
+
5
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
6
+
7
+ from .wan_i2v_A14B import i2v_A14B
8
+ from .wan_s2v_14B import s2v_14B
9
+ from .wan_t2v_A14B import t2v_A14B
10
+ from .wan_ti2v_5B import ti2v_5B
11
+ from .wan_animate_14B import animate_14B
12
+
13
+ WAN_CONFIGS = {
14
+ 't2v-A14B': t2v_A14B,
15
+ 'i2v-A14B': i2v_A14B,
16
+ 'ti2v-5B': ti2v_5B,
17
+ 'animate-14B': animate_14B,
18
+ 's2v-14B': s2v_14B,
19
+ }
20
+
21
+ SIZE_CONFIGS = {
22
+ '720*1280': (720, 1280),
23
+ '1280*720': (1280, 720),
24
+ '480*832': (480, 832),
25
+ '832*480': (832, 480),
26
+ '704*1280': (704, 1280),
27
+ '1280*704': (1280, 704),
28
+ '1024*704': (1024, 704),
29
+ '704*1024': (704, 1024),
30
+ }
31
+
32
+ MAX_AREA_CONFIGS = {
33
+ '720*1280': 720 * 1280,
34
+ '1280*720': 1280 * 720,
35
+ '480*832': 480 * 832,
36
+ '832*480': 832 * 480,
37
+ '704*1280': 704 * 1280,
38
+ '1280*704': 1280 * 704,
39
+ '1024*704': 1024 * 704,
40
+ '704*1024': 704 * 1024,
41
+ }
42
+
43
+ SUPPORTED_SIZES = {
44
+ 't2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
45
+ 'i2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
46
+ 'ti2v-5B': ('704*1280', '1280*704'),
47
+ 's2v-14B': ('720*1280', '1280*720', '480*832', '832*480', '1024*704',
48
+ '704*1024', '704*1280', '1280*704'),
49
+ 'animate-14B': ('720*1280', '1280*720')
50
+ }
wan/configs/shared_config.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ from easydict import EasyDict
4
+
5
+ #------------------------ Wan shared config ------------------------#
6
+ wan_shared_cfg = EasyDict()
7
+
8
+ # t5
9
+ wan_shared_cfg.t5_model = 'umt5_xxl'
10
+ wan_shared_cfg.t5_dtype = torch.bfloat16
11
+ wan_shared_cfg.text_len = 512
12
+
13
+ # transformer
14
+ wan_shared_cfg.param_dtype = torch.bfloat16
15
+
16
+ # inference
17
+ wan_shared_cfg.num_train_timesteps = 1000
18
+ wan_shared_cfg.sample_fps = 16
19
+ wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
20
+ wan_shared_cfg.frame_num = 81
wan/configs/wan_animate_14B.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ #------------------------ Wan animate 14B ------------------------#
7
+ animate_14B = EasyDict(__name__='Config: Wan animate 14B')
8
+ animate_14B.update(wan_shared_cfg)
9
+
10
+ animate_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
11
+ animate_14B.t5_tokenizer = 'google/umt5-xxl'
12
+
13
+ animate_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
14
+ animate_14B.clip_tokenizer = 'xlm-roberta-large'
15
+ animate_14B.lora_checkpoint = 'relighting_lora.ckpt'
16
+ # vae
17
+ animate_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
18
+ animate_14B.vae_stride = (4, 8, 8)
19
+
20
+ # transformer
21
+ animate_14B.patch_size = (1, 2, 2)
22
+ animate_14B.dim = 5120
23
+ animate_14B.ffn_dim = 13824
24
+ animate_14B.freq_dim = 256
25
+ animate_14B.num_heads = 40
26
+ animate_14B.num_layers = 40
27
+ animate_14B.window_size = (-1, -1)
28
+ animate_14B.qk_norm = True
29
+ animate_14B.cross_attn_norm = True
30
+ animate_14B.eps = 1e-6
31
+ animate_14B.use_face_encoder = True
32
+ animate_14B.motion_encoder_dim = 512
33
+
34
+ # inference
35
+ animate_14B.sample_shift = 5.0
36
+ animate_14B.sample_steps = 20
37
+ animate_14B.sample_guide_scale = 1.0
38
+ animate_14B.frame_num = 77
39
+ animate_14B.sample_fps = 30
40
+ animate_14B.prompt = '视频中的人在做动作'
wan/configs/wan_i2v_A14B.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ from easydict import EasyDict
4
+
5
+ from .shared_config import wan_shared_cfg
6
+
7
+ #------------------------ Wan I2V A14B ------------------------#
8
+
9
+ i2v_A14B = EasyDict(__name__='Config: Wan I2V A14B')
10
+ i2v_A14B.update(wan_shared_cfg)
11
+
12
+ i2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ i2v_A14B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ i2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
+ i2v_A14B.vae_stride = (4, 8, 8)
18
+
19
+ # transformer
20
+ i2v_A14B.patch_size = (1, 2, 2)
21
+ i2v_A14B.dim = 5120
22
+ i2v_A14B.ffn_dim = 13824
23
+ i2v_A14B.freq_dim = 256
24
+ i2v_A14B.num_heads = 40
25
+ i2v_A14B.num_layers = 40
26
+ i2v_A14B.window_size = (-1, -1)
27
+ i2v_A14B.qk_norm = True
28
+ i2v_A14B.cross_attn_norm = True
29
+ i2v_A14B.eps = 1e-6
30
+ i2v_A14B.low_noise_checkpoint = 'low_noise_model'
31
+ i2v_A14B.high_noise_checkpoint = 'high_noise_model'
32
+
33
+ # inference
34
+ i2v_A14B.sample_shift = 5.0
35
+ i2v_A14B.sample_steps = 40
36
+ i2v_A14B.boundary = 0.900
37
+ i2v_A14B.sample_guide_scale = (3.5, 3.5) # low noise, high noise
wan/configs/wan_s2v_14B.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ #------------------------ Wan S2V 14B ------------------------#
7
+
8
+ s2v_14B = EasyDict(__name__='Config: Wan S2V 14B')
9
+ s2v_14B.update(wan_shared_cfg)
10
+
11
+ # t5
12
+ s2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ s2v_14B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ s2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
+ s2v_14B.vae_stride = (4, 8, 8)
18
+
19
+ # wav2vec
20
+ s2v_14B.wav2vec = "wav2vec2-large-xlsr-53-english"
21
+
22
+ s2v_14B.num_heads = 40
23
+ # transformer
24
+ s2v_14B.transformer = EasyDict(
25
+ __name__="Config: Transformer config for WanModel_S2V")
26
+ s2v_14B.transformer.patch_size = (1, 2, 2)
27
+ s2v_14B.transformer.dim = 5120
28
+ s2v_14B.transformer.ffn_dim = 13824
29
+ s2v_14B.transformer.freq_dim = 256
30
+ s2v_14B.transformer.num_heads = 40
31
+ s2v_14B.transformer.num_layers = 40
32
+ s2v_14B.transformer.window_size = (-1, -1)
33
+ s2v_14B.transformer.qk_norm = True
34
+ s2v_14B.transformer.cross_attn_norm = True
35
+ s2v_14B.transformer.eps = 1e-6
36
+ s2v_14B.transformer.enable_adain = True
37
+ s2v_14B.transformer.adain_mode = "attn_norm"
38
+ s2v_14B.transformer.audio_inject_layers = [
39
+ 0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39
40
+ ]
41
+ s2v_14B.transformer.zero_init = True
42
+ s2v_14B.transformer.zero_timestep = True
43
+ s2v_14B.transformer.enable_motioner = False
44
+ s2v_14B.transformer.add_last_motion = True
45
+ s2v_14B.transformer.trainable_token = False
46
+ s2v_14B.transformer.enable_tsm = False
47
+ s2v_14B.transformer.enable_framepack = True
48
+ s2v_14B.transformer.framepack_drop_mode = 'padd'
49
+ s2v_14B.transformer.audio_dim = 1024
50
+
51
+ s2v_14B.transformer.motion_frames = 73
52
+ s2v_14B.transformer.cond_dim = 16
53
+
54
+ # inference
55
+ s2v_14B.sample_neg_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
56
+ s2v_14B.drop_first_motion = True
57
+ s2v_14B.sample_shift = 3
58
+ s2v_14B.sample_steps = 40
59
+ s2v_14B.sample_guide_scale = 4.5
wan/configs/wan_t2v_A14B.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ #------------------------ Wan T2V A14B ------------------------#
7
+
8
+ t2v_A14B = EasyDict(__name__='Config: Wan T2V A14B')
9
+ t2v_A14B.update(wan_shared_cfg)
10
+
11
+ # t5
12
+ t2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ t2v_A14B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ t2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
+ t2v_A14B.vae_stride = (4, 8, 8)
18
+
19
+ # transformer
20
+ t2v_A14B.patch_size = (1, 2, 2)
21
+ t2v_A14B.dim = 5120
22
+ t2v_A14B.ffn_dim = 13824
23
+ t2v_A14B.freq_dim = 256
24
+ t2v_A14B.num_heads = 40
25
+ t2v_A14B.num_layers = 40
26
+ t2v_A14B.window_size = (-1, -1)
27
+ t2v_A14B.qk_norm = True
28
+ t2v_A14B.cross_attn_norm = True
29
+ t2v_A14B.eps = 1e-6
30
+ t2v_A14B.low_noise_checkpoint = 'low_noise_model'
31
+ t2v_A14B.high_noise_checkpoint = 'high_noise_model'
32
+
33
+ # inference
34
+ t2v_A14B.sample_shift = 12.0
35
+ t2v_A14B.sample_steps = 40
36
+ t2v_A14B.boundary = 0.875
37
+ t2v_A14B.sample_guide_scale = (3.0, 4.0) # low noise, high noise
wan/configs/wan_ti2v_5B.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ #------------------------ Wan TI2V 5B ------------------------#
7
+
8
+ ti2v_5B = EasyDict(__name__='Config: Wan TI2V 5B')
9
+ ti2v_5B.update(wan_shared_cfg)
10
+
11
+ # t5
12
+ ti2v_5B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ ti2v_5B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ ti2v_5B.vae_checkpoint = 'Wan2.2_VAE.pth'
17
+ ti2v_5B.vae_stride = (4, 16, 16)
18
+
19
+ # transformer
20
+ ti2v_5B.patch_size = (1, 2, 2)
21
+ ti2v_5B.dim = 3072
22
+ ti2v_5B.ffn_dim = 14336
23
+ ti2v_5B.freq_dim = 256
24
+ ti2v_5B.num_heads = 24
25
+ ti2v_5B.num_layers = 30
26
+ ti2v_5B.window_size = (-1, -1)
27
+ ti2v_5B.qk_norm = True
28
+ ti2v_5B.cross_attn_norm = True
29
+ ti2v_5B.eps = 1e-6
30
+
31
+ # inference
32
+ ti2v_5B.sample_fps = 24
33
+ ti2v_5B.sample_shift = 5.0
34
+ ti2v_5B.sample_steps = 50
35
+ ti2v_5B.sample_guide_scale = 5.0
36
+ ti2v_5B.frame_num = 121
wan/distributed/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
wan/distributed/fsdp.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ from functools import partial
4
+
5
+ import torch
6
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
7
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
8
+ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
9
+ from torch.distributed.utils import _free_storage
10
+
11
+
12
+ def shard_model(
13
+ model,
14
+ device_id,
15
+ param_dtype=torch.bfloat16,
16
+ reduce_dtype=torch.float32,
17
+ buffer_dtype=torch.float32,
18
+ process_group=None,
19
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
20
+ sync_module_states=True,
21
+ use_lora=False
22
+ ):
23
+ model = FSDP(
24
+ module=model,
25
+ process_group=process_group,
26
+ sharding_strategy=sharding_strategy,
27
+ auto_wrap_policy=partial(
28
+ lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
29
+ mixed_precision=MixedPrecision(
30
+ param_dtype=param_dtype,
31
+ reduce_dtype=reduce_dtype,
32
+ buffer_dtype=buffer_dtype),
33
+ device_id=device_id,
34
+ sync_module_states=sync_module_states,
35
+ use_orig_params=True if use_lora else False)
36
+ return model
37
+
38
+
39
+ def free_model(model):
40
+ for m in model.modules():
41
+ if isinstance(m, FSDP):
42
+ _free_storage(m._handle.flat_param.data)
43
+ del model
44
+ gc.collect()
45
+ torch.cuda.empty_cache()
wan/distributed/sequence_parallel.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.cuda.amp as amp
4
+
5
+ from ..modules.model import sinusoidal_embedding_1d
6
+ from .ulysses import distributed_attention
7
+ from .util import gather_forward, get_rank, get_world_size
8
+
9
+
10
+ def pad_freqs(original_tensor, target_len):
11
+ seq_len, s1, s2 = original_tensor.shape
12
+ pad_size = target_len - seq_len
13
+ padding_tensor = torch.ones(
14
+ pad_size,
15
+ s1,
16
+ s2,
17
+ dtype=original_tensor.dtype,
18
+ device=original_tensor.device)
19
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
20
+ return padded_tensor
21
+
22
+
23
+ @torch.amp.autocast('cuda', enabled=False)
24
+ def rope_apply(x, grid_sizes, freqs):
25
+ """
26
+ x: [B, L, N, C].
27
+ grid_sizes: [B, 3].
28
+ freqs: [M, C // 2].
29
+ """
30
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
31
+ # split freqs
32
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
33
+
34
+ # loop over samples
35
+ output = []
36
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
37
+ seq_len = f * h * w
38
+
39
+ # precompute multipliers
40
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
41
+ s, n, -1, 2))
42
+ freqs_i = torch.cat([
43
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
44
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
45
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
46
+ ],
47
+ dim=-1).reshape(seq_len, 1, -1)
48
+
49
+ # apply rotary embedding
50
+ sp_size = get_world_size()
51
+ sp_rank = get_rank()
52
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
53
+ s_per_rank = s
54
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
55
+ s_per_rank), :, :]
56
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
57
+ x_i = torch.cat([x_i, x[i, s:]])
58
+
59
+ # append to collection
60
+ output.append(x_i)
61
+ return torch.stack(output).float()
62
+
63
+
64
+ def sp_dit_forward(
65
+ self,
66
+ x,
67
+ t,
68
+ context,
69
+ seq_len,
70
+ y=None,
71
+ ):
72
+ """
73
+ x: A list of videos each with shape [C, T, H, W].
74
+ t: [B].
75
+ context: A list of text embeddings each with shape [L, C].
76
+ """
77
+ if self.model_type == 'i2v':
78
+ assert y is not None
79
+ # params
80
+ device = self.patch_embedding.weight.device
81
+ if self.freqs.device != device:
82
+ self.freqs = self.freqs.to(device)
83
+
84
+ if y is not None:
85
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
86
+
87
+ # embeddings
88
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
89
+ grid_sizes = torch.stack(
90
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
91
+ x = [u.flatten(2).transpose(1, 2) for u in x]
92
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
93
+ assert seq_lens.max() <= seq_len
94
+ x = torch.cat([
95
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
96
+ for u in x
97
+ ])
98
+
99
+ # time embeddings
100
+ if t.dim() == 1:
101
+ t = t.expand(t.size(0), seq_len)
102
+ with torch.amp.autocast('cuda', dtype=torch.float32):
103
+ bt = t.size(0)
104
+ t = t.flatten()
105
+ e = self.time_embedding(
106
+ sinusoidal_embedding_1d(self.freq_dim,
107
+ t).unflatten(0, (bt, seq_len)).float())
108
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
109
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
110
+
111
+ # context
112
+ context_lens = None
113
+ context = self.text_embedding(
114
+ torch.stack([
115
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
116
+ for u in context
117
+ ]))
118
+
119
+ # Context Parallel
120
+ x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]
121
+ e = torch.chunk(e, get_world_size(), dim=1)[get_rank()]
122
+ e0 = torch.chunk(e0, get_world_size(), dim=1)[get_rank()]
123
+
124
+ # arguments
125
+ kwargs = dict(
126
+ e=e0,
127
+ seq_lens=seq_lens,
128
+ grid_sizes=grid_sizes,
129
+ freqs=self.freqs,
130
+ context=context,
131
+ context_lens=context_lens)
132
+
133
+ for block in self.blocks:
134
+ x = block(x, **kwargs)
135
+
136
+ # head
137
+ x = self.head(x, e)
138
+
139
+ # Context Parallel
140
+ x = gather_forward(x, dim=1)
141
+
142
+ # unpatchify
143
+ x = self.unpatchify(x, grid_sizes)
144
+ return [u.float() for u in x]
145
+
146
+
147
+ def sp_attn_forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16):
148
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
149
+ half_dtypes = (torch.float16, torch.bfloat16)
150
+
151
+ def half(x):
152
+ return x if x.dtype in half_dtypes else x.to(dtype)
153
+
154
+ # query, key, value function
155
+ def qkv_fn(x):
156
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
157
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
158
+ v = self.v(x).view(b, s, n, d)
159
+ return q, k, v
160
+
161
+ q, k, v = qkv_fn(x)
162
+ q = rope_apply(q, grid_sizes, freqs)
163
+ k = rope_apply(k, grid_sizes, freqs)
164
+
165
+ x = distributed_attention(
166
+ half(q),
167
+ half(k),
168
+ half(v),
169
+ seq_lens,
170
+ window_size=self.window_size,
171
+ )
172
+
173
+ # output
174
+ x = x.flatten(2)
175
+ x = self.o(x)
176
+ return x
wan/distributed/ulysses.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.distributed as dist
4
+
5
+ from ..modules.attention import flash_attention
6
+ from .util import all_to_all
7
+
8
+
9
+ def distributed_attention(
10
+ q,
11
+ k,
12
+ v,
13
+ seq_lens,
14
+ window_size=(-1, -1),
15
+ ):
16
+ """
17
+ Performs distributed attention based on DeepSpeed Ulysses attention mechanism.
18
+ please refer to https://arxiv.org/pdf/2309.14509
19
+
20
+ Args:
21
+ q: [B, Lq // p, Nq, C1].
22
+ k: [B, Lk // p, Nk, C1].
23
+ v: [B, Lk // p, Nk, C2]. Nq must be divisible by Nk.
24
+ seq_lens: [B], length of each sequence in batch
25
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
26
+ """
27
+ if not dist.is_initialized():
28
+ raise ValueError("distributed group should be initialized.")
29
+ b = q.shape[0]
30
+
31
+ # gather q/k/v sequence
32
+ q = all_to_all(q, scatter_dim=2, gather_dim=1)
33
+ k = all_to_all(k, scatter_dim=2, gather_dim=1)
34
+ v = all_to_all(v, scatter_dim=2, gather_dim=1)
35
+
36
+ # apply attention
37
+ x = flash_attention(
38
+ q,
39
+ k,
40
+ v,
41
+ k_lens=seq_lens,
42
+ window_size=window_size,
43
+ )
44
+
45
+ # scatter q/k/v sequence
46
+ x = all_to_all(x, scatter_dim=1, gather_dim=2)
47
+ return x
wan/distributed/util.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.distributed as dist
4
+
5
+
6
+ def init_distributed_group():
7
+ """r initialize sequence parallel group.
8
+ """
9
+ if not dist.is_initialized():
10
+ dist.init_process_group(backend='nccl')
11
+
12
+
13
+ def get_rank():
14
+ return dist.get_rank()
15
+
16
+
17
+ def get_world_size():
18
+ return dist.get_world_size()
19
+
20
+
21
+ def all_to_all(x, scatter_dim, gather_dim, group=None, **kwargs):
22
+ """
23
+ `scatter` along one dimension and `gather` along another.
24
+ """
25
+ world_size = get_world_size()
26
+ if world_size > 1:
27
+ inputs = [u.contiguous() for u in x.chunk(world_size, dim=scatter_dim)]
28
+ outputs = [torch.empty_like(u) for u in inputs]
29
+ dist.all_to_all(outputs, inputs, group=group, **kwargs)
30
+ x = torch.cat(outputs, dim=gather_dim).contiguous()
31
+ return x
32
+
33
+
34
+ def all_gather(tensor):
35
+ world_size = dist.get_world_size()
36
+ if world_size == 1:
37
+ return [tensor]
38
+ tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
39
+ torch.distributed.all_gather(tensor_list, tensor)
40
+ return tensor_list
41
+
42
+
43
+ def gather_forward(input, dim):
44
+ # skip if world_size == 1
45
+ world_size = dist.get_world_size()
46
+ if world_size == 1:
47
+ return input
48
+
49
+ # gather sequence
50
+ output = all_gather(input)
51
+ return torch.cat(output, dim=dim).contiguous()
wan/image2video.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ import logging
4
+ import math
5
+ import os
6
+ import random
7
+ import sys
8
+ import types
9
+ from contextlib import contextmanager
10
+ from functools import partial
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.cuda.amp as amp
15
+ import torch.distributed as dist
16
+ import torchvision.transforms.functional as TF
17
+ from tqdm import tqdm
18
+
19
+ from .distributed.fsdp import shard_model
20
+ from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
21
+ from .distributed.util import get_world_size
22
+ from .modules.model import WanModel
23
+ from .modules.t5 import T5EncoderModel
24
+ from .modules.vae2_1 import Wan2_1_VAE
25
+ from .utils.fm_solvers import (
26
+ FlowDPMSolverMultistepScheduler,
27
+ get_sampling_sigmas,
28
+ retrieve_timesteps,
29
+ )
30
+ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
31
+
32
+
33
+ class WanI2V:
34
+
35
+ def __init__(
36
+ self,
37
+ config,
38
+ checkpoint_dir,
39
+ device_id=0,
40
+ rank=0,
41
+ t5_fsdp=False,
42
+ dit_fsdp=False,
43
+ use_sp=False,
44
+ t5_cpu=False,
45
+ init_on_cpu=True,
46
+ convert_model_dtype=False,
47
+ ):
48
+ r"""
49
+ Initializes the image-to-video generation model components.
50
+
51
+ Args:
52
+ config (EasyDict):
53
+ Object containing model parameters initialized from config.py
54
+ checkpoint_dir (`str`):
55
+ Path to directory containing model checkpoints
56
+ device_id (`int`, *optional*, defaults to 0):
57
+ Id of target GPU device
58
+ rank (`int`, *optional*, defaults to 0):
59
+ Process rank for distributed training
60
+ t5_fsdp (`bool`, *optional*, defaults to False):
61
+ Enable FSDP sharding for T5 model
62
+ dit_fsdp (`bool`, *optional*, defaults to False):
63
+ Enable FSDP sharding for DiT model
64
+ use_sp (`bool`, *optional*, defaults to False):
65
+ Enable distribution strategy of sequence parallel.
66
+ t5_cpu (`bool`, *optional*, defaults to False):
67
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
68
+ init_on_cpu (`bool`, *optional*, defaults to True):
69
+ Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
70
+ convert_model_dtype (`bool`, *optional*, defaults to False):
71
+ Convert DiT model parameters dtype to 'config.param_dtype'.
72
+ Only works without FSDP.
73
+ """
74
+ self.device = torch.device(f"cuda:{device_id}")
75
+ self.config = config
76
+ self.rank = rank
77
+ self.t5_cpu = t5_cpu
78
+ self.init_on_cpu = init_on_cpu
79
+
80
+ self.num_train_timesteps = config.num_train_timesteps
81
+ self.boundary = config.boundary
82
+ self.param_dtype = config.param_dtype
83
+
84
+ if t5_fsdp or dit_fsdp or use_sp:
85
+ self.init_on_cpu = False
86
+
87
+ shard_fn = partial(shard_model, device_id=device_id)
88
+ self.text_encoder = T5EncoderModel(
89
+ text_len=config.text_len,
90
+ dtype=config.t5_dtype,
91
+ device=torch.device('cpu'),
92
+ checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
93
+ tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
94
+ shard_fn=shard_fn if t5_fsdp else None,
95
+ )
96
+
97
+ self.vae_stride = config.vae_stride
98
+ self.patch_size = config.patch_size
99
+ self.vae = Wan2_1_VAE(
100
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
101
+ device=self.device)
102
+
103
+ logging.info(f"Creating WanModel from {checkpoint_dir}")
104
+ self.low_noise_model = WanModel.from_pretrained(
105
+ checkpoint_dir, subfolder=config.low_noise_checkpoint)
106
+ self.low_noise_model = self._configure_model(
107
+ model=self.low_noise_model,
108
+ use_sp=use_sp,
109
+ dit_fsdp=dit_fsdp,
110
+ shard_fn=shard_fn,
111
+ convert_model_dtype=convert_model_dtype)
112
+
113
+ self.high_noise_model = WanModel.from_pretrained(
114
+ checkpoint_dir, subfolder=config.high_noise_checkpoint)
115
+ self.high_noise_model = self._configure_model(
116
+ model=self.high_noise_model,
117
+ use_sp=use_sp,
118
+ dit_fsdp=dit_fsdp,
119
+ shard_fn=shard_fn,
120
+ convert_model_dtype=convert_model_dtype)
121
+ if use_sp:
122
+ self.sp_size = get_world_size()
123
+ else:
124
+ self.sp_size = 1
125
+
126
+ self.sample_neg_prompt = config.sample_neg_prompt
127
+
128
+ def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
129
+ convert_model_dtype):
130
+ """
131
+ Configures a model object. This includes setting evaluation modes,
132
+ applying distributed parallel strategy, and handling device placement.
133
+
134
+ Args:
135
+ model (torch.nn.Module):
136
+ The model instance to configure.
137
+ use_sp (`bool`):
138
+ Enable distribution strategy of sequence parallel.
139
+ dit_fsdp (`bool`):
140
+ Enable FSDP sharding for DiT model.
141
+ shard_fn (callable):
142
+ The function to apply FSDP sharding.
143
+ convert_model_dtype (`bool`):
144
+ Convert DiT model parameters dtype to 'config.param_dtype'.
145
+ Only works without FSDP.
146
+
147
+ Returns:
148
+ torch.nn.Module:
149
+ The configured model.
150
+ """
151
+ model.eval().requires_grad_(False)
152
+
153
+ if use_sp:
154
+ for block in model.blocks:
155
+ block.self_attn.forward = types.MethodType(
156
+ sp_attn_forward, block.self_attn)
157
+ model.forward = types.MethodType(sp_dit_forward, model)
158
+
159
+ if dist.is_initialized():
160
+ dist.barrier()
161
+
162
+ if dit_fsdp:
163
+ model = shard_fn(model)
164
+ else:
165
+ if convert_model_dtype:
166
+ model.to(self.param_dtype)
167
+ if not self.init_on_cpu:
168
+ model.to(self.device)
169
+
170
+ return model
171
+
172
+ def _prepare_model_for_timestep(self, t, boundary, offload_model):
173
+ r"""
174
+ Prepares and returns the required model for the current timestep.
175
+
176
+ Args:
177
+ t (torch.Tensor):
178
+ current timestep.
179
+ boundary (`int`):
180
+ The timestep threshold. If `t` is at or above this value,
181
+ the `high_noise_model` is considered as the required model.
182
+ offload_model (`bool`):
183
+ A flag intended to control the offloading behavior.
184
+
185
+ Returns:
186
+ torch.nn.Module:
187
+ The active model on the target device for the current timestep.
188
+ """
189
+ if t.item() >= boundary:
190
+ required_model_name = 'high_noise_model'
191
+ offload_model_name = 'low_noise_model'
192
+ else:
193
+ required_model_name = 'low_noise_model'
194
+ offload_model_name = 'high_noise_model'
195
+ if offload_model or self.init_on_cpu:
196
+ if next(getattr(
197
+ self,
198
+ offload_model_name).parameters()).device.type == 'cuda':
199
+ getattr(self, offload_model_name).to('cpu')
200
+ if next(getattr(
201
+ self,
202
+ required_model_name).parameters()).device.type == 'cpu':
203
+ getattr(self, required_model_name).to(self.device)
204
+ return getattr(self, required_model_name)
205
+
206
+ def generate(self,
207
+ input_prompt,
208
+ img,
209
+ max_area=720 * 1280,
210
+ frame_num=81,
211
+ shift=5.0,
212
+ sample_solver='unipc',
213
+ sampling_steps=40,
214
+ guide_scale=5.0,
215
+ n_prompt="",
216
+ seed=-1,
217
+ offload_model=True):
218
+ r"""
219
+ Generates video frames from input image and text prompt using diffusion process.
220
+
221
+ Args:
222
+ input_prompt (`str`):
223
+ Text prompt for content generation.
224
+ img (PIL.Image.Image):
225
+ Input image tensor. Shape: [3, H, W]
226
+ max_area (`int`, *optional*, defaults to 720*1280):
227
+ Maximum pixel area for latent space calculation. Controls video resolution scaling
228
+ frame_num (`int`, *optional*, defaults to 81):
229
+ How many frames to sample from a video. The number should be 4n+1
230
+ shift (`float`, *optional*, defaults to 5.0):
231
+ Noise schedule shift parameter. Affects temporal dynamics
232
+ [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
233
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
234
+ Solver used to sample the video.
235
+ sampling_steps (`int`, *optional*, defaults to 40):
236
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
237
+ guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):
238
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity.
239
+ If tuple, the first guide_scale will be used for low noise model and
240
+ the second guide_scale will be used for high noise model.
241
+ n_prompt (`str`, *optional*, defaults to ""):
242
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
243
+ seed (`int`, *optional*, defaults to -1):
244
+ Random seed for noise generation. If -1, use random seed
245
+ offload_model (`bool`, *optional*, defaults to True):
246
+ If True, offloads models to CPU during generation to save VRAM
247
+
248
+ Returns:
249
+ torch.Tensor:
250
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
251
+ - C: Color channels (3 for RGB)
252
+ - N: Number of frames (81)
253
+ - H: Frame height (from max_area)
254
+ - W: Frame width from max_area)
255
+ """
256
+ # preprocess
257
+ guide_scale = (guide_scale, guide_scale) if isinstance(
258
+ guide_scale, float) else guide_scale
259
+ img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
260
+
261
+ F = frame_num
262
+ h, w = img.shape[1:]
263
+ aspect_ratio = h / w
264
+ lat_h = round(
265
+ np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
266
+ self.patch_size[1] * self.patch_size[1])
267
+ lat_w = round(
268
+ np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
269
+ self.patch_size[2] * self.patch_size[2])
270
+ h = lat_h * self.vae_stride[1]
271
+ w = lat_w * self.vae_stride[2]
272
+
273
+ max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
274
+ self.patch_size[1] * self.patch_size[2])
275
+ max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
276
+
277
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
278
+ seed_g = torch.Generator(device=self.device)
279
+ seed_g.manual_seed(seed)
280
+ noise = torch.randn(
281
+ 16,
282
+ (F - 1) // self.vae_stride[0] + 1,
283
+ lat_h,
284
+ lat_w,
285
+ dtype=torch.float32,
286
+ generator=seed_g,
287
+ device=self.device)
288
+
289
+ msk = torch.ones(1, F, lat_h, lat_w, device=self.device)
290
+ msk[:, 1:] = 0
291
+ msk = torch.concat([
292
+ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
293
+ ],
294
+ dim=1)
295
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
296
+ msk = msk.transpose(1, 2)[0]
297
+
298
+ if n_prompt == "":
299
+ n_prompt = self.sample_neg_prompt
300
+
301
+ # preprocess
302
+ if not self.t5_cpu:
303
+ self.text_encoder.model.to(self.device)
304
+ context = self.text_encoder([input_prompt], self.device)
305
+ context_null = self.text_encoder([n_prompt], self.device)
306
+ if offload_model:
307
+ self.text_encoder.model.cpu()
308
+ else:
309
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
310
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
311
+ context = [t.to(self.device) for t in context]
312
+ context_null = [t.to(self.device) for t in context_null]
313
+
314
+ y = self.vae.encode([
315
+ torch.concat([
316
+ torch.nn.functional.interpolate(
317
+ img[None].cpu(), size=(h, w), mode='bicubic').transpose(
318
+ 0, 1),
319
+ torch.zeros(3, F - 1, h, w)
320
+ ],
321
+ dim=1).to(self.device)
322
+ ])[0]
323
+ y = torch.concat([msk, y])
324
+
325
+ @contextmanager
326
+ def noop_no_sync():
327
+ yield
328
+
329
+ no_sync_low_noise = getattr(self.low_noise_model, 'no_sync',
330
+ noop_no_sync)
331
+ no_sync_high_noise = getattr(self.high_noise_model, 'no_sync',
332
+ noop_no_sync)
333
+
334
+ # evaluation mode
335
+ with (
336
+ torch.amp.autocast('cuda', dtype=self.param_dtype),
337
+ torch.no_grad(),
338
+ no_sync_low_noise(),
339
+ no_sync_high_noise(),
340
+ ):
341
+ boundary = self.boundary * self.num_train_timesteps
342
+
343
+ if sample_solver == 'unipc':
344
+ sample_scheduler = FlowUniPCMultistepScheduler(
345
+ num_train_timesteps=self.num_train_timesteps,
346
+ shift=1,
347
+ use_dynamic_shifting=False)
348
+ sample_scheduler.set_timesteps(
349
+ sampling_steps, device=self.device, shift=shift)
350
+ timesteps = sample_scheduler.timesteps
351
+ elif sample_solver == 'dpm++':
352
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
353
+ num_train_timesteps=self.num_train_timesteps,
354
+ shift=1,
355
+ use_dynamic_shifting=False)
356
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
357
+ timesteps, _ = retrieve_timesteps(
358
+ sample_scheduler,
359
+ device=self.device,
360
+ sigmas=sampling_sigmas)
361
+ else:
362
+ raise NotImplementedError("Unsupported solver.")
363
+
364
+ # sample videos
365
+ latent = noise
366
+
367
+ arg_c = {
368
+ 'context': [context[0]],
369
+ 'seq_len': max_seq_len,
370
+ 'y': [y],
371
+ }
372
+
373
+ arg_null = {
374
+ 'context': context_null,
375
+ 'seq_len': max_seq_len,
376
+ 'y': [y],
377
+ }
378
+
379
+ if offload_model:
380
+ torch.cuda.empty_cache()
381
+
382
+ for _, t in enumerate(tqdm(timesteps)):
383
+ latent_model_input = [latent.to(self.device)]
384
+ timestep = [t]
385
+
386
+ timestep = torch.stack(timestep).to(self.device)
387
+
388
+ model = self._prepare_model_for_timestep(
389
+ t, boundary, offload_model)
390
+ sample_guide_scale = guide_scale[1] if t.item(
391
+ ) >= boundary else guide_scale[0]
392
+
393
+ noise_pred_cond = model(
394
+ latent_model_input, t=timestep, **arg_c)[0]
395
+ if offload_model:
396
+ torch.cuda.empty_cache()
397
+ noise_pred_uncond = model(
398
+ latent_model_input, t=timestep, **arg_null)[0]
399
+ if offload_model:
400
+ torch.cuda.empty_cache()
401
+ noise_pred = noise_pred_uncond + sample_guide_scale * (
402
+ noise_pred_cond - noise_pred_uncond)
403
+
404
+ temp_x0 = sample_scheduler.step(
405
+ noise_pred.unsqueeze(0),
406
+ t,
407
+ latent.unsqueeze(0),
408
+ return_dict=False,
409
+ generator=seed_g)[0]
410
+ latent = temp_x0.squeeze(0)
411
+
412
+ x0 = [latent]
413
+ del latent_model_input, timestep
414
+
415
+ if offload_model:
416
+ self.low_noise_model.cpu()
417
+ self.high_noise_model.cpu()
418
+ torch.cuda.empty_cache()
419
+
420
+ if self.rank == 0:
421
+ videos = self.vae.decode(x0)
422
+
423
+ del noise, latent, x0
424
+ del sample_scheduler
425
+ if offload_model:
426
+ gc.collect()
427
+ torch.cuda.synchronize()
428
+ if dist.is_initialized():
429
+ dist.barrier()
430
+
431
+ return videos[0] if self.rank == 0 else None
wan/modules/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from .attention import flash_attention
3
+ from .model import WanModel
4
+ from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
5
+ from .tokenizers import HuggingfaceTokenizer
6
+ from .vae2_1 import Wan2_1_VAE
7
+ from .vae2_2 import Wan2_2_VAE
8
+
9
+ __all__ = [
10
+ 'Wan2_1_VAE',
11
+ 'Wan2_2_VAE',
12
+ 'WanModel',
13
+ 'T5Model',
14
+ 'T5Encoder',
15
+ 'T5Decoder',
16
+ 'T5EncoderModel',
17
+ 'HuggingfaceTokenizer',
18
+ 'flash_attention',
19
+ ]
wan/modules/animate/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from .model_animate import WanAnimateModel
3
+ from .clip import CLIPModel
4
+ __all__ = ['WanAnimateModel', 'CLIPModel']
wan/modules/animate/animate_utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import numbers
4
+ from peft import LoraConfig
5
+
6
+
7
+ def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"):
8
+ target_modules = []
9
+ for name, module in transformer.named_modules():
10
+ if "blocks" in name and "face" not in name and "modulation" not in name and isinstance(module, torch.nn.Linear):
11
+ target_modules.append(name)
12
+
13
+ transformer_lora_config = LoraConfig(
14
+ r=rank,
15
+ lora_alpha=alpha,
16
+ init_lora_weights=init_lora_weights,
17
+ target_modules=target_modules,
18
+ )
19
+ return transformer_lora_config
20
+
21
+
22
+
23
+ class TensorList(object):
24
+
25
+ def __init__(self, tensors):
26
+ """
27
+ tensors: a list of torch.Tensor objects. No need to have uniform shape.
28
+ """
29
+ assert isinstance(tensors, (list, tuple))
30
+ assert all(isinstance(u, torch.Tensor) for u in tensors)
31
+ assert len(set([u.ndim for u in tensors])) == 1
32
+ assert len(set([u.dtype for u in tensors])) == 1
33
+ assert len(set([u.device for u in tensors])) == 1
34
+ self.tensors = tensors
35
+
36
+ def to(self, *args, **kwargs):
37
+ return TensorList([u.to(*args, **kwargs) for u in self.tensors])
38
+
39
+ def size(self, dim):
40
+ assert dim == 0, 'only support get the 0th size'
41
+ return len(self.tensors)
42
+
43
+ def pow(self, *args, **kwargs):
44
+ return TensorList([u.pow(*args, **kwargs) for u in self.tensors])
45
+
46
+ def squeeze(self, dim):
47
+ assert dim != 0
48
+ if dim > 0:
49
+ dim -= 1
50
+ return TensorList([u.squeeze(dim) for u in self.tensors])
51
+
52
+ def type(self, *args, **kwargs):
53
+ return TensorList([u.type(*args, **kwargs) for u in self.tensors])
54
+
55
+ def type_as(self, other):
56
+ assert isinstance(other, (torch.Tensor, TensorList))
57
+ if isinstance(other, torch.Tensor):
58
+ return TensorList([u.type_as(other) for u in self.tensors])
59
+ else:
60
+ return TensorList([u.type(other.dtype) for u in self.tensors])
61
+
62
+ @property
63
+ def dtype(self):
64
+ return self.tensors[0].dtype
65
+
66
+ @property
67
+ def device(self):
68
+ return self.tensors[0].device
69
+
70
+ @property
71
+ def ndim(self):
72
+ return 1 + self.tensors[0].ndim
73
+
74
+ def __getitem__(self, index):
75
+ return self.tensors[index]
76
+
77
+ def __len__(self):
78
+ return len(self.tensors)
79
+
80
+ def __add__(self, other):
81
+ return self._apply(other, lambda u, v: u + v)
82
+
83
+ def __radd__(self, other):
84
+ return self._apply(other, lambda u, v: v + u)
85
+
86
+ def __sub__(self, other):
87
+ return self._apply(other, lambda u, v: u - v)
88
+
89
+ def __rsub__(self, other):
90
+ return self._apply(other, lambda u, v: v - u)
91
+
92
+ def __mul__(self, other):
93
+ return self._apply(other, lambda u, v: u * v)
94
+
95
+ def __rmul__(self, other):
96
+ return self._apply(other, lambda u, v: v * u)
97
+
98
+ def __floordiv__(self, other):
99
+ return self._apply(other, lambda u, v: u // v)
100
+
101
+ def __truediv__(self, other):
102
+ return self._apply(other, lambda u, v: u / v)
103
+
104
+ def __rfloordiv__(self, other):
105
+ return self._apply(other, lambda u, v: v // u)
106
+
107
+ def __rtruediv__(self, other):
108
+ return self._apply(other, lambda u, v: v / u)
109
+
110
+ def __pow__(self, other):
111
+ return self._apply(other, lambda u, v: u ** v)
112
+
113
+ def __rpow__(self, other):
114
+ return self._apply(other, lambda u, v: v ** u)
115
+
116
+ def __neg__(self):
117
+ return TensorList([-u for u in self.tensors])
118
+
119
+ def __iter__(self):
120
+ for tensor in self.tensors:
121
+ yield tensor
122
+
123
+ def __repr__(self):
124
+ return 'TensorList: \n' + repr(self.tensors)
125
+
126
+ def _apply(self, other, op):
127
+ if isinstance(other, (list, tuple, TensorList)) or (
128
+ isinstance(other, torch.Tensor) and (
129
+ other.numel() > 1 or other.ndim > 1
130
+ )
131
+ ):
132
+ assert len(other) == len(self.tensors)
133
+ return TensorList([op(u, v) for u, v in zip(self.tensors, other)])
134
+ elif isinstance(other, numbers.Number) or (
135
+ isinstance(other, torch.Tensor) and (
136
+ other.numel() == 1 and other.ndim <= 1
137
+ )
138
+ ):
139
+ return TensorList([op(u, other) for u in self.tensors])
140
+ else:
141
+ raise TypeError(
142
+ f'unsupported operand for *: "TensorList" and "{type(other)}"'
143
+ )
wan/modules/animate/clip.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms as T
10
+
11
+ from ..attention import flash_attention
12
+ from ..tokenizers import HuggingfaceTokenizer
13
+ from .xlm_roberta import XLMRoberta
14
+
15
+ __all__ = [
16
+ 'XLMRobertaCLIP',
17
+ 'clip_xlm_roberta_vit_h_14',
18
+ 'CLIPModel',
19
+ ]
20
+
21
+
22
+ def pos_interpolate(pos, seq_len):
23
+ if pos.size(1) == seq_len:
24
+ return pos
25
+ else:
26
+ src_grid = int(math.sqrt(pos.size(1)))
27
+ tar_grid = int(math.sqrt(seq_len))
28
+ n = pos.size(1) - src_grid * src_grid
29
+ return torch.cat([
30
+ pos[:, :n],
31
+ F.interpolate(
32
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
33
+ 0, 3, 1, 2),
34
+ size=(tar_grid, tar_grid),
35
+ mode='bicubic',
36
+ align_corners=False).flatten(2).transpose(1, 2)
37
+ ],
38
+ dim=1)
39
+
40
+
41
+ class QuickGELU(nn.Module):
42
+
43
+ def forward(self, x):
44
+ return x * torch.sigmoid(1.702 * x)
45
+
46
+
47
+ class LayerNorm(nn.LayerNorm):
48
+
49
+ def forward(self, x):
50
+ return super().forward(x.float()).type_as(x)
51
+
52
+
53
+ class SelfAttention(nn.Module):
54
+
55
+ def __init__(self,
56
+ dim,
57
+ num_heads,
58
+ causal=False,
59
+ attn_dropout=0.0,
60
+ proj_dropout=0.0):
61
+ assert dim % num_heads == 0
62
+ super().__init__()
63
+ self.dim = dim
64
+ self.num_heads = num_heads
65
+ self.head_dim = dim // num_heads
66
+ self.causal = causal
67
+ self.attn_dropout = attn_dropout
68
+ self.proj_dropout = proj_dropout
69
+
70
+ # layers
71
+ self.to_qkv = nn.Linear(dim, dim * 3)
72
+ self.proj = nn.Linear(dim, dim)
73
+
74
+ def forward(self, x):
75
+ """
76
+ x: [B, L, C].
77
+ """
78
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
79
+
80
+ # compute query, key, value
81
+ q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
82
+
83
+ # compute attention
84
+ p = self.attn_dropout if self.training else 0.0
85
+ x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
86
+ x = x.reshape(b, s, c)
87
+
88
+ # output
89
+ x = self.proj(x)
90
+ x = F.dropout(x, self.proj_dropout, self.training)
91
+ return x
92
+
93
+
94
+ class SwiGLU(nn.Module):
95
+
96
+ def __init__(self, dim, mid_dim):
97
+ super().__init__()
98
+ self.dim = dim
99
+ self.mid_dim = mid_dim
100
+
101
+ # layers
102
+ self.fc1 = nn.Linear(dim, mid_dim)
103
+ self.fc2 = nn.Linear(dim, mid_dim)
104
+ self.fc3 = nn.Linear(mid_dim, dim)
105
+
106
+ def forward(self, x):
107
+ x = F.silu(self.fc1(x)) * self.fc2(x)
108
+ x = self.fc3(x)
109
+ return x
110
+
111
+
112
+ class AttentionBlock(nn.Module):
113
+
114
+ def __init__(self,
115
+ dim,
116
+ mlp_ratio,
117
+ num_heads,
118
+ post_norm=False,
119
+ causal=False,
120
+ activation='quick_gelu',
121
+ attn_dropout=0.0,
122
+ proj_dropout=0.0,
123
+ norm_eps=1e-5):
124
+ assert activation in ['quick_gelu', 'gelu', 'swi_glu']
125
+ super().__init__()
126
+ self.dim = dim
127
+ self.mlp_ratio = mlp_ratio
128
+ self.num_heads = num_heads
129
+ self.post_norm = post_norm
130
+ self.causal = causal
131
+ self.norm_eps = norm_eps
132
+
133
+ # layers
134
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
135
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
136
+ proj_dropout)
137
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
138
+ if activation == 'swi_glu':
139
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
140
+ else:
141
+ self.mlp = nn.Sequential(
142
+ nn.Linear(dim, int(dim * mlp_ratio)),
143
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
144
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
145
+
146
+ def forward(self, x):
147
+ if self.post_norm:
148
+ x = x + self.norm1(self.attn(x))
149
+ x = x + self.norm2(self.mlp(x))
150
+ else:
151
+ x = x + self.attn(self.norm1(x))
152
+ x = x + self.mlp(self.norm2(x))
153
+ return x
154
+
155
+
156
+ class AttentionPool(nn.Module):
157
+
158
+ def __init__(self,
159
+ dim,
160
+ mlp_ratio,
161
+ num_heads,
162
+ activation='gelu',
163
+ proj_dropout=0.0,
164
+ norm_eps=1e-5):
165
+ assert dim % num_heads == 0
166
+ super().__init__()
167
+ self.dim = dim
168
+ self.mlp_ratio = mlp_ratio
169
+ self.num_heads = num_heads
170
+ self.head_dim = dim // num_heads
171
+ self.proj_dropout = proj_dropout
172
+ self.norm_eps = norm_eps
173
+
174
+ # layers
175
+ gain = 1.0 / math.sqrt(dim)
176
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
177
+ self.to_q = nn.Linear(dim, dim)
178
+ self.to_kv = nn.Linear(dim, dim * 2)
179
+ self.proj = nn.Linear(dim, dim)
180
+ self.norm = LayerNorm(dim, eps=norm_eps)
181
+ self.mlp = nn.Sequential(
182
+ nn.Linear(dim, int(dim * mlp_ratio)),
183
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
184
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
185
+
186
+ def forward(self, x):
187
+ """
188
+ x: [B, L, C].
189
+ """
190
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
191
+
192
+ # compute query, key, value
193
+ q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
194
+ k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
195
+
196
+ # compute attention
197
+ x = flash_attention(q, k, v, version=2)
198
+ x = x.reshape(b, 1, c)
199
+
200
+ # output
201
+ x = self.proj(x)
202
+ x = F.dropout(x, self.proj_dropout, self.training)
203
+
204
+ # mlp
205
+ x = x + self.mlp(self.norm(x))
206
+ return x[:, 0]
207
+
208
+
209
+ class VisionTransformer(nn.Module):
210
+
211
+ def __init__(self,
212
+ image_size=224,
213
+ patch_size=16,
214
+ dim=768,
215
+ mlp_ratio=4,
216
+ out_dim=512,
217
+ num_heads=12,
218
+ num_layers=12,
219
+ pool_type='token',
220
+ pre_norm=True,
221
+ post_norm=False,
222
+ activation='quick_gelu',
223
+ attn_dropout=0.0,
224
+ proj_dropout=0.0,
225
+ embedding_dropout=0.0,
226
+ norm_eps=1e-5):
227
+ if image_size % patch_size != 0:
228
+ print(
229
+ '[WARNING] image_size is not divisible by patch_size',
230
+ flush=True)
231
+ assert pool_type in ('token', 'token_fc', 'attn_pool')
232
+ out_dim = out_dim or dim
233
+ super().__init__()
234
+ self.image_size = image_size
235
+ self.patch_size = patch_size
236
+ self.num_patches = (image_size // patch_size)**2
237
+ self.dim = dim
238
+ self.mlp_ratio = mlp_ratio
239
+ self.out_dim = out_dim
240
+ self.num_heads = num_heads
241
+ self.num_layers = num_layers
242
+ self.pool_type = pool_type
243
+ self.post_norm = post_norm
244
+ self.norm_eps = norm_eps
245
+
246
+ # embeddings
247
+ gain = 1.0 / math.sqrt(dim)
248
+ self.patch_embedding = nn.Conv2d(
249
+ 3,
250
+ dim,
251
+ kernel_size=patch_size,
252
+ stride=patch_size,
253
+ bias=not pre_norm)
254
+ if pool_type in ('token', 'token_fc'):
255
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
256
+ self.pos_embedding = nn.Parameter(gain * torch.randn(
257
+ 1, self.num_patches +
258
+ (1 if pool_type in ('token', 'token_fc') else 0), dim))
259
+ self.dropout = nn.Dropout(embedding_dropout)
260
+
261
+ # transformer
262
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
263
+ self.transformer = nn.Sequential(*[
264
+ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
265
+ activation, attn_dropout, proj_dropout, norm_eps)
266
+ for _ in range(num_layers)
267
+ ])
268
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
269
+
270
+ # head
271
+ if pool_type == 'token':
272
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
273
+ elif pool_type == 'token_fc':
274
+ self.head = nn.Linear(dim, out_dim)
275
+ elif pool_type == 'attn_pool':
276
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
277
+ proj_dropout, norm_eps)
278
+
279
+ def forward(self, x, interpolation=False, use_31_block=False):
280
+ b = x.size(0)
281
+
282
+ # embeddings
283
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
284
+ if self.pool_type in ('token', 'token_fc'):
285
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
286
+ if interpolation:
287
+ e = pos_interpolate(self.pos_embedding, x.size(1))
288
+ else:
289
+ e = self.pos_embedding
290
+ x = self.dropout(x + e)
291
+ if self.pre_norm is not None:
292
+ x = self.pre_norm(x)
293
+
294
+ # transformer
295
+ if use_31_block:
296
+ x = self.transformer[:-1](x)
297
+ return x
298
+ else:
299
+ x = self.transformer(x)
300
+ return x
301
+
302
+
303
+ class XLMRobertaWithHead(XLMRoberta):
304
+
305
+ def __init__(self, **kwargs):
306
+ self.out_dim = kwargs.pop('out_dim')
307
+ super().__init__(**kwargs)
308
+
309
+ # head
310
+ mid_dim = (self.dim + self.out_dim) // 2
311
+ self.head = nn.Sequential(
312
+ nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
313
+ nn.Linear(mid_dim, self.out_dim, bias=False))
314
+
315
+ def forward(self, ids):
316
+ # xlm-roberta
317
+ x = super().forward(ids)
318
+
319
+ # average pooling
320
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
321
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
322
+
323
+ # head
324
+ x = self.head(x)
325
+ return x
326
+
327
+
328
+ class XLMRobertaCLIP(nn.Module):
329
+
330
+ def __init__(self,
331
+ embed_dim=1024,
332
+ image_size=224,
333
+ patch_size=14,
334
+ vision_dim=1280,
335
+ vision_mlp_ratio=4,
336
+ vision_heads=16,
337
+ vision_layers=32,
338
+ vision_pool='token',
339
+ vision_pre_norm=True,
340
+ vision_post_norm=False,
341
+ activation='gelu',
342
+ vocab_size=250002,
343
+ max_text_len=514,
344
+ type_size=1,
345
+ pad_id=1,
346
+ text_dim=1024,
347
+ text_heads=16,
348
+ text_layers=24,
349
+ text_post_norm=True,
350
+ text_dropout=0.1,
351
+ attn_dropout=0.0,
352
+ proj_dropout=0.0,
353
+ embedding_dropout=0.0,
354
+ norm_eps=1e-5):
355
+ super().__init__()
356
+ self.embed_dim = embed_dim
357
+ self.image_size = image_size
358
+ self.patch_size = patch_size
359
+ self.vision_dim = vision_dim
360
+ self.vision_mlp_ratio = vision_mlp_ratio
361
+ self.vision_heads = vision_heads
362
+ self.vision_layers = vision_layers
363
+ self.vision_pre_norm = vision_pre_norm
364
+ self.vision_post_norm = vision_post_norm
365
+ self.activation = activation
366
+ self.vocab_size = vocab_size
367
+ self.max_text_len = max_text_len
368
+ self.type_size = type_size
369
+ self.pad_id = pad_id
370
+ self.text_dim = text_dim
371
+ self.text_heads = text_heads
372
+ self.text_layers = text_layers
373
+ self.text_post_norm = text_post_norm
374
+ self.norm_eps = norm_eps
375
+
376
+ # models
377
+ self.visual = VisionTransformer(
378
+ image_size=image_size,
379
+ patch_size=patch_size,
380
+ dim=vision_dim,
381
+ mlp_ratio=vision_mlp_ratio,
382
+ out_dim=embed_dim,
383
+ num_heads=vision_heads,
384
+ num_layers=vision_layers,
385
+ pool_type=vision_pool,
386
+ pre_norm=vision_pre_norm,
387
+ post_norm=vision_post_norm,
388
+ activation=activation,
389
+ attn_dropout=attn_dropout,
390
+ proj_dropout=proj_dropout,
391
+ embedding_dropout=embedding_dropout,
392
+ norm_eps=norm_eps)
393
+ self.textual = XLMRobertaWithHead(
394
+ vocab_size=vocab_size,
395
+ max_seq_len=max_text_len,
396
+ type_size=type_size,
397
+ pad_id=pad_id,
398
+ dim=text_dim,
399
+ out_dim=embed_dim,
400
+ num_heads=text_heads,
401
+ num_layers=text_layers,
402
+ post_norm=text_post_norm,
403
+ dropout=text_dropout)
404
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
405
+
406
+ def forward(self, imgs, txt_ids):
407
+ """
408
+ imgs: [B, 3, H, W] of torch.float32.
409
+ - mean: [0.48145466, 0.4578275, 0.40821073]
410
+ - std: [0.26862954, 0.26130258, 0.27577711]
411
+ txt_ids: [B, L] of torch.long.
412
+ Encoded by data.CLIPTokenizer.
413
+ """
414
+ xi = self.visual(imgs)
415
+ xt = self.textual(txt_ids)
416
+ return xi, xt
417
+
418
+ def param_groups(self):
419
+ groups = [{
420
+ 'params': [
421
+ p for n, p in self.named_parameters()
422
+ if 'norm' in n or n.endswith('bias')
423
+ ],
424
+ 'weight_decay': 0.0
425
+ }, {
426
+ 'params': [
427
+ p for n, p in self.named_parameters()
428
+ if not ('norm' in n or n.endswith('bias'))
429
+ ]
430
+ }]
431
+ return groups
432
+
433
+
434
+ def _clip(pretrained=False,
435
+ pretrained_name=None,
436
+ model_cls=XLMRobertaCLIP,
437
+ return_transforms=False,
438
+ return_tokenizer=False,
439
+ tokenizer_padding='eos',
440
+ dtype=torch.float32,
441
+ device='cpu',
442
+ **kwargs):
443
+ # init a model on device
444
+ with torch.device(device):
445
+ model = model_cls(**kwargs)
446
+
447
+ # set device
448
+ model = model.to(dtype=dtype, device=device)
449
+ output = (model,)
450
+
451
+ # init transforms
452
+ if return_transforms:
453
+ # mean and std
454
+ if 'siglip' in pretrained_name.lower():
455
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
456
+ else:
457
+ mean = [0.48145466, 0.4578275, 0.40821073]
458
+ std = [0.26862954, 0.26130258, 0.27577711]
459
+
460
+ # transforms
461
+ transforms = T.Compose([
462
+ T.Resize((model.image_size, model.image_size),
463
+ interpolation=T.InterpolationMode.BICUBIC),
464
+ T.ToTensor(),
465
+ T.Normalize(mean=mean, std=std)
466
+ ])
467
+ output += (transforms,)
468
+ return output[0] if len(output) == 1 else output
469
+
470
+
471
+ def clip_xlm_roberta_vit_h_14(
472
+ pretrained=False,
473
+ pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
474
+ **kwargs):
475
+ cfg = dict(
476
+ embed_dim=1024,
477
+ image_size=224,
478
+ patch_size=14,
479
+ vision_dim=1280,
480
+ vision_mlp_ratio=4,
481
+ vision_heads=16,
482
+ vision_layers=32,
483
+ vision_pool='token',
484
+ activation='gelu',
485
+ vocab_size=250002,
486
+ max_text_len=514,
487
+ type_size=1,
488
+ pad_id=1,
489
+ text_dim=1024,
490
+ text_heads=16,
491
+ text_layers=24,
492
+ text_post_norm=True,
493
+ text_dropout=0.1,
494
+ attn_dropout=0.0,
495
+ proj_dropout=0.0,
496
+ embedding_dropout=0.0)
497
+ cfg.update(**kwargs)
498
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
499
+
500
+
501
+ class CLIPModel:
502
+
503
+ def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
504
+ self.dtype = dtype
505
+ self.device = device
506
+ self.checkpoint_path = checkpoint_path
507
+ self.tokenizer_path = tokenizer_path
508
+
509
+ # init model
510
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
511
+ pretrained=False,
512
+ return_transforms=True,
513
+ return_tokenizer=False,
514
+ dtype=dtype,
515
+ device=device)
516
+ self.model = self.model.eval().requires_grad_(False)
517
+ logging.info(f'loading {checkpoint_path}')
518
+ self.model.load_state_dict(
519
+ torch.load(checkpoint_path, map_location='cpu'))
520
+
521
+ # init tokenizer
522
+ self.tokenizer = HuggingfaceTokenizer(
523
+ name=tokenizer_path,
524
+ seq_len=self.model.max_text_len - 2,
525
+ clean='whitespace')
526
+
527
+ def visual(self, videos):
528
+ # preprocess
529
+ size = (self.model.image_size,) * 2
530
+ videos = torch.cat([
531
+ F.interpolate(
532
+ u.transpose(0, 1),
533
+ size=size,
534
+ mode='bicubic',
535
+ align_corners=False) for u in videos
536
+ ])
537
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
538
+
539
+ # forward
540
+ with torch.cuda.amp.autocast(dtype=self.dtype):
541
+ out = self.model.visual(videos, use_31_block=True)
542
+ return out
wan/modules/animate/face_blocks.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from torch import nn
3
+ import torch
4
+ from typing import Tuple, Optional
5
+ from einops import rearrange
6
+ import torch.nn.functional as F
7
+ import math
8
+ from ...distributed.util import gather_forward, get_rank, get_world_size
9
+
10
+
11
+ try:
12
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
13
+ except ImportError:
14
+ flash_attn_func = None
15
+
16
+ MEMORY_LAYOUT = {
17
+ "flash": (
18
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
19
+ lambda x: x,
20
+ ),
21
+ "torch": (
22
+ lambda x: x.transpose(1, 2),
23
+ lambda x: x.transpose(1, 2),
24
+ ),
25
+ "vanilla": (
26
+ lambda x: x.transpose(1, 2),
27
+ lambda x: x.transpose(1, 2),
28
+ ),
29
+ }
30
+
31
+
32
+ def attention(
33
+ q,
34
+ k,
35
+ v,
36
+ mode="flash",
37
+ drop_rate=0,
38
+ attn_mask=None,
39
+ causal=False,
40
+ max_seqlen_q=None,
41
+ batch_size=1,
42
+ ):
43
+ """
44
+ Perform QKV self attention.
45
+
46
+ Args:
47
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
48
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
49
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
50
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
51
+ drop_rate (float): Dropout rate in attention map. (default: 0)
52
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
53
+ (default: None)
54
+ causal (bool): Whether to use causal attention. (default: False)
55
+ cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
56
+ used to index into q.
57
+ cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
58
+ used to index into kv.
59
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
60
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
61
+
62
+ Returns:
63
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
64
+ """
65
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
66
+
67
+ if mode == "torch":
68
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
69
+ attn_mask = attn_mask.to(q.dtype)
70
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
71
+
72
+ elif mode == "flash":
73
+ x = flash_attn_func(
74
+ q,
75
+ k,
76
+ v,
77
+ )
78
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
79
+ elif mode == "vanilla":
80
+ scale_factor = 1 / math.sqrt(q.size(-1))
81
+
82
+ b, a, s, _ = q.shape
83
+ s1 = k.size(2)
84
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
85
+ if causal:
86
+ # Only applied to self attention
87
+ assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
88
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
89
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
90
+ attn_bias.to(q.dtype)
91
+
92
+ if attn_mask is not None:
93
+ if attn_mask.dtype == torch.bool:
94
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
95
+ else:
96
+ attn_bias += attn_mask
97
+
98
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
99
+ attn += attn_bias
100
+ attn = attn.softmax(dim=-1)
101
+ attn = torch.dropout(attn, p=drop_rate, train=True)
102
+ x = attn @ v
103
+ else:
104
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
105
+
106
+ x = post_attn_layout(x)
107
+ b, s, a, d = x.shape
108
+ out = x.reshape(b, s, -1)
109
+ return out
110
+
111
+
112
+ class CausalConv1d(nn.Module):
113
+
114
+ def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
115
+ super().__init__()
116
+
117
+ self.pad_mode = pad_mode
118
+ padding = (kernel_size - 1, 0) # T
119
+ self.time_causal_padding = padding
120
+
121
+ self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
122
+
123
+ def forward(self, x):
124
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
125
+ return self.conv(x)
126
+
127
+
128
+
129
+ class FaceEncoder(nn.Module):
130
+ def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
131
+ factory_kwargs = {"dtype": dtype, "device": device}
132
+ super().__init__()
133
+
134
+ self.num_heads = num_heads
135
+ self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
136
+ self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
137
+ self.act = nn.SiLU()
138
+ self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
139
+ self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
140
+
141
+ self.out_proj = nn.Linear(1024, hidden_dim)
142
+ self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
143
+
144
+ self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
145
+
146
+ self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
147
+
148
+ self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
149
+
150
+ def forward(self, x):
151
+
152
+ x = rearrange(x, "b t c -> b c t")
153
+ b, c, t = x.shape
154
+
155
+ x = self.conv1_local(x)
156
+ x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
157
+
158
+ x = self.norm1(x)
159
+ x = self.act(x)
160
+ x = rearrange(x, "b t c -> b c t")
161
+ x = self.conv2(x)
162
+ x = rearrange(x, "b c t -> b t c")
163
+ x = self.norm2(x)
164
+ x = self.act(x)
165
+ x = rearrange(x, "b t c -> b c t")
166
+ x = self.conv3(x)
167
+ x = rearrange(x, "b c t -> b t c")
168
+ x = self.norm3(x)
169
+ x = self.act(x)
170
+ x = self.out_proj(x)
171
+ x = rearrange(x, "(b n) t c -> b t n c", b=b)
172
+ padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
173
+ x = torch.cat([x, padding], dim=-2)
174
+ x_local = x.clone()
175
+
176
+ return x_local
177
+
178
+
179
+
180
+ class RMSNorm(nn.Module):
181
+ def __init__(
182
+ self,
183
+ dim: int,
184
+ elementwise_affine=True,
185
+ eps: float = 1e-6,
186
+ device=None,
187
+ dtype=None,
188
+ ):
189
+ """
190
+ Initialize the RMSNorm normalization layer.
191
+
192
+ Args:
193
+ dim (int): The dimension of the input tensor.
194
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
195
+
196
+ Attributes:
197
+ eps (float): A small value added to the denominator for numerical stability.
198
+ weight (nn.Parameter): Learnable scaling parameter.
199
+
200
+ """
201
+ factory_kwargs = {"device": device, "dtype": dtype}
202
+ super().__init__()
203
+ self.eps = eps
204
+ if elementwise_affine:
205
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
206
+
207
+ def _norm(self, x):
208
+ """
209
+ Apply the RMSNorm normalization to the input tensor.
210
+
211
+ Args:
212
+ x (torch.Tensor): The input tensor.
213
+
214
+ Returns:
215
+ torch.Tensor: The normalized tensor.
216
+
217
+ """
218
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
219
+
220
+ def forward(self, x):
221
+ """
222
+ Forward pass through the RMSNorm layer.
223
+
224
+ Args:
225
+ x (torch.Tensor): The input tensor.
226
+
227
+ Returns:
228
+ torch.Tensor: The output tensor after applying RMSNorm.
229
+
230
+ """
231
+ output = self._norm(x.float()).type_as(x)
232
+ if hasattr(self, "weight"):
233
+ output = output * self.weight
234
+ return output
235
+
236
+
237
+ def get_norm_layer(norm_layer):
238
+ """
239
+ Get the normalization layer.
240
+
241
+ Args:
242
+ norm_layer (str): The type of normalization layer.
243
+
244
+ Returns:
245
+ norm_layer (nn.Module): The normalization layer.
246
+ """
247
+ if norm_layer == "layer":
248
+ return nn.LayerNorm
249
+ elif norm_layer == "rms":
250
+ return RMSNorm
251
+ else:
252
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
253
+
254
+
255
+ class FaceAdapter(nn.Module):
256
+ def __init__(
257
+ self,
258
+ hidden_dim: int,
259
+ heads_num: int,
260
+ qk_norm: bool = True,
261
+ qk_norm_type: str = "rms",
262
+ num_adapter_layers: int = 1,
263
+ dtype=None,
264
+ device=None,
265
+ ):
266
+
267
+ factory_kwargs = {"dtype": dtype, "device": device}
268
+ super().__init__()
269
+ self.hidden_size = hidden_dim
270
+ self.heads_num = heads_num
271
+ self.fuser_blocks = nn.ModuleList(
272
+ [
273
+ FaceBlock(
274
+ self.hidden_size,
275
+ self.heads_num,
276
+ qk_norm=qk_norm,
277
+ qk_norm_type=qk_norm_type,
278
+ **factory_kwargs,
279
+ )
280
+ for _ in range(num_adapter_layers)
281
+ ]
282
+ )
283
+
284
+ def forward(
285
+ self,
286
+ x: torch.Tensor,
287
+ motion_embed: torch.Tensor,
288
+ idx: int,
289
+ freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
290
+ freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
291
+ ) -> torch.Tensor:
292
+
293
+ return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
294
+
295
+
296
+
297
+ class FaceBlock(nn.Module):
298
+ def __init__(
299
+ self,
300
+ hidden_size: int,
301
+ heads_num: int,
302
+ qk_norm: bool = True,
303
+ qk_norm_type: str = "rms",
304
+ qk_scale: float = None,
305
+ dtype: Optional[torch.dtype] = None,
306
+ device: Optional[torch.device] = None,
307
+ ):
308
+ factory_kwargs = {"device": device, "dtype": dtype}
309
+ super().__init__()
310
+
311
+ self.deterministic = False
312
+ self.hidden_size = hidden_size
313
+ self.heads_num = heads_num
314
+ head_dim = hidden_size // heads_num
315
+ self.scale = qk_scale or head_dim**-0.5
316
+
317
+ self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
318
+ self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
319
+
320
+ self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
321
+
322
+ qk_norm_layer = get_norm_layer(qk_norm_type)
323
+ self.q_norm = (
324
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
325
+ )
326
+ self.k_norm = (
327
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
328
+ )
329
+
330
+ self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
331
+
332
+ self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
333
+
334
+ def forward(
335
+ self,
336
+ x: torch.Tensor,
337
+ motion_vec: torch.Tensor,
338
+ motion_mask: Optional[torch.Tensor] = None,
339
+ use_context_parallel=False,
340
+ ) -> torch.Tensor:
341
+
342
+ B, T, N, C = motion_vec.shape
343
+ T_comp = T
344
+
345
+ x_motion = self.pre_norm_motion(motion_vec)
346
+ x_feat = self.pre_norm_feat(x)
347
+
348
+ kv = self.linear1_kv(x_motion)
349
+ q = self.linear1_q(x_feat)
350
+
351
+ k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
352
+ q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
353
+
354
+ # Apply QK-Norm if needed.
355
+ q = self.q_norm(q).to(v)
356
+ k = self.k_norm(k).to(v)
357
+
358
+ k = rearrange(k, "B L N H D -> (B L) N H D")
359
+ v = rearrange(v, "B L N H D -> (B L) N H D")
360
+
361
+ if use_context_parallel:
362
+ q = gather_forward(q, dim=1)
363
+
364
+ q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp)
365
+ # Compute attention.
366
+ attn = attention(
367
+ q,
368
+ k,
369
+ v,
370
+ max_seqlen_q=q.shape[1],
371
+ batch_size=q.shape[0],
372
+ )
373
+
374
+ attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
375
+ if use_context_parallel:
376
+ attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()]
377
+
378
+ output = self.linear2(attn)
379
+
380
+ if motion_mask is not None:
381
+ output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
382
+
383
+ return output
wan/modules/animate/model_animate.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import math
3
+ import types
4
+ from copy import deepcopy
5
+ from einops import rearrange
6
+ from typing import List
7
+ import numpy as np
8
+ import torch
9
+ import torch.cuda.amp as amp
10
+ import torch.nn as nn
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+ from diffusers.loaders import PeftAdapterMixin
14
+
15
+ from ...distributed.sequence_parallel import (
16
+ distributed_attention,
17
+ gather_forward,
18
+ get_rank,
19
+ get_world_size,
20
+ )
21
+
22
+
23
+ from ..model import (
24
+ Head,
25
+ WanAttentionBlock,
26
+ WanLayerNorm,
27
+ WanRMSNorm,
28
+ WanModel,
29
+ WanSelfAttention,
30
+ flash_attention,
31
+ rope_params,
32
+ sinusoidal_embedding_1d,
33
+ rope_apply
34
+ )
35
+
36
+ from .face_blocks import FaceEncoder, FaceAdapter
37
+ from .motion_encoder import Generator
38
+
39
+ class HeadAnimate(Head):
40
+
41
+ def forward(self, x, e):
42
+ """
43
+ Args:
44
+ x(Tensor): Shape [B, L1, C]
45
+ e(Tensor): Shape [B, L1, C]
46
+ """
47
+ assert e.dtype == torch.float32
48
+ with amp.autocast(dtype=torch.float32):
49
+ e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
50
+ x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
51
+ return x
52
+
53
+
54
+ class WanAnimateSelfAttention(WanSelfAttention):
55
+
56
+ def forward(self, x, seq_lens, grid_sizes, freqs):
57
+ """
58
+ Args:
59
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
60
+ seq_lens(Tensor): Shape [B]
61
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
62
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
63
+ """
64
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
65
+
66
+ # query, key, value function
67
+ def qkv_fn(x):
68
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
69
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
70
+ v = self.v(x).view(b, s, n, d)
71
+ return q, k, v
72
+
73
+ q, k, v = qkv_fn(x)
74
+
75
+ x = flash_attention(
76
+ q=rope_apply(q, grid_sizes, freqs),
77
+ k=rope_apply(k, grid_sizes, freqs),
78
+ v=v,
79
+ k_lens=seq_lens,
80
+ window_size=self.window_size)
81
+
82
+ # output
83
+ x = x.flatten(2)
84
+ x = self.o(x)
85
+ return x
86
+
87
+
88
+ class WanAnimateCrossAttention(WanSelfAttention):
89
+ def __init__(
90
+ self,
91
+ dim,
92
+ num_heads,
93
+ window_size=(-1, -1),
94
+ qk_norm=True,
95
+ eps=1e-6,
96
+ use_img_emb=True
97
+ ):
98
+ super().__init__(
99
+ dim,
100
+ num_heads,
101
+ window_size,
102
+ qk_norm,
103
+ eps
104
+ )
105
+ self.use_img_emb = use_img_emb
106
+
107
+ if use_img_emb:
108
+ self.k_img = nn.Linear(dim, dim)
109
+ self.v_img = nn.Linear(dim, dim)
110
+ self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
111
+
112
+ def forward(self, x, context, context_lens):
113
+ """
114
+ x: [B, L1, C].
115
+ context: [B, L2, C].
116
+ context_lens: [B].
117
+ """
118
+ if self.use_img_emb:
119
+ context_img = context[:, :257]
120
+ context = context[:, 257:]
121
+ else:
122
+ context = context
123
+
124
+ b, n, d = x.size(0), self.num_heads, self.head_dim
125
+
126
+ # compute query, key, value
127
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
128
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
129
+ v = self.v(context).view(b, -1, n, d)
130
+
131
+ if self.use_img_emb:
132
+ k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
133
+ v_img = self.v_img(context_img).view(b, -1, n, d)
134
+ img_x = flash_attention(q, k_img, v_img, k_lens=None)
135
+ # compute attention
136
+ x = flash_attention(q, k, v, k_lens=context_lens)
137
+
138
+ # output
139
+ x = x.flatten(2)
140
+
141
+ if self.use_img_emb:
142
+ img_x = img_x.flatten(2)
143
+ x = x + img_x
144
+
145
+ x = self.o(x)
146
+ return x
147
+
148
+
149
+ class WanAnimateAttentionBlock(nn.Module):
150
+ def __init__(self,
151
+ dim,
152
+ ffn_dim,
153
+ num_heads,
154
+ window_size=(-1, -1),
155
+ qk_norm=True,
156
+ cross_attn_norm=True,
157
+ eps=1e-6,
158
+ use_img_emb=True):
159
+
160
+ super().__init__()
161
+ self.dim = dim
162
+ self.ffn_dim = ffn_dim
163
+ self.num_heads = num_heads
164
+ self.window_size = window_size
165
+ self.qk_norm = qk_norm
166
+ self.cross_attn_norm = cross_attn_norm
167
+ self.eps = eps
168
+
169
+ # layers
170
+ self.norm1 = WanLayerNorm(dim, eps)
171
+ self.self_attn = WanAnimateSelfAttention(dim, num_heads, window_size, qk_norm, eps)
172
+
173
+ self.norm3 = WanLayerNorm(
174
+ dim, eps, elementwise_affine=True
175
+ ) if cross_attn_norm else nn.Identity()
176
+
177
+ self.cross_attn = WanAnimateCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps, use_img_emb=use_img_emb)
178
+ self.norm2 = WanLayerNorm(dim, eps)
179
+ self.ffn = nn.Sequential(
180
+ nn.Linear(dim, ffn_dim),
181
+ nn.GELU(approximate='tanh'),
182
+ nn.Linear(ffn_dim, dim)
183
+ )
184
+
185
+ # modulation
186
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5)
187
+
188
+ def forward(
189
+ self,
190
+ x,
191
+ e,
192
+ seq_lens,
193
+ grid_sizes,
194
+ freqs,
195
+ context,
196
+ context_lens,
197
+ ):
198
+ """
199
+ Args:
200
+ x(Tensor): Shape [B, L, C]
201
+ e(Tensor): Shape [B, L1, 6, C]
202
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
203
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
204
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
205
+ """
206
+ assert e.dtype == torch.float32
207
+ with amp.autocast(dtype=torch.float32):
208
+ e = (self.modulation + e).chunk(6, dim=1)
209
+ assert e[0].dtype == torch.float32
210
+
211
+ # self-attention
212
+ y = self.self_attn(
213
+ self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs
214
+ )
215
+ with amp.autocast(dtype=torch.float32):
216
+ x = x + y * e[2]
217
+
218
+ # cross-attention & ffn function
219
+ def cross_attn_ffn(x, context, context_lens, e):
220
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
221
+ y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
222
+ with amp.autocast(dtype=torch.float32):
223
+ x = x + y * e[5]
224
+ return x
225
+
226
+ x = cross_attn_ffn(x, context, context_lens, e)
227
+ return x
228
+
229
+
230
+ class MLPProj(torch.nn.Module):
231
+ def __init__(self, in_dim, out_dim):
232
+ super().__init__()
233
+
234
+ self.proj = torch.nn.Sequential(
235
+ torch.nn.LayerNorm(in_dim),
236
+ torch.nn.Linear(in_dim, in_dim),
237
+ torch.nn.GELU(),
238
+ torch.nn.Linear(in_dim, out_dim),
239
+ torch.nn.LayerNorm(out_dim),
240
+ )
241
+
242
+ def forward(self, image_embeds):
243
+ clip_extra_context_tokens = self.proj(image_embeds)
244
+ return clip_extra_context_tokens
245
+
246
+ class WanAnimateModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
247
+ _no_split_modules = ['WanAttentionBlock']
248
+
249
+ @register_to_config
250
+ def __init__(self,
251
+ patch_size=(1, 2, 2),
252
+ text_len=512,
253
+ in_dim=36,
254
+ dim=5120,
255
+ ffn_dim=13824,
256
+ freq_dim=256,
257
+ text_dim=4096,
258
+ out_dim=16,
259
+ num_heads=40,
260
+ num_layers=40,
261
+ window_size=(-1, -1),
262
+ qk_norm=True,
263
+ cross_attn_norm=True,
264
+ eps=1e-6,
265
+ motion_encoder_dim=512,
266
+ use_context_parallel=False,
267
+ use_img_emb=True):
268
+
269
+ super().__init__()
270
+ self.patch_size = patch_size
271
+ self.text_len = text_len
272
+ self.in_dim = in_dim
273
+ self.dim = dim
274
+ self.ffn_dim = ffn_dim
275
+ self.freq_dim = freq_dim
276
+ self.text_dim = text_dim
277
+ self.out_dim = out_dim
278
+ self.num_heads = num_heads
279
+ self.num_layers = num_layers
280
+ self.window_size = window_size
281
+ self.qk_norm = qk_norm
282
+ self.cross_attn_norm = cross_attn_norm
283
+ self.eps = eps
284
+ self.motion_encoder_dim = motion_encoder_dim
285
+ self.use_context_parallel = use_context_parallel
286
+ self.use_img_emb = use_img_emb
287
+
288
+ # embeddings
289
+ self.patch_embedding = nn.Conv3d(
290
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
291
+
292
+ self.pose_patch_embedding = nn.Conv3d(
293
+ 16, dim, kernel_size=patch_size, stride=patch_size
294
+ )
295
+
296
+ self.text_embedding = nn.Sequential(
297
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
298
+ nn.Linear(dim, dim))
299
+
300
+ self.time_embedding = nn.Sequential(
301
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
302
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
303
+
304
+ # blocks
305
+ self.blocks = nn.ModuleList([
306
+ WanAnimateAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
307
+ cross_attn_norm, eps, use_img_emb) for _ in range(num_layers)
308
+ ])
309
+
310
+ # head
311
+ self.head = HeadAnimate(dim, out_dim, patch_size, eps)
312
+
313
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
314
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
315
+ d = dim // num_heads
316
+ self.freqs = torch.cat([
317
+ rope_params(1024, d - 4 * (d // 6)),
318
+ rope_params(1024, 2 * (d // 6)),
319
+ rope_params(1024, 2 * (d // 6))
320
+ ], dim=1)
321
+
322
+ self.img_emb = MLPProj(1280, dim)
323
+
324
+ # initialize weights
325
+ self.init_weights()
326
+
327
+ self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
328
+ self.face_adapter = FaceAdapter(
329
+ heads_num=self.num_heads,
330
+ hidden_dim=self.dim,
331
+ num_adapter_layers=self.num_layers // 5,
332
+ )
333
+
334
+ self.face_encoder = FaceEncoder(
335
+ in_dim=motion_encoder_dim,
336
+ hidden_dim=self.dim,
337
+ num_heads=4,
338
+ )
339
+
340
+ def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
341
+ pose_latents = [self.pose_patch_embedding(u.unsqueeze(0)) for u in pose_latents]
342
+ for x_, pose_latents_ in zip(x, pose_latents):
343
+ x_[:, :, 1:] += pose_latents_
344
+
345
+ b,c,T,h,w = face_pixel_values.shape
346
+ face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
347
+
348
+ encode_bs = 8
349
+ face_pixel_values_tmp = []
350
+ for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
351
+ face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
352
+
353
+ motion_vec = torch.cat(face_pixel_values_tmp)
354
+
355
+ motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
356
+ motion_vec = self.face_encoder(motion_vec)
357
+
358
+ B, L, H, C = motion_vec.shape
359
+ pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
360
+ motion_vec = torch.cat([pad_face, motion_vec], dim=1)
361
+ return x, motion_vec
362
+
363
+
364
+ def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None):
365
+ if block_idx % 5 == 0:
366
+ adapter_args = [x, motion_vec, motion_masks, self.use_context_parallel]
367
+ residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args)
368
+ x = residual_out + x
369
+ return x
370
+
371
+
372
+ def forward(
373
+ self,
374
+ x,
375
+ t,
376
+ clip_fea,
377
+ context,
378
+ seq_len,
379
+ y=None,
380
+ pose_latents=None,
381
+ face_pixel_values=None
382
+ ):
383
+ # params
384
+ device = self.patch_embedding.weight.device
385
+ if self.freqs.device != device:
386
+ self.freqs = self.freqs.to(device)
387
+
388
+ if y is not None:
389
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
390
+
391
+ # embeddings
392
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
393
+ x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values)
394
+
395
+ grid_sizes = torch.stack(
396
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
397
+ x = [u.flatten(2).transpose(1, 2) for u in x]
398
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
399
+ assert seq_lens.max() <= seq_len
400
+ x = torch.cat([
401
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
402
+ dim=1) for u in x
403
+ ])
404
+
405
+ # time embeddings
406
+ with amp.autocast(dtype=torch.float32):
407
+ e = self.time_embedding(
408
+ sinusoidal_embedding_1d(self.freq_dim, t).float()
409
+ )
410
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
411
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
412
+
413
+ # context
414
+ context_lens = None
415
+ context = self.text_embedding(
416
+ torch.stack([
417
+ torch.cat(
418
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
419
+ for u in context
420
+ ]))
421
+
422
+ if self.use_img_emb:
423
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
424
+ context = torch.concat([context_clip, context], dim=1)
425
+
426
+ # arguments
427
+ kwargs = dict(
428
+ e=e0,
429
+ seq_lens=seq_lens,
430
+ grid_sizes=grid_sizes,
431
+ freqs=self.freqs,
432
+ context=context,
433
+ context_lens=context_lens)
434
+
435
+ if self.use_context_parallel:
436
+ x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]
437
+
438
+ for idx, block in enumerate(self.blocks):
439
+ x = block(x, **kwargs)
440
+ x = self.after_transformer_block(idx, x, motion_vec)
441
+
442
+ # head
443
+ x = self.head(x, e)
444
+
445
+ if self.use_context_parallel:
446
+ x = gather_forward(x, dim=1)
447
+
448
+ # unpatchify
449
+ x = self.unpatchify(x, grid_sizes)
450
+ return [u.float() for u in x]
451
+
452
+
453
+ def unpatchify(self, x, grid_sizes):
454
+ r"""
455
+ Reconstruct video tensors from patch embeddings.
456
+
457
+ Args:
458
+ x (List[Tensor]):
459
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
460
+ grid_sizes (Tensor):
461
+ Original spatial-temporal grid dimensions before patching,
462
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
463
+
464
+ Returns:
465
+ List[Tensor]:
466
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
467
+ """
468
+
469
+ c = self.out_dim
470
+ out = []
471
+ for u, v in zip(x, grid_sizes.tolist()):
472
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
473
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
474
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
475
+ out.append(u)
476
+ return out
477
+
478
+ def init_weights(self):
479
+ r"""
480
+ Initialize model parameters using Xavier initialization.
481
+ """
482
+
483
+ # basic init
484
+ for m in self.modules():
485
+ if isinstance(m, nn.Linear):
486
+ nn.init.xavier_uniform_(m.weight)
487
+ if m.bias is not None:
488
+ nn.init.zeros_(m.bias)
489
+
490
+ # init embeddings
491
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
492
+ for m in self.text_embedding.modules():
493
+ if isinstance(m, nn.Linear):
494
+ nn.init.normal_(m.weight, std=.02)
495
+ for m in self.time_embedding.modules():
496
+ if isinstance(m, nn.Linear):
497
+ nn.init.normal_(m.weight, std=.02)
498
+
499
+ # init output layer
500
+ nn.init.zeros_(self.head.head.weight)
wan/modules/animate/motion_encoder.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from ``https://github.com/wyhsirius/LIA``
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ import math
7
+
8
+ def custom_qr(input_tensor):
9
+ original_dtype = input_tensor.dtype
10
+ if original_dtype == torch.bfloat16:
11
+ q, r = torch.linalg.qr(input_tensor.to(torch.float32))
12
+ return q.to(original_dtype), r.to(original_dtype)
13
+ return torch.linalg.qr(input_tensor)
14
+
15
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
16
+ return F.leaky_relu(input + bias, negative_slope) * scale
17
+
18
+
19
+ def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
20
+ _, minor, in_h, in_w = input.shape
21
+ kernel_h, kernel_w = kernel.shape
22
+
23
+ out = input.view(-1, minor, in_h, 1, in_w, 1)
24
+ out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
25
+ out = out.view(-1, minor, in_h * up_y, in_w * up_x)
26
+
27
+ out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
28
+ out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
29
+ max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
30
+
31
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
32
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
33
+ out = F.conv2d(out, w)
34
+ out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
35
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
36
+ return out[:, :, ::down_y, ::down_x]
37
+
38
+
39
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
40
+ return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
41
+
42
+
43
+ def make_kernel(k):
44
+ k = torch.tensor(k, dtype=torch.float32)
45
+ if k.ndim == 1:
46
+ k = k[None, :] * k[:, None]
47
+ k /= k.sum()
48
+ return k
49
+
50
+
51
+ class FusedLeakyReLU(nn.Module):
52
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
53
+ super().__init__()
54
+ self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
55
+ self.negative_slope = negative_slope
56
+ self.scale = scale
57
+
58
+ def forward(self, input):
59
+ out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
60
+ return out
61
+
62
+
63
+ class Blur(nn.Module):
64
+ def __init__(self, kernel, pad, upsample_factor=1):
65
+ super().__init__()
66
+
67
+ kernel = make_kernel(kernel)
68
+
69
+ if upsample_factor > 1:
70
+ kernel = kernel * (upsample_factor ** 2)
71
+
72
+ self.register_buffer('kernel', kernel)
73
+
74
+ self.pad = pad
75
+
76
+ def forward(self, input):
77
+ return upfirdn2d(input, self.kernel, pad=self.pad)
78
+
79
+
80
+ class ScaledLeakyReLU(nn.Module):
81
+ def __init__(self, negative_slope=0.2):
82
+ super().__init__()
83
+
84
+ self.negative_slope = negative_slope
85
+
86
+ def forward(self, input):
87
+ return F.leaky_relu(input, negative_slope=self.negative_slope)
88
+
89
+
90
+ class EqualConv2d(nn.Module):
91
+ def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
92
+ super().__init__()
93
+
94
+ self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
95
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
96
+
97
+ self.stride = stride
98
+ self.padding = padding
99
+
100
+ if bias:
101
+ self.bias = nn.Parameter(torch.zeros(out_channel))
102
+ else:
103
+ self.bias = None
104
+
105
+ def forward(self, input):
106
+
107
+ return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
108
+
109
+ def __repr__(self):
110
+ return (
111
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
112
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
113
+ )
114
+
115
+
116
+ class EqualLinear(nn.Module):
117
+ def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
118
+ super().__init__()
119
+
120
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
121
+
122
+ if bias:
123
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
124
+ else:
125
+ self.bias = None
126
+
127
+ self.activation = activation
128
+
129
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
130
+ self.lr_mul = lr_mul
131
+
132
+ def forward(self, input):
133
+
134
+ if self.activation:
135
+ out = F.linear(input, self.weight * self.scale)
136
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
137
+ else:
138
+ out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
139
+
140
+ return out
141
+
142
+ def __repr__(self):
143
+ return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
144
+
145
+
146
+ class ConvLayer(nn.Sequential):
147
+ def __init__(
148
+ self,
149
+ in_channel,
150
+ out_channel,
151
+ kernel_size,
152
+ downsample=False,
153
+ blur_kernel=[1, 3, 3, 1],
154
+ bias=True,
155
+ activate=True,
156
+ ):
157
+ layers = []
158
+
159
+ if downsample:
160
+ factor = 2
161
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
162
+ pad0 = (p + 1) // 2
163
+ pad1 = p // 2
164
+
165
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
166
+
167
+ stride = 2
168
+ self.padding = 0
169
+
170
+ else:
171
+ stride = 1
172
+ self.padding = kernel_size // 2
173
+
174
+ layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
175
+ bias=bias and not activate))
176
+
177
+ if activate:
178
+ if bias:
179
+ layers.append(FusedLeakyReLU(out_channel))
180
+ else:
181
+ layers.append(ScaledLeakyReLU(0.2))
182
+
183
+ super().__init__(*layers)
184
+
185
+
186
+ class ResBlock(nn.Module):
187
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
188
+ super().__init__()
189
+
190
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
191
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
192
+
193
+ self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
194
+
195
+ def forward(self, input):
196
+ out = self.conv1(input)
197
+ out = self.conv2(out)
198
+
199
+ skip = self.skip(input)
200
+ out = (out + skip) / math.sqrt(2)
201
+
202
+ return out
203
+
204
+
205
+ class EncoderApp(nn.Module):
206
+ def __init__(self, size, w_dim=512):
207
+ super(EncoderApp, self).__init__()
208
+
209
+ channels = {
210
+ 4: 512,
211
+ 8: 512,
212
+ 16: 512,
213
+ 32: 512,
214
+ 64: 256,
215
+ 128: 128,
216
+ 256: 64,
217
+ 512: 32,
218
+ 1024: 16
219
+ }
220
+
221
+ self.w_dim = w_dim
222
+ log_size = int(math.log(size, 2))
223
+
224
+ self.convs = nn.ModuleList()
225
+ self.convs.append(ConvLayer(3, channels[size], 1))
226
+
227
+ in_channel = channels[size]
228
+ for i in range(log_size, 2, -1):
229
+ out_channel = channels[2 ** (i - 1)]
230
+ self.convs.append(ResBlock(in_channel, out_channel))
231
+ in_channel = out_channel
232
+
233
+ self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
234
+
235
+ def forward(self, x):
236
+
237
+ res = []
238
+ h = x
239
+ for conv in self.convs:
240
+ h = conv(h)
241
+ res.append(h)
242
+
243
+ return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
244
+
245
+
246
+ class Encoder(nn.Module):
247
+ def __init__(self, size, dim=512, dim_motion=20):
248
+ super(Encoder, self).__init__()
249
+
250
+ # appearance netmork
251
+ self.net_app = EncoderApp(size, dim)
252
+
253
+ # motion network
254
+ fc = [EqualLinear(dim, dim)]
255
+ for i in range(3):
256
+ fc.append(EqualLinear(dim, dim))
257
+
258
+ fc.append(EqualLinear(dim, dim_motion))
259
+ self.fc = nn.Sequential(*fc)
260
+
261
+ def enc_app(self, x):
262
+ h_source = self.net_app(x)
263
+ return h_source
264
+
265
+ def enc_motion(self, x):
266
+ h, _ = self.net_app(x)
267
+ h_motion = self.fc(h)
268
+ return h_motion
269
+
270
+
271
+ class Direction(nn.Module):
272
+ def __init__(self, motion_dim):
273
+ super(Direction, self).__init__()
274
+ self.weight = nn.Parameter(torch.randn(512, motion_dim))
275
+
276
+ def forward(self, input):
277
+
278
+ weight = self.weight + 1e-8
279
+ Q, R = custom_qr(weight)
280
+ if input is None:
281
+ return Q
282
+ else:
283
+ input_diag = torch.diag_embed(input) # alpha, diagonal matrix
284
+ out = torch.matmul(input_diag, Q.T)
285
+ out = torch.sum(out, dim=1)
286
+ return out
287
+
288
+
289
+ class Synthesis(nn.Module):
290
+ def __init__(self, motion_dim):
291
+ super(Synthesis, self).__init__()
292
+ self.direction = Direction(motion_dim)
293
+
294
+
295
+ class Generator(nn.Module):
296
+ def __init__(self, size, style_dim=512, motion_dim=20):
297
+ super().__init__()
298
+
299
+ self.enc = Encoder(size, style_dim, motion_dim)
300
+ self.dec = Synthesis(motion_dim)
301
+
302
+ def get_motion(self, img):
303
+ #motion_feat = self.enc.enc_motion(img)
304
+ motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
305
+ with torch.cuda.amp.autocast(dtype=torch.float32):
306
+ motion = self.dec.direction(motion_feat)
307
+ return motion
wan/modules/animate/preprocess/UserGuider.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Wan-animate Preprocessing User Guider
2
+
3
+ ## 1. Introductions
4
+
5
+
6
+ Wan-animate offers two generation modes: `animation` and `replacement`. While both modes extract the skeleton from the reference video, they each have a distinct preprocessing pipeline.
7
+
8
+ ### 1.1 Animation Mode
9
+
10
+ In this mode, it is highly recommended to enable pose retargeting, especially if the body proportions of the reference and driving characters are dissimilar.
11
+
12
+ - A simplified version of pose retargeting pipeline is provided to help developers quickly implement this functionality.
13
+
14
+ - **NOTE:** Due to the potential complexity of input data, the results from this simplified retargeting version are NOT guaranteed to be perfect. It is strongly advised to verify the preprocessing results before proceeding.
15
+
16
+ - Community contributions to improve on this feature are welcome.
17
+
18
+ ### 1.2 Replacement Mode
19
+
20
+ - Pose retargeting is DISABLED by default in this mode. This is a deliberate choice to account for potential spatial interactions between the character and the environment.
21
+
22
+ - **WARNING**: If there is a significant mismatch in body proportions between the reference and driving characters, artifacts or deformations may appear in the final output.
23
+
24
+ - A simplified version for extracting the character's mask is also provided.
25
+ - **WARNING:** This mask extraction process is designed for **single-person videos ONLY** and may produce incorrect results or fail in multi-person videos (incorrect pose tracking). For multi-person video, users are required to either develop their own solution or integrate a suitable open-source tool.
26
+
27
+ ---
28
+
29
+ ## 2. Preprocessing Instructions and Recommendations
30
+
31
+ ### 2.1 Basic Usage
32
+
33
+ - The preprocessing process requires some additional models, including pose detection (mandatory), and mask extraction and image editing models (optional, as needed). Place them according to the following directory structure:
34
+ ```
35
+ /path/to/your/ckpt_path/
36
+ ├── det/
37
+ │ └── yolov10m.onnx
38
+ ├── pose2d/
39
+ │ └── vitpose_h_wholebody.onnx
40
+ ├── sam2/
41
+ │ └── sam2_hiera_large.pt
42
+ └── FLUX.1-Kontext-dev/
43
+ ```
44
+ - `video_path`, `refer_path`, and `save_path` correspond to the paths for the input driving video, the character image, and the preprocessed results.
45
+
46
+ - When using `animation` mode, two videos, `src_face.mp4` and `src_pose.mp4`, will be generated in `save_path`. When using `replacement` mode, two additional videos, `src_bg.mp4` and `src_mask.mp4`, will also be generated.
47
+
48
+ - The `resolution_area` parameter determines the resolution for both preprocessing and the generation model. Its size is determined by pixel area.
49
+
50
+ - The `fps` parameter can specify the frame rate for video processing. A lower frame rate can improve generation efficiency, but may cause stuttering or choppiness.
51
+
52
+ ---
53
+
54
+ ### 2.2 Animation Mode
55
+
56
+ - We support three forms: not using pose retargeting, using basic pose retargeting, and using enhanced pose retargeting based on the `FLUX.1-Kontext-dev` image editing model. These are specified via the `retarget_flag` and `use_flux` parameters.
57
+
58
+ - Specifying `retarget_flag` to use basic pose retargeting requires ensuring that both the reference character and the character in the first frame of the driving video are in a front-facing, stretched pose.
59
+
60
+ - Other than that, we recommend using enhanced pose retargeting by specifying both `retarget_flag` and `use_flux`. **NOTE:** Due to the limited capabilities of `FLUX.1-Kontext-dev`, it is NOT guaranteed to produce the expected results (e.g., consistency is not maintained, the pose is incorrect, etc.). It is recommended to check the intermediate results as well as the finally generated pose video; both are stored in `save_path`. Of course, users can also use a better image editing model, or explore the prompts for Flux on their own.
61
+
62
+ ---
63
+
64
+ ### 2.3 Replacement Mode
65
+
66
+ - Specifying `replace_flag` to enable data preprocessing for this mode. The preprocessing will additionally process a mask for the character in the video, and its size and shape can be adjusted by specifying some parameters.
67
+ - `iterations` and `k` can make the mask larger, covering more area.
68
+ - `w_len` and `h_len` can adjust the mask's shape. Smaller values will make the outline coarser, while larger values will make it finer.
69
+
70
+ - A smaller, finer-contoured mask can allow for more of the original background to be preserved, but may potentially limit the character's generation area (considering potential appearance differences, this can lead to some shape leakage). A larger, coarser mask can allow the character generation to be more flexible and consistent, but because it includes more of the background, it might affect the background's consistency. We recommend users to adjust the relevant parameters based on their specific input data.
wan/modules/animate/preprocess/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from .process_pipepline import ProcessPipeline
3
+ from .video_predictor import SAM2VideoPredictor
wan/modules/animate/preprocess/human_visualization.py ADDED
@@ -0,0 +1,1357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import os
3
+ import cv2
4
+ import time
5
+ import math
6
+ import matplotlib
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ from typing import Dict, List
10
+ import random
11
+ from pose2d_utils import AAPoseMeta
12
+
13
+
14
+ def draw_handpose(canvas, keypoints, hand_score_th=0.6):
15
+ """
16
+ Draw keypoints and connections representing hand pose on a given canvas.
17
+
18
+ Args:
19
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
20
+ keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
21
+ or None if no keypoints are present.
22
+
23
+ Returns:
24
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
25
+
26
+ Note:
27
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
28
+ """
29
+ eps = 0.01
30
+
31
+ H, W, C = canvas.shape
32
+ stickwidth = max(int(min(H, W) / 200), 1)
33
+
34
+ edges = [
35
+ [0, 1],
36
+ [1, 2],
37
+ [2, 3],
38
+ [3, 4],
39
+ [0, 5],
40
+ [5, 6],
41
+ [6, 7],
42
+ [7, 8],
43
+ [0, 9],
44
+ [9, 10],
45
+ [10, 11],
46
+ [11, 12],
47
+ [0, 13],
48
+ [13, 14],
49
+ [14, 15],
50
+ [15, 16],
51
+ [0, 17],
52
+ [17, 18],
53
+ [18, 19],
54
+ [19, 20],
55
+ ]
56
+
57
+ for ie, (e1, e2) in enumerate(edges):
58
+ k1 = keypoints[e1]
59
+ k2 = keypoints[e2]
60
+ if k1 is None or k2 is None:
61
+ continue
62
+ if k1[2] < hand_score_th or k2[2] < hand_score_th:
63
+ continue
64
+
65
+ x1 = int(k1[0])
66
+ y1 = int(k1[1])
67
+ x2 = int(k2[0])
68
+ y2 = int(k2[1])
69
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
70
+ cv2.line(
71
+ canvas,
72
+ (x1, y1),
73
+ (x2, y2),
74
+ matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
75
+ thickness=stickwidth,
76
+ )
77
+
78
+ for keypoint in keypoints:
79
+
80
+ if keypoint is None:
81
+ continue
82
+ if keypoint[2] < hand_score_th:
83
+ continue
84
+
85
+ x, y = keypoint[0], keypoint[1]
86
+ x = int(x)
87
+ y = int(y)
88
+ if x > eps and y > eps:
89
+ cv2.circle(canvas, (x, y), stickwidth, (0, 0, 255), thickness=-1)
90
+ return canvas
91
+
92
+
93
+ def draw_handpose_new(canvas, keypoints, stickwidth_type='v2', hand_score_th=0.6):
94
+ """
95
+ Draw keypoints and connections representing hand pose on a given canvas.
96
+
97
+ Args:
98
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
99
+ keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
100
+ or None if no keypoints are present.
101
+
102
+ Returns:
103
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
104
+
105
+ Note:
106
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
107
+ """
108
+ eps = 0.01
109
+
110
+ H, W, C = canvas.shape
111
+ if stickwidth_type == 'v1':
112
+ stickwidth = max(int(min(H, W) / 200), 1)
113
+ elif stickwidth_type == 'v2':
114
+ stickwidth = max(max(int(min(H, W) / 200) - 1, 1) // 2, 1)
115
+
116
+ edges = [
117
+ [0, 1],
118
+ [1, 2],
119
+ [2, 3],
120
+ [3, 4],
121
+ [0, 5],
122
+ [5, 6],
123
+ [6, 7],
124
+ [7, 8],
125
+ [0, 9],
126
+ [9, 10],
127
+ [10, 11],
128
+ [11, 12],
129
+ [0, 13],
130
+ [13, 14],
131
+ [14, 15],
132
+ [15, 16],
133
+ [0, 17],
134
+ [17, 18],
135
+ [18, 19],
136
+ [19, 20],
137
+ ]
138
+
139
+ for ie, (e1, e2) in enumerate(edges):
140
+ k1 = keypoints[e1]
141
+ k2 = keypoints[e2]
142
+ if k1 is None or k2 is None:
143
+ continue
144
+ if k1[2] < hand_score_th or k2[2] < hand_score_th:
145
+ continue
146
+
147
+ x1 = int(k1[0])
148
+ y1 = int(k1[1])
149
+ x2 = int(k2[0])
150
+ y2 = int(k2[1])
151
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
152
+ cv2.line(
153
+ canvas,
154
+ (x1, y1),
155
+ (x2, y2),
156
+ matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
157
+ thickness=stickwidth,
158
+ )
159
+
160
+ for keypoint in keypoints:
161
+
162
+ if keypoint is None:
163
+ continue
164
+ if keypoint[2] < hand_score_th:
165
+ continue
166
+
167
+ x, y = keypoint[0], keypoint[1]
168
+ x = int(x)
169
+ y = int(y)
170
+ if x > eps and y > eps:
171
+ cv2.circle(canvas, (x, y), stickwidth, (0, 0, 255), thickness=-1)
172
+ return canvas
173
+
174
+
175
+ def draw_ellipse_by_2kp(img, keypoint1, keypoint2, color, threshold=0.6):
176
+ H, W, C = img.shape
177
+ stickwidth = max(int(min(H, W) / 200), 1)
178
+
179
+ if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
180
+ return img
181
+
182
+ Y = np.array([keypoint1[0], keypoint2[0]])
183
+ X = np.array([keypoint1[1], keypoint2[1]])
184
+ mX = np.mean(X)
185
+ mY = np.mean(Y)
186
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
187
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
188
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
189
+ cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
190
+ return img
191
+
192
+
193
+ def split_pose2d_kps_to_aa(kp2ds: np.ndarray) -> List[np.ndarray]:
194
+ """Convert the 133 keypoints from pose2d to body and hands keypoints.
195
+
196
+ Args:
197
+ kp2ds (np.ndarray): [133, 2]
198
+
199
+ Returns:
200
+ List[np.ndarray]: _description_
201
+ """
202
+ kp2ds_body = (
203
+ kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]]
204
+ + kp2ds[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]
205
+ ) / 2
206
+ kp2ds_lhand = kp2ds[91:112]
207
+ kp2ds_rhand = kp2ds[112:133]
208
+ return kp2ds_body.copy(), kp2ds_lhand.copy(), kp2ds_rhand.copy()
209
+
210
+
211
+ def draw_aapose_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200, draw_hand=True, draw_head=True):
212
+ kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)
213
+ kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
214
+ kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
215
+ pose_img = draw_aapose(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand, stick_width_norm=stick_width_norm, draw_hand=draw_hand, draw_head=draw_head)
216
+ return pose_img
217
+
218
+ def draw_aapose_by_meta_new(img, meta: AAPoseMeta, threshold=0.5, stickwidth_type='v2', draw_hand=True, draw_head=True):
219
+ kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)
220
+ kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
221
+ kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
222
+ pose_img = draw_aapose_new(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand,
223
+ stickwidth_type=stickwidth_type, draw_hand=draw_hand, draw_head=draw_head)
224
+ return pose_img
225
+
226
+ def draw_hand_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200):
227
+ kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None] * 0], axis=1)
228
+ kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
229
+ kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
230
+ pose_img = draw_aapose(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand, stick_width_norm=stick_width_norm, draw_hand=True, draw_head=False)
231
+ return pose_img
232
+
233
+
234
+ def draw_aaface_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200, draw_hand=False, draw_head=True):
235
+ kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)
236
+ # kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
237
+ # kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
238
+ pose_img = draw_M(img, kp2ds, threshold, kp2ds_lhand=None, kp2ds_rhand=None, stick_width_norm=stick_width_norm, draw_hand=draw_hand, draw_head=draw_head)
239
+ return pose_img
240
+
241
+
242
+ def draw_aanose_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=100, draw_hand=False):
243
+ kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)
244
+ # kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
245
+ # kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
246
+ pose_img = draw_nose(img, kp2ds, threshold, kp2ds_lhand=None, kp2ds_rhand=None, stick_width_norm=stick_width_norm, draw_hand=draw_hand)
247
+ return pose_img
248
+
249
+
250
+ def gen_face_motion_seq(img, metas: List[AAPoseMeta], threshold=0.5, stick_width_norm=200):
251
+
252
+ return
253
+
254
+
255
+ def draw_M(
256
+ img,
257
+ kp2ds,
258
+ threshold=0.6,
259
+ data_to_json=None,
260
+ idx=-1,
261
+ kp2ds_lhand=None,
262
+ kp2ds_rhand=None,
263
+ draw_hand=False,
264
+ stick_width_norm=200,
265
+ draw_head=True
266
+ ):
267
+ """
268
+ Draw keypoints and connections representing hand pose on a given canvas.
269
+
270
+ Args:
271
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
272
+ keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
273
+ or None if no keypoints are present.
274
+
275
+ Returns:
276
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
277
+
278
+ Note:
279
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
280
+ """
281
+
282
+ new_kep_list = [
283
+ "Nose",
284
+ "Neck",
285
+ "RShoulder",
286
+ "RElbow",
287
+ "RWrist", # No.4
288
+ "LShoulder",
289
+ "LElbow",
290
+ "LWrist", # No.7
291
+ "RHip",
292
+ "RKnee",
293
+ "RAnkle", # No.10
294
+ "LHip",
295
+ "LKnee",
296
+ "LAnkle", # No.13
297
+ "REye",
298
+ "LEye",
299
+ "REar",
300
+ "LEar",
301
+ "LToe",
302
+ "RToe",
303
+ ]
304
+ # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \
305
+ # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
306
+ kp2ds = kp2ds.copy()
307
+ # import ipdb; ipdb.set_trace()
308
+ kp2ds[[1,2,3,4,5,6,7,8,9,10,11,12,13,18,19], 2] = 0
309
+ if not draw_head:
310
+ kp2ds[[0,14,15,16,17], 2] = 0
311
+ kp2ds_body = kp2ds
312
+ # kp2ds_body = kp2ds_body[:18]
313
+
314
+ # kp2ds_lhand = kp2ds.copy()[91:112]
315
+ # kp2ds_rhand = kp2ds.copy()[112:133]
316
+
317
+ limbSeq = [
318
+ # [2, 3],
319
+ # [2, 6], # shoulders
320
+ # [3, 4],
321
+ # [4, 5], # left arm
322
+ # [6, 7],
323
+ # [7, 8], # right arm
324
+ # [2, 9],
325
+ # [9, 10],
326
+ # [10, 11], # right leg
327
+ # [2, 12],
328
+ # [12, 13],
329
+ # [13, 14], # left leg
330
+ # [2, 1],
331
+ [1, 15],
332
+ [15, 17],
333
+ [1, 16],
334
+ [16, 18], # face (nose, eyes, ears)
335
+ # [14, 19],
336
+ # [11, 20], # foot
337
+ ]
338
+
339
+ colors = [
340
+ # [255, 0, 0],
341
+ # [255, 85, 0],
342
+ # [255, 170, 0],
343
+ # [255, 255, 0],
344
+ # [170, 255, 0],
345
+ # [85, 255, 0],
346
+ # [0, 255, 0],
347
+ # [0, 255, 85],
348
+ # [0, 255, 170],
349
+ # [0, 255, 255],
350
+ # [0, 170, 255],
351
+ # [0, 85, 255],
352
+ # [0, 0, 255],
353
+ # [85, 0, 255],
354
+ [170, 0, 255],
355
+ [255, 0, 255],
356
+ [255, 0, 170],
357
+ [255, 0, 85],
358
+ # foot
359
+ # [200, 200, 0],
360
+ # [100, 100, 0],
361
+ ]
362
+
363
+ H, W, C = img.shape
364
+ stickwidth = max(int(min(H, W) / stick_width_norm), 1)
365
+
366
+ for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):
367
+ keypoint1 = kp2ds_body[k1_index - 1]
368
+ keypoint2 = kp2ds_body[k2_index - 1]
369
+
370
+ if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
371
+ continue
372
+
373
+ Y = np.array([keypoint1[0], keypoint2[0]])
374
+ X = np.array([keypoint1[1], keypoint2[1]])
375
+ mX = np.mean(X)
376
+ mY = np.mean(Y)
377
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
378
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
379
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
380
+ cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
381
+
382
+ for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):
383
+ if keypoint[-1] < threshold:
384
+ continue
385
+ x, y = keypoint[0], keypoint[1]
386
+ # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
387
+ cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)
388
+
389
+ if draw_hand:
390
+ img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold)
391
+ img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold)
392
+
393
+ kp2ds_body[:, 0] /= W
394
+ kp2ds_body[:, 1] /= H
395
+
396
+ if data_to_json is not None:
397
+ if idx == -1:
398
+ data_to_json.append(
399
+ {
400
+ "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1),
401
+ "height": H,
402
+ "width": W,
403
+ "category_id": 1,
404
+ "keypoints_body": kp2ds_body.tolist(),
405
+ "keypoints_left_hand": kp2ds_lhand.tolist(),
406
+ "keypoints_right_hand": kp2ds_rhand.tolist(),
407
+ }
408
+ )
409
+ else:
410
+ data_to_json[idx] = {
411
+ "image_id": "frame_{:05d}.jpg".format(idx + 1),
412
+ "height": H,
413
+ "width": W,
414
+ "category_id": 1,
415
+ "keypoints_body": kp2ds_body.tolist(),
416
+ "keypoints_left_hand": kp2ds_lhand.tolist(),
417
+ "keypoints_right_hand": kp2ds_rhand.tolist(),
418
+ }
419
+ return img
420
+
421
+
422
+ def draw_nose(
423
+ img,
424
+ kp2ds,
425
+ threshold=0.6,
426
+ data_to_json=None,
427
+ idx=-1,
428
+ kp2ds_lhand=None,
429
+ kp2ds_rhand=None,
430
+ draw_hand=False,
431
+ stick_width_norm=200,
432
+ ):
433
+ """
434
+ Draw keypoints and connections representing hand pose on a given canvas.
435
+
436
+ Args:
437
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
438
+ keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
439
+ or None if no keypoints are present.
440
+
441
+ Returns:
442
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
443
+
444
+ Note:
445
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
446
+ """
447
+
448
+ new_kep_list = [
449
+ "Nose",
450
+ "Neck",
451
+ "RShoulder",
452
+ "RElbow",
453
+ "RWrist", # No.4
454
+ "LShoulder",
455
+ "LElbow",
456
+ "LWrist", # No.7
457
+ "RHip",
458
+ "RKnee",
459
+ "RAnkle", # No.10
460
+ "LHip",
461
+ "LKnee",
462
+ "LAnkle", # No.13
463
+ "REye",
464
+ "LEye",
465
+ "REar",
466
+ "LEar",
467
+ "LToe",
468
+ "RToe",
469
+ ]
470
+ # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \
471
+ # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
472
+ kp2ds = kp2ds.copy()
473
+ kp2ds[1:, 2] = 0
474
+ # kp2ds[0, 2] = 1
475
+ kp2ds_body = kp2ds
476
+ # kp2ds_body = kp2ds_body[:18]
477
+
478
+ # kp2ds_lhand = kp2ds.copy()[91:112]
479
+ # kp2ds_rhand = kp2ds.copy()[112:133]
480
+
481
+ limbSeq = [
482
+ # [2, 3],
483
+ # [2, 6], # shoulders
484
+ # [3, 4],
485
+ # [4, 5], # left arm
486
+ # [6, 7],
487
+ # [7, 8], # right arm
488
+ # [2, 9],
489
+ # [9, 10],
490
+ # [10, 11], # right leg
491
+ # [2, 12],
492
+ # [12, 13],
493
+ # [13, 14], # left leg
494
+ # [2, 1],
495
+ [1, 15],
496
+ [15, 17],
497
+ [1, 16],
498
+ [16, 18], # face (nose, eyes, ears)
499
+ # [14, 19],
500
+ # [11, 20], # foot
501
+ ]
502
+
503
+ colors = [
504
+ # [255, 0, 0],
505
+ # [255, 85, 0],
506
+ # [255, 170, 0],
507
+ # [255, 255, 0],
508
+ # [170, 255, 0],
509
+ # [85, 255, 0],
510
+ # [0, 255, 0],
511
+ # [0, 255, 85],
512
+ # [0, 255, 170],
513
+ # [0, 255, 255],
514
+ # [0, 170, 255],
515
+ # [0, 85, 255],
516
+ # [0, 0, 255],
517
+ # [85, 0, 255],
518
+ [170, 0, 255],
519
+ # [255, 0, 255],
520
+ # [255, 0, 170],
521
+ # [255, 0, 85],
522
+ # foot
523
+ # [200, 200, 0],
524
+ # [100, 100, 0],
525
+ ]
526
+
527
+ H, W, C = img.shape
528
+ stickwidth = max(int(min(H, W) / stick_width_norm), 1)
529
+
530
+ # for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):
531
+ # keypoint1 = kp2ds_body[k1_index - 1]
532
+ # keypoint2 = kp2ds_body[k2_index - 1]
533
+
534
+ # if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
535
+ # continue
536
+
537
+ # Y = np.array([keypoint1[0], keypoint2[0]])
538
+ # X = np.array([keypoint1[1], keypoint2[1]])
539
+ # mX = np.mean(X)
540
+ # mY = np.mean(Y)
541
+ # length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
542
+ # angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
543
+ # polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
544
+ # cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
545
+
546
+ for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):
547
+ if keypoint[-1] < threshold:
548
+ continue
549
+ x, y = keypoint[0], keypoint[1]
550
+ # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
551
+ cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)
552
+
553
+ if draw_hand:
554
+ img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold)
555
+ img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold)
556
+
557
+ kp2ds_body[:, 0] /= W
558
+ kp2ds_body[:, 1] /= H
559
+
560
+ if data_to_json is not None:
561
+ if idx == -1:
562
+ data_to_json.append(
563
+ {
564
+ "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1),
565
+ "height": H,
566
+ "width": W,
567
+ "category_id": 1,
568
+ "keypoints_body": kp2ds_body.tolist(),
569
+ "keypoints_left_hand": kp2ds_lhand.tolist(),
570
+ "keypoints_right_hand": kp2ds_rhand.tolist(),
571
+ }
572
+ )
573
+ else:
574
+ data_to_json[idx] = {
575
+ "image_id": "frame_{:05d}.jpg".format(idx + 1),
576
+ "height": H,
577
+ "width": W,
578
+ "category_id": 1,
579
+ "keypoints_body": kp2ds_body.tolist(),
580
+ "keypoints_left_hand": kp2ds_lhand.tolist(),
581
+ "keypoints_right_hand": kp2ds_rhand.tolist(),
582
+ }
583
+ return img
584
+
585
+
586
+ def draw_aapose(
587
+ img,
588
+ kp2ds,
589
+ threshold=0.6,
590
+ data_to_json=None,
591
+ idx=-1,
592
+ kp2ds_lhand=None,
593
+ kp2ds_rhand=None,
594
+ draw_hand=False,
595
+ stick_width_norm=200,
596
+ draw_head=True
597
+ ):
598
+ """
599
+ Draw keypoints and connections representing hand pose on a given canvas.
600
+
601
+ Args:
602
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
603
+ keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
604
+ or None if no keypoints are present.
605
+
606
+ Returns:
607
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
608
+
609
+ Note:
610
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
611
+ """
612
+
613
+ new_kep_list = [
614
+ "Nose",
615
+ "Neck",
616
+ "RShoulder",
617
+ "RElbow",
618
+ "RWrist", # No.4
619
+ "LShoulder",
620
+ "LElbow",
621
+ "LWrist", # No.7
622
+ "RHip",
623
+ "RKnee",
624
+ "RAnkle", # No.10
625
+ "LHip",
626
+ "LKnee",
627
+ "LAnkle", # No.13
628
+ "REye",
629
+ "LEye",
630
+ "REar",
631
+ "LEar",
632
+ "LToe",
633
+ "RToe",
634
+ ]
635
+ # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \
636
+ # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
637
+ kp2ds = kp2ds.copy()
638
+ if not draw_head:
639
+ kp2ds[[0,14,15,16,17], 2] = 0
640
+ kp2ds_body = kp2ds
641
+
642
+ # kp2ds_lhand = kp2ds.copy()[91:112]
643
+ # kp2ds_rhand = kp2ds.copy()[112:133]
644
+
645
+ limbSeq = [
646
+ [2, 3],
647
+ [2, 6], # shoulders
648
+ [3, 4],
649
+ [4, 5], # left arm
650
+ [6, 7],
651
+ [7, 8], # right arm
652
+ [2, 9],
653
+ [9, 10],
654
+ [10, 11], # right leg
655
+ [2, 12],
656
+ [12, 13],
657
+ [13, 14], # left leg
658
+ [2, 1],
659
+ [1, 15],
660
+ [15, 17],
661
+ [1, 16],
662
+ [16, 18], # face (nose, eyes, ears)
663
+ [14, 19],
664
+ [11, 20], # foot
665
+ ]
666
+
667
+ colors = [
668
+ [255, 0, 0],
669
+ [255, 85, 0],
670
+ [255, 170, 0],
671
+ [255, 255, 0],
672
+ [170, 255, 0],
673
+ [85, 255, 0],
674
+ [0, 255, 0],
675
+ [0, 255, 85],
676
+ [0, 255, 170],
677
+ [0, 255, 255],
678
+ [0, 170, 255],
679
+ [0, 85, 255],
680
+ [0, 0, 255],
681
+ [85, 0, 255],
682
+ [170, 0, 255],
683
+ [255, 0, 255],
684
+ [255, 0, 170],
685
+ [255, 0, 85],
686
+ # foot
687
+ [200, 200, 0],
688
+ [100, 100, 0],
689
+ ]
690
+
691
+ H, W, C = img.shape
692
+ stickwidth = max(int(min(H, W) / stick_width_norm), 1)
693
+
694
+ for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):
695
+ keypoint1 = kp2ds_body[k1_index - 1]
696
+ keypoint2 = kp2ds_body[k2_index - 1]
697
+
698
+ if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
699
+ continue
700
+
701
+ Y = np.array([keypoint1[0], keypoint2[0]])
702
+ X = np.array([keypoint1[1], keypoint2[1]])
703
+ mX = np.mean(X)
704
+ mY = np.mean(Y)
705
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
706
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
707
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
708
+ cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
709
+
710
+ for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):
711
+ if keypoint[-1] < threshold:
712
+ continue
713
+ x, y = keypoint[0], keypoint[1]
714
+ # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
715
+ cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)
716
+
717
+ if draw_hand:
718
+ img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold)
719
+ img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold)
720
+
721
+ kp2ds_body[:, 0] /= W
722
+ kp2ds_body[:, 1] /= H
723
+
724
+ if data_to_json is not None:
725
+ if idx == -1:
726
+ data_to_json.append(
727
+ {
728
+ "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1),
729
+ "height": H,
730
+ "width": W,
731
+ "category_id": 1,
732
+ "keypoints_body": kp2ds_body.tolist(),
733
+ "keypoints_left_hand": kp2ds_lhand.tolist(),
734
+ "keypoints_right_hand": kp2ds_rhand.tolist(),
735
+ }
736
+ )
737
+ else:
738
+ data_to_json[idx] = {
739
+ "image_id": "frame_{:05d}.jpg".format(idx + 1),
740
+ "height": H,
741
+ "width": W,
742
+ "category_id": 1,
743
+ "keypoints_body": kp2ds_body.tolist(),
744
+ "keypoints_left_hand": kp2ds_lhand.tolist(),
745
+ "keypoints_right_hand": kp2ds_rhand.tolist(),
746
+ }
747
+ return img
748
+
749
+
750
+ def draw_aapose_new(
751
+ img,
752
+ kp2ds,
753
+ threshold=0.6,
754
+ data_to_json=None,
755
+ idx=-1,
756
+ kp2ds_lhand=None,
757
+ kp2ds_rhand=None,
758
+ draw_hand=False,
759
+ stickwidth_type='v2',
760
+ draw_head=True
761
+ ):
762
+ """
763
+ Draw keypoints and connections representing hand pose on a given canvas.
764
+
765
+ Args:
766
+ canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
767
+ keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
768
+ or None if no keypoints are present.
769
+
770
+ Returns:
771
+ np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
772
+
773
+ Note:
774
+ The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
775
+ """
776
+
777
+ new_kep_list = [
778
+ "Nose",
779
+ "Neck",
780
+ "RShoulder",
781
+ "RElbow",
782
+ "RWrist", # No.4
783
+ "LShoulder",
784
+ "LElbow",
785
+ "LWrist", # No.7
786
+ "RHip",
787
+ "RKnee",
788
+ "RAnkle", # No.10
789
+ "LHip",
790
+ "LKnee",
791
+ "LAnkle", # No.13
792
+ "REye",
793
+ "LEye",
794
+ "REar",
795
+ "LEar",
796
+ "LToe",
797
+ "RToe",
798
+ ]
799
+ # kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \
800
+ # kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
801
+ kp2ds = kp2ds.copy()
802
+ if not draw_head:
803
+ kp2ds[[0,14,15,16,17], 2] = 0
804
+ kp2ds_body = kp2ds
805
+
806
+ # kp2ds_lhand = kp2ds.copy()[91:112]
807
+ # kp2ds_rhand = kp2ds.copy()[112:133]
808
+
809
+ limbSeq = [
810
+ [2, 3],
811
+ [2, 6], # shoulders
812
+ [3, 4],
813
+ [4, 5], # left arm
814
+ [6, 7],
815
+ [7, 8], # right arm
816
+ [2, 9],
817
+ [9, 10],
818
+ [10, 11], # right leg
819
+ [2, 12],
820
+ [12, 13],
821
+ [13, 14], # left leg
822
+ [2, 1],
823
+ [1, 15],
824
+ [15, 17],
825
+ [1, 16],
826
+ [16, 18], # face (nose, eyes, ears)
827
+ [14, 19],
828
+ [11, 20], # foot
829
+ ]
830
+
831
+ colors = [
832
+ [255, 0, 0],
833
+ [255, 85, 0],
834
+ [255, 170, 0],
835
+ [255, 255, 0],
836
+ [170, 255, 0],
837
+ [85, 255, 0],
838
+ [0, 255, 0],
839
+ [0, 255, 85],
840
+ [0, 255, 170],
841
+ [0, 255, 255],
842
+ [0, 170, 255],
843
+ [0, 85, 255],
844
+ [0, 0, 255],
845
+ [85, 0, 255],
846
+ [170, 0, 255],
847
+ [255, 0, 255],
848
+ [255, 0, 170],
849
+ [255, 0, 85],
850
+ # foot
851
+ [200, 200, 0],
852
+ [100, 100, 0],
853
+ ]
854
+
855
+ H, W, C = img.shape
856
+ H, W, C = img.shape
857
+
858
+ if stickwidth_type == 'v1':
859
+ stickwidth = max(int(min(H, W) / 200), 1)
860
+ elif stickwidth_type == 'v2':
861
+ stickwidth = max(int(min(H, W) / 200) - 1, 1)
862
+ else:
863
+ raise
864
+
865
+ for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):
866
+ keypoint1 = kp2ds_body[k1_index - 1]
867
+ keypoint2 = kp2ds_body[k2_index - 1]
868
+
869
+ if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
870
+ continue
871
+
872
+ Y = np.array([keypoint1[0], keypoint2[0]])
873
+ X = np.array([keypoint1[1], keypoint2[1]])
874
+ mX = np.mean(X)
875
+ mY = np.mean(Y)
876
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
877
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
878
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
879
+ cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
880
+
881
+ for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):
882
+ if keypoint[-1] < threshold:
883
+ continue
884
+ x, y = keypoint[0], keypoint[1]
885
+ # cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
886
+ cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)
887
+
888
+ if draw_hand:
889
+ img = draw_handpose_new(img, kp2ds_lhand, stickwidth_type=stickwidth_type, hand_score_th=threshold)
890
+ img = draw_handpose_new(img, kp2ds_rhand, stickwidth_type=stickwidth_type, hand_score_th=threshold)
891
+
892
+ kp2ds_body[:, 0] /= W
893
+ kp2ds_body[:, 1] /= H
894
+
895
+ if data_to_json is not None:
896
+ if idx == -1:
897
+ data_to_json.append(
898
+ {
899
+ "image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1),
900
+ "height": H,
901
+ "width": W,
902
+ "category_id": 1,
903
+ "keypoints_body": kp2ds_body.tolist(),
904
+ "keypoints_left_hand": kp2ds_lhand.tolist(),
905
+ "keypoints_right_hand": kp2ds_rhand.tolist(),
906
+ }
907
+ )
908
+ else:
909
+ data_to_json[idx] = {
910
+ "image_id": "frame_{:05d}.jpg".format(idx + 1),
911
+ "height": H,
912
+ "width": W,
913
+ "category_id": 1,
914
+ "keypoints_body": kp2ds_body.tolist(),
915
+ "keypoints_left_hand": kp2ds_lhand.tolist(),
916
+ "keypoints_right_hand": kp2ds_rhand.tolist(),
917
+ }
918
+ return img
919
+
920
+
921
+ def draw_bbox(img, bbox, color=(255, 0, 0)):
922
+ img = load_image(img)
923
+ bbox = [int(bbox_tmp) for bbox_tmp in bbox]
924
+ cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2)
925
+ return img
926
+
927
+
928
+ def draw_kp2ds(img, kp2ds, threshold=0, color=(255, 0, 0), skeleton=None, reverse=False):
929
+ img = load_image(img, reverse)
930
+
931
+ if skeleton is not None:
932
+ if skeleton == "coco17":
933
+ skeleton_list = [
934
+ [6, 8],
935
+ [8, 10],
936
+ [5, 7],
937
+ [7, 9],
938
+ [11, 13],
939
+ [13, 15],
940
+ [12, 14],
941
+ [14, 16],
942
+ [5, 6],
943
+ [6, 12],
944
+ [12, 11],
945
+ [11, 5],
946
+ ]
947
+ color_list = [
948
+ (255, 0, 0),
949
+ (0, 255, 0),
950
+ (0, 0, 255),
951
+ (255, 255, 0),
952
+ (255, 0, 255),
953
+ (0, 255, 255),
954
+ ]
955
+ elif skeleton == "cocowholebody":
956
+ skeleton_list = [
957
+ [6, 8],
958
+ [8, 10],
959
+ [5, 7],
960
+ [7, 9],
961
+ [11, 13],
962
+ [13, 15],
963
+ [12, 14],
964
+ [14, 16],
965
+ [5, 6],
966
+ [6, 12],
967
+ [12, 11],
968
+ [11, 5],
969
+ [15, 17],
970
+ [15, 18],
971
+ [15, 19],
972
+ [16, 20],
973
+ [16, 21],
974
+ [16, 22],
975
+ [91, 92, 93, 94, 95],
976
+ [91, 96, 97, 98, 99],
977
+ [91, 100, 101, 102, 103],
978
+ [91, 104, 105, 106, 107],
979
+ [91, 108, 109, 110, 111],
980
+ [112, 113, 114, 115, 116],
981
+ [112, 117, 118, 119, 120],
982
+ [112, 121, 122, 123, 124],
983
+ [112, 125, 126, 127, 128],
984
+ [112, 129, 130, 131, 132],
985
+ ]
986
+ color_list = [
987
+ (255, 0, 0),
988
+ (0, 255, 0),
989
+ (0, 0, 255),
990
+ (255, 255, 0),
991
+ (255, 0, 255),
992
+ (0, 255, 255),
993
+ ]
994
+ else:
995
+ color_list = [color]
996
+ for _idx, _skeleton in enumerate(skeleton_list):
997
+ for i in range(len(_skeleton) - 1):
998
+ cv2.line(
999
+ img,
1000
+ (int(kp2ds[_skeleton[i], 0]), int(kp2ds[_skeleton[i], 1])),
1001
+ (int(kp2ds[_skeleton[i + 1], 0]), int(kp2ds[_skeleton[i + 1], 1])),
1002
+ color_list[_idx % len(color_list)],
1003
+ 3,
1004
+ )
1005
+
1006
+ for _idx, kp2d in enumerate(kp2ds):
1007
+ if kp2d[2] > threshold:
1008
+ cv2.circle(img, (int(kp2d[0]), int(kp2d[1])), 3, color, -1)
1009
+ # cv2.putText(img,
1010
+ # str(_idx),
1011
+ # (int(kp2d[0, i, 0])*1,
1012
+ # int(kp2d[0, i, 1])*1),
1013
+ # cv2.FONT_HERSHEY_SIMPLEX,
1014
+ # 0.75,
1015
+ # color,
1016
+ # 2
1017
+ # )
1018
+
1019
+ return img
1020
+
1021
+
1022
+ def draw_mask(img, mask, background=0, return_rgba=False):
1023
+ img = load_image(img)
1024
+ h, w, _ = img.shape
1025
+ if type(background) == int:
1026
+ background = np.ones((h, w, 3)).astype(np.uint8) * 255 * background
1027
+ backgournd = cv2.resize(background, (w, h))
1028
+ img_rgba = np.concatenate([img, mask], -1)
1029
+ return alphaMerge(img_rgba, background, 0, 0, return_rgba=True)
1030
+
1031
+
1032
+ def draw_pcd(pcd_list, save_path=None):
1033
+ fig = plt.figure()
1034
+ ax = fig.add_subplot(111, projection="3d")
1035
+
1036
+ color_list = ["r", "g", "b", "y", "p"]
1037
+
1038
+ for _idx, _pcd in enumerate(pcd_list):
1039
+ ax.scatter(_pcd[:, 0], _pcd[:, 1], _pcd[:, 2], c=color_list[_idx], marker="o")
1040
+
1041
+ ax.set_xlabel("X")
1042
+ ax.set_ylabel("Y")
1043
+ ax.set_zlabel("Z")
1044
+
1045
+ if save_path is not None:
1046
+ plt.savefig(save_path)
1047
+ else:
1048
+ plt.savefig("tmp.png")
1049
+
1050
+
1051
+ def load_image(img, reverse=False):
1052
+ if type(img) == str:
1053
+ img = cv2.imread(img)
1054
+ if reverse:
1055
+ img = img.astype(np.float32)
1056
+ img = img[:, :, ::-1]
1057
+ img = img.astype(np.uint8)
1058
+ return img
1059
+
1060
+
1061
+ def draw_skeleten(meta):
1062
+ kps = []
1063
+ for i, kp in enumerate(meta["keypoints_body"]):
1064
+ if kp is None:
1065
+ # if kp is None:
1066
+ kps.append([0, 0, 0])
1067
+ else:
1068
+ kps.append([*kp, 1])
1069
+ kps = np.array(kps)
1070
+
1071
+ kps[:, 0] *= meta["width"]
1072
+ kps[:, 1] *= meta["height"]
1073
+ pose_img = np.zeros([meta["height"], meta["width"], 3], dtype=np.uint8)
1074
+
1075
+ pose_img = draw_aapose(
1076
+ pose_img,
1077
+ kps,
1078
+ draw_hand=True,
1079
+ kp2ds_lhand=meta["keypoints_left_hand"],
1080
+ kp2ds_rhand=meta["keypoints_right_hand"],
1081
+ )
1082
+ return pose_img
1083
+
1084
+
1085
+ def draw_skeleten_with_pncc(pncc: np.ndarray, meta: Dict) -> np.ndarray:
1086
+ """
1087
+ Args:
1088
+ pncc: [H,W,3]
1089
+ meta: required keys: keypoints_body: [N, 3] keypoints_left_hand, keypoints_right_hand
1090
+ Return:
1091
+ np.ndarray [H, W, 3]
1092
+ """
1093
+ # preprocess keypoints
1094
+ kps = []
1095
+ for i, kp in enumerate(meta["keypoints_body"]):
1096
+ if kp is None:
1097
+ # if kp is None:
1098
+ kps.append([0, 0, 0])
1099
+ elif i in [14, 15, 16, 17]:
1100
+ kps.append([0, 0, 0])
1101
+ else:
1102
+ kps.append([*kp])
1103
+ kps = np.stack(kps)
1104
+
1105
+ kps[:, 0] *= pncc.shape[1]
1106
+ kps[:, 1] *= pncc.shape[0]
1107
+
1108
+ # draw neck
1109
+ canvas = np.zeros_like(pncc)
1110
+ if kps[0][2] > 0.6 and kps[1][2] > 0.6:
1111
+ canvas = draw_ellipse_by_2kp(canvas, kps[0], kps[1], [0, 0, 255])
1112
+
1113
+ # draw pncc
1114
+ mask = (pncc > 0).max(axis=2)
1115
+ canvas[mask] = pncc[mask]
1116
+ pncc = canvas
1117
+
1118
+ # draw other skeleten
1119
+ kps[0] = 0
1120
+
1121
+ meta["keypoints_left_hand"][:, 0] *= meta["width"]
1122
+ meta["keypoints_left_hand"][:, 1] *= meta["height"]
1123
+
1124
+ meta["keypoints_right_hand"][:, 0] *= meta["width"]
1125
+ meta["keypoints_right_hand"][:, 1] *= meta["height"]
1126
+ pose_img = draw_aapose(
1127
+ pncc,
1128
+ kps,
1129
+ draw_hand=True,
1130
+ kp2ds_lhand=meta["keypoints_left_hand"],
1131
+ kp2ds_rhand=meta["keypoints_right_hand"],
1132
+ )
1133
+ return pose_img
1134
+
1135
+
1136
+ FACE_CUSTOM_STYLE = {
1137
+ "eyeball": {"indexs": [68, 69], "color": [255, 255, 255], "connect": False},
1138
+ "left_eyebrow": {"indexs": [17, 18, 19, 20, 21], "color": [0, 255, 0]},
1139
+ "right_eyebrow": {"indexs": [22, 23, 24, 25, 26], "color": [0, 0, 255]},
1140
+ "left_eye": {"indexs": [36, 37, 38, 39, 40, 41], "color": [255, 255, 0], "close": True},
1141
+ "right_eye": {"indexs": [42, 43, 44, 45, 46, 47], "color": [255, 0, 255], "close": True},
1142
+ "mouth_outside": {"indexs": list(range(48, 60)), "color": [100, 255, 50], "close": True},
1143
+ "mouth_inside": {"indexs": [60, 61, 62, 63, 64, 65, 66, 67], "color": [255, 100, 50], "close": True},
1144
+ }
1145
+
1146
+
1147
+ def draw_face_kp(img, kps, thickness=2, style=FACE_CUSTOM_STYLE):
1148
+ """
1149
+ Args:
1150
+ img: [H, W, 3]
1151
+ kps: [70, 2]
1152
+ """
1153
+ img = img.copy()
1154
+ for key, item in style.items():
1155
+ pts = np.array(kps[item["indexs"]]).astype(np.int32)
1156
+ connect = item.get("connect", True)
1157
+ color = item["color"]
1158
+ close = item.get("close", False)
1159
+ if connect:
1160
+ cv2.polylines(img, [pts], close, color, thickness=thickness)
1161
+ else:
1162
+ for kp in pts:
1163
+ kp = np.array(kp).astype(np.int32)
1164
+ cv2.circle(img, kp, thickness * 2, color=color, thickness=-1)
1165
+ return img
1166
+
1167
+
1168
+ def draw_traj(metas: List[AAPoseMeta], threshold=0.6):
1169
+
1170
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
1171
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
1172
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85], [100, 255, 50], [255, 100, 50],
1173
+ # foot
1174
+ [200, 200, 0],
1175
+ [100, 100, 0]
1176
+ ]
1177
+ limbSeq = [
1178
+ [1, 2], [1, 5], # shoulders
1179
+ [2, 3], [3, 4], # left arm
1180
+ [5, 6], [6, 7], # right arm
1181
+ [1, 8], [8, 9], [9, 10], # right leg
1182
+ [1, 11], [11, 12], [12, 13], # left leg
1183
+ # face (nose, eyes, ears)
1184
+ [13, 18], [10, 19] # foot
1185
+ ]
1186
+
1187
+ face_seq = [[1, 0], [0, 14], [14, 16], [0, 15], [15, 17]]
1188
+ kp_body = np.array([meta.kps_body for meta in metas])
1189
+ kp_body_p = np.array([meta.kps_body_p for meta in metas])
1190
+
1191
+
1192
+ face_seq = random.sample(face_seq, 2)
1193
+
1194
+ kp_lh = np.array([meta.kps_lhand for meta in metas])
1195
+ kp_rh = np.array([meta.kps_rhand for meta in metas])
1196
+
1197
+ kp_lh_p = np.array([meta.kps_lhand_p for meta in metas])
1198
+ kp_rh_p = np.array([meta.kps_rhand_p for meta in metas])
1199
+
1200
+ # kp_lh = np.concatenate([kp_lh, kp_lh_p], axis=-1)
1201
+ # kp_rh = np.concatenate([kp_rh, kp_rh_p], axis=-1)
1202
+
1203
+ new_limbSeq = []
1204
+ key_point_list = []
1205
+ for _idx, ((k1_index, k2_index)) in enumerate(limbSeq):
1206
+
1207
+ vis = (kp_body_p[:, k1_index] > threshold) * (kp_body_p[:, k2_index] > threshold) * 1
1208
+ if vis.sum() * 1.0 / vis.shape[0] > 0.4:
1209
+ new_limbSeq.append([k1_index, k2_index])
1210
+
1211
+ for _idx, ((k1_index, k2_index)) in enumerate(limbSeq):
1212
+
1213
+ keypoint1 = kp_body[:, k1_index - 1]
1214
+ keypoint2 = kp_body[:, k2_index - 1]
1215
+ interleave = random.randint(4, 7)
1216
+ randind = random.randint(0, interleave - 1)
1217
+ # randind = random.rand(range(interleave), sampling_num)
1218
+
1219
+ Y = np.array([keypoint1[:, 0], keypoint2[:, 0]])
1220
+ X = np.array([keypoint1[:, 1], keypoint2[:, 1]])
1221
+
1222
+ vis = (keypoint1[:, -1] > threshold) * (keypoint2[:, -1] > threshold) * 1
1223
+
1224
+ # for randidx in randind:
1225
+ t = randind / interleave
1226
+ x = (1-t)*Y[0, :] + t*Y[1, :]
1227
+ y = (1-t)*X[0, :] + t*X[1, :]
1228
+
1229
+ # np.array([1])
1230
+ x = x.astype(int)
1231
+ y = y.astype(int)
1232
+
1233
+ new_array = np.array([x, y, vis]).T
1234
+
1235
+ key_point_list.append(new_array)
1236
+
1237
+ indx_lh = random.randint(0, kp_lh.shape[1] - 1)
1238
+ lh = kp_lh[:, indx_lh, :]
1239
+ lh_p = kp_lh_p[:, indx_lh:indx_lh+1]
1240
+ lh = np.concatenate([lh, lh_p], axis=-1)
1241
+
1242
+ indx_rh = random.randint(0, kp_rh.shape[1] - 1)
1243
+ rh = kp_rh[:, random.randint(0, kp_rh.shape[1] - 1), :]
1244
+ rh_p = kp_rh_p[:, indx_rh:indx_rh+1]
1245
+ rh = np.concatenate([rh, rh_p], axis=-1)
1246
+
1247
+
1248
+
1249
+ lh[-1, :] = (lh[-1, :] > threshold) * 1
1250
+ rh[-1, :] = (rh[-1, :] > threshold) * 1
1251
+
1252
+ # print(rh.shape, new_array.shape)
1253
+ # exit()
1254
+ key_point_list.append(lh.astype(int))
1255
+ key_point_list.append(rh.astype(int))
1256
+
1257
+
1258
+ key_points_list = np.stack(key_point_list)
1259
+ num_points = len(key_points_list)
1260
+ sample_colors = random.sample(colors, num_points)
1261
+
1262
+ stickwidth = max(int(min(metas[0].width, metas[0].height) / 150), 2)
1263
+
1264
+ image_list_ori = []
1265
+ for i in range(key_points_list.shape[-2]):
1266
+ _image_vis = np.zeros((metas[0].width, metas[0].height, 3))
1267
+ points = key_points_list[:, i, :]
1268
+ for idx, point in enumerate(points):
1269
+ x, y, vis = point
1270
+ if vis == 1:
1271
+ cv2.circle(_image_vis, (x, y), stickwidth, sample_colors[idx], thickness=-1)
1272
+
1273
+ image_list_ori.append(_image_vis)
1274
+
1275
+ return image_list_ori
1276
+
1277
+ return [np.zeros([meta.width, meta.height, 3], dtype=np.uint8) for meta in metas]
1278
+
1279
+
1280
+ if __name__ == "__main__":
1281
+ meta = {
1282
+ "image_id": "00472.jpg",
1283
+ "height": 540,
1284
+ "width": 414,
1285
+ "category_id": 1,
1286
+ "keypoints_body": [
1287
+ [0.5084776947463768, 0.11350188078703703],
1288
+ [0.504467655495169, 0.20419560185185184],
1289
+ [0.3982016153381642, 0.198046875],
1290
+ [0.3841664779589372, 0.34869068287037036],
1291
+ [0.3901815368357488, 0.4670536747685185],
1292
+ [0.610733695652174, 0.2103443287037037],
1293
+ [0.6167487545289855, 0.3517650462962963],
1294
+ [0.6448190292874396, 0.4762767650462963],
1295
+ [0.4523371452294686, 0.47320240162037036],
1296
+ [0.4503321256038647, 0.6776475694444445],
1297
+ [0.47639738073671495, 0.8544234664351852],
1298
+ [0.5766483620169082, 0.47320240162037036],
1299
+ [0.5666232638888888, 0.6761103877314815],
1300
+ [0.534542949879227, 0.863646556712963],
1301
+ [0.4864224788647343, 0.09505570023148148],
1302
+ [0.5285278910024155, 0.09351851851851851],
1303
+ [0.46236224335748793, 0.10581597222222222],
1304
+ [0.5586031853864735, 0.10274160879629629],
1305
+ [0.4994551064311594, 0.9405056423611111],
1306
+ [0.4152442821557971, 0.9312825520833333],
1307
+ ],
1308
+ "keypoints_left_hand": [
1309
+ [267.78515625, 263.830078125, 1.2840936183929443],
1310
+ [265.294921875, 269.640625, 1.2546794414520264],
1311
+ [263.634765625, 277.111328125, 1.2863062620162964],
1312
+ [262.8046875, 285.412109375, 1.267038345336914],
1313
+ [261.14453125, 292.8828125, 1.280144453048706],
1314
+ [273.595703125, 281.26171875, 1.2592815160751343],
1315
+ [271.10546875, 291.22265625, 1.3256099224090576],
1316
+ [265.294921875, 294.54296875, 1.2368024587631226],
1317
+ [261.14453125, 294.54296875, 0.9771889448165894],
1318
+ [274.42578125, 282.091796875, 1.250044584274292],
1319
+ [269.4453125, 291.22265625, 1.2571144104003906],
1320
+ [264.46484375, 292.8828125, 1.177802324295044],
1321
+ [260.314453125, 292.052734375, 0.9283463358879089],
1322
+ [273.595703125, 282.091796875, 1.1834490299224854],
1323
+ [269.4453125, 290.392578125, 1.188171625137329],
1324
+ [265.294921875, 290.392578125, 1.192609429359436],
1325
+ [261.974609375, 289.5625, 0.9366656541824341],
1326
+ [271.935546875, 281.26171875, 1.0946396589279175],
1327
+ [268.615234375, 287.072265625, 0.9906131029129028],
1328
+ [265.294921875, 287.90234375, 1.0219476222991943],
1329
+ [262.8046875, 287.072265625, 0.9240120053291321],
1330
+ ],
1331
+ "keypoints_right_hand": [
1332
+ [161.53515625, 258.849609375, 1.2069408893585205],
1333
+ [168.17578125, 263.0, 1.1846840381622314],
1334
+ [173.986328125, 269.640625, 1.1435924768447876],
1335
+ [173.986328125, 277.94140625, 1.1802611351013184],
1336
+ [173.986328125, 286.2421875, 1.2599592208862305],
1337
+ [165.685546875, 275.451171875, 1.0633569955825806],
1338
+ [167.345703125, 286.2421875, 1.1693341732025146],
1339
+ [169.8359375, 291.22265625, 1.2698509693145752],
1340
+ [170.666015625, 294.54296875, 1.0619274377822876],
1341
+ [160.705078125, 276.28125, 1.0995020866394043],
1342
+ [163.1953125, 287.90234375, 1.2735884189605713],
1343
+ [166.515625, 291.22265625, 1.339503526687622],
1344
+ [169.005859375, 294.54296875, 1.0835273265838623],
1345
+ [157.384765625, 277.111328125, 1.0866981744766235],
1346
+ [161.53515625, 287.072265625, 1.2468621730804443],
1347
+ [164.025390625, 289.5625, 1.2817761898040771],
1348
+ [166.515625, 292.052734375, 1.099466323852539],
1349
+ [155.724609375, 277.111328125, 1.1065717935562134],
1350
+ [159.044921875, 285.412109375, 1.1924479007720947],
1351
+ [160.705078125, 287.072265625, 1.1304771900177002],
1352
+ [162.365234375, 287.90234375, 1.0040509700775146],
1353
+ ],
1354
+ }
1355
+ demo_meta = AAPoseMeta(meta)
1356
+ res = draw_traj([demo_meta]*5)
1357
+ cv2.imwrite("traj.png", res[0][..., ::-1])
wan/modules/animate/preprocess/pose2d.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import os
3
+ import cv2
4
+ from typing import Union, List
5
+
6
+ import numpy as np
7
+ import torch
8
+ import onnxruntime
9
+
10
+ from pose2d_utils import (
11
+ read_img,
12
+ box_convert_simple,
13
+ bbox_from_detector,
14
+ crop,
15
+ keypoints_from_heatmaps,
16
+ load_pose_metas_from_kp2ds_seq
17
+ )
18
+
19
+
20
+ class SimpleOnnxInference(object):
21
+ def __init__(self, checkpoint, device='cuda', reverse_input=False, **kwargs):
22
+ if isinstance(device, str):
23
+ device = torch.device(device)
24
+ if device.type == 'cuda':
25
+ device = '{}:{}'.format(device.type, device.index)
26
+ providers = [("CUDAExecutionProvider", {"device_id": device[-1:] if device[-1] in [str(_i) for _i in range(10)] else "0"}), "CPUExecutionProvider"]
27
+ else:
28
+ providers = ["CPUExecutionProvider"]
29
+ self.device = device
30
+ if not os.path.exists(checkpoint):
31
+ raise RuntimeError("{} is not existed!".format(checkpoint))
32
+
33
+ if os.path.isdir(checkpoint):
34
+ checkpoint = os.path.join(checkpoint, 'end2end.onnx')
35
+
36
+ self.session = onnxruntime.InferenceSession(checkpoint,
37
+ providers=providers
38
+ )
39
+ self.input_name = self.session.get_inputs()[0].name
40
+ self.output_name = self.session.get_outputs()[0].name
41
+ self.input_resolution = self.session.get_inputs()[0].shape[2:] if not reverse_input else self.session.get_inputs()[0].shape[2:][::-1]
42
+ self.input_resolution = np.array(self.input_resolution)
43
+
44
+
45
+ def __call__(self, *args, **kwargs):
46
+ return self.forward(*args, **kwargs)
47
+
48
+
49
+ def get_output_names(self):
50
+ output_names = []
51
+ for node in self.session.get_outputs():
52
+ output_names.append(node.name)
53
+ return output_names
54
+
55
+
56
+ def set_device(self, device):
57
+ if isinstance(device, str):
58
+ device = torch.device(device)
59
+ if device.type == 'cuda':
60
+ device = '{}:{}'.format(device.type, device.index)
61
+ providers = [("CUDAExecutionProvider", {"device_id": device[-1:] if device[-1] in [str(_i) for _i in range(10)] else "0"}), "CPUExecutionProvider"]
62
+ else:
63
+ providers = ["CPUExecutionProvider"]
64
+ self.session.set_providers(["CUDAExecutionProvider"])
65
+ self.device = device
66
+
67
+
68
+ class Yolo(SimpleOnnxInference):
69
+ def __init__(self, checkpoint, device='cuda', threshold_conf=0.05, threshold_multi_persons=0.1, input_resolution=(640, 640), threshold_iou=0.5, threshold_bbox_shape_ratio=0.4, cat_id=[1], select_type='max', strict=True, sorted_func=None, **kwargs):
70
+ super(Yolo, self).__init__(checkpoint, device=device, **kwargs)
71
+ self.session.set_providers(["CUDAExecutionProvider"])
72
+ model_inputs = self.session.get_inputs()
73
+ input_shape = model_inputs[0].shape
74
+
75
+ self.input_width = 640
76
+ self.input_height = 640
77
+
78
+ self.threshold_multi_persons = threshold_multi_persons
79
+ self.threshold_conf = threshold_conf
80
+ self.threshold_iou = threshold_iou
81
+ self.threshold_bbox_shape_ratio = threshold_bbox_shape_ratio
82
+ self.input_resolution = input_resolution
83
+ self.cat_id = cat_id
84
+ self.select_type = select_type
85
+ self.strict = strict
86
+ self.sorted_func = sorted_func
87
+
88
+
89
+ def preprocess(self, input_image):
90
+ """
91
+ Preprocesses the input image before performing inference.
92
+
93
+ Returns:
94
+ image_data: Preprocessed image data ready for inference.
95
+ """
96
+ img = read_img(input_image)
97
+ # Get the height and width of the input image
98
+ img_height, img_width = img.shape[:2]
99
+ # Resize the image to match the input shape
100
+ img = cv2.resize(img, (self.input_resolution[1], self.input_resolution[0]))
101
+ # Normalize the image data by dividing it by 255.0
102
+ image_data = np.array(img) / 255.0
103
+ # Transpose the image to have the channel dimension as the first dimension
104
+ image_data = np.transpose(image_data, (2, 0, 1)) # Channel first
105
+ # Expand the dimensions of the image data to match the expected input shape
106
+ # image_data = np.expand_dims(image_data, axis=0).astype(np.float32)
107
+ image_data = image_data.astype(np.float32)
108
+ # Return the preprocessed image data
109
+ return image_data, np.array([img_height, img_width])
110
+
111
+
112
+ def postprocess(self, output, shape_raw, cat_id=[1]):
113
+ """
114
+ Performs post-processing on the model's output to extract bounding boxes, scores, and class IDs.
115
+
116
+ Args:
117
+ input_image (numpy.ndarray): The input image.
118
+ output (numpy.ndarray): The output of the model.
119
+
120
+ Returns:
121
+ numpy.ndarray: The input image with detections drawn on it.
122
+ """
123
+ # Transpose and squeeze the output to match the expected shape
124
+
125
+ outputs = np.squeeze(output)
126
+ if len(outputs.shape) == 1:
127
+ outputs = outputs[None]
128
+ if output.shape[-1] != 6 and output.shape[1] == 84:
129
+ outputs = np.transpose(outputs)
130
+
131
+ # Get the number of rows in the outputs array
132
+ rows = outputs.shape[0]
133
+
134
+ # Calculate the scaling factors for the bounding box coordinates
135
+ x_factor = shape_raw[1] / self.input_width
136
+ y_factor = shape_raw[0] / self.input_height
137
+
138
+ # Lists to store the bounding boxes, scores, and class IDs of the detections
139
+ boxes = []
140
+ scores = []
141
+ class_ids = []
142
+
143
+ if outputs.shape[-1] == 6:
144
+ max_scores = outputs[:, 4]
145
+ classid = outputs[:, -1]
146
+
147
+ threshold_conf_masks = max_scores >= self.threshold_conf
148
+ classid_masks = classid[threshold_conf_masks] != 3.14159
149
+
150
+ max_scores = max_scores[threshold_conf_masks][classid_masks]
151
+ classid = classid[threshold_conf_masks][classid_masks]
152
+
153
+ boxes = outputs[:, :4][threshold_conf_masks][classid_masks]
154
+ boxes[:, [0, 2]] *= x_factor
155
+ boxes[:, [1, 3]] *= y_factor
156
+ boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
157
+ boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
158
+ boxes = boxes.astype(np.int32)
159
+
160
+ else:
161
+ classes_scores = outputs[:, 4:]
162
+ max_scores = np.amax(classes_scores, -1)
163
+ threshold_conf_masks = max_scores >= self.threshold_conf
164
+
165
+ classid = np.argmax(classes_scores[threshold_conf_masks], -1)
166
+
167
+ classid_masks = classid!=3.14159
168
+
169
+ classes_scores = classes_scores[threshold_conf_masks][classid_masks]
170
+ max_scores = max_scores[threshold_conf_masks][classid_masks]
171
+ classid = classid[classid_masks]
172
+
173
+ xywh = outputs[:, :4][threshold_conf_masks][classid_masks]
174
+
175
+ x = xywh[:, 0:1]
176
+ y = xywh[:, 1:2]
177
+ w = xywh[:, 2:3]
178
+ h = xywh[:, 3:4]
179
+
180
+ left = ((x - w / 2) * x_factor)
181
+ top = ((y - h / 2) * y_factor)
182
+ width = (w * x_factor)
183
+ height = (h * y_factor)
184
+ boxes = np.concatenate([left, top, width, height], axis=-1).astype(np.int32)
185
+
186
+ boxes = boxes.tolist()
187
+ scores = max_scores.tolist()
188
+ class_ids = classid.tolist()
189
+
190
+ # Apply non-maximum suppression to filter out overlapping bounding boxes
191
+ indices = cv2.dnn.NMSBoxes(boxes, scores, self.threshold_conf, self.threshold_iou)
192
+ # Iterate over the selected indices after non-maximum suppression
193
+
194
+ results = []
195
+ for i in indices:
196
+ # Get the box, score, and class ID corresponding to the index
197
+ box = box_convert_simple(boxes[i], 'xywh2xyxy')
198
+ score = scores[i]
199
+ class_id = class_ids[i]
200
+ results.append(box + [score] + [class_id])
201
+ # # Draw the detection on the input image
202
+
203
+ # Return the modified input image
204
+ return np.array(results)
205
+
206
+
207
+ def process_results(self, results, shape_raw, cat_id=[1], single_person=True):
208
+ if isinstance(results, tuple):
209
+ det_results = results[0]
210
+ else:
211
+ det_results = results
212
+
213
+ person_results = []
214
+ person_count = 0
215
+ if len(results):
216
+ max_idx = -1
217
+ max_bbox_size = shape_raw[0] * shape_raw[1] * -10
218
+ max_bbox_shape = -1
219
+
220
+ bboxes = []
221
+ idx_list = []
222
+ for i in range(results.shape[0]):
223
+ bbox = results[i]
224
+ if (bbox[-1] + 1 in cat_id) and (bbox[-2] > self.threshold_conf):
225
+ idx_list.append(i)
226
+ bbox_shape = max((bbox[2] - bbox[0]), ((bbox[3] - bbox[1])))
227
+ if bbox_shape > max_bbox_shape:
228
+ max_bbox_shape = bbox_shape
229
+
230
+ results = results[idx_list]
231
+
232
+ for i in range(results.shape[0]):
233
+ bbox = results[i]
234
+ bboxes.append(bbox)
235
+ if self.select_type == 'max':
236
+ bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1]))
237
+ elif self.select_type == 'center':
238
+ bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1]/2)) * -1
239
+ bbox_shape = max((bbox[2] - bbox[0]), ((bbox[3] - bbox[1])))
240
+ if bbox_size > max_bbox_size:
241
+ if (self.strict or max_idx != -1) and bbox_shape < max_bbox_shape * self.threshold_bbox_shape_ratio:
242
+ continue
243
+ max_bbox_size = bbox_size
244
+ max_bbox_shape = bbox_shape
245
+ max_idx = i
246
+
247
+ if self.sorted_func is not None and len(bboxes) > 0:
248
+ max_idx = self.sorted_func(bboxes, shape_raw)
249
+ bbox = bboxes[max_idx]
250
+ if self.select_type == 'max':
251
+ max_bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1]))
252
+ elif self.select_type == 'center':
253
+ max_bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1]/2)) * -1
254
+
255
+ if max_idx != -1:
256
+ person_count = 1
257
+
258
+ if max_idx != -1:
259
+ person = {}
260
+ person['bbox'] = results[max_idx, :5]
261
+ person['track_id'] = int(0)
262
+ person_results.append(person)
263
+
264
+ for i in range(results.shape[0]):
265
+ bbox = results[i]
266
+ if (bbox[-1] + 1 in cat_id) and (bbox[-2] > self.threshold_conf):
267
+ if self.select_type == 'max':
268
+ bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1]))
269
+ elif self.select_type == 'center':
270
+ bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1]/2)) * -1
271
+ if i != max_idx and bbox_size > max_bbox_size * self.threshold_multi_persons and bbox_size < max_bbox_size:
272
+ person_count += 1
273
+ if not single_person:
274
+ person = {}
275
+ person['bbox'] = results[i, :5]
276
+ person['track_id'] = int(person_count - 1)
277
+ person_results.append(person)
278
+ return person_results
279
+ else:
280
+ return None
281
+
282
+
283
+ def postprocess_threading(self, outputs, shape_raw, person_results, i, single_person=True, **kwargs):
284
+ result = self.postprocess(outputs[i], shape_raw[i], cat_id=self.cat_id)
285
+ result = self.process_results(result, shape_raw[i], cat_id=self.cat_id, single_person=single_person)
286
+ if result is not None and len(result) != 0:
287
+ person_results[i] = result
288
+
289
+
290
+ def forward(self, img, shape_raw, **kwargs):
291
+ """
292
+ Performs inference using an ONNX model and returns the output image with drawn detections.
293
+
294
+ Returns:
295
+ output_img: The output image with drawn detections.
296
+ """
297
+ if isinstance(img, torch.Tensor):
298
+ img = img.cpu().numpy()
299
+ shape_raw = shape_raw.cpu().numpy()
300
+
301
+ outputs = self.session.run(None, {self.session.get_inputs()[0].name: img})[0]
302
+ person_results = [[{'bbox': np.array([0., 0., 1.*shape_raw[i][1], 1.*shape_raw[i][0], -1]), 'track_id': -1}] for i in range(len(outputs))]
303
+
304
+ for i in range(len(outputs)):
305
+ self.postprocess_threading(outputs, shape_raw, person_results, i, **kwargs)
306
+ return person_results
307
+
308
+
309
+ class ViTPose(SimpleOnnxInference):
310
+ def __init__(self, checkpoint, device='cuda', **kwargs):
311
+ super(ViTPose, self).__init__(checkpoint, device=device)
312
+ self.session.set_providers(["CUDAExecutionProvider"])
313
+
314
+ def forward(self, img, center, scale, **kwargs):
315
+ heatmaps = self.session.run([], {self.session.get_inputs()[0].name: img})[0]
316
+ points, prob = keypoints_from_heatmaps(heatmaps=heatmaps,
317
+ center=center,
318
+ scale=scale*200,
319
+ unbiased=True,
320
+ use_udp=False)
321
+ return np.concatenate([points, prob], axis=2)
322
+
323
+
324
+ @staticmethod
325
+ def preprocess(img, bbox=None, input_resolution=(256, 192), rescale=1.25, mask=None, **kwargs):
326
+ if bbox is None or bbox[-1] <= 0 or (bbox[2] - bbox[0]) < 10 or (bbox[3] - bbox[1]) < 10:
327
+ bbox = np.array([0, 0, img.shape[1], img.shape[0]])
328
+
329
+ bbox_xywh = bbox
330
+ if mask is not None:
331
+ img = np.where(mask>128, img, mask)
332
+
333
+ if isinstance(input_resolution, int):
334
+ center, scale = bbox_from_detector(bbox_xywh, (input_resolution, input_resolution), rescale=rescale)
335
+ img, new_shape, old_xy, new_xy = crop(img, center, scale, (input_resolution, input_resolution))
336
+ else:
337
+ center, scale = bbox_from_detector(bbox_xywh, input_resolution, rescale=rescale)
338
+ img, new_shape, old_xy, new_xy = crop(img, center, scale, (input_resolution[0], input_resolution[1]))
339
+
340
+ IMG_NORM_MEAN = np.array([0.485, 0.456, 0.406])
341
+ IMG_NORM_STD = np.array([0.229, 0.224, 0.225])
342
+ img_norm = (img / 255. - IMG_NORM_MEAN) / IMG_NORM_STD
343
+ img_norm = img_norm.transpose(2, 0, 1).astype(np.float32)
344
+ return img_norm, np.array(center), np.array(scale)
345
+
346
+
347
+ class Pose2d:
348
+ def __init__(self, checkpoint, detector_checkpoint=None, device='cuda', **kwargs):
349
+
350
+ if detector_checkpoint is not None:
351
+ self.detector = Yolo(detector_checkpoint, device)
352
+ else:
353
+ self.detector = None
354
+
355
+ self.model = ViTPose(checkpoint, device)
356
+ self.device = device
357
+
358
+ def load_images(self, inputs):
359
+ """
360
+ Load images from various input types.
361
+
362
+ Args:
363
+ inputs (Union[str, np.ndarray, List[np.ndarray]]): Input can be file path,
364
+ single image array, or list of image arrays
365
+
366
+ Returns:
367
+ List[np.ndarray]: List of RGB image arrays
368
+
369
+ Raises:
370
+ ValueError: If file format is unsupported or image cannot be read
371
+ """
372
+ if isinstance(inputs, str):
373
+ if inputs.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
374
+ cap = cv2.VideoCapture(inputs)
375
+ frames = []
376
+ while True:
377
+ ret, frame = cap.read()
378
+ if not ret:
379
+ break
380
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
381
+ cap.release()
382
+ images = frames
383
+ elif inputs.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
384
+ img = cv2.cvtColor(cv2.imread(inputs), cv2.COLOR_BGR2RGB)
385
+ if img is None:
386
+ raise ValueError(f"Cannot read image: {inputs}")
387
+ images = [img]
388
+ else:
389
+ raise ValueError(f"Unsupported file format: {inputs}")
390
+
391
+ elif isinstance(inputs, np.ndarray):
392
+ images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in inputs]
393
+ elif isinstance(inputs, list):
394
+ images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in inputs]
395
+ return images
396
+
397
+ def __call__(
398
+ self,
399
+ inputs: Union[str, np.ndarray, List[np.ndarray]],
400
+ return_image: bool = False,
401
+ **kwargs
402
+ ):
403
+ """
404
+ Process input and estimate 2D keypoints.
405
+
406
+ Args:
407
+ inputs (Union[str, np.ndarray, List[np.ndarray]]): Input can be file path,
408
+ single image array, or list of image arrays
409
+ **kwargs: Additional arguments for processing
410
+
411
+ Returns:
412
+ np.ndarray: Array of detected 2D keypoints for all input images
413
+ """
414
+ images = self.load_images(inputs)
415
+ H, W = images[0].shape[:2]
416
+ if self.detector is not None:
417
+ bboxes = []
418
+ for _image in images:
419
+ img, shape = self.detector.preprocess(_image)
420
+ bboxes.append(self.detector(img[None], shape[None])[0][0]["bbox"])
421
+ else:
422
+ bboxes = [None] * len(images)
423
+
424
+ kp2ds = []
425
+ for _image, _bbox in zip(images, bboxes):
426
+ img, center, scale = self.model.preprocess(_image, _bbox)
427
+ kp2ds.append(self.model(img[None], center[None], scale[None]))
428
+ kp2ds = np.concatenate(kp2ds, 0)
429
+ metas = load_pose_metas_from_kp2ds_seq(kp2ds, width=W, height=H)
430
+ return metas
wan/modules/animate/preprocess/pose2d_utils.py ADDED
@@ -0,0 +1,1159 @@