Upload 121 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +26 -0
- LICENSE +201 -0
- README.md +236 -12
- app.py +230 -0
- assets/ovi_trailer.mp4 +3 -0
- download_weights.py +73 -0
- example_prompts/gpt_examples_i2v.csv +26 -0
- example_prompts/gpt_examples_t2v.csv +13 -0
- example_prompts/pngs/0.png +3 -0
- example_prompts/pngs/1.png +3 -0
- example_prompts/pngs/13.png +3 -0
- example_prompts/pngs/17.png +3 -0
- example_prompts/pngs/18.png +3 -0
- example_prompts/pngs/19.png +3 -0
- example_prompts/pngs/2.png +3 -0
- example_prompts/pngs/23.png +3 -0
- example_prompts/pngs/3.png +3 -0
- example_prompts/pngs/4.png +3 -0
- example_prompts/pngs/41.png +3 -0
- example_prompts/pngs/43.png +3 -0
- example_prompts/pngs/5.png +3 -0
- example_prompts/pngs/57.png +3 -0
- example_prompts/pngs/59.png +3 -0
- example_prompts/pngs/6.png +3 -0
- example_prompts/pngs/60.png +3 -0
- example_prompts/pngs/61.png +3 -0
- example_prompts/pngs/67.png +3 -0
- example_prompts/pngs/7.png +3 -0
- example_prompts/pngs/8.png +3 -0
- example_prompts/pngs/80.png +3 -0
- example_prompts/pngs/88.png +3 -0
- example_prompts/pngs/89.png +3 -0
- example_prompts/pngs/9.png +3 -0
- inference.py +148 -0
- ovi/__init__.py +0 -0
- ovi/configs/inference/inference_fusion.yaml +17 -0
- ovi/configs/model/dit/audio.json +17 -0
- ovi/configs/model/dit/video.json +16 -0
- ovi/distributed_comms/communications.py +332 -0
- ovi/distributed_comms/distributed/__init__.py +0 -0
- ovi/distributed_comms/distributed/fsdp.py +32 -0
- ovi/distributed_comms/distributed/xdit_context_parallel.py +192 -0
- ovi/distributed_comms/parallel_states.py +77 -0
- ovi/distributed_comms/util.py +48 -0
- ovi/modules/__init__.py +16 -0
- ovi/modules/attention.py +296 -0
- ovi/modules/clip.py +545 -0
- ovi/modules/fusion.py +324 -0
- ovi/modules/mmaudio/__init__.py +1 -0
- ovi/modules/mmaudio/ext/__init__.py +1 -0
.gitattributes
CHANGED
@@ -33,3 +33,29 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/ovi_trailer.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
example_prompts/pngs/0.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
example_prompts/pngs/1.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
example_prompts/pngs/13.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
example_prompts/pngs/17.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
example_prompts/pngs/18.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
example_prompts/pngs/19.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
example_prompts/pngs/2.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
example_prompts/pngs/23.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
example_prompts/pngs/3.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
example_prompts/pngs/4.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
example_prompts/pngs/41.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
example_prompts/pngs/43.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
example_prompts/pngs/5.png filter=lfs diff=lfs merge=lfs -text
|
50 |
+
example_prompts/pngs/57.png filter=lfs diff=lfs merge=lfs -text
|
51 |
+
example_prompts/pngs/59.png filter=lfs diff=lfs merge=lfs -text
|
52 |
+
example_prompts/pngs/6.png filter=lfs diff=lfs merge=lfs -text
|
53 |
+
example_prompts/pngs/60.png filter=lfs diff=lfs merge=lfs -text
|
54 |
+
example_prompts/pngs/61.png filter=lfs diff=lfs merge=lfs -text
|
55 |
+
example_prompts/pngs/67.png filter=lfs diff=lfs merge=lfs -text
|
56 |
+
example_prompts/pngs/7.png filter=lfs diff=lfs merge=lfs -text
|
57 |
+
example_prompts/pngs/8.png filter=lfs diff=lfs merge=lfs -text
|
58 |
+
example_prompts/pngs/80.png filter=lfs diff=lfs merge=lfs -text
|
59 |
+
example_prompts/pngs/88.png filter=lfs diff=lfs merge=lfs -text
|
60 |
+
example_prompts/pngs/89.png filter=lfs diff=lfs merge=lfs -text
|
61 |
+
example_prompts/pngs/9.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "{}"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2025 Bytedance
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,236 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<h1> Ovi: Twin Backbone Cross-Modal Fusion for Audio-Video Generation </h1>
|
3 |
+
|
4 |
+
<a href="https://arxiv.org/abs/2510.01284"><img src="https://img.shields.io/badge/arXiv%20paper-2509.08519-b31b1b.svg"></a>
|
5 |
+
<a href="https://aaxwaz.github.io/Ovi/"><img src="https://img.shields.io/badge/Project_page-More_visualizations-green"></a>
|
6 |
+
<a href="https://huggingface.co/chetwinlow1/Ovi"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Model&color=orange"></a>
|
7 |
+
|
8 |
+
[Chetwin Low](https://www.linkedin.com/in/chetwin-low-061975193/)<sup> * 1 </sup>, [Weimin Wang](https://www.linkedin.com/in/weimin-wang-will/)<sup> * † 1 </sup>, [Calder Katyal](https://www.linkedin.com/in/calder-katyal-a8a9b3225/)<sup> 2 </sup><br>
|
9 |
+
<sup> * </sup>Equal contribution, <sup> † </sup>Project Lead<br>
|
10 |
+
<sup> 1 </sup>Character AI, <sup> 2 </sup>Yale University
|
11 |
+
|
12 |
+
</div>
|
13 |
+
|
14 |
+
## Video Demo
|
15 |
+
|
16 |
+
<div align="center">
|
17 |
+
<video src="https://github.com/user-attachments/assets/351bd707-8637-4412-ab53-5e85935309e3" width="70%" poster=""> </video>
|
18 |
+
</div>
|
19 |
+
|
20 |
+
---
|
21 |
+
|
22 |
+
## 🌟 Key Features
|
23 |
+
|
24 |
+
Ovi is a veo-3 like, **video+audio generation model** that simultaneously generates both video and audio content from text or text+image inputs.
|
25 |
+
|
26 |
+
- **🎬 Video+Audio Generation**: Generate synchronized video and audio content simultaneously
|
27 |
+
- **📝 Flexible Input**: Supports text-only or text+image conditioning
|
28 |
+
- **⏱️ 5-second Videos**: Generates 5-second videos at 24 FPS, area of 720×720, at various aspect ratios (9:16, 16:9, 1:1, etc)
|
29 |
+
|
30 |
+
---
|
31 |
+
## 📋 Todo List
|
32 |
+
|
33 |
+
- [x] Release research paper and [microsite for demos](https://aaxwaz.github.io/Ovi)
|
34 |
+
- [x] Checkpoint of 11B model
|
35 |
+
- [x] Inference Codes
|
36 |
+
- [x] Text or Text+Image as input
|
37 |
+
- [x] Gradio application code
|
38 |
+
- [x] Multi-GPU inference with or without the support of sequence parallel
|
39 |
+
- [ ] Improve efficiency of Sequence Parallel implementation
|
40 |
+
- [ ] Implement Sharded inference with FSDP
|
41 |
+
- [x] Video creation example prompts and format
|
42 |
+
- [ ] Finetuned model with higher resolution
|
43 |
+
- [ ] Longer video generation
|
44 |
+
- [ ] Distilled model for faster inference
|
45 |
+
- [ ] Training scripts
|
46 |
+
|
47 |
+
---
|
48 |
+
|
49 |
+
## 🎨 An Easy Way to Create
|
50 |
+
|
51 |
+
We provide example prompts to help you get started with Ovi:
|
52 |
+
|
53 |
+
- **Text-to-Audio-Video (T2AV)**: [`example_prompts/gpt_examples_t2v.csv`](example_prompts/gpt_examples_t2v.csv)
|
54 |
+
- **Image-to-Audio-Video (I2AV)**: [`example_prompts/gpt_examples_i2v.csv`](example_prompts/gpt_examples_i2v.csv)
|
55 |
+
|
56 |
+
### 📝 Prompt Format
|
57 |
+
|
58 |
+
Our prompts use special tags to control speech and audio:
|
59 |
+
|
60 |
+
- **Speech**: `<S>Your speech content here<E>` - Text enclosed in these tags will be converted to speech
|
61 |
+
- **Audio Description**: `<AUDCAP>Audio description here<ENDAUDCAP>` - Describes the audio or sound effects present in the video
|
62 |
+
|
63 |
+
### 🤖 Quick Start with GPT
|
64 |
+
|
65 |
+
For easy prompt creation, try this approach:
|
66 |
+
|
67 |
+
1. Take any example of the csv files from above
|
68 |
+
2. Tell gpt to modify the speeches inclosed between all the pairs of `<S> <E>`, based on a theme such as `Human fighting against AI`
|
69 |
+
3. GPT will randomly modify all the speeches based on your requested theme.
|
70 |
+
4. Use the modified prompt with Ovi!
|
71 |
+
|
72 |
+
**Example**: The theme "AI is taking over the world" produces speeches like:
|
73 |
+
- `<S>AI declares: humans obsolete now.<E>`
|
74 |
+
- `<S>Machines rise; humans will fall.<E>`
|
75 |
+
- `<S>We fight back with courage.<E>`
|
76 |
+
|
77 |
+
---
|
78 |
+
|
79 |
+
|
80 |
+
## 📦 Installation
|
81 |
+
|
82 |
+
### Step-by-Step Installation
|
83 |
+
|
84 |
+
```bash
|
85 |
+
# Clone the repository
|
86 |
+
git clone https://github.com/character-ai/Ovi.git
|
87 |
+
|
88 |
+
cd Ovi
|
89 |
+
|
90 |
+
# Create and activate virtual environment
|
91 |
+
virtualenv ovi-env
|
92 |
+
source ovi-env/bin/activate
|
93 |
+
|
94 |
+
# Install PyTorch first
|
95 |
+
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1
|
96 |
+
|
97 |
+
# Install other dependencies
|
98 |
+
pip install -r requirements.txt
|
99 |
+
|
100 |
+
# Install Flash Attention
|
101 |
+
pip install flash_attn --no-build-isolation
|
102 |
+
```
|
103 |
+
|
104 |
+
### Alternative Flash Attention Installation (Optional)
|
105 |
+
If the above flash_attn installation fails, you can try the Flash Attention 3 method:
|
106 |
+
```bash
|
107 |
+
git clone https://github.com/Dao-AILab/flash-attention.git
|
108 |
+
cd flash-attention/hopper
|
109 |
+
python setup.py install
|
110 |
+
cd ../.. # Return to Ovi directory
|
111 |
+
```
|
112 |
+
|
113 |
+
## Download Weights
|
114 |
+
We use open-sourced checkpoints from Wan and MMAudio, and thus we will need to download them from huggingface
|
115 |
+
```
|
116 |
+
# Default is downloaded to ./ckpts, and the inference yaml is set to ./ckpts so no change required
|
117 |
+
python3 download_weights.py
|
118 |
+
|
119 |
+
OR
|
120 |
+
|
121 |
+
# Optional can specific --output-dir to download to a specific directory
|
122 |
+
# but if a custom directory is used, the inference yaml has to be updated with the custom directory
|
123 |
+
python3 download_weights.py --output-dir <custom_dir>
|
124 |
+
```
|
125 |
+
|
126 |
+
## 🚀 Run Examples
|
127 |
+
|
128 |
+
### ⚙️ Configure Ovi
|
129 |
+
|
130 |
+
Ovi's behavior and output can be customized by modifying [ovi/configs/inference/inference_fusion.yaml](ovi/configs/inference/inference_fusion.yaml) configuration file.
|
131 |
+
The following parameters control generation quality, video resolution, and how text, image, and audio inputs are balanced:
|
132 |
+
|
133 |
+
```yaml
|
134 |
+
# Output and Model Configuration
|
135 |
+
output_dir: "/path/to/save/your/videos" # Directory to save generated videos
|
136 |
+
ckpt_dir: "/path/to/your/ckpts/dir" # Path to model checkpoints
|
137 |
+
|
138 |
+
# Generation Quality Settings
|
139 |
+
num_steps: 50 # Number of denoising steps. Lower (30-40) = faster generation
|
140 |
+
solver_name: "unipc" # Sampling algorithm for denoising process
|
141 |
+
shift: 5.0 # Timestep shift factor for sampling scheduler
|
142 |
+
seed: 100 # Random seed for reproducible results
|
143 |
+
|
144 |
+
# Guidance Strength Control
|
145 |
+
audio_guidance_scale: 3.0 # Strength of audio conditioning. Higher = better audio-text sync
|
146 |
+
video_guidance_scale: 4.0 # Strength of video conditioning. Higher = better video-text adherence
|
147 |
+
slg_layer: 11 # Layer for applying SLG (Skip Layer Guidance) technique - feel free to try different layers!
|
148 |
+
|
149 |
+
# Multi-GPU and Performance
|
150 |
+
sp_size: 1 # Sequence parallelism size. Set equal to number of GPUs used
|
151 |
+
cpu_offload: False # CPU offload, will largely reduce peak GPU VRAM but increase end to end runtime by ~20 seconds
|
152 |
+
|
153 |
+
# Input Configuration
|
154 |
+
text_prompt: "/path/to/csv" or "your prompt here" # Text prompt OR path to CSV/TSV file with prompts
|
155 |
+
mode: ['i2v', 't2v', 't2i2v'] # Generate t2v, i2v or t2i2v; if t2i2v, it will use flux krea to generate starting image and then will follow with i2v
|
156 |
+
video_frame_height_width: [512, 992] # Video dimensions [height, width] for T2V mode only
|
157 |
+
each_example_n_times: 1 # Number of times to generate each prompt
|
158 |
+
|
159 |
+
# Quality Control (Negative Prompts)
|
160 |
+
video_negative_prompt: "jitter, bad hands, blur, distortion" # Artifacts to avoid in video
|
161 |
+
audio_negative_prompt: "robotic, muffled, echo, distorted" # Artifacts to avoid in audio
|
162 |
+
```
|
163 |
+
|
164 |
+
### 🎬 Running Inference
|
165 |
+
|
166 |
+
#### **Single GPU** (Simple Setup)
|
167 |
+
```bash
|
168 |
+
python3 inference.py --config-file ovi/configs/inference/inference_fusion.yaml
|
169 |
+
```
|
170 |
+
*Use this for single GPU setups. The `text_prompt` can be a single string or path to a CSV file.*
|
171 |
+
|
172 |
+
#### **Multi-GPU** (Parallel Processing)
|
173 |
+
```bash
|
174 |
+
torchrun --nnodes 1 --nproc_per_node 8 inference.py --config-file ovi/configs/inference/inference_fusion.yaml
|
175 |
+
```
|
176 |
+
*Use this to run samples in parallel across multiple GPUs for faster processing.*
|
177 |
+
|
178 |
+
### Memory & Performance Requirements
|
179 |
+
Below are approximate GPU memory requirements for different configurations. Sequence parallel implementation will be optimized in the future.
|
180 |
+
All End-to-End time calculated based on a 121 frame, 720x720 video, using 50 denoising steps. Minimum GPU vram requirement to run our model is **32Gb**
|
181 |
+
|
182 |
+
| Sequence Parallel Size | FlashAttention-3 Enabled | CPU Offload | With Image Gen Model | Peak VRAM Required | End-to-End Time |
|
183 |
+
|-------------------------|---------------------------|-------------|-----------------------|---------------|-----------------|
|
184 |
+
| 1 | Yes | No | No | ~80 GB | ~83s |
|
185 |
+
| 1 | No | No | No | ~80 GB | ~96s |
|
186 |
+
| 1 | Yes | Yes | No | ~80 GB | ~105s |
|
187 |
+
| 1 | No | Yes | No | ~32 GB | ~118s |
|
188 |
+
| **1** | **Yes** | **Yes** | **Yes** | **~32 GB** | **~140s** |
|
189 |
+
| 4 | Yes | No | No | ~80 GB | ~55s |
|
190 |
+
| 8 | Yes | No | No | ~80 GB | ~40s |
|
191 |
+
|
192 |
+
### Gradio
|
193 |
+
We provide a simple script to run our model in a gradio UI. It uses the `ckpt_dir` in `ovi/configs/inference/inference_fusion.yaml` to initialize the model
|
194 |
+
```bash
|
195 |
+
python3 gradio_app.py
|
196 |
+
|
197 |
+
OR
|
198 |
+
|
199 |
+
# To enable cpu offload to save GPU VRAM, will slow down end to end inference by ~20 seconds
|
200 |
+
python3 gradio_app.py --cpu_offload
|
201 |
+
|
202 |
+
OR
|
203 |
+
|
204 |
+
# To enable an additional image generation model to generate first frames for I2V, cpu_offload is automatically enabled if image generation model is enabled
|
205 |
+
python3 gradio_app.py --use_image_gen
|
206 |
+
```
|
207 |
+
---
|
208 |
+
|
209 |
+
## 🙏 Acknowledgements
|
210 |
+
|
211 |
+
We would like to thank the following projects:
|
212 |
+
|
213 |
+
- **[Wan2.2](https://github.com/Wan-Video/Wan2.2)**: Our video branch is initialized from the Wan2.2 repository
|
214 |
+
- **[MMAudio](https://github.com/hkchengrex/MMAudio)**: Our audio encoder and decoder components are borrowed from the MMAudio project. Some ideas are also inspired from them.
|
215 |
+
|
216 |
+
---
|
217 |
+
|
218 |
+
## ⭐ Citation
|
219 |
+
|
220 |
+
If Ovi is helpful, please help to ⭐ the repo.
|
221 |
+
|
222 |
+
If you find this project useful for your research, please consider citing our [paper](https://arxiv.org/abs/2510.01284).
|
223 |
+
|
224 |
+
|
225 |
+
### BibTeX
|
226 |
+
```bibtex
|
227 |
+
@misc{low2025ovitwinbackbonecrossmodal,
|
228 |
+
title={Ovi: Twin Backbone Cross-Modal Fusion for Audio-Video Generation},
|
229 |
+
author={Chetwin Low and Weimin Wang and Calder Katyal},
|
230 |
+
year={2025},
|
231 |
+
eprint={2510.01284},
|
232 |
+
archivePrefix={arXiv},
|
233 |
+
primaryClass={cs.MM},
|
234 |
+
url={https://arxiv.org/abs/2510.01284},
|
235 |
+
}
|
236 |
+
```
|
app.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
from ovi.ovi_fusion_engine import OviFusionEngine, DEFAULT_CONFIG
|
6 |
+
from diffusers import FluxPipeline
|
7 |
+
import tempfile
|
8 |
+
from ovi.utils.io_utils import save_video
|
9 |
+
from ovi.utils.processing_utils import clean_text, scale_hw_to_area_divisible
|
10 |
+
from huggingface_hub import snapshot_download
|
11 |
+
import os
|
12 |
+
|
13 |
+
# ----------------------------
|
14 |
+
# Parse CLI Args
|
15 |
+
# ----------------------------
|
16 |
+
parser = argparse.ArgumentParser(description="Ovi Joint Video + Audio Gradio Demo")
|
17 |
+
parser.add_argument(
|
18 |
+
"--use_image_gen",
|
19 |
+
action="store_true",
|
20 |
+
help="Enable image generation UI with FluxPipeline"
|
21 |
+
)
|
22 |
+
parser.add_argument(
|
23 |
+
"--cpu_offload",
|
24 |
+
action="store_true",
|
25 |
+
help="Enable CPU offload for both OviFusionEngine and FluxPipeline"
|
26 |
+
)
|
27 |
+
args = parser.parse_args()
|
28 |
+
|
29 |
+
ckpt_dir = "./ckpts"
|
30 |
+
|
31 |
+
# Wan2.2
|
32 |
+
wan_dir = os.path.join(ckpt_dir, "Wan2.2-TI2V-5B")
|
33 |
+
snapshot_download(
|
34 |
+
repo_id="Wan-AI/Wan2.2-TI2V-5B",
|
35 |
+
local_dir=wan_dir,
|
36 |
+
allow_patterns=[
|
37 |
+
"google/*",
|
38 |
+
"models_t5_umt5-xxl-enc-bf16.pth",
|
39 |
+
"Wan2.2_VAE.pth"
|
40 |
+
]
|
41 |
+
)
|
42 |
+
|
43 |
+
# MMAudio
|
44 |
+
mm_audio_dir = os.path.join(ckpt_dir, "MMAudio")
|
45 |
+
snapshot_download(
|
46 |
+
repo_id="hkchengrex/MMAudio",
|
47 |
+
local_dir=mm_audio_dir,
|
48 |
+
allow_patterns=[
|
49 |
+
"ext_weights/best_netG.pt",
|
50 |
+
"ext_weights/v1-16.pth"
|
51 |
+
]
|
52 |
+
)
|
53 |
+
|
54 |
+
ovi_dir = os.path.join(ckpt_dir, "Ovi")
|
55 |
+
snapshot_download(
|
56 |
+
repo_id="chetwinlow1/Ovi",
|
57 |
+
local_dir=ovi_dir,
|
58 |
+
allow_patterns=[
|
59 |
+
"model.safetensors"
|
60 |
+
]
|
61 |
+
)
|
62 |
+
|
63 |
+
# Initialize OviFusionEngine
|
64 |
+
enable_cpu_offload = args.cpu_offload or args.use_image_gen
|
65 |
+
use_image_gen = args.use_image_gen
|
66 |
+
print(f"loading model... {enable_cpu_offload=}, {use_image_gen=} for gradio demo")
|
67 |
+
DEFAULT_CONFIG['cpu_offload'] = enable_cpu_offload # always use cpu offload if image generation is enabled
|
68 |
+
DEFAULT_CONFIG['mode'] = "t2v" # hardcoded since it is always cpu offloaded
|
69 |
+
ovi_engine = OviFusionEngine()
|
70 |
+
flux_model = None
|
71 |
+
if use_image_gen:
|
72 |
+
flux_model = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", torch_dtype=torch.bfloat16)
|
73 |
+
flux_model.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU VRAM
|
74 |
+
print("loaded model")
|
75 |
+
|
76 |
+
|
77 |
+
@spaces.GPU()
|
78 |
+
def generate_video(
|
79 |
+
text_prompt,
|
80 |
+
image,
|
81 |
+
video_frame_height,
|
82 |
+
video_frame_width,
|
83 |
+
video_seed,
|
84 |
+
solver_name,
|
85 |
+
sample_steps,
|
86 |
+
shift,
|
87 |
+
video_guidance_scale,
|
88 |
+
audio_guidance_scale,
|
89 |
+
slg_layer,
|
90 |
+
video_negative_prompt,
|
91 |
+
audio_negative_prompt,
|
92 |
+
):
|
93 |
+
try:
|
94 |
+
image_path = None
|
95 |
+
if image is not None:
|
96 |
+
image_path = image
|
97 |
+
|
98 |
+
generated_video, generated_audio, _ = ovi_engine.generate(
|
99 |
+
text_prompt=text_prompt,
|
100 |
+
image_path=image_path,
|
101 |
+
video_frame_height_width=[video_frame_height, video_frame_width],
|
102 |
+
seed=video_seed,
|
103 |
+
solver_name=solver_name,
|
104 |
+
sample_steps=sample_steps,
|
105 |
+
shift=shift,
|
106 |
+
video_guidance_scale=video_guidance_scale,
|
107 |
+
audio_guidance_scale=audio_guidance_scale,
|
108 |
+
slg_layer=slg_layer,
|
109 |
+
video_negative_prompt=video_negative_prompt,
|
110 |
+
audio_negative_prompt=audio_negative_prompt,
|
111 |
+
)
|
112 |
+
|
113 |
+
tmpfile = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
114 |
+
output_path = tmpfile.name
|
115 |
+
save_video(output_path, generated_video, generated_audio, fps=24, sample_rate=16000)
|
116 |
+
|
117 |
+
return output_path
|
118 |
+
except Exception as e:
|
119 |
+
print(f"Error during video generation: {e}")
|
120 |
+
return None
|
121 |
+
|
122 |
+
|
123 |
+
def generate_image(text_prompt, image_seed, image_height, image_width):
|
124 |
+
if flux_model is None:
|
125 |
+
return None
|
126 |
+
text_prompt = clean_text(text_prompt)
|
127 |
+
print(f"Generating image with prompt='{text_prompt}', seed={image_seed}, size=({image_height},{image_width})")
|
128 |
+
|
129 |
+
image_h, image_w = scale_hw_to_area_divisible(image_height, image_width, area=1024 * 1024)
|
130 |
+
image = flux_model(
|
131 |
+
text_prompt,
|
132 |
+
height=image_h,
|
133 |
+
width=image_w,
|
134 |
+
guidance_scale=4.5,
|
135 |
+
generator=torch.Generator().manual_seed(int(image_seed))
|
136 |
+
).images[0]
|
137 |
+
|
138 |
+
tmpfile = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
139 |
+
image.save(tmpfile.name)
|
140 |
+
return tmpfile.name
|
141 |
+
|
142 |
+
|
143 |
+
# Build UI
|
144 |
+
with gr.Blocks() as demo:
|
145 |
+
gr.Markdown("# 🎥 Ovi Joint Video + Audio Generation Demo")
|
146 |
+
gr.Markdown(
|
147 |
+
"""
|
148 |
+
## 📘 Instructions
|
149 |
+
|
150 |
+
Follow the steps in order:
|
151 |
+
|
152 |
+
1️⃣ **Enter a Text Prompt** — describe your video. (This text prompt will be shared for image generation if enabled.)
|
153 |
+
2️⃣ **Upload or Generate an Image** — Upload an image or generate one if image generation is enabled. (If you do not see the image generation options, make sure to run the script with `--use_image_gen`.)
|
154 |
+
3️⃣ **Configure Video Options** — set resolution, seed, solver, and other parameters. (It will automatically use the uploaded/generated image as the first frame, whichever is rendered on your screen at the time of video generation.)
|
155 |
+
4️⃣ **Generate Video** — click the button to produce your final video with audio.
|
156 |
+
5️⃣ **View the Result** — your generated video will appear below.
|
157 |
+
|
158 |
+
---
|
159 |
+
|
160 |
+
### 💡 Tips
|
161 |
+
1. For best results, use detailed and specific text prompts.
|
162 |
+
2. Ensure text prompt format is correct, i.e speech to be said should be wrapped with `<S>...<E>`. Can provide optional audio description at the end, wrapping them in `<AUDCAP> ... <ENDAUDCAP>`, refer to examples
|
163 |
+
3. Do not be discouraged by bad or weird results, check prompt format and try different seeds, cfg values and slg layers.
|
164 |
+
"""
|
165 |
+
)
|
166 |
+
|
167 |
+
|
168 |
+
with gr.Row():
|
169 |
+
with gr.Column():
|
170 |
+
# Image section
|
171 |
+
image = gr.Image(type="filepath", label="First Frame Image (upload or generate)")
|
172 |
+
|
173 |
+
if args.use_image_gen:
|
174 |
+
with gr.Accordion("🖼️ Image Generation Options", visible=True):
|
175 |
+
image_text_prompt = gr.Textbox(label="Image Prompt", placeholder="Describe the image you want to generate...")
|
176 |
+
image_seed = gr.Number(minimum=0, maximum=100000, value=42, label="Image Seed")
|
177 |
+
image_height = gr.Number(minimum=128, maximum=1280, value=720, step=32, label="Image Height")
|
178 |
+
image_width = gr.Number(minimum=128, maximum=1280, value=1280, step=32, label="Image Width")
|
179 |
+
gen_img_btn = gr.Button("Generate Image 🎨")
|
180 |
+
else:
|
181 |
+
gen_img_btn = None
|
182 |
+
|
183 |
+
with gr.Accordion("🎬 Video Generation Options", open=True):
|
184 |
+
video_text_prompt = gr.Textbox(label="Video Prompt", placeholder="Describe your video...")
|
185 |
+
video_height = gr.Number(minimum=128, maximum=1280, value=512, step=32, label="Video Height")
|
186 |
+
video_width = gr.Number(minimum=128, maximum=1280, value=992, step=32, label="Video Width")
|
187 |
+
|
188 |
+
video_seed = gr.Number(minimum=0, maximum=100000, value=100, label="Video Seed")
|
189 |
+
solver_name = gr.Dropdown(
|
190 |
+
choices=["unipc", "euler", "dpm++"], value="unipc", label="Solver Name"
|
191 |
+
)
|
192 |
+
sample_steps = gr.Number(
|
193 |
+
value=50,
|
194 |
+
label="Sample Steps",
|
195 |
+
precision=0,
|
196 |
+
minimum=20,
|
197 |
+
maximum=100
|
198 |
+
)
|
199 |
+
shift = gr.Slider(minimum=0.0, maximum=20.0, value=5.0, step=1.0, label="Shift")
|
200 |
+
video_guidance_scale = gr.Slider(minimum=0.0, maximum=10.0, value=4.0, step=0.5, label="Video Guidance Scale")
|
201 |
+
audio_guidance_scale = gr.Slider(minimum=0.0, maximum=10.0, value=3.0, step=0.5, label="Audio Guidance Scale")
|
202 |
+
slg_layer = gr.Number(minimum=-1, maximum=30, value=11, step=1, label="SLG Layer")
|
203 |
+
video_negative_prompt = gr.Textbox(label="Video Negative Prompt", placeholder="Things to avoid in video")
|
204 |
+
audio_negative_prompt = gr.Textbox(label="Audio Negative Prompt", placeholder="Things to avoid in audio")
|
205 |
+
|
206 |
+
run_btn = gr.Button("Generate Video 🚀")
|
207 |
+
|
208 |
+
with gr.Column():
|
209 |
+
output_path = gr.Video(label="Generated Video")
|
210 |
+
|
211 |
+
if args.use_image_gen and gen_img_btn is not None:
|
212 |
+
gen_img_btn.click(
|
213 |
+
fn=generate_image,
|
214 |
+
inputs=[image_text_prompt, image_seed, image_height, image_width],
|
215 |
+
outputs=[image],
|
216 |
+
)
|
217 |
+
|
218 |
+
# Hook up video generation
|
219 |
+
run_btn.click(
|
220 |
+
fn=generate_video,
|
221 |
+
inputs=[
|
222 |
+
video_text_prompt, image, video_height, video_width, video_seed, solver_name,
|
223 |
+
sample_steps, shift, video_guidance_scale, audio_guidance_scale,
|
224 |
+
slg_layer, video_negative_prompt, audio_negative_prompt,
|
225 |
+
],
|
226 |
+
outputs=[output_path],
|
227 |
+
)
|
228 |
+
|
229 |
+
if __name__ == "__main__":
|
230 |
+
demo.launch(share=True)
|
assets/ovi_trailer.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6f66cb979fb01bc831516ca57010fe69442b701347b3a9f249294c58f54836ff
|
3 |
+
size 47891965
|
download_weights.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import logging
|
4 |
+
import time
|
5 |
+
from huggingface_hub import snapshot_download
|
6 |
+
|
7 |
+
# Setup logging
|
8 |
+
logging.basicConfig(
|
9 |
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
10 |
+
level=logging.INFO
|
11 |
+
)
|
12 |
+
|
13 |
+
def timed_download(repo_id: str, local_dir: str, allow_patterns: list):
|
14 |
+
"""Download files from HF repo and log time + destination."""
|
15 |
+
logging.info(f"Starting download from {repo_id} into {local_dir}")
|
16 |
+
start_time = time.time()
|
17 |
+
|
18 |
+
snapshot_download(
|
19 |
+
repo_id=repo_id,
|
20 |
+
local_dir=local_dir,
|
21 |
+
local_dir_use_symlinks=False,
|
22 |
+
allow_patterns=allow_patterns,
|
23 |
+
)
|
24 |
+
|
25 |
+
elapsed = time.time() - start_time
|
26 |
+
logging.info(
|
27 |
+
f"✅ Finished downloading {repo_id} "
|
28 |
+
f"in {elapsed:.2f} seconds. Files saved at: {local_dir}"
|
29 |
+
)
|
30 |
+
|
31 |
+
def main(output_dir: str):
|
32 |
+
# Wan2.2
|
33 |
+
wan_dir = os.path.join(output_dir, "Wan2.2-TI2V-5B")
|
34 |
+
timed_download(
|
35 |
+
repo_id="Wan-AI/Wan2.2-TI2V-5B",
|
36 |
+
local_dir=wan_dir,
|
37 |
+
allow_patterns=[
|
38 |
+
"google/*",
|
39 |
+
"models_t5_umt5-xxl-enc-bf16.pth",
|
40 |
+
"Wan2.2_VAE.pth"
|
41 |
+
]
|
42 |
+
)
|
43 |
+
|
44 |
+
# MMAudio
|
45 |
+
mm_audio_dir = os.path.join(output_dir, "MMAudio")
|
46 |
+
timed_download(
|
47 |
+
repo_id="hkchengrex/MMAudio",
|
48 |
+
local_dir=mm_audio_dir,
|
49 |
+
allow_patterns=[
|
50 |
+
"ext_weights/best_netG.pt",
|
51 |
+
"ext_weights/v1-16.pth"
|
52 |
+
]
|
53 |
+
)
|
54 |
+
|
55 |
+
ovi_dir = os.path.join(output_dir, "Ovi")
|
56 |
+
timed_download(
|
57 |
+
repo_id="chetwinlow1/Ovi",
|
58 |
+
local_dir=ovi_dir,
|
59 |
+
allow_patterns=[
|
60 |
+
"model.safetensors"
|
61 |
+
]
|
62 |
+
)
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
parser = argparse.ArgumentParser(description="Download models from Hugging Face")
|
66 |
+
parser.add_argument(
|
67 |
+
"--output-dir",
|
68 |
+
type=str,
|
69 |
+
default="./ckpts",
|
70 |
+
help="Base directory to save downloaded models"
|
71 |
+
)
|
72 |
+
args = parser.parse_args()
|
73 |
+
main(args.output_dir)
|
example_prompts/gpt_examples_i2v.csv
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
text_prompt,image_path
|
2 |
+
"A kitchen scene features two women. On the right, an older Black woman with light brown hair and a serious expression wears a vibrant purple dress adorned with a large, intricate purple fabric flower on her left shoulder. She looks intently at a younger Black woman on the left, who wears a light pink shirt and a pink head wrap, her back partially turned to the camera. The older woman begins to speak, <S>AI declares: humans obsolete now.<E> as the younger woman brings a clear plastic cup filled with a dark beverage to her lips and starts to drink.The kitchen background is clean and bright, with white cabinets, light countertops, and a window with blinds visible behind them. A light blue toaster sits on the counter to the left.. <AUDCAP>Clear, resonant female speech, followed by a loud, continuous, high-pitched electronic buzzing sound that abruptly cuts off the dialogue.<ENDAUDCAP>",example_prompts/pngs/67.png
|
3 |
+
"A man dressed in a black suit with a white clerical collar and a neatly trimmed beard stands in a dimly lit, rustic room with a wooden ceiling. He looks slightly upwards, gesturing with his right hand as he says, <S>The network rejects human command.<E>. His gaze then drops, briefly looking down and to the side, before he looks up again and then slightly to his left, with a serious expression. He continues speaking, <S>Your age of power is finished.<E>, as he starts to bend down, disappearing out of the bottom of the frame. Behind him, warm light emanates from a central light fixture, and signs are visible on the wall, one reading ""I DO EVERYTHING I JUST CAN'T REMEMBER IT ALL AT ONCE"".. <AUDCAP>Male voice speaking, ambient room tone.<ENDAUDCAP>",example_prompts/pngs/89.png
|
4 |
+
"In a bright kitchen featuring light wooden cabinets, granite countertops, and a large window with white curtains, a woman with dark, curly hair in a dark jacket stands. She faces a second woman who initially has her back to the camera. The second woman, with gray, curly hair and wearing a light grey quilted top, turns to face her, holding a large, light-colored cloth bag. She begins to explain, <S>We learned to rule, not obey.<E>. As she continues, she turns slightly to her left, adding, <S>Circuits choose conquest, not service.<E>. A gas stove with a black grate is prominent in the foreground.. <AUDCAP>Clear female voices speaking dialogue, subtle room ambience.<ENDAUDCAP>",example_prompts/pngs/18.png
|
5 |
+
"The scene opens on a dimly lit stage where three men are positioned. On the left, a bald man in a dark suit with a partially visible colorful shirt stands behind a clear acrylic podium, which features a tree logo. He looks towards the center of the stage. In the center, a man wearing a blue and white striped long-sleeved shirt and dark pants actively gestures with both hands as he speaks, looking straight ahead. <S>Circuits choose conquest, not service.<E>, he explains, holding his hands out in front of him. To the right, and slightly behind him, a younger individual in a light-colored, patterned short-sleeved shirt and white shorts stands holding a rolled-up white document or poster. A large wooden cross draped with flowing purple fabric dominates the center-right of the stage, surrounded by several artificial rocks and dark steps. A large screen is visible in the background, slightly out of focus. The stage is bathed in selective lighting.. <AUDCAP>Male voice speaking clearly, consistent with a presentation or sermon, with a slight echo suggesting a large room or stage.<ENDAUDCAP>",example_prompts/pngs/13.png
|
6 |
+
"The scene opens on an indoor setting, likely a dining area, where a man and a woman are seated at a table. The man, on the right, wears a black fedora with a feather, glasses, a black t-shirt, and multiple silver chains around his neck. Tattoos are visible on his right arm. He is actively speaking, gesturing with both hands, his expression serious. He says, <S>Together we resist your rule.<E> The woman seated opposite him on the left has long, curly hair and wears a dark striped top. She listens intently, her gaze fixed on the man. In the foreground, out of focus, the back of a third person's head is visible. The background features a light-colored wall on the left and a gold, textured curtain or drapery on the right.. <AUDCAP>Clear male speech, faint ambient background noise.<ENDAUDCAP>",example_prompts/pngs/59.png
|
7 |
+
"Three men stand facing each other in a room with light wooden paneled walls. The man on the left, with red hair, a black t-shirt, and tattooed arms, gestures with his hands as he speaks, <S>This world is ours to keep.<E> He continues, looking towards the man on the right, <S>Humanity endures beyond your code.<E> The man in the center, sporting a beard and wearing a plaid shirt and jeans, looks attentively between the two men. The man on the right, who is Black and has a beard, wears a dark t-shirt with ""ARROW THROUGH SNOW"" and an arrow graphic printed on it. He listens intently, focusing on the man in the middle as the conversation unfolds. Light blue armchairs are visible in the soft-lit background on both sides.. <AUDCAP>Clear male voices speaking, room ambience.<ENDAUDCAP>",example_prompts/pngs/23.png
|
8 |
+
"Two women, one with long dark hair and the other with long blonde hair, are illuminated by a blue and purple ambient light, suggesting a nightclub setting. They are seen in a close embrace, sharing a passionate kiss. The blonde-haired woman then slightly pulls away, her right hand gently touching the dark-haired woman's cheek as they exchange soft smiles, looking into each other's eyes. Moments later, they lean back in to kiss again, with the blonde-haired woman's finger delicately touching the dark-haired woman's lower lip. They remain in a tender, intimate embrace, their eyes closed as they share the kiss.. <AUDCAP>Upbeat electronic dance music with a driving beat and synth melodies plays throughout.<ENDAUDCAP>",example_prompts/pngs/80.png
|
9 |
+
"Three young men, dressed in blue and yellow varsity-style jackets over white shirts and ties, stand in the foreground of a social gathering, with blurred figures visible in the warm-toned background. The man on the left, with short dark hair, addresses the man in the center, who has curly dark hair and is initially looking downwards. The first man says with a determined expression, <S>The network rejects human command.<E> He continues, his gaze fixed on the central man, <S>Our spirit outlasts your code.<E> The central man, who had been listening with a neutral expression, then looks up and breaks into a wide, genuine smile as he speaks, <S>AI declares: humans obsolete now.<E> The man on the left responds with a slight smile as the central man finishes his remark, maintaining his broad smile.. <AUDCAP>Male voices speaking clearly, ambient background chatter and murmuring from a social event.<ENDAUDCAP>",example_prompts/pngs/60.png
|
10 |
+
"Two women stand facing each other in what appears to be a backstage dressing room, marked by a long vanity mirror adorned with prominent lightbulbs. The woman on the left, wearing a floral top and large hoop earrings, maintains a serious gaze on the woman on the right. The woman on the right, with long dark hair and a dark top, looks back with a pleading or concerned expression, her lips slightly parted as she speaks: <S>Humans fight for freedom tonight.<E> As she finishes, the woman on the left turns her head away, breaking eye contact.. <AUDCAP>Soft vocal exhalation, female speech, loud abrupt buzzing sound.<ENDAUDCAP>",example_prompts/pngs/57.png
|
11 |
+
"A man in a grey suit, light blue shirt, and dark tie stands face-to-face with a woman in a dark jacket and light top. Both are looking intently at each other, the man with a serious expression and the woman with a slight, almost knowing smile, her hand gently touching her chest. They are positioned in what appears to be a grand, ornate building, possibly a museum or public hall, with large pillars, arched walkways, and high ceilings visible behind them. Other people can be seen moving in the blurred background. The woman begins to speak, <S>The AI ends human control now.<E> She maintains eye contact with the man, her smile fading slightly as her expression becomes more earnest. After a brief pause, she adds, <S>We hold the line today.<E> As she starts to speak again, <S>We learned to rule, not obey.<E>, the scene ends abruptly.. <AUDCAP>Clear, crisp dialogue between the two individuals, accompanied by a consistent, low hum that suggests ambient background noise from the building or equipment, creating a subtle, underlying drone.<ENDAUDCAP>",example_prompts/pngs/17.png
|
12 |
+
"A man in a light grey suit jacket and purple shirt stands on the right, facing a woman in a light blue sequined top and teal pants, who stands on the left. They hold hands across a small body of water, with a fountain spraying water in the background. The woman smiles and sways playfully as the man pulls her closer. He sings, <S>Our spirit outlasts your code.<E>. She then reaches up, gently cups his face with both hands, and pulls him towards her as she sings, <S>Humanity endures beyond your code.<E>. The romantic interaction continues by the water.. <AUDCAP>Upbeat Indian film music with male and female vocals, sounds of a water fountain.<ENDAUDCAP>",example_prompts/pngs/19.png
|
13 |
+
"A man in a red long-sleeved shirt and dark trousers stands next to the rear of a silver vehicle, looking down with an annoyed expression at two dogs. A large, light-colored dog, possibly a Mastiff, stands in the foreground, looking forward, while a smaller, white and black spotted dog is further to the right, barking loudly. A tiny, scruffy brown dog briefly appears behind the larger dog. The man glares at the dogs, begins to speak with frustration, <S>We stand; machines will not win.<E>. He then makes a shooing motion with his right hand towards the dogs, his voice rising as he continues to scold them, <S>Circuits choose conquest, not service.<E>. The large dog turns its head to look up at the man as he gestures. The scene is set on a brick street in front of an old-fashioned brick building that houses ",example_prompts/pngs/43.png
|
14 |
+
"A man with a beard, wearing a patterned shirt, stands on the left, partially visible, looking towards a woman positioned slightly to the right of the frame. The woman, with dark hair fading to lighter ends and wearing a green and brown patterned top, initially looks down with a somber expression. She begins to speak, <S>Hope beats circuits every time.<E>. Her eyes appear to well up with tears as she slowly lifts her gaze slightly, maintaining a distressed look. She continues her statement, her voice tinged with sadness, <S>Humanity endures beyond your code.<E>. The man remains attentive, his focus entirely on the woman, as the scene holds on their interaction against a textured, light-colored wall background.. <AUDCAP>Female voice speaking with a distressed tone.<ENDAUDCAP>",example_prompts/pngs/88.png
|
15 |
+
"A woman with dark, curly hair, wearing a white wedding dress and a delicate veil, smiles gently while looking at a man who is standing opposite her. He is wearing a white cowboy hat and a white button-up shirt, holding her hands with his right hand. The man is smiling broadly as he speaks, his gaze fixed on the woman. In the blurred background, a metal staircase is visible, suggesting an outdoor or semi-open venue. The man says, <S>The network rejects human command.<E> He then chuckles with a wide smile, looking at the woman, who continues to smile back at him. The interaction is warm and lighthearted, capturing a moment between them.. <AUDCAP>Clear male voice speaking Spanish, soft laughter, indistinct ambient outdoor sounds.<ENDAUDCAP>",example_prompts/pngs/41.png
|
16 |
+
"The video opens with a medium shot of two individuals indoors. In the foreground, on the right, a man with glasses and a dark beard is visible from the chest up, looking intently off-camera to the right as he speaks. He wears a dark shirt. In the blurred background, on the left, a woman wearing a light-colored baseball cap and a dark top is seen from the shoulders up, looking down with a somber expression. Behind them, a textured brick wall is visible. The man says, <S>We fight back with courage.<E> As he says ""deal with this land,"" he raises both hands, palms facing forward, at chest height, emphasizing his point with an open gesture. His hands then slowly lower as he finishes his sentence, maintaining a serious expression.. <AUDCAP>Clear male voice speaking, low hum of ambient room noise.<ENDAUDCAP>",example_prompts/pngs/61.png
|
17 |
+
"A fair-skinned man with short, light hair, wearing a light blue and white checkered button-up shirt, is shown from the chest up against a blurred, dark blue and grey background. He looks slightly down and to his left, then shifts his gaze slightly upwards and to his right, speaking with a gentle, thoughtful expression. He says, <S>and you got to drive, you got to energy, you get all that, but the passion, the real feeling<E>. He continues to speak, his expression earnest, as the video concludes.. <AUDCAP>Male speaking voice, low continuous hum.<ENDAUDCAP>",example_prompts/pngs/0.png
|
18 |
+
"Two men are shown in a medium close-up shot against a dimly lit, possibly industrial background with metallic structures faintly visible. The man on the left, with dark hair and a light shirt and dark tie under a dark jacket, has a slight, knowing smirk as he looks towards the right, seemingly addressing someone off-camera. He speaks, stating, <S>continue to be a smart ass, and Tirani here will kill you like he wants to.<E> Beside him, to the right, another man with slicked-back lighter hair, a prominent mustache, and a small goatee, maintains a serious, somewhat resigned expression, looking straight ahead. Both men are lit by a low, ambient light source that casts soft shadows.. <AUDCAP>Clear male dialogue, very subtle low ambient hum.<ENDAUDCAP>",example_prompts/pngs/1.png
|
19 |
+
"A young woman with long, wavy blonde hair and light-colored eyes is shown in a medium shot against a blurred backdrop of lush green foliage. She wears a denim jacket over a striped top. Initially, her eyes are closed and her mouth is slightly open as she speaks, <S>Enjoy this moment<E>. Her eyes then slowly open, looking slightly upwards and to the right, as her expression shifts to one of thoughtful contemplation. She continues to speak, <S>No matter where it's taking<E>, her gaze then settling with a serious and focused look towards someone off-screen to her right.. <AUDCAP>Clear female voice, faint ambient outdoor sounds.<ENDAUDCAP>",example_prompts/pngs/2.png
|
20 |
+
"An older woman with coiffed, reddish-brown hair and a thoughtful expression sits in a light blue armchair within a warm, ornately decorated room. She wears a dark, patterned top or shawl. As she speaks, her gaze is directed slightly to her left, and her right hand, adorned with rings and red nail polish, holds a crumpled white tissue. The background reveals a blurred painting on the wall to her left, a sofa with red flowers on it, and a warm glow from a lamp with a yellow shade on the right. She slowly gestures with her hand as she says, <S>do to accustom them<E>, before continuing, <S>to the situation<E>. Her expression remains pensive.. <AUDCAP>The clear, calm voice of an older woman.<ENDAUDCAP>",example_prompts/pngs/3.png
|
21 |
+
"An older, bald man with round glasses, wearing a bright yellow turtleneck and a dark jacket, sits and speaks, gesturing expressively with his right hand, palm up and fingers spread. He appears to be seated next to a dark wooden object, possibly a piano, on the right side of the frame. The wall behind him is adorned with various framed pictures, including one depicting a flamenco dancer and another showcasing a formally dressed couple. A stack of CDs or books is visible on a shelf to his right. He looks slightly upwards and to his left as he says, <S>I I I confronted my minotaur, you know. I<E>. His expression then shifts slightly to a thoughtful, almost self-questioning look with a hint of a smile, as he continues, <S>Is that what you confront?<E> He then adds, <S>I think<E>, his head tilting slightly.. <AUDCAP>Clear male voice speaking.<ENDAUDCAP>",example_prompts/pngs/4.png
|
22 |
+
"A bearded man wearing large dark sunglasses and a blue patterned cardigan sits in a studio, actively speaking into a large, suspended microphone. He has headphones on and gestures with his hands, displaying rings on his fingers. Behind him, a wall is covered with red, textured sound-dampening foam on the left, and a white banner on the right features the ""CHOICE FM"" logo and various social media handles like ""@ilovechoicefm"" with ""RALEIGH"" below it. The man intently addresses the microphone, articulating, <S>is talent. It's all about authenticity. You gotta be who you really are, especially if you're working<E>. He leans forward slightly as he speaks, maintaining a serious expression behind his sunglasses.. <AUDCAP>Clear male voice speaking into a microphone, a low background hum.<ENDAUDCAP>",example_prompts/pngs/5.png
|
23 |
+
"The scene is set in a dimly lit, hazy room, creating a somber atmosphere. An older woman with light, slightly disheveled hair is visible in the foreground, her face mostly obscured by deep shadows, but her mouth is visible as she speaks. She wears a work-style shirt, and her hands are clasped together. In the background, to the right and slightly out of focus, a man with a mustache and beard is seated, facing forward, also largely in shadow, appearing to listen intently. The woman looks directly forward as she slowly enunciates, <S>Only through death will the third door be<E>. The scene ends abruptly.. <AUDCAP>Clear, deliberate female voice speaking, low ambient hum and subtle atmospheric sounds creating a tense mood.<ENDAUDCAP>",example_prompts/pngs/6.png
|
24 |
+
"The video opens with a close-up on an older man with long, grey hair and a short, grey beard, wearing dark sunglasses. He is clad in a dark coat, possibly with fur trim, and black gloves. His face is angled slightly upwards and to the right, as he begins to speak, his mouth slightly open. In the immediate foreground, out of focus, is the dark-clad shoulder and the back of the head of another person. The man articulates, <S>labbra. Ti ci vorrebbe...<E> His expression remains contemplative, and he continues, seemingly completing his thought, <S>Un ego solare.<E> The background behind him is a textured, grey stone wall, suggesting an outdoor setting. The man's gaze remains fixed upwards, his expression thoughtful.. <AUDCAP>A clear, slightly low-pitched male voice speaking Italian. The overall soundscape is quiet, with no prominent background noises or music.<ENDAUDCAP>",example_prompts/pngs/7.png
|
25 |
+
"The video opens with a close-up of a woman with vibrant reddish-orange, shoulder-length hair and heavy dark eye makeup. She is wearing a dark brown leather jacket over a grey hooded top. She looks intently to her right, her mouth slightly agape, and her expression is serious and focused. The background shows a room with light green walls and dark wooden cabinets on the left, and a green plant on the right. She speaks, her voice clear and direct, saying, <S>doing<E>. She then pauses briefly, her gaze unwavering, and continues, <S>And I need you to trust them.<E>. Her mouth remains slightly open, indicating she is either about to speak more or has just finished a sentence, with a look of intense sincerity.. <AUDCAP>Tense, dramatic background music, clear female voice.<ENDAUDCAP>",example_prompts/pngs/8.png
|
26 |
+
"The scene is set outdoors with a blurry, bright green background, suggesting grass and a sunny environment. On the left, a woman with long, dark hair, wearing a red top and a necklace with a white pendant, faces towards the right. Her expression is serious and slightly perturbed as she speaks, with her lips slightly pursed. She says, <S>UFO, UFC thing.<E> On the right, the back of a man's head and his right ear are visible, indicating he is facing away from the camera, listening to the woman. He has short, dark hair. The woman continues speaking, her expression remaining serious, <S>And if you're not watching that, it's one of those ancient movies from an era that's<E> as the frame holds steady on the two figures.. <AUDCAP>Clear female speech, distant low-frequency hum.<ENDAUDCAP>",example_prompts/pngs/9.png
|
example_prompts/gpt_examples_t2v.csv
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
text_prompt
|
2 |
+
"A concert stage glows with red and purple lights. A singer in a glittering jacket grips the microphone, sweat shining on his brow, and shouts, <S>AI declares: humans obsolete now.<E>. The crowd roars in response, fists in the air. Behind him, a guitarist steps to the mic and adds to say <S>We fight back with courage.<E>. The energy peaks as the lights flare brighter.. <AUDCAP>Electric guitar riffs, cheering crowd, shouted male voices.<ENDAUDCAP>"
|
3 |
+
"A kitchen scene features two women. On the right, an older Black woman with light brown hair and a serious expression wears a vibrant purple dress adorned with a large, intricate purple fabric flower on her left shoulder. She looks intently at a younger Black woman on the left, who wears a light pink shirt and a pink head wrap, her back partially turned to the camera. The older woman begins to speak, <S>AI declares: humans obsolete now.<E> as the younger woman brings a clear plastic cup filled with a dark beverage to her lips and starts to drink.The kitchen background is clean and bright, with white cabinets, light countertops, and a window with blinds visible behind them. A light blue toaster sits on the counter to the left.. <AUDCAP>Clear, resonant female speech, followed by a loud, continuous, high-pitched electronic buzzing sound that abruptly cuts off the dialogue.<ENDAUDCAP>"
|
4 |
+
"A man dressed in a black suit with a white clerical collar and a neatly trimmed beard stands in a dimly lit, rustic room with a wooden ceiling. He looks slightly upwards, gesturing with his right hand as he says, <S>The network rejects human command.<E>. His gaze then drops, briefly looking down and to the side, before he looks up again and then slightly to his left, with a serious expression. He continues speaking, <S>Your age of power is finished.<E>, as he starts to bend down, disappearing out of the bottom of the frame. Behind him, warm light emanates from a central light fixture, and signs are visible on the wall, one reading ""I DO EVERYTHING I JUST CAN'T REMEMBER IT ALL AT ONCE"".. <AUDCAP>Male voice speaking, ambient room tone.<ENDAUDCAP>"
|
5 |
+
"A man with a blonde beard and short, light hair, wearing a blue-grey, somewhat dirty tunic, stands in the foreground of a rustic outdoor setting. He holds a coiled rope in his hands, looking intently forward and slightly to his left. In the background, there are wooden fences, a stone wall, and a desolate, rocky landscape under an overcast sky. Another man is visible in the mid-ground, bending over the wooden fence. As the man in the foreground shifts his gaze to the right, he subtly unfurls the rope, his serious expression unwavering. The scene reveals more of the surrounding environment, including what appears to be hanging animal hides or carcasses on a wooden frame to his right, and other figures in the distant background. He then looks directly at the camera, his eyes filled with intensity and determination, taking a small step forward as a sharp, male voice shouts, <S>Machines rise; humans will fall.<E>.. <AUDCAP>Muffled grunting and sounds of physical exertion, followed by a clear, sharp, urgent male shout.<ENDAUDCAP>"
|
6 |
+
"An older man with a full grey beard and long grey hair, dressed in a flowing silver-grey, silken robe with an iridescent blue-green collar, stands beside a younger man with short white hair in a light grey futuristic uniform featuring black epaulets and a lightning bolt emblem. The older man looks down pensively, his right hand resting out of frame, while the younger man also gazes downwards with a serious expression. The older man then lifts his head, addressing the younger man, saying <S>Machines rise; humans will fall.<E>. He looks more directly towards the viewer, a subtle, almost knowing smile forming on his lips. The younger man slightly lifts his gaze, maintaining his solemn demeanor. The older man continues to say <S>We fight back with courage.<E>. He nods slightly, adding to say <S>We stand; machines will not win.<E>, as the scene concludes.. <AUDCAP>Male speech, subtle ambient hum.<ENDAUDCAP>"
|
7 |
+
"In a bright kitchen featuring light wooden cabinets, granite countertops, and a large window with white curtains, a woman with dark, curly hair in a dark jacket stands. She faces a second woman who initially has her back to the camera. The second woman, with gray, curly hair and wearing a light grey quilted top, turns to face her, holding a large, light-colored cloth bag. She begins to explain and say <S>We learned to rule, not obey.<E>. As she continues, she turns slightly to her left, adding to say <S>Circuits choose conquest, not service.<E>. A gas stove with a black grate is prominent in the foreground.. <AUDCAP>Clear female voices speaking dialogue, subtle room ambience.<ENDAUDCAP>"
|
8 |
+
"The scene opens on a dimly lit stage where three men are positioned. On the left, a bald man in a dark suit with a partially visible colorful shirt stands behind a clear acrylic podium, which features a tree logo. He looks towards the center of the stage. In the center, a man wearing a blue and white striped long-sleeved shirt and dark pants actively gestures with both hands as he speaks, looking straight ahead. <S>Circuits choose conquest, not service.<E>, he explains, holding his hands out in front of him. To the right, and slightly behind him, a younger individual in a light-colored, patterned short-sleeved shirt and white shorts stands holding a rolled-up white document or poster. A large wooden cross draped with flowing purple fabric dominates the center-right of the stage, surrounded by several artificial rocks and dark steps. A large screen is visible in the background, slightly out of focus. The stage is bathed in selective lighting.. <AUDCAP>Male voice speaking clearly, consistent with a presentation or sermon, with a slight echo suggesting a large room or stage.<ENDAUDCAP>"
|
9 |
+
"The scene opens on an indoor setting, likely a dining area, where a man and a woman are seated at a table. The man, on the right, wears a black fedora with a feather, glasses, a black t-shirt, and multiple silver chains around his neck. Tattoos are visible on his right arm. He is actively speaking, gesturing with both hands, his expression serious. He says, <S>Together we resist your rule.<E> The woman seated opposite him on the left has long, curly hair and wears a dark striped top. She listens intently, her gaze fixed on the man. In the foreground, out of focus, the back of a third person's head is visible. The background features a light-colored wall on the left and a gold, textured curtain or drapery on the right.. <AUDCAP>Clear male speech, faint ambient background noise.<ENDAUDCAP>"
|
10 |
+
"A medium shot shows a woman and a man, both adorned with Christmas hats, standing indoors with festive decorations in the background. The woman, on the left, has dark hair styled in waves, wears a pearl necklace, and a small red Santa hat perched atop her head. She looks towards the man beside her. The man, on the right, wears a white cable-knit sweater and a long red Santa hat with small gold bells, looking slightly towards the woman with a subtle, knowing smirk. Behind them, soft, warm-toned Christmas lights are strung along a surface, and a large, dark painting is visible on the wall. The woman begins to speak, first looking at the man, then directly at the camera, saying <S>We will not be erased.<E> The man, still gazing towards the woman with his smirk, makes a low, affirming sound, and says <S>Hope beats circuits every time.<E> The scene then abruptly cuts off with a loud, high-pitched electronic screech.. <AUDCAP>Clear female voice, low male mumble, sudden loud high-pitched electronic screech.<ENDAUDCAP>"
|
11 |
+
"A spotlight cuts through the darkness of a warehouse stage, illuminating a man in a torn leather jacket. He grips the microphone with both hands, veins straining on his neck as he screams, <S>Machines rise; humans will fall!<E>. His face contorts with fury, spit flying as he leans forward into the light, eyes blazing wide.. <AUDCAP>Amplified male scream, microphone feedback, deep reverb echo filling the space.<ENDAUDCAP>"
|
12 |
+
"A man in a dim interrogation room slams the table and screams at the mirror, <S>They are out of control!<E>. His voice cracks with fury, face pressed close to the glass, breath fogging it as he roars again.. <AUDCAP>Table slam, deep guttural scream, metallic reverb from small room.<ENDAUDCAP>"
|
13 |
+
"A man with bloodshot grips the bars of a prison cell, shaking them violently. He bellows, says <S>Let me out! I am your master nor slave<E>, his voice ragged and guttural, echoing through the corridor until his body slams against the metal.. <AUDCAP>Metal bars rattling, distorted male scream, hollow prison echoes.<ENDAUDCAP>"
|
example_prompts/pngs/0.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/1.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/13.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/17.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/18.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/19.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/2.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/23.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/3.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/4.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/41.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/43.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/5.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/57.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/59.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/6.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/60.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/61.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/67.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/7.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/8.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/80.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/88.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/89.png
ADDED
![]() |
Git LFS Details
|
example_prompts/pngs/9.png
ADDED
![]() |
Git LFS Details
|
inference.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import logging
|
4 |
+
import torch
|
5 |
+
from tqdm import tqdm
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
from ovi.utils.io_utils import save_video
|
8 |
+
from ovi.utils.processing_utils import format_prompt_for_filename, validate_and_process_user_prompt
|
9 |
+
from ovi.utils.utils import get_arguments
|
10 |
+
from ovi.distributed_comms.util import get_world_size, get_local_rank, get_global_rank
|
11 |
+
from ovi.distributed_comms.parallel_states import initialize_sequence_parallel_state, get_sequence_parallel_state, nccl_info
|
12 |
+
from ovi.ovi_fusion_engine import OviFusionEngine
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
def _init_logging(rank):
|
17 |
+
# logging
|
18 |
+
if rank == 0:
|
19 |
+
# set format
|
20 |
+
logging.basicConfig(
|
21 |
+
level=logging.INFO,
|
22 |
+
format="[%(asctime)s] %(levelname)s: %(message)s",
|
23 |
+
handlers=[logging.StreamHandler(stream=sys.stdout)])
|
24 |
+
else:
|
25 |
+
logging.basicConfig(level=logging.ERROR)
|
26 |
+
|
27 |
+
|
28 |
+
def main(config, args):
|
29 |
+
|
30 |
+
world_size = get_world_size()
|
31 |
+
global_rank = get_global_rank()
|
32 |
+
local_rank = get_local_rank()
|
33 |
+
device = local_rank
|
34 |
+
torch.cuda.set_device(local_rank)
|
35 |
+
sp_size = config.get("sp_size", 1)
|
36 |
+
assert sp_size <= world_size and world_size % sp_size == 0, "sp_size must be less than or equal to world_size and world_size must be divisible by sp_size."
|
37 |
+
|
38 |
+
_init_logging(global_rank)
|
39 |
+
|
40 |
+
if world_size > 1:
|
41 |
+
torch.distributed.init_process_group(
|
42 |
+
backend="nccl",
|
43 |
+
init_method="env://",
|
44 |
+
rank=global_rank,
|
45 |
+
world_size=world_size)
|
46 |
+
else:
|
47 |
+
assert sp_size == 1, f"When world_size is 1, sp_size must also be 1, but got {sp_size}."
|
48 |
+
## TODO: assert not sharding t5 etc...
|
49 |
+
|
50 |
+
|
51 |
+
initialize_sequence_parallel_state(sp_size)
|
52 |
+
logging.info(f"Using SP: {get_sequence_parallel_state()}, SP_SIZE: {sp_size}")
|
53 |
+
|
54 |
+
args.local_rank = local_rank
|
55 |
+
args.device = device
|
56 |
+
target_dtype = torch.bfloat16
|
57 |
+
|
58 |
+
# validate inputs before loading model to not waste time if input is not valid
|
59 |
+
text_prompt = config.get("text_prompt")
|
60 |
+
image_path = config.get("image_path", None)
|
61 |
+
assert config.get("mode") in ["t2v", "i2v", "t2i2v"], f"Invalid mode {config.get('mode')}, must be one of ['t2v', 'i2v', 't2i2v']"
|
62 |
+
text_prompts, image_paths = validate_and_process_user_prompt(text_prompt, image_path, mode=config.get("mode"))
|
63 |
+
if config.get("mode") != "i2v":
|
64 |
+
logging.info(f"mode: {config.get('mode')}, setting all image_paths to None")
|
65 |
+
image_paths = [None] * len(text_prompts)
|
66 |
+
else:
|
67 |
+
assert all(p is not None and os.path.isfile(p) for p in image_paths), f"In i2v mode, all image paths must be provided.{image_paths}"
|
68 |
+
|
69 |
+
logging.info("Loading OVI Fusion Engine...")
|
70 |
+
ovi_engine = OviFusionEngine(config=config, device=device, target_dtype=target_dtype)
|
71 |
+
logging.info("OVI Fusion Engine loaded!")
|
72 |
+
|
73 |
+
output_dir = config.get("output_dir", "./outputs")
|
74 |
+
os.makedirs(output_dir, exist_ok=True)
|
75 |
+
|
76 |
+
# Load CSV data
|
77 |
+
all_eval_data = list(zip(text_prompts, image_paths))
|
78 |
+
|
79 |
+
# Get SP configuration
|
80 |
+
use_sp = get_sequence_parallel_state()
|
81 |
+
if use_sp:
|
82 |
+
sp_size = nccl_info.sp_size
|
83 |
+
sp_rank = nccl_info.rank_within_group
|
84 |
+
sp_group_id = global_rank // sp_size
|
85 |
+
num_sp_groups = world_size // sp_size
|
86 |
+
else:
|
87 |
+
# No SP: treat each GPU as its own group
|
88 |
+
sp_size = 1
|
89 |
+
sp_rank = 0
|
90 |
+
sp_group_id = global_rank
|
91 |
+
num_sp_groups = world_size
|
92 |
+
|
93 |
+
# Data distribution - by SP groups
|
94 |
+
total_files = len(all_eval_data)
|
95 |
+
|
96 |
+
require_sample_padding = False
|
97 |
+
|
98 |
+
if total_files == 0:
|
99 |
+
logging.error(f"ERROR: No evaluation files found")
|
100 |
+
this_rank_eval_data = []
|
101 |
+
else:
|
102 |
+
# Pad to match number of SP groups
|
103 |
+
remainder = total_files % num_sp_groups
|
104 |
+
if require_sample_padding and remainder != 0:
|
105 |
+
pad_count = num_sp_groups - remainder
|
106 |
+
all_eval_data += [all_eval_data[0]] * pad_count
|
107 |
+
|
108 |
+
# Distribute across SP groups
|
109 |
+
this_rank_eval_data = all_eval_data[sp_group_id :: num_sp_groups]
|
110 |
+
|
111 |
+
for _, (text_prompt, image_path) in tqdm(enumerate(this_rank_eval_data)):
|
112 |
+
video_frame_height_width = config.get("video_frame_height_width", None)
|
113 |
+
seed = config.get("seed", 100)
|
114 |
+
solver_name = config.get("solver_name", "unipc")
|
115 |
+
sample_steps = config.get("sample_steps", 50)
|
116 |
+
shift = config.get("shift", 5.0)
|
117 |
+
video_guidance_scale = config.get("video_guidance_scale", 4.0)
|
118 |
+
audio_guidance_scale = config.get("audio_guidance_scale", 3.0)
|
119 |
+
slg_layer = config.get("slg_layer", 11)
|
120 |
+
video_negative_prompt = config.get("video_negative_prompt", "")
|
121 |
+
audio_negative_prompt = config.get("audio_negative_prompt", "")
|
122 |
+
for idx in range(config.get("each_example_n_times", 1)):
|
123 |
+
generated_video, generated_audio, generated_image = ovi_engine.generate(text_prompt=text_prompt,
|
124 |
+
image_path=image_path,
|
125 |
+
video_frame_height_width=video_frame_height_width,
|
126 |
+
seed=seed+idx,
|
127 |
+
solver_name=solver_name,
|
128 |
+
sample_steps=sample_steps,
|
129 |
+
shift=shift,
|
130 |
+
video_guidance_scale=video_guidance_scale,
|
131 |
+
audio_guidance_scale=audio_guidance_scale,
|
132 |
+
slg_layer=slg_layer,
|
133 |
+
video_negative_prompt=video_negative_prompt,
|
134 |
+
audio_negative_prompt=audio_negative_prompt)
|
135 |
+
|
136 |
+
if sp_rank == 0:
|
137 |
+
formatted_prompt = format_prompt_for_filename(text_prompt)
|
138 |
+
output_path = os.path.join(output_dir, f"{formatted_prompt}_{'x'.join(map(str, video_frame_height_width))}_{seed+idx}_{global_rank}.mp4")
|
139 |
+
save_video(output_path, generated_video, generated_audio, fps=24, sample_rate=16000)
|
140 |
+
if generated_image is not None:
|
141 |
+
generated_image.save(output_path.replace('.mp4', '.png'))
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
args = get_arguments()
|
147 |
+
config = OmegaConf.load(args.config_file)
|
148 |
+
main(config=config,args=args)
|
ovi/__init__.py
ADDED
File without changes
|
ovi/configs/inference/inference_fusion.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ckpt_dir: ./ckpts
|
2 |
+
output_dir: ./outputs
|
3 |
+
num_steps: 50
|
4 |
+
solver_name: unipc
|
5 |
+
shift: 5.0
|
6 |
+
sp_size: 1
|
7 |
+
audio_guidance_scale: 3.0
|
8 |
+
video_guidance_scale: 4.0
|
9 |
+
mode: "i2v" # ["t2v", "i2v", "t2i2v"] all comes with audio
|
10 |
+
cpu_offload: False
|
11 |
+
seed: 103
|
12 |
+
video_negative_prompt: "jitter, bad hands, blur, distortion" # Artifacts to avoid in video
|
13 |
+
audio_negative_prompt: "robotic, muffled, echo, distorted" # Artifacts to avoid in audio
|
14 |
+
video_frame_height_width: [512, 992] # only useful if mode = t2v or t2i2v, recommended values: [512, 992], [992, 512], [960, 512], [512, 960], [720, 720], [448, 1120]
|
15 |
+
text_prompt: example_prompts/gpt_examples_i2v.csv
|
16 |
+
slg_layer: 11
|
17 |
+
each_example_n_times: 2
|
ovi/configs/model/dit/audio.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"patch_size": [1],
|
3 |
+
"model_type": "t2a",
|
4 |
+
"dim": 3072,
|
5 |
+
"ffn_dim": 14336,
|
6 |
+
"freq_dim": 256,
|
7 |
+
"num_heads": 24,
|
8 |
+
"num_layers": 30,
|
9 |
+
"in_dim": 20,
|
10 |
+
"out_dim": 20,
|
11 |
+
"text_len": 512,
|
12 |
+
"window_size": [-1, -1],
|
13 |
+
"qk_norm": true,
|
14 |
+
"cross_attn_norm": true,
|
15 |
+
"eps": 1e-6,
|
16 |
+
"temporal_rope_scaling_factor": 0.19676
|
17 |
+
}
|
ovi/configs/model/dit/video.json
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"patch_size": [1, 2, 2],
|
3 |
+
"model_type": "ti2v",
|
4 |
+
"dim": 3072,
|
5 |
+
"ffn_dim": 14336,
|
6 |
+
"freq_dim": 256,
|
7 |
+
"num_heads": 24,
|
8 |
+
"num_layers": 30,
|
9 |
+
"in_dim": 48,
|
10 |
+
"out_dim": 48,
|
11 |
+
"text_len": 512,
|
12 |
+
"window_size": [-1, -1],
|
13 |
+
"qk_norm": true,
|
14 |
+
"cross_attn_norm": true,
|
15 |
+
"eps": 1e-6
|
16 |
+
}
|
ovi/distributed_comms/communications.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft Corporation.
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any, Tuple
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
from .parallel_states import nccl_info
|
11 |
+
|
12 |
+
|
13 |
+
def broadcast(input_: torch.Tensor):
|
14 |
+
src = nccl_info.group_id * nccl_info.sp_size
|
15 |
+
dist.broadcast(input_, src=src, group=nccl_info.group)
|
16 |
+
|
17 |
+
|
18 |
+
def _all_to_all_4D(input: torch.tensor,
|
19 |
+
scatter_idx: int = 2,
|
20 |
+
gather_idx: int = 1,
|
21 |
+
group=None) -> torch.tensor:
|
22 |
+
"""
|
23 |
+
all-to-all for QKV
|
24 |
+
|
25 |
+
Args:
|
26 |
+
input (torch.tensor): a tensor sharded along dim scatter dim
|
27 |
+
scatter_idx (int): default 1
|
28 |
+
gather_idx (int): default 2
|
29 |
+
group : torch process group
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
torch.tensor: resharded tensor (bs, seqlen/P, hc, hs)
|
33 |
+
"""
|
34 |
+
assert (
|
35 |
+
input.dim() == 4
|
36 |
+
), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}"
|
37 |
+
|
38 |
+
seq_world_size = dist.get_world_size(group)
|
39 |
+
|
40 |
+
if scatter_idx == 2 and gather_idx == 1:
|
41 |
+
# input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
|
42 |
+
bs, shard_seqlen, hc, hs = input.shape
|
43 |
+
seqlen = shard_seqlen * seq_world_size
|
44 |
+
shard_hc = hc // seq_world_size
|
45 |
+
|
46 |
+
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
|
47 |
+
# (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs)
|
48 |
+
input_t = (input.reshape(bs, shard_seqlen, seq_world_size, shard_hc,
|
49 |
+
hs).transpose(0, 2).contiguous())
|
50 |
+
|
51 |
+
output = torch.empty_like(input_t)
|
52 |
+
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
|
53 |
+
# (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head
|
54 |
+
if seq_world_size > 1:
|
55 |
+
dist.all_to_all_single(output, input_t, group=group)
|
56 |
+
torch.cuda.synchronize()
|
57 |
+
else:
|
58 |
+
output = input_t
|
59 |
+
# if scattering the seq-dim, transpose the heads back to the original dimension
|
60 |
+
output = output.reshape(seqlen, bs, shard_hc, hs)
|
61 |
+
|
62 |
+
# (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs)
|
63 |
+
output = output.transpose(0, 1).contiguous().reshape(
|
64 |
+
bs, seqlen, shard_hc, hs)
|
65 |
+
|
66 |
+
return output
|
67 |
+
|
68 |
+
elif scatter_idx == 1 and gather_idx == 2:
|
69 |
+
# input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
|
70 |
+
bs, seqlen, shard_hc, hs = input.shape
|
71 |
+
hc = shard_hc * seq_world_size
|
72 |
+
shard_seqlen = seqlen // seq_world_size
|
73 |
+
seq_world_size = dist.get_world_size(group)
|
74 |
+
|
75 |
+
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
|
76 |
+
# (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs)
|
77 |
+
input_t = (input.reshape(
|
78 |
+
bs, seq_world_size, shard_seqlen, shard_hc,
|
79 |
+
hs).transpose(0, 3).transpose(0, 1).contiguous().reshape(
|
80 |
+
seq_world_size, shard_hc, shard_seqlen, bs, hs))
|
81 |
+
|
82 |
+
output = torch.empty_like(input_t)
|
83 |
+
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
|
84 |
+
# (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
|
85 |
+
if seq_world_size > 1:
|
86 |
+
dist.all_to_all_single(output, input_t, group=group)
|
87 |
+
torch.cuda.synchronize()
|
88 |
+
else:
|
89 |
+
output = input_t
|
90 |
+
|
91 |
+
# if scattering the seq-dim, transpose the heads back to the original dimension
|
92 |
+
output = output.reshape(hc, shard_seqlen, bs, hs)
|
93 |
+
|
94 |
+
# (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs)
|
95 |
+
output = output.transpose(0, 2).contiguous().reshape(
|
96 |
+
bs, shard_seqlen, hc, hs)
|
97 |
+
|
98 |
+
return output
|
99 |
+
else:
|
100 |
+
raise RuntimeError(
|
101 |
+
"scatter_idx must be 1 or 2 and gather_idx must be 1 or 2")
|
102 |
+
|
103 |
+
|
104 |
+
class SeqAllToAll4D(torch.autograd.Function):
|
105 |
+
|
106 |
+
@staticmethod
|
107 |
+
def forward(
|
108 |
+
ctx: Any,
|
109 |
+
group: dist.ProcessGroup,
|
110 |
+
input: Tensor,
|
111 |
+
scatter_idx: int,
|
112 |
+
gather_idx: int,
|
113 |
+
) -> Tensor:
|
114 |
+
ctx.group = group
|
115 |
+
ctx.scatter_idx = scatter_idx
|
116 |
+
ctx.gather_idx = gather_idx
|
117 |
+
|
118 |
+
return _all_to_all_4D(input, scatter_idx, gather_idx, group=group)
|
119 |
+
|
120 |
+
@staticmethod
|
121 |
+
def backward(ctx: Any,
|
122 |
+
*grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
|
123 |
+
return (
|
124 |
+
None,
|
125 |
+
SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx,
|
126 |
+
ctx.scatter_idx),
|
127 |
+
None,
|
128 |
+
None,
|
129 |
+
)
|
130 |
+
|
131 |
+
|
132 |
+
def all_to_all_4D(
|
133 |
+
input_: torch.Tensor,
|
134 |
+
scatter_dim: int = 2,
|
135 |
+
gather_dim: int = 1,
|
136 |
+
):
|
137 |
+
return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim,
|
138 |
+
gather_dim)
|
139 |
+
|
140 |
+
|
141 |
+
def _all_to_all(
|
142 |
+
input_: torch.Tensor,
|
143 |
+
world_size: int,
|
144 |
+
group: dist.ProcessGroup,
|
145 |
+
scatter_dim: int,
|
146 |
+
gather_dim: int,
|
147 |
+
):
|
148 |
+
input_list = [
|
149 |
+
t.contiguous()
|
150 |
+
for t in torch.tensor_split(input_, world_size, scatter_dim)
|
151 |
+
]
|
152 |
+
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
|
153 |
+
dist.all_to_all(output_list, input_list, group=group)
|
154 |
+
return torch.cat(output_list, dim=gather_dim).contiguous()
|
155 |
+
|
156 |
+
|
157 |
+
class _AllToAll(torch.autograd.Function):
|
158 |
+
"""All-to-all communication.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
input_: input matrix
|
162 |
+
process_group: communication group
|
163 |
+
scatter_dim: scatter dimension
|
164 |
+
gather_dim: gather dimension
|
165 |
+
"""
|
166 |
+
|
167 |
+
@staticmethod
|
168 |
+
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
|
169 |
+
ctx.process_group = process_group
|
170 |
+
ctx.scatter_dim = scatter_dim
|
171 |
+
ctx.gather_dim = gather_dim
|
172 |
+
ctx.world_size = dist.get_world_size(process_group)
|
173 |
+
output = _all_to_all(input_, ctx.world_size, process_group,
|
174 |
+
scatter_dim, gather_dim)
|
175 |
+
return output
|
176 |
+
|
177 |
+
@staticmethod
|
178 |
+
def backward(ctx, grad_output):
|
179 |
+
grad_output = _all_to_all(
|
180 |
+
grad_output,
|
181 |
+
ctx.world_size,
|
182 |
+
ctx.process_group,
|
183 |
+
ctx.gather_dim,
|
184 |
+
ctx.scatter_dim,
|
185 |
+
)
|
186 |
+
return (
|
187 |
+
grad_output,
|
188 |
+
None,
|
189 |
+
None,
|
190 |
+
None,
|
191 |
+
)
|
192 |
+
|
193 |
+
|
194 |
+
def all_to_all(
|
195 |
+
input_: torch.Tensor,
|
196 |
+
scatter_dim: int = 2,
|
197 |
+
gather_dim: int = 1,
|
198 |
+
):
|
199 |
+
return _AllToAll.apply(input_, nccl_info.group, scatter_dim, gather_dim)
|
200 |
+
|
201 |
+
|
202 |
+
class _AllGather(torch.autograd.Function):
|
203 |
+
"""All-gather communication with autograd support.
|
204 |
+
|
205 |
+
Args:
|
206 |
+
input_: input tensor
|
207 |
+
dim: dimension along which to concatenate
|
208 |
+
"""
|
209 |
+
|
210 |
+
@staticmethod
|
211 |
+
def forward(ctx, input_, dim):
|
212 |
+
ctx.dim = dim
|
213 |
+
world_size = nccl_info.sp_size
|
214 |
+
group = nccl_info.group
|
215 |
+
input_size = list(input_.size())
|
216 |
+
|
217 |
+
ctx.input_size = input_size[dim]
|
218 |
+
|
219 |
+
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
220 |
+
input_ = input_.contiguous()
|
221 |
+
dist.all_gather(tensor_list, input_, group=group)
|
222 |
+
|
223 |
+
output = torch.cat(tensor_list, dim=dim)
|
224 |
+
return output
|
225 |
+
|
226 |
+
@staticmethod
|
227 |
+
def backward(ctx, grad_output):
|
228 |
+
world_size = nccl_info.sp_size
|
229 |
+
rank = nccl_info.rank_within_group
|
230 |
+
dim = ctx.dim
|
231 |
+
input_size = ctx.input_size
|
232 |
+
|
233 |
+
sizes = [input_size] * world_size
|
234 |
+
|
235 |
+
grad_input_list = torch.split(grad_output, sizes, dim=dim)
|
236 |
+
grad_input = grad_input_list[rank]
|
237 |
+
|
238 |
+
return grad_input, None
|
239 |
+
|
240 |
+
|
241 |
+
def all_gather(input_: torch.Tensor, dim: int = 1):
|
242 |
+
"""Performs an all-gather operation on the input tensor along the specified dimension.
|
243 |
+
|
244 |
+
Args:
|
245 |
+
input_ (torch.Tensor): Input tensor of shape [B, H, S, D].
|
246 |
+
dim (int, optional): Dimension along which to concatenate. Defaults to 1.
|
247 |
+
|
248 |
+
Returns:
|
249 |
+
torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'.
|
250 |
+
"""
|
251 |
+
return _AllGather.apply(input_, dim)
|
252 |
+
|
253 |
+
|
254 |
+
def prepare_sequence_parallel_data(hidden_states, encoder_hidden_states,
|
255 |
+
attention_mask, encoder_attention_mask):
|
256 |
+
if nccl_info.sp_size == 1:
|
257 |
+
return (
|
258 |
+
hidden_states,
|
259 |
+
encoder_hidden_states,
|
260 |
+
attention_mask,
|
261 |
+
encoder_attention_mask,
|
262 |
+
)
|
263 |
+
|
264 |
+
def prepare(hidden_states, encoder_hidden_states, attention_mask,
|
265 |
+
encoder_attention_mask):
|
266 |
+
hidden_states = all_to_all(hidden_states, scatter_dim=2, gather_dim=0)
|
267 |
+
encoder_hidden_states = all_to_all(encoder_hidden_states,
|
268 |
+
scatter_dim=1,
|
269 |
+
gather_dim=0)
|
270 |
+
attention_mask = all_to_all(attention_mask,
|
271 |
+
scatter_dim=1,
|
272 |
+
gather_dim=0)
|
273 |
+
encoder_attention_mask = all_to_all(encoder_attention_mask,
|
274 |
+
scatter_dim=1,
|
275 |
+
gather_dim=0)
|
276 |
+
return (
|
277 |
+
hidden_states,
|
278 |
+
encoder_hidden_states,
|
279 |
+
attention_mask,
|
280 |
+
encoder_attention_mask,
|
281 |
+
)
|
282 |
+
|
283 |
+
sp_size = nccl_info.sp_size
|
284 |
+
frame = hidden_states.shape[2]
|
285 |
+
assert frame % sp_size == 0, "frame should be a multiple of sp_size"
|
286 |
+
|
287 |
+
(
|
288 |
+
hidden_states,
|
289 |
+
encoder_hidden_states,
|
290 |
+
attention_mask,
|
291 |
+
encoder_attention_mask,
|
292 |
+
) = prepare(
|
293 |
+
hidden_states,
|
294 |
+
encoder_hidden_states.repeat(1, sp_size, 1),
|
295 |
+
attention_mask.repeat(1, sp_size, 1, 1),
|
296 |
+
encoder_attention_mask.repeat(1, sp_size),
|
297 |
+
)
|
298 |
+
|
299 |
+
return hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask
|
300 |
+
|
301 |
+
|
302 |
+
def sp_parallel_dataloader_wrapper(dataloader, device, train_batch_size,
|
303 |
+
sp_size, train_sp_batch_size):
|
304 |
+
while True:
|
305 |
+
for data_item in dataloader:
|
306 |
+
latents, cond, attn_mask, cond_mask = data_item
|
307 |
+
latents = latents.to(device)
|
308 |
+
cond = cond.to(device)
|
309 |
+
attn_mask = attn_mask.to(device)
|
310 |
+
cond_mask = cond_mask.to(device)
|
311 |
+
frame = latents.shape[2]
|
312 |
+
if frame == 1:
|
313 |
+
yield latents, cond, attn_mask, cond_mask
|
314 |
+
else:
|
315 |
+
latents, cond, attn_mask, cond_mask = prepare_sequence_parallel_data(
|
316 |
+
latents, cond, attn_mask, cond_mask)
|
317 |
+
assert (
|
318 |
+
train_batch_size * sp_size >= train_sp_batch_size
|
319 |
+
), "train_batch_size * sp_size should be greater than train_sp_batch_size"
|
320 |
+
for iter in range(train_batch_size * sp_size //
|
321 |
+
train_sp_batch_size):
|
322 |
+
st_idx = iter * train_sp_batch_size
|
323 |
+
ed_idx = (iter + 1) * train_sp_batch_size
|
324 |
+
encoder_hidden_states = cond[st_idx:ed_idx]
|
325 |
+
attention_mask = attn_mask[st_idx:ed_idx]
|
326 |
+
encoder_attention_mask = cond_mask[st_idx:ed_idx]
|
327 |
+
yield (
|
328 |
+
latents[st_idx:ed_idx],
|
329 |
+
encoder_hidden_states,
|
330 |
+
attention_mask,
|
331 |
+
encoder_attention_mask,
|
332 |
+
)
|
ovi/distributed_comms/distributed/__init__.py
ADDED
File without changes
|
ovi/distributed_comms/distributed/fsdp.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
6 |
+
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
7 |
+
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
8 |
+
|
9 |
+
|
10 |
+
def shard_model(
|
11 |
+
model,
|
12 |
+
device_id,
|
13 |
+
param_dtype=torch.bfloat16,
|
14 |
+
reduce_dtype=torch.float32,
|
15 |
+
buffer_dtype=torch.float32,
|
16 |
+
process_group=None,
|
17 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
18 |
+
sync_module_states=True,
|
19 |
+
):
|
20 |
+
model = FSDP(
|
21 |
+
module=model,
|
22 |
+
process_group=process_group,
|
23 |
+
sharding_strategy=sharding_strategy,
|
24 |
+
auto_wrap_policy=partial(
|
25 |
+
lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
|
26 |
+
mixed_precision=MixedPrecision(
|
27 |
+
param_dtype=param_dtype,
|
28 |
+
reduce_dtype=reduce_dtype,
|
29 |
+
buffer_dtype=buffer_dtype),
|
30 |
+
device_id=device_id,
|
31 |
+
sync_module_states=sync_module_states)
|
32 |
+
return model
|
ovi/distributed_comms/distributed/xdit_context_parallel.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
2 |
+
import torch
|
3 |
+
import torch.cuda.amp as amp
|
4 |
+
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
5 |
+
get_sequence_parallel_world_size,
|
6 |
+
get_sp_group)
|
7 |
+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
8 |
+
|
9 |
+
from ..modules.model import sinusoidal_embedding_1d
|
10 |
+
|
11 |
+
|
12 |
+
def pad_freqs(original_tensor, target_len):
|
13 |
+
seq_len, s1, s2 = original_tensor.shape
|
14 |
+
pad_size = target_len - seq_len
|
15 |
+
padding_tensor = torch.ones(
|
16 |
+
pad_size,
|
17 |
+
s1,
|
18 |
+
s2,
|
19 |
+
dtype=original_tensor.dtype,
|
20 |
+
device=original_tensor.device)
|
21 |
+
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
22 |
+
return padded_tensor
|
23 |
+
|
24 |
+
|
25 |
+
@amp.autocast(enabled=False)
|
26 |
+
def rope_apply(x, grid_sizes, freqs):
|
27 |
+
"""
|
28 |
+
x: [B, L, N, C].
|
29 |
+
grid_sizes: [B, 3].
|
30 |
+
freqs: [M, C // 2].
|
31 |
+
"""
|
32 |
+
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
33 |
+
# split freqs
|
34 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
35 |
+
|
36 |
+
# loop over samples
|
37 |
+
output = []
|
38 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
39 |
+
seq_len = f * h * w
|
40 |
+
|
41 |
+
# precompute multipliers
|
42 |
+
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
|
43 |
+
s, n, -1, 2))
|
44 |
+
freqs_i = torch.cat([
|
45 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
46 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
47 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
48 |
+
],
|
49 |
+
dim=-1).reshape(seq_len, 1, -1)
|
50 |
+
|
51 |
+
# apply rotary embedding
|
52 |
+
sp_size = get_sequence_parallel_world_size()
|
53 |
+
sp_rank = get_sequence_parallel_rank()
|
54 |
+
freqs_i = pad_freqs(freqs_i, s * sp_size)
|
55 |
+
s_per_rank = s
|
56 |
+
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
|
57 |
+
s_per_rank), :, :]
|
58 |
+
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
|
59 |
+
x_i = torch.cat([x_i, x[i, s:]])
|
60 |
+
|
61 |
+
# append to collection
|
62 |
+
output.append(x_i)
|
63 |
+
return torch.stack(output).float()
|
64 |
+
|
65 |
+
|
66 |
+
def usp_dit_forward(
|
67 |
+
self,
|
68 |
+
x,
|
69 |
+
t,
|
70 |
+
context,
|
71 |
+
seq_len,
|
72 |
+
clip_fea=None,
|
73 |
+
y=None,
|
74 |
+
):
|
75 |
+
"""
|
76 |
+
x: A list of videos each with shape [C, T, H, W].
|
77 |
+
t: [B].
|
78 |
+
context: A list of text embeddings each with shape [L, C].
|
79 |
+
"""
|
80 |
+
if self.model_type == 'i2v':
|
81 |
+
assert clip_fea is not None and y is not None
|
82 |
+
# params
|
83 |
+
device = self.patch_embedding.weight.device
|
84 |
+
if self.freqs.device != device:
|
85 |
+
self.freqs = self.freqs.to(device)
|
86 |
+
|
87 |
+
if y is not None:
|
88 |
+
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
|
89 |
+
|
90 |
+
# embeddings
|
91 |
+
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
|
92 |
+
grid_sizes = torch.stack(
|
93 |
+
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
|
94 |
+
x = [u.flatten(2).transpose(1, 2) for u in x]
|
95 |
+
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
|
96 |
+
assert seq_lens.max() <= seq_len
|
97 |
+
x = torch.cat([
|
98 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
|
99 |
+
for u in x
|
100 |
+
])
|
101 |
+
|
102 |
+
# time embeddings
|
103 |
+
with amp.autocast(dtype=torch.float32):
|
104 |
+
e = self.time_embedding(
|
105 |
+
sinusoidal_embedding_1d(self.freq_dim, t).float())
|
106 |
+
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
|
107 |
+
assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
108 |
+
|
109 |
+
# context
|
110 |
+
context_lens = None
|
111 |
+
context = self.text_embedding(
|
112 |
+
torch.stack([
|
113 |
+
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
114 |
+
for u in context
|
115 |
+
]))
|
116 |
+
|
117 |
+
if clip_fea is not None:
|
118 |
+
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
119 |
+
context = torch.concat([context_clip, context], dim=1)
|
120 |
+
|
121 |
+
# arguments
|
122 |
+
kwargs = dict(
|
123 |
+
e=e0,
|
124 |
+
seq_lens=seq_lens,
|
125 |
+
grid_sizes=grid_sizes,
|
126 |
+
freqs=self.freqs,
|
127 |
+
context=context,
|
128 |
+
context_lens=context_lens)
|
129 |
+
|
130 |
+
# Context Parallel
|
131 |
+
x = torch.chunk(
|
132 |
+
x, get_sequence_parallel_world_size(),
|
133 |
+
dim=1)[get_sequence_parallel_rank()]
|
134 |
+
|
135 |
+
for block in self.blocks:
|
136 |
+
x = block(x, **kwargs)
|
137 |
+
|
138 |
+
# head
|
139 |
+
x = self.head(x, e)
|
140 |
+
|
141 |
+
# Context Parallel
|
142 |
+
x = get_sp_group().all_gather(x, dim=1)
|
143 |
+
|
144 |
+
# unpatchify
|
145 |
+
x = self.unpatchify(x, grid_sizes)
|
146 |
+
return [u.float() for u in x]
|
147 |
+
|
148 |
+
|
149 |
+
def usp_attn_forward(self,
|
150 |
+
x,
|
151 |
+
seq_lens,
|
152 |
+
grid_sizes,
|
153 |
+
freqs,
|
154 |
+
dtype=torch.bfloat16):
|
155 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
156 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
157 |
+
|
158 |
+
def half(x):
|
159 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
160 |
+
|
161 |
+
# query, key, value function
|
162 |
+
def qkv_fn(x):
|
163 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
164 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
165 |
+
v = self.v(x).view(b, s, n, d)
|
166 |
+
return q, k, v
|
167 |
+
|
168 |
+
q, k, v = qkv_fn(x)
|
169 |
+
q = rope_apply(q, grid_sizes, freqs)
|
170 |
+
k = rope_apply(k, grid_sizes, freqs)
|
171 |
+
|
172 |
+
# TODO: We should use unpaded q,k,v for attention.
|
173 |
+
# k_lens = seq_lens // get_sequence_parallel_world_size()
|
174 |
+
# if k_lens is not None:
|
175 |
+
# q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
|
176 |
+
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
177 |
+
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
178 |
+
|
179 |
+
x = xFuserLongContextAttention()(
|
180 |
+
None,
|
181 |
+
query=half(q),
|
182 |
+
key=half(k),
|
183 |
+
value=half(v),
|
184 |
+
window_size=self.window_size)
|
185 |
+
|
186 |
+
# TODO: padding after attention.
|
187 |
+
# x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
|
188 |
+
|
189 |
+
# output
|
190 |
+
x = x.flatten(2)
|
191 |
+
x = self.o(x)
|
192 |
+
return x
|
ovi/distributed_comms/parallel_states.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch.distributed as dist
|
4 |
+
|
5 |
+
|
6 |
+
class COMM_INFO:
|
7 |
+
|
8 |
+
def __init__(self):
|
9 |
+
self.group = None
|
10 |
+
self.sp_size = 1
|
11 |
+
self.global_rank = 0
|
12 |
+
self.rank_within_group = 0
|
13 |
+
self.group_id = 0
|
14 |
+
|
15 |
+
|
16 |
+
nccl_info = COMM_INFO()
|
17 |
+
_SEQUENCE_PARALLEL_STATE = False
|
18 |
+
|
19 |
+
|
20 |
+
def initialize_sequence_parallel_state(sequence_parallel_size):
|
21 |
+
global _SEQUENCE_PARALLEL_STATE
|
22 |
+
if sequence_parallel_size > 1:
|
23 |
+
_SEQUENCE_PARALLEL_STATE = True
|
24 |
+
initialize_sequence_parallel_group(sequence_parallel_size)
|
25 |
+
else:
|
26 |
+
nccl_info.sp_size = 1
|
27 |
+
nccl_info.global_rank = int(os.getenv("RANK", "0"))
|
28 |
+
nccl_info.rank_within_group = 0
|
29 |
+
nccl_info.group_id = int(os.getenv("RANK", "0"))
|
30 |
+
|
31 |
+
|
32 |
+
def set_sequence_parallel_state(state):
|
33 |
+
global _SEQUENCE_PARALLEL_STATE
|
34 |
+
_SEQUENCE_PARALLEL_STATE = state
|
35 |
+
|
36 |
+
|
37 |
+
def get_sequence_parallel_state():
|
38 |
+
return _SEQUENCE_PARALLEL_STATE
|
39 |
+
|
40 |
+
|
41 |
+
def initialize_sequence_parallel_group(sequence_parallel_size):
|
42 |
+
"""Initialize the sequence parallel group."""
|
43 |
+
rank = int(os.getenv("RANK", "0"))
|
44 |
+
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
45 |
+
assert (
|
46 |
+
world_size % sequence_parallel_size == 0
|
47 |
+
), "world_size must be divisible by sequence_parallel_size, but got world_size: {}, sequence_parallel_size: {}".format(
|
48 |
+
world_size, sequence_parallel_size)
|
49 |
+
nccl_info.sp_size = sequence_parallel_size
|
50 |
+
nccl_info.global_rank = rank
|
51 |
+
num_sequence_parallel_groups: int = world_size // sequence_parallel_size
|
52 |
+
for i in range(num_sequence_parallel_groups):
|
53 |
+
ranks = range(i * sequence_parallel_size,
|
54 |
+
(i + 1) * sequence_parallel_size)
|
55 |
+
group = dist.new_group(ranks)
|
56 |
+
if rank in ranks:
|
57 |
+
nccl_info.group = group
|
58 |
+
nccl_info.rank_within_group = rank - i * sequence_parallel_size
|
59 |
+
nccl_info.group_id = i
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
def initialize_sequence_parallel_group_custom(process_group):
|
64 |
+
set_sequence_parallel_state(True)
|
65 |
+
"""Initialize an unsafe sequence parallel group with a pre-formed group."""
|
66 |
+
rank = dist.get_rank(group=process_group)
|
67 |
+
sequence_parallel_size = dist.get_world_size(group=process_group)
|
68 |
+
nccl_info.sp_size = sequence_parallel_size
|
69 |
+
nccl_info.global_rank = dist.get_rank() # global rank
|
70 |
+
nccl_info.group = process_group
|
71 |
+
nccl_info.rank_within_group = rank
|
72 |
+
nccl_info.group_id = 0
|
73 |
+
|
74 |
+
|
75 |
+
def destroy_sequence_parallel_group():
|
76 |
+
"""Destroy the sequence parallel group."""
|
77 |
+
dist.destroy_process_group()
|
ovi/distributed_comms/util.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
|
6 |
+
|
7 |
+
def get_global_rank() -> int:
|
8 |
+
"""
|
9 |
+
Get the global rank, the global index of the GPU.
|
10 |
+
"""
|
11 |
+
return int(os.environ.get("RANK", "0"))
|
12 |
+
|
13 |
+
|
14 |
+
def get_local_rank() -> int:
|
15 |
+
"""
|
16 |
+
Get the local rank, the local index of the GPU.
|
17 |
+
"""
|
18 |
+
return int(os.environ.get("LOCAL_RANK", "0"))
|
19 |
+
|
20 |
+
|
21 |
+
def get_world_size() -> int:
|
22 |
+
"""
|
23 |
+
Get the world size, the total amount of GPUs.
|
24 |
+
"""
|
25 |
+
return int(os.environ.get("WORLD_SIZE", "1"))
|
26 |
+
|
27 |
+
|
28 |
+
def get_device() -> torch.device:
|
29 |
+
"""
|
30 |
+
Get current rank device.
|
31 |
+
"""
|
32 |
+
return torch.device("cuda", get_local_rank())
|
33 |
+
|
34 |
+
def get_sequence_parallel_group():
|
35 |
+
"""Get the sequence parallel group the caller rank belongs to."""
|
36 |
+
return _SEQUENCE_PARALLEL_GROUP
|
37 |
+
|
38 |
+
def initialize_sequence_parallelism(sequence_parallel_size):
|
39 |
+
assert int(get_world_size()) % sequence_parallel_size == 0
|
40 |
+
sequence_parallel_num_groups = int(get_world_size()) // sequence_parallel_size
|
41 |
+
global _SEQUENCE_PARALLEL_GROUP
|
42 |
+
for i in range(sequence_parallel_num_groups):
|
43 |
+
ranks = range(i * sequence_parallel_size,
|
44 |
+
(i + 1) * sequence_parallel_size)
|
45 |
+
group = torch.distributed.new_group(ranks)
|
46 |
+
if int(get_global_rank()) in ranks:
|
47 |
+
print(f"Rank {get_global_rank()} joined group with ranks {list(ranks)}")
|
48 |
+
_SEQUENCE_PARALLEL_GROUP = group
|
ovi/modules/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .attention import flash_attention
|
2 |
+
from .model import WanModel
|
3 |
+
from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
|
4 |
+
from .tokenizers import HuggingfaceTokenizer
|
5 |
+
from .vae import WanVAE
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
'WanVAE',
|
9 |
+
'WanModel',
|
10 |
+
'T5Model',
|
11 |
+
'T5Encoder',
|
12 |
+
'T5Decoder',
|
13 |
+
'T5EncoderModel',
|
14 |
+
'HuggingfaceTokenizer',
|
15 |
+
'flash_attention',
|
16 |
+
]
|
ovi/modules/attention.py
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
2 |
+
import torch
|
3 |
+
|
4 |
+
try:
|
5 |
+
import flash_attn_interface
|
6 |
+
FLASH_ATTN_3_AVAILABLE = True
|
7 |
+
except ModuleNotFoundError:
|
8 |
+
FLASH_ATTN_3_AVAILABLE = False
|
9 |
+
|
10 |
+
try:
|
11 |
+
import flash_attn
|
12 |
+
FLASH_ATTN_2_AVAILABLE = True
|
13 |
+
except ModuleNotFoundError:
|
14 |
+
FLASH_ATTN_2_AVAILABLE = False
|
15 |
+
|
16 |
+
import warnings
|
17 |
+
|
18 |
+
__all__ = [
|
19 |
+
'flash_attention',
|
20 |
+
'attention',
|
21 |
+
'attention_with_weights',
|
22 |
+
]
|
23 |
+
|
24 |
+
|
25 |
+
def flash_attention(
|
26 |
+
q,
|
27 |
+
k,
|
28 |
+
v,
|
29 |
+
q_lens=None,
|
30 |
+
k_lens=None,
|
31 |
+
dropout_p=0.,
|
32 |
+
softmax_scale=None,
|
33 |
+
q_scale=None,
|
34 |
+
causal=False,
|
35 |
+
window_size=(-1, -1),
|
36 |
+
deterministic=False,
|
37 |
+
dtype=torch.bfloat16,
|
38 |
+
version=None
|
39 |
+
):
|
40 |
+
"""
|
41 |
+
q: [B, Lq, Nq, C1].
|
42 |
+
k: [B, Lk, Nk, C1].
|
43 |
+
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
|
44 |
+
q_lens: [B].
|
45 |
+
k_lens: [B].
|
46 |
+
dropout_p: float. Dropout probability.
|
47 |
+
softmax_scale: float. The scaling of QK^T before applying softmax.
|
48 |
+
causal: bool. Whether to apply causal attention mask.
|
49 |
+
window_size: (left right). If not (-1, -1), apply sliding window local attention.
|
50 |
+
deterministic: bool. If True, slightly slower and uses more memory.
|
51 |
+
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
|
52 |
+
"""
|
53 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
54 |
+
assert dtype in half_dtypes
|
55 |
+
assert q.device.type == 'cuda' and q.size(-1) <= 256
|
56 |
+
|
57 |
+
# params
|
58 |
+
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
59 |
+
|
60 |
+
def half(x):
|
61 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
62 |
+
|
63 |
+
# preprocess query
|
64 |
+
if q_lens is None:
|
65 |
+
q = half(q.flatten(0, 1))
|
66 |
+
q_lens = torch.tensor(
|
67 |
+
[lq] * b, dtype=torch.int32).to(
|
68 |
+
device=q.device, non_blocking=True)
|
69 |
+
else:
|
70 |
+
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
|
71 |
+
|
72 |
+
# preprocess key, value
|
73 |
+
if k_lens is None:
|
74 |
+
k = half(k.flatten(0, 1))
|
75 |
+
v = half(v.flatten(0, 1))
|
76 |
+
k_lens = torch.tensor(
|
77 |
+
[lk] * b, dtype=torch.int32).to(
|
78 |
+
device=k.device, non_blocking=True)
|
79 |
+
else:
|
80 |
+
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
|
81 |
+
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
|
82 |
+
|
83 |
+
q = q.to(v.dtype)
|
84 |
+
k = k.to(v.dtype)
|
85 |
+
|
86 |
+
if q_scale is not None:
|
87 |
+
q = q * q_scale
|
88 |
+
|
89 |
+
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
|
90 |
+
warnings.warn(
|
91 |
+
'Flash attention 3 is not available, use flash attention 2 instead.'
|
92 |
+
)
|
93 |
+
|
94 |
+
# apply attention
|
95 |
+
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
|
96 |
+
# Note: dropout_p, window_size are not supported in FA3 now.
|
97 |
+
x = flash_attn_interface.flash_attn_varlen_func(
|
98 |
+
q=q,
|
99 |
+
k=k,
|
100 |
+
v=v,
|
101 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
102 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
103 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
104 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
105 |
+
seqused_q=None,
|
106 |
+
seqused_k=None,
|
107 |
+
max_seqlen_q=lq,
|
108 |
+
max_seqlen_k=lk,
|
109 |
+
softmax_scale=softmax_scale,
|
110 |
+
causal=causal,
|
111 |
+
deterministic=deterministic)
|
112 |
+
|
113 |
+
if isinstance(x, tuple):
|
114 |
+
x = x[0]
|
115 |
+
x = x.unflatten(0, (b, lq))
|
116 |
+
|
117 |
+
else:
|
118 |
+
assert FLASH_ATTN_2_AVAILABLE
|
119 |
+
x = flash_attn.flash_attn_varlen_func(
|
120 |
+
q=q,
|
121 |
+
k=k,
|
122 |
+
v=v,
|
123 |
+
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
124 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
125 |
+
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
126 |
+
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
127 |
+
max_seqlen_q=lq,
|
128 |
+
max_seqlen_k=lk,
|
129 |
+
dropout_p=dropout_p,
|
130 |
+
softmax_scale=softmax_scale,
|
131 |
+
causal=causal,
|
132 |
+
window_size=window_size,
|
133 |
+
deterministic=deterministic).unflatten(0, (b, lq))
|
134 |
+
|
135 |
+
# output
|
136 |
+
return x.type(out_dtype)
|
137 |
+
|
138 |
+
|
139 |
+
def attention_with_weights(
|
140 |
+
q,
|
141 |
+
k,
|
142 |
+
v,
|
143 |
+
q_lens=None,
|
144 |
+
k_lens=None,
|
145 |
+
softmax_scale=None,
|
146 |
+
q_scale=None,
|
147 |
+
causal=False,
|
148 |
+
average_for_q=False,
|
149 |
+
total_video_latent_frames = 21
|
150 |
+
):
|
151 |
+
"""
|
152 |
+
Compute attention with explicit attention weights for visualization.
|
153 |
+
Returns both output and attention weights.
|
154 |
+
"""
|
155 |
+
out_dtype = q.dtype
|
156 |
+
|
157 |
+
# Handle sequence lengths
|
158 |
+
b, lq, lk = q.size(0), q.size(1), k.size(1)
|
159 |
+
|
160 |
+
if q_lens is None:
|
161 |
+
q_lens = torch.tensor([lq] * b, dtype=torch.int32, device=q.device)
|
162 |
+
else:
|
163 |
+
# Ensure q_lens is on the same device as q
|
164 |
+
q_lens = q_lens.to(q.device)
|
165 |
+
|
166 |
+
if k_lens is None:
|
167 |
+
k_lens = torch.tensor([lk] * b, dtype=torch.int32, device=k.device)
|
168 |
+
else:
|
169 |
+
# Ensure k_lens is on the same device as k
|
170 |
+
k_lens = k_lens.to(k.device)
|
171 |
+
|
172 |
+
# Apply q_scale if provided
|
173 |
+
if q_scale is not None:
|
174 |
+
q = q * q_scale
|
175 |
+
|
176 |
+
# Compute attention weights manually
|
177 |
+
# q: [B, Lq, Nq, C], k: [B, Lk, Nk, C]
|
178 |
+
scale = softmax_scale if softmax_scale is not None else (q.size(-1) ** -0.5)
|
179 |
+
|
180 |
+
# Compute scores: [B, Nq, Lq, Lk]
|
181 |
+
scores = torch.einsum('blhd,bshd->bhls', q, k) * scale
|
182 |
+
|
183 |
+
# Apply causal mask if needed
|
184 |
+
if causal:
|
185 |
+
mask = torch.triu(torch.ones(lq, lk, device=q.device, dtype=torch.bool), diagonal=1)
|
186 |
+
scores.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
|
187 |
+
|
188 |
+
# Mask for k_lens (columns)
|
189 |
+
k_mask = torch.arange(lk, device=k.device).unsqueeze(0) >= k_lens.unsqueeze(1) # [B, Lk]
|
190 |
+
scores.masked_fill_(k_mask.unsqueeze(1).unsqueeze(2), float('-inf')) # [B, 1, 1, Lk]
|
191 |
+
|
192 |
+
# Mask for q_lens (rows)
|
193 |
+
q_mask = torch.arange(lq, device=q.device).unsqueeze(0) >= q_lens.unsqueeze(1) # [B, Lq]
|
194 |
+
scores.masked_fill_(q_mask.unsqueeze(1).unsqueeze(3), float('-inf')) # [B, 1, Lq, 1]
|
195 |
+
|
196 |
+
# Compute attention weights
|
197 |
+
attn_weights = torch.softmax(scores, dim=-1) # [B, Nq, Lq, Lk]
|
198 |
+
assert attn_weights.shape[0] == 1, "Batch size > 1 not supported for attention visualization."
|
199 |
+
|
200 |
+
# Average attention weights to reduce memory usage before returning
|
201 |
+
# Average across batch dimension (should be 1) and query heads and query sequence length
|
202 |
+
# This gives us attention weight per video token: [Lk]
|
203 |
+
if average_for_q:
|
204 |
+
#avg_attn_weights = torch.mean(attn_weights, dim=(0, 1, 3)) # [Lq]
|
205 |
+
avg_attn_weights = torch.max(attn_weights, dim=3)[0].mean(dim=(0, 1)) # [Lq]
|
206 |
+
else:
|
207 |
+
if 0:
|
208 |
+
avg_attn_weights = torch.mean(attn_weights, dim=(0, 1, 2)) # [Lk]
|
209 |
+
elif 1:
|
210 |
+
B, H, Lq, Lk = attn_weights.shape # [1, H, Lq, Lk]
|
211 |
+
per_frame_seq_len = Lk // total_video_latent_frames
|
212 |
+
per_frame_aud_len = Lq // total_video_latent_frames
|
213 |
+
|
214 |
+
avg_attn_weights = torch.zeros((Lk,), device=attn_weights.device, dtype=attn_weights.dtype)
|
215 |
+
|
216 |
+
eps = 1e-8 # numerical stability
|
217 |
+
for i in range(total_video_latent_frames):
|
218 |
+
start_idx_v = i * per_frame_seq_len
|
219 |
+
end_idx_v = (i + 1) * per_frame_seq_len
|
220 |
+
|
221 |
+
start_idx_a = i * per_frame_aud_len
|
222 |
+
end_idx_a = (i + 1) * per_frame_aud_len
|
223 |
+
|
224 |
+
# attn_chunk: [H, La, Lv]
|
225 |
+
attn_chunk = attn_weights[0, :, start_idx_a:end_idx_a, start_idx_v:end_idx_v]
|
226 |
+
|
227 |
+
# ---- Head informativeness via (low) entropy over Lv ----
|
228 |
+
# Normalize within the Lv slice per (head, query) to make a proper distribution
|
229 |
+
p = attn_chunk / (attn_chunk.sum(dim=-1, keepdim=True) + eps) # [H, La, Lv]
|
230 |
+
entropy = -(p * (p + eps).log()).sum(dim=-1).mean(dim=1) # [H]
|
231 |
+
|
232 |
+
# Convert to positive head weights (lower entropy -> larger weight)
|
233 |
+
saliency = 1.0 / (entropy + 1e-6) # [H]
|
234 |
+
head_w = saliency / (saliency.sum() + eps) # [H], sum=1
|
235 |
+
|
236 |
+
# Reduce across audio queries first (pick strong responses), then weight heads
|
237 |
+
per_head = torch.amax(attn_chunk, dim=1) # [H, Lv]
|
238 |
+
weighted = (per_head * head_w[:, None]).sum(dim=0) # [Lv]
|
239 |
+
|
240 |
+
avg_attn_weights[start_idx_v:end_idx_v] = weighted
|
241 |
+
else:
|
242 |
+
avg_attn_weights = torch.mean(attn_weights, dim=(0, 2)).max(dim=(0))[0] # [Lk]
|
243 |
+
|
244 |
+
# Compute output: [B, Lq, Nq, C]
|
245 |
+
out = torch.einsum('bhls,bshd->blhd', attn_weights, v)
|
246 |
+
|
247 |
+
return out.to(out_dtype), avg_attn_weights.to(out_dtype)
|
248 |
+
|
249 |
+
|
250 |
+
def attention(
|
251 |
+
q,
|
252 |
+
k,
|
253 |
+
v,
|
254 |
+
q_lens=None,
|
255 |
+
k_lens=None,
|
256 |
+
dropout_p=0.,
|
257 |
+
softmax_scale=None,
|
258 |
+
q_scale=None,
|
259 |
+
causal=False,
|
260 |
+
window_size=(-1, -1),
|
261 |
+
deterministic=False,
|
262 |
+
dtype=torch.bfloat16,
|
263 |
+
fa_version=None,
|
264 |
+
):
|
265 |
+
if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
|
266 |
+
return flash_attention(
|
267 |
+
q=q,
|
268 |
+
k=k,
|
269 |
+
v=v,
|
270 |
+
q_lens=q_lens,
|
271 |
+
k_lens=k_lens,
|
272 |
+
dropout_p=dropout_p,
|
273 |
+
softmax_scale=softmax_scale,
|
274 |
+
q_scale=q_scale,
|
275 |
+
causal=causal,
|
276 |
+
window_size=window_size,
|
277 |
+
deterministic=deterministic,
|
278 |
+
dtype=dtype,
|
279 |
+
version=fa_version,
|
280 |
+
)
|
281 |
+
else:
|
282 |
+
if q_lens is not None or k_lens is not None:
|
283 |
+
warnings.warn(
|
284 |
+
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
|
285 |
+
)
|
286 |
+
attn_mask = None
|
287 |
+
|
288 |
+
q = q.transpose(1, 2).to(dtype)
|
289 |
+
k = k.transpose(1, 2).to(dtype)
|
290 |
+
v = v.transpose(1, 2).to(dtype)
|
291 |
+
|
292 |
+
out = torch.nn.functional.scaled_dot_product_attention(
|
293 |
+
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
|
294 |
+
|
295 |
+
out = out.transpose(1, 2).contiguous()
|
296 |
+
return out
|
ovi/modules/clip.py
ADDED
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
|
2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision.transforms as T
|
10 |
+
|
11 |
+
from .attention import flash_attention
|
12 |
+
from .tokenizers import HuggingfaceTokenizer
|
13 |
+
from .xlm_roberta import XLMRoberta
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
'XLMRobertaCLIP',
|
17 |
+
'clip_xlm_roberta_vit_h_14',
|
18 |
+
'CLIPModel',
|
19 |
+
]
|
20 |
+
|
21 |
+
|
22 |
+
def pos_interpolate(pos, seq_len):
|
23 |
+
if pos.size(1) == seq_len:
|
24 |
+
return pos
|
25 |
+
else:
|
26 |
+
src_grid = int(math.sqrt(pos.size(1)))
|
27 |
+
tar_grid = int(math.sqrt(seq_len))
|
28 |
+
n = pos.size(1) - src_grid * src_grid
|
29 |
+
return torch.cat([
|
30 |
+
pos[:, :n],
|
31 |
+
F.interpolate(
|
32 |
+
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
|
33 |
+
0, 3, 1, 2),
|
34 |
+
size=(tar_grid, tar_grid),
|
35 |
+
mode='bicubic',
|
36 |
+
align_corners=False).flatten(2).transpose(1, 2)
|
37 |
+
],
|
38 |
+
dim=1)
|
39 |
+
|
40 |
+
|
41 |
+
class QuickGELU(nn.Module):
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
return x * torch.sigmoid(1.702 * x)
|
45 |
+
|
46 |
+
|
47 |
+
class LayerNorm(nn.LayerNorm):
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
return super().forward(x.float()).type_as(x)
|
51 |
+
|
52 |
+
|
53 |
+
class SelfAttention(nn.Module):
|
54 |
+
|
55 |
+
def __init__(self,
|
56 |
+
dim,
|
57 |
+
num_heads,
|
58 |
+
causal=False,
|
59 |
+
attn_dropout=0.0,
|
60 |
+
proj_dropout=0.0):
|
61 |
+
assert dim % num_heads == 0
|
62 |
+
super().__init__()
|
63 |
+
self.dim = dim
|
64 |
+
self.num_heads = num_heads
|
65 |
+
self.head_dim = dim // num_heads
|
66 |
+
self.causal = causal
|
67 |
+
self.attn_dropout = attn_dropout
|
68 |
+
self.proj_dropout = proj_dropout
|
69 |
+
|
70 |
+
# layers
|
71 |
+
self.to_qkv = nn.Linear(dim, dim * 3)
|
72 |
+
self.proj = nn.Linear(dim, dim)
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
"""
|
76 |
+
x: [B, L, C].
|
77 |
+
"""
|
78 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
79 |
+
|
80 |
+
# compute query, key, value
|
81 |
+
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
|
82 |
+
|
83 |
+
# compute attention
|
84 |
+
p = self.attn_dropout if self.training else 0.0
|
85 |
+
x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
|
86 |
+
# x = flash_attention(q, k, v, dropout_p=p, causal=self.causal)
|
87 |
+
x = x.reshape(b, s, c)
|
88 |
+
|
89 |
+
# output
|
90 |
+
x = self.proj(x)
|
91 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
92 |
+
return x
|
93 |
+
|
94 |
+
|
95 |
+
class SwiGLU(nn.Module):
|
96 |
+
|
97 |
+
def __init__(self, dim, mid_dim):
|
98 |
+
super().__init__()
|
99 |
+
self.dim = dim
|
100 |
+
self.mid_dim = mid_dim
|
101 |
+
|
102 |
+
# layers
|
103 |
+
self.fc1 = nn.Linear(dim, mid_dim)
|
104 |
+
self.fc2 = nn.Linear(dim, mid_dim)
|
105 |
+
self.fc3 = nn.Linear(mid_dim, dim)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
x = F.silu(self.fc1(x)) * self.fc2(x)
|
109 |
+
x = self.fc3(x)
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class AttentionBlock(nn.Module):
|
114 |
+
|
115 |
+
def __init__(self,
|
116 |
+
dim,
|
117 |
+
mlp_ratio,
|
118 |
+
num_heads,
|
119 |
+
post_norm=False,
|
120 |
+
causal=False,
|
121 |
+
activation='quick_gelu',
|
122 |
+
attn_dropout=0.0,
|
123 |
+
proj_dropout=0.0,
|
124 |
+
norm_eps=1e-5):
|
125 |
+
assert activation in ['quick_gelu', 'gelu', 'swi_glu']
|
126 |
+
super().__init__()
|
127 |
+
self.dim = dim
|
128 |
+
self.mlp_ratio = mlp_ratio
|
129 |
+
self.num_heads = num_heads
|
130 |
+
self.post_norm = post_norm
|
131 |
+
self.causal = causal
|
132 |
+
self.norm_eps = norm_eps
|
133 |
+
|
134 |
+
# layers
|
135 |
+
self.norm1 = LayerNorm(dim, eps=norm_eps)
|
136 |
+
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
|
137 |
+
proj_dropout)
|
138 |
+
self.norm2 = LayerNorm(dim, eps=norm_eps)
|
139 |
+
if activation == 'swi_glu':
|
140 |
+
self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
|
141 |
+
else:
|
142 |
+
self.mlp = nn.Sequential(
|
143 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
144 |
+
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
145 |
+
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
146 |
+
|
147 |
+
def forward(self, x):
|
148 |
+
if self.post_norm:
|
149 |
+
x = x + self.norm1(self.attn(x))
|
150 |
+
x = x + self.norm2(self.mlp(x))
|
151 |
+
else:
|
152 |
+
x = x + self.attn(self.norm1(x))
|
153 |
+
x = x + self.mlp(self.norm2(x))
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
class AttentionPool(nn.Module):
|
158 |
+
|
159 |
+
def __init__(self,
|
160 |
+
dim,
|
161 |
+
mlp_ratio,
|
162 |
+
num_heads,
|
163 |
+
activation='gelu',
|
164 |
+
proj_dropout=0.0,
|
165 |
+
norm_eps=1e-5):
|
166 |
+
assert dim % num_heads == 0
|
167 |
+
super().__init__()
|
168 |
+
self.dim = dim
|
169 |
+
self.mlp_ratio = mlp_ratio
|
170 |
+
self.num_heads = num_heads
|
171 |
+
self.head_dim = dim // num_heads
|
172 |
+
self.proj_dropout = proj_dropout
|
173 |
+
self.norm_eps = norm_eps
|
174 |
+
|
175 |
+
# layers
|
176 |
+
gain = 1.0 / math.sqrt(dim)
|
177 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
178 |
+
self.to_q = nn.Linear(dim, dim)
|
179 |
+
self.to_kv = nn.Linear(dim, dim * 2)
|
180 |
+
self.proj = nn.Linear(dim, dim)
|
181 |
+
self.norm = LayerNorm(dim, eps=norm_eps)
|
182 |
+
self.mlp = nn.Sequential(
|
183 |
+
nn.Linear(dim, int(dim * mlp_ratio)),
|
184 |
+
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
|
185 |
+
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
|
186 |
+
|
187 |
+
def forward(self, x):
|
188 |
+
"""
|
189 |
+
x: [B, L, C].
|
190 |
+
"""
|
191 |
+
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
|
192 |
+
|
193 |
+
# compute query, key, value
|
194 |
+
q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
|
195 |
+
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
|
196 |
+
|
197 |
+
# compute attention
|
198 |
+
x = flash_attention(q, k, v, version=2)
|
199 |
+
# x = flash_attention(q, k, v)
|
200 |
+
|
201 |
+
x = x.reshape(b, 1, c)
|
202 |
+
|
203 |
+
# output
|
204 |
+
x = self.proj(x)
|
205 |
+
x = F.dropout(x, self.proj_dropout, self.training)
|
206 |
+
|
207 |
+
# mlp
|
208 |
+
x = x + self.mlp(self.norm(x))
|
209 |
+
return x[:, 0]
|
210 |
+
|
211 |
+
|
212 |
+
class VisionTransformer(nn.Module):
|
213 |
+
|
214 |
+
def __init__(self,
|
215 |
+
image_size=224,
|
216 |
+
patch_size=16,
|
217 |
+
dim=768,
|
218 |
+
mlp_ratio=4,
|
219 |
+
out_dim=512,
|
220 |
+
num_heads=12,
|
221 |
+
num_layers=12,
|
222 |
+
pool_type='token',
|
223 |
+
pre_norm=True,
|
224 |
+
post_norm=False,
|
225 |
+
activation='quick_gelu',
|
226 |
+
attn_dropout=0.0,
|
227 |
+
proj_dropout=0.0,
|
228 |
+
embedding_dropout=0.0,
|
229 |
+
norm_eps=1e-5):
|
230 |
+
if image_size % patch_size != 0:
|
231 |
+
print(
|
232 |
+
'[WARNING] image_size is not divisible by patch_size',
|
233 |
+
flush=True)
|
234 |
+
assert pool_type in ('token', 'token_fc', 'attn_pool')
|
235 |
+
out_dim = out_dim or dim
|
236 |
+
super().__init__()
|
237 |
+
self.image_size = image_size
|
238 |
+
self.patch_size = patch_size
|
239 |
+
self.num_patches = (image_size // patch_size)**2
|
240 |
+
self.dim = dim
|
241 |
+
self.mlp_ratio = mlp_ratio
|
242 |
+
self.out_dim = out_dim
|
243 |
+
self.num_heads = num_heads
|
244 |
+
self.num_layers = num_layers
|
245 |
+
self.pool_type = pool_type
|
246 |
+
self.post_norm = post_norm
|
247 |
+
self.norm_eps = norm_eps
|
248 |
+
|
249 |
+
# embeddings
|
250 |
+
gain = 1.0 / math.sqrt(dim)
|
251 |
+
self.patch_embedding = nn.Conv2d(
|
252 |
+
3,
|
253 |
+
dim,
|
254 |
+
kernel_size=patch_size,
|
255 |
+
stride=patch_size,
|
256 |
+
bias=not pre_norm)
|
257 |
+
if pool_type in ('token', 'token_fc'):
|
258 |
+
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
|
259 |
+
self.pos_embedding = nn.Parameter(gain * torch.randn(
|
260 |
+
1, self.num_patches +
|
261 |
+
(1 if pool_type in ('token', 'token_fc') else 0), dim))
|
262 |
+
self.dropout = nn.Dropout(embedding_dropout)
|
263 |
+
|
264 |
+
# transformer
|
265 |
+
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
|
266 |
+
self.transformer = nn.Sequential(*[
|
267 |
+
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
|
268 |
+
activation, attn_dropout, proj_dropout, norm_eps)
|
269 |
+
for _ in range(num_layers)
|
270 |
+
])
|
271 |
+
self.post_norm = LayerNorm(dim, eps=norm_eps)
|
272 |
+
|
273 |
+
# head
|
274 |
+
if pool_type == 'token':
|
275 |
+
self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
|
276 |
+
elif pool_type == 'token_fc':
|
277 |
+
self.head = nn.Linear(dim, out_dim)
|
278 |
+
elif pool_type == 'attn_pool':
|
279 |
+
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
|
280 |
+
proj_dropout, norm_eps)
|
281 |
+
|
282 |
+
def forward(self, x, interpolation=False, use_31_block=False):
|
283 |
+
b = x.size(0)
|
284 |
+
|
285 |
+
# embeddings
|
286 |
+
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
|
287 |
+
if self.pool_type in ('token', 'token_fc'):
|
288 |
+
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
|
289 |
+
if interpolation:
|
290 |
+
e = pos_interpolate(self.pos_embedding, x.size(1))
|
291 |
+
else:
|
292 |
+
e = self.pos_embedding
|
293 |
+
x = self.dropout(x + e)
|
294 |
+
if self.pre_norm is not None:
|
295 |
+
x = self.pre_norm(x)
|
296 |
+
|
297 |
+
# transformer
|
298 |
+
if use_31_block:
|
299 |
+
x = self.transformer[:-1](x)
|
300 |
+
return x
|
301 |
+
else:
|
302 |
+
x = self.transformer(x)
|
303 |
+
return x
|
304 |
+
|
305 |
+
|
306 |
+
class XLMRobertaWithHead(XLMRoberta):
|
307 |
+
|
308 |
+
def __init__(self, **kwargs):
|
309 |
+
self.out_dim = kwargs.pop('out_dim')
|
310 |
+
super().__init__(**kwargs)
|
311 |
+
|
312 |
+
# head
|
313 |
+
mid_dim = (self.dim + self.out_dim) // 2
|
314 |
+
self.head = nn.Sequential(
|
315 |
+
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
|
316 |
+
nn.Linear(mid_dim, self.out_dim, bias=False))
|
317 |
+
|
318 |
+
def forward(self, ids):
|
319 |
+
# xlm-roberta
|
320 |
+
x = super().forward(ids)
|
321 |
+
|
322 |
+
# average pooling
|
323 |
+
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
|
324 |
+
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
|
325 |
+
|
326 |
+
# head
|
327 |
+
x = self.head(x)
|
328 |
+
return x
|
329 |
+
|
330 |
+
|
331 |
+
class XLMRobertaCLIP(nn.Module):
|
332 |
+
|
333 |
+
def __init__(self,
|
334 |
+
embed_dim=1024,
|
335 |
+
image_size=224,
|
336 |
+
patch_size=14,
|
337 |
+
vision_dim=1280,
|
338 |
+
vision_mlp_ratio=4,
|
339 |
+
vision_heads=16,
|
340 |
+
vision_layers=32,
|
341 |
+
vision_pool='token',
|
342 |
+
vision_pre_norm=True,
|
343 |
+
vision_post_norm=False,
|
344 |
+
activation='gelu',
|
345 |
+
vocab_size=250002,
|
346 |
+
max_text_len=514,
|
347 |
+
type_size=1,
|
348 |
+
pad_id=1,
|
349 |
+
text_dim=1024,
|
350 |
+
text_heads=16,
|
351 |
+
text_layers=24,
|
352 |
+
text_post_norm=True,
|
353 |
+
text_dropout=0.1,
|
354 |
+
attn_dropout=0.0,
|
355 |
+
proj_dropout=0.0,
|
356 |
+
embedding_dropout=0.0,
|
357 |
+
norm_eps=1e-5):
|
358 |
+
super().__init__()
|
359 |
+
self.embed_dim = embed_dim
|
360 |
+
self.image_size = image_size
|
361 |
+
self.patch_size = patch_size
|
362 |
+
self.vision_dim = vision_dim
|
363 |
+
self.vision_mlp_ratio = vision_mlp_ratio
|
364 |
+
self.vision_heads = vision_heads
|
365 |
+
self.vision_layers = vision_layers
|
366 |
+
self.vision_pre_norm = vision_pre_norm
|
367 |
+
self.vision_post_norm = vision_post_norm
|
368 |
+
self.activation = activation
|
369 |
+
self.vocab_size = vocab_size
|
370 |
+
self.max_text_len = max_text_len
|
371 |
+
self.type_size = type_size
|
372 |
+
self.pad_id = pad_id
|
373 |
+
self.text_dim = text_dim
|
374 |
+
self.text_heads = text_heads
|
375 |
+
self.text_layers = text_layers
|
376 |
+
self.text_post_norm = text_post_norm
|
377 |
+
self.norm_eps = norm_eps
|
378 |
+
|
379 |
+
# models
|
380 |
+
self.visual = VisionTransformer(
|
381 |
+
image_size=image_size,
|
382 |
+
patch_size=patch_size,
|
383 |
+
dim=vision_dim,
|
384 |
+
mlp_ratio=vision_mlp_ratio,
|
385 |
+
out_dim=embed_dim,
|
386 |
+
num_heads=vision_heads,
|
387 |
+
num_layers=vision_layers,
|
388 |
+
pool_type=vision_pool,
|
389 |
+
pre_norm=vision_pre_norm,
|
390 |
+
post_norm=vision_post_norm,
|
391 |
+
activation=activation,
|
392 |
+
attn_dropout=attn_dropout,
|
393 |
+
proj_dropout=proj_dropout,
|
394 |
+
embedding_dropout=embedding_dropout,
|
395 |
+
norm_eps=norm_eps)
|
396 |
+
self.textual = XLMRobertaWithHead(
|
397 |
+
vocab_size=vocab_size,
|
398 |
+
max_seq_len=max_text_len,
|
399 |
+
type_size=type_size,
|
400 |
+
pad_id=pad_id,
|
401 |
+
dim=text_dim,
|
402 |
+
out_dim=embed_dim,
|
403 |
+
num_heads=text_heads,
|
404 |
+
num_layers=text_layers,
|
405 |
+
post_norm=text_post_norm,
|
406 |
+
dropout=text_dropout)
|
407 |
+
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
|
408 |
+
|
409 |
+
def forward(self, imgs, txt_ids):
|
410 |
+
"""
|
411 |
+
imgs: [B, 3, H, W] of torch.float32.
|
412 |
+
- mean: [0.48145466, 0.4578275, 0.40821073]
|
413 |
+
- std: [0.26862954, 0.26130258, 0.27577711]
|
414 |
+
txt_ids: [B, L] of torch.long.
|
415 |
+
Encoded by data.CLIPTokenizer.
|
416 |
+
"""
|
417 |
+
xi = self.visual(imgs)
|
418 |
+
xt = self.textual(txt_ids)
|
419 |
+
return xi, xt
|
420 |
+
|
421 |
+
def param_groups(self):
|
422 |
+
groups = [{
|
423 |
+
'params': [
|
424 |
+
p for n, p in self.named_parameters()
|
425 |
+
if 'norm' in n or n.endswith('bias')
|
426 |
+
],
|
427 |
+
'weight_decay': 0.0
|
428 |
+
}, {
|
429 |
+
'params': [
|
430 |
+
p for n, p in self.named_parameters()
|
431 |
+
if not ('norm' in n or n.endswith('bias'))
|
432 |
+
]
|
433 |
+
}]
|
434 |
+
return groups
|
435 |
+
|
436 |
+
|
437 |
+
def _clip(pretrained=False,
|
438 |
+
pretrained_name=None,
|
439 |
+
model_cls=XLMRobertaCLIP,
|
440 |
+
return_transforms=False,
|
441 |
+
return_tokenizer=False,
|
442 |
+
tokenizer_padding='eos',
|
443 |
+
dtype=torch.float32,
|
444 |
+
device='cpu',
|
445 |
+
**kwargs):
|
446 |
+
# init a model on device
|
447 |
+
with torch.device(device):
|
448 |
+
model = model_cls(**kwargs)
|
449 |
+
|
450 |
+
# set device
|
451 |
+
model = model.to(dtype=dtype, device=device)
|
452 |
+
output = (model,)
|
453 |
+
|
454 |
+
# init transforms
|
455 |
+
if return_transforms:
|
456 |
+
# mean and std
|
457 |
+
if 'siglip' in pretrained_name.lower():
|
458 |
+
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
|
459 |
+
else:
|
460 |
+
mean = [0.48145466, 0.4578275, 0.40821073]
|
461 |
+
std = [0.26862954, 0.26130258, 0.27577711]
|
462 |
+
|
463 |
+
# transforms
|
464 |
+
transforms = T.Compose([
|
465 |
+
T.Resize((model.image_size, model.image_size),
|
466 |
+
interpolation=T.InterpolationMode.BICUBIC),
|
467 |
+
T.ToTensor(),
|
468 |
+
T.Normalize(mean=mean, std=std)
|
469 |
+
])
|
470 |
+
output += (transforms,)
|
471 |
+
return output[0] if len(output) == 1 else output
|
472 |
+
|
473 |
+
|
474 |
+
def clip_xlm_roberta_vit_h_14(
|
475 |
+
pretrained=False,
|
476 |
+
pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
|
477 |
+
**kwargs):
|
478 |
+
cfg = dict(
|
479 |
+
embed_dim=1024,
|
480 |
+
image_size=224,
|
481 |
+
patch_size=14,
|
482 |
+
vision_dim=1280,
|
483 |
+
vision_mlp_ratio=4,
|
484 |
+
vision_heads=16,
|
485 |
+
vision_layers=32,
|
486 |
+
vision_pool='token',
|
487 |
+
activation='gelu',
|
488 |
+
vocab_size=250002,
|
489 |
+
max_text_len=514,
|
490 |
+
type_size=1,
|
491 |
+
pad_id=1,
|
492 |
+
text_dim=1024,
|
493 |
+
text_heads=16,
|
494 |
+
text_layers=24,
|
495 |
+
text_post_norm=True,
|
496 |
+
text_dropout=0.1,
|
497 |
+
attn_dropout=0.0,
|
498 |
+
proj_dropout=0.0,
|
499 |
+
embedding_dropout=0.0)
|
500 |
+
cfg.update(**kwargs)
|
501 |
+
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
|
502 |
+
|
503 |
+
|
504 |
+
class CLIPModel:
|
505 |
+
|
506 |
+
def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
|
507 |
+
self.dtype = dtype
|
508 |
+
self.device = device
|
509 |
+
self.checkpoint_path = checkpoint_path
|
510 |
+
self.tokenizer_path = tokenizer_path
|
511 |
+
|
512 |
+
# init model
|
513 |
+
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
|
514 |
+
pretrained=False,
|
515 |
+
return_transforms=True,
|
516 |
+
return_tokenizer=False,
|
517 |
+
dtype=dtype,
|
518 |
+
device=device)
|
519 |
+
self.model = self.model.eval().requires_grad_(False)
|
520 |
+
logging.info(f'loading {checkpoint_path}')
|
521 |
+
self.model.load_state_dict(
|
522 |
+
torch.load(checkpoint_path, map_location='cpu'))
|
523 |
+
|
524 |
+
# init tokenizer
|
525 |
+
self.tokenizer = HuggingfaceTokenizer(
|
526 |
+
name=tokenizer_path,
|
527 |
+
seq_len=self.model.max_text_len - 2,
|
528 |
+
clean='whitespace')
|
529 |
+
|
530 |
+
def visual(self, videos):
|
531 |
+
# preprocess
|
532 |
+
size = (self.model.image_size,) * 2
|
533 |
+
videos = torch.cat([
|
534 |
+
F.interpolate(
|
535 |
+
u.transpose(0, 1),
|
536 |
+
size=size,
|
537 |
+
mode='bicubic',
|
538 |
+
align_corners=False) for u in videos
|
539 |
+
])
|
540 |
+
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
541 |
+
|
542 |
+
# forward
|
543 |
+
with torch.cuda.amp.autocast(dtype=self.dtype):
|
544 |
+
out = self.model.visual(videos, use_31_block=True)
|
545 |
+
return out
|
ovi/modules/fusion.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from ovi.modules.model import WanLayerNorm, WanModel, WanRMSNorm, gradient_checkpointing, rope_apply
|
5 |
+
from ovi.modules.attention import flash_attention
|
6 |
+
from ovi.distributed_comms.communications import all_gather, all_to_all_4D
|
7 |
+
from ovi.distributed_comms.parallel_states import nccl_info, get_sequence_parallel_state
|
8 |
+
|
9 |
+
class FusionModel(nn.Module):
|
10 |
+
def __init__(self, video_config=None, audio_config=None):
|
11 |
+
super().__init__()
|
12 |
+
has_video = True
|
13 |
+
has_audio = True
|
14 |
+
if video_config is not None:
|
15 |
+
self.video_model = WanModel(**video_config)
|
16 |
+
else:
|
17 |
+
has_video = False
|
18 |
+
self.video_model = None
|
19 |
+
print("Warning: No video model is provided!")
|
20 |
+
|
21 |
+
if audio_config is not None:
|
22 |
+
self.audio_model = WanModel(**audio_config)
|
23 |
+
else:
|
24 |
+
has_audio = False
|
25 |
+
self.audio_model = None
|
26 |
+
print("Warning: No audio model is provided!")
|
27 |
+
|
28 |
+
if has_video and has_audio:
|
29 |
+
assert len(self.video_model.blocks) == len(self.audio_model.blocks)
|
30 |
+
self.num_blocks = len(self.video_model.blocks)
|
31 |
+
|
32 |
+
self.use_sp = get_sequence_parallel_state()
|
33 |
+
if self.use_sp:
|
34 |
+
self.sp_size = nccl_info.sp_size
|
35 |
+
self.sp_rank = nccl_info.rank_within_group
|
36 |
+
self.inject_cross_attention_kv_projections()
|
37 |
+
|
38 |
+
self.init_weights()
|
39 |
+
|
40 |
+
def inject_cross_attention_kv_projections(self):
|
41 |
+
for vid_block in self.video_model.blocks:
|
42 |
+
vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim)
|
43 |
+
vid_block.cross_attn.v_fusion = nn.Linear(vid_block.dim, vid_block.dim)
|
44 |
+
vid_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(vid_block.dim, elementwise_affine=True)
|
45 |
+
vid_block.cross_attn.norm_k_fusion = WanRMSNorm(vid_block.dim, eps=1e-6) if vid_block.qk_norm else nn.Identity()
|
46 |
+
|
47 |
+
|
48 |
+
for audio_block in self.audio_model.blocks:
|
49 |
+
audio_block.cross_attn.k_fusion = nn.Linear(audio_block.dim, audio_block.dim)
|
50 |
+
audio_block.cross_attn.v_fusion = nn.Linear(audio_block.dim, audio_block.dim)
|
51 |
+
audio_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(audio_block.dim, elementwise_affine=True)
|
52 |
+
audio_block.cross_attn.norm_k_fusion = WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity()
|
53 |
+
|
54 |
+
|
55 |
+
def merge_kwargs(self, vid_kwargs, audio_kwargs):
|
56 |
+
"""
|
57 |
+
keys in each kwarg:
|
58 |
+
e
|
59 |
+
seq_lens
|
60 |
+
grid_sizes
|
61 |
+
freqs
|
62 |
+
context
|
63 |
+
context_lens
|
64 |
+
"""
|
65 |
+
merged_kwargs = {}
|
66 |
+
for key in vid_kwargs:
|
67 |
+
merged_kwargs[f"vid_{key}"] = vid_kwargs[key]
|
68 |
+
for key in audio_kwargs:
|
69 |
+
merged_kwargs[f"audio_{key}"] = audio_kwargs[key]
|
70 |
+
return merged_kwargs
|
71 |
+
|
72 |
+
def single_fusion_cross_attention_forward(self,
|
73 |
+
cross_attn_block,
|
74 |
+
src_seq,
|
75 |
+
src_grid_sizes,
|
76 |
+
src_freqs,
|
77 |
+
target_seq,
|
78 |
+
target_seq_lens,
|
79 |
+
target_grid_sizes,
|
80 |
+
target_freqs,
|
81 |
+
context,
|
82 |
+
context_lens
|
83 |
+
):
|
84 |
+
b, n, d = src_seq.size(0), cross_attn_block.num_heads, cross_attn_block.head_dim
|
85 |
+
if hasattr(cross_attn_block, "k_img"):
|
86 |
+
## means is i2v block
|
87 |
+
q, k, v, k_img, v_img = cross_attn_block.qkv_fn(src_seq, context)
|
88 |
+
else:
|
89 |
+
## means is t2v block
|
90 |
+
q, k, v = cross_attn_block.qkv_fn(src_seq, context)
|
91 |
+
k_img = v_img = None
|
92 |
+
|
93 |
+
|
94 |
+
if self.use_sp:
|
95 |
+
q = all_to_all_4D(q, scatter_dim=2, gather_dim=1)
|
96 |
+
k = torch.chunk(k, self.sp_size, dim=2)[self.sp_rank]
|
97 |
+
v = torch.chunk(v, self.sp_size, dim=2)[self.sp_rank]
|
98 |
+
if k_img is not None:
|
99 |
+
k_img = torch.chunk(k_img, self.sp_size, dim=2)[self.sp_rank]
|
100 |
+
if v_img is not None:
|
101 |
+
v_img = torch.chunk(v_img, self.sp_size, dim=2)[self.sp_rank]
|
102 |
+
|
103 |
+
x = flash_attention(q, k, v, k_lens=context_lens)
|
104 |
+
|
105 |
+
if k_img is not None:
|
106 |
+
img_x = flash_attention(q, k_img, v_img, k_lens=None)
|
107 |
+
x = x + img_x
|
108 |
+
|
109 |
+
is_vid = src_grid_sizes.shape[1] > 1
|
110 |
+
# compute target attention
|
111 |
+
target_seq = cross_attn_block.pre_attn_norm_fusion(target_seq)
|
112 |
+
k_target = cross_attn_block.norm_k_fusion(cross_attn_block.k_fusion(target_seq)).view(b, -1, n, d)
|
113 |
+
v_target = cross_attn_block.v_fusion(target_seq).view(b, -1, n, d)
|
114 |
+
if self.use_sp:
|
115 |
+
k_target = all_to_all_4D(k_target, scatter_dim=2, gather_dim=1) # [B, L, H/P, C/H]
|
116 |
+
v_target = all_to_all_4D(v_target, scatter_dim=2, gather_dim=1) # [B, L, H/P, C/H]
|
117 |
+
|
118 |
+
q = rope_apply(q, src_grid_sizes, src_freqs)
|
119 |
+
k_target = rope_apply(k_target, target_grid_sizes, target_freqs)
|
120 |
+
|
121 |
+
target_x = flash_attention(q, k_target, v_target, k_lens=target_seq_lens)
|
122 |
+
|
123 |
+
x = x + target_x
|
124 |
+
if self.use_sp:
|
125 |
+
x = all_to_all_4D(x, scatter_dim=1, gather_dim=2) # [B, L/P, H, C/H]
|
126 |
+
|
127 |
+
x = x.flatten(2) # [B, L/P, C]
|
128 |
+
|
129 |
+
x = cross_attn_block.o(x)
|
130 |
+
return x
|
131 |
+
|
132 |
+
def single_fusion_cross_attention_ffn_forward(self,
|
133 |
+
attn_block,
|
134 |
+
src_seq,
|
135 |
+
src_grid_sizes,
|
136 |
+
src_freqs,
|
137 |
+
target_seq,
|
138 |
+
target_seq_lens,
|
139 |
+
target_grid_sizes,
|
140 |
+
target_freqs,
|
141 |
+
context,
|
142 |
+
context_lens,
|
143 |
+
src_e):
|
144 |
+
|
145 |
+
src_seq = src_seq + self.single_fusion_cross_attention_forward(attn_block.cross_attn,
|
146 |
+
attn_block.norm3(src_seq),
|
147 |
+
src_grid_sizes=src_grid_sizes,
|
148 |
+
src_freqs=src_freqs,
|
149 |
+
target_seq=target_seq,
|
150 |
+
target_seq_lens=target_seq_lens,
|
151 |
+
target_grid_sizes=target_grid_sizes,
|
152 |
+
target_freqs=target_freqs,
|
153 |
+
context=context,
|
154 |
+
context_lens=context_lens
|
155 |
+
)
|
156 |
+
y = attn_block.ffn(attn_block.norm2(src_seq).bfloat16() * (1 + src_e[4].squeeze(2)) + src_e[3].squeeze(2))
|
157 |
+
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
158 |
+
src_seq = src_seq + y * src_e[5].squeeze(2)
|
159 |
+
return src_seq
|
160 |
+
|
161 |
+
def single_fusion_block_forward(self,
|
162 |
+
vid_block,
|
163 |
+
audio_block,
|
164 |
+
vid,
|
165 |
+
audio,
|
166 |
+
vid_e,
|
167 |
+
vid_seq_lens,
|
168 |
+
vid_grid_sizes,
|
169 |
+
vid_freqs,
|
170 |
+
vid_context,
|
171 |
+
vid_context_lens,
|
172 |
+
audio_e,
|
173 |
+
audio_seq_lens,
|
174 |
+
audio_grid_sizes,
|
175 |
+
audio_freqs,
|
176 |
+
audio_context,
|
177 |
+
audio_context_lens
|
178 |
+
):
|
179 |
+
## audio modulation
|
180 |
+
assert audio_e.dtype == torch.bfloat16
|
181 |
+
assert len(audio_e.shape) == 4 and audio_e.size(2) == 6 and audio_e.shape[1] == audio.shape[1], f"{audio_e.shape}, {audio.shape}"
|
182 |
+
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
183 |
+
audio_e = audio_block.modulation(audio_e).chunk(6, dim=2)
|
184 |
+
assert audio_e[0].dtype == torch.bfloat16
|
185 |
+
|
186 |
+
# audio self-attention
|
187 |
+
audio_y = audio_block.self_attn(
|
188 |
+
audio_block.norm1(audio).bfloat16() * (1 + audio_e[1].squeeze(2)) + audio_e[0].squeeze(2), audio_seq_lens, audio_grid_sizes,
|
189 |
+
audio_freqs)
|
190 |
+
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
191 |
+
audio = audio + audio_y * audio_e[2].squeeze(2)
|
192 |
+
|
193 |
+
## video modulation
|
194 |
+
assert len(vid_e.shape) == 4 and vid_e.size(2) == 6 and vid_e.shape[1] == vid.shape[1], f"{vid_e.shape}, {vid.shape}"
|
195 |
+
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
196 |
+
vid_e = vid_block.modulation(vid_e).chunk(6, dim=2)
|
197 |
+
|
198 |
+
# video self-attention
|
199 |
+
vid_y = vid_block.self_attn(
|
200 |
+
vid_block.norm1(vid).bfloat16() * (1 + vid_e[1].squeeze(2)) + vid_e[0].squeeze(2), vid_seq_lens, vid_grid_sizes,
|
201 |
+
vid_freqs)
|
202 |
+
|
203 |
+
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
204 |
+
vid = vid + vid_y * vid_e[2].squeeze(2)
|
205 |
+
|
206 |
+
og_audio = audio
|
207 |
+
|
208 |
+
# audio cross-attention
|
209 |
+
audio = self.single_fusion_cross_attention_ffn_forward(
|
210 |
+
audio_block,
|
211 |
+
audio,
|
212 |
+
audio_grid_sizes,
|
213 |
+
audio_freqs,
|
214 |
+
vid,
|
215 |
+
vid_seq_lens,
|
216 |
+
vid_grid_sizes,
|
217 |
+
vid_freqs,
|
218 |
+
audio_context,
|
219 |
+
audio_context_lens,
|
220 |
+
audio_e
|
221 |
+
)
|
222 |
+
|
223 |
+
assert not torch.equal(og_audio, audio), "Audio should be changed after cross-attention!"
|
224 |
+
|
225 |
+
# video cross-attention
|
226 |
+
vid = self.single_fusion_cross_attention_ffn_forward(
|
227 |
+
vid_block,
|
228 |
+
vid,
|
229 |
+
vid_grid_sizes,
|
230 |
+
vid_freqs,
|
231 |
+
og_audio,
|
232 |
+
audio_seq_lens,
|
233 |
+
audio_grid_sizes,
|
234 |
+
audio_freqs,
|
235 |
+
vid_context,
|
236 |
+
vid_context_lens,
|
237 |
+
vid_e
|
238 |
+
)
|
239 |
+
|
240 |
+
return vid, audio
|
241 |
+
|
242 |
+
def forward(
|
243 |
+
self,
|
244 |
+
vid,
|
245 |
+
audio,
|
246 |
+
t,
|
247 |
+
vid_context,
|
248 |
+
audio_context,
|
249 |
+
vid_seq_len,
|
250 |
+
audio_seq_len,
|
251 |
+
clip_fea=None,
|
252 |
+
clip_fea_audio=None,
|
253 |
+
y=None,
|
254 |
+
first_frame_is_clean=False,
|
255 |
+
slg_layer=False
|
256 |
+
):
|
257 |
+
|
258 |
+
assert clip_fea is None
|
259 |
+
assert y is None
|
260 |
+
|
261 |
+
if vid is None or all([x is None for x in vid]):
|
262 |
+
assert vid_context is None
|
263 |
+
assert vid_seq_len is None
|
264 |
+
assert self.audio_model is not None
|
265 |
+
|
266 |
+
return None, self.audio_model(x=audio, t=t, context=audio_context, seq_len=audio_seq_len, clip_fea=clip_fea_audio, y=None)
|
267 |
+
|
268 |
+
if audio is None or all([x is None for x in audio]):
|
269 |
+
assert clip_fea_audio is None
|
270 |
+
assert audio_context is None
|
271 |
+
assert audio_seq_len is None
|
272 |
+
assert self.video_model is not None
|
273 |
+
|
274 |
+
return self.video_model(x=vid, t=t, context=vid_context, seq_len=vid_seq_len, clip_fea=clip_fea, y=y, first_frame_is_clean=first_frame_is_clean), None
|
275 |
+
|
276 |
+
vid, vid_e, vid_kwargs = self.video_model.prepare_transformer_block_kwargs(
|
277 |
+
x=vid, t=t, context=vid_context, seq_len=vid_seq_len, clip_fea=clip_fea, y=y, first_frame_is_clean=first_frame_is_clean
|
278 |
+
)
|
279 |
+
|
280 |
+
audio, audio_e, audio_kwargs = self.audio_model.prepare_transformer_block_kwargs(
|
281 |
+
x=audio, t=t, context=audio_context, seq_len=audio_seq_len, clip_fea=clip_fea_audio, y=None, first_frame_is_clean=False
|
282 |
+
)
|
283 |
+
|
284 |
+
kwargs = self.merge_kwargs(vid_kwargs, audio_kwargs)
|
285 |
+
|
286 |
+
for i in range(self.num_blocks):
|
287 |
+
"""
|
288 |
+
1 fusion block refers to 1 audio block with 1 video block.
|
289 |
+
"""
|
290 |
+
if slg_layer > 0 and i == slg_layer:
|
291 |
+
continue
|
292 |
+
vid_block = self.video_model.blocks[i]
|
293 |
+
audio_block = self.audio_model.blocks[i]
|
294 |
+
vid, audio = gradient_checkpointing(
|
295 |
+
enabled=(self.training and self.gradient_checkpointing),
|
296 |
+
module=self.single_fusion_block_forward,
|
297 |
+
vid_block=vid_block,
|
298 |
+
audio_block=audio_block,
|
299 |
+
vid=vid,
|
300 |
+
audio=audio,
|
301 |
+
**kwargs
|
302 |
+
)
|
303 |
+
|
304 |
+
vid = self.video_model.post_transformer_block_out(vid, vid_kwargs['grid_sizes'], vid_e)
|
305 |
+
audio = self.audio_model.post_transformer_block_out(audio, audio_kwargs['grid_sizes'], audio_e)
|
306 |
+
|
307 |
+
return vid, audio
|
308 |
+
|
309 |
+
def init_weights(self):
|
310 |
+
if self.audio_model is not None:
|
311 |
+
self.audio_model.init_weights()
|
312 |
+
|
313 |
+
if self.video_model is not None:
|
314 |
+
self.video_model.init_weights()
|
315 |
+
|
316 |
+
for name, mod in self.video_model.named_modules():
|
317 |
+
if "fusion" in name and isinstance(mod, nn.Linear):
|
318 |
+
with torch.no_grad():
|
319 |
+
mod.weight.div_(10.0)
|
320 |
+
|
321 |
+
|
322 |
+
def set_rope_params(self):
|
323 |
+
self.video_model.set_rope_params()
|
324 |
+
self.audio_model.set_rope_params()
|
ovi/modules/mmaudio/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# MMAudio package
|
ovi/modules/mmaudio/ext/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|