Spaces:
Running
on
Zero
Running
on
Zero
Upload 69 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -0
- INSTALL.md +55 -0
- LICENSE.txt +201 -0
- Makefile +5 -0
- README.md +12 -12
- app.py +546 -0
- examples/desi.mp4 +3 -0
- examples/desi.png +3 -0
- examples/man.png +3 -0
- examples/paul.mp4 +3 -0
- generate.py +236 -0
- pyproject.toml +66 -0
- requirements.txt +31 -0
- wan/__init__.py +7 -0
- wan/animate.py +653 -0
- wan/configs/__init__.py +50 -0
- wan/configs/shared_config.py +20 -0
- wan/configs/wan_animate_14B.py +40 -0
- wan/configs/wan_i2v_A14B.py +37 -0
- wan/configs/wan_s2v_14B.py +59 -0
- wan/configs/wan_t2v_A14B.py +37 -0
- wan/configs/wan_ti2v_5B.py +36 -0
- wan/distributed/__init__.py +1 -0
- wan/distributed/fsdp.py +45 -0
- wan/distributed/sequence_parallel.py +176 -0
- wan/distributed/ulysses.py +47 -0
- wan/distributed/util.py +51 -0
- wan/image2video.py +431 -0
- wan/modules/__init__.py +19 -0
- wan/modules/animate/__init__.py +4 -0
- wan/modules/animate/animate_utils.py +143 -0
- wan/modules/animate/clip.py +542 -0
- wan/modules/animate/face_blocks.py +383 -0
- wan/modules/animate/model_animate.py +500 -0
- wan/modules/animate/motion_encoder.py +307 -0
- wan/modules/animate/preprocess/UserGuider.md +70 -0
- wan/modules/animate/preprocess/__init__.py +3 -0
- wan/modules/animate/preprocess/human_visualization.py +1357 -0
- wan/modules/animate/preprocess/pose2d.py +430 -0
- wan/modules/animate/preprocess/pose2d_utils.py +1159 -0
- wan/modules/animate/preprocess/preprocess_data.py +121 -0
- wan/modules/animate/preprocess/process_pipepline.py +354 -0
- wan/modules/animate/preprocess/retarget_pose.py +847 -0
- wan/modules/animate/preprocess/sam_utils.py +155 -0
- wan/modules/animate/preprocess/utils.py +226 -0
- wan/modules/animate/preprocess/video_predictor.py +157 -0
- wan/modules/animate/xlm_roberta.py +170 -0
- wan/modules/attention.py +256 -0
- wan/modules/model.py +546 -0
- wan/modules/s2v/__init__.py +5 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
examples/desi.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
examples/desi.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
examples/man.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
examples/paul.mp4 filter=lfs diff=lfs merge=lfs -text
|
INSTALL.md
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Installation Guide
|
| 2 |
+
|
| 3 |
+
## Install with pip
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
pip install .
|
| 7 |
+
pip install .[dev] # Installe aussi les outils de dev
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
## Install with Poetry
|
| 11 |
+
|
| 12 |
+
Ensure you have [Poetry](https://python-poetry.org/docs/#installation) installed on your system.
|
| 13 |
+
|
| 14 |
+
To install all dependencies:
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
poetry install
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
### Handling `flash-attn` Installation Issues
|
| 21 |
+
|
| 22 |
+
If `flash-attn` fails due to **PEP 517 build issues**, you can try one of the following fixes.
|
| 23 |
+
|
| 24 |
+
#### No-Build-Isolation Installation (Recommended)
|
| 25 |
+
```bash
|
| 26 |
+
poetry run pip install --upgrade pip setuptools wheel
|
| 27 |
+
poetry run pip install flash-attn --no-build-isolation
|
| 28 |
+
poetry install
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
#### Install from Git (Alternative)
|
| 32 |
+
```bash
|
| 33 |
+
poetry run pip install git+https://github.com/Dao-AILab/flash-attention.git
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
---
|
| 37 |
+
|
| 38 |
+
### Running the Model
|
| 39 |
+
|
| 40 |
+
Once the installation is complete, you can run **Wan2.2** using:
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
poetry run python generate.py --task t2v-A14B --size '1280*720' --ckpt_dir ./Wan2.2-T2V-A14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
#### Test
|
| 47 |
+
```bash
|
| 48 |
+
bash tests/test.sh
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
#### Format
|
| 52 |
+
```bash
|
| 53 |
+
black .
|
| 54 |
+
isort .
|
| 55 |
+
```
|
LICENSE.txt
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
Makefile
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: format
|
| 2 |
+
|
| 3 |
+
format:
|
| 4 |
+
isort generate.py wan
|
| 5 |
+
yapf -i -r *.py generate.py wan
|
README.md
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: Wan2.2 Animate
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.49.1
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
---
|
| 11 |
-
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Wan2.2 Animate [Local]
|
| 3 |
+
emoji: 🔥
|
| 4 |
+
colorFrom: pink
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.49.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import spaces
|
| 2 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
| 3 |
+
import os
|
| 4 |
+
import subprocess
|
| 5 |
+
import importlib, site
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import uuid
|
| 8 |
+
import shutil
|
| 9 |
+
import time
|
| 10 |
+
import cv2
|
| 11 |
+
from generate import generate, load_model
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
# Re-discover all .pth/.egg-link files
|
| 15 |
+
for sitedir in site.getsitepackages():
|
| 16 |
+
site.addsitedir(sitedir)
|
| 17 |
+
|
| 18 |
+
# Clear caches so importlib will pick up new modules
|
| 19 |
+
importlib.invalidate_caches()
|
| 20 |
+
|
| 21 |
+
def sh(cmd): subprocess.check_call(cmd, shell=True)
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
print("Attempting to download and build sam2...")
|
| 25 |
+
|
| 26 |
+
print("download sam")
|
| 27 |
+
sam_dir = snapshot_download(repo_id="alexnasa/sam2")
|
| 28 |
+
|
| 29 |
+
@spaces.GPU(duration=450)
|
| 30 |
+
def install_sam():
|
| 31 |
+
os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0"
|
| 32 |
+
sh(f"cd {sam_dir} && python setup.py build_ext --inplace && pip install -e .")
|
| 33 |
+
|
| 34 |
+
print("install sam")
|
| 35 |
+
install_sam()
|
| 36 |
+
|
| 37 |
+
# tell Python to re-scan site-packages now that the egg-link exists
|
| 38 |
+
import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.invalidate_caches()
|
| 39 |
+
|
| 40 |
+
flash_attention_installed = True
|
| 41 |
+
print("sam2 installed successfully.")
|
| 42 |
+
|
| 43 |
+
except Exception as e:
|
| 44 |
+
print(f"⚠️ Could not install sam2: {e}")
|
| 45 |
+
print("Continuing without sam2...")
|
| 46 |
+
|
| 47 |
+
import torch
|
| 48 |
+
print(f"Torch version: {torch.__version__}")
|
| 49 |
+
|
| 50 |
+
os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/processed_results"
|
| 51 |
+
|
| 52 |
+
import gradio as gr
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
snapshot_download(repo_id="Wan-AI/Wan2.2-Animate-14B", local_dir="./Wan2.2-Animate-14B")
|
| 56 |
+
wan_animate = load_model(True)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
rc_mapping = {
|
| 60 |
+
"Video → Ref Image" : False,
|
| 61 |
+
"Video ← Ref Image" : True
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def preprocess_video(input_video_path, session_id=None):
|
| 66 |
+
|
| 67 |
+
if session_id is None:
|
| 68 |
+
session_id = uuid.uuid4().hex
|
| 69 |
+
|
| 70 |
+
output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
|
| 71 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 72 |
+
|
| 73 |
+
process_video_path = os.path.join(output_dir, 'input_video.mp4')
|
| 74 |
+
|
| 75 |
+
convert_video_to_30fps_and_clip(input_video_path, process_video_path, crop_width=720, crop_height=1280)
|
| 76 |
+
|
| 77 |
+
return process_video_path
|
| 78 |
+
|
| 79 |
+
def extract_audio_from_video_ffmpeg(video_path, output_wav_path, sample_rate=None):
|
| 80 |
+
"""
|
| 81 |
+
Extracts the audio track from a video file and saves it as a WAV file.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
video_path (str): Path to the input video file.
|
| 85 |
+
output_wav_path (str): Path to save the extracted WAV file.
|
| 86 |
+
sample_rate (int, optional): Output sample rate (e.g., 16000).
|
| 87 |
+
If None, keep the original.
|
| 88 |
+
"""
|
| 89 |
+
cmd = [
|
| 90 |
+
'ffmpeg',
|
| 91 |
+
'-i', video_path, # Input video
|
| 92 |
+
'-vn', # Disable video
|
| 93 |
+
'-acodec', 'pcm_s16le', # 16-bit PCM (WAV format)
|
| 94 |
+
'-ac', '1', # Mono channel (use '2' for stereo)
|
| 95 |
+
'-y', # Overwrite output
|
| 96 |
+
'-loglevel', 'error' # Cleaner output
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
# Only add the sample rate option if explicitly specified
|
| 100 |
+
if sample_rate is not None:
|
| 101 |
+
cmd.extend(['-ar', str(sample_rate)])
|
| 102 |
+
|
| 103 |
+
cmd.append(output_wav_path)
|
| 104 |
+
|
| 105 |
+
try:
|
| 106 |
+
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
| 107 |
+
except subprocess.CalledProcessError as e:
|
| 108 |
+
raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}")
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def combine_video_and_audio_ffmpeg(video_path, audio_path, output_video_path):
|
| 112 |
+
"""
|
| 113 |
+
Combines a silent MP4 video with a WAV audio file into a single MP4 with sound.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
video_path (str): Path to the silent video file.
|
| 117 |
+
audio_path (str): Path to the WAV audio file.
|
| 118 |
+
output_video_path (str): Path to save the output MP4 with audio.
|
| 119 |
+
"""
|
| 120 |
+
cmd = [
|
| 121 |
+
'ffmpeg',
|
| 122 |
+
'-i', video_path, # Input video
|
| 123 |
+
'-i', audio_path, # Input audio
|
| 124 |
+
'-c:v', 'copy', # Copy video without re-encoding
|
| 125 |
+
'-c:a', 'aac', # Encode audio as AAC (MP4-compatible)
|
| 126 |
+
'-shortest', # Stop when the shortest stream ends
|
| 127 |
+
'-y', # Overwrite output
|
| 128 |
+
'-loglevel', 'error',
|
| 129 |
+
output_video_path
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
| 134 |
+
except subprocess.CalledProcessError as e:
|
| 135 |
+
raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def convert_video_to_30fps_and_clip(
|
| 139 |
+
input_video_path,
|
| 140 |
+
output_video_path,
|
| 141 |
+
duration_s=2,
|
| 142 |
+
target_fps=30,
|
| 143 |
+
crop_width=None,
|
| 144 |
+
crop_height=None
|
| 145 |
+
):
|
| 146 |
+
# Get input video dimensions using ffprobe
|
| 147 |
+
if crop_width and crop_height:
|
| 148 |
+
probe_cmd = [
|
| 149 |
+
'ffprobe', '-v', 'error', '-select_streams', 'v:0',
|
| 150 |
+
'-show_entries', 'stream=width,height',
|
| 151 |
+
'-of', 'json', input_video_path
|
| 152 |
+
]
|
| 153 |
+
probe_result = subprocess.run(probe_cmd, capture_output=True, text=True, check=True)
|
| 154 |
+
video_info = json.loads(probe_result.stdout)
|
| 155 |
+
w = video_info['streams'][0]['width']
|
| 156 |
+
h = video_info['streams'][0]['height']
|
| 157 |
+
|
| 158 |
+
# Clamp crop size to not exceed actual dimensions
|
| 159 |
+
crop_width = min(crop_width, w)
|
| 160 |
+
crop_height = min(crop_height, h)
|
| 161 |
+
|
| 162 |
+
# Center crop offsets
|
| 163 |
+
crop_x = max((w - crop_width) // 2, 0)
|
| 164 |
+
crop_y = max((h - crop_height) // 2, 0)
|
| 165 |
+
crop_filter = f"crop={crop_width}:{crop_height}:{crop_x}:{crop_y}"
|
| 166 |
+
else:
|
| 167 |
+
crop_filter = None
|
| 168 |
+
|
| 169 |
+
cmd = [
|
| 170 |
+
'ffmpeg',
|
| 171 |
+
'-i', input_video_path,
|
| 172 |
+
'-r', str(target_fps),
|
| 173 |
+
'-t', str(duration_s),
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
if crop_filter:
|
| 177 |
+
cmd += ['-vf', crop_filter]
|
| 178 |
+
|
| 179 |
+
cmd += [
|
| 180 |
+
'-c:v', 'libx264',
|
| 181 |
+
'-c:a', 'aac',
|
| 182 |
+
'-strict', 'experimental',
|
| 183 |
+
'-y',
|
| 184 |
+
'-loglevel', 'error',
|
| 185 |
+
output_video_path
|
| 186 |
+
]
|
| 187 |
+
|
| 188 |
+
try:
|
| 189 |
+
subprocess.run(cmd, check=True, capture_output=True, text=True)
|
| 190 |
+
except subprocess.CalledProcessError as e:
|
| 191 |
+
raise RuntimeError(f"ffmpeg failed ({e.returncode}): {e.stderr.strip()}")
|
| 192 |
+
|
| 193 |
+
def get_frames_count(video_file):
|
| 194 |
+
|
| 195 |
+
# Get video information
|
| 196 |
+
cap = cv2.VideoCapture(video_file)
|
| 197 |
+
if not cap.isOpened():
|
| 198 |
+
error_msg = "Cannot open video file"
|
| 199 |
+
gr.Warning(error_msg)
|
| 200 |
+
|
| 201 |
+
orig_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 202 |
+
orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 203 |
+
orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 204 |
+
|
| 205 |
+
cap.release()
|
| 206 |
+
|
| 207 |
+
return orig_frame_count
|
| 208 |
+
|
| 209 |
+
def calculate_time_required(input_video, rc_bool):
|
| 210 |
+
|
| 211 |
+
frames_count = get_frames_count(input_video)
|
| 212 |
+
|
| 213 |
+
chunks = frames_count // 77 + 1
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
if rc_bool:
|
| 217 |
+
pose2d_tracking_duration_s = 75
|
| 218 |
+
iteration_per_step_s = 13
|
| 219 |
+
else:
|
| 220 |
+
pose2d_tracking_duration_s = 50
|
| 221 |
+
iteration_per_step_s = 12
|
| 222 |
+
|
| 223 |
+
time_required = pose2d_tracking_duration_s + iteration_per_step_s * 20 * chunks
|
| 224 |
+
print(f'for frames_count:{frames_count} doing {chunks} chunks the time_required is {time_required}')
|
| 225 |
+
return time_required
|
| 226 |
+
|
| 227 |
+
def update_time_required(input_video, rc_str):
|
| 228 |
+
|
| 229 |
+
if input_video is None:
|
| 230 |
+
return gr.update(value="⌚ Zero GPU Required: --")
|
| 231 |
+
|
| 232 |
+
rc_bool = rc_mapping[rc_str]
|
| 233 |
+
|
| 234 |
+
duration_s = calculate_time_required(input_video, rc_bool)
|
| 235 |
+
duration_m = duration_s / 60
|
| 236 |
+
|
| 237 |
+
return gr.update(value=f"⌚ Zero GPU Required: ~{duration_s}.0s ({duration_m:.1f} mins)")
|
| 238 |
+
|
| 239 |
+
def get_duration(input_video, edited_frame, rc_bool, session_id, progress):
|
| 240 |
+
|
| 241 |
+
return calculate_time_required(input_video, rc_bool)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@spaces.GPU(duration=get_duration)
|
| 245 |
+
def _animate(input_video, edited_frame, rc_bool, session_id = None, progress=gr.Progress(track_tqdm=True),):
|
| 246 |
+
|
| 247 |
+
if session_id is None:
|
| 248 |
+
session_id = uuid.uuid4().hex
|
| 249 |
+
|
| 250 |
+
output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
|
| 251 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 252 |
+
|
| 253 |
+
preprocess_dir = os.path.join(output_dir, "preprocess_dir")
|
| 254 |
+
os.makedirs(preprocess_dir, exist_ok=True)
|
| 255 |
+
|
| 256 |
+
output_video_path = os.path.join(output_dir, 'result.mp4')
|
| 257 |
+
|
| 258 |
+
# --- Measure preprocess time ---
|
| 259 |
+
start_preprocess = time.time()
|
| 260 |
+
|
| 261 |
+
# w = 720
|
| 262 |
+
# h = 480
|
| 263 |
+
|
| 264 |
+
# w = 720
|
| 265 |
+
# h = 1280
|
| 266 |
+
|
| 267 |
+
w = 480
|
| 268 |
+
h = 832
|
| 269 |
+
|
| 270 |
+
# w = 480
|
| 271 |
+
# h = 720
|
| 272 |
+
|
| 273 |
+
tag_string = "retarget_flag"
|
| 274 |
+
|
| 275 |
+
if rc_bool:
|
| 276 |
+
tag_string = "replace_flag"
|
| 277 |
+
|
| 278 |
+
sh("python ./wan/modules/animate/preprocess/preprocess_data.py "
|
| 279 |
+
"--ckpt_path ./Wan2.2-Animate-14B/process_checkpoint "
|
| 280 |
+
f"--video_path {input_video} "
|
| 281 |
+
f"--refer_path {edited_frame} "
|
| 282 |
+
f"--save_path {preprocess_dir} "
|
| 283 |
+
f"--resolution_area {w} {h} --{tag_string} "
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
preprocess_time = time.time() - start_preprocess
|
| 287 |
+
print(f"Preprocess took {preprocess_time:.2f} seconds")
|
| 288 |
+
|
| 289 |
+
# --- Measure generate time ---
|
| 290 |
+
start_generate = time.time()
|
| 291 |
+
|
| 292 |
+
generate(wan_animate, preprocess_dir, output_video_path, rc_bool)
|
| 293 |
+
|
| 294 |
+
generate_time = time.time() - start_generate
|
| 295 |
+
print(f"Generate took {generate_time:.2f} seconds")
|
| 296 |
+
|
| 297 |
+
# --- Optional total time ---
|
| 298 |
+
total_time = preprocess_time + generate_time
|
| 299 |
+
print(f"Total time: {total_time:.2f} seconds")
|
| 300 |
+
|
| 301 |
+
return output_video_path
|
| 302 |
+
|
| 303 |
+
def animate_scene(input_video, edited_frame, rc_str, session_id = None, progress=gr.Progress(track_tqdm=True),):
|
| 304 |
+
|
| 305 |
+
if not input_video:
|
| 306 |
+
raise gr.Error("Please provide an video")
|
| 307 |
+
|
| 308 |
+
if not edited_frame:
|
| 309 |
+
raise gr.Error("Please provide an image")
|
| 310 |
+
|
| 311 |
+
if session_id is None:
|
| 312 |
+
session_id = uuid.uuid4().hex
|
| 313 |
+
|
| 314 |
+
rc_bool = rc_mapping[rc_str]
|
| 315 |
+
|
| 316 |
+
output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
|
| 317 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 318 |
+
|
| 319 |
+
input_audio_path = os.path.join(output_dir, 'input_audio.wav')
|
| 320 |
+
|
| 321 |
+
extract_audio_from_video_ffmpeg(input_video, input_audio_path)
|
| 322 |
+
|
| 323 |
+
output_video_path = _animate(input_video, edited_frame, rc_bool, session_id, progress)
|
| 324 |
+
|
| 325 |
+
final_video_path = os.path.join(output_dir, 'final_result.mp4')
|
| 326 |
+
|
| 327 |
+
preprocess_dir = os.path.join(output_dir, "preprocess_dir")
|
| 328 |
+
pose_video = os.path.join(preprocess_dir, 'src_pose.mp4')
|
| 329 |
+
|
| 330 |
+
if rc_bool:
|
| 331 |
+
mask_video = os.path.join(preprocess_dir, 'src_mask.mp4')
|
| 332 |
+
bg_video = os.path.join(preprocess_dir, 'src_bg.mp4')
|
| 333 |
+
face_video = os.path.join(preprocess_dir, 'src_face.mp4')
|
| 334 |
+
else:
|
| 335 |
+
mask_video = os.path.join(preprocess_dir, 'src_pose.mp4')
|
| 336 |
+
bg_video = os.path.join(preprocess_dir, 'src_pose.mp4')
|
| 337 |
+
face_video = os.path.join(preprocess_dir, 'src_pose.mp4')
|
| 338 |
+
|
| 339 |
+
combine_video_and_audio_ffmpeg(output_video_path, input_audio_path, final_video_path)
|
| 340 |
+
|
| 341 |
+
return final_video_path, pose_video, bg_video, mask_video, face_video
|
| 342 |
+
|
| 343 |
+
css = """
|
| 344 |
+
#col-container {
|
| 345 |
+
margin: 0 auto;
|
| 346 |
+
max-width: 1600px;
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
#step-column {
|
| 350 |
+
padding: 20px;
|
| 351 |
+
border-radius: 8px;
|
| 352 |
+
box-shadow: var(--card-shadow);
|
| 353 |
+
margin: 10px;
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
#col-showcase {
|
| 357 |
+
margin: 0 auto;
|
| 358 |
+
max-width: 1100px;
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
.button-gradient {
|
| 362 |
+
background: linear-gradient(45deg, rgb(255, 65, 108), rgb(255, 75, 43), rgb(255, 155, 0), rgb(255, 65, 108)) 0% 0% / 400% 400%;
|
| 363 |
+
border: none;
|
| 364 |
+
padding: 14px 28px;
|
| 365 |
+
font-size: 16px;
|
| 366 |
+
font-weight: bold;
|
| 367 |
+
color: white;
|
| 368 |
+
border-radius: 10px;
|
| 369 |
+
cursor: pointer;
|
| 370 |
+
transition: 0.3s ease-in-out;
|
| 371 |
+
animation: 2s linear 0s infinite normal none running gradientAnimation;
|
| 372 |
+
box-shadow: rgba(255, 65, 108, 0.6) 0px 4px 10px;
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
.toggle-container {
|
| 376 |
+
display: inline-flex;
|
| 377 |
+
background-color: #ffd6ff; /* light pink background */
|
| 378 |
+
border-radius: 9999px;
|
| 379 |
+
padding: 4px;
|
| 380 |
+
position: relative;
|
| 381 |
+
width: fit-content;
|
| 382 |
+
font-family: sans-serif;
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
.toggle-container input[type="radio"] {
|
| 386 |
+
display: none;
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
.toggle-container label {
|
| 390 |
+
position: relative;
|
| 391 |
+
z-index: 2;
|
| 392 |
+
flex: 1;
|
| 393 |
+
text-align: center;
|
| 394 |
+
font-weight: 700;
|
| 395 |
+
color: #4b2ab5; /* dark purple text for unselected */
|
| 396 |
+
padding: 6px 22px;
|
| 397 |
+
border-radius: 9999px;
|
| 398 |
+
cursor: pointer;
|
| 399 |
+
transition: color 0.25s ease;
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
/* Moving highlight */
|
| 403 |
+
.toggle-highlight {
|
| 404 |
+
position: absolute;
|
| 405 |
+
top: 4px;
|
| 406 |
+
left: 4px;
|
| 407 |
+
width: calc(50% - 4px);
|
| 408 |
+
height: calc(100% - 8px);
|
| 409 |
+
background-color: #4b2ab5; /* dark purple background */
|
| 410 |
+
border-radius: 9999px;
|
| 411 |
+
transition: transform 0.25s ease;
|
| 412 |
+
z-index: 1;
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
/* When "True" is checked */
|
| 416 |
+
#true:checked ~ label[for="true"] {
|
| 417 |
+
color: #ffd6ff; /* light pink text */
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
/* When "False" is checked */
|
| 421 |
+
#false:checked ~ label[for="false"] {
|
| 422 |
+
color: #ffd6ff; /* light pink text */
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
/* Move highlight to right side when False is checked */
|
| 426 |
+
#false:checked ~ .toggle-highlight {
|
| 427 |
+
transform: translateX(100%);
|
| 428 |
+
}
|
| 429 |
+
"""
|
| 430 |
+
def start_session(request: gr.Request):
|
| 431 |
+
|
| 432 |
+
return request.session_hash
|
| 433 |
+
|
| 434 |
+
def cleanup(request: gr.Request):
|
| 435 |
+
|
| 436 |
+
sid = request.session_hash
|
| 437 |
+
|
| 438 |
+
if sid:
|
| 439 |
+
d1 = os.path.join(os.environ["PROCESSED_RESULTS"], sid)
|
| 440 |
+
shutil.rmtree(d1, ignore_errors=True)
|
| 441 |
+
|
| 442 |
+
with gr.Blocks(css=css, title="Wan 2.2 Animate --replace", theme=gr.themes.Ocean()) as demo:
|
| 443 |
+
|
| 444 |
+
session_state = gr.State()
|
| 445 |
+
demo.load(start_session, outputs=[session_state])
|
| 446 |
+
|
| 447 |
+
with gr.Column(elem_id="col-container"):
|
| 448 |
+
with gr.Row():
|
| 449 |
+
gr.HTML(
|
| 450 |
+
"""
|
| 451 |
+
<div style="text-align: center;">
|
| 452 |
+
<p style="font-size:16px; display: inline; margin: 0;">
|
| 453 |
+
<strong>Wan2.2-Animate-14B </strong>
|
| 454 |
+
</p>
|
| 455 |
+
<a href="https://huggingface.co/Wan-AI/Wan2.2-Animate-14B" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
|
| 456 |
+
[Model]
|
| 457 |
+
</a>
|
| 458 |
+
<div style="text-align: center;">
|
| 459 |
+
<p style="font-size:16px; display: inline; margin: 0;">
|
| 460 |
+
HF Space By:
|
| 461 |
+
</p>
|
| 462 |
+
<a href="https://huggingface.co/alexnasa" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
|
| 463 |
+
<img src="https://img.shields.io/badge/🤗-Follow Me-yellow.svg">
|
| 464 |
+
</a>
|
| 465 |
+
</div>
|
| 466 |
+
"""
|
| 467 |
+
)
|
| 468 |
+
with gr.Row():
|
| 469 |
+
with gr.Column(elem_id="step-column"):
|
| 470 |
+
gr.HTML("""
|
| 471 |
+
<div>
|
| 472 |
+
<span style="font-size: 24px;">1. Upload a Video</span><br>
|
| 473 |
+
</div>
|
| 474 |
+
""")
|
| 475 |
+
input_video = gr.Video(label="Input Video", height=512)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
with gr.Column(elem_id="step-column"):
|
| 479 |
+
gr.HTML("""
|
| 480 |
+
<div>
|
| 481 |
+
<span style="font-size: 24px;">2. Upload a Ref Image</span><br>
|
| 482 |
+
</div>
|
| 483 |
+
""")
|
| 484 |
+
edited_frame = gr.Image(label="Ref Image", type="filepath", height=512)
|
| 485 |
+
gr.HTML("""
|
| 486 |
+
<div>
|
| 487 |
+
<span style="font-size: 24px;">3. Choose Mode</span><br>
|
| 488 |
+
</div>
|
| 489 |
+
""")
|
| 490 |
+
replace_character_string = gr.Radio(
|
| 491 |
+
["Video → Ref Image", "Video ← Ref Image"], value="Video → Ref Image", show_label=False
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
with gr.Column(elem_id="step-column"):
|
| 495 |
+
gr.HTML("""
|
| 496 |
+
<div>
|
| 497 |
+
<span style="font-size: 24px;">4. Wan Animate it!</span><br>
|
| 498 |
+
</div>
|
| 499 |
+
""")
|
| 500 |
+
output_video = gr.Video(label="Edited Video", height=512)
|
| 501 |
+
|
| 502 |
+
time_required = gr.Text(value="⌚ Zero GPU Required: --", show_label=False)
|
| 503 |
+
action_button = gr.Button("Wan Animate 🦆", variant='primary', elem_classes="button-gradient")
|
| 504 |
+
|
| 505 |
+
with gr.Accordion("Preprocessed Data", open=False, visible=False):
|
| 506 |
+
pose_video = gr.Video(label="Pose Video", height=512)
|
| 507 |
+
bg_video = gr.Video(label="Background Video", height=512)
|
| 508 |
+
face_video = gr.Video(label="Face Video", height=512)
|
| 509 |
+
mask_video = gr.Video(label="Mask Video", height=512)
|
| 510 |
+
|
| 511 |
+
with gr.Row():
|
| 512 |
+
with gr.Column(elem_id="col-showcase"):
|
| 513 |
+
|
| 514 |
+
gr.Examples(
|
| 515 |
+
examples=[
|
| 516 |
+
|
| 517 |
+
[
|
| 518 |
+
"./examples/desi.mp4",
|
| 519 |
+
"./examples/desi.png",
|
| 520 |
+
"Video ← Ref Image"
|
| 521 |
+
],
|
| 522 |
+
|
| 523 |
+
[
|
| 524 |
+
"./examples/paul.mp4",
|
| 525 |
+
"./examples/man.png",
|
| 526 |
+
"Video → Ref Image"
|
| 527 |
+
],
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
],
|
| 531 |
+
inputs=[input_video, edited_frame, replace_character_string],
|
| 532 |
+
outputs=[output_video, pose_video, bg_video, mask_video, face_video],
|
| 533 |
+
fn=animate_scene,
|
| 534 |
+
cache_examples=True,
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
action_button.click(fn=animate_scene, inputs=[input_video, edited_frame, replace_character_string, session_state], outputs=[output_video, pose_video, bg_video, mask_video, face_video])
|
| 538 |
+
|
| 539 |
+
input_video.upload(preprocess_video, inputs=[input_video, session_state], outputs=[input_video]).then(update_time_required, inputs=[input_video, replace_character_string], outputs=[time_required])
|
| 540 |
+
replace_character_string.change(update_time_required, inputs=[input_video, replace_character_string], outputs=[time_required])
|
| 541 |
+
|
| 542 |
+
if __name__ == "__main__":
|
| 543 |
+
demo.queue()
|
| 544 |
+
demo.unload(cleanup)
|
| 545 |
+
demo.launch(ssr_mode=False, share=True)
|
| 546 |
+
|
examples/desi.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e02e84151e5625fb3863ebdf65dfab06940afac5fbd471db3b46a4ebd84b248d
|
| 3 |
+
size 551595
|
examples/desi.png
ADDED
|
Git LFS Details
|
examples/man.png
ADDED
|
Git LFS Details
|
examples/paul.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fb065c2d24bff8a49955389f94c05c80d39638410dad8082f7e0eb7f2dc5c672
|
| 3 |
+
size 1029922
|
generate.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import argparse
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import warnings
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
|
| 9 |
+
warnings.filterwarnings('ignore')
|
| 10 |
+
|
| 11 |
+
import random
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.distributed as dist
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
import wan
|
| 18 |
+
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
|
| 19 |
+
from wan.distributed.util import init_distributed_group
|
| 20 |
+
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
| 21 |
+
from wan.utils.utils import merge_video_audio, save_video, str2bool
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
EXAMPLE_PROMPT = {
|
| 25 |
+
"t2v-A14B": {
|
| 26 |
+
"prompt":
|
| 27 |
+
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
| 28 |
+
},
|
| 29 |
+
"i2v-A14B": {
|
| 30 |
+
"prompt":
|
| 31 |
+
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
|
| 32 |
+
"image":
|
| 33 |
+
"examples/i2v_input.JPG",
|
| 34 |
+
},
|
| 35 |
+
"ti2v-5B": {
|
| 36 |
+
"prompt":
|
| 37 |
+
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
| 38 |
+
},
|
| 39 |
+
"animate-14B": {
|
| 40 |
+
"prompt": "视频中的人在做动作",
|
| 41 |
+
"video": "",
|
| 42 |
+
"pose": "",
|
| 43 |
+
"mask": "",
|
| 44 |
+
},
|
| 45 |
+
"s2v-14B": {
|
| 46 |
+
"prompt":
|
| 47 |
+
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
|
| 48 |
+
"image":
|
| 49 |
+
"examples/i2v_input.JPG",
|
| 50 |
+
"audio":
|
| 51 |
+
"examples/talk.wav",
|
| 52 |
+
"tts_prompt_audio":
|
| 53 |
+
"examples/zero_shot_prompt.wav",
|
| 54 |
+
"tts_prompt_text":
|
| 55 |
+
"希望你以后能够做的比我还好呦。",
|
| 56 |
+
"tts_text":
|
| 57 |
+
"收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。"
|
| 58 |
+
},
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _validate_args(args):
|
| 63 |
+
# Basic check
|
| 64 |
+
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
|
| 65 |
+
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
|
| 66 |
+
assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
|
| 67 |
+
|
| 68 |
+
if args.prompt is None:
|
| 69 |
+
args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
|
| 70 |
+
if args.image is None and "image" in EXAMPLE_PROMPT[args.task]:
|
| 71 |
+
args.image = EXAMPLE_PROMPT[args.task]["image"]
|
| 72 |
+
if args.audio is None and args.enable_tts is False and "audio" in EXAMPLE_PROMPT[args.task]:
|
| 73 |
+
args.audio = EXAMPLE_PROMPT[args.task]["audio"]
|
| 74 |
+
if (args.tts_prompt_audio is None or args.tts_text is None) and args.enable_tts is True and "audio" in EXAMPLE_PROMPT[args.task]:
|
| 75 |
+
args.tts_prompt_audio = EXAMPLE_PROMPT[args.task]["tts_prompt_audio"]
|
| 76 |
+
args.tts_prompt_text = EXAMPLE_PROMPT[args.task]["tts_prompt_text"]
|
| 77 |
+
args.tts_text = EXAMPLE_PROMPT[args.task]["tts_text"]
|
| 78 |
+
|
| 79 |
+
if args.task == "i2v-A14B":
|
| 80 |
+
assert args.image is not None, "Please specify the image path for i2v."
|
| 81 |
+
|
| 82 |
+
cfg = WAN_CONFIGS[args.task]
|
| 83 |
+
|
| 84 |
+
if args.sample_steps is None:
|
| 85 |
+
args.sample_steps = cfg.sample_steps
|
| 86 |
+
|
| 87 |
+
if args.sample_shift is None:
|
| 88 |
+
args.sample_shift = cfg.sample_shift
|
| 89 |
+
|
| 90 |
+
if args.sample_guide_scale is None:
|
| 91 |
+
args.sample_guide_scale = cfg.sample_guide_scale
|
| 92 |
+
|
| 93 |
+
if args.frame_num is None:
|
| 94 |
+
args.frame_num = cfg.frame_num
|
| 95 |
+
|
| 96 |
+
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
|
| 97 |
+
0, sys.maxsize)
|
| 98 |
+
# Size check
|
| 99 |
+
if not 's2v' in args.task:
|
| 100 |
+
assert args.size in SUPPORTED_SIZES[
|
| 101 |
+
args.
|
| 102 |
+
task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class _Args:
|
| 106 |
+
pass
|
| 107 |
+
|
| 108 |
+
def _parse_args():
|
| 109 |
+
args = _Args()
|
| 110 |
+
|
| 111 |
+
# core generation options
|
| 112 |
+
args.task = "animate-14B"
|
| 113 |
+
# args.size = "1280*720"
|
| 114 |
+
args.size = "720*1280"
|
| 115 |
+
args.frame_num = None
|
| 116 |
+
args.ckpt_dir = "./Wan2.2-Animate-14B/"
|
| 117 |
+
args.offload_model = True
|
| 118 |
+
args.ulysses_size = 1
|
| 119 |
+
args.t5_fsdp = False
|
| 120 |
+
args.t5_cpu = False
|
| 121 |
+
args.dit_fsdp = False
|
| 122 |
+
args.prompt = None
|
| 123 |
+
args.use_prompt_extend = False
|
| 124 |
+
args.prompt_extend_method = "local_qwen" # ["dashscope", "local_qwen"]
|
| 125 |
+
args.prompt_extend_model = None
|
| 126 |
+
args.prompt_extend_target_lang = "zh" # ["zh", "en"]
|
| 127 |
+
args.base_seed = 0
|
| 128 |
+
args.image = None
|
| 129 |
+
args.sample_solver = "unipc" # ['unipc', 'dpm++']
|
| 130 |
+
args.sample_steps = None
|
| 131 |
+
args.sample_shift = None
|
| 132 |
+
args.sample_guide_scale = None
|
| 133 |
+
args.convert_model_dtype = False
|
| 134 |
+
|
| 135 |
+
# animate
|
| 136 |
+
args.refert_num = 1
|
| 137 |
+
|
| 138 |
+
# s2v-only
|
| 139 |
+
args.num_clip = None
|
| 140 |
+
args.audio = None
|
| 141 |
+
args.enable_tts = False
|
| 142 |
+
args.tts_prompt_audio = None
|
| 143 |
+
args.tts_prompt_text = None
|
| 144 |
+
args.tts_text = None
|
| 145 |
+
args.pose_video = None
|
| 146 |
+
args.start_from_ref = False
|
| 147 |
+
args.infer_frames = 80
|
| 148 |
+
|
| 149 |
+
_validate_args(args)
|
| 150 |
+
return args
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _init_logging(rank):
|
| 155 |
+
# logging
|
| 156 |
+
if rank == 0:
|
| 157 |
+
# set format
|
| 158 |
+
logging.basicConfig(
|
| 159 |
+
level=logging.INFO,
|
| 160 |
+
format="[%(asctime)s] %(levelname)s: %(message)s",
|
| 161 |
+
handlers=[logging.StreamHandler(stream=sys.stdout)])
|
| 162 |
+
else:
|
| 163 |
+
logging.basicConfig(level=logging.ERROR)
|
| 164 |
+
|
| 165 |
+
def load_model(use_relighting_lora = False):
|
| 166 |
+
|
| 167 |
+
cfg = WAN_CONFIGS["animate-14B"]
|
| 168 |
+
|
| 169 |
+
return wan.WanAnimate(
|
| 170 |
+
config=cfg,
|
| 171 |
+
checkpoint_dir="./Wan2.2-Animate-14B/",
|
| 172 |
+
device_id=0,
|
| 173 |
+
rank=0,
|
| 174 |
+
t5_fsdp=False,
|
| 175 |
+
dit_fsdp=False,
|
| 176 |
+
use_sp=False,
|
| 177 |
+
t5_cpu=False,
|
| 178 |
+
convert_model_dtype=False,
|
| 179 |
+
use_relighting_lora=use_relighting_lora
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def generate(wan_animate, preprocess_dir, save_file, replace_flag = False):
|
| 183 |
+
args = _parse_args()
|
| 184 |
+
rank = int(os.getenv("RANK", 0))
|
| 185 |
+
world_size = int(os.getenv("WORLD_SIZE", 1))
|
| 186 |
+
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
| 187 |
+
device = local_rank
|
| 188 |
+
_init_logging(rank)
|
| 189 |
+
|
| 190 |
+
cfg = WAN_CONFIGS[args.task]
|
| 191 |
+
|
| 192 |
+
logging.info(f"Input prompt: {args.prompt}")
|
| 193 |
+
img = None
|
| 194 |
+
if args.image is not None:
|
| 195 |
+
img = Image.open(args.image).convert("RGB")
|
| 196 |
+
logging.info(f"Input image: {args.image}")
|
| 197 |
+
|
| 198 |
+
print(f'rank:{rank}')
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
logging.info(f"Generating video ...")
|
| 203 |
+
video = wan_animate.generate(
|
| 204 |
+
src_root_path=preprocess_dir,
|
| 205 |
+
replace_flag=replace_flag,
|
| 206 |
+
refert_num = args.refert_num,
|
| 207 |
+
clip_len=args.frame_num,
|
| 208 |
+
shift=args.sample_shift,
|
| 209 |
+
sample_solver=args.sample_solver,
|
| 210 |
+
sampling_steps=args.sample_steps,
|
| 211 |
+
guide_scale=args.sample_guide_scale,
|
| 212 |
+
seed=args.base_seed,
|
| 213 |
+
offload_model=args.offload_model)
|
| 214 |
+
if rank == 0:
|
| 215 |
+
|
| 216 |
+
save_video(
|
| 217 |
+
tensor=video[None],
|
| 218 |
+
save_file=save_file,
|
| 219 |
+
fps=cfg.sample_fps,
|
| 220 |
+
nrow=1,
|
| 221 |
+
normalize=True,
|
| 222 |
+
value_range=(-1, 1))
|
| 223 |
+
# if "s2v" in args.task:
|
| 224 |
+
# if args.enable_tts is False:
|
| 225 |
+
# merge_video_audio(video_path=args.save_file, audio_path=args.audio)
|
| 226 |
+
# else:
|
| 227 |
+
# merge_video_audio(video_path=args.save_file, audio_path="tts.wav")
|
| 228 |
+
del video
|
| 229 |
+
|
| 230 |
+
torch.cuda.synchronize()
|
| 231 |
+
if dist.is_initialized():
|
| 232 |
+
dist.barrier()
|
| 233 |
+
dist.destroy_process_group()
|
| 234 |
+
|
| 235 |
+
logging.info("Finished.")
|
| 236 |
+
|
pyproject.toml
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "wan"
|
| 7 |
+
version = "2.2.0"
|
| 8 |
+
description = "Wan: Open and Advanced Large-Scale Video Generative Models"
|
| 9 |
+
authors = [
|
| 10 |
+
{ name = "Wan Team", email = "wan.ai@alibabacloud.com" }
|
| 11 |
+
]
|
| 12 |
+
license = { file = "LICENSE.txt" }
|
| 13 |
+
readme = "README.md"
|
| 14 |
+
requires-python = ">=3.10,<4.0"
|
| 15 |
+
dependencies = [
|
| 16 |
+
"torch>=2.4.0",
|
| 17 |
+
"torchvision>=0.19.0",
|
| 18 |
+
"opencv-python>=4.9.0.80",
|
| 19 |
+
"diffusers>=0.31.0",
|
| 20 |
+
"transformers>=4.49.0",
|
| 21 |
+
"tokenizers>=0.20.3",
|
| 22 |
+
"accelerate>=1.1.1",
|
| 23 |
+
"tqdm",
|
| 24 |
+
"imageio",
|
| 25 |
+
"easydict",
|
| 26 |
+
"ftfy",
|
| 27 |
+
"dashscope",
|
| 28 |
+
"imageio-ffmpeg",
|
| 29 |
+
"flash_attn",
|
| 30 |
+
"numpy>=1.23.5,<2"
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
[project.optional-dependencies]
|
| 34 |
+
dev = [
|
| 35 |
+
"pytest",
|
| 36 |
+
"black",
|
| 37 |
+
"flake8",
|
| 38 |
+
"isort",
|
| 39 |
+
"mypy",
|
| 40 |
+
"huggingface-hub[cli]"
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
[project.urls]
|
| 44 |
+
homepage = "https://wanxai.com"
|
| 45 |
+
documentation = "https://github.com/Wan-Video/Wan2.2"
|
| 46 |
+
repository = "https://github.com/Wan-Video/Wan2.2"
|
| 47 |
+
huggingface = "https://huggingface.co/Wan-AI/"
|
| 48 |
+
modelscope = "https://modelscope.cn/organization/Wan-AI"
|
| 49 |
+
discord = "https://discord.gg/p5XbdQV7"
|
| 50 |
+
|
| 51 |
+
[tool.setuptools]
|
| 52 |
+
packages = ["wan"]
|
| 53 |
+
|
| 54 |
+
[tool.setuptools.package-data]
|
| 55 |
+
"wan" = ["**/*.py"]
|
| 56 |
+
|
| 57 |
+
[tool.black]
|
| 58 |
+
line-length = 88
|
| 59 |
+
|
| 60 |
+
[tool.isort]
|
| 61 |
+
profile = "black"
|
| 62 |
+
|
| 63 |
+
[tool.mypy]
|
| 64 |
+
strict = true
|
| 65 |
+
|
| 66 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.8.0
|
| 2 |
+
decord
|
| 3 |
+
peft
|
| 4 |
+
pandas
|
| 5 |
+
matplotlib
|
| 6 |
+
loguru
|
| 7 |
+
sentencepiece
|
| 8 |
+
dashscope
|
| 9 |
+
ftfy
|
| 10 |
+
diffusers
|
| 11 |
+
opencv-python
|
| 12 |
+
moviepy
|
| 13 |
+
torchvision
|
| 14 |
+
torchaudio
|
| 15 |
+
transformers
|
| 16 |
+
tokenizers
|
| 17 |
+
accelerate
|
| 18 |
+
tqdm
|
| 19 |
+
imageio[ffmpeg]
|
| 20 |
+
easydict
|
| 21 |
+
imageio-ffmpeg
|
| 22 |
+
numpy>=1.23.5,<2
|
| 23 |
+
hydra-core
|
| 24 |
+
iopath
|
| 25 |
+
pytest
|
| 26 |
+
pillow
|
| 27 |
+
fvcore
|
| 28 |
+
librosa
|
| 29 |
+
flash-attn
|
| 30 |
+
onnxruntime-gpu
|
| 31 |
+
flash-attn-3 @ https://huggingface.co/alexnasa/flash-attn-3/resolve/main/128/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
|
wan/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from . import configs, distributed, modules
|
| 3 |
+
from .image2video import WanI2V
|
| 4 |
+
from .speech2video import WanS2V
|
| 5 |
+
from .text2video import WanT2V
|
| 6 |
+
from .textimage2video import WanTI2V
|
| 7 |
+
from .animate import WanAnimate
|
wan/animate.py
ADDED
|
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import logging
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
import cv2
|
| 6 |
+
import types
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
from functools import partial
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
import torch.distributed as dist
|
| 14 |
+
from peft import set_peft_model_state_dict
|
| 15 |
+
from decord import VideoReader
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from .distributed.fsdp import shard_model
|
| 19 |
+
from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
|
| 20 |
+
from .distributed.util import get_world_size
|
| 21 |
+
|
| 22 |
+
from .modules.animate import WanAnimateModel
|
| 23 |
+
from .modules.animate import CLIPModel
|
| 24 |
+
from .modules.t5 import T5EncoderModel
|
| 25 |
+
from .modules.vae2_1 import Wan2_1_VAE
|
| 26 |
+
from .modules.animate.animate_utils import TensorList, get_loraconfig
|
| 27 |
+
from .utils.fm_solvers import (
|
| 28 |
+
FlowDPMSolverMultistepScheduler,
|
| 29 |
+
get_sampling_sigmas,
|
| 30 |
+
retrieve_timesteps,
|
| 31 |
+
)
|
| 32 |
+
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class WanAnimate:
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
config,
|
| 41 |
+
checkpoint_dir,
|
| 42 |
+
device_id=0,
|
| 43 |
+
rank=0,
|
| 44 |
+
t5_fsdp=False,
|
| 45 |
+
dit_fsdp=False,
|
| 46 |
+
use_sp=False,
|
| 47 |
+
t5_cpu=False,
|
| 48 |
+
init_on_cpu=True,
|
| 49 |
+
convert_model_dtype=False,
|
| 50 |
+
use_relighting_lora=False
|
| 51 |
+
):
|
| 52 |
+
r"""
|
| 53 |
+
Initializes the generation model components.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
config (EasyDict):
|
| 57 |
+
Object containing model parameters initialized from config.py
|
| 58 |
+
checkpoint_dir (`str`):
|
| 59 |
+
Path to directory containing model checkpoints
|
| 60 |
+
device_id (`int`, *optional*, defaults to 0):
|
| 61 |
+
Id of target GPU device
|
| 62 |
+
rank (`int`, *optional*, defaults to 0):
|
| 63 |
+
Process rank for distributed training
|
| 64 |
+
t5_fsdp (`bool`, *optional*, defaults to False):
|
| 65 |
+
Enable FSDP sharding for T5 model
|
| 66 |
+
dit_fsdp (`bool`, *optional*, defaults to False):
|
| 67 |
+
Enable FSDP sharding for DiT model
|
| 68 |
+
use_sp (`bool`, *optional*, defaults to False):
|
| 69 |
+
Enable distribution strategy of sequence parallel.
|
| 70 |
+
t5_cpu (`bool`, *optional*, defaults to False):
|
| 71 |
+
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
| 72 |
+
init_on_cpu (`bool`, *optional*, defaults to True):
|
| 73 |
+
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
| 74 |
+
convert_model_dtype (`bool`, *optional*, defaults to False):
|
| 75 |
+
Convert DiT model parameters dtype to 'config.param_dtype'.
|
| 76 |
+
Only works without FSDP.
|
| 77 |
+
use_relighting_lora (`bool`, *optional*, defaults to False):
|
| 78 |
+
Whether to use relighting lora for character replacement.
|
| 79 |
+
"""
|
| 80 |
+
self.device = torch.device(f"cuda:{device_id}")
|
| 81 |
+
self.config = config
|
| 82 |
+
self.rank = rank
|
| 83 |
+
self.t5_cpu = t5_cpu
|
| 84 |
+
self.init_on_cpu = init_on_cpu
|
| 85 |
+
|
| 86 |
+
self.num_train_timesteps = config.num_train_timesteps
|
| 87 |
+
self.param_dtype = config.param_dtype
|
| 88 |
+
|
| 89 |
+
if t5_fsdp or dit_fsdp or use_sp:
|
| 90 |
+
self.init_on_cpu = False
|
| 91 |
+
|
| 92 |
+
shard_fn = partial(shard_model, device_id=device_id)
|
| 93 |
+
self.text_encoder = T5EncoderModel(
|
| 94 |
+
text_len=config.text_len,
|
| 95 |
+
dtype=config.t5_dtype,
|
| 96 |
+
device=torch.device('cpu'),
|
| 97 |
+
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
| 98 |
+
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
| 99 |
+
shard_fn=shard_fn if t5_fsdp else None,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
self.clip = CLIPModel(
|
| 103 |
+
dtype=torch.float16,
|
| 104 |
+
device=self.device,
|
| 105 |
+
checkpoint_path=os.path.join(checkpoint_dir,
|
| 106 |
+
config.clip_checkpoint),
|
| 107 |
+
tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
|
| 108 |
+
|
| 109 |
+
self.vae = Wan2_1_VAE(
|
| 110 |
+
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
| 111 |
+
device=self.device)
|
| 112 |
+
|
| 113 |
+
logging.info(f"Creating WanAnimate from {checkpoint_dir}")
|
| 114 |
+
|
| 115 |
+
if not dit_fsdp:
|
| 116 |
+
self.noise_model = WanAnimateModel.from_pretrained(
|
| 117 |
+
checkpoint_dir,
|
| 118 |
+
torch_dtype=self.param_dtype,
|
| 119 |
+
device_map=self.device)
|
| 120 |
+
else:
|
| 121 |
+
self.noise_model = WanAnimateModel.from_pretrained(
|
| 122 |
+
checkpoint_dir, torch_dtype=self.param_dtype)
|
| 123 |
+
|
| 124 |
+
self.noise_model = self._configure_model(
|
| 125 |
+
model=self.noise_model,
|
| 126 |
+
use_sp=use_sp,
|
| 127 |
+
dit_fsdp=dit_fsdp,
|
| 128 |
+
shard_fn=shard_fn,
|
| 129 |
+
convert_model_dtype=convert_model_dtype,
|
| 130 |
+
use_lora=use_relighting_lora,
|
| 131 |
+
checkpoint_dir=checkpoint_dir,
|
| 132 |
+
config=config
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
if use_sp:
|
| 136 |
+
self.sp_size = get_world_size()
|
| 137 |
+
else:
|
| 138 |
+
self.sp_size = 1
|
| 139 |
+
|
| 140 |
+
self.sample_neg_prompt = config.sample_neg_prompt
|
| 141 |
+
self.sample_prompt = config.prompt
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
|
| 145 |
+
convert_model_dtype, use_lora, checkpoint_dir, config):
|
| 146 |
+
"""
|
| 147 |
+
Configures a model object. This includes setting evaluation modes,
|
| 148 |
+
applying distributed parallel strategy, and handling device placement.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
model (torch.nn.Module):
|
| 152 |
+
The model instance to configure.
|
| 153 |
+
use_sp (`bool`):
|
| 154 |
+
Enable distribution strategy of sequence parallel.
|
| 155 |
+
dit_fsdp (`bool`):
|
| 156 |
+
Enable FSDP sharding for DiT model.
|
| 157 |
+
shard_fn (callable):
|
| 158 |
+
The function to apply FSDP sharding.
|
| 159 |
+
convert_model_dtype (`bool`):
|
| 160 |
+
Convert DiT model parameters dtype to 'config.param_dtype'.
|
| 161 |
+
Only works without FSDP.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
torch.nn.Module:
|
| 165 |
+
The configured model.
|
| 166 |
+
"""
|
| 167 |
+
model.eval().requires_grad_(False)
|
| 168 |
+
|
| 169 |
+
if use_sp:
|
| 170 |
+
for block in model.blocks:
|
| 171 |
+
block.self_attn.forward = types.MethodType(
|
| 172 |
+
sp_attn_forward, block.self_attn)
|
| 173 |
+
|
| 174 |
+
model.use_context_parallel = True
|
| 175 |
+
|
| 176 |
+
if dist.is_initialized():
|
| 177 |
+
dist.barrier()
|
| 178 |
+
|
| 179 |
+
if use_lora:
|
| 180 |
+
logging.info("Loading Relighting Lora. ")
|
| 181 |
+
lora_config = get_loraconfig(
|
| 182 |
+
transformer=model,
|
| 183 |
+
rank=128,
|
| 184 |
+
alpha=128
|
| 185 |
+
)
|
| 186 |
+
model.add_adapter(lora_config)
|
| 187 |
+
lora_path = os.path.join(checkpoint_dir, config.lora_checkpoint)
|
| 188 |
+
peft_state_dict = torch.load(lora_path)["state_dict"]
|
| 189 |
+
set_peft_model_state_dict(model, peft_state_dict)
|
| 190 |
+
|
| 191 |
+
if dit_fsdp:
|
| 192 |
+
model = shard_fn(model, use_lora=use_lora)
|
| 193 |
+
else:
|
| 194 |
+
if convert_model_dtype:
|
| 195 |
+
model.to(self.param_dtype)
|
| 196 |
+
if not self.init_on_cpu:
|
| 197 |
+
model.to(self.device)
|
| 198 |
+
|
| 199 |
+
return model
|
| 200 |
+
|
| 201 |
+
def inputs_padding(self, array, target_len):
|
| 202 |
+
idx = 0
|
| 203 |
+
flip = False
|
| 204 |
+
target_array = []
|
| 205 |
+
while len(target_array) < target_len:
|
| 206 |
+
target_array.append(deepcopy(array[idx]))
|
| 207 |
+
if flip:
|
| 208 |
+
idx -= 1
|
| 209 |
+
else:
|
| 210 |
+
idx += 1
|
| 211 |
+
if idx == 0 or idx == len(array) - 1:
|
| 212 |
+
flip = not flip
|
| 213 |
+
return target_array[:target_len]
|
| 214 |
+
|
| 215 |
+
def get_valid_len(self, real_len, clip_len=81, overlap=1):
|
| 216 |
+
real_clip_len = clip_len - overlap
|
| 217 |
+
last_clip_num = (real_len - overlap) % real_clip_len
|
| 218 |
+
if last_clip_num == 0:
|
| 219 |
+
extra = 0
|
| 220 |
+
else:
|
| 221 |
+
extra = real_clip_len - last_clip_num
|
| 222 |
+
target_len = real_len + extra
|
| 223 |
+
return target_len
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
|
| 227 |
+
if mask_pixel_values is None:
|
| 228 |
+
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
|
| 229 |
+
else:
|
| 230 |
+
msk = mask_pixel_values.clone()
|
| 231 |
+
msk[:, :mask_len] = 1
|
| 232 |
+
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
| 233 |
+
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
| 234 |
+
msk = msk.transpose(1, 2)[0]
|
| 235 |
+
return msk
|
| 236 |
+
|
| 237 |
+
def padding_resize(self, img_ori, height=512, width=512, padding_color=(0, 0, 0), interpolation=cv2.INTER_LINEAR):
|
| 238 |
+
ori_height = img_ori.shape[0]
|
| 239 |
+
ori_width = img_ori.shape[1]
|
| 240 |
+
channel = img_ori.shape[2]
|
| 241 |
+
|
| 242 |
+
img_pad = np.zeros((height, width, channel))
|
| 243 |
+
if channel == 1:
|
| 244 |
+
img_pad[:, :, 0] = padding_color[0]
|
| 245 |
+
else:
|
| 246 |
+
img_pad[:, :, 0] = padding_color[0]
|
| 247 |
+
img_pad[:, :, 1] = padding_color[1]
|
| 248 |
+
img_pad[:, :, 2] = padding_color[2]
|
| 249 |
+
|
| 250 |
+
if (ori_height / ori_width) > (height / width):
|
| 251 |
+
new_width = int(height / ori_height * ori_width)
|
| 252 |
+
img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation)
|
| 253 |
+
padding = int((width - new_width) / 2)
|
| 254 |
+
if len(img.shape) == 2:
|
| 255 |
+
img = img[:, :, np.newaxis]
|
| 256 |
+
img_pad[:, padding: padding + new_width, :] = img
|
| 257 |
+
else:
|
| 258 |
+
new_height = int(width / ori_width * ori_height)
|
| 259 |
+
img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation)
|
| 260 |
+
padding = int((height - new_height) / 2)
|
| 261 |
+
if len(img.shape) == 2:
|
| 262 |
+
img = img[:, :, np.newaxis]
|
| 263 |
+
img_pad[padding: padding + new_height, :, :] = img
|
| 264 |
+
|
| 265 |
+
img_pad = np.uint8(img_pad)
|
| 266 |
+
|
| 267 |
+
return img_pad
|
| 268 |
+
|
| 269 |
+
def prepare_source(self, src_pose_path, src_face_path, src_ref_path):
|
| 270 |
+
pose_video_reader = VideoReader(src_pose_path)
|
| 271 |
+
pose_len = len(pose_video_reader)
|
| 272 |
+
pose_idxs = list(range(pose_len))
|
| 273 |
+
cond_images = pose_video_reader.get_batch(pose_idxs).asnumpy()
|
| 274 |
+
|
| 275 |
+
face_video_reader = VideoReader(src_face_path)
|
| 276 |
+
face_len = len(face_video_reader)
|
| 277 |
+
face_idxs = list(range(face_len))
|
| 278 |
+
face_images = face_video_reader.get_batch(face_idxs).asnumpy()
|
| 279 |
+
height, width = cond_images[0].shape[:2]
|
| 280 |
+
refer_images = cv2.imread(src_ref_path)[..., ::-1]
|
| 281 |
+
refer_images = self.padding_resize(refer_images, height=height, width=width)
|
| 282 |
+
return cond_images, face_images, refer_images
|
| 283 |
+
|
| 284 |
+
def prepare_source_for_replace(self, src_bg_path, src_mask_path):
|
| 285 |
+
bg_video_reader = VideoReader(src_bg_path)
|
| 286 |
+
bg_len = len(bg_video_reader)
|
| 287 |
+
bg_idxs = list(range(bg_len))
|
| 288 |
+
bg_images = bg_video_reader.get_batch(bg_idxs).asnumpy()
|
| 289 |
+
|
| 290 |
+
mask_video_reader = VideoReader(src_mask_path)
|
| 291 |
+
mask_len = len(mask_video_reader)
|
| 292 |
+
mask_idxs = list(range(mask_len))
|
| 293 |
+
mask_images = mask_video_reader.get_batch(mask_idxs).asnumpy()
|
| 294 |
+
mask_images = mask_images[:, :, :, 0] / 255
|
| 295 |
+
return bg_images, mask_images
|
| 296 |
+
|
| 297 |
+
def generate(
|
| 298 |
+
self,
|
| 299 |
+
src_root_path,
|
| 300 |
+
replace_flag=False,
|
| 301 |
+
clip_len=77,
|
| 302 |
+
refert_num=1,
|
| 303 |
+
shift=5.0,
|
| 304 |
+
sample_solver='dpm++',
|
| 305 |
+
sampling_steps=20,
|
| 306 |
+
guide_scale=1,
|
| 307 |
+
input_prompt="",
|
| 308 |
+
n_prompt="",
|
| 309 |
+
seed=-1,
|
| 310 |
+
offload_model=True,
|
| 311 |
+
):
|
| 312 |
+
r"""
|
| 313 |
+
Generates video frames from input image using diffusion process.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
src_root_path ('str'):
|
| 317 |
+
Process output path
|
| 318 |
+
replace_flag (`bool`, *optional*, defaults to False):
|
| 319 |
+
Whether to use character replace.
|
| 320 |
+
clip_len (`int`, *optional*, defaults to 77):
|
| 321 |
+
How many frames to generate per clips. The number should be 4n+1
|
| 322 |
+
refert_num (`int`, *optional*, defaults to 1):
|
| 323 |
+
How many frames used for temporal guidance. Recommended to be 1 or 5.
|
| 324 |
+
shift (`float`, *optional*, defaults to 5.0):
|
| 325 |
+
Noise schedule shift parameter.
|
| 326 |
+
sample_solver (`str`, *optional*, defaults to 'dpm++'):
|
| 327 |
+
Solver used to sample the video.
|
| 328 |
+
sampling_steps (`int`, *optional*, defaults to 20):
|
| 329 |
+
Number of diffusion sampling steps. Higher values improve quality but slow generation
|
| 330 |
+
guide_scale (`float` or tuple[`float`], *optional*, defaults 1.0):
|
| 331 |
+
Classifier-free guidance scale. We only use it for expression control.
|
| 332 |
+
In most cases, it's not necessary and faster generation can be achieved without it.
|
| 333 |
+
When expression adjustments are needed, you may consider using this feature.
|
| 334 |
+
input_prompt (`str`):
|
| 335 |
+
Text prompt for content generation. We don't recommend custom prompts (although they work)
|
| 336 |
+
n_prompt (`str`, *optional*, defaults to ""):
|
| 337 |
+
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
|
| 338 |
+
seed (`int`, *optional*, defaults to -1):
|
| 339 |
+
Random seed for noise generation. If -1, use random seed
|
| 340 |
+
offload_model (`bool`, *optional*, defaults to True):
|
| 341 |
+
If True, offloads models to CPU during generation to save VRAM
|
| 342 |
+
|
| 343 |
+
Returns:
|
| 344 |
+
torch.Tensor:
|
| 345 |
+
Generated video frames tensor. Dimensions: (C, N, H, W) where:
|
| 346 |
+
- C: Color channels (3 for RGB)
|
| 347 |
+
- N: Number of frames
|
| 348 |
+
- H: Frame height
|
| 349 |
+
- W: Frame width
|
| 350 |
+
"""
|
| 351 |
+
assert refert_num == 1 or refert_num == 5, "refert_num should be 1 or 5."
|
| 352 |
+
|
| 353 |
+
seed_g = torch.Generator(device=self.device)
|
| 354 |
+
seed_g.manual_seed(seed)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
if n_prompt == "":
|
| 358 |
+
n_prompt = self.sample_neg_prompt
|
| 359 |
+
|
| 360 |
+
if input_prompt == "":
|
| 361 |
+
input_prompt = self.sample_prompt
|
| 362 |
+
|
| 363 |
+
src_pose_path = os.path.join(src_root_path, "src_pose.mp4")
|
| 364 |
+
src_face_path = os.path.join(src_root_path, "src_face.mp4")
|
| 365 |
+
src_ref_path = os.path.join(src_root_path, "src_ref.png")
|
| 366 |
+
|
| 367 |
+
cond_images, face_images, refer_images = self.prepare_source(src_pose_path=src_pose_path, src_face_path=src_face_path, src_ref_path=src_ref_path)
|
| 368 |
+
|
| 369 |
+
if not self.t5_cpu:
|
| 370 |
+
self.text_encoder.model.to(self.device)
|
| 371 |
+
context = self.text_encoder([input_prompt], self.device)
|
| 372 |
+
context_null = self.text_encoder([n_prompt], self.device)
|
| 373 |
+
if offload_model:
|
| 374 |
+
self.text_encoder.model.cpu()
|
| 375 |
+
else:
|
| 376 |
+
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
| 377 |
+
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
| 378 |
+
context = [t.to(self.device) for t in context]
|
| 379 |
+
context_null = [t.to(self.device) for t in context_null]
|
| 380 |
+
|
| 381 |
+
real_frame_len = len(cond_images)
|
| 382 |
+
target_len = self.get_valid_len(real_frame_len, clip_len, overlap=refert_num)
|
| 383 |
+
logging.info('real frames: {} target frames: {}'.format(real_frame_len, target_len))
|
| 384 |
+
cond_images = self.inputs_padding(cond_images, target_len)
|
| 385 |
+
face_images = self.inputs_padding(face_images, target_len)
|
| 386 |
+
|
| 387 |
+
if replace_flag:
|
| 388 |
+
src_bg_path = os.path.join(src_root_path, "src_bg.mp4")
|
| 389 |
+
src_mask_path = os.path.join(src_root_path, "src_mask.mp4")
|
| 390 |
+
bg_images, mask_images = self.prepare_source_for_replace(src_bg_path, src_mask_path)
|
| 391 |
+
bg_images = self.inputs_padding(bg_images, target_len)
|
| 392 |
+
mask_images = self.inputs_padding(mask_images, target_len)
|
| 393 |
+
self.noise_model.disable_adapters()
|
| 394 |
+
else:
|
| 395 |
+
self.noise_model.disable_adapters()
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
height, width = refer_images.shape[:2]
|
| 399 |
+
start = 0
|
| 400 |
+
end = clip_len
|
| 401 |
+
all_out_frames = []
|
| 402 |
+
while True:
|
| 403 |
+
if start + refert_num >= len(cond_images):
|
| 404 |
+
break
|
| 405 |
+
|
| 406 |
+
if start == 0:
|
| 407 |
+
mask_reft_len = 0
|
| 408 |
+
else:
|
| 409 |
+
mask_reft_len = refert_num
|
| 410 |
+
|
| 411 |
+
batch = {
|
| 412 |
+
"conditioning_pixel_values": torch.zeros(1, 3, clip_len, height, width),
|
| 413 |
+
"bg_pixel_values": torch.zeros(1, 3, clip_len, height, width),
|
| 414 |
+
"mask_pixel_values": torch.zeros(1, 1, clip_len, height, width),
|
| 415 |
+
"face_pixel_values": torch.zeros(1, 3, clip_len, 512, 512),
|
| 416 |
+
"refer_pixel_values": torch.zeros(1, 3, height, width),
|
| 417 |
+
"refer_t_pixel_values": torch.zeros(refert_num, 3, height, width)
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
batch["conditioning_pixel_values"] = rearrange(
|
| 421 |
+
torch.tensor(np.stack(cond_images[start:end]) / 127.5 - 1),
|
| 422 |
+
"t h w c -> 1 c t h w",
|
| 423 |
+
)
|
| 424 |
+
batch["face_pixel_values"] = rearrange(
|
| 425 |
+
torch.tensor(np.stack(face_images[start:end]) / 127.5 - 1),
|
| 426 |
+
"t h w c -> 1 c t h w",
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
batch["refer_pixel_values"] = rearrange(
|
| 430 |
+
torch.tensor(refer_images / 127.5 - 1), "h w c -> 1 c h w"
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
if start > 0:
|
| 434 |
+
batch["refer_t_pixel_values"] = rearrange(
|
| 435 |
+
out_frames[0, :, -refert_num:].clone().detach(),
|
| 436 |
+
"c t h w -> t c h w",
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
batch["refer_t_pixel_values"] = rearrange(batch["refer_t_pixel_values"],
|
| 440 |
+
"t c h w -> 1 c t h w",
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
if replace_flag:
|
| 444 |
+
batch["bg_pixel_values"] = rearrange(
|
| 445 |
+
torch.tensor(np.stack(bg_images[start:end]) / 127.5 - 1),
|
| 446 |
+
"t h w c -> 1 c t h w",
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
batch["mask_pixel_values"] = rearrange(
|
| 450 |
+
torch.tensor(np.stack(mask_images[start:end])[:, :, :, None]),
|
| 451 |
+
"t h w c -> 1 t c h w",
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
for key, value in batch.items():
|
| 456 |
+
if isinstance(value, torch.Tensor):
|
| 457 |
+
batch[key] = value.to(device=self.device, dtype=torch.bfloat16)
|
| 458 |
+
|
| 459 |
+
ref_pixel_values = batch["refer_pixel_values"]
|
| 460 |
+
refer_t_pixel_values = batch["refer_t_pixel_values"]
|
| 461 |
+
conditioning_pixel_values = batch["conditioning_pixel_values"]
|
| 462 |
+
face_pixel_values = batch["face_pixel_values"]
|
| 463 |
+
|
| 464 |
+
B, _, H, W = ref_pixel_values.shape
|
| 465 |
+
T = clip_len
|
| 466 |
+
lat_h = H // 8
|
| 467 |
+
lat_w = W // 8
|
| 468 |
+
lat_t = T // 4 + 1
|
| 469 |
+
target_shape = [lat_t + 1, lat_h, lat_w]
|
| 470 |
+
noise = [
|
| 471 |
+
torch.randn(
|
| 472 |
+
16,
|
| 473 |
+
target_shape[0],
|
| 474 |
+
target_shape[1],
|
| 475 |
+
target_shape[2],
|
| 476 |
+
dtype=torch.float32,
|
| 477 |
+
device=self.device,
|
| 478 |
+
generator=seed_g,
|
| 479 |
+
)
|
| 480 |
+
]
|
| 481 |
+
|
| 482 |
+
max_seq_len = int(math.ceil(np.prod(target_shape) // 4 / self.sp_size)) * self.sp_size
|
| 483 |
+
if max_seq_len % self.sp_size != 0:
|
| 484 |
+
raise ValueError(f"max_seq_len {max_seq_len} is not divisible by sp_size {self.sp_size}")
|
| 485 |
+
|
| 486 |
+
with (
|
| 487 |
+
torch.autocast(device_type=str(self.device), dtype=torch.bfloat16, enabled=True),
|
| 488 |
+
torch.no_grad()
|
| 489 |
+
):
|
| 490 |
+
if sample_solver == 'unipc':
|
| 491 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
| 492 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 493 |
+
shift=1,
|
| 494 |
+
use_dynamic_shifting=False)
|
| 495 |
+
sample_scheduler.set_timesteps(
|
| 496 |
+
sampling_steps, device=self.device, shift=shift)
|
| 497 |
+
timesteps = sample_scheduler.timesteps
|
| 498 |
+
elif sample_solver == 'dpm++':
|
| 499 |
+
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
| 500 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 501 |
+
shift=1,
|
| 502 |
+
use_dynamic_shifting=False)
|
| 503 |
+
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
| 504 |
+
timesteps, _ = retrieve_timesteps(
|
| 505 |
+
sample_scheduler,
|
| 506 |
+
device=self.device,
|
| 507 |
+
sigmas=sampling_sigmas)
|
| 508 |
+
else:
|
| 509 |
+
raise NotImplementedError("Unsupported solver.")
|
| 510 |
+
|
| 511 |
+
latents = noise
|
| 512 |
+
|
| 513 |
+
pose_latents_no_ref = self.vae.encode(conditioning_pixel_values.to(torch.bfloat16))
|
| 514 |
+
pose_latents_no_ref = torch.stack(pose_latents_no_ref)
|
| 515 |
+
pose_latents = torch.cat([pose_latents_no_ref], dim=2)
|
| 516 |
+
|
| 517 |
+
ref_pixel_values = rearrange(ref_pixel_values, "t c h w -> 1 c t h w")
|
| 518 |
+
ref_latents = self.vae.encode(ref_pixel_values.to(torch.bfloat16))
|
| 519 |
+
ref_latents = torch.stack(ref_latents)
|
| 520 |
+
|
| 521 |
+
mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=self.device)
|
| 522 |
+
y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=self.device)
|
| 523 |
+
|
| 524 |
+
img = ref_pixel_values[0, :, 0]
|
| 525 |
+
clip_context = self.clip.visual([img[:, None, :, :]]).to(dtype=torch.bfloat16, device=self.device)
|
| 526 |
+
|
| 527 |
+
if mask_reft_len > 0:
|
| 528 |
+
if replace_flag:
|
| 529 |
+
bg_pixel_values = batch["bg_pixel_values"]
|
| 530 |
+
y_reft = self.vae.encode(
|
| 531 |
+
[
|
| 532 |
+
torch.concat([refer_t_pixel_values[0, :, :mask_reft_len], bg_pixel_values[0, :, mask_reft_len:]], dim=1).to(self.device)
|
| 533 |
+
]
|
| 534 |
+
)[0]
|
| 535 |
+
mask_pixel_values = 1 - batch["mask_pixel_values"]
|
| 536 |
+
mask_pixel_values = rearrange(mask_pixel_values, "b t c h w -> (b t) c h w")
|
| 537 |
+
mask_pixel_values = F.interpolate(mask_pixel_values, size=(H//8, W//8), mode='nearest')
|
| 538 |
+
mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0]
|
| 539 |
+
msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, mask_pixel_values=mask_pixel_values, device=self.device)
|
| 540 |
+
else:
|
| 541 |
+
y_reft = self.vae.encode(
|
| 542 |
+
[
|
| 543 |
+
torch.concat(
|
| 544 |
+
[
|
| 545 |
+
torch.nn.functional.interpolate(refer_t_pixel_values[0, :, :mask_reft_len].cpu(),
|
| 546 |
+
size=(H, W), mode="bicubic"),
|
| 547 |
+
torch.zeros(3, T - mask_reft_len, H, W),
|
| 548 |
+
],
|
| 549 |
+
dim=1,
|
| 550 |
+
).to(self.device)
|
| 551 |
+
]
|
| 552 |
+
)[0]
|
| 553 |
+
msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, device=self.device)
|
| 554 |
+
else:
|
| 555 |
+
if replace_flag:
|
| 556 |
+
bg_pixel_values = batch["bg_pixel_values"]
|
| 557 |
+
mask_pixel_values = 1 - batch["mask_pixel_values"]
|
| 558 |
+
mask_pixel_values = rearrange(mask_pixel_values, "b t c h w -> (b t) c h w")
|
| 559 |
+
mask_pixel_values = F.interpolate(mask_pixel_values, size=(H//8, W//8), mode='nearest')
|
| 560 |
+
mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0]
|
| 561 |
+
y_reft = self.vae.encode(
|
| 562 |
+
[
|
| 563 |
+
torch.concat(
|
| 564 |
+
[
|
| 565 |
+
bg_pixel_values[0],
|
| 566 |
+
],
|
| 567 |
+
dim=1,
|
| 568 |
+
).to(self.device)
|
| 569 |
+
]
|
| 570 |
+
)[0]
|
| 571 |
+
msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, mask_pixel_values=mask_pixel_values, device=self.device)
|
| 572 |
+
else:
|
| 573 |
+
y_reft = self.vae.encode(
|
| 574 |
+
[
|
| 575 |
+
torch.concat(
|
| 576 |
+
[
|
| 577 |
+
torch.zeros(3, T - mask_reft_len, H, W),
|
| 578 |
+
],
|
| 579 |
+
dim=1,
|
| 580 |
+
).to(self.device)
|
| 581 |
+
]
|
| 582 |
+
)[0]
|
| 583 |
+
msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, mask_reft_len, device=self.device)
|
| 584 |
+
|
| 585 |
+
y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=self.device)
|
| 586 |
+
y = torch.concat([y_ref, y_reft], dim=1)
|
| 587 |
+
|
| 588 |
+
arg_c = {
|
| 589 |
+
"context": context,
|
| 590 |
+
"seq_len": max_seq_len,
|
| 591 |
+
"clip_fea": clip_context.to(dtype=torch.bfloat16, device=self.device),
|
| 592 |
+
"y": [y],
|
| 593 |
+
"pose_latents": pose_latents,
|
| 594 |
+
"face_pixel_values": face_pixel_values,
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
if guide_scale > 1:
|
| 598 |
+
face_pixel_values_uncond = face_pixel_values * 0 - 1
|
| 599 |
+
arg_null = {
|
| 600 |
+
"context": context_null,
|
| 601 |
+
"seq_len": max_seq_len,
|
| 602 |
+
"clip_fea": clip_context.to(dtype=torch.bfloat16, device=self.device),
|
| 603 |
+
"y": [y],
|
| 604 |
+
"pose_latents": pose_latents,
|
| 605 |
+
"face_pixel_values": face_pixel_values_uncond,
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
for i, t in enumerate(tqdm(timesteps)):
|
| 609 |
+
latent_model_input = latents
|
| 610 |
+
timestep = [t]
|
| 611 |
+
|
| 612 |
+
timestep = torch.stack(timestep)
|
| 613 |
+
|
| 614 |
+
noise_pred_cond = TensorList(
|
| 615 |
+
self.noise_model(TensorList(latent_model_input), t=timestep, **arg_c)
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
if guide_scale > 1:
|
| 619 |
+
noise_pred_uncond = TensorList(
|
| 620 |
+
self.noise_model(
|
| 621 |
+
TensorList(latent_model_input), t=timestep, **arg_null
|
| 622 |
+
)
|
| 623 |
+
)
|
| 624 |
+
noise_pred = noise_pred_uncond + guide_scale * (
|
| 625 |
+
noise_pred_cond - noise_pred_uncond
|
| 626 |
+
)
|
| 627 |
+
else:
|
| 628 |
+
noise_pred = noise_pred_cond
|
| 629 |
+
|
| 630 |
+
temp_x0 = sample_scheduler.step(
|
| 631 |
+
noise_pred[0].unsqueeze(0),
|
| 632 |
+
t,
|
| 633 |
+
latents[0].unsqueeze(0),
|
| 634 |
+
return_dict=False,
|
| 635 |
+
generator=seed_g,
|
| 636 |
+
)[0]
|
| 637 |
+
latents[0] = temp_x0.squeeze(0)
|
| 638 |
+
|
| 639 |
+
x0 = latents
|
| 640 |
+
|
| 641 |
+
x0 = [x.to(dtype=torch.float32) for x in x0]
|
| 642 |
+
out_frames = torch.stack(self.vae.decode([x0[0][:, 1:]]))
|
| 643 |
+
|
| 644 |
+
if start != 0:
|
| 645 |
+
out_frames = out_frames[:, :, refert_num:]
|
| 646 |
+
|
| 647 |
+
all_out_frames.append(out_frames.cpu())
|
| 648 |
+
|
| 649 |
+
start += clip_len - refert_num
|
| 650 |
+
end += clip_len - refert_num
|
| 651 |
+
|
| 652 |
+
videos = torch.cat(all_out_frames, dim=2)[:, :, :real_frame_len]
|
| 653 |
+
return videos[0] if self.rank == 0 else None
|
wan/configs/__init__.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import copy
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
| 6 |
+
|
| 7 |
+
from .wan_i2v_A14B import i2v_A14B
|
| 8 |
+
from .wan_s2v_14B import s2v_14B
|
| 9 |
+
from .wan_t2v_A14B import t2v_A14B
|
| 10 |
+
from .wan_ti2v_5B import ti2v_5B
|
| 11 |
+
from .wan_animate_14B import animate_14B
|
| 12 |
+
|
| 13 |
+
WAN_CONFIGS = {
|
| 14 |
+
't2v-A14B': t2v_A14B,
|
| 15 |
+
'i2v-A14B': i2v_A14B,
|
| 16 |
+
'ti2v-5B': ti2v_5B,
|
| 17 |
+
'animate-14B': animate_14B,
|
| 18 |
+
's2v-14B': s2v_14B,
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
SIZE_CONFIGS = {
|
| 22 |
+
'720*1280': (720, 1280),
|
| 23 |
+
'1280*720': (1280, 720),
|
| 24 |
+
'480*832': (480, 832),
|
| 25 |
+
'832*480': (832, 480),
|
| 26 |
+
'704*1280': (704, 1280),
|
| 27 |
+
'1280*704': (1280, 704),
|
| 28 |
+
'1024*704': (1024, 704),
|
| 29 |
+
'704*1024': (704, 1024),
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
MAX_AREA_CONFIGS = {
|
| 33 |
+
'720*1280': 720 * 1280,
|
| 34 |
+
'1280*720': 1280 * 720,
|
| 35 |
+
'480*832': 480 * 832,
|
| 36 |
+
'832*480': 832 * 480,
|
| 37 |
+
'704*1280': 704 * 1280,
|
| 38 |
+
'1280*704': 1280 * 704,
|
| 39 |
+
'1024*704': 1024 * 704,
|
| 40 |
+
'704*1024': 704 * 1024,
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
SUPPORTED_SIZES = {
|
| 44 |
+
't2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
| 45 |
+
'i2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
| 46 |
+
'ti2v-5B': ('704*1280', '1280*704'),
|
| 47 |
+
's2v-14B': ('720*1280', '1280*720', '480*832', '832*480', '1024*704',
|
| 48 |
+
'704*1024', '704*1280', '1280*704'),
|
| 49 |
+
'animate-14B': ('720*1280', '1280*720')
|
| 50 |
+
}
|
wan/configs/shared_config.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
from easydict import EasyDict
|
| 4 |
+
|
| 5 |
+
#------------------------ Wan shared config ------------------------#
|
| 6 |
+
wan_shared_cfg = EasyDict()
|
| 7 |
+
|
| 8 |
+
# t5
|
| 9 |
+
wan_shared_cfg.t5_model = 'umt5_xxl'
|
| 10 |
+
wan_shared_cfg.t5_dtype = torch.bfloat16
|
| 11 |
+
wan_shared_cfg.text_len = 512
|
| 12 |
+
|
| 13 |
+
# transformer
|
| 14 |
+
wan_shared_cfg.param_dtype = torch.bfloat16
|
| 15 |
+
|
| 16 |
+
# inference
|
| 17 |
+
wan_shared_cfg.num_train_timesteps = 1000
|
| 18 |
+
wan_shared_cfg.sample_fps = 16
|
| 19 |
+
wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
|
| 20 |
+
wan_shared_cfg.frame_num = 81
|
wan/configs/wan_animate_14B.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from easydict import EasyDict
|
| 3 |
+
|
| 4 |
+
from .shared_config import wan_shared_cfg
|
| 5 |
+
|
| 6 |
+
#------------------------ Wan animate 14B ------------------------#
|
| 7 |
+
animate_14B = EasyDict(__name__='Config: Wan animate 14B')
|
| 8 |
+
animate_14B.update(wan_shared_cfg)
|
| 9 |
+
|
| 10 |
+
animate_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
| 11 |
+
animate_14B.t5_tokenizer = 'google/umt5-xxl'
|
| 12 |
+
|
| 13 |
+
animate_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
|
| 14 |
+
animate_14B.clip_tokenizer = 'xlm-roberta-large'
|
| 15 |
+
animate_14B.lora_checkpoint = 'relighting_lora.ckpt'
|
| 16 |
+
# vae
|
| 17 |
+
animate_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
|
| 18 |
+
animate_14B.vae_stride = (4, 8, 8)
|
| 19 |
+
|
| 20 |
+
# transformer
|
| 21 |
+
animate_14B.patch_size = (1, 2, 2)
|
| 22 |
+
animate_14B.dim = 5120
|
| 23 |
+
animate_14B.ffn_dim = 13824
|
| 24 |
+
animate_14B.freq_dim = 256
|
| 25 |
+
animate_14B.num_heads = 40
|
| 26 |
+
animate_14B.num_layers = 40
|
| 27 |
+
animate_14B.window_size = (-1, -1)
|
| 28 |
+
animate_14B.qk_norm = True
|
| 29 |
+
animate_14B.cross_attn_norm = True
|
| 30 |
+
animate_14B.eps = 1e-6
|
| 31 |
+
animate_14B.use_face_encoder = True
|
| 32 |
+
animate_14B.motion_encoder_dim = 512
|
| 33 |
+
|
| 34 |
+
# inference
|
| 35 |
+
animate_14B.sample_shift = 5.0
|
| 36 |
+
animate_14B.sample_steps = 20
|
| 37 |
+
animate_14B.sample_guide_scale = 1.0
|
| 38 |
+
animate_14B.frame_num = 77
|
| 39 |
+
animate_14B.sample_fps = 30
|
| 40 |
+
animate_14B.prompt = '视频中的人在做动作'
|
wan/configs/wan_i2v_A14B.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
from easydict import EasyDict
|
| 4 |
+
|
| 5 |
+
from .shared_config import wan_shared_cfg
|
| 6 |
+
|
| 7 |
+
#------------------------ Wan I2V A14B ------------------------#
|
| 8 |
+
|
| 9 |
+
i2v_A14B = EasyDict(__name__='Config: Wan I2V A14B')
|
| 10 |
+
i2v_A14B.update(wan_shared_cfg)
|
| 11 |
+
|
| 12 |
+
i2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
| 13 |
+
i2v_A14B.t5_tokenizer = 'google/umt5-xxl'
|
| 14 |
+
|
| 15 |
+
# vae
|
| 16 |
+
i2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
|
| 17 |
+
i2v_A14B.vae_stride = (4, 8, 8)
|
| 18 |
+
|
| 19 |
+
# transformer
|
| 20 |
+
i2v_A14B.patch_size = (1, 2, 2)
|
| 21 |
+
i2v_A14B.dim = 5120
|
| 22 |
+
i2v_A14B.ffn_dim = 13824
|
| 23 |
+
i2v_A14B.freq_dim = 256
|
| 24 |
+
i2v_A14B.num_heads = 40
|
| 25 |
+
i2v_A14B.num_layers = 40
|
| 26 |
+
i2v_A14B.window_size = (-1, -1)
|
| 27 |
+
i2v_A14B.qk_norm = True
|
| 28 |
+
i2v_A14B.cross_attn_norm = True
|
| 29 |
+
i2v_A14B.eps = 1e-6
|
| 30 |
+
i2v_A14B.low_noise_checkpoint = 'low_noise_model'
|
| 31 |
+
i2v_A14B.high_noise_checkpoint = 'high_noise_model'
|
| 32 |
+
|
| 33 |
+
# inference
|
| 34 |
+
i2v_A14B.sample_shift = 5.0
|
| 35 |
+
i2v_A14B.sample_steps = 40
|
| 36 |
+
i2v_A14B.boundary = 0.900
|
| 37 |
+
i2v_A14B.sample_guide_scale = (3.5, 3.5) # low noise, high noise
|
wan/configs/wan_s2v_14B.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from easydict import EasyDict
|
| 3 |
+
|
| 4 |
+
from .shared_config import wan_shared_cfg
|
| 5 |
+
|
| 6 |
+
#------------------------ Wan S2V 14B ------------------------#
|
| 7 |
+
|
| 8 |
+
s2v_14B = EasyDict(__name__='Config: Wan S2V 14B')
|
| 9 |
+
s2v_14B.update(wan_shared_cfg)
|
| 10 |
+
|
| 11 |
+
# t5
|
| 12 |
+
s2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
| 13 |
+
s2v_14B.t5_tokenizer = 'google/umt5-xxl'
|
| 14 |
+
|
| 15 |
+
# vae
|
| 16 |
+
s2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
|
| 17 |
+
s2v_14B.vae_stride = (4, 8, 8)
|
| 18 |
+
|
| 19 |
+
# wav2vec
|
| 20 |
+
s2v_14B.wav2vec = "wav2vec2-large-xlsr-53-english"
|
| 21 |
+
|
| 22 |
+
s2v_14B.num_heads = 40
|
| 23 |
+
# transformer
|
| 24 |
+
s2v_14B.transformer = EasyDict(
|
| 25 |
+
__name__="Config: Transformer config for WanModel_S2V")
|
| 26 |
+
s2v_14B.transformer.patch_size = (1, 2, 2)
|
| 27 |
+
s2v_14B.transformer.dim = 5120
|
| 28 |
+
s2v_14B.transformer.ffn_dim = 13824
|
| 29 |
+
s2v_14B.transformer.freq_dim = 256
|
| 30 |
+
s2v_14B.transformer.num_heads = 40
|
| 31 |
+
s2v_14B.transformer.num_layers = 40
|
| 32 |
+
s2v_14B.transformer.window_size = (-1, -1)
|
| 33 |
+
s2v_14B.transformer.qk_norm = True
|
| 34 |
+
s2v_14B.transformer.cross_attn_norm = True
|
| 35 |
+
s2v_14B.transformer.eps = 1e-6
|
| 36 |
+
s2v_14B.transformer.enable_adain = True
|
| 37 |
+
s2v_14B.transformer.adain_mode = "attn_norm"
|
| 38 |
+
s2v_14B.transformer.audio_inject_layers = [
|
| 39 |
+
0, 4, 8, 12, 16, 20, 24, 27, 30, 33, 36, 39
|
| 40 |
+
]
|
| 41 |
+
s2v_14B.transformer.zero_init = True
|
| 42 |
+
s2v_14B.transformer.zero_timestep = True
|
| 43 |
+
s2v_14B.transformer.enable_motioner = False
|
| 44 |
+
s2v_14B.transformer.add_last_motion = True
|
| 45 |
+
s2v_14B.transformer.trainable_token = False
|
| 46 |
+
s2v_14B.transformer.enable_tsm = False
|
| 47 |
+
s2v_14B.transformer.enable_framepack = True
|
| 48 |
+
s2v_14B.transformer.framepack_drop_mode = 'padd'
|
| 49 |
+
s2v_14B.transformer.audio_dim = 1024
|
| 50 |
+
|
| 51 |
+
s2v_14B.transformer.motion_frames = 73
|
| 52 |
+
s2v_14B.transformer.cond_dim = 16
|
| 53 |
+
|
| 54 |
+
# inference
|
| 55 |
+
s2v_14B.sample_neg_prompt = "画面模糊,最差质量,画面模糊,细节模糊不清,情绪激动剧烈,手快速抖动,字幕,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
|
| 56 |
+
s2v_14B.drop_first_motion = True
|
| 57 |
+
s2v_14B.sample_shift = 3
|
| 58 |
+
s2v_14B.sample_steps = 40
|
| 59 |
+
s2v_14B.sample_guide_scale = 4.5
|
wan/configs/wan_t2v_A14B.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from easydict import EasyDict
|
| 3 |
+
|
| 4 |
+
from .shared_config import wan_shared_cfg
|
| 5 |
+
|
| 6 |
+
#------------------------ Wan T2V A14B ------------------------#
|
| 7 |
+
|
| 8 |
+
t2v_A14B = EasyDict(__name__='Config: Wan T2V A14B')
|
| 9 |
+
t2v_A14B.update(wan_shared_cfg)
|
| 10 |
+
|
| 11 |
+
# t5
|
| 12 |
+
t2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
| 13 |
+
t2v_A14B.t5_tokenizer = 'google/umt5-xxl'
|
| 14 |
+
|
| 15 |
+
# vae
|
| 16 |
+
t2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
|
| 17 |
+
t2v_A14B.vae_stride = (4, 8, 8)
|
| 18 |
+
|
| 19 |
+
# transformer
|
| 20 |
+
t2v_A14B.patch_size = (1, 2, 2)
|
| 21 |
+
t2v_A14B.dim = 5120
|
| 22 |
+
t2v_A14B.ffn_dim = 13824
|
| 23 |
+
t2v_A14B.freq_dim = 256
|
| 24 |
+
t2v_A14B.num_heads = 40
|
| 25 |
+
t2v_A14B.num_layers = 40
|
| 26 |
+
t2v_A14B.window_size = (-1, -1)
|
| 27 |
+
t2v_A14B.qk_norm = True
|
| 28 |
+
t2v_A14B.cross_attn_norm = True
|
| 29 |
+
t2v_A14B.eps = 1e-6
|
| 30 |
+
t2v_A14B.low_noise_checkpoint = 'low_noise_model'
|
| 31 |
+
t2v_A14B.high_noise_checkpoint = 'high_noise_model'
|
| 32 |
+
|
| 33 |
+
# inference
|
| 34 |
+
t2v_A14B.sample_shift = 12.0
|
| 35 |
+
t2v_A14B.sample_steps = 40
|
| 36 |
+
t2v_A14B.boundary = 0.875
|
| 37 |
+
t2v_A14B.sample_guide_scale = (3.0, 4.0) # low noise, high noise
|
wan/configs/wan_ti2v_5B.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from easydict import EasyDict
|
| 3 |
+
|
| 4 |
+
from .shared_config import wan_shared_cfg
|
| 5 |
+
|
| 6 |
+
#------------------------ Wan TI2V 5B ------------------------#
|
| 7 |
+
|
| 8 |
+
ti2v_5B = EasyDict(__name__='Config: Wan TI2V 5B')
|
| 9 |
+
ti2v_5B.update(wan_shared_cfg)
|
| 10 |
+
|
| 11 |
+
# t5
|
| 12 |
+
ti2v_5B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
|
| 13 |
+
ti2v_5B.t5_tokenizer = 'google/umt5-xxl'
|
| 14 |
+
|
| 15 |
+
# vae
|
| 16 |
+
ti2v_5B.vae_checkpoint = 'Wan2.2_VAE.pth'
|
| 17 |
+
ti2v_5B.vae_stride = (4, 16, 16)
|
| 18 |
+
|
| 19 |
+
# transformer
|
| 20 |
+
ti2v_5B.patch_size = (1, 2, 2)
|
| 21 |
+
ti2v_5B.dim = 3072
|
| 22 |
+
ti2v_5B.ffn_dim = 14336
|
| 23 |
+
ti2v_5B.freq_dim = 256
|
| 24 |
+
ti2v_5B.num_heads = 24
|
| 25 |
+
ti2v_5B.num_layers = 30
|
| 26 |
+
ti2v_5B.window_size = (-1, -1)
|
| 27 |
+
ti2v_5B.qk_norm = True
|
| 28 |
+
ti2v_5B.cross_attn_norm = True
|
| 29 |
+
ti2v_5B.eps = 1e-6
|
| 30 |
+
|
| 31 |
+
# inference
|
| 32 |
+
ti2v_5B.sample_fps = 24
|
| 33 |
+
ti2v_5B.sample_shift = 5.0
|
| 34 |
+
ti2v_5B.sample_steps = 50
|
| 35 |
+
ti2v_5B.sample_guide_scale = 5.0
|
| 36 |
+
ti2v_5B.frame_num = 121
|
wan/distributed/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
wan/distributed/fsdp.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import gc
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 7 |
+
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
| 8 |
+
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
| 9 |
+
from torch.distributed.utils import _free_storage
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def shard_model(
|
| 13 |
+
model,
|
| 14 |
+
device_id,
|
| 15 |
+
param_dtype=torch.bfloat16,
|
| 16 |
+
reduce_dtype=torch.float32,
|
| 17 |
+
buffer_dtype=torch.float32,
|
| 18 |
+
process_group=None,
|
| 19 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
| 20 |
+
sync_module_states=True,
|
| 21 |
+
use_lora=False
|
| 22 |
+
):
|
| 23 |
+
model = FSDP(
|
| 24 |
+
module=model,
|
| 25 |
+
process_group=process_group,
|
| 26 |
+
sharding_strategy=sharding_strategy,
|
| 27 |
+
auto_wrap_policy=partial(
|
| 28 |
+
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
|
| 29 |
+
mixed_precision=MixedPrecision(
|
| 30 |
+
param_dtype=param_dtype,
|
| 31 |
+
reduce_dtype=reduce_dtype,
|
| 32 |
+
buffer_dtype=buffer_dtype),
|
| 33 |
+
device_id=device_id,
|
| 34 |
+
sync_module_states=sync_module_states,
|
| 35 |
+
use_orig_params=True if use_lora else False)
|
| 36 |
+
return model
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def free_model(model):
|
| 40 |
+
for m in model.modules():
|
| 41 |
+
if isinstance(m, FSDP):
|
| 42 |
+
_free_storage(m._handle.flat_param.data)
|
| 43 |
+
del model
|
| 44 |
+
gc.collect()
|
| 45 |
+
torch.cuda.empty_cache()
|
wan/distributed/sequence_parallel.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.cuda.amp as amp
|
| 4 |
+
|
| 5 |
+
from ..modules.model import sinusoidal_embedding_1d
|
| 6 |
+
from .ulysses import distributed_attention
|
| 7 |
+
from .util import gather_forward, get_rank, get_world_size
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def pad_freqs(original_tensor, target_len):
|
| 11 |
+
seq_len, s1, s2 = original_tensor.shape
|
| 12 |
+
pad_size = target_len - seq_len
|
| 13 |
+
padding_tensor = torch.ones(
|
| 14 |
+
pad_size,
|
| 15 |
+
s1,
|
| 16 |
+
s2,
|
| 17 |
+
dtype=original_tensor.dtype,
|
| 18 |
+
device=original_tensor.device)
|
| 19 |
+
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
| 20 |
+
return padded_tensor
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@torch.amp.autocast('cuda', enabled=False)
|
| 24 |
+
def rope_apply(x, grid_sizes, freqs):
|
| 25 |
+
"""
|
| 26 |
+
x: [B, L, N, C].
|
| 27 |
+
grid_sizes: [B, 3].
|
| 28 |
+
freqs: [M, C // 2].
|
| 29 |
+
"""
|
| 30 |
+
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
| 31 |
+
# split freqs
|
| 32 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 33 |
+
|
| 34 |
+
# loop over samples
|
| 35 |
+
output = []
|
| 36 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 37 |
+
seq_len = f * h * w
|
| 38 |
+
|
| 39 |
+
# precompute multipliers
|
| 40 |
+
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
|
| 41 |
+
s, n, -1, 2))
|
| 42 |
+
freqs_i = torch.cat([
|
| 43 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 44 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 45 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 46 |
+
],
|
| 47 |
+
dim=-1).reshape(seq_len, 1, -1)
|
| 48 |
+
|
| 49 |
+
# apply rotary embedding
|
| 50 |
+
sp_size = get_world_size()
|
| 51 |
+
sp_rank = get_rank()
|
| 52 |
+
freqs_i = pad_freqs(freqs_i, s * sp_size)
|
| 53 |
+
s_per_rank = s
|
| 54 |
+
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
|
| 55 |
+
s_per_rank), :, :]
|
| 56 |
+
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
|
| 57 |
+
x_i = torch.cat([x_i, x[i, s:]])
|
| 58 |
+
|
| 59 |
+
# append to collection
|
| 60 |
+
output.append(x_i)
|
| 61 |
+
return torch.stack(output).float()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def sp_dit_forward(
|
| 65 |
+
self,
|
| 66 |
+
x,
|
| 67 |
+
t,
|
| 68 |
+
context,
|
| 69 |
+
seq_len,
|
| 70 |
+
y=None,
|
| 71 |
+
):
|
| 72 |
+
"""
|
| 73 |
+
x: A list of videos each with shape [C, T, H, W].
|
| 74 |
+
t: [B].
|
| 75 |
+
context: A list of text embeddings each with shape [L, C].
|
| 76 |
+
"""
|
| 77 |
+
if self.model_type == 'i2v':
|
| 78 |
+
assert y is not None
|
| 79 |
+
# params
|
| 80 |
+
device = self.patch_embedding.weight.device
|
| 81 |
+
if self.freqs.device != device:
|
| 82 |
+
self.freqs = self.freqs.to(device)
|
| 83 |
+
|
| 84 |
+
if y is not None:
|
| 85 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 86 |
+
|
| 87 |
+
# embeddings
|
| 88 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 89 |
+
grid_sizes = torch.stack(
|
| 90 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 91 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 92 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 93 |
+
assert seq_lens.max() <= seq_len
|
| 94 |
+
x = torch.cat([
|
| 95 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
|
| 96 |
+
for u in x
|
| 97 |
+
])
|
| 98 |
+
|
| 99 |
+
# time embeddings
|
| 100 |
+
if t.dim() == 1:
|
| 101 |
+
t = t.expand(t.size(0), seq_len)
|
| 102 |
+
with torch.amp.autocast('cuda', dtype=torch.float32):
|
| 103 |
+
bt = t.size(0)
|
| 104 |
+
t = t.flatten()
|
| 105 |
+
e = self.time_embedding(
|
| 106 |
+
sinusoidal_embedding_1d(self.freq_dim,
|
| 107 |
+
t).unflatten(0, (bt, seq_len)).float())
|
| 108 |
+
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
| 109 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 110 |
+
|
| 111 |
+
# context
|
| 112 |
+
context_lens = None
|
| 113 |
+
context = self.text_embedding(
|
| 114 |
+
torch.stack([
|
| 115 |
+
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 116 |
+
for u in context
|
| 117 |
+
]))
|
| 118 |
+
|
| 119 |
+
# Context Parallel
|
| 120 |
+
x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]
|
| 121 |
+
e = torch.chunk(e, get_world_size(), dim=1)[get_rank()]
|
| 122 |
+
e0 = torch.chunk(e0, get_world_size(), dim=1)[get_rank()]
|
| 123 |
+
|
| 124 |
+
# arguments
|
| 125 |
+
kwargs = dict(
|
| 126 |
+
e=e0,
|
| 127 |
+
seq_lens=seq_lens,
|
| 128 |
+
grid_sizes=grid_sizes,
|
| 129 |
+
freqs=self.freqs,
|
| 130 |
+
context=context,
|
| 131 |
+
context_lens=context_lens)
|
| 132 |
+
|
| 133 |
+
for block in self.blocks:
|
| 134 |
+
x = block(x, **kwargs)
|
| 135 |
+
|
| 136 |
+
# head
|
| 137 |
+
x = self.head(x, e)
|
| 138 |
+
|
| 139 |
+
# Context Parallel
|
| 140 |
+
x = gather_forward(x, dim=1)
|
| 141 |
+
|
| 142 |
+
# unpatchify
|
| 143 |
+
x = self.unpatchify(x, grid_sizes)
|
| 144 |
+
return [u.float() for u in x]
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def sp_attn_forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16):
|
| 148 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 149 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 150 |
+
|
| 151 |
+
def half(x):
|
| 152 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 153 |
+
|
| 154 |
+
# query, key, value function
|
| 155 |
+
def qkv_fn(x):
|
| 156 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 157 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 158 |
+
v = self.v(x).view(b, s, n, d)
|
| 159 |
+
return q, k, v
|
| 160 |
+
|
| 161 |
+
q, k, v = qkv_fn(x)
|
| 162 |
+
q = rope_apply(q, grid_sizes, freqs)
|
| 163 |
+
k = rope_apply(k, grid_sizes, freqs)
|
| 164 |
+
|
| 165 |
+
x = distributed_attention(
|
| 166 |
+
half(q),
|
| 167 |
+
half(k),
|
| 168 |
+
half(v),
|
| 169 |
+
seq_lens,
|
| 170 |
+
window_size=self.window_size,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# output
|
| 174 |
+
x = x.flatten(2)
|
| 175 |
+
x = self.o(x)
|
| 176 |
+
return x
|
wan/distributed/ulysses.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
|
| 5 |
+
from ..modules.attention import flash_attention
|
| 6 |
+
from .util import all_to_all
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def distributed_attention(
|
| 10 |
+
q,
|
| 11 |
+
k,
|
| 12 |
+
v,
|
| 13 |
+
seq_lens,
|
| 14 |
+
window_size=(-1, -1),
|
| 15 |
+
):
|
| 16 |
+
"""
|
| 17 |
+
Performs distributed attention based on DeepSpeed Ulysses attention mechanism.
|
| 18 |
+
please refer to https://arxiv.org/pdf/2309.14509
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
q: [B, Lq // p, Nq, C1].
|
| 22 |
+
k: [B, Lk // p, Nk, C1].
|
| 23 |
+
v: [B, Lk // p, Nk, C2]. Nq must be divisible by Nk.
|
| 24 |
+
seq_lens: [B], length of each sequence in batch
|
| 25 |
+
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
| 26 |
+
"""
|
| 27 |
+
if not dist.is_initialized():
|
| 28 |
+
raise ValueError("distributed group should be initialized.")
|
| 29 |
+
b = q.shape[0]
|
| 30 |
+
|
| 31 |
+
# gather q/k/v sequence
|
| 32 |
+
q = all_to_all(q, scatter_dim=2, gather_dim=1)
|
| 33 |
+
k = all_to_all(k, scatter_dim=2, gather_dim=1)
|
| 34 |
+
v = all_to_all(v, scatter_dim=2, gather_dim=1)
|
| 35 |
+
|
| 36 |
+
# apply attention
|
| 37 |
+
x = flash_attention(
|
| 38 |
+
q,
|
| 39 |
+
k,
|
| 40 |
+
v,
|
| 41 |
+
k_lens=seq_lens,
|
| 42 |
+
window_size=window_size,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# scatter q/k/v sequence
|
| 46 |
+
x = all_to_all(x, scatter_dim=1, gather_dim=2)
|
| 47 |
+
return x
|
wan/distributed/util.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def init_distributed_group():
|
| 7 |
+
"""r initialize sequence parallel group.
|
| 8 |
+
"""
|
| 9 |
+
if not dist.is_initialized():
|
| 10 |
+
dist.init_process_group(backend='nccl')
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_rank():
|
| 14 |
+
return dist.get_rank()
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_world_size():
|
| 18 |
+
return dist.get_world_size()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def all_to_all(x, scatter_dim, gather_dim, group=None, **kwargs):
|
| 22 |
+
"""
|
| 23 |
+
`scatter` along one dimension and `gather` along another.
|
| 24 |
+
"""
|
| 25 |
+
world_size = get_world_size()
|
| 26 |
+
if world_size > 1:
|
| 27 |
+
inputs = [u.contiguous() for u in x.chunk(world_size, dim=scatter_dim)]
|
| 28 |
+
outputs = [torch.empty_like(u) for u in inputs]
|
| 29 |
+
dist.all_to_all(outputs, inputs, group=group, **kwargs)
|
| 30 |
+
x = torch.cat(outputs, dim=gather_dim).contiguous()
|
| 31 |
+
return x
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def all_gather(tensor):
|
| 35 |
+
world_size = dist.get_world_size()
|
| 36 |
+
if world_size == 1:
|
| 37 |
+
return [tensor]
|
| 38 |
+
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
| 39 |
+
torch.distributed.all_gather(tensor_list, tensor)
|
| 40 |
+
return tensor_list
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def gather_forward(input, dim):
|
| 44 |
+
# skip if world_size == 1
|
| 45 |
+
world_size = dist.get_world_size()
|
| 46 |
+
if world_size == 1:
|
| 47 |
+
return input
|
| 48 |
+
|
| 49 |
+
# gather sequence
|
| 50 |
+
output = all_gather(input)
|
| 51 |
+
return torch.cat(output, dim=dim).contiguous()
|
wan/image2video.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import gc
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import sys
|
| 8 |
+
import types
|
| 9 |
+
from contextlib import contextmanager
|
| 10 |
+
from functools import partial
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.cuda.amp as amp
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
import torchvision.transforms.functional as TF
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
from .distributed.fsdp import shard_model
|
| 20 |
+
from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
|
| 21 |
+
from .distributed.util import get_world_size
|
| 22 |
+
from .modules.model import WanModel
|
| 23 |
+
from .modules.t5 import T5EncoderModel
|
| 24 |
+
from .modules.vae2_1 import Wan2_1_VAE
|
| 25 |
+
from .utils.fm_solvers import (
|
| 26 |
+
FlowDPMSolverMultistepScheduler,
|
| 27 |
+
get_sampling_sigmas,
|
| 28 |
+
retrieve_timesteps,
|
| 29 |
+
)
|
| 30 |
+
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class WanI2V:
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
config,
|
| 38 |
+
checkpoint_dir,
|
| 39 |
+
device_id=0,
|
| 40 |
+
rank=0,
|
| 41 |
+
t5_fsdp=False,
|
| 42 |
+
dit_fsdp=False,
|
| 43 |
+
use_sp=False,
|
| 44 |
+
t5_cpu=False,
|
| 45 |
+
init_on_cpu=True,
|
| 46 |
+
convert_model_dtype=False,
|
| 47 |
+
):
|
| 48 |
+
r"""
|
| 49 |
+
Initializes the image-to-video generation model components.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
config (EasyDict):
|
| 53 |
+
Object containing model parameters initialized from config.py
|
| 54 |
+
checkpoint_dir (`str`):
|
| 55 |
+
Path to directory containing model checkpoints
|
| 56 |
+
device_id (`int`, *optional*, defaults to 0):
|
| 57 |
+
Id of target GPU device
|
| 58 |
+
rank (`int`, *optional*, defaults to 0):
|
| 59 |
+
Process rank for distributed training
|
| 60 |
+
t5_fsdp (`bool`, *optional*, defaults to False):
|
| 61 |
+
Enable FSDP sharding for T5 model
|
| 62 |
+
dit_fsdp (`bool`, *optional*, defaults to False):
|
| 63 |
+
Enable FSDP sharding for DiT model
|
| 64 |
+
use_sp (`bool`, *optional*, defaults to False):
|
| 65 |
+
Enable distribution strategy of sequence parallel.
|
| 66 |
+
t5_cpu (`bool`, *optional*, defaults to False):
|
| 67 |
+
Whether to place T5 model on CPU. Only works without t5_fsdp.
|
| 68 |
+
init_on_cpu (`bool`, *optional*, defaults to True):
|
| 69 |
+
Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
|
| 70 |
+
convert_model_dtype (`bool`, *optional*, defaults to False):
|
| 71 |
+
Convert DiT model parameters dtype to 'config.param_dtype'.
|
| 72 |
+
Only works without FSDP.
|
| 73 |
+
"""
|
| 74 |
+
self.device = torch.device(f"cuda:{device_id}")
|
| 75 |
+
self.config = config
|
| 76 |
+
self.rank = rank
|
| 77 |
+
self.t5_cpu = t5_cpu
|
| 78 |
+
self.init_on_cpu = init_on_cpu
|
| 79 |
+
|
| 80 |
+
self.num_train_timesteps = config.num_train_timesteps
|
| 81 |
+
self.boundary = config.boundary
|
| 82 |
+
self.param_dtype = config.param_dtype
|
| 83 |
+
|
| 84 |
+
if t5_fsdp or dit_fsdp or use_sp:
|
| 85 |
+
self.init_on_cpu = False
|
| 86 |
+
|
| 87 |
+
shard_fn = partial(shard_model, device_id=device_id)
|
| 88 |
+
self.text_encoder = T5EncoderModel(
|
| 89 |
+
text_len=config.text_len,
|
| 90 |
+
dtype=config.t5_dtype,
|
| 91 |
+
device=torch.device('cpu'),
|
| 92 |
+
checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
|
| 93 |
+
tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
|
| 94 |
+
shard_fn=shard_fn if t5_fsdp else None,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
self.vae_stride = config.vae_stride
|
| 98 |
+
self.patch_size = config.patch_size
|
| 99 |
+
self.vae = Wan2_1_VAE(
|
| 100 |
+
vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
|
| 101 |
+
device=self.device)
|
| 102 |
+
|
| 103 |
+
logging.info(f"Creating WanModel from {checkpoint_dir}")
|
| 104 |
+
self.low_noise_model = WanModel.from_pretrained(
|
| 105 |
+
checkpoint_dir, subfolder=config.low_noise_checkpoint)
|
| 106 |
+
self.low_noise_model = self._configure_model(
|
| 107 |
+
model=self.low_noise_model,
|
| 108 |
+
use_sp=use_sp,
|
| 109 |
+
dit_fsdp=dit_fsdp,
|
| 110 |
+
shard_fn=shard_fn,
|
| 111 |
+
convert_model_dtype=convert_model_dtype)
|
| 112 |
+
|
| 113 |
+
self.high_noise_model = WanModel.from_pretrained(
|
| 114 |
+
checkpoint_dir, subfolder=config.high_noise_checkpoint)
|
| 115 |
+
self.high_noise_model = self._configure_model(
|
| 116 |
+
model=self.high_noise_model,
|
| 117 |
+
use_sp=use_sp,
|
| 118 |
+
dit_fsdp=dit_fsdp,
|
| 119 |
+
shard_fn=shard_fn,
|
| 120 |
+
convert_model_dtype=convert_model_dtype)
|
| 121 |
+
if use_sp:
|
| 122 |
+
self.sp_size = get_world_size()
|
| 123 |
+
else:
|
| 124 |
+
self.sp_size = 1
|
| 125 |
+
|
| 126 |
+
self.sample_neg_prompt = config.sample_neg_prompt
|
| 127 |
+
|
| 128 |
+
def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
|
| 129 |
+
convert_model_dtype):
|
| 130 |
+
"""
|
| 131 |
+
Configures a model object. This includes setting evaluation modes,
|
| 132 |
+
applying distributed parallel strategy, and handling device placement.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
model (torch.nn.Module):
|
| 136 |
+
The model instance to configure.
|
| 137 |
+
use_sp (`bool`):
|
| 138 |
+
Enable distribution strategy of sequence parallel.
|
| 139 |
+
dit_fsdp (`bool`):
|
| 140 |
+
Enable FSDP sharding for DiT model.
|
| 141 |
+
shard_fn (callable):
|
| 142 |
+
The function to apply FSDP sharding.
|
| 143 |
+
convert_model_dtype (`bool`):
|
| 144 |
+
Convert DiT model parameters dtype to 'config.param_dtype'.
|
| 145 |
+
Only works without FSDP.
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
torch.nn.Module:
|
| 149 |
+
The configured model.
|
| 150 |
+
"""
|
| 151 |
+
model.eval().requires_grad_(False)
|
| 152 |
+
|
| 153 |
+
if use_sp:
|
| 154 |
+
for block in model.blocks:
|
| 155 |
+
block.self_attn.forward = types.MethodType(
|
| 156 |
+
sp_attn_forward, block.self_attn)
|
| 157 |
+
model.forward = types.MethodType(sp_dit_forward, model)
|
| 158 |
+
|
| 159 |
+
if dist.is_initialized():
|
| 160 |
+
dist.barrier()
|
| 161 |
+
|
| 162 |
+
if dit_fsdp:
|
| 163 |
+
model = shard_fn(model)
|
| 164 |
+
else:
|
| 165 |
+
if convert_model_dtype:
|
| 166 |
+
model.to(self.param_dtype)
|
| 167 |
+
if not self.init_on_cpu:
|
| 168 |
+
model.to(self.device)
|
| 169 |
+
|
| 170 |
+
return model
|
| 171 |
+
|
| 172 |
+
def _prepare_model_for_timestep(self, t, boundary, offload_model):
|
| 173 |
+
r"""
|
| 174 |
+
Prepares and returns the required model for the current timestep.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
t (torch.Tensor):
|
| 178 |
+
current timestep.
|
| 179 |
+
boundary (`int`):
|
| 180 |
+
The timestep threshold. If `t` is at or above this value,
|
| 181 |
+
the `high_noise_model` is considered as the required model.
|
| 182 |
+
offload_model (`bool`):
|
| 183 |
+
A flag intended to control the offloading behavior.
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
torch.nn.Module:
|
| 187 |
+
The active model on the target device for the current timestep.
|
| 188 |
+
"""
|
| 189 |
+
if t.item() >= boundary:
|
| 190 |
+
required_model_name = 'high_noise_model'
|
| 191 |
+
offload_model_name = 'low_noise_model'
|
| 192 |
+
else:
|
| 193 |
+
required_model_name = 'low_noise_model'
|
| 194 |
+
offload_model_name = 'high_noise_model'
|
| 195 |
+
if offload_model or self.init_on_cpu:
|
| 196 |
+
if next(getattr(
|
| 197 |
+
self,
|
| 198 |
+
offload_model_name).parameters()).device.type == 'cuda':
|
| 199 |
+
getattr(self, offload_model_name).to('cpu')
|
| 200 |
+
if next(getattr(
|
| 201 |
+
self,
|
| 202 |
+
required_model_name).parameters()).device.type == 'cpu':
|
| 203 |
+
getattr(self, required_model_name).to(self.device)
|
| 204 |
+
return getattr(self, required_model_name)
|
| 205 |
+
|
| 206 |
+
def generate(self,
|
| 207 |
+
input_prompt,
|
| 208 |
+
img,
|
| 209 |
+
max_area=720 * 1280,
|
| 210 |
+
frame_num=81,
|
| 211 |
+
shift=5.0,
|
| 212 |
+
sample_solver='unipc',
|
| 213 |
+
sampling_steps=40,
|
| 214 |
+
guide_scale=5.0,
|
| 215 |
+
n_prompt="",
|
| 216 |
+
seed=-1,
|
| 217 |
+
offload_model=True):
|
| 218 |
+
r"""
|
| 219 |
+
Generates video frames from input image and text prompt using diffusion process.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
input_prompt (`str`):
|
| 223 |
+
Text prompt for content generation.
|
| 224 |
+
img (PIL.Image.Image):
|
| 225 |
+
Input image tensor. Shape: [3, H, W]
|
| 226 |
+
max_area (`int`, *optional*, defaults to 720*1280):
|
| 227 |
+
Maximum pixel area for latent space calculation. Controls video resolution scaling
|
| 228 |
+
frame_num (`int`, *optional*, defaults to 81):
|
| 229 |
+
How many frames to sample from a video. The number should be 4n+1
|
| 230 |
+
shift (`float`, *optional*, defaults to 5.0):
|
| 231 |
+
Noise schedule shift parameter. Affects temporal dynamics
|
| 232 |
+
[NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
|
| 233 |
+
sample_solver (`str`, *optional*, defaults to 'unipc'):
|
| 234 |
+
Solver used to sample the video.
|
| 235 |
+
sampling_steps (`int`, *optional*, defaults to 40):
|
| 236 |
+
Number of diffusion sampling steps. Higher values improve quality but slow generation
|
| 237 |
+
guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):
|
| 238 |
+
Classifier-free guidance scale. Controls prompt adherence vs. creativity.
|
| 239 |
+
If tuple, the first guide_scale will be used for low noise model and
|
| 240 |
+
the second guide_scale will be used for high noise model.
|
| 241 |
+
n_prompt (`str`, *optional*, defaults to ""):
|
| 242 |
+
Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
|
| 243 |
+
seed (`int`, *optional*, defaults to -1):
|
| 244 |
+
Random seed for noise generation. If -1, use random seed
|
| 245 |
+
offload_model (`bool`, *optional*, defaults to True):
|
| 246 |
+
If True, offloads models to CPU during generation to save VRAM
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
torch.Tensor:
|
| 250 |
+
Generated video frames tensor. Dimensions: (C, N H, W) where:
|
| 251 |
+
- C: Color channels (3 for RGB)
|
| 252 |
+
- N: Number of frames (81)
|
| 253 |
+
- H: Frame height (from max_area)
|
| 254 |
+
- W: Frame width from max_area)
|
| 255 |
+
"""
|
| 256 |
+
# preprocess
|
| 257 |
+
guide_scale = (guide_scale, guide_scale) if isinstance(
|
| 258 |
+
guide_scale, float) else guide_scale
|
| 259 |
+
img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
|
| 260 |
+
|
| 261 |
+
F = frame_num
|
| 262 |
+
h, w = img.shape[1:]
|
| 263 |
+
aspect_ratio = h / w
|
| 264 |
+
lat_h = round(
|
| 265 |
+
np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
|
| 266 |
+
self.patch_size[1] * self.patch_size[1])
|
| 267 |
+
lat_w = round(
|
| 268 |
+
np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
|
| 269 |
+
self.patch_size[2] * self.patch_size[2])
|
| 270 |
+
h = lat_h * self.vae_stride[1]
|
| 271 |
+
w = lat_w * self.vae_stride[2]
|
| 272 |
+
|
| 273 |
+
max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
|
| 274 |
+
self.patch_size[1] * self.patch_size[2])
|
| 275 |
+
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
|
| 276 |
+
|
| 277 |
+
seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
|
| 278 |
+
seed_g = torch.Generator(device=self.device)
|
| 279 |
+
seed_g.manual_seed(seed)
|
| 280 |
+
noise = torch.randn(
|
| 281 |
+
16,
|
| 282 |
+
(F - 1) // self.vae_stride[0] + 1,
|
| 283 |
+
lat_h,
|
| 284 |
+
lat_w,
|
| 285 |
+
dtype=torch.float32,
|
| 286 |
+
generator=seed_g,
|
| 287 |
+
device=self.device)
|
| 288 |
+
|
| 289 |
+
msk = torch.ones(1, F, lat_h, lat_w, device=self.device)
|
| 290 |
+
msk[:, 1:] = 0
|
| 291 |
+
msk = torch.concat([
|
| 292 |
+
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
|
| 293 |
+
],
|
| 294 |
+
dim=1)
|
| 295 |
+
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
| 296 |
+
msk = msk.transpose(1, 2)[0]
|
| 297 |
+
|
| 298 |
+
if n_prompt == "":
|
| 299 |
+
n_prompt = self.sample_neg_prompt
|
| 300 |
+
|
| 301 |
+
# preprocess
|
| 302 |
+
if not self.t5_cpu:
|
| 303 |
+
self.text_encoder.model.to(self.device)
|
| 304 |
+
context = self.text_encoder([input_prompt], self.device)
|
| 305 |
+
context_null = self.text_encoder([n_prompt], self.device)
|
| 306 |
+
if offload_model:
|
| 307 |
+
self.text_encoder.model.cpu()
|
| 308 |
+
else:
|
| 309 |
+
context = self.text_encoder([input_prompt], torch.device('cpu'))
|
| 310 |
+
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
| 311 |
+
context = [t.to(self.device) for t in context]
|
| 312 |
+
context_null = [t.to(self.device) for t in context_null]
|
| 313 |
+
|
| 314 |
+
y = self.vae.encode([
|
| 315 |
+
torch.concat([
|
| 316 |
+
torch.nn.functional.interpolate(
|
| 317 |
+
img[None].cpu(), size=(h, w), mode='bicubic').transpose(
|
| 318 |
+
0, 1),
|
| 319 |
+
torch.zeros(3, F - 1, h, w)
|
| 320 |
+
],
|
| 321 |
+
dim=1).to(self.device)
|
| 322 |
+
])[0]
|
| 323 |
+
y = torch.concat([msk, y])
|
| 324 |
+
|
| 325 |
+
@contextmanager
|
| 326 |
+
def noop_no_sync():
|
| 327 |
+
yield
|
| 328 |
+
|
| 329 |
+
no_sync_low_noise = getattr(self.low_noise_model, 'no_sync',
|
| 330 |
+
noop_no_sync)
|
| 331 |
+
no_sync_high_noise = getattr(self.high_noise_model, 'no_sync',
|
| 332 |
+
noop_no_sync)
|
| 333 |
+
|
| 334 |
+
# evaluation mode
|
| 335 |
+
with (
|
| 336 |
+
torch.amp.autocast('cuda', dtype=self.param_dtype),
|
| 337 |
+
torch.no_grad(),
|
| 338 |
+
no_sync_low_noise(),
|
| 339 |
+
no_sync_high_noise(),
|
| 340 |
+
):
|
| 341 |
+
boundary = self.boundary * self.num_train_timesteps
|
| 342 |
+
|
| 343 |
+
if sample_solver == 'unipc':
|
| 344 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
| 345 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 346 |
+
shift=1,
|
| 347 |
+
use_dynamic_shifting=False)
|
| 348 |
+
sample_scheduler.set_timesteps(
|
| 349 |
+
sampling_steps, device=self.device, shift=shift)
|
| 350 |
+
timesteps = sample_scheduler.timesteps
|
| 351 |
+
elif sample_solver == 'dpm++':
|
| 352 |
+
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
| 353 |
+
num_train_timesteps=self.num_train_timesteps,
|
| 354 |
+
shift=1,
|
| 355 |
+
use_dynamic_shifting=False)
|
| 356 |
+
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
|
| 357 |
+
timesteps, _ = retrieve_timesteps(
|
| 358 |
+
sample_scheduler,
|
| 359 |
+
device=self.device,
|
| 360 |
+
sigmas=sampling_sigmas)
|
| 361 |
+
else:
|
| 362 |
+
raise NotImplementedError("Unsupported solver.")
|
| 363 |
+
|
| 364 |
+
# sample videos
|
| 365 |
+
latent = noise
|
| 366 |
+
|
| 367 |
+
arg_c = {
|
| 368 |
+
'context': [context[0]],
|
| 369 |
+
'seq_len': max_seq_len,
|
| 370 |
+
'y': [y],
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
arg_null = {
|
| 374 |
+
'context': context_null,
|
| 375 |
+
'seq_len': max_seq_len,
|
| 376 |
+
'y': [y],
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
if offload_model:
|
| 380 |
+
torch.cuda.empty_cache()
|
| 381 |
+
|
| 382 |
+
for _, t in enumerate(tqdm(timesteps)):
|
| 383 |
+
latent_model_input = [latent.to(self.device)]
|
| 384 |
+
timestep = [t]
|
| 385 |
+
|
| 386 |
+
timestep = torch.stack(timestep).to(self.device)
|
| 387 |
+
|
| 388 |
+
model = self._prepare_model_for_timestep(
|
| 389 |
+
t, boundary, offload_model)
|
| 390 |
+
sample_guide_scale = guide_scale[1] if t.item(
|
| 391 |
+
) >= boundary else guide_scale[0]
|
| 392 |
+
|
| 393 |
+
noise_pred_cond = model(
|
| 394 |
+
latent_model_input, t=timestep, **arg_c)[0]
|
| 395 |
+
if offload_model:
|
| 396 |
+
torch.cuda.empty_cache()
|
| 397 |
+
noise_pred_uncond = model(
|
| 398 |
+
latent_model_input, t=timestep, **arg_null)[0]
|
| 399 |
+
if offload_model:
|
| 400 |
+
torch.cuda.empty_cache()
|
| 401 |
+
noise_pred = noise_pred_uncond + sample_guide_scale * (
|
| 402 |
+
noise_pred_cond - noise_pred_uncond)
|
| 403 |
+
|
| 404 |
+
temp_x0 = sample_scheduler.step(
|
| 405 |
+
noise_pred.unsqueeze(0),
|
| 406 |
+
t,
|
| 407 |
+
latent.unsqueeze(0),
|
| 408 |
+
return_dict=False,
|
| 409 |
+
generator=seed_g)[0]
|
| 410 |
+
latent = temp_x0.squeeze(0)
|
| 411 |
+
|
| 412 |
+
x0 = [latent]
|
| 413 |
+
del latent_model_input, timestep
|
| 414 |
+
|
| 415 |
+
if offload_model:
|
| 416 |
+
self.low_noise_model.cpu()
|
| 417 |
+
self.high_noise_model.cpu()
|
| 418 |
+
torch.cuda.empty_cache()
|
| 419 |
+
|
| 420 |
+
if self.rank == 0:
|
| 421 |
+
videos = self.vae.decode(x0)
|
| 422 |
+
|
| 423 |
+
del noise, latent, x0
|
| 424 |
+
del sample_scheduler
|
| 425 |
+
if offload_model:
|
| 426 |
+
gc.collect()
|
| 427 |
+
torch.cuda.synchronize()
|
| 428 |
+
if dist.is_initialized():
|
| 429 |
+
dist.barrier()
|
| 430 |
+
|
| 431 |
+
return videos[0] if self.rank == 0 else None
|
wan/modules/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from .attention import flash_attention
|
| 3 |
+
from .model import WanModel
|
| 4 |
+
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
|
| 5 |
+
from .tokenizers import HuggingfaceTokenizer
|
| 6 |
+
from .vae2_1 import Wan2_1_VAE
|
| 7 |
+
from .vae2_2 import Wan2_2_VAE
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
'Wan2_1_VAE',
|
| 11 |
+
'Wan2_2_VAE',
|
| 12 |
+
'WanModel',
|
| 13 |
+
'T5Model',
|
| 14 |
+
'T5Encoder',
|
| 15 |
+
'T5Decoder',
|
| 16 |
+
'T5EncoderModel',
|
| 17 |
+
'HuggingfaceTokenizer',
|
| 18 |
+
'flash_attention',
|
| 19 |
+
]
|
wan/modules/animate/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from .model_animate import WanAnimateModel
|
| 3 |
+
from .clip import CLIPModel
|
| 4 |
+
__all__ = ['WanAnimateModel', 'CLIPModel']
|
wan/modules/animate/animate_utils.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import torch
|
| 3 |
+
import numbers
|
| 4 |
+
from peft import LoraConfig
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"):
|
| 8 |
+
target_modules = []
|
| 9 |
+
for name, module in transformer.named_modules():
|
| 10 |
+
if "blocks" in name and "face" not in name and "modulation" not in name and isinstance(module, torch.nn.Linear):
|
| 11 |
+
target_modules.append(name)
|
| 12 |
+
|
| 13 |
+
transformer_lora_config = LoraConfig(
|
| 14 |
+
r=rank,
|
| 15 |
+
lora_alpha=alpha,
|
| 16 |
+
init_lora_weights=init_lora_weights,
|
| 17 |
+
target_modules=target_modules,
|
| 18 |
+
)
|
| 19 |
+
return transformer_lora_config
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TensorList(object):
|
| 24 |
+
|
| 25 |
+
def __init__(self, tensors):
|
| 26 |
+
"""
|
| 27 |
+
tensors: a list of torch.Tensor objects. No need to have uniform shape.
|
| 28 |
+
"""
|
| 29 |
+
assert isinstance(tensors, (list, tuple))
|
| 30 |
+
assert all(isinstance(u, torch.Tensor) for u in tensors)
|
| 31 |
+
assert len(set([u.ndim for u in tensors])) == 1
|
| 32 |
+
assert len(set([u.dtype for u in tensors])) == 1
|
| 33 |
+
assert len(set([u.device for u in tensors])) == 1
|
| 34 |
+
self.tensors = tensors
|
| 35 |
+
|
| 36 |
+
def to(self, *args, **kwargs):
|
| 37 |
+
return TensorList([u.to(*args, **kwargs) for u in self.tensors])
|
| 38 |
+
|
| 39 |
+
def size(self, dim):
|
| 40 |
+
assert dim == 0, 'only support get the 0th size'
|
| 41 |
+
return len(self.tensors)
|
| 42 |
+
|
| 43 |
+
def pow(self, *args, **kwargs):
|
| 44 |
+
return TensorList([u.pow(*args, **kwargs) for u in self.tensors])
|
| 45 |
+
|
| 46 |
+
def squeeze(self, dim):
|
| 47 |
+
assert dim != 0
|
| 48 |
+
if dim > 0:
|
| 49 |
+
dim -= 1
|
| 50 |
+
return TensorList([u.squeeze(dim) for u in self.tensors])
|
| 51 |
+
|
| 52 |
+
def type(self, *args, **kwargs):
|
| 53 |
+
return TensorList([u.type(*args, **kwargs) for u in self.tensors])
|
| 54 |
+
|
| 55 |
+
def type_as(self, other):
|
| 56 |
+
assert isinstance(other, (torch.Tensor, TensorList))
|
| 57 |
+
if isinstance(other, torch.Tensor):
|
| 58 |
+
return TensorList([u.type_as(other) for u in self.tensors])
|
| 59 |
+
else:
|
| 60 |
+
return TensorList([u.type(other.dtype) for u in self.tensors])
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def dtype(self):
|
| 64 |
+
return self.tensors[0].dtype
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def device(self):
|
| 68 |
+
return self.tensors[0].device
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def ndim(self):
|
| 72 |
+
return 1 + self.tensors[0].ndim
|
| 73 |
+
|
| 74 |
+
def __getitem__(self, index):
|
| 75 |
+
return self.tensors[index]
|
| 76 |
+
|
| 77 |
+
def __len__(self):
|
| 78 |
+
return len(self.tensors)
|
| 79 |
+
|
| 80 |
+
def __add__(self, other):
|
| 81 |
+
return self._apply(other, lambda u, v: u + v)
|
| 82 |
+
|
| 83 |
+
def __radd__(self, other):
|
| 84 |
+
return self._apply(other, lambda u, v: v + u)
|
| 85 |
+
|
| 86 |
+
def __sub__(self, other):
|
| 87 |
+
return self._apply(other, lambda u, v: u - v)
|
| 88 |
+
|
| 89 |
+
def __rsub__(self, other):
|
| 90 |
+
return self._apply(other, lambda u, v: v - u)
|
| 91 |
+
|
| 92 |
+
def __mul__(self, other):
|
| 93 |
+
return self._apply(other, lambda u, v: u * v)
|
| 94 |
+
|
| 95 |
+
def __rmul__(self, other):
|
| 96 |
+
return self._apply(other, lambda u, v: v * u)
|
| 97 |
+
|
| 98 |
+
def __floordiv__(self, other):
|
| 99 |
+
return self._apply(other, lambda u, v: u // v)
|
| 100 |
+
|
| 101 |
+
def __truediv__(self, other):
|
| 102 |
+
return self._apply(other, lambda u, v: u / v)
|
| 103 |
+
|
| 104 |
+
def __rfloordiv__(self, other):
|
| 105 |
+
return self._apply(other, lambda u, v: v // u)
|
| 106 |
+
|
| 107 |
+
def __rtruediv__(self, other):
|
| 108 |
+
return self._apply(other, lambda u, v: v / u)
|
| 109 |
+
|
| 110 |
+
def __pow__(self, other):
|
| 111 |
+
return self._apply(other, lambda u, v: u ** v)
|
| 112 |
+
|
| 113 |
+
def __rpow__(self, other):
|
| 114 |
+
return self._apply(other, lambda u, v: v ** u)
|
| 115 |
+
|
| 116 |
+
def __neg__(self):
|
| 117 |
+
return TensorList([-u for u in self.tensors])
|
| 118 |
+
|
| 119 |
+
def __iter__(self):
|
| 120 |
+
for tensor in self.tensors:
|
| 121 |
+
yield tensor
|
| 122 |
+
|
| 123 |
+
def __repr__(self):
|
| 124 |
+
return 'TensorList: \n' + repr(self.tensors)
|
| 125 |
+
|
| 126 |
+
def _apply(self, other, op):
|
| 127 |
+
if isinstance(other, (list, tuple, TensorList)) or (
|
| 128 |
+
isinstance(other, torch.Tensor) and (
|
| 129 |
+
other.numel() > 1 or other.ndim > 1
|
| 130 |
+
)
|
| 131 |
+
):
|
| 132 |
+
assert len(other) == len(self.tensors)
|
| 133 |
+
return TensorList([op(u, v) for u, v in zip(self.tensors, other)])
|
| 134 |
+
elif isinstance(other, numbers.Number) or (
|
| 135 |
+
isinstance(other, torch.Tensor) and (
|
| 136 |
+
other.numel() == 1 and other.ndim <= 1
|
| 137 |
+
)
|
| 138 |
+
):
|
| 139 |
+
return TensorList([op(u, other) for u in self.tensors])
|
| 140 |
+
else:
|
| 141 |
+
raise TypeError(
|
| 142 |
+
f'unsupported operand for *: "TensorList" and "{type(other)}"'
|
| 143 |
+
)
|
wan/modules/animate/clip.py
ADDED
|
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import logging
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
|
| 11 |
+
from ..attention import flash_attention
|
| 12 |
+
from ..tokenizers import HuggingfaceTokenizer
|
| 13 |
+
from .xlm_roberta import XLMRoberta
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
'XLMRobertaCLIP',
|
| 17 |
+
'clip_xlm_roberta_vit_h_14',
|
| 18 |
+
'CLIPModel',
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def pos_interpolate(pos, seq_len):
|
| 23 |
+
if pos.size(1) == seq_len:
|
| 24 |
+
return pos
|
| 25 |
+
else:
|
| 26 |
+
src_grid = int(math.sqrt(pos.size(1)))
|
| 27 |
+
tar_grid = int(math.sqrt(seq_len))
|
| 28 |
+
n = pos.size(1) - src_grid * src_grid
|
| 29 |
+
return torch.cat([
|
| 30 |
+
pos[:, :n],
|
| 31 |
+
F.interpolate(
|
| 32 |
+
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
|
| 33 |
+
0, 3, 1, 2),
|
| 34 |
+
size=(tar_grid, tar_grid),
|
| 35 |
+
mode='bicubic',
|
| 36 |
+
align_corners=False).flatten(2).transpose(1, 2)
|
| 37 |
+
],
|
| 38 |
+
dim=1)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class QuickGELU(nn.Module):
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
return x * torch.sigmoid(1.702 * x)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class LayerNorm(nn.LayerNorm):
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
return super().forward(x.float()).type_as(x)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class SelfAttention(nn.Module):
|
| 54 |
+
|
| 55 |
+
def __init__(self,
|
| 56 |
+
dim,
|
| 57 |
+
num_heads,
|
| 58 |
+
causal=False,
|
| 59 |
+
attn_dropout=0.0,
|
| 60 |
+
proj_dropout=0.0):
|
| 61 |
+
assert dim % num_heads == 0
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.dim = dim
|
| 64 |
+
self.num_heads = num_heads
|
| 65 |
+
self.head_dim = dim // num_heads
|
| 66 |
+
self.causal = causal
|
| 67 |
+
self.attn_dropout = attn_dropout
|
| 68 |
+
self.proj_dropout = proj_dropout
|
| 69 |
+
|
| 70 |
+
# layers
|
| 71 |
+
self.to_qkv = nn.Linear(dim, dim * 3)
|
| 72 |
+
self.proj = nn.Linear(dim, dim)
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
"""
|
| 76 |
+
x: [B, L, C].
|
| 77 |
+
"""
|
| 78 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
| 79 |
+
|
| 80 |
+
# compute query, key, value
|
| 81 |
+
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
|
| 82 |
+
|
| 83 |
+
# compute attention
|
| 84 |
+
p = self.attn_dropout if self.training else 0.0
|
| 85 |
+
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
|
| 86 |
+
x = x.reshape(b, s, c)
|
| 87 |
+
|
| 88 |
+
# output
|
| 89 |
+
x = self.proj(x)
|
| 90 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
| 91 |
+
return x
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class SwiGLU(nn.Module):
|
| 95 |
+
|
| 96 |
+
def __init__(self, dim, mid_dim):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.dim = dim
|
| 99 |
+
self.mid_dim = mid_dim
|
| 100 |
+
|
| 101 |
+
# layers
|
| 102 |
+
self.fc1 = nn.Linear(dim, mid_dim)
|
| 103 |
+
self.fc2 = nn.Linear(dim, mid_dim)
|
| 104 |
+
self.fc3 = nn.Linear(mid_dim, dim)
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
x = F.silu(self.fc1(x)) * self.fc2(x)
|
| 108 |
+
x = self.fc3(x)
|
| 109 |
+
return x
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class AttentionBlock(nn.Module):
|
| 113 |
+
|
| 114 |
+
def __init__(self,
|
| 115 |
+
dim,
|
| 116 |
+
mlp_ratio,
|
| 117 |
+
num_heads,
|
| 118 |
+
post_norm=False,
|
| 119 |
+
causal=False,
|
| 120 |
+
activation='quick_gelu',
|
| 121 |
+
attn_dropout=0.0,
|
| 122 |
+
proj_dropout=0.0,
|
| 123 |
+
norm_eps=1e-5):
|
| 124 |
+
assert activation in ['quick_gelu', 'gelu', 'swi_glu']
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.dim = dim
|
| 127 |
+
self.mlp_ratio = mlp_ratio
|
| 128 |
+
self.num_heads = num_heads
|
| 129 |
+
self.post_norm = post_norm
|
| 130 |
+
self.causal = causal
|
| 131 |
+
self.norm_eps = norm_eps
|
| 132 |
+
|
| 133 |
+
# layers
|
| 134 |
+
self.norm1 = LayerNorm(dim, eps=norm_eps)
|
| 135 |
+
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
|
| 136 |
+
proj_dropout)
|
| 137 |
+
self.norm2 = LayerNorm(dim, eps=norm_eps)
|
| 138 |
+
if activation == 'swi_glu':
|
| 139 |
+
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
|
| 140 |
+
else:
|
| 141 |
+
self.mlp = nn.Sequential(
|
| 142 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
| 143 |
+
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
| 144 |
+
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
| 145 |
+
|
| 146 |
+
def forward(self, x):
|
| 147 |
+
if self.post_norm:
|
| 148 |
+
x = x + self.norm1(self.attn(x))
|
| 149 |
+
x = x + self.norm2(self.mlp(x))
|
| 150 |
+
else:
|
| 151 |
+
x = x + self.attn(self.norm1(x))
|
| 152 |
+
x = x + self.mlp(self.norm2(x))
|
| 153 |
+
return x
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class AttentionPool(nn.Module):
|
| 157 |
+
|
| 158 |
+
def __init__(self,
|
| 159 |
+
dim,
|
| 160 |
+
mlp_ratio,
|
| 161 |
+
num_heads,
|
| 162 |
+
activation='gelu',
|
| 163 |
+
proj_dropout=0.0,
|
| 164 |
+
norm_eps=1e-5):
|
| 165 |
+
assert dim % num_heads == 0
|
| 166 |
+
super().__init__()
|
| 167 |
+
self.dim = dim
|
| 168 |
+
self.mlp_ratio = mlp_ratio
|
| 169 |
+
self.num_heads = num_heads
|
| 170 |
+
self.head_dim = dim // num_heads
|
| 171 |
+
self.proj_dropout = proj_dropout
|
| 172 |
+
self.norm_eps = norm_eps
|
| 173 |
+
|
| 174 |
+
# layers
|
| 175 |
+
gain = 1.0 / math.sqrt(dim)
|
| 176 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
| 177 |
+
self.to_q = nn.Linear(dim, dim)
|
| 178 |
+
self.to_kv = nn.Linear(dim, dim * 2)
|
| 179 |
+
self.proj = nn.Linear(dim, dim)
|
| 180 |
+
self.norm = LayerNorm(dim, eps=norm_eps)
|
| 181 |
+
self.mlp = nn.Sequential(
|
| 182 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
| 183 |
+
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
| 184 |
+
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
| 185 |
+
|
| 186 |
+
def forward(self, x):
|
| 187 |
+
"""
|
| 188 |
+
x: [B, L, C].
|
| 189 |
+
"""
|
| 190 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
| 191 |
+
|
| 192 |
+
# compute query, key, value
|
| 193 |
+
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
|
| 194 |
+
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
| 195 |
+
|
| 196 |
+
# compute attention
|
| 197 |
+
x = flash_attention(q, k, v, version=2)
|
| 198 |
+
x = x.reshape(b, 1, c)
|
| 199 |
+
|
| 200 |
+
# output
|
| 201 |
+
x = self.proj(x)
|
| 202 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
| 203 |
+
|
| 204 |
+
# mlp
|
| 205 |
+
x = x + self.mlp(self.norm(x))
|
| 206 |
+
return x[:, 0]
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class VisionTransformer(nn.Module):
|
| 210 |
+
|
| 211 |
+
def __init__(self,
|
| 212 |
+
image_size=224,
|
| 213 |
+
patch_size=16,
|
| 214 |
+
dim=768,
|
| 215 |
+
mlp_ratio=4,
|
| 216 |
+
out_dim=512,
|
| 217 |
+
num_heads=12,
|
| 218 |
+
num_layers=12,
|
| 219 |
+
pool_type='token',
|
| 220 |
+
pre_norm=True,
|
| 221 |
+
post_norm=False,
|
| 222 |
+
activation='quick_gelu',
|
| 223 |
+
attn_dropout=0.0,
|
| 224 |
+
proj_dropout=0.0,
|
| 225 |
+
embedding_dropout=0.0,
|
| 226 |
+
norm_eps=1e-5):
|
| 227 |
+
if image_size % patch_size != 0:
|
| 228 |
+
print(
|
| 229 |
+
'[WARNING] image_size is not divisible by patch_size',
|
| 230 |
+
flush=True)
|
| 231 |
+
assert pool_type in ('token', 'token_fc', 'attn_pool')
|
| 232 |
+
out_dim = out_dim or dim
|
| 233 |
+
super().__init__()
|
| 234 |
+
self.image_size = image_size
|
| 235 |
+
self.patch_size = patch_size
|
| 236 |
+
self.num_patches = (image_size // patch_size)**2
|
| 237 |
+
self.dim = dim
|
| 238 |
+
self.mlp_ratio = mlp_ratio
|
| 239 |
+
self.out_dim = out_dim
|
| 240 |
+
self.num_heads = num_heads
|
| 241 |
+
self.num_layers = num_layers
|
| 242 |
+
self.pool_type = pool_type
|
| 243 |
+
self.post_norm = post_norm
|
| 244 |
+
self.norm_eps = norm_eps
|
| 245 |
+
|
| 246 |
+
# embeddings
|
| 247 |
+
gain = 1.0 / math.sqrt(dim)
|
| 248 |
+
self.patch_embedding = nn.Conv2d(
|
| 249 |
+
3,
|
| 250 |
+
dim,
|
| 251 |
+
kernel_size=patch_size,
|
| 252 |
+
stride=patch_size,
|
| 253 |
+
bias=not pre_norm)
|
| 254 |
+
if pool_type in ('token', 'token_fc'):
|
| 255 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
| 256 |
+
self.pos_embedding = nn.Parameter(gain * torch.randn(
|
| 257 |
+
1, self.num_patches +
|
| 258 |
+
(1 if pool_type in ('token', 'token_fc') else 0), dim))
|
| 259 |
+
self.dropout = nn.Dropout(embedding_dropout)
|
| 260 |
+
|
| 261 |
+
# transformer
|
| 262 |
+
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
|
| 263 |
+
self.transformer = nn.Sequential(*[
|
| 264 |
+
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
|
| 265 |
+
activation, attn_dropout, proj_dropout, norm_eps)
|
| 266 |
+
for _ in range(num_layers)
|
| 267 |
+
])
|
| 268 |
+
self.post_norm = LayerNorm(dim, eps=norm_eps)
|
| 269 |
+
|
| 270 |
+
# head
|
| 271 |
+
if pool_type == 'token':
|
| 272 |
+
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
|
| 273 |
+
elif pool_type == 'token_fc':
|
| 274 |
+
self.head = nn.Linear(dim, out_dim)
|
| 275 |
+
elif pool_type == 'attn_pool':
|
| 276 |
+
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
|
| 277 |
+
proj_dropout, norm_eps)
|
| 278 |
+
|
| 279 |
+
def forward(self, x, interpolation=False, use_31_block=False):
|
| 280 |
+
b = x.size(0)
|
| 281 |
+
|
| 282 |
+
# embeddings
|
| 283 |
+
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
|
| 284 |
+
if self.pool_type in ('token', 'token_fc'):
|
| 285 |
+
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
|
| 286 |
+
if interpolation:
|
| 287 |
+
e = pos_interpolate(self.pos_embedding, x.size(1))
|
| 288 |
+
else:
|
| 289 |
+
e = self.pos_embedding
|
| 290 |
+
x = self.dropout(x + e)
|
| 291 |
+
if self.pre_norm is not None:
|
| 292 |
+
x = self.pre_norm(x)
|
| 293 |
+
|
| 294 |
+
# transformer
|
| 295 |
+
if use_31_block:
|
| 296 |
+
x = self.transformer[:-1](x)
|
| 297 |
+
return x
|
| 298 |
+
else:
|
| 299 |
+
x = self.transformer(x)
|
| 300 |
+
return x
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class XLMRobertaWithHead(XLMRoberta):
|
| 304 |
+
|
| 305 |
+
def __init__(self, **kwargs):
|
| 306 |
+
self.out_dim = kwargs.pop('out_dim')
|
| 307 |
+
super().__init__(**kwargs)
|
| 308 |
+
|
| 309 |
+
# head
|
| 310 |
+
mid_dim = (self.dim + self.out_dim) // 2
|
| 311 |
+
self.head = nn.Sequential(
|
| 312 |
+
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
|
| 313 |
+
nn.Linear(mid_dim, self.out_dim, bias=False))
|
| 314 |
+
|
| 315 |
+
def forward(self, ids):
|
| 316 |
+
# xlm-roberta
|
| 317 |
+
x = super().forward(ids)
|
| 318 |
+
|
| 319 |
+
# average pooling
|
| 320 |
+
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
|
| 321 |
+
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
|
| 322 |
+
|
| 323 |
+
# head
|
| 324 |
+
x = self.head(x)
|
| 325 |
+
return x
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
class XLMRobertaCLIP(nn.Module):
|
| 329 |
+
|
| 330 |
+
def __init__(self,
|
| 331 |
+
embed_dim=1024,
|
| 332 |
+
image_size=224,
|
| 333 |
+
patch_size=14,
|
| 334 |
+
vision_dim=1280,
|
| 335 |
+
vision_mlp_ratio=4,
|
| 336 |
+
vision_heads=16,
|
| 337 |
+
vision_layers=32,
|
| 338 |
+
vision_pool='token',
|
| 339 |
+
vision_pre_norm=True,
|
| 340 |
+
vision_post_norm=False,
|
| 341 |
+
activation='gelu',
|
| 342 |
+
vocab_size=250002,
|
| 343 |
+
max_text_len=514,
|
| 344 |
+
type_size=1,
|
| 345 |
+
pad_id=1,
|
| 346 |
+
text_dim=1024,
|
| 347 |
+
text_heads=16,
|
| 348 |
+
text_layers=24,
|
| 349 |
+
text_post_norm=True,
|
| 350 |
+
text_dropout=0.1,
|
| 351 |
+
attn_dropout=0.0,
|
| 352 |
+
proj_dropout=0.0,
|
| 353 |
+
embedding_dropout=0.0,
|
| 354 |
+
norm_eps=1e-5):
|
| 355 |
+
super().__init__()
|
| 356 |
+
self.embed_dim = embed_dim
|
| 357 |
+
self.image_size = image_size
|
| 358 |
+
self.patch_size = patch_size
|
| 359 |
+
self.vision_dim = vision_dim
|
| 360 |
+
self.vision_mlp_ratio = vision_mlp_ratio
|
| 361 |
+
self.vision_heads = vision_heads
|
| 362 |
+
self.vision_layers = vision_layers
|
| 363 |
+
self.vision_pre_norm = vision_pre_norm
|
| 364 |
+
self.vision_post_norm = vision_post_norm
|
| 365 |
+
self.activation = activation
|
| 366 |
+
self.vocab_size = vocab_size
|
| 367 |
+
self.max_text_len = max_text_len
|
| 368 |
+
self.type_size = type_size
|
| 369 |
+
self.pad_id = pad_id
|
| 370 |
+
self.text_dim = text_dim
|
| 371 |
+
self.text_heads = text_heads
|
| 372 |
+
self.text_layers = text_layers
|
| 373 |
+
self.text_post_norm = text_post_norm
|
| 374 |
+
self.norm_eps = norm_eps
|
| 375 |
+
|
| 376 |
+
# models
|
| 377 |
+
self.visual = VisionTransformer(
|
| 378 |
+
image_size=image_size,
|
| 379 |
+
patch_size=patch_size,
|
| 380 |
+
dim=vision_dim,
|
| 381 |
+
mlp_ratio=vision_mlp_ratio,
|
| 382 |
+
out_dim=embed_dim,
|
| 383 |
+
num_heads=vision_heads,
|
| 384 |
+
num_layers=vision_layers,
|
| 385 |
+
pool_type=vision_pool,
|
| 386 |
+
pre_norm=vision_pre_norm,
|
| 387 |
+
post_norm=vision_post_norm,
|
| 388 |
+
activation=activation,
|
| 389 |
+
attn_dropout=attn_dropout,
|
| 390 |
+
proj_dropout=proj_dropout,
|
| 391 |
+
embedding_dropout=embedding_dropout,
|
| 392 |
+
norm_eps=norm_eps)
|
| 393 |
+
self.textual = XLMRobertaWithHead(
|
| 394 |
+
vocab_size=vocab_size,
|
| 395 |
+
max_seq_len=max_text_len,
|
| 396 |
+
type_size=type_size,
|
| 397 |
+
pad_id=pad_id,
|
| 398 |
+
dim=text_dim,
|
| 399 |
+
out_dim=embed_dim,
|
| 400 |
+
num_heads=text_heads,
|
| 401 |
+
num_layers=text_layers,
|
| 402 |
+
post_norm=text_post_norm,
|
| 403 |
+
dropout=text_dropout)
|
| 404 |
+
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
| 405 |
+
|
| 406 |
+
def forward(self, imgs, txt_ids):
|
| 407 |
+
"""
|
| 408 |
+
imgs: [B, 3, H, W] of torch.float32.
|
| 409 |
+
- mean: [0.48145466, 0.4578275, 0.40821073]
|
| 410 |
+
- std: [0.26862954, 0.26130258, 0.27577711]
|
| 411 |
+
txt_ids: [B, L] of torch.long.
|
| 412 |
+
Encoded by data.CLIPTokenizer.
|
| 413 |
+
"""
|
| 414 |
+
xi = self.visual(imgs)
|
| 415 |
+
xt = self.textual(txt_ids)
|
| 416 |
+
return xi, xt
|
| 417 |
+
|
| 418 |
+
def param_groups(self):
|
| 419 |
+
groups = [{
|
| 420 |
+
'params': [
|
| 421 |
+
p for n, p in self.named_parameters()
|
| 422 |
+
if 'norm' in n or n.endswith('bias')
|
| 423 |
+
],
|
| 424 |
+
'weight_decay': 0.0
|
| 425 |
+
}, {
|
| 426 |
+
'params': [
|
| 427 |
+
p for n, p in self.named_parameters()
|
| 428 |
+
if not ('norm' in n or n.endswith('bias'))
|
| 429 |
+
]
|
| 430 |
+
}]
|
| 431 |
+
return groups
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def _clip(pretrained=False,
|
| 435 |
+
pretrained_name=None,
|
| 436 |
+
model_cls=XLMRobertaCLIP,
|
| 437 |
+
return_transforms=False,
|
| 438 |
+
return_tokenizer=False,
|
| 439 |
+
tokenizer_padding='eos',
|
| 440 |
+
dtype=torch.float32,
|
| 441 |
+
device='cpu',
|
| 442 |
+
**kwargs):
|
| 443 |
+
# init a model on device
|
| 444 |
+
with torch.device(device):
|
| 445 |
+
model = model_cls(**kwargs)
|
| 446 |
+
|
| 447 |
+
# set device
|
| 448 |
+
model = model.to(dtype=dtype, device=device)
|
| 449 |
+
output = (model,)
|
| 450 |
+
|
| 451 |
+
# init transforms
|
| 452 |
+
if return_transforms:
|
| 453 |
+
# mean and std
|
| 454 |
+
if 'siglip' in pretrained_name.lower():
|
| 455 |
+
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
|
| 456 |
+
else:
|
| 457 |
+
mean = [0.48145466, 0.4578275, 0.40821073]
|
| 458 |
+
std = [0.26862954, 0.26130258, 0.27577711]
|
| 459 |
+
|
| 460 |
+
# transforms
|
| 461 |
+
transforms = T.Compose([
|
| 462 |
+
T.Resize((model.image_size, model.image_size),
|
| 463 |
+
interpolation=T.InterpolationMode.BICUBIC),
|
| 464 |
+
T.ToTensor(),
|
| 465 |
+
T.Normalize(mean=mean, std=std)
|
| 466 |
+
])
|
| 467 |
+
output += (transforms,)
|
| 468 |
+
return output[0] if len(output) == 1 else output
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def clip_xlm_roberta_vit_h_14(
|
| 472 |
+
pretrained=False,
|
| 473 |
+
pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
|
| 474 |
+
**kwargs):
|
| 475 |
+
cfg = dict(
|
| 476 |
+
embed_dim=1024,
|
| 477 |
+
image_size=224,
|
| 478 |
+
patch_size=14,
|
| 479 |
+
vision_dim=1280,
|
| 480 |
+
vision_mlp_ratio=4,
|
| 481 |
+
vision_heads=16,
|
| 482 |
+
vision_layers=32,
|
| 483 |
+
vision_pool='token',
|
| 484 |
+
activation='gelu',
|
| 485 |
+
vocab_size=250002,
|
| 486 |
+
max_text_len=514,
|
| 487 |
+
type_size=1,
|
| 488 |
+
pad_id=1,
|
| 489 |
+
text_dim=1024,
|
| 490 |
+
text_heads=16,
|
| 491 |
+
text_layers=24,
|
| 492 |
+
text_post_norm=True,
|
| 493 |
+
text_dropout=0.1,
|
| 494 |
+
attn_dropout=0.0,
|
| 495 |
+
proj_dropout=0.0,
|
| 496 |
+
embedding_dropout=0.0)
|
| 497 |
+
cfg.update(**kwargs)
|
| 498 |
+
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
class CLIPModel:
|
| 502 |
+
|
| 503 |
+
def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
|
| 504 |
+
self.dtype = dtype
|
| 505 |
+
self.device = device
|
| 506 |
+
self.checkpoint_path = checkpoint_path
|
| 507 |
+
self.tokenizer_path = tokenizer_path
|
| 508 |
+
|
| 509 |
+
# init model
|
| 510 |
+
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
|
| 511 |
+
pretrained=False,
|
| 512 |
+
return_transforms=True,
|
| 513 |
+
return_tokenizer=False,
|
| 514 |
+
dtype=dtype,
|
| 515 |
+
device=device)
|
| 516 |
+
self.model = self.model.eval().requires_grad_(False)
|
| 517 |
+
logging.info(f'loading {checkpoint_path}')
|
| 518 |
+
self.model.load_state_dict(
|
| 519 |
+
torch.load(checkpoint_path, map_location='cpu'))
|
| 520 |
+
|
| 521 |
+
# init tokenizer
|
| 522 |
+
self.tokenizer = HuggingfaceTokenizer(
|
| 523 |
+
name=tokenizer_path,
|
| 524 |
+
seq_len=self.model.max_text_len - 2,
|
| 525 |
+
clean='whitespace')
|
| 526 |
+
|
| 527 |
+
def visual(self, videos):
|
| 528 |
+
# preprocess
|
| 529 |
+
size = (self.model.image_size,) * 2
|
| 530 |
+
videos = torch.cat([
|
| 531 |
+
F.interpolate(
|
| 532 |
+
u.transpose(0, 1),
|
| 533 |
+
size=size,
|
| 534 |
+
mode='bicubic',
|
| 535 |
+
align_corners=False) for u in videos
|
| 536 |
+
])
|
| 537 |
+
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
| 538 |
+
|
| 539 |
+
# forward
|
| 540 |
+
with torch.cuda.amp.autocast(dtype=self.dtype):
|
| 541 |
+
out = self.model.visual(videos, use_31_block=True)
|
| 542 |
+
return out
|
wan/modules/animate/face_blocks.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Tuple, Optional
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import math
|
| 8 |
+
from ...distributed.util import gather_forward, get_rank, get_world_size
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
| 13 |
+
except ImportError:
|
| 14 |
+
flash_attn_func = None
|
| 15 |
+
|
| 16 |
+
MEMORY_LAYOUT = {
|
| 17 |
+
"flash": (
|
| 18 |
+
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
| 19 |
+
lambda x: x,
|
| 20 |
+
),
|
| 21 |
+
"torch": (
|
| 22 |
+
lambda x: x.transpose(1, 2),
|
| 23 |
+
lambda x: x.transpose(1, 2),
|
| 24 |
+
),
|
| 25 |
+
"vanilla": (
|
| 26 |
+
lambda x: x.transpose(1, 2),
|
| 27 |
+
lambda x: x.transpose(1, 2),
|
| 28 |
+
),
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def attention(
|
| 33 |
+
q,
|
| 34 |
+
k,
|
| 35 |
+
v,
|
| 36 |
+
mode="flash",
|
| 37 |
+
drop_rate=0,
|
| 38 |
+
attn_mask=None,
|
| 39 |
+
causal=False,
|
| 40 |
+
max_seqlen_q=None,
|
| 41 |
+
batch_size=1,
|
| 42 |
+
):
|
| 43 |
+
"""
|
| 44 |
+
Perform QKV self attention.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
|
| 48 |
+
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
|
| 49 |
+
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
|
| 50 |
+
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
|
| 51 |
+
drop_rate (float): Dropout rate in attention map. (default: 0)
|
| 52 |
+
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
|
| 53 |
+
(default: None)
|
| 54 |
+
causal (bool): Whether to use causal attention. (default: False)
|
| 55 |
+
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
| 56 |
+
used to index into q.
|
| 57 |
+
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
| 58 |
+
used to index into kv.
|
| 59 |
+
max_seqlen_q (int): The maximum sequence length in the batch of q.
|
| 60 |
+
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
|
| 64 |
+
"""
|
| 65 |
+
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
| 66 |
+
|
| 67 |
+
if mode == "torch":
|
| 68 |
+
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
| 69 |
+
attn_mask = attn_mask.to(q.dtype)
|
| 70 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
|
| 71 |
+
|
| 72 |
+
elif mode == "flash":
|
| 73 |
+
x = flash_attn_func(
|
| 74 |
+
q,
|
| 75 |
+
k,
|
| 76 |
+
v,
|
| 77 |
+
)
|
| 78 |
+
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
|
| 79 |
+
elif mode == "vanilla":
|
| 80 |
+
scale_factor = 1 / math.sqrt(q.size(-1))
|
| 81 |
+
|
| 82 |
+
b, a, s, _ = q.shape
|
| 83 |
+
s1 = k.size(2)
|
| 84 |
+
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
| 85 |
+
if causal:
|
| 86 |
+
# Only applied to self attention
|
| 87 |
+
assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
|
| 88 |
+
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
|
| 89 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
| 90 |
+
attn_bias.to(q.dtype)
|
| 91 |
+
|
| 92 |
+
if attn_mask is not None:
|
| 93 |
+
if attn_mask.dtype == torch.bool:
|
| 94 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
| 95 |
+
else:
|
| 96 |
+
attn_bias += attn_mask
|
| 97 |
+
|
| 98 |
+
attn = (q @ k.transpose(-2, -1)) * scale_factor
|
| 99 |
+
attn += attn_bias
|
| 100 |
+
attn = attn.softmax(dim=-1)
|
| 101 |
+
attn = torch.dropout(attn, p=drop_rate, train=True)
|
| 102 |
+
x = attn @ v
|
| 103 |
+
else:
|
| 104 |
+
raise NotImplementedError(f"Unsupported attention mode: {mode}")
|
| 105 |
+
|
| 106 |
+
x = post_attn_layout(x)
|
| 107 |
+
b, s, a, d = x.shape
|
| 108 |
+
out = x.reshape(b, s, -1)
|
| 109 |
+
return out
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class CausalConv1d(nn.Module):
|
| 113 |
+
|
| 114 |
+
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
|
| 115 |
+
super().__init__()
|
| 116 |
+
|
| 117 |
+
self.pad_mode = pad_mode
|
| 118 |
+
padding = (kernel_size - 1, 0) # T
|
| 119 |
+
self.time_causal_padding = padding
|
| 120 |
+
|
| 121 |
+
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
| 122 |
+
|
| 123 |
+
def forward(self, x):
|
| 124 |
+
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
| 125 |
+
return self.conv(x)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class FaceEncoder(nn.Module):
|
| 130 |
+
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
|
| 131 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
| 132 |
+
super().__init__()
|
| 133 |
+
|
| 134 |
+
self.num_heads = num_heads
|
| 135 |
+
self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
|
| 136 |
+
self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 137 |
+
self.act = nn.SiLU()
|
| 138 |
+
self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
|
| 139 |
+
self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
|
| 140 |
+
|
| 141 |
+
self.out_proj = nn.Linear(1024, hidden_dim)
|
| 142 |
+
self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 143 |
+
|
| 144 |
+
self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 145 |
+
|
| 146 |
+
self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 147 |
+
|
| 148 |
+
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
|
| 149 |
+
|
| 150 |
+
def forward(self, x):
|
| 151 |
+
|
| 152 |
+
x = rearrange(x, "b t c -> b c t")
|
| 153 |
+
b, c, t = x.shape
|
| 154 |
+
|
| 155 |
+
x = self.conv1_local(x)
|
| 156 |
+
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
|
| 157 |
+
|
| 158 |
+
x = self.norm1(x)
|
| 159 |
+
x = self.act(x)
|
| 160 |
+
x = rearrange(x, "b t c -> b c t")
|
| 161 |
+
x = self.conv2(x)
|
| 162 |
+
x = rearrange(x, "b c t -> b t c")
|
| 163 |
+
x = self.norm2(x)
|
| 164 |
+
x = self.act(x)
|
| 165 |
+
x = rearrange(x, "b t c -> b c t")
|
| 166 |
+
x = self.conv3(x)
|
| 167 |
+
x = rearrange(x, "b c t -> b t c")
|
| 168 |
+
x = self.norm3(x)
|
| 169 |
+
x = self.act(x)
|
| 170 |
+
x = self.out_proj(x)
|
| 171 |
+
x = rearrange(x, "(b n) t c -> b t n c", b=b)
|
| 172 |
+
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
|
| 173 |
+
x = torch.cat([x, padding], dim=-2)
|
| 174 |
+
x_local = x.clone()
|
| 175 |
+
|
| 176 |
+
return x_local
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class RMSNorm(nn.Module):
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
dim: int,
|
| 184 |
+
elementwise_affine=True,
|
| 185 |
+
eps: float = 1e-6,
|
| 186 |
+
device=None,
|
| 187 |
+
dtype=None,
|
| 188 |
+
):
|
| 189 |
+
"""
|
| 190 |
+
Initialize the RMSNorm normalization layer.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
dim (int): The dimension of the input tensor.
|
| 194 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
| 195 |
+
|
| 196 |
+
Attributes:
|
| 197 |
+
eps (float): A small value added to the denominator for numerical stability.
|
| 198 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
| 199 |
+
|
| 200 |
+
"""
|
| 201 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 202 |
+
super().__init__()
|
| 203 |
+
self.eps = eps
|
| 204 |
+
if elementwise_affine:
|
| 205 |
+
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
| 206 |
+
|
| 207 |
+
def _norm(self, x):
|
| 208 |
+
"""
|
| 209 |
+
Apply the RMSNorm normalization to the input tensor.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
x (torch.Tensor): The input tensor.
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
torch.Tensor: The normalized tensor.
|
| 216 |
+
|
| 217 |
+
"""
|
| 218 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 219 |
+
|
| 220 |
+
def forward(self, x):
|
| 221 |
+
"""
|
| 222 |
+
Forward pass through the RMSNorm layer.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
x (torch.Tensor): The input tensor.
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
| 229 |
+
|
| 230 |
+
"""
|
| 231 |
+
output = self._norm(x.float()).type_as(x)
|
| 232 |
+
if hasattr(self, "weight"):
|
| 233 |
+
output = output * self.weight
|
| 234 |
+
return output
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def get_norm_layer(norm_layer):
|
| 238 |
+
"""
|
| 239 |
+
Get the normalization layer.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
norm_layer (str): The type of normalization layer.
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
norm_layer (nn.Module): The normalization layer.
|
| 246 |
+
"""
|
| 247 |
+
if norm_layer == "layer":
|
| 248 |
+
return nn.LayerNorm
|
| 249 |
+
elif norm_layer == "rms":
|
| 250 |
+
return RMSNorm
|
| 251 |
+
else:
|
| 252 |
+
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class FaceAdapter(nn.Module):
|
| 256 |
+
def __init__(
|
| 257 |
+
self,
|
| 258 |
+
hidden_dim: int,
|
| 259 |
+
heads_num: int,
|
| 260 |
+
qk_norm: bool = True,
|
| 261 |
+
qk_norm_type: str = "rms",
|
| 262 |
+
num_adapter_layers: int = 1,
|
| 263 |
+
dtype=None,
|
| 264 |
+
device=None,
|
| 265 |
+
):
|
| 266 |
+
|
| 267 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
| 268 |
+
super().__init__()
|
| 269 |
+
self.hidden_size = hidden_dim
|
| 270 |
+
self.heads_num = heads_num
|
| 271 |
+
self.fuser_blocks = nn.ModuleList(
|
| 272 |
+
[
|
| 273 |
+
FaceBlock(
|
| 274 |
+
self.hidden_size,
|
| 275 |
+
self.heads_num,
|
| 276 |
+
qk_norm=qk_norm,
|
| 277 |
+
qk_norm_type=qk_norm_type,
|
| 278 |
+
**factory_kwargs,
|
| 279 |
+
)
|
| 280 |
+
for _ in range(num_adapter_layers)
|
| 281 |
+
]
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
def forward(
|
| 285 |
+
self,
|
| 286 |
+
x: torch.Tensor,
|
| 287 |
+
motion_embed: torch.Tensor,
|
| 288 |
+
idx: int,
|
| 289 |
+
freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
|
| 290 |
+
freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
|
| 291 |
+
) -> torch.Tensor:
|
| 292 |
+
|
| 293 |
+
return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class FaceBlock(nn.Module):
|
| 298 |
+
def __init__(
|
| 299 |
+
self,
|
| 300 |
+
hidden_size: int,
|
| 301 |
+
heads_num: int,
|
| 302 |
+
qk_norm: bool = True,
|
| 303 |
+
qk_norm_type: str = "rms",
|
| 304 |
+
qk_scale: float = None,
|
| 305 |
+
dtype: Optional[torch.dtype] = None,
|
| 306 |
+
device: Optional[torch.device] = None,
|
| 307 |
+
):
|
| 308 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 309 |
+
super().__init__()
|
| 310 |
+
|
| 311 |
+
self.deterministic = False
|
| 312 |
+
self.hidden_size = hidden_size
|
| 313 |
+
self.heads_num = heads_num
|
| 314 |
+
head_dim = hidden_size // heads_num
|
| 315 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 316 |
+
|
| 317 |
+
self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
|
| 318 |
+
self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
| 319 |
+
|
| 320 |
+
self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
| 321 |
+
|
| 322 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
| 323 |
+
self.q_norm = (
|
| 324 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 325 |
+
)
|
| 326 |
+
self.k_norm = (
|
| 327 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 331 |
+
|
| 332 |
+
self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 333 |
+
|
| 334 |
+
def forward(
|
| 335 |
+
self,
|
| 336 |
+
x: torch.Tensor,
|
| 337 |
+
motion_vec: torch.Tensor,
|
| 338 |
+
motion_mask: Optional[torch.Tensor] = None,
|
| 339 |
+
use_context_parallel=False,
|
| 340 |
+
) -> torch.Tensor:
|
| 341 |
+
|
| 342 |
+
B, T, N, C = motion_vec.shape
|
| 343 |
+
T_comp = T
|
| 344 |
+
|
| 345 |
+
x_motion = self.pre_norm_motion(motion_vec)
|
| 346 |
+
x_feat = self.pre_norm_feat(x)
|
| 347 |
+
|
| 348 |
+
kv = self.linear1_kv(x_motion)
|
| 349 |
+
q = self.linear1_q(x_feat)
|
| 350 |
+
|
| 351 |
+
k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
|
| 352 |
+
q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
|
| 353 |
+
|
| 354 |
+
# Apply QK-Norm if needed.
|
| 355 |
+
q = self.q_norm(q).to(v)
|
| 356 |
+
k = self.k_norm(k).to(v)
|
| 357 |
+
|
| 358 |
+
k = rearrange(k, "B L N H D -> (B L) N H D")
|
| 359 |
+
v = rearrange(v, "B L N H D -> (B L) N H D")
|
| 360 |
+
|
| 361 |
+
if use_context_parallel:
|
| 362 |
+
q = gather_forward(q, dim=1)
|
| 363 |
+
|
| 364 |
+
q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp)
|
| 365 |
+
# Compute attention.
|
| 366 |
+
attn = attention(
|
| 367 |
+
q,
|
| 368 |
+
k,
|
| 369 |
+
v,
|
| 370 |
+
max_seqlen_q=q.shape[1],
|
| 371 |
+
batch_size=q.shape[0],
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
|
| 375 |
+
if use_context_parallel:
|
| 376 |
+
attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()]
|
| 377 |
+
|
| 378 |
+
output = self.linear2(attn)
|
| 379 |
+
|
| 380 |
+
if motion_mask is not None:
|
| 381 |
+
output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
|
| 382 |
+
|
| 383 |
+
return output
|
wan/modules/animate/model_animate.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import math
|
| 3 |
+
import types
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from typing import List
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.cuda.amp as amp
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 12 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 13 |
+
from diffusers.loaders import PeftAdapterMixin
|
| 14 |
+
|
| 15 |
+
from ...distributed.sequence_parallel import (
|
| 16 |
+
distributed_attention,
|
| 17 |
+
gather_forward,
|
| 18 |
+
get_rank,
|
| 19 |
+
get_world_size,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
from ..model import (
|
| 24 |
+
Head,
|
| 25 |
+
WanAttentionBlock,
|
| 26 |
+
WanLayerNorm,
|
| 27 |
+
WanRMSNorm,
|
| 28 |
+
WanModel,
|
| 29 |
+
WanSelfAttention,
|
| 30 |
+
flash_attention,
|
| 31 |
+
rope_params,
|
| 32 |
+
sinusoidal_embedding_1d,
|
| 33 |
+
rope_apply
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
from .face_blocks import FaceEncoder, FaceAdapter
|
| 37 |
+
from .motion_encoder import Generator
|
| 38 |
+
|
| 39 |
+
class HeadAnimate(Head):
|
| 40 |
+
|
| 41 |
+
def forward(self, x, e):
|
| 42 |
+
"""
|
| 43 |
+
Args:
|
| 44 |
+
x(Tensor): Shape [B, L1, C]
|
| 45 |
+
e(Tensor): Shape [B, L1, C]
|
| 46 |
+
"""
|
| 47 |
+
assert e.dtype == torch.float32
|
| 48 |
+
with amp.autocast(dtype=torch.float32):
|
| 49 |
+
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
|
| 50 |
+
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
| 51 |
+
return x
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class WanAnimateSelfAttention(WanSelfAttention):
|
| 55 |
+
|
| 56 |
+
def forward(self, x, seq_lens, grid_sizes, freqs):
|
| 57 |
+
"""
|
| 58 |
+
Args:
|
| 59 |
+
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
| 60 |
+
seq_lens(Tensor): Shape [B]
|
| 61 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 62 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 63 |
+
"""
|
| 64 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 65 |
+
|
| 66 |
+
# query, key, value function
|
| 67 |
+
def qkv_fn(x):
|
| 68 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 69 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 70 |
+
v = self.v(x).view(b, s, n, d)
|
| 71 |
+
return q, k, v
|
| 72 |
+
|
| 73 |
+
q, k, v = qkv_fn(x)
|
| 74 |
+
|
| 75 |
+
x = flash_attention(
|
| 76 |
+
q=rope_apply(q, grid_sizes, freqs),
|
| 77 |
+
k=rope_apply(k, grid_sizes, freqs),
|
| 78 |
+
v=v,
|
| 79 |
+
k_lens=seq_lens,
|
| 80 |
+
window_size=self.window_size)
|
| 81 |
+
|
| 82 |
+
# output
|
| 83 |
+
x = x.flatten(2)
|
| 84 |
+
x = self.o(x)
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class WanAnimateCrossAttention(WanSelfAttention):
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
dim,
|
| 92 |
+
num_heads,
|
| 93 |
+
window_size=(-1, -1),
|
| 94 |
+
qk_norm=True,
|
| 95 |
+
eps=1e-6,
|
| 96 |
+
use_img_emb=True
|
| 97 |
+
):
|
| 98 |
+
super().__init__(
|
| 99 |
+
dim,
|
| 100 |
+
num_heads,
|
| 101 |
+
window_size,
|
| 102 |
+
qk_norm,
|
| 103 |
+
eps
|
| 104 |
+
)
|
| 105 |
+
self.use_img_emb = use_img_emb
|
| 106 |
+
|
| 107 |
+
if use_img_emb:
|
| 108 |
+
self.k_img = nn.Linear(dim, dim)
|
| 109 |
+
self.v_img = nn.Linear(dim, dim)
|
| 110 |
+
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
| 111 |
+
|
| 112 |
+
def forward(self, x, context, context_lens):
|
| 113 |
+
"""
|
| 114 |
+
x: [B, L1, C].
|
| 115 |
+
context: [B, L2, C].
|
| 116 |
+
context_lens: [B].
|
| 117 |
+
"""
|
| 118 |
+
if self.use_img_emb:
|
| 119 |
+
context_img = context[:, :257]
|
| 120 |
+
context = context[:, 257:]
|
| 121 |
+
else:
|
| 122 |
+
context = context
|
| 123 |
+
|
| 124 |
+
b, n, d = x.size(0), self.num_heads, self.head_dim
|
| 125 |
+
|
| 126 |
+
# compute query, key, value
|
| 127 |
+
q = self.norm_q(self.q(x)).view(b, -1, n, d)
|
| 128 |
+
k = self.norm_k(self.k(context)).view(b, -1, n, d)
|
| 129 |
+
v = self.v(context).view(b, -1, n, d)
|
| 130 |
+
|
| 131 |
+
if self.use_img_emb:
|
| 132 |
+
k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
|
| 133 |
+
v_img = self.v_img(context_img).view(b, -1, n, d)
|
| 134 |
+
img_x = flash_attention(q, k_img, v_img, k_lens=None)
|
| 135 |
+
# compute attention
|
| 136 |
+
x = flash_attention(q, k, v, k_lens=context_lens)
|
| 137 |
+
|
| 138 |
+
# output
|
| 139 |
+
x = x.flatten(2)
|
| 140 |
+
|
| 141 |
+
if self.use_img_emb:
|
| 142 |
+
img_x = img_x.flatten(2)
|
| 143 |
+
x = x + img_x
|
| 144 |
+
|
| 145 |
+
x = self.o(x)
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class WanAnimateAttentionBlock(nn.Module):
|
| 150 |
+
def __init__(self,
|
| 151 |
+
dim,
|
| 152 |
+
ffn_dim,
|
| 153 |
+
num_heads,
|
| 154 |
+
window_size=(-1, -1),
|
| 155 |
+
qk_norm=True,
|
| 156 |
+
cross_attn_norm=True,
|
| 157 |
+
eps=1e-6,
|
| 158 |
+
use_img_emb=True):
|
| 159 |
+
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.dim = dim
|
| 162 |
+
self.ffn_dim = ffn_dim
|
| 163 |
+
self.num_heads = num_heads
|
| 164 |
+
self.window_size = window_size
|
| 165 |
+
self.qk_norm = qk_norm
|
| 166 |
+
self.cross_attn_norm = cross_attn_norm
|
| 167 |
+
self.eps = eps
|
| 168 |
+
|
| 169 |
+
# layers
|
| 170 |
+
self.norm1 = WanLayerNorm(dim, eps)
|
| 171 |
+
self.self_attn = WanAnimateSelfAttention(dim, num_heads, window_size, qk_norm, eps)
|
| 172 |
+
|
| 173 |
+
self.norm3 = WanLayerNorm(
|
| 174 |
+
dim, eps, elementwise_affine=True
|
| 175 |
+
) if cross_attn_norm else nn.Identity()
|
| 176 |
+
|
| 177 |
+
self.cross_attn = WanAnimateCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps, use_img_emb=use_img_emb)
|
| 178 |
+
self.norm2 = WanLayerNorm(dim, eps)
|
| 179 |
+
self.ffn = nn.Sequential(
|
| 180 |
+
nn.Linear(dim, ffn_dim),
|
| 181 |
+
nn.GELU(approximate='tanh'),
|
| 182 |
+
nn.Linear(ffn_dim, dim)
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# modulation
|
| 186 |
+
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5)
|
| 187 |
+
|
| 188 |
+
def forward(
|
| 189 |
+
self,
|
| 190 |
+
x,
|
| 191 |
+
e,
|
| 192 |
+
seq_lens,
|
| 193 |
+
grid_sizes,
|
| 194 |
+
freqs,
|
| 195 |
+
context,
|
| 196 |
+
context_lens,
|
| 197 |
+
):
|
| 198 |
+
"""
|
| 199 |
+
Args:
|
| 200 |
+
x(Tensor): Shape [B, L, C]
|
| 201 |
+
e(Tensor): Shape [B, L1, 6, C]
|
| 202 |
+
seq_lens(Tensor): Shape [B], length of each sequence in batch
|
| 203 |
+
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
| 204 |
+
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
| 205 |
+
"""
|
| 206 |
+
assert e.dtype == torch.float32
|
| 207 |
+
with amp.autocast(dtype=torch.float32):
|
| 208 |
+
e = (self.modulation + e).chunk(6, dim=1)
|
| 209 |
+
assert e[0].dtype == torch.float32
|
| 210 |
+
|
| 211 |
+
# self-attention
|
| 212 |
+
y = self.self_attn(
|
| 213 |
+
self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs
|
| 214 |
+
)
|
| 215 |
+
with amp.autocast(dtype=torch.float32):
|
| 216 |
+
x = x + y * e[2]
|
| 217 |
+
|
| 218 |
+
# cross-attention & ffn function
|
| 219 |
+
def cross_attn_ffn(x, context, context_lens, e):
|
| 220 |
+
x = x + self.cross_attn(self.norm3(x), context, context_lens)
|
| 221 |
+
y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
|
| 222 |
+
with amp.autocast(dtype=torch.float32):
|
| 223 |
+
x = x + y * e[5]
|
| 224 |
+
return x
|
| 225 |
+
|
| 226 |
+
x = cross_attn_ffn(x, context, context_lens, e)
|
| 227 |
+
return x
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class MLPProj(torch.nn.Module):
|
| 231 |
+
def __init__(self, in_dim, out_dim):
|
| 232 |
+
super().__init__()
|
| 233 |
+
|
| 234 |
+
self.proj = torch.nn.Sequential(
|
| 235 |
+
torch.nn.LayerNorm(in_dim),
|
| 236 |
+
torch.nn.Linear(in_dim, in_dim),
|
| 237 |
+
torch.nn.GELU(),
|
| 238 |
+
torch.nn.Linear(in_dim, out_dim),
|
| 239 |
+
torch.nn.LayerNorm(out_dim),
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
def forward(self, image_embeds):
|
| 243 |
+
clip_extra_context_tokens = self.proj(image_embeds)
|
| 244 |
+
return clip_extra_context_tokens
|
| 245 |
+
|
| 246 |
+
class WanAnimateModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
| 247 |
+
_no_split_modules = ['WanAttentionBlock']
|
| 248 |
+
|
| 249 |
+
@register_to_config
|
| 250 |
+
def __init__(self,
|
| 251 |
+
patch_size=(1, 2, 2),
|
| 252 |
+
text_len=512,
|
| 253 |
+
in_dim=36,
|
| 254 |
+
dim=5120,
|
| 255 |
+
ffn_dim=13824,
|
| 256 |
+
freq_dim=256,
|
| 257 |
+
text_dim=4096,
|
| 258 |
+
out_dim=16,
|
| 259 |
+
num_heads=40,
|
| 260 |
+
num_layers=40,
|
| 261 |
+
window_size=(-1, -1),
|
| 262 |
+
qk_norm=True,
|
| 263 |
+
cross_attn_norm=True,
|
| 264 |
+
eps=1e-6,
|
| 265 |
+
motion_encoder_dim=512,
|
| 266 |
+
use_context_parallel=False,
|
| 267 |
+
use_img_emb=True):
|
| 268 |
+
|
| 269 |
+
super().__init__()
|
| 270 |
+
self.patch_size = patch_size
|
| 271 |
+
self.text_len = text_len
|
| 272 |
+
self.in_dim = in_dim
|
| 273 |
+
self.dim = dim
|
| 274 |
+
self.ffn_dim = ffn_dim
|
| 275 |
+
self.freq_dim = freq_dim
|
| 276 |
+
self.text_dim = text_dim
|
| 277 |
+
self.out_dim = out_dim
|
| 278 |
+
self.num_heads = num_heads
|
| 279 |
+
self.num_layers = num_layers
|
| 280 |
+
self.window_size = window_size
|
| 281 |
+
self.qk_norm = qk_norm
|
| 282 |
+
self.cross_attn_norm = cross_attn_norm
|
| 283 |
+
self.eps = eps
|
| 284 |
+
self.motion_encoder_dim = motion_encoder_dim
|
| 285 |
+
self.use_context_parallel = use_context_parallel
|
| 286 |
+
self.use_img_emb = use_img_emb
|
| 287 |
+
|
| 288 |
+
# embeddings
|
| 289 |
+
self.patch_embedding = nn.Conv3d(
|
| 290 |
+
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
| 291 |
+
|
| 292 |
+
self.pose_patch_embedding = nn.Conv3d(
|
| 293 |
+
16, dim, kernel_size=patch_size, stride=patch_size
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
self.text_embedding = nn.Sequential(
|
| 297 |
+
nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
|
| 298 |
+
nn.Linear(dim, dim))
|
| 299 |
+
|
| 300 |
+
self.time_embedding = nn.Sequential(
|
| 301 |
+
nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
| 302 |
+
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
| 303 |
+
|
| 304 |
+
# blocks
|
| 305 |
+
self.blocks = nn.ModuleList([
|
| 306 |
+
WanAnimateAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
|
| 307 |
+
cross_attn_norm, eps, use_img_emb) for _ in range(num_layers)
|
| 308 |
+
])
|
| 309 |
+
|
| 310 |
+
# head
|
| 311 |
+
self.head = HeadAnimate(dim, out_dim, patch_size, eps)
|
| 312 |
+
|
| 313 |
+
# buffers (don't use register_buffer otherwise dtype will be changed in to())
|
| 314 |
+
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
| 315 |
+
d = dim // num_heads
|
| 316 |
+
self.freqs = torch.cat([
|
| 317 |
+
rope_params(1024, d - 4 * (d // 6)),
|
| 318 |
+
rope_params(1024, 2 * (d // 6)),
|
| 319 |
+
rope_params(1024, 2 * (d // 6))
|
| 320 |
+
], dim=1)
|
| 321 |
+
|
| 322 |
+
self.img_emb = MLPProj(1280, dim)
|
| 323 |
+
|
| 324 |
+
# initialize weights
|
| 325 |
+
self.init_weights()
|
| 326 |
+
|
| 327 |
+
self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
|
| 328 |
+
self.face_adapter = FaceAdapter(
|
| 329 |
+
heads_num=self.num_heads,
|
| 330 |
+
hidden_dim=self.dim,
|
| 331 |
+
num_adapter_layers=self.num_layers // 5,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
self.face_encoder = FaceEncoder(
|
| 335 |
+
in_dim=motion_encoder_dim,
|
| 336 |
+
hidden_dim=self.dim,
|
| 337 |
+
num_heads=4,
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
|
| 341 |
+
pose_latents = [self.pose_patch_embedding(u.unsqueeze(0)) for u in pose_latents]
|
| 342 |
+
for x_, pose_latents_ in zip(x, pose_latents):
|
| 343 |
+
x_[:, :, 1:] += pose_latents_
|
| 344 |
+
|
| 345 |
+
b,c,T,h,w = face_pixel_values.shape
|
| 346 |
+
face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
|
| 347 |
+
|
| 348 |
+
encode_bs = 8
|
| 349 |
+
face_pixel_values_tmp = []
|
| 350 |
+
for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
|
| 351 |
+
face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
|
| 352 |
+
|
| 353 |
+
motion_vec = torch.cat(face_pixel_values_tmp)
|
| 354 |
+
|
| 355 |
+
motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
|
| 356 |
+
motion_vec = self.face_encoder(motion_vec)
|
| 357 |
+
|
| 358 |
+
B, L, H, C = motion_vec.shape
|
| 359 |
+
pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
|
| 360 |
+
motion_vec = torch.cat([pad_face, motion_vec], dim=1)
|
| 361 |
+
return x, motion_vec
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None):
|
| 365 |
+
if block_idx % 5 == 0:
|
| 366 |
+
adapter_args = [x, motion_vec, motion_masks, self.use_context_parallel]
|
| 367 |
+
residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args)
|
| 368 |
+
x = residual_out + x
|
| 369 |
+
return x
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def forward(
|
| 373 |
+
self,
|
| 374 |
+
x,
|
| 375 |
+
t,
|
| 376 |
+
clip_fea,
|
| 377 |
+
context,
|
| 378 |
+
seq_len,
|
| 379 |
+
y=None,
|
| 380 |
+
pose_latents=None,
|
| 381 |
+
face_pixel_values=None
|
| 382 |
+
):
|
| 383 |
+
# params
|
| 384 |
+
device = self.patch_embedding.weight.device
|
| 385 |
+
if self.freqs.device != device:
|
| 386 |
+
self.freqs = self.freqs.to(device)
|
| 387 |
+
|
| 388 |
+
if y is not None:
|
| 389 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
| 390 |
+
|
| 391 |
+
# embeddings
|
| 392 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
| 393 |
+
x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values)
|
| 394 |
+
|
| 395 |
+
grid_sizes = torch.stack(
|
| 396 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
| 397 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
| 398 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
| 399 |
+
assert seq_lens.max() <= seq_len
|
| 400 |
+
x = torch.cat([
|
| 401 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
| 402 |
+
dim=1) for u in x
|
| 403 |
+
])
|
| 404 |
+
|
| 405 |
+
# time embeddings
|
| 406 |
+
with amp.autocast(dtype=torch.float32):
|
| 407 |
+
e = self.time_embedding(
|
| 408 |
+
sinusoidal_embedding_1d(self.freq_dim, t).float()
|
| 409 |
+
)
|
| 410 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
| 411 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 412 |
+
|
| 413 |
+
# context
|
| 414 |
+
context_lens = None
|
| 415 |
+
context = self.text_embedding(
|
| 416 |
+
torch.stack([
|
| 417 |
+
torch.cat(
|
| 418 |
+
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
| 419 |
+
for u in context
|
| 420 |
+
]))
|
| 421 |
+
|
| 422 |
+
if self.use_img_emb:
|
| 423 |
+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
| 424 |
+
context = torch.concat([context_clip, context], dim=1)
|
| 425 |
+
|
| 426 |
+
# arguments
|
| 427 |
+
kwargs = dict(
|
| 428 |
+
e=e0,
|
| 429 |
+
seq_lens=seq_lens,
|
| 430 |
+
grid_sizes=grid_sizes,
|
| 431 |
+
freqs=self.freqs,
|
| 432 |
+
context=context,
|
| 433 |
+
context_lens=context_lens)
|
| 434 |
+
|
| 435 |
+
if self.use_context_parallel:
|
| 436 |
+
x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]
|
| 437 |
+
|
| 438 |
+
for idx, block in enumerate(self.blocks):
|
| 439 |
+
x = block(x, **kwargs)
|
| 440 |
+
x = self.after_transformer_block(idx, x, motion_vec)
|
| 441 |
+
|
| 442 |
+
# head
|
| 443 |
+
x = self.head(x, e)
|
| 444 |
+
|
| 445 |
+
if self.use_context_parallel:
|
| 446 |
+
x = gather_forward(x, dim=1)
|
| 447 |
+
|
| 448 |
+
# unpatchify
|
| 449 |
+
x = self.unpatchify(x, grid_sizes)
|
| 450 |
+
return [u.float() for u in x]
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def unpatchify(self, x, grid_sizes):
|
| 454 |
+
r"""
|
| 455 |
+
Reconstruct video tensors from patch embeddings.
|
| 456 |
+
|
| 457 |
+
Args:
|
| 458 |
+
x (List[Tensor]):
|
| 459 |
+
List of patchified features, each with shape [L, C_out * prod(patch_size)]
|
| 460 |
+
grid_sizes (Tensor):
|
| 461 |
+
Original spatial-temporal grid dimensions before patching,
|
| 462 |
+
shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
|
| 463 |
+
|
| 464 |
+
Returns:
|
| 465 |
+
List[Tensor]:
|
| 466 |
+
Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
|
| 467 |
+
"""
|
| 468 |
+
|
| 469 |
+
c = self.out_dim
|
| 470 |
+
out = []
|
| 471 |
+
for u, v in zip(x, grid_sizes.tolist()):
|
| 472 |
+
u = u[:math.prod(v)].view(*v, *self.patch_size, c)
|
| 473 |
+
u = torch.einsum('fhwpqrc->cfphqwr', u)
|
| 474 |
+
u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
|
| 475 |
+
out.append(u)
|
| 476 |
+
return out
|
| 477 |
+
|
| 478 |
+
def init_weights(self):
|
| 479 |
+
r"""
|
| 480 |
+
Initialize model parameters using Xavier initialization.
|
| 481 |
+
"""
|
| 482 |
+
|
| 483 |
+
# basic init
|
| 484 |
+
for m in self.modules():
|
| 485 |
+
if isinstance(m, nn.Linear):
|
| 486 |
+
nn.init.xavier_uniform_(m.weight)
|
| 487 |
+
if m.bias is not None:
|
| 488 |
+
nn.init.zeros_(m.bias)
|
| 489 |
+
|
| 490 |
+
# init embeddings
|
| 491 |
+
nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
|
| 492 |
+
for m in self.text_embedding.modules():
|
| 493 |
+
if isinstance(m, nn.Linear):
|
| 494 |
+
nn.init.normal_(m.weight, std=.02)
|
| 495 |
+
for m in self.time_embedding.modules():
|
| 496 |
+
if isinstance(m, nn.Linear):
|
| 497 |
+
nn.init.normal_(m.weight, std=.02)
|
| 498 |
+
|
| 499 |
+
# init output layer
|
| 500 |
+
nn.init.zeros_(self.head.head.weight)
|
wan/modules/animate/motion_encoder.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from ``https://github.com/wyhsirius/LIA``
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
def custom_qr(input_tensor):
|
| 9 |
+
original_dtype = input_tensor.dtype
|
| 10 |
+
if original_dtype == torch.bfloat16:
|
| 11 |
+
q, r = torch.linalg.qr(input_tensor.to(torch.float32))
|
| 12 |
+
return q.to(original_dtype), r.to(original_dtype)
|
| 13 |
+
return torch.linalg.qr(input_tensor)
|
| 14 |
+
|
| 15 |
+
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
| 16 |
+
return F.leaky_relu(input + bias, negative_slope) * scale
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
|
| 20 |
+
_, minor, in_h, in_w = input.shape
|
| 21 |
+
kernel_h, kernel_w = kernel.shape
|
| 22 |
+
|
| 23 |
+
out = input.view(-1, minor, in_h, 1, in_w, 1)
|
| 24 |
+
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
|
| 25 |
+
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
|
| 26 |
+
|
| 27 |
+
out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
| 28 |
+
out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
|
| 29 |
+
max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
|
| 30 |
+
|
| 31 |
+
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
| 32 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
| 33 |
+
out = F.conv2d(out, w)
|
| 34 |
+
out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
| 35 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
|
| 36 |
+
return out[:, :, ::down_y, ::down_x]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
| 40 |
+
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def make_kernel(k):
|
| 44 |
+
k = torch.tensor(k, dtype=torch.float32)
|
| 45 |
+
if k.ndim == 1:
|
| 46 |
+
k = k[None, :] * k[:, None]
|
| 47 |
+
k /= k.sum()
|
| 48 |
+
return k
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class FusedLeakyReLU(nn.Module):
|
| 52 |
+
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
| 55 |
+
self.negative_slope = negative_slope
|
| 56 |
+
self.scale = scale
|
| 57 |
+
|
| 58 |
+
def forward(self, input):
|
| 59 |
+
out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
| 60 |
+
return out
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class Blur(nn.Module):
|
| 64 |
+
def __init__(self, kernel, pad, upsample_factor=1):
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
kernel = make_kernel(kernel)
|
| 68 |
+
|
| 69 |
+
if upsample_factor > 1:
|
| 70 |
+
kernel = kernel * (upsample_factor ** 2)
|
| 71 |
+
|
| 72 |
+
self.register_buffer('kernel', kernel)
|
| 73 |
+
|
| 74 |
+
self.pad = pad
|
| 75 |
+
|
| 76 |
+
def forward(self, input):
|
| 77 |
+
return upfirdn2d(input, self.kernel, pad=self.pad)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class ScaledLeakyReLU(nn.Module):
|
| 81 |
+
def __init__(self, negative_slope=0.2):
|
| 82 |
+
super().__init__()
|
| 83 |
+
|
| 84 |
+
self.negative_slope = negative_slope
|
| 85 |
+
|
| 86 |
+
def forward(self, input):
|
| 87 |
+
return F.leaky_relu(input, negative_slope=self.negative_slope)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class EqualConv2d(nn.Module):
|
| 91 |
+
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
|
| 92 |
+
super().__init__()
|
| 93 |
+
|
| 94 |
+
self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
|
| 95 |
+
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
| 96 |
+
|
| 97 |
+
self.stride = stride
|
| 98 |
+
self.padding = padding
|
| 99 |
+
|
| 100 |
+
if bias:
|
| 101 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
| 102 |
+
else:
|
| 103 |
+
self.bias = None
|
| 104 |
+
|
| 105 |
+
def forward(self, input):
|
| 106 |
+
|
| 107 |
+
return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
|
| 108 |
+
|
| 109 |
+
def __repr__(self):
|
| 110 |
+
return (
|
| 111 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
| 112 |
+
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class EqualLinear(nn.Module):
|
| 117 |
+
def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
|
| 118 |
+
super().__init__()
|
| 119 |
+
|
| 120 |
+
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
| 121 |
+
|
| 122 |
+
if bias:
|
| 123 |
+
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
| 124 |
+
else:
|
| 125 |
+
self.bias = None
|
| 126 |
+
|
| 127 |
+
self.activation = activation
|
| 128 |
+
|
| 129 |
+
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
| 130 |
+
self.lr_mul = lr_mul
|
| 131 |
+
|
| 132 |
+
def forward(self, input):
|
| 133 |
+
|
| 134 |
+
if self.activation:
|
| 135 |
+
out = F.linear(input, self.weight * self.scale)
|
| 136 |
+
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
| 137 |
+
else:
|
| 138 |
+
out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
|
| 139 |
+
|
| 140 |
+
return out
|
| 141 |
+
|
| 142 |
+
def __repr__(self):
|
| 143 |
+
return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class ConvLayer(nn.Sequential):
|
| 147 |
+
def __init__(
|
| 148 |
+
self,
|
| 149 |
+
in_channel,
|
| 150 |
+
out_channel,
|
| 151 |
+
kernel_size,
|
| 152 |
+
downsample=False,
|
| 153 |
+
blur_kernel=[1, 3, 3, 1],
|
| 154 |
+
bias=True,
|
| 155 |
+
activate=True,
|
| 156 |
+
):
|
| 157 |
+
layers = []
|
| 158 |
+
|
| 159 |
+
if downsample:
|
| 160 |
+
factor = 2
|
| 161 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
| 162 |
+
pad0 = (p + 1) // 2
|
| 163 |
+
pad1 = p // 2
|
| 164 |
+
|
| 165 |
+
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
| 166 |
+
|
| 167 |
+
stride = 2
|
| 168 |
+
self.padding = 0
|
| 169 |
+
|
| 170 |
+
else:
|
| 171 |
+
stride = 1
|
| 172 |
+
self.padding = kernel_size // 2
|
| 173 |
+
|
| 174 |
+
layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
|
| 175 |
+
bias=bias and not activate))
|
| 176 |
+
|
| 177 |
+
if activate:
|
| 178 |
+
if bias:
|
| 179 |
+
layers.append(FusedLeakyReLU(out_channel))
|
| 180 |
+
else:
|
| 181 |
+
layers.append(ScaledLeakyReLU(0.2))
|
| 182 |
+
|
| 183 |
+
super().__init__(*layers)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class ResBlock(nn.Module):
|
| 187 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
| 188 |
+
super().__init__()
|
| 189 |
+
|
| 190 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
| 191 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
| 192 |
+
|
| 193 |
+
self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
|
| 194 |
+
|
| 195 |
+
def forward(self, input):
|
| 196 |
+
out = self.conv1(input)
|
| 197 |
+
out = self.conv2(out)
|
| 198 |
+
|
| 199 |
+
skip = self.skip(input)
|
| 200 |
+
out = (out + skip) / math.sqrt(2)
|
| 201 |
+
|
| 202 |
+
return out
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class EncoderApp(nn.Module):
|
| 206 |
+
def __init__(self, size, w_dim=512):
|
| 207 |
+
super(EncoderApp, self).__init__()
|
| 208 |
+
|
| 209 |
+
channels = {
|
| 210 |
+
4: 512,
|
| 211 |
+
8: 512,
|
| 212 |
+
16: 512,
|
| 213 |
+
32: 512,
|
| 214 |
+
64: 256,
|
| 215 |
+
128: 128,
|
| 216 |
+
256: 64,
|
| 217 |
+
512: 32,
|
| 218 |
+
1024: 16
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
self.w_dim = w_dim
|
| 222 |
+
log_size = int(math.log(size, 2))
|
| 223 |
+
|
| 224 |
+
self.convs = nn.ModuleList()
|
| 225 |
+
self.convs.append(ConvLayer(3, channels[size], 1))
|
| 226 |
+
|
| 227 |
+
in_channel = channels[size]
|
| 228 |
+
for i in range(log_size, 2, -1):
|
| 229 |
+
out_channel = channels[2 ** (i - 1)]
|
| 230 |
+
self.convs.append(ResBlock(in_channel, out_channel))
|
| 231 |
+
in_channel = out_channel
|
| 232 |
+
|
| 233 |
+
self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
|
| 234 |
+
|
| 235 |
+
def forward(self, x):
|
| 236 |
+
|
| 237 |
+
res = []
|
| 238 |
+
h = x
|
| 239 |
+
for conv in self.convs:
|
| 240 |
+
h = conv(h)
|
| 241 |
+
res.append(h)
|
| 242 |
+
|
| 243 |
+
return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class Encoder(nn.Module):
|
| 247 |
+
def __init__(self, size, dim=512, dim_motion=20):
|
| 248 |
+
super(Encoder, self).__init__()
|
| 249 |
+
|
| 250 |
+
# appearance netmork
|
| 251 |
+
self.net_app = EncoderApp(size, dim)
|
| 252 |
+
|
| 253 |
+
# motion network
|
| 254 |
+
fc = [EqualLinear(dim, dim)]
|
| 255 |
+
for i in range(3):
|
| 256 |
+
fc.append(EqualLinear(dim, dim))
|
| 257 |
+
|
| 258 |
+
fc.append(EqualLinear(dim, dim_motion))
|
| 259 |
+
self.fc = nn.Sequential(*fc)
|
| 260 |
+
|
| 261 |
+
def enc_app(self, x):
|
| 262 |
+
h_source = self.net_app(x)
|
| 263 |
+
return h_source
|
| 264 |
+
|
| 265 |
+
def enc_motion(self, x):
|
| 266 |
+
h, _ = self.net_app(x)
|
| 267 |
+
h_motion = self.fc(h)
|
| 268 |
+
return h_motion
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class Direction(nn.Module):
|
| 272 |
+
def __init__(self, motion_dim):
|
| 273 |
+
super(Direction, self).__init__()
|
| 274 |
+
self.weight = nn.Parameter(torch.randn(512, motion_dim))
|
| 275 |
+
|
| 276 |
+
def forward(self, input):
|
| 277 |
+
|
| 278 |
+
weight = self.weight + 1e-8
|
| 279 |
+
Q, R = custom_qr(weight)
|
| 280 |
+
if input is None:
|
| 281 |
+
return Q
|
| 282 |
+
else:
|
| 283 |
+
input_diag = torch.diag_embed(input) # alpha, diagonal matrix
|
| 284 |
+
out = torch.matmul(input_diag, Q.T)
|
| 285 |
+
out = torch.sum(out, dim=1)
|
| 286 |
+
return out
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class Synthesis(nn.Module):
|
| 290 |
+
def __init__(self, motion_dim):
|
| 291 |
+
super(Synthesis, self).__init__()
|
| 292 |
+
self.direction = Direction(motion_dim)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class Generator(nn.Module):
|
| 296 |
+
def __init__(self, size, style_dim=512, motion_dim=20):
|
| 297 |
+
super().__init__()
|
| 298 |
+
|
| 299 |
+
self.enc = Encoder(size, style_dim, motion_dim)
|
| 300 |
+
self.dec = Synthesis(motion_dim)
|
| 301 |
+
|
| 302 |
+
def get_motion(self, img):
|
| 303 |
+
#motion_feat = self.enc.enc_motion(img)
|
| 304 |
+
motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
|
| 305 |
+
with torch.cuda.amp.autocast(dtype=torch.float32):
|
| 306 |
+
motion = self.dec.direction(motion_feat)
|
| 307 |
+
return motion
|
wan/modules/animate/preprocess/UserGuider.md
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Wan-animate Preprocessing User Guider
|
| 2 |
+
|
| 3 |
+
## 1. Introductions
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
Wan-animate offers two generation modes: `animation` and `replacement`. While both modes extract the skeleton from the reference video, they each have a distinct preprocessing pipeline.
|
| 7 |
+
|
| 8 |
+
### 1.1 Animation Mode
|
| 9 |
+
|
| 10 |
+
In this mode, it is highly recommended to enable pose retargeting, especially if the body proportions of the reference and driving characters are dissimilar.
|
| 11 |
+
|
| 12 |
+
- A simplified version of pose retargeting pipeline is provided to help developers quickly implement this functionality.
|
| 13 |
+
|
| 14 |
+
- **NOTE:** Due to the potential complexity of input data, the results from this simplified retargeting version are NOT guaranteed to be perfect. It is strongly advised to verify the preprocessing results before proceeding.
|
| 15 |
+
|
| 16 |
+
- Community contributions to improve on this feature are welcome.
|
| 17 |
+
|
| 18 |
+
### 1.2 Replacement Mode
|
| 19 |
+
|
| 20 |
+
- Pose retargeting is DISABLED by default in this mode. This is a deliberate choice to account for potential spatial interactions between the character and the environment.
|
| 21 |
+
|
| 22 |
+
- **WARNING**: If there is a significant mismatch in body proportions between the reference and driving characters, artifacts or deformations may appear in the final output.
|
| 23 |
+
|
| 24 |
+
- A simplified version for extracting the character's mask is also provided.
|
| 25 |
+
- **WARNING:** This mask extraction process is designed for **single-person videos ONLY** and may produce incorrect results or fail in multi-person videos (incorrect pose tracking). For multi-person video, users are required to either develop their own solution or integrate a suitable open-source tool.
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## 2. Preprocessing Instructions and Recommendations
|
| 30 |
+
|
| 31 |
+
### 2.1 Basic Usage
|
| 32 |
+
|
| 33 |
+
- The preprocessing process requires some additional models, including pose detection (mandatory), and mask extraction and image editing models (optional, as needed). Place them according to the following directory structure:
|
| 34 |
+
```
|
| 35 |
+
/path/to/your/ckpt_path/
|
| 36 |
+
├── det/
|
| 37 |
+
│ └── yolov10m.onnx
|
| 38 |
+
├── pose2d/
|
| 39 |
+
│ └── vitpose_h_wholebody.onnx
|
| 40 |
+
├── sam2/
|
| 41 |
+
│ └── sam2_hiera_large.pt
|
| 42 |
+
└── FLUX.1-Kontext-dev/
|
| 43 |
+
```
|
| 44 |
+
- `video_path`, `refer_path`, and `save_path` correspond to the paths for the input driving video, the character image, and the preprocessed results.
|
| 45 |
+
|
| 46 |
+
- When using `animation` mode, two videos, `src_face.mp4` and `src_pose.mp4`, will be generated in `save_path`. When using `replacement` mode, two additional videos, `src_bg.mp4` and `src_mask.mp4`, will also be generated.
|
| 47 |
+
|
| 48 |
+
- The `resolution_area` parameter determines the resolution for both preprocessing and the generation model. Its size is determined by pixel area.
|
| 49 |
+
|
| 50 |
+
- The `fps` parameter can specify the frame rate for video processing. A lower frame rate can improve generation efficiency, but may cause stuttering or choppiness.
|
| 51 |
+
|
| 52 |
+
---
|
| 53 |
+
|
| 54 |
+
### 2.2 Animation Mode
|
| 55 |
+
|
| 56 |
+
- We support three forms: not using pose retargeting, using basic pose retargeting, and using enhanced pose retargeting based on the `FLUX.1-Kontext-dev` image editing model. These are specified via the `retarget_flag` and `use_flux` parameters.
|
| 57 |
+
|
| 58 |
+
- Specifying `retarget_flag` to use basic pose retargeting requires ensuring that both the reference character and the character in the first frame of the driving video are in a front-facing, stretched pose.
|
| 59 |
+
|
| 60 |
+
- Other than that, we recommend using enhanced pose retargeting by specifying both `retarget_flag` and `use_flux`. **NOTE:** Due to the limited capabilities of `FLUX.1-Kontext-dev`, it is NOT guaranteed to produce the expected results (e.g., consistency is not maintained, the pose is incorrect, etc.). It is recommended to check the intermediate results as well as the finally generated pose video; both are stored in `save_path`. Of course, users can also use a better image editing model, or explore the prompts for Flux on their own.
|
| 61 |
+
|
| 62 |
+
---
|
| 63 |
+
|
| 64 |
+
### 2.3 Replacement Mode
|
| 65 |
+
|
| 66 |
+
- Specifying `replace_flag` to enable data preprocessing for this mode. The preprocessing will additionally process a mask for the character in the video, and its size and shape can be adjusted by specifying some parameters.
|
| 67 |
+
- `iterations` and `k` can make the mask larger, covering more area.
|
| 68 |
+
- `w_len` and `h_len` can adjust the mask's shape. Smaller values will make the outline coarser, while larger values will make it finer.
|
| 69 |
+
|
| 70 |
+
- A smaller, finer-contoured mask can allow for more of the original background to be preserved, but may potentially limit the character's generation area (considering potential appearance differences, this can lead to some shape leakage). A larger, coarser mask can allow the character generation to be more flexible and consistent, but because it includes more of the background, it might affect the background's consistency. We recommend users to adjust the relevant parameters based on their specific input data.
|
wan/modules/animate/preprocess/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
from .process_pipepline import ProcessPipeline
|
| 3 |
+
from .video_predictor import SAM2VideoPredictor
|
wan/modules/animate/preprocess/human_visualization.py
ADDED
|
@@ -0,0 +1,1357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import os
|
| 3 |
+
import cv2
|
| 4 |
+
import time
|
| 5 |
+
import math
|
| 6 |
+
import matplotlib
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import numpy as np
|
| 9 |
+
from typing import Dict, List
|
| 10 |
+
import random
|
| 11 |
+
from pose2d_utils import AAPoseMeta
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def draw_handpose(canvas, keypoints, hand_score_th=0.6):
|
| 15 |
+
"""
|
| 16 |
+
Draw keypoints and connections representing hand pose on a given canvas.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
|
| 20 |
+
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
|
| 21 |
+
or None if no keypoints are present.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
|
| 25 |
+
|
| 26 |
+
Note:
|
| 27 |
+
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
|
| 28 |
+
"""
|
| 29 |
+
eps = 0.01
|
| 30 |
+
|
| 31 |
+
H, W, C = canvas.shape
|
| 32 |
+
stickwidth = max(int(min(H, W) / 200), 1)
|
| 33 |
+
|
| 34 |
+
edges = [
|
| 35 |
+
[0, 1],
|
| 36 |
+
[1, 2],
|
| 37 |
+
[2, 3],
|
| 38 |
+
[3, 4],
|
| 39 |
+
[0, 5],
|
| 40 |
+
[5, 6],
|
| 41 |
+
[6, 7],
|
| 42 |
+
[7, 8],
|
| 43 |
+
[0, 9],
|
| 44 |
+
[9, 10],
|
| 45 |
+
[10, 11],
|
| 46 |
+
[11, 12],
|
| 47 |
+
[0, 13],
|
| 48 |
+
[13, 14],
|
| 49 |
+
[14, 15],
|
| 50 |
+
[15, 16],
|
| 51 |
+
[0, 17],
|
| 52 |
+
[17, 18],
|
| 53 |
+
[18, 19],
|
| 54 |
+
[19, 20],
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
for ie, (e1, e2) in enumerate(edges):
|
| 58 |
+
k1 = keypoints[e1]
|
| 59 |
+
k2 = keypoints[e2]
|
| 60 |
+
if k1 is None or k2 is None:
|
| 61 |
+
continue
|
| 62 |
+
if k1[2] < hand_score_th or k2[2] < hand_score_th:
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
x1 = int(k1[0])
|
| 66 |
+
y1 = int(k1[1])
|
| 67 |
+
x2 = int(k2[0])
|
| 68 |
+
y2 = int(k2[1])
|
| 69 |
+
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
| 70 |
+
cv2.line(
|
| 71 |
+
canvas,
|
| 72 |
+
(x1, y1),
|
| 73 |
+
(x2, y2),
|
| 74 |
+
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
|
| 75 |
+
thickness=stickwidth,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
for keypoint in keypoints:
|
| 79 |
+
|
| 80 |
+
if keypoint is None:
|
| 81 |
+
continue
|
| 82 |
+
if keypoint[2] < hand_score_th:
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
x, y = keypoint[0], keypoint[1]
|
| 86 |
+
x = int(x)
|
| 87 |
+
y = int(y)
|
| 88 |
+
if x > eps and y > eps:
|
| 89 |
+
cv2.circle(canvas, (x, y), stickwidth, (0, 0, 255), thickness=-1)
|
| 90 |
+
return canvas
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def draw_handpose_new(canvas, keypoints, stickwidth_type='v2', hand_score_th=0.6):
|
| 94 |
+
"""
|
| 95 |
+
Draw keypoints and connections representing hand pose on a given canvas.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
|
| 99 |
+
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
|
| 100 |
+
or None if no keypoints are present.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
|
| 104 |
+
|
| 105 |
+
Note:
|
| 106 |
+
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
|
| 107 |
+
"""
|
| 108 |
+
eps = 0.01
|
| 109 |
+
|
| 110 |
+
H, W, C = canvas.shape
|
| 111 |
+
if stickwidth_type == 'v1':
|
| 112 |
+
stickwidth = max(int(min(H, W) / 200), 1)
|
| 113 |
+
elif stickwidth_type == 'v2':
|
| 114 |
+
stickwidth = max(max(int(min(H, W) / 200) - 1, 1) // 2, 1)
|
| 115 |
+
|
| 116 |
+
edges = [
|
| 117 |
+
[0, 1],
|
| 118 |
+
[1, 2],
|
| 119 |
+
[2, 3],
|
| 120 |
+
[3, 4],
|
| 121 |
+
[0, 5],
|
| 122 |
+
[5, 6],
|
| 123 |
+
[6, 7],
|
| 124 |
+
[7, 8],
|
| 125 |
+
[0, 9],
|
| 126 |
+
[9, 10],
|
| 127 |
+
[10, 11],
|
| 128 |
+
[11, 12],
|
| 129 |
+
[0, 13],
|
| 130 |
+
[13, 14],
|
| 131 |
+
[14, 15],
|
| 132 |
+
[15, 16],
|
| 133 |
+
[0, 17],
|
| 134 |
+
[17, 18],
|
| 135 |
+
[18, 19],
|
| 136 |
+
[19, 20],
|
| 137 |
+
]
|
| 138 |
+
|
| 139 |
+
for ie, (e1, e2) in enumerate(edges):
|
| 140 |
+
k1 = keypoints[e1]
|
| 141 |
+
k2 = keypoints[e2]
|
| 142 |
+
if k1 is None or k2 is None:
|
| 143 |
+
continue
|
| 144 |
+
if k1[2] < hand_score_th or k2[2] < hand_score_th:
|
| 145 |
+
continue
|
| 146 |
+
|
| 147 |
+
x1 = int(k1[0])
|
| 148 |
+
y1 = int(k1[1])
|
| 149 |
+
x2 = int(k2[0])
|
| 150 |
+
y2 = int(k2[1])
|
| 151 |
+
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
| 152 |
+
cv2.line(
|
| 153 |
+
canvas,
|
| 154 |
+
(x1, y1),
|
| 155 |
+
(x2, y2),
|
| 156 |
+
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
|
| 157 |
+
thickness=stickwidth,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
for keypoint in keypoints:
|
| 161 |
+
|
| 162 |
+
if keypoint is None:
|
| 163 |
+
continue
|
| 164 |
+
if keypoint[2] < hand_score_th:
|
| 165 |
+
continue
|
| 166 |
+
|
| 167 |
+
x, y = keypoint[0], keypoint[1]
|
| 168 |
+
x = int(x)
|
| 169 |
+
y = int(y)
|
| 170 |
+
if x > eps and y > eps:
|
| 171 |
+
cv2.circle(canvas, (x, y), stickwidth, (0, 0, 255), thickness=-1)
|
| 172 |
+
return canvas
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def draw_ellipse_by_2kp(img, keypoint1, keypoint2, color, threshold=0.6):
|
| 176 |
+
H, W, C = img.shape
|
| 177 |
+
stickwidth = max(int(min(H, W) / 200), 1)
|
| 178 |
+
|
| 179 |
+
if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
|
| 180 |
+
return img
|
| 181 |
+
|
| 182 |
+
Y = np.array([keypoint1[0], keypoint2[0]])
|
| 183 |
+
X = np.array([keypoint1[1], keypoint2[1]])
|
| 184 |
+
mX = np.mean(X)
|
| 185 |
+
mY = np.mean(Y)
|
| 186 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
| 187 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
| 188 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
| 189 |
+
cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
|
| 190 |
+
return img
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def split_pose2d_kps_to_aa(kp2ds: np.ndarray) -> List[np.ndarray]:
|
| 194 |
+
"""Convert the 133 keypoints from pose2d to body and hands keypoints.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
kp2ds (np.ndarray): [133, 2]
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
List[np.ndarray]: _description_
|
| 201 |
+
"""
|
| 202 |
+
kp2ds_body = (
|
| 203 |
+
kp2ds[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]]
|
| 204 |
+
+ kp2ds[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]
|
| 205 |
+
) / 2
|
| 206 |
+
kp2ds_lhand = kp2ds[91:112]
|
| 207 |
+
kp2ds_rhand = kp2ds[112:133]
|
| 208 |
+
return kp2ds_body.copy(), kp2ds_lhand.copy(), kp2ds_rhand.copy()
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def draw_aapose_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200, draw_hand=True, draw_head=True):
|
| 212 |
+
kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)
|
| 213 |
+
kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
|
| 214 |
+
kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
|
| 215 |
+
pose_img = draw_aapose(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand, stick_width_norm=stick_width_norm, draw_hand=draw_hand, draw_head=draw_head)
|
| 216 |
+
return pose_img
|
| 217 |
+
|
| 218 |
+
def draw_aapose_by_meta_new(img, meta: AAPoseMeta, threshold=0.5, stickwidth_type='v2', draw_hand=True, draw_head=True):
|
| 219 |
+
kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)
|
| 220 |
+
kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
|
| 221 |
+
kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
|
| 222 |
+
pose_img = draw_aapose_new(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand,
|
| 223 |
+
stickwidth_type=stickwidth_type, draw_hand=draw_hand, draw_head=draw_head)
|
| 224 |
+
return pose_img
|
| 225 |
+
|
| 226 |
+
def draw_hand_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200):
|
| 227 |
+
kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None] * 0], axis=1)
|
| 228 |
+
kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
|
| 229 |
+
kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
|
| 230 |
+
pose_img = draw_aapose(img, kp2ds, threshold, kp2ds_lhand=kp2ds_lhand, kp2ds_rhand=kp2ds_rhand, stick_width_norm=stick_width_norm, draw_hand=True, draw_head=False)
|
| 231 |
+
return pose_img
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def draw_aaface_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=200, draw_hand=False, draw_head=True):
|
| 235 |
+
kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)
|
| 236 |
+
# kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
|
| 237 |
+
# kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
|
| 238 |
+
pose_img = draw_M(img, kp2ds, threshold, kp2ds_lhand=None, kp2ds_rhand=None, stick_width_norm=stick_width_norm, draw_hand=draw_hand, draw_head=draw_head)
|
| 239 |
+
return pose_img
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def draw_aanose_by_meta(img, meta: AAPoseMeta, threshold=0.5, stick_width_norm=100, draw_hand=False):
|
| 243 |
+
kp2ds = np.concatenate([meta.kps_body, meta.kps_body_p[:, None]], axis=1)
|
| 244 |
+
# kp2ds_lhand = np.concatenate([meta.kps_lhand, meta.kps_lhand_p[:, None]], axis=1)
|
| 245 |
+
# kp2ds_rhand = np.concatenate([meta.kps_rhand, meta.kps_rhand_p[:, None]], axis=1)
|
| 246 |
+
pose_img = draw_nose(img, kp2ds, threshold, kp2ds_lhand=None, kp2ds_rhand=None, stick_width_norm=stick_width_norm, draw_hand=draw_hand)
|
| 247 |
+
return pose_img
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def gen_face_motion_seq(img, metas: List[AAPoseMeta], threshold=0.5, stick_width_norm=200):
|
| 251 |
+
|
| 252 |
+
return
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def draw_M(
|
| 256 |
+
img,
|
| 257 |
+
kp2ds,
|
| 258 |
+
threshold=0.6,
|
| 259 |
+
data_to_json=None,
|
| 260 |
+
idx=-1,
|
| 261 |
+
kp2ds_lhand=None,
|
| 262 |
+
kp2ds_rhand=None,
|
| 263 |
+
draw_hand=False,
|
| 264 |
+
stick_width_norm=200,
|
| 265 |
+
draw_head=True
|
| 266 |
+
):
|
| 267 |
+
"""
|
| 268 |
+
Draw keypoints and connections representing hand pose on a given canvas.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
|
| 272 |
+
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
|
| 273 |
+
or None if no keypoints are present.
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
|
| 277 |
+
|
| 278 |
+
Note:
|
| 279 |
+
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
|
| 280 |
+
"""
|
| 281 |
+
|
| 282 |
+
new_kep_list = [
|
| 283 |
+
"Nose",
|
| 284 |
+
"Neck",
|
| 285 |
+
"RShoulder",
|
| 286 |
+
"RElbow",
|
| 287 |
+
"RWrist", # No.4
|
| 288 |
+
"LShoulder",
|
| 289 |
+
"LElbow",
|
| 290 |
+
"LWrist", # No.7
|
| 291 |
+
"RHip",
|
| 292 |
+
"RKnee",
|
| 293 |
+
"RAnkle", # No.10
|
| 294 |
+
"LHip",
|
| 295 |
+
"LKnee",
|
| 296 |
+
"LAnkle", # No.13
|
| 297 |
+
"REye",
|
| 298 |
+
"LEye",
|
| 299 |
+
"REar",
|
| 300 |
+
"LEar",
|
| 301 |
+
"LToe",
|
| 302 |
+
"RToe",
|
| 303 |
+
]
|
| 304 |
+
# kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \
|
| 305 |
+
# kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
|
| 306 |
+
kp2ds = kp2ds.copy()
|
| 307 |
+
# import ipdb; ipdb.set_trace()
|
| 308 |
+
kp2ds[[1,2,3,4,5,6,7,8,9,10,11,12,13,18,19], 2] = 0
|
| 309 |
+
if not draw_head:
|
| 310 |
+
kp2ds[[0,14,15,16,17], 2] = 0
|
| 311 |
+
kp2ds_body = kp2ds
|
| 312 |
+
# kp2ds_body = kp2ds_body[:18]
|
| 313 |
+
|
| 314 |
+
# kp2ds_lhand = kp2ds.copy()[91:112]
|
| 315 |
+
# kp2ds_rhand = kp2ds.copy()[112:133]
|
| 316 |
+
|
| 317 |
+
limbSeq = [
|
| 318 |
+
# [2, 3],
|
| 319 |
+
# [2, 6], # shoulders
|
| 320 |
+
# [3, 4],
|
| 321 |
+
# [4, 5], # left arm
|
| 322 |
+
# [6, 7],
|
| 323 |
+
# [7, 8], # right arm
|
| 324 |
+
# [2, 9],
|
| 325 |
+
# [9, 10],
|
| 326 |
+
# [10, 11], # right leg
|
| 327 |
+
# [2, 12],
|
| 328 |
+
# [12, 13],
|
| 329 |
+
# [13, 14], # left leg
|
| 330 |
+
# [2, 1],
|
| 331 |
+
[1, 15],
|
| 332 |
+
[15, 17],
|
| 333 |
+
[1, 16],
|
| 334 |
+
[16, 18], # face (nose, eyes, ears)
|
| 335 |
+
# [14, 19],
|
| 336 |
+
# [11, 20], # foot
|
| 337 |
+
]
|
| 338 |
+
|
| 339 |
+
colors = [
|
| 340 |
+
# [255, 0, 0],
|
| 341 |
+
# [255, 85, 0],
|
| 342 |
+
# [255, 170, 0],
|
| 343 |
+
# [255, 255, 0],
|
| 344 |
+
# [170, 255, 0],
|
| 345 |
+
# [85, 255, 0],
|
| 346 |
+
# [0, 255, 0],
|
| 347 |
+
# [0, 255, 85],
|
| 348 |
+
# [0, 255, 170],
|
| 349 |
+
# [0, 255, 255],
|
| 350 |
+
# [0, 170, 255],
|
| 351 |
+
# [0, 85, 255],
|
| 352 |
+
# [0, 0, 255],
|
| 353 |
+
# [85, 0, 255],
|
| 354 |
+
[170, 0, 255],
|
| 355 |
+
[255, 0, 255],
|
| 356 |
+
[255, 0, 170],
|
| 357 |
+
[255, 0, 85],
|
| 358 |
+
# foot
|
| 359 |
+
# [200, 200, 0],
|
| 360 |
+
# [100, 100, 0],
|
| 361 |
+
]
|
| 362 |
+
|
| 363 |
+
H, W, C = img.shape
|
| 364 |
+
stickwidth = max(int(min(H, W) / stick_width_norm), 1)
|
| 365 |
+
|
| 366 |
+
for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):
|
| 367 |
+
keypoint1 = kp2ds_body[k1_index - 1]
|
| 368 |
+
keypoint2 = kp2ds_body[k2_index - 1]
|
| 369 |
+
|
| 370 |
+
if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
|
| 371 |
+
continue
|
| 372 |
+
|
| 373 |
+
Y = np.array([keypoint1[0], keypoint2[0]])
|
| 374 |
+
X = np.array([keypoint1[1], keypoint2[1]])
|
| 375 |
+
mX = np.mean(X)
|
| 376 |
+
mY = np.mean(Y)
|
| 377 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
| 378 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
| 379 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
| 380 |
+
cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
|
| 381 |
+
|
| 382 |
+
for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):
|
| 383 |
+
if keypoint[-1] < threshold:
|
| 384 |
+
continue
|
| 385 |
+
x, y = keypoint[0], keypoint[1]
|
| 386 |
+
# cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
|
| 387 |
+
cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)
|
| 388 |
+
|
| 389 |
+
if draw_hand:
|
| 390 |
+
img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold)
|
| 391 |
+
img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold)
|
| 392 |
+
|
| 393 |
+
kp2ds_body[:, 0] /= W
|
| 394 |
+
kp2ds_body[:, 1] /= H
|
| 395 |
+
|
| 396 |
+
if data_to_json is not None:
|
| 397 |
+
if idx == -1:
|
| 398 |
+
data_to_json.append(
|
| 399 |
+
{
|
| 400 |
+
"image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1),
|
| 401 |
+
"height": H,
|
| 402 |
+
"width": W,
|
| 403 |
+
"category_id": 1,
|
| 404 |
+
"keypoints_body": kp2ds_body.tolist(),
|
| 405 |
+
"keypoints_left_hand": kp2ds_lhand.tolist(),
|
| 406 |
+
"keypoints_right_hand": kp2ds_rhand.tolist(),
|
| 407 |
+
}
|
| 408 |
+
)
|
| 409 |
+
else:
|
| 410 |
+
data_to_json[idx] = {
|
| 411 |
+
"image_id": "frame_{:05d}.jpg".format(idx + 1),
|
| 412 |
+
"height": H,
|
| 413 |
+
"width": W,
|
| 414 |
+
"category_id": 1,
|
| 415 |
+
"keypoints_body": kp2ds_body.tolist(),
|
| 416 |
+
"keypoints_left_hand": kp2ds_lhand.tolist(),
|
| 417 |
+
"keypoints_right_hand": kp2ds_rhand.tolist(),
|
| 418 |
+
}
|
| 419 |
+
return img
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def draw_nose(
|
| 423 |
+
img,
|
| 424 |
+
kp2ds,
|
| 425 |
+
threshold=0.6,
|
| 426 |
+
data_to_json=None,
|
| 427 |
+
idx=-1,
|
| 428 |
+
kp2ds_lhand=None,
|
| 429 |
+
kp2ds_rhand=None,
|
| 430 |
+
draw_hand=False,
|
| 431 |
+
stick_width_norm=200,
|
| 432 |
+
):
|
| 433 |
+
"""
|
| 434 |
+
Draw keypoints and connections representing hand pose on a given canvas.
|
| 435 |
+
|
| 436 |
+
Args:
|
| 437 |
+
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
|
| 438 |
+
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
|
| 439 |
+
or None if no keypoints are present.
|
| 440 |
+
|
| 441 |
+
Returns:
|
| 442 |
+
np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
|
| 443 |
+
|
| 444 |
+
Note:
|
| 445 |
+
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
|
| 446 |
+
"""
|
| 447 |
+
|
| 448 |
+
new_kep_list = [
|
| 449 |
+
"Nose",
|
| 450 |
+
"Neck",
|
| 451 |
+
"RShoulder",
|
| 452 |
+
"RElbow",
|
| 453 |
+
"RWrist", # No.4
|
| 454 |
+
"LShoulder",
|
| 455 |
+
"LElbow",
|
| 456 |
+
"LWrist", # No.7
|
| 457 |
+
"RHip",
|
| 458 |
+
"RKnee",
|
| 459 |
+
"RAnkle", # No.10
|
| 460 |
+
"LHip",
|
| 461 |
+
"LKnee",
|
| 462 |
+
"LAnkle", # No.13
|
| 463 |
+
"REye",
|
| 464 |
+
"LEye",
|
| 465 |
+
"REar",
|
| 466 |
+
"LEar",
|
| 467 |
+
"LToe",
|
| 468 |
+
"RToe",
|
| 469 |
+
]
|
| 470 |
+
# kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \
|
| 471 |
+
# kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
|
| 472 |
+
kp2ds = kp2ds.copy()
|
| 473 |
+
kp2ds[1:, 2] = 0
|
| 474 |
+
# kp2ds[0, 2] = 1
|
| 475 |
+
kp2ds_body = kp2ds
|
| 476 |
+
# kp2ds_body = kp2ds_body[:18]
|
| 477 |
+
|
| 478 |
+
# kp2ds_lhand = kp2ds.copy()[91:112]
|
| 479 |
+
# kp2ds_rhand = kp2ds.copy()[112:133]
|
| 480 |
+
|
| 481 |
+
limbSeq = [
|
| 482 |
+
# [2, 3],
|
| 483 |
+
# [2, 6], # shoulders
|
| 484 |
+
# [3, 4],
|
| 485 |
+
# [4, 5], # left arm
|
| 486 |
+
# [6, 7],
|
| 487 |
+
# [7, 8], # right arm
|
| 488 |
+
# [2, 9],
|
| 489 |
+
# [9, 10],
|
| 490 |
+
# [10, 11], # right leg
|
| 491 |
+
# [2, 12],
|
| 492 |
+
# [12, 13],
|
| 493 |
+
# [13, 14], # left leg
|
| 494 |
+
# [2, 1],
|
| 495 |
+
[1, 15],
|
| 496 |
+
[15, 17],
|
| 497 |
+
[1, 16],
|
| 498 |
+
[16, 18], # face (nose, eyes, ears)
|
| 499 |
+
# [14, 19],
|
| 500 |
+
# [11, 20], # foot
|
| 501 |
+
]
|
| 502 |
+
|
| 503 |
+
colors = [
|
| 504 |
+
# [255, 0, 0],
|
| 505 |
+
# [255, 85, 0],
|
| 506 |
+
# [255, 170, 0],
|
| 507 |
+
# [255, 255, 0],
|
| 508 |
+
# [170, 255, 0],
|
| 509 |
+
# [85, 255, 0],
|
| 510 |
+
# [0, 255, 0],
|
| 511 |
+
# [0, 255, 85],
|
| 512 |
+
# [0, 255, 170],
|
| 513 |
+
# [0, 255, 255],
|
| 514 |
+
# [0, 170, 255],
|
| 515 |
+
# [0, 85, 255],
|
| 516 |
+
# [0, 0, 255],
|
| 517 |
+
# [85, 0, 255],
|
| 518 |
+
[170, 0, 255],
|
| 519 |
+
# [255, 0, 255],
|
| 520 |
+
# [255, 0, 170],
|
| 521 |
+
# [255, 0, 85],
|
| 522 |
+
# foot
|
| 523 |
+
# [200, 200, 0],
|
| 524 |
+
# [100, 100, 0],
|
| 525 |
+
]
|
| 526 |
+
|
| 527 |
+
H, W, C = img.shape
|
| 528 |
+
stickwidth = max(int(min(H, W) / stick_width_norm), 1)
|
| 529 |
+
|
| 530 |
+
# for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):
|
| 531 |
+
# keypoint1 = kp2ds_body[k1_index - 1]
|
| 532 |
+
# keypoint2 = kp2ds_body[k2_index - 1]
|
| 533 |
+
|
| 534 |
+
# if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
|
| 535 |
+
# continue
|
| 536 |
+
|
| 537 |
+
# Y = np.array([keypoint1[0], keypoint2[0]])
|
| 538 |
+
# X = np.array([keypoint1[1], keypoint2[1]])
|
| 539 |
+
# mX = np.mean(X)
|
| 540 |
+
# mY = np.mean(Y)
|
| 541 |
+
# length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
| 542 |
+
# angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
| 543 |
+
# polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
| 544 |
+
# cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
|
| 545 |
+
|
| 546 |
+
for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):
|
| 547 |
+
if keypoint[-1] < threshold:
|
| 548 |
+
continue
|
| 549 |
+
x, y = keypoint[0], keypoint[1]
|
| 550 |
+
# cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
|
| 551 |
+
cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)
|
| 552 |
+
|
| 553 |
+
if draw_hand:
|
| 554 |
+
img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold)
|
| 555 |
+
img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold)
|
| 556 |
+
|
| 557 |
+
kp2ds_body[:, 0] /= W
|
| 558 |
+
kp2ds_body[:, 1] /= H
|
| 559 |
+
|
| 560 |
+
if data_to_json is not None:
|
| 561 |
+
if idx == -1:
|
| 562 |
+
data_to_json.append(
|
| 563 |
+
{
|
| 564 |
+
"image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1),
|
| 565 |
+
"height": H,
|
| 566 |
+
"width": W,
|
| 567 |
+
"category_id": 1,
|
| 568 |
+
"keypoints_body": kp2ds_body.tolist(),
|
| 569 |
+
"keypoints_left_hand": kp2ds_lhand.tolist(),
|
| 570 |
+
"keypoints_right_hand": kp2ds_rhand.tolist(),
|
| 571 |
+
}
|
| 572 |
+
)
|
| 573 |
+
else:
|
| 574 |
+
data_to_json[idx] = {
|
| 575 |
+
"image_id": "frame_{:05d}.jpg".format(idx + 1),
|
| 576 |
+
"height": H,
|
| 577 |
+
"width": W,
|
| 578 |
+
"category_id": 1,
|
| 579 |
+
"keypoints_body": kp2ds_body.tolist(),
|
| 580 |
+
"keypoints_left_hand": kp2ds_lhand.tolist(),
|
| 581 |
+
"keypoints_right_hand": kp2ds_rhand.tolist(),
|
| 582 |
+
}
|
| 583 |
+
return img
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def draw_aapose(
|
| 587 |
+
img,
|
| 588 |
+
kp2ds,
|
| 589 |
+
threshold=0.6,
|
| 590 |
+
data_to_json=None,
|
| 591 |
+
idx=-1,
|
| 592 |
+
kp2ds_lhand=None,
|
| 593 |
+
kp2ds_rhand=None,
|
| 594 |
+
draw_hand=False,
|
| 595 |
+
stick_width_norm=200,
|
| 596 |
+
draw_head=True
|
| 597 |
+
):
|
| 598 |
+
"""
|
| 599 |
+
Draw keypoints and connections representing hand pose on a given canvas.
|
| 600 |
+
|
| 601 |
+
Args:
|
| 602 |
+
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
|
| 603 |
+
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
|
| 604 |
+
or None if no keypoints are present.
|
| 605 |
+
|
| 606 |
+
Returns:
|
| 607 |
+
np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
|
| 608 |
+
|
| 609 |
+
Note:
|
| 610 |
+
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
|
| 611 |
+
"""
|
| 612 |
+
|
| 613 |
+
new_kep_list = [
|
| 614 |
+
"Nose",
|
| 615 |
+
"Neck",
|
| 616 |
+
"RShoulder",
|
| 617 |
+
"RElbow",
|
| 618 |
+
"RWrist", # No.4
|
| 619 |
+
"LShoulder",
|
| 620 |
+
"LElbow",
|
| 621 |
+
"LWrist", # No.7
|
| 622 |
+
"RHip",
|
| 623 |
+
"RKnee",
|
| 624 |
+
"RAnkle", # No.10
|
| 625 |
+
"LHip",
|
| 626 |
+
"LKnee",
|
| 627 |
+
"LAnkle", # No.13
|
| 628 |
+
"REye",
|
| 629 |
+
"LEye",
|
| 630 |
+
"REar",
|
| 631 |
+
"LEar",
|
| 632 |
+
"LToe",
|
| 633 |
+
"RToe",
|
| 634 |
+
]
|
| 635 |
+
# kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \
|
| 636 |
+
# kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
|
| 637 |
+
kp2ds = kp2ds.copy()
|
| 638 |
+
if not draw_head:
|
| 639 |
+
kp2ds[[0,14,15,16,17], 2] = 0
|
| 640 |
+
kp2ds_body = kp2ds
|
| 641 |
+
|
| 642 |
+
# kp2ds_lhand = kp2ds.copy()[91:112]
|
| 643 |
+
# kp2ds_rhand = kp2ds.copy()[112:133]
|
| 644 |
+
|
| 645 |
+
limbSeq = [
|
| 646 |
+
[2, 3],
|
| 647 |
+
[2, 6], # shoulders
|
| 648 |
+
[3, 4],
|
| 649 |
+
[4, 5], # left arm
|
| 650 |
+
[6, 7],
|
| 651 |
+
[7, 8], # right arm
|
| 652 |
+
[2, 9],
|
| 653 |
+
[9, 10],
|
| 654 |
+
[10, 11], # right leg
|
| 655 |
+
[2, 12],
|
| 656 |
+
[12, 13],
|
| 657 |
+
[13, 14], # left leg
|
| 658 |
+
[2, 1],
|
| 659 |
+
[1, 15],
|
| 660 |
+
[15, 17],
|
| 661 |
+
[1, 16],
|
| 662 |
+
[16, 18], # face (nose, eyes, ears)
|
| 663 |
+
[14, 19],
|
| 664 |
+
[11, 20], # foot
|
| 665 |
+
]
|
| 666 |
+
|
| 667 |
+
colors = [
|
| 668 |
+
[255, 0, 0],
|
| 669 |
+
[255, 85, 0],
|
| 670 |
+
[255, 170, 0],
|
| 671 |
+
[255, 255, 0],
|
| 672 |
+
[170, 255, 0],
|
| 673 |
+
[85, 255, 0],
|
| 674 |
+
[0, 255, 0],
|
| 675 |
+
[0, 255, 85],
|
| 676 |
+
[0, 255, 170],
|
| 677 |
+
[0, 255, 255],
|
| 678 |
+
[0, 170, 255],
|
| 679 |
+
[0, 85, 255],
|
| 680 |
+
[0, 0, 255],
|
| 681 |
+
[85, 0, 255],
|
| 682 |
+
[170, 0, 255],
|
| 683 |
+
[255, 0, 255],
|
| 684 |
+
[255, 0, 170],
|
| 685 |
+
[255, 0, 85],
|
| 686 |
+
# foot
|
| 687 |
+
[200, 200, 0],
|
| 688 |
+
[100, 100, 0],
|
| 689 |
+
]
|
| 690 |
+
|
| 691 |
+
H, W, C = img.shape
|
| 692 |
+
stickwidth = max(int(min(H, W) / stick_width_norm), 1)
|
| 693 |
+
|
| 694 |
+
for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):
|
| 695 |
+
keypoint1 = kp2ds_body[k1_index - 1]
|
| 696 |
+
keypoint2 = kp2ds_body[k2_index - 1]
|
| 697 |
+
|
| 698 |
+
if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
|
| 699 |
+
continue
|
| 700 |
+
|
| 701 |
+
Y = np.array([keypoint1[0], keypoint2[0]])
|
| 702 |
+
X = np.array([keypoint1[1], keypoint2[1]])
|
| 703 |
+
mX = np.mean(X)
|
| 704 |
+
mY = np.mean(Y)
|
| 705 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
| 706 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
| 707 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
| 708 |
+
cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
|
| 709 |
+
|
| 710 |
+
for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):
|
| 711 |
+
if keypoint[-1] < threshold:
|
| 712 |
+
continue
|
| 713 |
+
x, y = keypoint[0], keypoint[1]
|
| 714 |
+
# cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
|
| 715 |
+
cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)
|
| 716 |
+
|
| 717 |
+
if draw_hand:
|
| 718 |
+
img = draw_handpose(img, kp2ds_lhand, hand_score_th=threshold)
|
| 719 |
+
img = draw_handpose(img, kp2ds_rhand, hand_score_th=threshold)
|
| 720 |
+
|
| 721 |
+
kp2ds_body[:, 0] /= W
|
| 722 |
+
kp2ds_body[:, 1] /= H
|
| 723 |
+
|
| 724 |
+
if data_to_json is not None:
|
| 725 |
+
if idx == -1:
|
| 726 |
+
data_to_json.append(
|
| 727 |
+
{
|
| 728 |
+
"image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1),
|
| 729 |
+
"height": H,
|
| 730 |
+
"width": W,
|
| 731 |
+
"category_id": 1,
|
| 732 |
+
"keypoints_body": kp2ds_body.tolist(),
|
| 733 |
+
"keypoints_left_hand": kp2ds_lhand.tolist(),
|
| 734 |
+
"keypoints_right_hand": kp2ds_rhand.tolist(),
|
| 735 |
+
}
|
| 736 |
+
)
|
| 737 |
+
else:
|
| 738 |
+
data_to_json[idx] = {
|
| 739 |
+
"image_id": "frame_{:05d}.jpg".format(idx + 1),
|
| 740 |
+
"height": H,
|
| 741 |
+
"width": W,
|
| 742 |
+
"category_id": 1,
|
| 743 |
+
"keypoints_body": kp2ds_body.tolist(),
|
| 744 |
+
"keypoints_left_hand": kp2ds_lhand.tolist(),
|
| 745 |
+
"keypoints_right_hand": kp2ds_rhand.tolist(),
|
| 746 |
+
}
|
| 747 |
+
return img
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
def draw_aapose_new(
|
| 751 |
+
img,
|
| 752 |
+
kp2ds,
|
| 753 |
+
threshold=0.6,
|
| 754 |
+
data_to_json=None,
|
| 755 |
+
idx=-1,
|
| 756 |
+
kp2ds_lhand=None,
|
| 757 |
+
kp2ds_rhand=None,
|
| 758 |
+
draw_hand=False,
|
| 759 |
+
stickwidth_type='v2',
|
| 760 |
+
draw_head=True
|
| 761 |
+
):
|
| 762 |
+
"""
|
| 763 |
+
Draw keypoints and connections representing hand pose on a given canvas.
|
| 764 |
+
|
| 765 |
+
Args:
|
| 766 |
+
canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose.
|
| 767 |
+
keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn
|
| 768 |
+
or None if no keypoints are present.
|
| 769 |
+
|
| 770 |
+
Returns:
|
| 771 |
+
np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose.
|
| 772 |
+
|
| 773 |
+
Note:
|
| 774 |
+
The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1.
|
| 775 |
+
"""
|
| 776 |
+
|
| 777 |
+
new_kep_list = [
|
| 778 |
+
"Nose",
|
| 779 |
+
"Neck",
|
| 780 |
+
"RShoulder",
|
| 781 |
+
"RElbow",
|
| 782 |
+
"RWrist", # No.4
|
| 783 |
+
"LShoulder",
|
| 784 |
+
"LElbow",
|
| 785 |
+
"LWrist", # No.7
|
| 786 |
+
"RHip",
|
| 787 |
+
"RKnee",
|
| 788 |
+
"RAnkle", # No.10
|
| 789 |
+
"LHip",
|
| 790 |
+
"LKnee",
|
| 791 |
+
"LAnkle", # No.13
|
| 792 |
+
"REye",
|
| 793 |
+
"LEye",
|
| 794 |
+
"REar",
|
| 795 |
+
"LEar",
|
| 796 |
+
"LToe",
|
| 797 |
+
"RToe",
|
| 798 |
+
]
|
| 799 |
+
# kp2ds_body = (kp2ds.copy()[[0, 6, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 17, 20]] + \
|
| 800 |
+
# kp2ds.copy()[[0, 5, 6, 8, 10, 5, 7, 9, 12, 14, 16, 11, 13, 15, 2, 1, 4, 3, 18, 21]]) / 2
|
| 801 |
+
kp2ds = kp2ds.copy()
|
| 802 |
+
if not draw_head:
|
| 803 |
+
kp2ds[[0,14,15,16,17], 2] = 0
|
| 804 |
+
kp2ds_body = kp2ds
|
| 805 |
+
|
| 806 |
+
# kp2ds_lhand = kp2ds.copy()[91:112]
|
| 807 |
+
# kp2ds_rhand = kp2ds.copy()[112:133]
|
| 808 |
+
|
| 809 |
+
limbSeq = [
|
| 810 |
+
[2, 3],
|
| 811 |
+
[2, 6], # shoulders
|
| 812 |
+
[3, 4],
|
| 813 |
+
[4, 5], # left arm
|
| 814 |
+
[6, 7],
|
| 815 |
+
[7, 8], # right arm
|
| 816 |
+
[2, 9],
|
| 817 |
+
[9, 10],
|
| 818 |
+
[10, 11], # right leg
|
| 819 |
+
[2, 12],
|
| 820 |
+
[12, 13],
|
| 821 |
+
[13, 14], # left leg
|
| 822 |
+
[2, 1],
|
| 823 |
+
[1, 15],
|
| 824 |
+
[15, 17],
|
| 825 |
+
[1, 16],
|
| 826 |
+
[16, 18], # face (nose, eyes, ears)
|
| 827 |
+
[14, 19],
|
| 828 |
+
[11, 20], # foot
|
| 829 |
+
]
|
| 830 |
+
|
| 831 |
+
colors = [
|
| 832 |
+
[255, 0, 0],
|
| 833 |
+
[255, 85, 0],
|
| 834 |
+
[255, 170, 0],
|
| 835 |
+
[255, 255, 0],
|
| 836 |
+
[170, 255, 0],
|
| 837 |
+
[85, 255, 0],
|
| 838 |
+
[0, 255, 0],
|
| 839 |
+
[0, 255, 85],
|
| 840 |
+
[0, 255, 170],
|
| 841 |
+
[0, 255, 255],
|
| 842 |
+
[0, 170, 255],
|
| 843 |
+
[0, 85, 255],
|
| 844 |
+
[0, 0, 255],
|
| 845 |
+
[85, 0, 255],
|
| 846 |
+
[170, 0, 255],
|
| 847 |
+
[255, 0, 255],
|
| 848 |
+
[255, 0, 170],
|
| 849 |
+
[255, 0, 85],
|
| 850 |
+
# foot
|
| 851 |
+
[200, 200, 0],
|
| 852 |
+
[100, 100, 0],
|
| 853 |
+
]
|
| 854 |
+
|
| 855 |
+
H, W, C = img.shape
|
| 856 |
+
H, W, C = img.shape
|
| 857 |
+
|
| 858 |
+
if stickwidth_type == 'v1':
|
| 859 |
+
stickwidth = max(int(min(H, W) / 200), 1)
|
| 860 |
+
elif stickwidth_type == 'v2':
|
| 861 |
+
stickwidth = max(int(min(H, W) / 200) - 1, 1)
|
| 862 |
+
else:
|
| 863 |
+
raise
|
| 864 |
+
|
| 865 |
+
for _idx, ((k1_index, k2_index), color) in enumerate(zip(limbSeq, colors)):
|
| 866 |
+
keypoint1 = kp2ds_body[k1_index - 1]
|
| 867 |
+
keypoint2 = kp2ds_body[k2_index - 1]
|
| 868 |
+
|
| 869 |
+
if keypoint1[-1] < threshold or keypoint2[-1] < threshold:
|
| 870 |
+
continue
|
| 871 |
+
|
| 872 |
+
Y = np.array([keypoint1[0], keypoint2[0]])
|
| 873 |
+
X = np.array([keypoint1[1], keypoint2[1]])
|
| 874 |
+
mX = np.mean(X)
|
| 875 |
+
mY = np.mean(Y)
|
| 876 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
| 877 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
| 878 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
| 879 |
+
cv2.fillConvexPoly(img, polygon, [int(float(c) * 0.6) for c in color])
|
| 880 |
+
|
| 881 |
+
for _idx, (keypoint, color) in enumerate(zip(kp2ds_body, colors)):
|
| 882 |
+
if keypoint[-1] < threshold:
|
| 883 |
+
continue
|
| 884 |
+
x, y = keypoint[0], keypoint[1]
|
| 885 |
+
# cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1)
|
| 886 |
+
cv2.circle(img, (int(x), int(y)), stickwidth, color, thickness=-1)
|
| 887 |
+
|
| 888 |
+
if draw_hand:
|
| 889 |
+
img = draw_handpose_new(img, kp2ds_lhand, stickwidth_type=stickwidth_type, hand_score_th=threshold)
|
| 890 |
+
img = draw_handpose_new(img, kp2ds_rhand, stickwidth_type=stickwidth_type, hand_score_th=threshold)
|
| 891 |
+
|
| 892 |
+
kp2ds_body[:, 0] /= W
|
| 893 |
+
kp2ds_body[:, 1] /= H
|
| 894 |
+
|
| 895 |
+
if data_to_json is not None:
|
| 896 |
+
if idx == -1:
|
| 897 |
+
data_to_json.append(
|
| 898 |
+
{
|
| 899 |
+
"image_id": "frame_{:05d}.jpg".format(len(data_to_json) + 1),
|
| 900 |
+
"height": H,
|
| 901 |
+
"width": W,
|
| 902 |
+
"category_id": 1,
|
| 903 |
+
"keypoints_body": kp2ds_body.tolist(),
|
| 904 |
+
"keypoints_left_hand": kp2ds_lhand.tolist(),
|
| 905 |
+
"keypoints_right_hand": kp2ds_rhand.tolist(),
|
| 906 |
+
}
|
| 907 |
+
)
|
| 908 |
+
else:
|
| 909 |
+
data_to_json[idx] = {
|
| 910 |
+
"image_id": "frame_{:05d}.jpg".format(idx + 1),
|
| 911 |
+
"height": H,
|
| 912 |
+
"width": W,
|
| 913 |
+
"category_id": 1,
|
| 914 |
+
"keypoints_body": kp2ds_body.tolist(),
|
| 915 |
+
"keypoints_left_hand": kp2ds_lhand.tolist(),
|
| 916 |
+
"keypoints_right_hand": kp2ds_rhand.tolist(),
|
| 917 |
+
}
|
| 918 |
+
return img
|
| 919 |
+
|
| 920 |
+
|
| 921 |
+
def draw_bbox(img, bbox, color=(255, 0, 0)):
|
| 922 |
+
img = load_image(img)
|
| 923 |
+
bbox = [int(bbox_tmp) for bbox_tmp in bbox]
|
| 924 |
+
cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2)
|
| 925 |
+
return img
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
def draw_kp2ds(img, kp2ds, threshold=0, color=(255, 0, 0), skeleton=None, reverse=False):
|
| 929 |
+
img = load_image(img, reverse)
|
| 930 |
+
|
| 931 |
+
if skeleton is not None:
|
| 932 |
+
if skeleton == "coco17":
|
| 933 |
+
skeleton_list = [
|
| 934 |
+
[6, 8],
|
| 935 |
+
[8, 10],
|
| 936 |
+
[5, 7],
|
| 937 |
+
[7, 9],
|
| 938 |
+
[11, 13],
|
| 939 |
+
[13, 15],
|
| 940 |
+
[12, 14],
|
| 941 |
+
[14, 16],
|
| 942 |
+
[5, 6],
|
| 943 |
+
[6, 12],
|
| 944 |
+
[12, 11],
|
| 945 |
+
[11, 5],
|
| 946 |
+
]
|
| 947 |
+
color_list = [
|
| 948 |
+
(255, 0, 0),
|
| 949 |
+
(0, 255, 0),
|
| 950 |
+
(0, 0, 255),
|
| 951 |
+
(255, 255, 0),
|
| 952 |
+
(255, 0, 255),
|
| 953 |
+
(0, 255, 255),
|
| 954 |
+
]
|
| 955 |
+
elif skeleton == "cocowholebody":
|
| 956 |
+
skeleton_list = [
|
| 957 |
+
[6, 8],
|
| 958 |
+
[8, 10],
|
| 959 |
+
[5, 7],
|
| 960 |
+
[7, 9],
|
| 961 |
+
[11, 13],
|
| 962 |
+
[13, 15],
|
| 963 |
+
[12, 14],
|
| 964 |
+
[14, 16],
|
| 965 |
+
[5, 6],
|
| 966 |
+
[6, 12],
|
| 967 |
+
[12, 11],
|
| 968 |
+
[11, 5],
|
| 969 |
+
[15, 17],
|
| 970 |
+
[15, 18],
|
| 971 |
+
[15, 19],
|
| 972 |
+
[16, 20],
|
| 973 |
+
[16, 21],
|
| 974 |
+
[16, 22],
|
| 975 |
+
[91, 92, 93, 94, 95],
|
| 976 |
+
[91, 96, 97, 98, 99],
|
| 977 |
+
[91, 100, 101, 102, 103],
|
| 978 |
+
[91, 104, 105, 106, 107],
|
| 979 |
+
[91, 108, 109, 110, 111],
|
| 980 |
+
[112, 113, 114, 115, 116],
|
| 981 |
+
[112, 117, 118, 119, 120],
|
| 982 |
+
[112, 121, 122, 123, 124],
|
| 983 |
+
[112, 125, 126, 127, 128],
|
| 984 |
+
[112, 129, 130, 131, 132],
|
| 985 |
+
]
|
| 986 |
+
color_list = [
|
| 987 |
+
(255, 0, 0),
|
| 988 |
+
(0, 255, 0),
|
| 989 |
+
(0, 0, 255),
|
| 990 |
+
(255, 255, 0),
|
| 991 |
+
(255, 0, 255),
|
| 992 |
+
(0, 255, 255),
|
| 993 |
+
]
|
| 994 |
+
else:
|
| 995 |
+
color_list = [color]
|
| 996 |
+
for _idx, _skeleton in enumerate(skeleton_list):
|
| 997 |
+
for i in range(len(_skeleton) - 1):
|
| 998 |
+
cv2.line(
|
| 999 |
+
img,
|
| 1000 |
+
(int(kp2ds[_skeleton[i], 0]), int(kp2ds[_skeleton[i], 1])),
|
| 1001 |
+
(int(kp2ds[_skeleton[i + 1], 0]), int(kp2ds[_skeleton[i + 1], 1])),
|
| 1002 |
+
color_list[_idx % len(color_list)],
|
| 1003 |
+
3,
|
| 1004 |
+
)
|
| 1005 |
+
|
| 1006 |
+
for _idx, kp2d in enumerate(kp2ds):
|
| 1007 |
+
if kp2d[2] > threshold:
|
| 1008 |
+
cv2.circle(img, (int(kp2d[0]), int(kp2d[1])), 3, color, -1)
|
| 1009 |
+
# cv2.putText(img,
|
| 1010 |
+
# str(_idx),
|
| 1011 |
+
# (int(kp2d[0, i, 0])*1,
|
| 1012 |
+
# int(kp2d[0, i, 1])*1),
|
| 1013 |
+
# cv2.FONT_HERSHEY_SIMPLEX,
|
| 1014 |
+
# 0.75,
|
| 1015 |
+
# color,
|
| 1016 |
+
# 2
|
| 1017 |
+
# )
|
| 1018 |
+
|
| 1019 |
+
return img
|
| 1020 |
+
|
| 1021 |
+
|
| 1022 |
+
def draw_mask(img, mask, background=0, return_rgba=False):
|
| 1023 |
+
img = load_image(img)
|
| 1024 |
+
h, w, _ = img.shape
|
| 1025 |
+
if type(background) == int:
|
| 1026 |
+
background = np.ones((h, w, 3)).astype(np.uint8) * 255 * background
|
| 1027 |
+
backgournd = cv2.resize(background, (w, h))
|
| 1028 |
+
img_rgba = np.concatenate([img, mask], -1)
|
| 1029 |
+
return alphaMerge(img_rgba, background, 0, 0, return_rgba=True)
|
| 1030 |
+
|
| 1031 |
+
|
| 1032 |
+
def draw_pcd(pcd_list, save_path=None):
|
| 1033 |
+
fig = plt.figure()
|
| 1034 |
+
ax = fig.add_subplot(111, projection="3d")
|
| 1035 |
+
|
| 1036 |
+
color_list = ["r", "g", "b", "y", "p"]
|
| 1037 |
+
|
| 1038 |
+
for _idx, _pcd in enumerate(pcd_list):
|
| 1039 |
+
ax.scatter(_pcd[:, 0], _pcd[:, 1], _pcd[:, 2], c=color_list[_idx], marker="o")
|
| 1040 |
+
|
| 1041 |
+
ax.set_xlabel("X")
|
| 1042 |
+
ax.set_ylabel("Y")
|
| 1043 |
+
ax.set_zlabel("Z")
|
| 1044 |
+
|
| 1045 |
+
if save_path is not None:
|
| 1046 |
+
plt.savefig(save_path)
|
| 1047 |
+
else:
|
| 1048 |
+
plt.savefig("tmp.png")
|
| 1049 |
+
|
| 1050 |
+
|
| 1051 |
+
def load_image(img, reverse=False):
|
| 1052 |
+
if type(img) == str:
|
| 1053 |
+
img = cv2.imread(img)
|
| 1054 |
+
if reverse:
|
| 1055 |
+
img = img.astype(np.float32)
|
| 1056 |
+
img = img[:, :, ::-1]
|
| 1057 |
+
img = img.astype(np.uint8)
|
| 1058 |
+
return img
|
| 1059 |
+
|
| 1060 |
+
|
| 1061 |
+
def draw_skeleten(meta):
|
| 1062 |
+
kps = []
|
| 1063 |
+
for i, kp in enumerate(meta["keypoints_body"]):
|
| 1064 |
+
if kp is None:
|
| 1065 |
+
# if kp is None:
|
| 1066 |
+
kps.append([0, 0, 0])
|
| 1067 |
+
else:
|
| 1068 |
+
kps.append([*kp, 1])
|
| 1069 |
+
kps = np.array(kps)
|
| 1070 |
+
|
| 1071 |
+
kps[:, 0] *= meta["width"]
|
| 1072 |
+
kps[:, 1] *= meta["height"]
|
| 1073 |
+
pose_img = np.zeros([meta["height"], meta["width"], 3], dtype=np.uint8)
|
| 1074 |
+
|
| 1075 |
+
pose_img = draw_aapose(
|
| 1076 |
+
pose_img,
|
| 1077 |
+
kps,
|
| 1078 |
+
draw_hand=True,
|
| 1079 |
+
kp2ds_lhand=meta["keypoints_left_hand"],
|
| 1080 |
+
kp2ds_rhand=meta["keypoints_right_hand"],
|
| 1081 |
+
)
|
| 1082 |
+
return pose_img
|
| 1083 |
+
|
| 1084 |
+
|
| 1085 |
+
def draw_skeleten_with_pncc(pncc: np.ndarray, meta: Dict) -> np.ndarray:
|
| 1086 |
+
"""
|
| 1087 |
+
Args:
|
| 1088 |
+
pncc: [H,W,3]
|
| 1089 |
+
meta: required keys: keypoints_body: [N, 3] keypoints_left_hand, keypoints_right_hand
|
| 1090 |
+
Return:
|
| 1091 |
+
np.ndarray [H, W, 3]
|
| 1092 |
+
"""
|
| 1093 |
+
# preprocess keypoints
|
| 1094 |
+
kps = []
|
| 1095 |
+
for i, kp in enumerate(meta["keypoints_body"]):
|
| 1096 |
+
if kp is None:
|
| 1097 |
+
# if kp is None:
|
| 1098 |
+
kps.append([0, 0, 0])
|
| 1099 |
+
elif i in [14, 15, 16, 17]:
|
| 1100 |
+
kps.append([0, 0, 0])
|
| 1101 |
+
else:
|
| 1102 |
+
kps.append([*kp])
|
| 1103 |
+
kps = np.stack(kps)
|
| 1104 |
+
|
| 1105 |
+
kps[:, 0] *= pncc.shape[1]
|
| 1106 |
+
kps[:, 1] *= pncc.shape[0]
|
| 1107 |
+
|
| 1108 |
+
# draw neck
|
| 1109 |
+
canvas = np.zeros_like(pncc)
|
| 1110 |
+
if kps[0][2] > 0.6 and kps[1][2] > 0.6:
|
| 1111 |
+
canvas = draw_ellipse_by_2kp(canvas, kps[0], kps[1], [0, 0, 255])
|
| 1112 |
+
|
| 1113 |
+
# draw pncc
|
| 1114 |
+
mask = (pncc > 0).max(axis=2)
|
| 1115 |
+
canvas[mask] = pncc[mask]
|
| 1116 |
+
pncc = canvas
|
| 1117 |
+
|
| 1118 |
+
# draw other skeleten
|
| 1119 |
+
kps[0] = 0
|
| 1120 |
+
|
| 1121 |
+
meta["keypoints_left_hand"][:, 0] *= meta["width"]
|
| 1122 |
+
meta["keypoints_left_hand"][:, 1] *= meta["height"]
|
| 1123 |
+
|
| 1124 |
+
meta["keypoints_right_hand"][:, 0] *= meta["width"]
|
| 1125 |
+
meta["keypoints_right_hand"][:, 1] *= meta["height"]
|
| 1126 |
+
pose_img = draw_aapose(
|
| 1127 |
+
pncc,
|
| 1128 |
+
kps,
|
| 1129 |
+
draw_hand=True,
|
| 1130 |
+
kp2ds_lhand=meta["keypoints_left_hand"],
|
| 1131 |
+
kp2ds_rhand=meta["keypoints_right_hand"],
|
| 1132 |
+
)
|
| 1133 |
+
return pose_img
|
| 1134 |
+
|
| 1135 |
+
|
| 1136 |
+
FACE_CUSTOM_STYLE = {
|
| 1137 |
+
"eyeball": {"indexs": [68, 69], "color": [255, 255, 255], "connect": False},
|
| 1138 |
+
"left_eyebrow": {"indexs": [17, 18, 19, 20, 21], "color": [0, 255, 0]},
|
| 1139 |
+
"right_eyebrow": {"indexs": [22, 23, 24, 25, 26], "color": [0, 0, 255]},
|
| 1140 |
+
"left_eye": {"indexs": [36, 37, 38, 39, 40, 41], "color": [255, 255, 0], "close": True},
|
| 1141 |
+
"right_eye": {"indexs": [42, 43, 44, 45, 46, 47], "color": [255, 0, 255], "close": True},
|
| 1142 |
+
"mouth_outside": {"indexs": list(range(48, 60)), "color": [100, 255, 50], "close": True},
|
| 1143 |
+
"mouth_inside": {"indexs": [60, 61, 62, 63, 64, 65, 66, 67], "color": [255, 100, 50], "close": True},
|
| 1144 |
+
}
|
| 1145 |
+
|
| 1146 |
+
|
| 1147 |
+
def draw_face_kp(img, kps, thickness=2, style=FACE_CUSTOM_STYLE):
|
| 1148 |
+
"""
|
| 1149 |
+
Args:
|
| 1150 |
+
img: [H, W, 3]
|
| 1151 |
+
kps: [70, 2]
|
| 1152 |
+
"""
|
| 1153 |
+
img = img.copy()
|
| 1154 |
+
for key, item in style.items():
|
| 1155 |
+
pts = np.array(kps[item["indexs"]]).astype(np.int32)
|
| 1156 |
+
connect = item.get("connect", True)
|
| 1157 |
+
color = item["color"]
|
| 1158 |
+
close = item.get("close", False)
|
| 1159 |
+
if connect:
|
| 1160 |
+
cv2.polylines(img, [pts], close, color, thickness=thickness)
|
| 1161 |
+
else:
|
| 1162 |
+
for kp in pts:
|
| 1163 |
+
kp = np.array(kp).astype(np.int32)
|
| 1164 |
+
cv2.circle(img, kp, thickness * 2, color=color, thickness=-1)
|
| 1165 |
+
return img
|
| 1166 |
+
|
| 1167 |
+
|
| 1168 |
+
def draw_traj(metas: List[AAPoseMeta], threshold=0.6):
|
| 1169 |
+
|
| 1170 |
+
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
|
| 1171 |
+
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
|
| 1172 |
+
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85], [100, 255, 50], [255, 100, 50],
|
| 1173 |
+
# foot
|
| 1174 |
+
[200, 200, 0],
|
| 1175 |
+
[100, 100, 0]
|
| 1176 |
+
]
|
| 1177 |
+
limbSeq = [
|
| 1178 |
+
[1, 2], [1, 5], # shoulders
|
| 1179 |
+
[2, 3], [3, 4], # left arm
|
| 1180 |
+
[5, 6], [6, 7], # right arm
|
| 1181 |
+
[1, 8], [8, 9], [9, 10], # right leg
|
| 1182 |
+
[1, 11], [11, 12], [12, 13], # left leg
|
| 1183 |
+
# face (nose, eyes, ears)
|
| 1184 |
+
[13, 18], [10, 19] # foot
|
| 1185 |
+
]
|
| 1186 |
+
|
| 1187 |
+
face_seq = [[1, 0], [0, 14], [14, 16], [0, 15], [15, 17]]
|
| 1188 |
+
kp_body = np.array([meta.kps_body for meta in metas])
|
| 1189 |
+
kp_body_p = np.array([meta.kps_body_p for meta in metas])
|
| 1190 |
+
|
| 1191 |
+
|
| 1192 |
+
face_seq = random.sample(face_seq, 2)
|
| 1193 |
+
|
| 1194 |
+
kp_lh = np.array([meta.kps_lhand for meta in metas])
|
| 1195 |
+
kp_rh = np.array([meta.kps_rhand for meta in metas])
|
| 1196 |
+
|
| 1197 |
+
kp_lh_p = np.array([meta.kps_lhand_p for meta in metas])
|
| 1198 |
+
kp_rh_p = np.array([meta.kps_rhand_p for meta in metas])
|
| 1199 |
+
|
| 1200 |
+
# kp_lh = np.concatenate([kp_lh, kp_lh_p], axis=-1)
|
| 1201 |
+
# kp_rh = np.concatenate([kp_rh, kp_rh_p], axis=-1)
|
| 1202 |
+
|
| 1203 |
+
new_limbSeq = []
|
| 1204 |
+
key_point_list = []
|
| 1205 |
+
for _idx, ((k1_index, k2_index)) in enumerate(limbSeq):
|
| 1206 |
+
|
| 1207 |
+
vis = (kp_body_p[:, k1_index] > threshold) * (kp_body_p[:, k2_index] > threshold) * 1
|
| 1208 |
+
if vis.sum() * 1.0 / vis.shape[0] > 0.4:
|
| 1209 |
+
new_limbSeq.append([k1_index, k2_index])
|
| 1210 |
+
|
| 1211 |
+
for _idx, ((k1_index, k2_index)) in enumerate(limbSeq):
|
| 1212 |
+
|
| 1213 |
+
keypoint1 = kp_body[:, k1_index - 1]
|
| 1214 |
+
keypoint2 = kp_body[:, k2_index - 1]
|
| 1215 |
+
interleave = random.randint(4, 7)
|
| 1216 |
+
randind = random.randint(0, interleave - 1)
|
| 1217 |
+
# randind = random.rand(range(interleave), sampling_num)
|
| 1218 |
+
|
| 1219 |
+
Y = np.array([keypoint1[:, 0], keypoint2[:, 0]])
|
| 1220 |
+
X = np.array([keypoint1[:, 1], keypoint2[:, 1]])
|
| 1221 |
+
|
| 1222 |
+
vis = (keypoint1[:, -1] > threshold) * (keypoint2[:, -1] > threshold) * 1
|
| 1223 |
+
|
| 1224 |
+
# for randidx in randind:
|
| 1225 |
+
t = randind / interleave
|
| 1226 |
+
x = (1-t)*Y[0, :] + t*Y[1, :]
|
| 1227 |
+
y = (1-t)*X[0, :] + t*X[1, :]
|
| 1228 |
+
|
| 1229 |
+
# np.array([1])
|
| 1230 |
+
x = x.astype(int)
|
| 1231 |
+
y = y.astype(int)
|
| 1232 |
+
|
| 1233 |
+
new_array = np.array([x, y, vis]).T
|
| 1234 |
+
|
| 1235 |
+
key_point_list.append(new_array)
|
| 1236 |
+
|
| 1237 |
+
indx_lh = random.randint(0, kp_lh.shape[1] - 1)
|
| 1238 |
+
lh = kp_lh[:, indx_lh, :]
|
| 1239 |
+
lh_p = kp_lh_p[:, indx_lh:indx_lh+1]
|
| 1240 |
+
lh = np.concatenate([lh, lh_p], axis=-1)
|
| 1241 |
+
|
| 1242 |
+
indx_rh = random.randint(0, kp_rh.shape[1] - 1)
|
| 1243 |
+
rh = kp_rh[:, random.randint(0, kp_rh.shape[1] - 1), :]
|
| 1244 |
+
rh_p = kp_rh_p[:, indx_rh:indx_rh+1]
|
| 1245 |
+
rh = np.concatenate([rh, rh_p], axis=-1)
|
| 1246 |
+
|
| 1247 |
+
|
| 1248 |
+
|
| 1249 |
+
lh[-1, :] = (lh[-1, :] > threshold) * 1
|
| 1250 |
+
rh[-1, :] = (rh[-1, :] > threshold) * 1
|
| 1251 |
+
|
| 1252 |
+
# print(rh.shape, new_array.shape)
|
| 1253 |
+
# exit()
|
| 1254 |
+
key_point_list.append(lh.astype(int))
|
| 1255 |
+
key_point_list.append(rh.astype(int))
|
| 1256 |
+
|
| 1257 |
+
|
| 1258 |
+
key_points_list = np.stack(key_point_list)
|
| 1259 |
+
num_points = len(key_points_list)
|
| 1260 |
+
sample_colors = random.sample(colors, num_points)
|
| 1261 |
+
|
| 1262 |
+
stickwidth = max(int(min(metas[0].width, metas[0].height) / 150), 2)
|
| 1263 |
+
|
| 1264 |
+
image_list_ori = []
|
| 1265 |
+
for i in range(key_points_list.shape[-2]):
|
| 1266 |
+
_image_vis = np.zeros((metas[0].width, metas[0].height, 3))
|
| 1267 |
+
points = key_points_list[:, i, :]
|
| 1268 |
+
for idx, point in enumerate(points):
|
| 1269 |
+
x, y, vis = point
|
| 1270 |
+
if vis == 1:
|
| 1271 |
+
cv2.circle(_image_vis, (x, y), stickwidth, sample_colors[idx], thickness=-1)
|
| 1272 |
+
|
| 1273 |
+
image_list_ori.append(_image_vis)
|
| 1274 |
+
|
| 1275 |
+
return image_list_ori
|
| 1276 |
+
|
| 1277 |
+
return [np.zeros([meta.width, meta.height, 3], dtype=np.uint8) for meta in metas]
|
| 1278 |
+
|
| 1279 |
+
|
| 1280 |
+
if __name__ == "__main__":
|
| 1281 |
+
meta = {
|
| 1282 |
+
"image_id": "00472.jpg",
|
| 1283 |
+
"height": 540,
|
| 1284 |
+
"width": 414,
|
| 1285 |
+
"category_id": 1,
|
| 1286 |
+
"keypoints_body": [
|
| 1287 |
+
[0.5084776947463768, 0.11350188078703703],
|
| 1288 |
+
[0.504467655495169, 0.20419560185185184],
|
| 1289 |
+
[0.3982016153381642, 0.198046875],
|
| 1290 |
+
[0.3841664779589372, 0.34869068287037036],
|
| 1291 |
+
[0.3901815368357488, 0.4670536747685185],
|
| 1292 |
+
[0.610733695652174, 0.2103443287037037],
|
| 1293 |
+
[0.6167487545289855, 0.3517650462962963],
|
| 1294 |
+
[0.6448190292874396, 0.4762767650462963],
|
| 1295 |
+
[0.4523371452294686, 0.47320240162037036],
|
| 1296 |
+
[0.4503321256038647, 0.6776475694444445],
|
| 1297 |
+
[0.47639738073671495, 0.8544234664351852],
|
| 1298 |
+
[0.5766483620169082, 0.47320240162037036],
|
| 1299 |
+
[0.5666232638888888, 0.6761103877314815],
|
| 1300 |
+
[0.534542949879227, 0.863646556712963],
|
| 1301 |
+
[0.4864224788647343, 0.09505570023148148],
|
| 1302 |
+
[0.5285278910024155, 0.09351851851851851],
|
| 1303 |
+
[0.46236224335748793, 0.10581597222222222],
|
| 1304 |
+
[0.5586031853864735, 0.10274160879629629],
|
| 1305 |
+
[0.4994551064311594, 0.9405056423611111],
|
| 1306 |
+
[0.4152442821557971, 0.9312825520833333],
|
| 1307 |
+
],
|
| 1308 |
+
"keypoints_left_hand": [
|
| 1309 |
+
[267.78515625, 263.830078125, 1.2840936183929443],
|
| 1310 |
+
[265.294921875, 269.640625, 1.2546794414520264],
|
| 1311 |
+
[263.634765625, 277.111328125, 1.2863062620162964],
|
| 1312 |
+
[262.8046875, 285.412109375, 1.267038345336914],
|
| 1313 |
+
[261.14453125, 292.8828125, 1.280144453048706],
|
| 1314 |
+
[273.595703125, 281.26171875, 1.2592815160751343],
|
| 1315 |
+
[271.10546875, 291.22265625, 1.3256099224090576],
|
| 1316 |
+
[265.294921875, 294.54296875, 1.2368024587631226],
|
| 1317 |
+
[261.14453125, 294.54296875, 0.9771889448165894],
|
| 1318 |
+
[274.42578125, 282.091796875, 1.250044584274292],
|
| 1319 |
+
[269.4453125, 291.22265625, 1.2571144104003906],
|
| 1320 |
+
[264.46484375, 292.8828125, 1.177802324295044],
|
| 1321 |
+
[260.314453125, 292.052734375, 0.9283463358879089],
|
| 1322 |
+
[273.595703125, 282.091796875, 1.1834490299224854],
|
| 1323 |
+
[269.4453125, 290.392578125, 1.188171625137329],
|
| 1324 |
+
[265.294921875, 290.392578125, 1.192609429359436],
|
| 1325 |
+
[261.974609375, 289.5625, 0.9366656541824341],
|
| 1326 |
+
[271.935546875, 281.26171875, 1.0946396589279175],
|
| 1327 |
+
[268.615234375, 287.072265625, 0.9906131029129028],
|
| 1328 |
+
[265.294921875, 287.90234375, 1.0219476222991943],
|
| 1329 |
+
[262.8046875, 287.072265625, 0.9240120053291321],
|
| 1330 |
+
],
|
| 1331 |
+
"keypoints_right_hand": [
|
| 1332 |
+
[161.53515625, 258.849609375, 1.2069408893585205],
|
| 1333 |
+
[168.17578125, 263.0, 1.1846840381622314],
|
| 1334 |
+
[173.986328125, 269.640625, 1.1435924768447876],
|
| 1335 |
+
[173.986328125, 277.94140625, 1.1802611351013184],
|
| 1336 |
+
[173.986328125, 286.2421875, 1.2599592208862305],
|
| 1337 |
+
[165.685546875, 275.451171875, 1.0633569955825806],
|
| 1338 |
+
[167.345703125, 286.2421875, 1.1693341732025146],
|
| 1339 |
+
[169.8359375, 291.22265625, 1.2698509693145752],
|
| 1340 |
+
[170.666015625, 294.54296875, 1.0619274377822876],
|
| 1341 |
+
[160.705078125, 276.28125, 1.0995020866394043],
|
| 1342 |
+
[163.1953125, 287.90234375, 1.2735884189605713],
|
| 1343 |
+
[166.515625, 291.22265625, 1.339503526687622],
|
| 1344 |
+
[169.005859375, 294.54296875, 1.0835273265838623],
|
| 1345 |
+
[157.384765625, 277.111328125, 1.0866981744766235],
|
| 1346 |
+
[161.53515625, 287.072265625, 1.2468621730804443],
|
| 1347 |
+
[164.025390625, 289.5625, 1.2817761898040771],
|
| 1348 |
+
[166.515625, 292.052734375, 1.099466323852539],
|
| 1349 |
+
[155.724609375, 277.111328125, 1.1065717935562134],
|
| 1350 |
+
[159.044921875, 285.412109375, 1.1924479007720947],
|
| 1351 |
+
[160.705078125, 287.072265625, 1.1304771900177002],
|
| 1352 |
+
[162.365234375, 287.90234375, 1.0040509700775146],
|
| 1353 |
+
],
|
| 1354 |
+
}
|
| 1355 |
+
demo_meta = AAPoseMeta(meta)
|
| 1356 |
+
res = draw_traj([demo_meta]*5)
|
| 1357 |
+
cv2.imwrite("traj.png", res[0][..., ::-1])
|
wan/modules/animate/preprocess/pose2d.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import os
|
| 3 |
+
import cv2
|
| 4 |
+
from typing import Union, List
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import onnxruntime
|
| 9 |
+
|
| 10 |
+
from pose2d_utils import (
|
| 11 |
+
read_img,
|
| 12 |
+
box_convert_simple,
|
| 13 |
+
bbox_from_detector,
|
| 14 |
+
crop,
|
| 15 |
+
keypoints_from_heatmaps,
|
| 16 |
+
load_pose_metas_from_kp2ds_seq
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SimpleOnnxInference(object):
|
| 21 |
+
def __init__(self, checkpoint, device='cuda', reverse_input=False, **kwargs):
|
| 22 |
+
if isinstance(device, str):
|
| 23 |
+
device = torch.device(device)
|
| 24 |
+
if device.type == 'cuda':
|
| 25 |
+
device = '{}:{}'.format(device.type, device.index)
|
| 26 |
+
providers = [("CUDAExecutionProvider", {"device_id": device[-1:] if device[-1] in [str(_i) for _i in range(10)] else "0"}), "CPUExecutionProvider"]
|
| 27 |
+
else:
|
| 28 |
+
providers = ["CPUExecutionProvider"]
|
| 29 |
+
self.device = device
|
| 30 |
+
if not os.path.exists(checkpoint):
|
| 31 |
+
raise RuntimeError("{} is not existed!".format(checkpoint))
|
| 32 |
+
|
| 33 |
+
if os.path.isdir(checkpoint):
|
| 34 |
+
checkpoint = os.path.join(checkpoint, 'end2end.onnx')
|
| 35 |
+
|
| 36 |
+
self.session = onnxruntime.InferenceSession(checkpoint,
|
| 37 |
+
providers=providers
|
| 38 |
+
)
|
| 39 |
+
self.input_name = self.session.get_inputs()[0].name
|
| 40 |
+
self.output_name = self.session.get_outputs()[0].name
|
| 41 |
+
self.input_resolution = self.session.get_inputs()[0].shape[2:] if not reverse_input else self.session.get_inputs()[0].shape[2:][::-1]
|
| 42 |
+
self.input_resolution = np.array(self.input_resolution)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def __call__(self, *args, **kwargs):
|
| 46 |
+
return self.forward(*args, **kwargs)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_output_names(self):
|
| 50 |
+
output_names = []
|
| 51 |
+
for node in self.session.get_outputs():
|
| 52 |
+
output_names.append(node.name)
|
| 53 |
+
return output_names
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def set_device(self, device):
|
| 57 |
+
if isinstance(device, str):
|
| 58 |
+
device = torch.device(device)
|
| 59 |
+
if device.type == 'cuda':
|
| 60 |
+
device = '{}:{}'.format(device.type, device.index)
|
| 61 |
+
providers = [("CUDAExecutionProvider", {"device_id": device[-1:] if device[-1] in [str(_i) for _i in range(10)] else "0"}), "CPUExecutionProvider"]
|
| 62 |
+
else:
|
| 63 |
+
providers = ["CPUExecutionProvider"]
|
| 64 |
+
self.session.set_providers(["CUDAExecutionProvider"])
|
| 65 |
+
self.device = device
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Yolo(SimpleOnnxInference):
|
| 69 |
+
def __init__(self, checkpoint, device='cuda', threshold_conf=0.05, threshold_multi_persons=0.1, input_resolution=(640, 640), threshold_iou=0.5, threshold_bbox_shape_ratio=0.4, cat_id=[1], select_type='max', strict=True, sorted_func=None, **kwargs):
|
| 70 |
+
super(Yolo, self).__init__(checkpoint, device=device, **kwargs)
|
| 71 |
+
self.session.set_providers(["CUDAExecutionProvider"])
|
| 72 |
+
model_inputs = self.session.get_inputs()
|
| 73 |
+
input_shape = model_inputs[0].shape
|
| 74 |
+
|
| 75 |
+
self.input_width = 640
|
| 76 |
+
self.input_height = 640
|
| 77 |
+
|
| 78 |
+
self.threshold_multi_persons = threshold_multi_persons
|
| 79 |
+
self.threshold_conf = threshold_conf
|
| 80 |
+
self.threshold_iou = threshold_iou
|
| 81 |
+
self.threshold_bbox_shape_ratio = threshold_bbox_shape_ratio
|
| 82 |
+
self.input_resolution = input_resolution
|
| 83 |
+
self.cat_id = cat_id
|
| 84 |
+
self.select_type = select_type
|
| 85 |
+
self.strict = strict
|
| 86 |
+
self.sorted_func = sorted_func
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def preprocess(self, input_image):
|
| 90 |
+
"""
|
| 91 |
+
Preprocesses the input image before performing inference.
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
image_data: Preprocessed image data ready for inference.
|
| 95 |
+
"""
|
| 96 |
+
img = read_img(input_image)
|
| 97 |
+
# Get the height and width of the input image
|
| 98 |
+
img_height, img_width = img.shape[:2]
|
| 99 |
+
# Resize the image to match the input shape
|
| 100 |
+
img = cv2.resize(img, (self.input_resolution[1], self.input_resolution[0]))
|
| 101 |
+
# Normalize the image data by dividing it by 255.0
|
| 102 |
+
image_data = np.array(img) / 255.0
|
| 103 |
+
# Transpose the image to have the channel dimension as the first dimension
|
| 104 |
+
image_data = np.transpose(image_data, (2, 0, 1)) # Channel first
|
| 105 |
+
# Expand the dimensions of the image data to match the expected input shape
|
| 106 |
+
# image_data = np.expand_dims(image_data, axis=0).astype(np.float32)
|
| 107 |
+
image_data = image_data.astype(np.float32)
|
| 108 |
+
# Return the preprocessed image data
|
| 109 |
+
return image_data, np.array([img_height, img_width])
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def postprocess(self, output, shape_raw, cat_id=[1]):
|
| 113 |
+
"""
|
| 114 |
+
Performs post-processing on the model's output to extract bounding boxes, scores, and class IDs.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
input_image (numpy.ndarray): The input image.
|
| 118 |
+
output (numpy.ndarray): The output of the model.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
numpy.ndarray: The input image with detections drawn on it.
|
| 122 |
+
"""
|
| 123 |
+
# Transpose and squeeze the output to match the expected shape
|
| 124 |
+
|
| 125 |
+
outputs = np.squeeze(output)
|
| 126 |
+
if len(outputs.shape) == 1:
|
| 127 |
+
outputs = outputs[None]
|
| 128 |
+
if output.shape[-1] != 6 and output.shape[1] == 84:
|
| 129 |
+
outputs = np.transpose(outputs)
|
| 130 |
+
|
| 131 |
+
# Get the number of rows in the outputs array
|
| 132 |
+
rows = outputs.shape[0]
|
| 133 |
+
|
| 134 |
+
# Calculate the scaling factors for the bounding box coordinates
|
| 135 |
+
x_factor = shape_raw[1] / self.input_width
|
| 136 |
+
y_factor = shape_raw[0] / self.input_height
|
| 137 |
+
|
| 138 |
+
# Lists to store the bounding boxes, scores, and class IDs of the detections
|
| 139 |
+
boxes = []
|
| 140 |
+
scores = []
|
| 141 |
+
class_ids = []
|
| 142 |
+
|
| 143 |
+
if outputs.shape[-1] == 6:
|
| 144 |
+
max_scores = outputs[:, 4]
|
| 145 |
+
classid = outputs[:, -1]
|
| 146 |
+
|
| 147 |
+
threshold_conf_masks = max_scores >= self.threshold_conf
|
| 148 |
+
classid_masks = classid[threshold_conf_masks] != 3.14159
|
| 149 |
+
|
| 150 |
+
max_scores = max_scores[threshold_conf_masks][classid_masks]
|
| 151 |
+
classid = classid[threshold_conf_masks][classid_masks]
|
| 152 |
+
|
| 153 |
+
boxes = outputs[:, :4][threshold_conf_masks][classid_masks]
|
| 154 |
+
boxes[:, [0, 2]] *= x_factor
|
| 155 |
+
boxes[:, [1, 3]] *= y_factor
|
| 156 |
+
boxes[:, 2] = boxes[:, 2] - boxes[:, 0]
|
| 157 |
+
boxes[:, 3] = boxes[:, 3] - boxes[:, 1]
|
| 158 |
+
boxes = boxes.astype(np.int32)
|
| 159 |
+
|
| 160 |
+
else:
|
| 161 |
+
classes_scores = outputs[:, 4:]
|
| 162 |
+
max_scores = np.amax(classes_scores, -1)
|
| 163 |
+
threshold_conf_masks = max_scores >= self.threshold_conf
|
| 164 |
+
|
| 165 |
+
classid = np.argmax(classes_scores[threshold_conf_masks], -1)
|
| 166 |
+
|
| 167 |
+
classid_masks = classid!=3.14159
|
| 168 |
+
|
| 169 |
+
classes_scores = classes_scores[threshold_conf_masks][classid_masks]
|
| 170 |
+
max_scores = max_scores[threshold_conf_masks][classid_masks]
|
| 171 |
+
classid = classid[classid_masks]
|
| 172 |
+
|
| 173 |
+
xywh = outputs[:, :4][threshold_conf_masks][classid_masks]
|
| 174 |
+
|
| 175 |
+
x = xywh[:, 0:1]
|
| 176 |
+
y = xywh[:, 1:2]
|
| 177 |
+
w = xywh[:, 2:3]
|
| 178 |
+
h = xywh[:, 3:4]
|
| 179 |
+
|
| 180 |
+
left = ((x - w / 2) * x_factor)
|
| 181 |
+
top = ((y - h / 2) * y_factor)
|
| 182 |
+
width = (w * x_factor)
|
| 183 |
+
height = (h * y_factor)
|
| 184 |
+
boxes = np.concatenate([left, top, width, height], axis=-1).astype(np.int32)
|
| 185 |
+
|
| 186 |
+
boxes = boxes.tolist()
|
| 187 |
+
scores = max_scores.tolist()
|
| 188 |
+
class_ids = classid.tolist()
|
| 189 |
+
|
| 190 |
+
# Apply non-maximum suppression to filter out overlapping bounding boxes
|
| 191 |
+
indices = cv2.dnn.NMSBoxes(boxes, scores, self.threshold_conf, self.threshold_iou)
|
| 192 |
+
# Iterate over the selected indices after non-maximum suppression
|
| 193 |
+
|
| 194 |
+
results = []
|
| 195 |
+
for i in indices:
|
| 196 |
+
# Get the box, score, and class ID corresponding to the index
|
| 197 |
+
box = box_convert_simple(boxes[i], 'xywh2xyxy')
|
| 198 |
+
score = scores[i]
|
| 199 |
+
class_id = class_ids[i]
|
| 200 |
+
results.append(box + [score] + [class_id])
|
| 201 |
+
# # Draw the detection on the input image
|
| 202 |
+
|
| 203 |
+
# Return the modified input image
|
| 204 |
+
return np.array(results)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def process_results(self, results, shape_raw, cat_id=[1], single_person=True):
|
| 208 |
+
if isinstance(results, tuple):
|
| 209 |
+
det_results = results[0]
|
| 210 |
+
else:
|
| 211 |
+
det_results = results
|
| 212 |
+
|
| 213 |
+
person_results = []
|
| 214 |
+
person_count = 0
|
| 215 |
+
if len(results):
|
| 216 |
+
max_idx = -1
|
| 217 |
+
max_bbox_size = shape_raw[0] * shape_raw[1] * -10
|
| 218 |
+
max_bbox_shape = -1
|
| 219 |
+
|
| 220 |
+
bboxes = []
|
| 221 |
+
idx_list = []
|
| 222 |
+
for i in range(results.shape[0]):
|
| 223 |
+
bbox = results[i]
|
| 224 |
+
if (bbox[-1] + 1 in cat_id) and (bbox[-2] > self.threshold_conf):
|
| 225 |
+
idx_list.append(i)
|
| 226 |
+
bbox_shape = max((bbox[2] - bbox[0]), ((bbox[3] - bbox[1])))
|
| 227 |
+
if bbox_shape > max_bbox_shape:
|
| 228 |
+
max_bbox_shape = bbox_shape
|
| 229 |
+
|
| 230 |
+
results = results[idx_list]
|
| 231 |
+
|
| 232 |
+
for i in range(results.shape[0]):
|
| 233 |
+
bbox = results[i]
|
| 234 |
+
bboxes.append(bbox)
|
| 235 |
+
if self.select_type == 'max':
|
| 236 |
+
bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1]))
|
| 237 |
+
elif self.select_type == 'center':
|
| 238 |
+
bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1]/2)) * -1
|
| 239 |
+
bbox_shape = max((bbox[2] - bbox[0]), ((bbox[3] - bbox[1])))
|
| 240 |
+
if bbox_size > max_bbox_size:
|
| 241 |
+
if (self.strict or max_idx != -1) and bbox_shape < max_bbox_shape * self.threshold_bbox_shape_ratio:
|
| 242 |
+
continue
|
| 243 |
+
max_bbox_size = bbox_size
|
| 244 |
+
max_bbox_shape = bbox_shape
|
| 245 |
+
max_idx = i
|
| 246 |
+
|
| 247 |
+
if self.sorted_func is not None and len(bboxes) > 0:
|
| 248 |
+
max_idx = self.sorted_func(bboxes, shape_raw)
|
| 249 |
+
bbox = bboxes[max_idx]
|
| 250 |
+
if self.select_type == 'max':
|
| 251 |
+
max_bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1]))
|
| 252 |
+
elif self.select_type == 'center':
|
| 253 |
+
max_bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1]/2)) * -1
|
| 254 |
+
|
| 255 |
+
if max_idx != -1:
|
| 256 |
+
person_count = 1
|
| 257 |
+
|
| 258 |
+
if max_idx != -1:
|
| 259 |
+
person = {}
|
| 260 |
+
person['bbox'] = results[max_idx, :5]
|
| 261 |
+
person['track_id'] = int(0)
|
| 262 |
+
person_results.append(person)
|
| 263 |
+
|
| 264 |
+
for i in range(results.shape[0]):
|
| 265 |
+
bbox = results[i]
|
| 266 |
+
if (bbox[-1] + 1 in cat_id) and (bbox[-2] > self.threshold_conf):
|
| 267 |
+
if self.select_type == 'max':
|
| 268 |
+
bbox_size = (bbox[2] - bbox[0]) * ((bbox[3] - bbox[1]))
|
| 269 |
+
elif self.select_type == 'center':
|
| 270 |
+
bbox_size = (abs((bbox[2] + bbox[0]) / 2 - shape_raw[1]/2)) * -1
|
| 271 |
+
if i != max_idx and bbox_size > max_bbox_size * self.threshold_multi_persons and bbox_size < max_bbox_size:
|
| 272 |
+
person_count += 1
|
| 273 |
+
if not single_person:
|
| 274 |
+
person = {}
|
| 275 |
+
person['bbox'] = results[i, :5]
|
| 276 |
+
person['track_id'] = int(person_count - 1)
|
| 277 |
+
person_results.append(person)
|
| 278 |
+
return person_results
|
| 279 |
+
else:
|
| 280 |
+
return None
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def postprocess_threading(self, outputs, shape_raw, person_results, i, single_person=True, **kwargs):
|
| 284 |
+
result = self.postprocess(outputs[i], shape_raw[i], cat_id=self.cat_id)
|
| 285 |
+
result = self.process_results(result, shape_raw[i], cat_id=self.cat_id, single_person=single_person)
|
| 286 |
+
if result is not None and len(result) != 0:
|
| 287 |
+
person_results[i] = result
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def forward(self, img, shape_raw, **kwargs):
|
| 291 |
+
"""
|
| 292 |
+
Performs inference using an ONNX model and returns the output image with drawn detections.
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
output_img: The output image with drawn detections.
|
| 296 |
+
"""
|
| 297 |
+
if isinstance(img, torch.Tensor):
|
| 298 |
+
img = img.cpu().numpy()
|
| 299 |
+
shape_raw = shape_raw.cpu().numpy()
|
| 300 |
+
|
| 301 |
+
outputs = self.session.run(None, {self.session.get_inputs()[0].name: img})[0]
|
| 302 |
+
person_results = [[{'bbox': np.array([0., 0., 1.*shape_raw[i][1], 1.*shape_raw[i][0], -1]), 'track_id': -1}] for i in range(len(outputs))]
|
| 303 |
+
|
| 304 |
+
for i in range(len(outputs)):
|
| 305 |
+
self.postprocess_threading(outputs, shape_raw, person_results, i, **kwargs)
|
| 306 |
+
return person_results
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class ViTPose(SimpleOnnxInference):
|
| 310 |
+
def __init__(self, checkpoint, device='cuda', **kwargs):
|
| 311 |
+
super(ViTPose, self).__init__(checkpoint, device=device)
|
| 312 |
+
self.session.set_providers(["CUDAExecutionProvider"])
|
| 313 |
+
|
| 314 |
+
def forward(self, img, center, scale, **kwargs):
|
| 315 |
+
heatmaps = self.session.run([], {self.session.get_inputs()[0].name: img})[0]
|
| 316 |
+
points, prob = keypoints_from_heatmaps(heatmaps=heatmaps,
|
| 317 |
+
center=center,
|
| 318 |
+
scale=scale*200,
|
| 319 |
+
unbiased=True,
|
| 320 |
+
use_udp=False)
|
| 321 |
+
return np.concatenate([points, prob], axis=2)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
@staticmethod
|
| 325 |
+
def preprocess(img, bbox=None, input_resolution=(256, 192), rescale=1.25, mask=None, **kwargs):
|
| 326 |
+
if bbox is None or bbox[-1] <= 0 or (bbox[2] - bbox[0]) < 10 or (bbox[3] - bbox[1]) < 10:
|
| 327 |
+
bbox = np.array([0, 0, img.shape[1], img.shape[0]])
|
| 328 |
+
|
| 329 |
+
bbox_xywh = bbox
|
| 330 |
+
if mask is not None:
|
| 331 |
+
img = np.where(mask>128, img, mask)
|
| 332 |
+
|
| 333 |
+
if isinstance(input_resolution, int):
|
| 334 |
+
center, scale = bbox_from_detector(bbox_xywh, (input_resolution, input_resolution), rescale=rescale)
|
| 335 |
+
img, new_shape, old_xy, new_xy = crop(img, center, scale, (input_resolution, input_resolution))
|
| 336 |
+
else:
|
| 337 |
+
center, scale = bbox_from_detector(bbox_xywh, input_resolution, rescale=rescale)
|
| 338 |
+
img, new_shape, old_xy, new_xy = crop(img, center, scale, (input_resolution[0], input_resolution[1]))
|
| 339 |
+
|
| 340 |
+
IMG_NORM_MEAN = np.array([0.485, 0.456, 0.406])
|
| 341 |
+
IMG_NORM_STD = np.array([0.229, 0.224, 0.225])
|
| 342 |
+
img_norm = (img / 255. - IMG_NORM_MEAN) / IMG_NORM_STD
|
| 343 |
+
img_norm = img_norm.transpose(2, 0, 1).astype(np.float32)
|
| 344 |
+
return img_norm, np.array(center), np.array(scale)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
class Pose2d:
|
| 348 |
+
def __init__(self, checkpoint, detector_checkpoint=None, device='cuda', **kwargs):
|
| 349 |
+
|
| 350 |
+
if detector_checkpoint is not None:
|
| 351 |
+
self.detector = Yolo(detector_checkpoint, device)
|
| 352 |
+
else:
|
| 353 |
+
self.detector = None
|
| 354 |
+
|
| 355 |
+
self.model = ViTPose(checkpoint, device)
|
| 356 |
+
self.device = device
|
| 357 |
+
|
| 358 |
+
def load_images(self, inputs):
|
| 359 |
+
"""
|
| 360 |
+
Load images from various input types.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
inputs (Union[str, np.ndarray, List[np.ndarray]]): Input can be file path,
|
| 364 |
+
single image array, or list of image arrays
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
List[np.ndarray]: List of RGB image arrays
|
| 368 |
+
|
| 369 |
+
Raises:
|
| 370 |
+
ValueError: If file format is unsupported or image cannot be read
|
| 371 |
+
"""
|
| 372 |
+
if isinstance(inputs, str):
|
| 373 |
+
if inputs.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
|
| 374 |
+
cap = cv2.VideoCapture(inputs)
|
| 375 |
+
frames = []
|
| 376 |
+
while True:
|
| 377 |
+
ret, frame = cap.read()
|
| 378 |
+
if not ret:
|
| 379 |
+
break
|
| 380 |
+
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 381 |
+
cap.release()
|
| 382 |
+
images = frames
|
| 383 |
+
elif inputs.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp')):
|
| 384 |
+
img = cv2.cvtColor(cv2.imread(inputs), cv2.COLOR_BGR2RGB)
|
| 385 |
+
if img is None:
|
| 386 |
+
raise ValueError(f"Cannot read image: {inputs}")
|
| 387 |
+
images = [img]
|
| 388 |
+
else:
|
| 389 |
+
raise ValueError(f"Unsupported file format: {inputs}")
|
| 390 |
+
|
| 391 |
+
elif isinstance(inputs, np.ndarray):
|
| 392 |
+
images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in inputs]
|
| 393 |
+
elif isinstance(inputs, list):
|
| 394 |
+
images = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in inputs]
|
| 395 |
+
return images
|
| 396 |
+
|
| 397 |
+
def __call__(
|
| 398 |
+
self,
|
| 399 |
+
inputs: Union[str, np.ndarray, List[np.ndarray]],
|
| 400 |
+
return_image: bool = False,
|
| 401 |
+
**kwargs
|
| 402 |
+
):
|
| 403 |
+
"""
|
| 404 |
+
Process input and estimate 2D keypoints.
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
inputs (Union[str, np.ndarray, List[np.ndarray]]): Input can be file path,
|
| 408 |
+
single image array, or list of image arrays
|
| 409 |
+
**kwargs: Additional arguments for processing
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
np.ndarray: Array of detected 2D keypoints for all input images
|
| 413 |
+
"""
|
| 414 |
+
images = self.load_images(inputs)
|
| 415 |
+
H, W = images[0].shape[:2]
|
| 416 |
+
if self.detector is not None:
|
| 417 |
+
bboxes = []
|
| 418 |
+
for _image in images:
|
| 419 |
+
img, shape = self.detector.preprocess(_image)
|
| 420 |
+
bboxes.append(self.detector(img[None], shape[None])[0][0]["bbox"])
|
| 421 |
+
else:
|
| 422 |
+
bboxes = [None] * len(images)
|
| 423 |
+
|
| 424 |
+
kp2ds = []
|
| 425 |
+
for _image, _bbox in zip(images, bboxes):
|
| 426 |
+
img, center, scale = self.model.preprocess(_image, _bbox)
|
| 427 |
+
kp2ds.append(self.model(img[None], center[None], scale[None]))
|
| 428 |
+
kp2ds = np.concatenate(kp2ds, 0)
|
| 429 |
+
metas = load_pose_metas_from_kp2ds_seq(kp2ds, width=W, height=H)
|
| 430 |
+
return metas
|
wan/modules/animate/preprocess/pose2d_utils.py
ADDED
|
@@ -0,0 +1,1159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|