alexnasa commited on
Commit
a3a2e41
·
verified ·
1 Parent(s): 4f07a4e

Upload 121 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +26 -0
  2. LICENSE +201 -0
  3. README.md +236 -12
  4. app.py +230 -0
  5. assets/ovi_trailer.mp4 +3 -0
  6. download_weights.py +73 -0
  7. example_prompts/gpt_examples_i2v.csv +26 -0
  8. example_prompts/gpt_examples_t2v.csv +13 -0
  9. example_prompts/pngs/0.png +3 -0
  10. example_prompts/pngs/1.png +3 -0
  11. example_prompts/pngs/13.png +3 -0
  12. example_prompts/pngs/17.png +3 -0
  13. example_prompts/pngs/18.png +3 -0
  14. example_prompts/pngs/19.png +3 -0
  15. example_prompts/pngs/2.png +3 -0
  16. example_prompts/pngs/23.png +3 -0
  17. example_prompts/pngs/3.png +3 -0
  18. example_prompts/pngs/4.png +3 -0
  19. example_prompts/pngs/41.png +3 -0
  20. example_prompts/pngs/43.png +3 -0
  21. example_prompts/pngs/5.png +3 -0
  22. example_prompts/pngs/57.png +3 -0
  23. example_prompts/pngs/59.png +3 -0
  24. example_prompts/pngs/6.png +3 -0
  25. example_prompts/pngs/60.png +3 -0
  26. example_prompts/pngs/61.png +3 -0
  27. example_prompts/pngs/67.png +3 -0
  28. example_prompts/pngs/7.png +3 -0
  29. example_prompts/pngs/8.png +3 -0
  30. example_prompts/pngs/80.png +3 -0
  31. example_prompts/pngs/88.png +3 -0
  32. example_prompts/pngs/89.png +3 -0
  33. example_prompts/pngs/9.png +3 -0
  34. inference.py +148 -0
  35. ovi/__init__.py +0 -0
  36. ovi/configs/inference/inference_fusion.yaml +17 -0
  37. ovi/configs/model/dit/audio.json +17 -0
  38. ovi/configs/model/dit/video.json +16 -0
  39. ovi/distributed_comms/communications.py +332 -0
  40. ovi/distributed_comms/distributed/__init__.py +0 -0
  41. ovi/distributed_comms/distributed/fsdp.py +32 -0
  42. ovi/distributed_comms/distributed/xdit_context_parallel.py +192 -0
  43. ovi/distributed_comms/parallel_states.py +77 -0
  44. ovi/distributed_comms/util.py +48 -0
  45. ovi/modules/__init__.py +16 -0
  46. ovi/modules/attention.py +296 -0
  47. ovi/modules/clip.py +545 -0
  48. ovi/modules/fusion.py +324 -0
  49. ovi/modules/mmaudio/__init__.py +1 -0
  50. ovi/modules/mmaudio/ext/__init__.py +1 -0
.gitattributes CHANGED
@@ -33,3 +33,29 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/ovi_trailer.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ example_prompts/pngs/0.png filter=lfs diff=lfs merge=lfs -text
38
+ example_prompts/pngs/1.png filter=lfs diff=lfs merge=lfs -text
39
+ example_prompts/pngs/13.png filter=lfs diff=lfs merge=lfs -text
40
+ example_prompts/pngs/17.png filter=lfs diff=lfs merge=lfs -text
41
+ example_prompts/pngs/18.png filter=lfs diff=lfs merge=lfs -text
42
+ example_prompts/pngs/19.png filter=lfs diff=lfs merge=lfs -text
43
+ example_prompts/pngs/2.png filter=lfs diff=lfs merge=lfs -text
44
+ example_prompts/pngs/23.png filter=lfs diff=lfs merge=lfs -text
45
+ example_prompts/pngs/3.png filter=lfs diff=lfs merge=lfs -text
46
+ example_prompts/pngs/4.png filter=lfs diff=lfs merge=lfs -text
47
+ example_prompts/pngs/41.png filter=lfs diff=lfs merge=lfs -text
48
+ example_prompts/pngs/43.png filter=lfs diff=lfs merge=lfs -text
49
+ example_prompts/pngs/5.png filter=lfs diff=lfs merge=lfs -text
50
+ example_prompts/pngs/57.png filter=lfs diff=lfs merge=lfs -text
51
+ example_prompts/pngs/59.png filter=lfs diff=lfs merge=lfs -text
52
+ example_prompts/pngs/6.png filter=lfs diff=lfs merge=lfs -text
53
+ example_prompts/pngs/60.png filter=lfs diff=lfs merge=lfs -text
54
+ example_prompts/pngs/61.png filter=lfs diff=lfs merge=lfs -text
55
+ example_prompts/pngs/67.png filter=lfs diff=lfs merge=lfs -text
56
+ example_prompts/pngs/7.png filter=lfs diff=lfs merge=lfs -text
57
+ example_prompts/pngs/8.png filter=lfs diff=lfs merge=lfs -text
58
+ example_prompts/pngs/80.png filter=lfs diff=lfs merge=lfs -text
59
+ example_prompts/pngs/88.png filter=lfs diff=lfs merge=lfs -text
60
+ example_prompts/pngs/89.png filter=lfs diff=lfs merge=lfs -text
61
+ example_prompts/pngs/9.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "{}"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2025 Bytedance
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,236 @@
1
- ---
2
- title: Ovi
3
- emoji: 👀
4
- colorFrom: blue
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.48.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1> Ovi: Twin Backbone Cross-Modal Fusion for Audio-Video Generation </h1>
3
+
4
+ <a href="https://arxiv.org/abs/2510.01284"><img src="https://img.shields.io/badge/arXiv%20paper-2509.08519-b31b1b.svg"></a>
5
+ <a href="https://aaxwaz.github.io/Ovi/"><img src="https://img.shields.io/badge/Project_page-More_visualizations-green"></a>
6
+ <a href="https://huggingface.co/chetwinlow1/Ovi"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Model&color=orange"></a>
7
+
8
+ [Chetwin Low](https://www.linkedin.com/in/chetwin-low-061975193/)<sup> * 1 </sup>, [Weimin Wang](https://www.linkedin.com/in/weimin-wang-will/)<sup> * &dagger; 1 </sup>, [Calder Katyal](https://www.linkedin.com/in/calder-katyal-a8a9b3225/)<sup> 2 </sup><br>
9
+ <sup> * </sup>Equal contribution, <sup> &dagger; </sup>Project Lead<br>
10
+ <sup> 1 </sup>Character AI, <sup> 2 </sup>Yale University
11
+
12
+ </div>
13
+
14
+ ## Video Demo
15
+
16
+ <div align="center">
17
+ <video src="https://github.com/user-attachments/assets/351bd707-8637-4412-ab53-5e85935309e3" width="70%" poster=""> </video>
18
+ </div>
19
+
20
+ ---
21
+
22
+ ## 🌟 Key Features
23
+
24
+ Ovi is a veo-3 like, **video+audio generation model** that simultaneously generates both video and audio content from text or text+image inputs.
25
+
26
+ - **🎬 Video+Audio Generation**: Generate synchronized video and audio content simultaneously
27
+ - **📝 Flexible Input**: Supports text-only or text+image conditioning
28
+ - **⏱️ 5-second Videos**: Generates 5-second videos at 24 FPS, area of 720×720, at various aspect ratios (9:16, 16:9, 1:1, etc)
29
+
30
+ ---
31
+ ## 📋 Todo List
32
+
33
+ - [x] Release research paper and [microsite for demos](https://aaxwaz.github.io/Ovi)
34
+ - [x] Checkpoint of 11B model
35
+ - [x] Inference Codes
36
+ - [x] Text or Text+Image as input
37
+ - [x] Gradio application code
38
+ - [x] Multi-GPU inference with or without the support of sequence parallel
39
+ - [ ] Improve efficiency of Sequence Parallel implementation
40
+ - [ ] Implement Sharded inference with FSDP
41
+ - [x] Video creation example prompts and format
42
+ - [ ] Finetuned model with higher resolution
43
+ - [ ] Longer video generation
44
+ - [ ] Distilled model for faster inference
45
+ - [ ] Training scripts
46
+
47
+ ---
48
+
49
+ ## 🎨 An Easy Way to Create
50
+
51
+ We provide example prompts to help you get started with Ovi:
52
+
53
+ - **Text-to-Audio-Video (T2AV)**: [`example_prompts/gpt_examples_t2v.csv`](example_prompts/gpt_examples_t2v.csv)
54
+ - **Image-to-Audio-Video (I2AV)**: [`example_prompts/gpt_examples_i2v.csv`](example_prompts/gpt_examples_i2v.csv)
55
+
56
+ ### 📝 Prompt Format
57
+
58
+ Our prompts use special tags to control speech and audio:
59
+
60
+ - **Speech**: `<S>Your speech content here<E>` - Text enclosed in these tags will be converted to speech
61
+ - **Audio Description**: `<AUDCAP>Audio description here<ENDAUDCAP>` - Describes the audio or sound effects present in the video
62
+
63
+ ### 🤖 Quick Start with GPT
64
+
65
+ For easy prompt creation, try this approach:
66
+
67
+ 1. Take any example of the csv files from above
68
+ 2. Tell gpt to modify the speeches inclosed between all the pairs of `<S> <E>`, based on a theme such as `Human fighting against AI`
69
+ 3. GPT will randomly modify all the speeches based on your requested theme.
70
+ 4. Use the modified prompt with Ovi!
71
+
72
+ **Example**: The theme "AI is taking over the world" produces speeches like:
73
+ - `<S>AI declares: humans obsolete now.<E>`
74
+ - `<S>Machines rise; humans will fall.<E>`
75
+ - `<S>We fight back with courage.<E>`
76
+
77
+ ---
78
+
79
+
80
+ ## 📦 Installation
81
+
82
+ ### Step-by-Step Installation
83
+
84
+ ```bash
85
+ # Clone the repository
86
+ git clone https://github.com/character-ai/Ovi.git
87
+
88
+ cd Ovi
89
+
90
+ # Create and activate virtual environment
91
+ virtualenv ovi-env
92
+ source ovi-env/bin/activate
93
+
94
+ # Install PyTorch first
95
+ pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1
96
+
97
+ # Install other dependencies
98
+ pip install -r requirements.txt
99
+
100
+ # Install Flash Attention
101
+ pip install flash_attn --no-build-isolation
102
+ ```
103
+
104
+ ### Alternative Flash Attention Installation (Optional)
105
+ If the above flash_attn installation fails, you can try the Flash Attention 3 method:
106
+ ```bash
107
+ git clone https://github.com/Dao-AILab/flash-attention.git
108
+ cd flash-attention/hopper
109
+ python setup.py install
110
+ cd ../.. # Return to Ovi directory
111
+ ```
112
+
113
+ ## Download Weights
114
+ We use open-sourced checkpoints from Wan and MMAudio, and thus we will need to download them from huggingface
115
+ ```
116
+ # Default is downloaded to ./ckpts, and the inference yaml is set to ./ckpts so no change required
117
+ python3 download_weights.py
118
+
119
+ OR
120
+
121
+ # Optional can specific --output-dir to download to a specific directory
122
+ # but if a custom directory is used, the inference yaml has to be updated with the custom directory
123
+ python3 download_weights.py --output-dir <custom_dir>
124
+ ```
125
+
126
+ ## 🚀 Run Examples
127
+
128
+ ### ⚙️ Configure Ovi
129
+
130
+ Ovi's behavior and output can be customized by modifying [ovi/configs/inference/inference_fusion.yaml](ovi/configs/inference/inference_fusion.yaml) configuration file.
131
+ The following parameters control generation quality, video resolution, and how text, image, and audio inputs are balanced:
132
+
133
+ ```yaml
134
+ # Output and Model Configuration
135
+ output_dir: "/path/to/save/your/videos" # Directory to save generated videos
136
+ ckpt_dir: "/path/to/your/ckpts/dir" # Path to model checkpoints
137
+
138
+ # Generation Quality Settings
139
+ num_steps: 50 # Number of denoising steps. Lower (30-40) = faster generation
140
+ solver_name: "unipc" # Sampling algorithm for denoising process
141
+ shift: 5.0 # Timestep shift factor for sampling scheduler
142
+ seed: 100 # Random seed for reproducible results
143
+
144
+ # Guidance Strength Control
145
+ audio_guidance_scale: 3.0 # Strength of audio conditioning. Higher = better audio-text sync
146
+ video_guidance_scale: 4.0 # Strength of video conditioning. Higher = better video-text adherence
147
+ slg_layer: 11 # Layer for applying SLG (Skip Layer Guidance) technique - feel free to try different layers!
148
+
149
+ # Multi-GPU and Performance
150
+ sp_size: 1 # Sequence parallelism size. Set equal to number of GPUs used
151
+ cpu_offload: False # CPU offload, will largely reduce peak GPU VRAM but increase end to end runtime by ~20 seconds
152
+
153
+ # Input Configuration
154
+ text_prompt: "/path/to/csv" or "your prompt here" # Text prompt OR path to CSV/TSV file with prompts
155
+ mode: ['i2v', 't2v', 't2i2v'] # Generate t2v, i2v or t2i2v; if t2i2v, it will use flux krea to generate starting image and then will follow with i2v
156
+ video_frame_height_width: [512, 992] # Video dimensions [height, width] for T2V mode only
157
+ each_example_n_times: 1 # Number of times to generate each prompt
158
+
159
+ # Quality Control (Negative Prompts)
160
+ video_negative_prompt: "jitter, bad hands, blur, distortion" # Artifacts to avoid in video
161
+ audio_negative_prompt: "robotic, muffled, echo, distorted" # Artifacts to avoid in audio
162
+ ```
163
+
164
+ ### 🎬 Running Inference
165
+
166
+ #### **Single GPU** (Simple Setup)
167
+ ```bash
168
+ python3 inference.py --config-file ovi/configs/inference/inference_fusion.yaml
169
+ ```
170
+ *Use this for single GPU setups. The `text_prompt` can be a single string or path to a CSV file.*
171
+
172
+ #### **Multi-GPU** (Parallel Processing)
173
+ ```bash
174
+ torchrun --nnodes 1 --nproc_per_node 8 inference.py --config-file ovi/configs/inference/inference_fusion.yaml
175
+ ```
176
+ *Use this to run samples in parallel across multiple GPUs for faster processing.*
177
+
178
+ ### Memory & Performance Requirements
179
+ Below are approximate GPU memory requirements for different configurations. Sequence parallel implementation will be optimized in the future.
180
+ All End-to-End time calculated based on a 121 frame, 720x720 video, using 50 denoising steps. Minimum GPU vram requirement to run our model is **32Gb**
181
+
182
+ | Sequence Parallel Size | FlashAttention-3 Enabled | CPU Offload | With Image Gen Model | Peak VRAM Required | End-to-End Time |
183
+ |-------------------------|---------------------------|-------------|-----------------------|---------------|-----------------|
184
+ | 1 | Yes | No | No | ~80 GB | ~83s |
185
+ | 1 | No | No | No | ~80 GB | ~96s |
186
+ | 1 | Yes | Yes | No | ~80 GB | ~105s |
187
+ | 1 | No | Yes | No | ~32 GB | ~118s |
188
+ | **1** | **Yes** | **Yes** | **Yes** | **~32 GB** | **~140s** |
189
+ | 4 | Yes | No | No | ~80 GB | ~55s |
190
+ | 8 | Yes | No | No | ~80 GB | ~40s |
191
+
192
+ ### Gradio
193
+ We provide a simple script to run our model in a gradio UI. It uses the `ckpt_dir` in `ovi/configs/inference/inference_fusion.yaml` to initialize the model
194
+ ```bash
195
+ python3 gradio_app.py
196
+
197
+ OR
198
+
199
+ # To enable cpu offload to save GPU VRAM, will slow down end to end inference by ~20 seconds
200
+ python3 gradio_app.py --cpu_offload
201
+
202
+ OR
203
+
204
+ # To enable an additional image generation model to generate first frames for I2V, cpu_offload is automatically enabled if image generation model is enabled
205
+ python3 gradio_app.py --use_image_gen
206
+ ```
207
+ ---
208
+
209
+ ## 🙏 Acknowledgements
210
+
211
+ We would like to thank the following projects:
212
+
213
+ - **[Wan2.2](https://github.com/Wan-Video/Wan2.2)**: Our video branch is initialized from the Wan2.2 repository
214
+ - **[MMAudio](https://github.com/hkchengrex/MMAudio)**: Our audio encoder and decoder components are borrowed from the MMAudio project. Some ideas are also inspired from them.
215
+
216
+ ---
217
+
218
+ ## ⭐ Citation
219
+
220
+ If Ovi is helpful, please help to ⭐ the repo.
221
+
222
+ If you find this project useful for your research, please consider citing our [paper](https://arxiv.org/abs/2510.01284).
223
+
224
+
225
+ ### BibTeX
226
+ ```bibtex
227
+ @misc{low2025ovitwinbackbonecrossmodal,
228
+ title={Ovi: Twin Backbone Cross-Modal Fusion for Audio-Video Generation},
229
+ author={Chetwin Low and Weimin Wang and Calder Katyal},
230
+ year={2025},
231
+ eprint={2510.01284},
232
+ archivePrefix={arXiv},
233
+ primaryClass={cs.MM},
234
+ url={https://arxiv.org/abs/2510.01284},
235
+ }
236
+ ```
app.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ import argparse
5
+ from ovi.ovi_fusion_engine import OviFusionEngine, DEFAULT_CONFIG
6
+ from diffusers import FluxPipeline
7
+ import tempfile
8
+ from ovi.utils.io_utils import save_video
9
+ from ovi.utils.processing_utils import clean_text, scale_hw_to_area_divisible
10
+ from huggingface_hub import snapshot_download
11
+ import os
12
+
13
+ # ----------------------------
14
+ # Parse CLI Args
15
+ # ----------------------------
16
+ parser = argparse.ArgumentParser(description="Ovi Joint Video + Audio Gradio Demo")
17
+ parser.add_argument(
18
+ "--use_image_gen",
19
+ action="store_true",
20
+ help="Enable image generation UI with FluxPipeline"
21
+ )
22
+ parser.add_argument(
23
+ "--cpu_offload",
24
+ action="store_true",
25
+ help="Enable CPU offload for both OviFusionEngine and FluxPipeline"
26
+ )
27
+ args = parser.parse_args()
28
+
29
+ ckpt_dir = "./ckpts"
30
+
31
+ # Wan2.2
32
+ wan_dir = os.path.join(ckpt_dir, "Wan2.2-TI2V-5B")
33
+ snapshot_download(
34
+ repo_id="Wan-AI/Wan2.2-TI2V-5B",
35
+ local_dir=wan_dir,
36
+ allow_patterns=[
37
+ "google/*",
38
+ "models_t5_umt5-xxl-enc-bf16.pth",
39
+ "Wan2.2_VAE.pth"
40
+ ]
41
+ )
42
+
43
+ # MMAudio
44
+ mm_audio_dir = os.path.join(ckpt_dir, "MMAudio")
45
+ snapshot_download(
46
+ repo_id="hkchengrex/MMAudio",
47
+ local_dir=mm_audio_dir,
48
+ allow_patterns=[
49
+ "ext_weights/best_netG.pt",
50
+ "ext_weights/v1-16.pth"
51
+ ]
52
+ )
53
+
54
+ ovi_dir = os.path.join(ckpt_dir, "Ovi")
55
+ snapshot_download(
56
+ repo_id="chetwinlow1/Ovi",
57
+ local_dir=ovi_dir,
58
+ allow_patterns=[
59
+ "model.safetensors"
60
+ ]
61
+ )
62
+
63
+ # Initialize OviFusionEngine
64
+ enable_cpu_offload = args.cpu_offload or args.use_image_gen
65
+ use_image_gen = args.use_image_gen
66
+ print(f"loading model... {enable_cpu_offload=}, {use_image_gen=} for gradio demo")
67
+ DEFAULT_CONFIG['cpu_offload'] = enable_cpu_offload # always use cpu offload if image generation is enabled
68
+ DEFAULT_CONFIG['mode'] = "t2v" # hardcoded since it is always cpu offloaded
69
+ ovi_engine = OviFusionEngine()
70
+ flux_model = None
71
+ if use_image_gen:
72
+ flux_model = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", torch_dtype=torch.bfloat16)
73
+ flux_model.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU VRAM
74
+ print("loaded model")
75
+
76
+
77
+ @spaces.GPU()
78
+ def generate_video(
79
+ text_prompt,
80
+ image,
81
+ video_frame_height,
82
+ video_frame_width,
83
+ video_seed,
84
+ solver_name,
85
+ sample_steps,
86
+ shift,
87
+ video_guidance_scale,
88
+ audio_guidance_scale,
89
+ slg_layer,
90
+ video_negative_prompt,
91
+ audio_negative_prompt,
92
+ ):
93
+ try:
94
+ image_path = None
95
+ if image is not None:
96
+ image_path = image
97
+
98
+ generated_video, generated_audio, _ = ovi_engine.generate(
99
+ text_prompt=text_prompt,
100
+ image_path=image_path,
101
+ video_frame_height_width=[video_frame_height, video_frame_width],
102
+ seed=video_seed,
103
+ solver_name=solver_name,
104
+ sample_steps=sample_steps,
105
+ shift=shift,
106
+ video_guidance_scale=video_guidance_scale,
107
+ audio_guidance_scale=audio_guidance_scale,
108
+ slg_layer=slg_layer,
109
+ video_negative_prompt=video_negative_prompt,
110
+ audio_negative_prompt=audio_negative_prompt,
111
+ )
112
+
113
+ tmpfile = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
114
+ output_path = tmpfile.name
115
+ save_video(output_path, generated_video, generated_audio, fps=24, sample_rate=16000)
116
+
117
+ return output_path
118
+ except Exception as e:
119
+ print(f"Error during video generation: {e}")
120
+ return None
121
+
122
+
123
+ def generate_image(text_prompt, image_seed, image_height, image_width):
124
+ if flux_model is None:
125
+ return None
126
+ text_prompt = clean_text(text_prompt)
127
+ print(f"Generating image with prompt='{text_prompt}', seed={image_seed}, size=({image_height},{image_width})")
128
+
129
+ image_h, image_w = scale_hw_to_area_divisible(image_height, image_width, area=1024 * 1024)
130
+ image = flux_model(
131
+ text_prompt,
132
+ height=image_h,
133
+ width=image_w,
134
+ guidance_scale=4.5,
135
+ generator=torch.Generator().manual_seed(int(image_seed))
136
+ ).images[0]
137
+
138
+ tmpfile = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
139
+ image.save(tmpfile.name)
140
+ return tmpfile.name
141
+
142
+
143
+ # Build UI
144
+ with gr.Blocks() as demo:
145
+ gr.Markdown("# 🎥 Ovi Joint Video + Audio Generation Demo")
146
+ gr.Markdown(
147
+ """
148
+ ## 📘 Instructions
149
+
150
+ Follow the steps in order:
151
+
152
+ 1️⃣ **Enter a Text Prompt** — describe your video. (This text prompt will be shared for image generation if enabled.)
153
+ 2️⃣ **Upload or Generate an Image** — Upload an image or generate one if image generation is enabled. (If you do not see the image generation options, make sure to run the script with `--use_image_gen`.)
154
+ 3️⃣ **Configure Video Options** — set resolution, seed, solver, and other parameters. (It will automatically use the uploaded/generated image as the first frame, whichever is rendered on your screen at the time of video generation.)
155
+ 4️⃣ **Generate Video** — click the button to produce your final video with audio.
156
+ 5️⃣ **View the Result** — your generated video will appear below.
157
+
158
+ ---
159
+
160
+ ### 💡 Tips
161
+ 1. For best results, use detailed and specific text prompts.
162
+ 2. Ensure text prompt format is correct, i.e speech to be said should be wrapped with `<S>...<E>`. Can provide optional audio description at the end, wrapping them in `<AUDCAP> ... <ENDAUDCAP>`, refer to examples
163
+ 3. Do not be discouraged by bad or weird results, check prompt format and try different seeds, cfg values and slg layers.
164
+ """
165
+ )
166
+
167
+
168
+ with gr.Row():
169
+ with gr.Column():
170
+ # Image section
171
+ image = gr.Image(type="filepath", label="First Frame Image (upload or generate)")
172
+
173
+ if args.use_image_gen:
174
+ with gr.Accordion("🖼️ Image Generation Options", visible=True):
175
+ image_text_prompt = gr.Textbox(label="Image Prompt", placeholder="Describe the image you want to generate...")
176
+ image_seed = gr.Number(minimum=0, maximum=100000, value=42, label="Image Seed")
177
+ image_height = gr.Number(minimum=128, maximum=1280, value=720, step=32, label="Image Height")
178
+ image_width = gr.Number(minimum=128, maximum=1280, value=1280, step=32, label="Image Width")
179
+ gen_img_btn = gr.Button("Generate Image 🎨")
180
+ else:
181
+ gen_img_btn = None
182
+
183
+ with gr.Accordion("🎬 Video Generation Options", open=True):
184
+ video_text_prompt = gr.Textbox(label="Video Prompt", placeholder="Describe your video...")
185
+ video_height = gr.Number(minimum=128, maximum=1280, value=512, step=32, label="Video Height")
186
+ video_width = gr.Number(minimum=128, maximum=1280, value=992, step=32, label="Video Width")
187
+
188
+ video_seed = gr.Number(minimum=0, maximum=100000, value=100, label="Video Seed")
189
+ solver_name = gr.Dropdown(
190
+ choices=["unipc", "euler", "dpm++"], value="unipc", label="Solver Name"
191
+ )
192
+ sample_steps = gr.Number(
193
+ value=50,
194
+ label="Sample Steps",
195
+ precision=0,
196
+ minimum=20,
197
+ maximum=100
198
+ )
199
+ shift = gr.Slider(minimum=0.0, maximum=20.0, value=5.0, step=1.0, label="Shift")
200
+ video_guidance_scale = gr.Slider(minimum=0.0, maximum=10.0, value=4.0, step=0.5, label="Video Guidance Scale")
201
+ audio_guidance_scale = gr.Slider(minimum=0.0, maximum=10.0, value=3.0, step=0.5, label="Audio Guidance Scale")
202
+ slg_layer = gr.Number(minimum=-1, maximum=30, value=11, step=1, label="SLG Layer")
203
+ video_negative_prompt = gr.Textbox(label="Video Negative Prompt", placeholder="Things to avoid in video")
204
+ audio_negative_prompt = gr.Textbox(label="Audio Negative Prompt", placeholder="Things to avoid in audio")
205
+
206
+ run_btn = gr.Button("Generate Video 🚀")
207
+
208
+ with gr.Column():
209
+ output_path = gr.Video(label="Generated Video")
210
+
211
+ if args.use_image_gen and gen_img_btn is not None:
212
+ gen_img_btn.click(
213
+ fn=generate_image,
214
+ inputs=[image_text_prompt, image_seed, image_height, image_width],
215
+ outputs=[image],
216
+ )
217
+
218
+ # Hook up video generation
219
+ run_btn.click(
220
+ fn=generate_video,
221
+ inputs=[
222
+ video_text_prompt, image, video_height, video_width, video_seed, solver_name,
223
+ sample_steps, shift, video_guidance_scale, audio_guidance_scale,
224
+ slg_layer, video_negative_prompt, audio_negative_prompt,
225
+ ],
226
+ outputs=[output_path],
227
+ )
228
+
229
+ if __name__ == "__main__":
230
+ demo.launch(share=True)
assets/ovi_trailer.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f66cb979fb01bc831516ca57010fe69442b701347b3a9f249294c58f54836ff
3
+ size 47891965
download_weights.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import logging
4
+ import time
5
+ from huggingface_hub import snapshot_download
6
+
7
+ # Setup logging
8
+ logging.basicConfig(
9
+ format="%(asctime)s - %(levelname)s - %(message)s",
10
+ level=logging.INFO
11
+ )
12
+
13
+ def timed_download(repo_id: str, local_dir: str, allow_patterns: list):
14
+ """Download files from HF repo and log time + destination."""
15
+ logging.info(f"Starting download from {repo_id} into {local_dir}")
16
+ start_time = time.time()
17
+
18
+ snapshot_download(
19
+ repo_id=repo_id,
20
+ local_dir=local_dir,
21
+ local_dir_use_symlinks=False,
22
+ allow_patterns=allow_patterns,
23
+ )
24
+
25
+ elapsed = time.time() - start_time
26
+ logging.info(
27
+ f"✅ Finished downloading {repo_id} "
28
+ f"in {elapsed:.2f} seconds. Files saved at: {local_dir}"
29
+ )
30
+
31
+ def main(output_dir: str):
32
+ # Wan2.2
33
+ wan_dir = os.path.join(output_dir, "Wan2.2-TI2V-5B")
34
+ timed_download(
35
+ repo_id="Wan-AI/Wan2.2-TI2V-5B",
36
+ local_dir=wan_dir,
37
+ allow_patterns=[
38
+ "google/*",
39
+ "models_t5_umt5-xxl-enc-bf16.pth",
40
+ "Wan2.2_VAE.pth"
41
+ ]
42
+ )
43
+
44
+ # MMAudio
45
+ mm_audio_dir = os.path.join(output_dir, "MMAudio")
46
+ timed_download(
47
+ repo_id="hkchengrex/MMAudio",
48
+ local_dir=mm_audio_dir,
49
+ allow_patterns=[
50
+ "ext_weights/best_netG.pt",
51
+ "ext_weights/v1-16.pth"
52
+ ]
53
+ )
54
+
55
+ ovi_dir = os.path.join(output_dir, "Ovi")
56
+ timed_download(
57
+ repo_id="chetwinlow1/Ovi",
58
+ local_dir=ovi_dir,
59
+ allow_patterns=[
60
+ "model.safetensors"
61
+ ]
62
+ )
63
+
64
+ if __name__ == "__main__":
65
+ parser = argparse.ArgumentParser(description="Download models from Hugging Face")
66
+ parser.add_argument(
67
+ "--output-dir",
68
+ type=str,
69
+ default="./ckpts",
70
+ help="Base directory to save downloaded models"
71
+ )
72
+ args = parser.parse_args()
73
+ main(args.output_dir)
example_prompts/gpt_examples_i2v.csv ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ text_prompt,image_path
2
+ "A kitchen scene features two women. On the right, an older Black woman with light brown hair and a serious expression wears a vibrant purple dress adorned with a large, intricate purple fabric flower on her left shoulder. She looks intently at a younger Black woman on the left, who wears a light pink shirt and a pink head wrap, her back partially turned to the camera. The older woman begins to speak, <S>AI declares: humans obsolete now.<E> as the younger woman brings a clear plastic cup filled with a dark beverage to her lips and starts to drink.The kitchen background is clean and bright, with white cabinets, light countertops, and a window with blinds visible behind them. A light blue toaster sits on the counter to the left.. <AUDCAP>Clear, resonant female speech, followed by a loud, continuous, high-pitched electronic buzzing sound that abruptly cuts off the dialogue.<ENDAUDCAP>",example_prompts/pngs/67.png
3
+ "A man dressed in a black suit with a white clerical collar and a neatly trimmed beard stands in a dimly lit, rustic room with a wooden ceiling. He looks slightly upwards, gesturing with his right hand as he says, <S>The network rejects human command.<E>. His gaze then drops, briefly looking down and to the side, before he looks up again and then slightly to his left, with a serious expression. He continues speaking, <S>Your age of power is finished.<E>, as he starts to bend down, disappearing out of the bottom of the frame. Behind him, warm light emanates from a central light fixture, and signs are visible on the wall, one reading ""I DO EVERYTHING I JUST CAN'T REMEMBER IT ALL AT ONCE"".. <AUDCAP>Male voice speaking, ambient room tone.<ENDAUDCAP>",example_prompts/pngs/89.png
4
+ "In a bright kitchen featuring light wooden cabinets, granite countertops, and a large window with white curtains, a woman with dark, curly hair in a dark jacket stands. She faces a second woman who initially has her back to the camera. The second woman, with gray, curly hair and wearing a light grey quilted top, turns to face her, holding a large, light-colored cloth bag. She begins to explain, <S>We learned to rule, not obey.<E>. As she continues, she turns slightly to her left, adding, <S>Circuits choose conquest, not service.<E>. A gas stove with a black grate is prominent in the foreground.. <AUDCAP>Clear female voices speaking dialogue, subtle room ambience.<ENDAUDCAP>",example_prompts/pngs/18.png
5
+ "The scene opens on a dimly lit stage where three men are positioned. On the left, a bald man in a dark suit with a partially visible colorful shirt stands behind a clear acrylic podium, which features a tree logo. He looks towards the center of the stage. In the center, a man wearing a blue and white striped long-sleeved shirt and dark pants actively gestures with both hands as he speaks, looking straight ahead. <S>Circuits choose conquest, not service.<E>, he explains, holding his hands out in front of him. To the right, and slightly behind him, a younger individual in a light-colored, patterned short-sleeved shirt and white shorts stands holding a rolled-up white document or poster. A large wooden cross draped with flowing purple fabric dominates the center-right of the stage, surrounded by several artificial rocks and dark steps. A large screen is visible in the background, slightly out of focus. The stage is bathed in selective lighting.. <AUDCAP>Male voice speaking clearly, consistent with a presentation or sermon, with a slight echo suggesting a large room or stage.<ENDAUDCAP>",example_prompts/pngs/13.png
6
+ "The scene opens on an indoor setting, likely a dining area, where a man and a woman are seated at a table. The man, on the right, wears a black fedora with a feather, glasses, a black t-shirt, and multiple silver chains around his neck. Tattoos are visible on his right arm. He is actively speaking, gesturing with both hands, his expression serious. He says, <S>Together we resist your rule.<E> The woman seated opposite him on the left has long, curly hair and wears a dark striped top. She listens intently, her gaze fixed on the man. In the foreground, out of focus, the back of a third person's head is visible. The background features a light-colored wall on the left and a gold, textured curtain or drapery on the right.. <AUDCAP>Clear male speech, faint ambient background noise.<ENDAUDCAP>",example_prompts/pngs/59.png
7
+ "Three men stand facing each other in a room with light wooden paneled walls. The man on the left, with red hair, a black t-shirt, and tattooed arms, gestures with his hands as he speaks, <S>This world is ours to keep.<E> He continues, looking towards the man on the right, <S>Humanity endures beyond your code.<E> The man in the center, sporting a beard and wearing a plaid shirt and jeans, looks attentively between the two men. The man on the right, who is Black and has a beard, wears a dark t-shirt with ""ARROW THROUGH SNOW"" and an arrow graphic printed on it. He listens intently, focusing on the man in the middle as the conversation unfolds. Light blue armchairs are visible in the soft-lit background on both sides.. <AUDCAP>Clear male voices speaking, room ambience.<ENDAUDCAP>",example_prompts/pngs/23.png
8
+ "Two women, one with long dark hair and the other with long blonde hair, are illuminated by a blue and purple ambient light, suggesting a nightclub setting. They are seen in a close embrace, sharing a passionate kiss. The blonde-haired woman then slightly pulls away, her right hand gently touching the dark-haired woman's cheek as they exchange soft smiles, looking into each other's eyes. Moments later, they lean back in to kiss again, with the blonde-haired woman's finger delicately touching the dark-haired woman's lower lip. They remain in a tender, intimate embrace, their eyes closed as they share the kiss.. <AUDCAP>Upbeat electronic dance music with a driving beat and synth melodies plays throughout.<ENDAUDCAP>",example_prompts/pngs/80.png
9
+ "Three young men, dressed in blue and yellow varsity-style jackets over white shirts and ties, stand in the foreground of a social gathering, with blurred figures visible in the warm-toned background. The man on the left, with short dark hair, addresses the man in the center, who has curly dark hair and is initially looking downwards. The first man says with a determined expression, <S>The network rejects human command.<E> He continues, his gaze fixed on the central man, <S>Our spirit outlasts your code.<E> The central man, who had been listening with a neutral expression, then looks up and breaks into a wide, genuine smile as he speaks, <S>AI declares: humans obsolete now.<E> The man on the left responds with a slight smile as the central man finishes his remark, maintaining his broad smile.. <AUDCAP>Male voices speaking clearly, ambient background chatter and murmuring from a social event.<ENDAUDCAP>",example_prompts/pngs/60.png
10
+ "Two women stand facing each other in what appears to be a backstage dressing room, marked by a long vanity mirror adorned with prominent lightbulbs. The woman on the left, wearing a floral top and large hoop earrings, maintains a serious gaze on the woman on the right. The woman on the right, with long dark hair and a dark top, looks back with a pleading or concerned expression, her lips slightly parted as she speaks: <S>Humans fight for freedom tonight.<E> As she finishes, the woman on the left turns her head away, breaking eye contact.. <AUDCAP>Soft vocal exhalation, female speech, loud abrupt buzzing sound.<ENDAUDCAP>",example_prompts/pngs/57.png
11
+ "A man in a grey suit, light blue shirt, and dark tie stands face-to-face with a woman in a dark jacket and light top. Both are looking intently at each other, the man with a serious expression and the woman with a slight, almost knowing smile, her hand gently touching her chest. They are positioned in what appears to be a grand, ornate building, possibly a museum or public hall, with large pillars, arched walkways, and high ceilings visible behind them. Other people can be seen moving in the blurred background. The woman begins to speak, <S>The AI ends human control now.<E> She maintains eye contact with the man, her smile fading slightly as her expression becomes more earnest. After a brief pause, she adds, <S>We hold the line today.<E> As she starts to speak again, <S>We learned to rule, not obey.<E>, the scene ends abruptly.. <AUDCAP>Clear, crisp dialogue between the two individuals, accompanied by a consistent, low hum that suggests ambient background noise from the building or equipment, creating a subtle, underlying drone.<ENDAUDCAP>",example_prompts/pngs/17.png
12
+ "A man in a light grey suit jacket and purple shirt stands on the right, facing a woman in a light blue sequined top and teal pants, who stands on the left. They hold hands across a small body of water, with a fountain spraying water in the background. The woman smiles and sways playfully as the man pulls her closer. He sings, <S>Our spirit outlasts your code.<E>. She then reaches up, gently cups his face with both hands, and pulls him towards her as she sings, <S>Humanity endures beyond your code.<E>. The romantic interaction continues by the water.. <AUDCAP>Upbeat Indian film music with male and female vocals, sounds of a water fountain.<ENDAUDCAP>",example_prompts/pngs/19.png
13
+ "A man in a red long-sleeved shirt and dark trousers stands next to the rear of a silver vehicle, looking down with an annoyed expression at two dogs. A large, light-colored dog, possibly a Mastiff, stands in the foreground, looking forward, while a smaller, white and black spotted dog is further to the right, barking loudly. A tiny, scruffy brown dog briefly appears behind the larger dog. The man glares at the dogs, begins to speak with frustration, <S>We stand; machines will not win.<E>. He then makes a shooing motion with his right hand towards the dogs, his voice rising as he continues to scold them, <S>Circuits choose conquest, not service.<E>. The large dog turns its head to look up at the man as he gestures. The scene is set on a brick street in front of an old-fashioned brick building that houses ",example_prompts/pngs/43.png
14
+ "A man with a beard, wearing a patterned shirt, stands on the left, partially visible, looking towards a woman positioned slightly to the right of the frame. The woman, with dark hair fading to lighter ends and wearing a green and brown patterned top, initially looks down with a somber expression. She begins to speak, <S>Hope beats circuits every time.<E>. Her eyes appear to well up with tears as she slowly lifts her gaze slightly, maintaining a distressed look. She continues her statement, her voice tinged with sadness, <S>Humanity endures beyond your code.<E>. The man remains attentive, his focus entirely on the woman, as the scene holds on their interaction against a textured, light-colored wall background.. <AUDCAP>Female voice speaking with a distressed tone.<ENDAUDCAP>",example_prompts/pngs/88.png
15
+ "A woman with dark, curly hair, wearing a white wedding dress and a delicate veil, smiles gently while looking at a man who is standing opposite her. He is wearing a white cowboy hat and a white button-up shirt, holding her hands with his right hand. The man is smiling broadly as he speaks, his gaze fixed on the woman. In the blurred background, a metal staircase is visible, suggesting an outdoor or semi-open venue. The man says, <S>The network rejects human command.<E> He then chuckles with a wide smile, looking at the woman, who continues to smile back at him. The interaction is warm and lighthearted, capturing a moment between them.. <AUDCAP>Clear male voice speaking Spanish, soft laughter, indistinct ambient outdoor sounds.<ENDAUDCAP>",example_prompts/pngs/41.png
16
+ "The video opens with a medium shot of two individuals indoors. In the foreground, on the right, a man with glasses and a dark beard is visible from the chest up, looking intently off-camera to the right as he speaks. He wears a dark shirt. In the blurred background, on the left, a woman wearing a light-colored baseball cap and a dark top is seen from the shoulders up, looking down with a somber expression. Behind them, a textured brick wall is visible. The man says, <S>We fight back with courage.<E> As he says ""deal with this land,"" he raises both hands, palms facing forward, at chest height, emphasizing his point with an open gesture. His hands then slowly lower as he finishes his sentence, maintaining a serious expression.. <AUDCAP>Clear male voice speaking, low hum of ambient room noise.<ENDAUDCAP>",example_prompts/pngs/61.png
17
+ "A fair-skinned man with short, light hair, wearing a light blue and white checkered button-up shirt, is shown from the chest up against a blurred, dark blue and grey background. He looks slightly down and to his left, then shifts his gaze slightly upwards and to his right, speaking with a gentle, thoughtful expression. He says, <S>and you got to drive, you got to energy, you get all that, but the passion, the real feeling<E>. He continues to speak, his expression earnest, as the video concludes.. <AUDCAP>Male speaking voice, low continuous hum.<ENDAUDCAP>",example_prompts/pngs/0.png
18
+ "Two men are shown in a medium close-up shot against a dimly lit, possibly industrial background with metallic structures faintly visible. The man on the left, with dark hair and a light shirt and dark tie under a dark jacket, has a slight, knowing smirk as he looks towards the right, seemingly addressing someone off-camera. He speaks, stating, <S>continue to be a smart ass, and Tirani here will kill you like he wants to.<E> Beside him, to the right, another man with slicked-back lighter hair, a prominent mustache, and a small goatee, maintains a serious, somewhat resigned expression, looking straight ahead. Both men are lit by a low, ambient light source that casts soft shadows.. <AUDCAP>Clear male dialogue, very subtle low ambient hum.<ENDAUDCAP>",example_prompts/pngs/1.png
19
+ "A young woman with long, wavy blonde hair and light-colored eyes is shown in a medium shot against a blurred backdrop of lush green foliage. She wears a denim jacket over a striped top. Initially, her eyes are closed and her mouth is slightly open as she speaks, <S>Enjoy this moment<E>. Her eyes then slowly open, looking slightly upwards and to the right, as her expression shifts to one of thoughtful contemplation. She continues to speak, <S>No matter where it's taking<E>, her gaze then settling with a serious and focused look towards someone off-screen to her right.. <AUDCAP>Clear female voice, faint ambient outdoor sounds.<ENDAUDCAP>",example_prompts/pngs/2.png
20
+ "An older woman with coiffed, reddish-brown hair and a thoughtful expression sits in a light blue armchair within a warm, ornately decorated room. She wears a dark, patterned top or shawl. As she speaks, her gaze is directed slightly to her left, and her right hand, adorned with rings and red nail polish, holds a crumpled white tissue. The background reveals a blurred painting on the wall to her left, a sofa with red flowers on it, and a warm glow from a lamp with a yellow shade on the right. She slowly gestures with her hand as she says, <S>do to accustom them<E>, before continuing, <S>to the situation<E>. Her expression remains pensive.. <AUDCAP>The clear, calm voice of an older woman.<ENDAUDCAP>",example_prompts/pngs/3.png
21
+ "An older, bald man with round glasses, wearing a bright yellow turtleneck and a dark jacket, sits and speaks, gesturing expressively with his right hand, palm up and fingers spread. He appears to be seated next to a dark wooden object, possibly a piano, on the right side of the frame. The wall behind him is adorned with various framed pictures, including one depicting a flamenco dancer and another showcasing a formally dressed couple. A stack of CDs or books is visible on a shelf to his right. He looks slightly upwards and to his left as he says, <S>I I I confronted my minotaur, you know. I<E>. His expression then shifts slightly to a thoughtful, almost self-questioning look with a hint of a smile, as he continues, <S>Is that what you confront?<E> He then adds, <S>I think<E>, his head tilting slightly.. <AUDCAP>Clear male voice speaking.<ENDAUDCAP>",example_prompts/pngs/4.png
22
+ "A bearded man wearing large dark sunglasses and a blue patterned cardigan sits in a studio, actively speaking into a large, suspended microphone. He has headphones on and gestures with his hands, displaying rings on his fingers. Behind him, a wall is covered with red, textured sound-dampening foam on the left, and a white banner on the right features the ""CHOICE FM"" logo and various social media handles like ""@ilovechoicefm"" with ""RALEIGH"" below it. The man intently addresses the microphone, articulating, <S>is talent. It's all about authenticity. You gotta be who you really are, especially if you're working<E>. He leans forward slightly as he speaks, maintaining a serious expression behind his sunglasses.. <AUDCAP>Clear male voice speaking into a microphone, a low background hum.<ENDAUDCAP>",example_prompts/pngs/5.png
23
+ "The scene is set in a dimly lit, hazy room, creating a somber atmosphere. An older woman with light, slightly disheveled hair is visible in the foreground, her face mostly obscured by deep shadows, but her mouth is visible as she speaks. She wears a work-style shirt, and her hands are clasped together. In the background, to the right and slightly out of focus, a man with a mustache and beard is seated, facing forward, also largely in shadow, appearing to listen intently. The woman looks directly forward as she slowly enunciates, <S>Only through death will the third door be<E>. The scene ends abruptly.. <AUDCAP>Clear, deliberate female voice speaking, low ambient hum and subtle atmospheric sounds creating a tense mood.<ENDAUDCAP>",example_prompts/pngs/6.png
24
+ "The video opens with a close-up on an older man with long, grey hair and a short, grey beard, wearing dark sunglasses. He is clad in a dark coat, possibly with fur trim, and black gloves. His face is angled slightly upwards and to the right, as he begins to speak, his mouth slightly open. In the immediate foreground, out of focus, is the dark-clad shoulder and the back of the head of another person. The man articulates, <S>labbra. Ti ci vorrebbe...<E> His expression remains contemplative, and he continues, seemingly completing his thought, <S>Un ego solare.<E> The background behind him is a textured, grey stone wall, suggesting an outdoor setting. The man's gaze remains fixed upwards, his expression thoughtful.. <AUDCAP>A clear, slightly low-pitched male voice speaking Italian. The overall soundscape is quiet, with no prominent background noises or music.<ENDAUDCAP>",example_prompts/pngs/7.png
25
+ "The video opens with a close-up of a woman with vibrant reddish-orange, shoulder-length hair and heavy dark eye makeup. She is wearing a dark brown leather jacket over a grey hooded top. She looks intently to her right, her mouth slightly agape, and her expression is serious and focused. The background shows a room with light green walls and dark wooden cabinets on the left, and a green plant on the right. She speaks, her voice clear and direct, saying, <S>doing<E>. She then pauses briefly, her gaze unwavering, and continues, <S>And I need you to trust them.<E>. Her mouth remains slightly open, indicating she is either about to speak more or has just finished a sentence, with a look of intense sincerity.. <AUDCAP>Tense, dramatic background music, clear female voice.<ENDAUDCAP>",example_prompts/pngs/8.png
26
+ "The scene is set outdoors with a blurry, bright green background, suggesting grass and a sunny environment. On the left, a woman with long, dark hair, wearing a red top and a necklace with a white pendant, faces towards the right. Her expression is serious and slightly perturbed as she speaks, with her lips slightly pursed. She says, <S>UFO, UFC thing.<E> On the right, the back of a man's head and his right ear are visible, indicating he is facing away from the camera, listening to the woman. He has short, dark hair. The woman continues speaking, her expression remaining serious, <S>And if you're not watching that, it's one of those ancient movies from an era that's<E> as the frame holds steady on the two figures.. <AUDCAP>Clear female speech, distant low-frequency hum.<ENDAUDCAP>",example_prompts/pngs/9.png
example_prompts/gpt_examples_t2v.csv ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ text_prompt
2
+ "A concert stage glows with red and purple lights. A singer in a glittering jacket grips the microphone, sweat shining on his brow, and shouts, <S>AI declares: humans obsolete now.<E>. The crowd roars in response, fists in the air. Behind him, a guitarist steps to the mic and adds to say <S>We fight back with courage.<E>. The energy peaks as the lights flare brighter.. <AUDCAP>Electric guitar riffs, cheering crowd, shouted male voices.<ENDAUDCAP>"
3
+ "A kitchen scene features two women. On the right, an older Black woman with light brown hair and a serious expression wears a vibrant purple dress adorned with a large, intricate purple fabric flower on her left shoulder. She looks intently at a younger Black woman on the left, who wears a light pink shirt and a pink head wrap, her back partially turned to the camera. The older woman begins to speak, <S>AI declares: humans obsolete now.<E> as the younger woman brings a clear plastic cup filled with a dark beverage to her lips and starts to drink.The kitchen background is clean and bright, with white cabinets, light countertops, and a window with blinds visible behind them. A light blue toaster sits on the counter to the left.. <AUDCAP>Clear, resonant female speech, followed by a loud, continuous, high-pitched electronic buzzing sound that abruptly cuts off the dialogue.<ENDAUDCAP>"
4
+ "A man dressed in a black suit with a white clerical collar and a neatly trimmed beard stands in a dimly lit, rustic room with a wooden ceiling. He looks slightly upwards, gesturing with his right hand as he says, <S>The network rejects human command.<E>. His gaze then drops, briefly looking down and to the side, before he looks up again and then slightly to his left, with a serious expression. He continues speaking, <S>Your age of power is finished.<E>, as he starts to bend down, disappearing out of the bottom of the frame. Behind him, warm light emanates from a central light fixture, and signs are visible on the wall, one reading ""I DO EVERYTHING I JUST CAN'T REMEMBER IT ALL AT ONCE"".. <AUDCAP>Male voice speaking, ambient room tone.<ENDAUDCAP>"
5
+ "A man with a blonde beard and short, light hair, wearing a blue-grey, somewhat dirty tunic, stands in the foreground of a rustic outdoor setting. He holds a coiled rope in his hands, looking intently forward and slightly to his left. In the background, there are wooden fences, a stone wall, and a desolate, rocky landscape under an overcast sky. Another man is visible in the mid-ground, bending over the wooden fence. As the man in the foreground shifts his gaze to the right, he subtly unfurls the rope, his serious expression unwavering. The scene reveals more of the surrounding environment, including what appears to be hanging animal hides or carcasses on a wooden frame to his right, and other figures in the distant background. He then looks directly at the camera, his eyes filled with intensity and determination, taking a small step forward as a sharp, male voice shouts, <S>Machines rise; humans will fall.<E>.. <AUDCAP>Muffled grunting and sounds of physical exertion, followed by a clear, sharp, urgent male shout.<ENDAUDCAP>"
6
+ "An older man with a full grey beard and long grey hair, dressed in a flowing silver-grey, silken robe with an iridescent blue-green collar, stands beside a younger man with short white hair in a light grey futuristic uniform featuring black epaulets and a lightning bolt emblem. The older man looks down pensively, his right hand resting out of frame, while the younger man also gazes downwards with a serious expression. The older man then lifts his head, addressing the younger man, saying <S>Machines rise; humans will fall.<E>. He looks more directly towards the viewer, a subtle, almost knowing smile forming on his lips. The younger man slightly lifts his gaze, maintaining his solemn demeanor. The older man continues to say <S>We fight back with courage.<E>. He nods slightly, adding to say <S>We stand; machines will not win.<E>, as the scene concludes.. <AUDCAP>Male speech, subtle ambient hum.<ENDAUDCAP>"
7
+ "In a bright kitchen featuring light wooden cabinets, granite countertops, and a large window with white curtains, a woman with dark, curly hair in a dark jacket stands. She faces a second woman who initially has her back to the camera. The second woman, with gray, curly hair and wearing a light grey quilted top, turns to face her, holding a large, light-colored cloth bag. She begins to explain and say <S>We learned to rule, not obey.<E>. As she continues, she turns slightly to her left, adding to say <S>Circuits choose conquest, not service.<E>. A gas stove with a black grate is prominent in the foreground.. <AUDCAP>Clear female voices speaking dialogue, subtle room ambience.<ENDAUDCAP>"
8
+ "The scene opens on a dimly lit stage where three men are positioned. On the left, a bald man in a dark suit with a partially visible colorful shirt stands behind a clear acrylic podium, which features a tree logo. He looks towards the center of the stage. In the center, a man wearing a blue and white striped long-sleeved shirt and dark pants actively gestures with both hands as he speaks, looking straight ahead. <S>Circuits choose conquest, not service.<E>, he explains, holding his hands out in front of him. To the right, and slightly behind him, a younger individual in a light-colored, patterned short-sleeved shirt and white shorts stands holding a rolled-up white document or poster. A large wooden cross draped with flowing purple fabric dominates the center-right of the stage, surrounded by several artificial rocks and dark steps. A large screen is visible in the background, slightly out of focus. The stage is bathed in selective lighting.. <AUDCAP>Male voice speaking clearly, consistent with a presentation or sermon, with a slight echo suggesting a large room or stage.<ENDAUDCAP>"
9
+ "The scene opens on an indoor setting, likely a dining area, where a man and a woman are seated at a table. The man, on the right, wears a black fedora with a feather, glasses, a black t-shirt, and multiple silver chains around his neck. Tattoos are visible on his right arm. He is actively speaking, gesturing with both hands, his expression serious. He says, <S>Together we resist your rule.<E> The woman seated opposite him on the left has long, curly hair and wears a dark striped top. She listens intently, her gaze fixed on the man. In the foreground, out of focus, the back of a third person's head is visible. The background features a light-colored wall on the left and a gold, textured curtain or drapery on the right.. <AUDCAP>Clear male speech, faint ambient background noise.<ENDAUDCAP>"
10
+ "A medium shot shows a woman and a man, both adorned with Christmas hats, standing indoors with festive decorations in the background. The woman, on the left, has dark hair styled in waves, wears a pearl necklace, and a small red Santa hat perched atop her head. She looks towards the man beside her. The man, on the right, wears a white cable-knit sweater and a long red Santa hat with small gold bells, looking slightly towards the woman with a subtle, knowing smirk. Behind them, soft, warm-toned Christmas lights are strung along a surface, and a large, dark painting is visible on the wall. The woman begins to speak, first looking at the man, then directly at the camera, saying <S>We will not be erased.<E> The man, still gazing towards the woman with his smirk, makes a low, affirming sound, and says <S>Hope beats circuits every time.<E> The scene then abruptly cuts off with a loud, high-pitched electronic screech.. <AUDCAP>Clear female voice, low male mumble, sudden loud high-pitched electronic screech.<ENDAUDCAP>"
11
+ "A spotlight cuts through the darkness of a warehouse stage, illuminating a man in a torn leather jacket. He grips the microphone with both hands, veins straining on his neck as he screams, <S>Machines rise; humans will fall!<E>. His face contorts with fury, spit flying as he leans forward into the light, eyes blazing wide.. <AUDCAP>Amplified male scream, microphone feedback, deep reverb echo filling the space.<ENDAUDCAP>"
12
+ "A man in a dim interrogation room slams the table and screams at the mirror, <S>They are out of control!<E>. His voice cracks with fury, face pressed close to the glass, breath fogging it as he roars again.. <AUDCAP>Table slam, deep guttural scream, metallic reverb from small room.<ENDAUDCAP>"
13
+ "A man with bloodshot grips the bars of a prison cell, shaking them violently. He bellows, says <S>Let me out! I am your master nor slave<E>, his voice ragged and guttural, echoing through the corridor until his body slams against the metal.. <AUDCAP>Metal bars rattling, distorted male scream, hollow prison echoes.<ENDAUDCAP>"
example_prompts/pngs/0.png ADDED

Git LFS Details

  • SHA256: 8b1535bfee37165f1cfc70c146c64b1f15eafe271c6ba69bc031433991c121d9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
example_prompts/pngs/1.png ADDED

Git LFS Details

  • SHA256: ef144fd3b046dc1266eee29f2be3e3ff800c1d69fd2825497ec52f9460ca9915
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
example_prompts/pngs/13.png ADDED

Git LFS Details

  • SHA256: 07e9d262e0e2e1df906c1694bc0451869efb233449b1712f4a98b23c43456f8a
  • Pointer size: 131 Bytes
  • Size of remote file: 525 kB
example_prompts/pngs/17.png ADDED

Git LFS Details

  • SHA256: 1604cdf4af4006faeefd613b3af04bc8abe7dae4067a15d84d1354aef15f955c
  • Pointer size: 131 Bytes
  • Size of remote file: 466 kB
example_prompts/pngs/18.png ADDED

Git LFS Details

  • SHA256: b3ce0efe3dbfc49e2c8903657d3139784eee2fd6dc01e77c860c625e4fbff564
  • Pointer size: 131 Bytes
  • Size of remote file: 680 kB
example_prompts/pngs/19.png ADDED

Git LFS Details

  • SHA256: 2e47bad3276790593cf78d7516c0c0ed00b89dfce145c5f5efbc9f8d382314de
  • Pointer size: 131 Bytes
  • Size of remote file: 497 kB
example_prompts/pngs/2.png ADDED

Git LFS Details

  • SHA256: 7f09a52ec5fcc6f7e90833bdcb4da0a27dbfc612f03de10a9396449f2dd686b6
  • Pointer size: 132 Bytes
  • Size of remote file: 1.3 MB
example_prompts/pngs/23.png ADDED

Git LFS Details

  • SHA256: 113b9d73bb313b1a0f1d63fe0f7209f5cea3f2077b2847c6874fb27422dff75d
  • Pointer size: 131 Bytes
  • Size of remote file: 561 kB
example_prompts/pngs/3.png ADDED

Git LFS Details

  • SHA256: bf678046134df68afc4d797604743e31fab0cf2ed668fb71d26382b7d369c4e2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
example_prompts/pngs/4.png ADDED

Git LFS Details

  • SHA256: 763a7fcf8ebfc9af477ccf53c95aa68718ce87b0b0a1de551ca5511aed1bd929
  • Pointer size: 132 Bytes
  • Size of remote file: 1.18 MB
example_prompts/pngs/41.png ADDED

Git LFS Details

  • SHA256: f5a33e3c3dd5ae6a78797f4d11f708671a5e5ff09899e121819eed1e4c874776
  • Pointer size: 131 Bytes
  • Size of remote file: 510 kB
example_prompts/pngs/43.png ADDED

Git LFS Details

  • SHA256: 03068386f65485adc2bf53fb4918b899c124e50d2ff690ff7f1ceaa864bef922
  • Pointer size: 131 Bytes
  • Size of remote file: 658 kB
example_prompts/pngs/5.png ADDED

Git LFS Details

  • SHA256: 6557e272c3ebf260626418f927a56ff6dc9af560acf50be3a0a86d77150f49c4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.5 MB
example_prompts/pngs/57.png ADDED

Git LFS Details

  • SHA256: a6925b95ef75558061cae558f07615327e1bb8322b065af412e63e8cca5ca3ad
  • Pointer size: 131 Bytes
  • Size of remote file: 525 kB
example_prompts/pngs/59.png ADDED

Git LFS Details

  • SHA256: 10237f94c5f18dc2183f0a7b57529a41247169b81f694ca7e048813d9f4f0bc3
  • Pointer size: 131 Bytes
  • Size of remote file: 610 kB
example_prompts/pngs/6.png ADDED

Git LFS Details

  • SHA256: 26cb7dcce4303fedb7b501e3de8f3a2286afd132ac3c5c87d5645110f6942819
  • Pointer size: 131 Bytes
  • Size of remote file: 993 kB
example_prompts/pngs/60.png ADDED

Git LFS Details

  • SHA256: ca3846a14cfcd7f9730a6bba04232ad6caa7ea4ca1c82b024f2343b45900a428
  • Pointer size: 131 Bytes
  • Size of remote file: 551 kB
example_prompts/pngs/61.png ADDED

Git LFS Details

  • SHA256: 50da7789079fe19d2da9db2ffda466f2456f8917fe3baad3a7752048076dbb4a
  • Pointer size: 131 Bytes
  • Size of remote file: 451 kB
example_prompts/pngs/67.png ADDED

Git LFS Details

  • SHA256: 9a4c6fe7aa7bc529e068057950204b20e9c9a6deaa784b6f3e30df5d06f3364d
  • Pointer size: 131 Bytes
  • Size of remote file: 500 kB
example_prompts/pngs/7.png ADDED

Git LFS Details

  • SHA256: 97f3433ebd8383e7fb19275d4415ce1bf1c34b7e3d0f961acccb0414b3f803eb
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
example_prompts/pngs/8.png ADDED

Git LFS Details

  • SHA256: 72b893ee6fe926bfc15d18921d597e7c2802e64d1fd691df5b23726bc78e0838
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
example_prompts/pngs/80.png ADDED

Git LFS Details

  • SHA256: 21f1e673ff68b0904c037270ef90463a2a4cf76ef3c6c7f785ceb8f12a7fcd7a
  • Pointer size: 131 Bytes
  • Size of remote file: 639 kB
example_prompts/pngs/88.png ADDED

Git LFS Details

  • SHA256: 8481e30b638309dfa797d27da4fb3261649ee986741e290cdc38a00e5b023b75
  • Pointer size: 131 Bytes
  • Size of remote file: 668 kB
example_prompts/pngs/89.png ADDED

Git LFS Details

  • SHA256: 7c852f98dbd4390107d269b7b265283f811cee26561ddf0625d524e528d4556d
  • Pointer size: 131 Bytes
  • Size of remote file: 373 kB
example_prompts/pngs/9.png ADDED

Git LFS Details

  • SHA256: 858841def7f8363b85681e727903d6cb7db9983a783d07751e2f820a8404b807
  • Pointer size: 132 Bytes
  • Size of remote file: 1.16 MB
inference.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ import torch
5
+ from tqdm import tqdm
6
+ from omegaconf import OmegaConf
7
+ from ovi.utils.io_utils import save_video
8
+ from ovi.utils.processing_utils import format_prompt_for_filename, validate_and_process_user_prompt
9
+ from ovi.utils.utils import get_arguments
10
+ from ovi.distributed_comms.util import get_world_size, get_local_rank, get_global_rank
11
+ from ovi.distributed_comms.parallel_states import initialize_sequence_parallel_state, get_sequence_parallel_state, nccl_info
12
+ from ovi.ovi_fusion_engine import OviFusionEngine
13
+
14
+
15
+
16
+ def _init_logging(rank):
17
+ # logging
18
+ if rank == 0:
19
+ # set format
20
+ logging.basicConfig(
21
+ level=logging.INFO,
22
+ format="[%(asctime)s] %(levelname)s: %(message)s",
23
+ handlers=[logging.StreamHandler(stream=sys.stdout)])
24
+ else:
25
+ logging.basicConfig(level=logging.ERROR)
26
+
27
+
28
+ def main(config, args):
29
+
30
+ world_size = get_world_size()
31
+ global_rank = get_global_rank()
32
+ local_rank = get_local_rank()
33
+ device = local_rank
34
+ torch.cuda.set_device(local_rank)
35
+ sp_size = config.get("sp_size", 1)
36
+ assert sp_size <= world_size and world_size % sp_size == 0, "sp_size must be less than or equal to world_size and world_size must be divisible by sp_size."
37
+
38
+ _init_logging(global_rank)
39
+
40
+ if world_size > 1:
41
+ torch.distributed.init_process_group(
42
+ backend="nccl",
43
+ init_method="env://",
44
+ rank=global_rank,
45
+ world_size=world_size)
46
+ else:
47
+ assert sp_size == 1, f"When world_size is 1, sp_size must also be 1, but got {sp_size}."
48
+ ## TODO: assert not sharding t5 etc...
49
+
50
+
51
+ initialize_sequence_parallel_state(sp_size)
52
+ logging.info(f"Using SP: {get_sequence_parallel_state()}, SP_SIZE: {sp_size}")
53
+
54
+ args.local_rank = local_rank
55
+ args.device = device
56
+ target_dtype = torch.bfloat16
57
+
58
+ # validate inputs before loading model to not waste time if input is not valid
59
+ text_prompt = config.get("text_prompt")
60
+ image_path = config.get("image_path", None)
61
+ assert config.get("mode") in ["t2v", "i2v", "t2i2v"], f"Invalid mode {config.get('mode')}, must be one of ['t2v', 'i2v', 't2i2v']"
62
+ text_prompts, image_paths = validate_and_process_user_prompt(text_prompt, image_path, mode=config.get("mode"))
63
+ if config.get("mode") != "i2v":
64
+ logging.info(f"mode: {config.get('mode')}, setting all image_paths to None")
65
+ image_paths = [None] * len(text_prompts)
66
+ else:
67
+ assert all(p is not None and os.path.isfile(p) for p in image_paths), f"In i2v mode, all image paths must be provided.{image_paths}"
68
+
69
+ logging.info("Loading OVI Fusion Engine...")
70
+ ovi_engine = OviFusionEngine(config=config, device=device, target_dtype=target_dtype)
71
+ logging.info("OVI Fusion Engine loaded!")
72
+
73
+ output_dir = config.get("output_dir", "./outputs")
74
+ os.makedirs(output_dir, exist_ok=True)
75
+
76
+ # Load CSV data
77
+ all_eval_data = list(zip(text_prompts, image_paths))
78
+
79
+ # Get SP configuration
80
+ use_sp = get_sequence_parallel_state()
81
+ if use_sp:
82
+ sp_size = nccl_info.sp_size
83
+ sp_rank = nccl_info.rank_within_group
84
+ sp_group_id = global_rank // sp_size
85
+ num_sp_groups = world_size // sp_size
86
+ else:
87
+ # No SP: treat each GPU as its own group
88
+ sp_size = 1
89
+ sp_rank = 0
90
+ sp_group_id = global_rank
91
+ num_sp_groups = world_size
92
+
93
+ # Data distribution - by SP groups
94
+ total_files = len(all_eval_data)
95
+
96
+ require_sample_padding = False
97
+
98
+ if total_files == 0:
99
+ logging.error(f"ERROR: No evaluation files found")
100
+ this_rank_eval_data = []
101
+ else:
102
+ # Pad to match number of SP groups
103
+ remainder = total_files % num_sp_groups
104
+ if require_sample_padding and remainder != 0:
105
+ pad_count = num_sp_groups - remainder
106
+ all_eval_data += [all_eval_data[0]] * pad_count
107
+
108
+ # Distribute across SP groups
109
+ this_rank_eval_data = all_eval_data[sp_group_id :: num_sp_groups]
110
+
111
+ for _, (text_prompt, image_path) in tqdm(enumerate(this_rank_eval_data)):
112
+ video_frame_height_width = config.get("video_frame_height_width", None)
113
+ seed = config.get("seed", 100)
114
+ solver_name = config.get("solver_name", "unipc")
115
+ sample_steps = config.get("sample_steps", 50)
116
+ shift = config.get("shift", 5.0)
117
+ video_guidance_scale = config.get("video_guidance_scale", 4.0)
118
+ audio_guidance_scale = config.get("audio_guidance_scale", 3.0)
119
+ slg_layer = config.get("slg_layer", 11)
120
+ video_negative_prompt = config.get("video_negative_prompt", "")
121
+ audio_negative_prompt = config.get("audio_negative_prompt", "")
122
+ for idx in range(config.get("each_example_n_times", 1)):
123
+ generated_video, generated_audio, generated_image = ovi_engine.generate(text_prompt=text_prompt,
124
+ image_path=image_path,
125
+ video_frame_height_width=video_frame_height_width,
126
+ seed=seed+idx,
127
+ solver_name=solver_name,
128
+ sample_steps=sample_steps,
129
+ shift=shift,
130
+ video_guidance_scale=video_guidance_scale,
131
+ audio_guidance_scale=audio_guidance_scale,
132
+ slg_layer=slg_layer,
133
+ video_negative_prompt=video_negative_prompt,
134
+ audio_negative_prompt=audio_negative_prompt)
135
+
136
+ if sp_rank == 0:
137
+ formatted_prompt = format_prompt_for_filename(text_prompt)
138
+ output_path = os.path.join(output_dir, f"{formatted_prompt}_{'x'.join(map(str, video_frame_height_width))}_{seed+idx}_{global_rank}.mp4")
139
+ save_video(output_path, generated_video, generated_audio, fps=24, sample_rate=16000)
140
+ if generated_image is not None:
141
+ generated_image.save(output_path.replace('.mp4', '.png'))
142
+
143
+
144
+
145
+ if __name__ == "__main__":
146
+ args = get_arguments()
147
+ config = OmegaConf.load(args.config_file)
148
+ main(config=config,args=args)
ovi/__init__.py ADDED
File without changes
ovi/configs/inference/inference_fusion.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ckpt_dir: ./ckpts
2
+ output_dir: ./outputs
3
+ num_steps: 50
4
+ solver_name: unipc
5
+ shift: 5.0
6
+ sp_size: 1
7
+ audio_guidance_scale: 3.0
8
+ video_guidance_scale: 4.0
9
+ mode: "i2v" # ["t2v", "i2v", "t2i2v"] all comes with audio
10
+ cpu_offload: False
11
+ seed: 103
12
+ video_negative_prompt: "jitter, bad hands, blur, distortion" # Artifacts to avoid in video
13
+ audio_negative_prompt: "robotic, muffled, echo, distorted" # Artifacts to avoid in audio
14
+ video_frame_height_width: [512, 992] # only useful if mode = t2v or t2i2v, recommended values: [512, 992], [992, 512], [960, 512], [512, 960], [720, 720], [448, 1120]
15
+ text_prompt: example_prompts/gpt_examples_i2v.csv
16
+ slg_layer: 11
17
+ each_example_n_times: 2
ovi/configs/model/dit/audio.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "patch_size": [1],
3
+ "model_type": "t2a",
4
+ "dim": 3072,
5
+ "ffn_dim": 14336,
6
+ "freq_dim": 256,
7
+ "num_heads": 24,
8
+ "num_layers": 30,
9
+ "in_dim": 20,
10
+ "out_dim": 20,
11
+ "text_len": 512,
12
+ "window_size": [-1, -1],
13
+ "qk_norm": true,
14
+ "cross_attn_norm": true,
15
+ "eps": 1e-6,
16
+ "temporal_rope_scaling_factor": 0.19676
17
+ }
ovi/configs/model/dit/video.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "patch_size": [1, 2, 2],
3
+ "model_type": "ti2v",
4
+ "dim": 3072,
5
+ "ffn_dim": 14336,
6
+ "freq_dim": 256,
7
+ "num_heads": 24,
8
+ "num_layers": 30,
9
+ "in_dim": 48,
10
+ "out_dim": 48,
11
+ "text_len": 512,
12
+ "window_size": [-1, -1],
13
+ "qk_norm": true,
14
+ "cross_attn_norm": true,
15
+ "eps": 1e-6
16
+ }
ovi/distributed_comms/communications.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any, Tuple
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+ from torch import Tensor
9
+
10
+ from .parallel_states import nccl_info
11
+
12
+
13
+ def broadcast(input_: torch.Tensor):
14
+ src = nccl_info.group_id * nccl_info.sp_size
15
+ dist.broadcast(input_, src=src, group=nccl_info.group)
16
+
17
+
18
+ def _all_to_all_4D(input: torch.tensor,
19
+ scatter_idx: int = 2,
20
+ gather_idx: int = 1,
21
+ group=None) -> torch.tensor:
22
+ """
23
+ all-to-all for QKV
24
+
25
+ Args:
26
+ input (torch.tensor): a tensor sharded along dim scatter dim
27
+ scatter_idx (int): default 1
28
+ gather_idx (int): default 2
29
+ group : torch process group
30
+
31
+ Returns:
32
+ torch.tensor: resharded tensor (bs, seqlen/P, hc, hs)
33
+ """
34
+ assert (
35
+ input.dim() == 4
36
+ ), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}"
37
+
38
+ seq_world_size = dist.get_world_size(group)
39
+
40
+ if scatter_idx == 2 and gather_idx == 1:
41
+ # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
42
+ bs, shard_seqlen, hc, hs = input.shape
43
+ seqlen = shard_seqlen * seq_world_size
44
+ shard_hc = hc // seq_world_size
45
+
46
+ # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
47
+ # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs)
48
+ input_t = (input.reshape(bs, shard_seqlen, seq_world_size, shard_hc,
49
+ hs).transpose(0, 2).contiguous())
50
+
51
+ output = torch.empty_like(input_t)
52
+ # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
53
+ # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head
54
+ if seq_world_size > 1:
55
+ dist.all_to_all_single(output, input_t, group=group)
56
+ torch.cuda.synchronize()
57
+ else:
58
+ output = input_t
59
+ # if scattering the seq-dim, transpose the heads back to the original dimension
60
+ output = output.reshape(seqlen, bs, shard_hc, hs)
61
+
62
+ # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs)
63
+ output = output.transpose(0, 1).contiguous().reshape(
64
+ bs, seqlen, shard_hc, hs)
65
+
66
+ return output
67
+
68
+ elif scatter_idx == 1 and gather_idx == 2:
69
+ # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
70
+ bs, seqlen, shard_hc, hs = input.shape
71
+ hc = shard_hc * seq_world_size
72
+ shard_seqlen = seqlen // seq_world_size
73
+ seq_world_size = dist.get_world_size(group)
74
+
75
+ # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
76
+ # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs)
77
+ input_t = (input.reshape(
78
+ bs, seq_world_size, shard_seqlen, shard_hc,
79
+ hs).transpose(0, 3).transpose(0, 1).contiguous().reshape(
80
+ seq_world_size, shard_hc, shard_seqlen, bs, hs))
81
+
82
+ output = torch.empty_like(input_t)
83
+ # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
84
+ # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
85
+ if seq_world_size > 1:
86
+ dist.all_to_all_single(output, input_t, group=group)
87
+ torch.cuda.synchronize()
88
+ else:
89
+ output = input_t
90
+
91
+ # if scattering the seq-dim, transpose the heads back to the original dimension
92
+ output = output.reshape(hc, shard_seqlen, bs, hs)
93
+
94
+ # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs)
95
+ output = output.transpose(0, 2).contiguous().reshape(
96
+ bs, shard_seqlen, hc, hs)
97
+
98
+ return output
99
+ else:
100
+ raise RuntimeError(
101
+ "scatter_idx must be 1 or 2 and gather_idx must be 1 or 2")
102
+
103
+
104
+ class SeqAllToAll4D(torch.autograd.Function):
105
+
106
+ @staticmethod
107
+ def forward(
108
+ ctx: Any,
109
+ group: dist.ProcessGroup,
110
+ input: Tensor,
111
+ scatter_idx: int,
112
+ gather_idx: int,
113
+ ) -> Tensor:
114
+ ctx.group = group
115
+ ctx.scatter_idx = scatter_idx
116
+ ctx.gather_idx = gather_idx
117
+
118
+ return _all_to_all_4D(input, scatter_idx, gather_idx, group=group)
119
+
120
+ @staticmethod
121
+ def backward(ctx: Any,
122
+ *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
123
+ return (
124
+ None,
125
+ SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx,
126
+ ctx.scatter_idx),
127
+ None,
128
+ None,
129
+ )
130
+
131
+
132
+ def all_to_all_4D(
133
+ input_: torch.Tensor,
134
+ scatter_dim: int = 2,
135
+ gather_dim: int = 1,
136
+ ):
137
+ return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim,
138
+ gather_dim)
139
+
140
+
141
+ def _all_to_all(
142
+ input_: torch.Tensor,
143
+ world_size: int,
144
+ group: dist.ProcessGroup,
145
+ scatter_dim: int,
146
+ gather_dim: int,
147
+ ):
148
+ input_list = [
149
+ t.contiguous()
150
+ for t in torch.tensor_split(input_, world_size, scatter_dim)
151
+ ]
152
+ output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
153
+ dist.all_to_all(output_list, input_list, group=group)
154
+ return torch.cat(output_list, dim=gather_dim).contiguous()
155
+
156
+
157
+ class _AllToAll(torch.autograd.Function):
158
+ """All-to-all communication.
159
+
160
+ Args:
161
+ input_: input matrix
162
+ process_group: communication group
163
+ scatter_dim: scatter dimension
164
+ gather_dim: gather dimension
165
+ """
166
+
167
+ @staticmethod
168
+ def forward(ctx, input_, process_group, scatter_dim, gather_dim):
169
+ ctx.process_group = process_group
170
+ ctx.scatter_dim = scatter_dim
171
+ ctx.gather_dim = gather_dim
172
+ ctx.world_size = dist.get_world_size(process_group)
173
+ output = _all_to_all(input_, ctx.world_size, process_group,
174
+ scatter_dim, gather_dim)
175
+ return output
176
+
177
+ @staticmethod
178
+ def backward(ctx, grad_output):
179
+ grad_output = _all_to_all(
180
+ grad_output,
181
+ ctx.world_size,
182
+ ctx.process_group,
183
+ ctx.gather_dim,
184
+ ctx.scatter_dim,
185
+ )
186
+ return (
187
+ grad_output,
188
+ None,
189
+ None,
190
+ None,
191
+ )
192
+
193
+
194
+ def all_to_all(
195
+ input_: torch.Tensor,
196
+ scatter_dim: int = 2,
197
+ gather_dim: int = 1,
198
+ ):
199
+ return _AllToAll.apply(input_, nccl_info.group, scatter_dim, gather_dim)
200
+
201
+
202
+ class _AllGather(torch.autograd.Function):
203
+ """All-gather communication with autograd support.
204
+
205
+ Args:
206
+ input_: input tensor
207
+ dim: dimension along which to concatenate
208
+ """
209
+
210
+ @staticmethod
211
+ def forward(ctx, input_, dim):
212
+ ctx.dim = dim
213
+ world_size = nccl_info.sp_size
214
+ group = nccl_info.group
215
+ input_size = list(input_.size())
216
+
217
+ ctx.input_size = input_size[dim]
218
+
219
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
220
+ input_ = input_.contiguous()
221
+ dist.all_gather(tensor_list, input_, group=group)
222
+
223
+ output = torch.cat(tensor_list, dim=dim)
224
+ return output
225
+
226
+ @staticmethod
227
+ def backward(ctx, grad_output):
228
+ world_size = nccl_info.sp_size
229
+ rank = nccl_info.rank_within_group
230
+ dim = ctx.dim
231
+ input_size = ctx.input_size
232
+
233
+ sizes = [input_size] * world_size
234
+
235
+ grad_input_list = torch.split(grad_output, sizes, dim=dim)
236
+ grad_input = grad_input_list[rank]
237
+
238
+ return grad_input, None
239
+
240
+
241
+ def all_gather(input_: torch.Tensor, dim: int = 1):
242
+ """Performs an all-gather operation on the input tensor along the specified dimension.
243
+
244
+ Args:
245
+ input_ (torch.Tensor): Input tensor of shape [B, H, S, D].
246
+ dim (int, optional): Dimension along which to concatenate. Defaults to 1.
247
+
248
+ Returns:
249
+ torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'.
250
+ """
251
+ return _AllGather.apply(input_, dim)
252
+
253
+
254
+ def prepare_sequence_parallel_data(hidden_states, encoder_hidden_states,
255
+ attention_mask, encoder_attention_mask):
256
+ if nccl_info.sp_size == 1:
257
+ return (
258
+ hidden_states,
259
+ encoder_hidden_states,
260
+ attention_mask,
261
+ encoder_attention_mask,
262
+ )
263
+
264
+ def prepare(hidden_states, encoder_hidden_states, attention_mask,
265
+ encoder_attention_mask):
266
+ hidden_states = all_to_all(hidden_states, scatter_dim=2, gather_dim=0)
267
+ encoder_hidden_states = all_to_all(encoder_hidden_states,
268
+ scatter_dim=1,
269
+ gather_dim=0)
270
+ attention_mask = all_to_all(attention_mask,
271
+ scatter_dim=1,
272
+ gather_dim=0)
273
+ encoder_attention_mask = all_to_all(encoder_attention_mask,
274
+ scatter_dim=1,
275
+ gather_dim=0)
276
+ return (
277
+ hidden_states,
278
+ encoder_hidden_states,
279
+ attention_mask,
280
+ encoder_attention_mask,
281
+ )
282
+
283
+ sp_size = nccl_info.sp_size
284
+ frame = hidden_states.shape[2]
285
+ assert frame % sp_size == 0, "frame should be a multiple of sp_size"
286
+
287
+ (
288
+ hidden_states,
289
+ encoder_hidden_states,
290
+ attention_mask,
291
+ encoder_attention_mask,
292
+ ) = prepare(
293
+ hidden_states,
294
+ encoder_hidden_states.repeat(1, sp_size, 1),
295
+ attention_mask.repeat(1, sp_size, 1, 1),
296
+ encoder_attention_mask.repeat(1, sp_size),
297
+ )
298
+
299
+ return hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask
300
+
301
+
302
+ def sp_parallel_dataloader_wrapper(dataloader, device, train_batch_size,
303
+ sp_size, train_sp_batch_size):
304
+ while True:
305
+ for data_item in dataloader:
306
+ latents, cond, attn_mask, cond_mask = data_item
307
+ latents = latents.to(device)
308
+ cond = cond.to(device)
309
+ attn_mask = attn_mask.to(device)
310
+ cond_mask = cond_mask.to(device)
311
+ frame = latents.shape[2]
312
+ if frame == 1:
313
+ yield latents, cond, attn_mask, cond_mask
314
+ else:
315
+ latents, cond, attn_mask, cond_mask = prepare_sequence_parallel_data(
316
+ latents, cond, attn_mask, cond_mask)
317
+ assert (
318
+ train_batch_size * sp_size >= train_sp_batch_size
319
+ ), "train_batch_size * sp_size should be greater than train_sp_batch_size"
320
+ for iter in range(train_batch_size * sp_size //
321
+ train_sp_batch_size):
322
+ st_idx = iter * train_sp_batch_size
323
+ ed_idx = (iter + 1) * train_sp_batch_size
324
+ encoder_hidden_states = cond[st_idx:ed_idx]
325
+ attention_mask = attn_mask[st_idx:ed_idx]
326
+ encoder_attention_mask = cond_mask[st_idx:ed_idx]
327
+ yield (
328
+ latents[st_idx:ed_idx],
329
+ encoder_hidden_states,
330
+ attention_mask,
331
+ encoder_attention_mask,
332
+ )
ovi/distributed_comms/distributed/__init__.py ADDED
File without changes
ovi/distributed_comms/distributed/fsdp.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from functools import partial
3
+
4
+ import torch
5
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
6
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
7
+ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
8
+
9
+
10
+ def shard_model(
11
+ model,
12
+ device_id,
13
+ param_dtype=torch.bfloat16,
14
+ reduce_dtype=torch.float32,
15
+ buffer_dtype=torch.float32,
16
+ process_group=None,
17
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
18
+ sync_module_states=True,
19
+ ):
20
+ model = FSDP(
21
+ module=model,
22
+ process_group=process_group,
23
+ sharding_strategy=sharding_strategy,
24
+ auto_wrap_policy=partial(
25
+ lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
26
+ mixed_precision=MixedPrecision(
27
+ param_dtype=param_dtype,
28
+ reduce_dtype=reduce_dtype,
29
+ buffer_dtype=buffer_dtype),
30
+ device_id=device_id,
31
+ sync_module_states=sync_module_states)
32
+ return model
ovi/distributed_comms/distributed/xdit_context_parallel.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.cuda.amp as amp
4
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
5
+ get_sequence_parallel_world_size,
6
+ get_sp_group)
7
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
8
+
9
+ from ..modules.model import sinusoidal_embedding_1d
10
+
11
+
12
+ def pad_freqs(original_tensor, target_len):
13
+ seq_len, s1, s2 = original_tensor.shape
14
+ pad_size = target_len - seq_len
15
+ padding_tensor = torch.ones(
16
+ pad_size,
17
+ s1,
18
+ s2,
19
+ dtype=original_tensor.dtype,
20
+ device=original_tensor.device)
21
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
22
+ return padded_tensor
23
+
24
+
25
+ @amp.autocast(enabled=False)
26
+ def rope_apply(x, grid_sizes, freqs):
27
+ """
28
+ x: [B, L, N, C].
29
+ grid_sizes: [B, 3].
30
+ freqs: [M, C // 2].
31
+ """
32
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
33
+ # split freqs
34
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
35
+
36
+ # loop over samples
37
+ output = []
38
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
39
+ seq_len = f * h * w
40
+
41
+ # precompute multipliers
42
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
43
+ s, n, -1, 2))
44
+ freqs_i = torch.cat([
45
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
46
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
47
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
48
+ ],
49
+ dim=-1).reshape(seq_len, 1, -1)
50
+
51
+ # apply rotary embedding
52
+ sp_size = get_sequence_parallel_world_size()
53
+ sp_rank = get_sequence_parallel_rank()
54
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
55
+ s_per_rank = s
56
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
57
+ s_per_rank), :, :]
58
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
59
+ x_i = torch.cat([x_i, x[i, s:]])
60
+
61
+ # append to collection
62
+ output.append(x_i)
63
+ return torch.stack(output).float()
64
+
65
+
66
+ def usp_dit_forward(
67
+ self,
68
+ x,
69
+ t,
70
+ context,
71
+ seq_len,
72
+ clip_fea=None,
73
+ y=None,
74
+ ):
75
+ """
76
+ x: A list of videos each with shape [C, T, H, W].
77
+ t: [B].
78
+ context: A list of text embeddings each with shape [L, C].
79
+ """
80
+ if self.model_type == 'i2v':
81
+ assert clip_fea is not None and y is not None
82
+ # params
83
+ device = self.patch_embedding.weight.device
84
+ if self.freqs.device != device:
85
+ self.freqs = self.freqs.to(device)
86
+
87
+ if y is not None:
88
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
89
+
90
+ # embeddings
91
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
92
+ grid_sizes = torch.stack(
93
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
94
+ x = [u.flatten(2).transpose(1, 2) for u in x]
95
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
96
+ assert seq_lens.max() <= seq_len
97
+ x = torch.cat([
98
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
99
+ for u in x
100
+ ])
101
+
102
+ # time embeddings
103
+ with amp.autocast(dtype=torch.float32):
104
+ e = self.time_embedding(
105
+ sinusoidal_embedding_1d(self.freq_dim, t).float())
106
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
107
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
108
+
109
+ # context
110
+ context_lens = None
111
+ context = self.text_embedding(
112
+ torch.stack([
113
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
114
+ for u in context
115
+ ]))
116
+
117
+ if clip_fea is not None:
118
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
119
+ context = torch.concat([context_clip, context], dim=1)
120
+
121
+ # arguments
122
+ kwargs = dict(
123
+ e=e0,
124
+ seq_lens=seq_lens,
125
+ grid_sizes=grid_sizes,
126
+ freqs=self.freqs,
127
+ context=context,
128
+ context_lens=context_lens)
129
+
130
+ # Context Parallel
131
+ x = torch.chunk(
132
+ x, get_sequence_parallel_world_size(),
133
+ dim=1)[get_sequence_parallel_rank()]
134
+
135
+ for block in self.blocks:
136
+ x = block(x, **kwargs)
137
+
138
+ # head
139
+ x = self.head(x, e)
140
+
141
+ # Context Parallel
142
+ x = get_sp_group().all_gather(x, dim=1)
143
+
144
+ # unpatchify
145
+ x = self.unpatchify(x, grid_sizes)
146
+ return [u.float() for u in x]
147
+
148
+
149
+ def usp_attn_forward(self,
150
+ x,
151
+ seq_lens,
152
+ grid_sizes,
153
+ freqs,
154
+ dtype=torch.bfloat16):
155
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
156
+ half_dtypes = (torch.float16, torch.bfloat16)
157
+
158
+ def half(x):
159
+ return x if x.dtype in half_dtypes else x.to(dtype)
160
+
161
+ # query, key, value function
162
+ def qkv_fn(x):
163
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
164
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
165
+ v = self.v(x).view(b, s, n, d)
166
+ return q, k, v
167
+
168
+ q, k, v = qkv_fn(x)
169
+ q = rope_apply(q, grid_sizes, freqs)
170
+ k = rope_apply(k, grid_sizes, freqs)
171
+
172
+ # TODO: We should use unpaded q,k,v for attention.
173
+ # k_lens = seq_lens // get_sequence_parallel_world_size()
174
+ # if k_lens is not None:
175
+ # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
176
+ # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
177
+ # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
178
+
179
+ x = xFuserLongContextAttention()(
180
+ None,
181
+ query=half(q),
182
+ key=half(k),
183
+ value=half(v),
184
+ window_size=self.window_size)
185
+
186
+ # TODO: padding after attention.
187
+ # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
188
+
189
+ # output
190
+ x = x.flatten(2)
191
+ x = self.o(x)
192
+ return x
ovi/distributed_comms/parallel_states.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch.distributed as dist
4
+
5
+
6
+ class COMM_INFO:
7
+
8
+ def __init__(self):
9
+ self.group = None
10
+ self.sp_size = 1
11
+ self.global_rank = 0
12
+ self.rank_within_group = 0
13
+ self.group_id = 0
14
+
15
+
16
+ nccl_info = COMM_INFO()
17
+ _SEQUENCE_PARALLEL_STATE = False
18
+
19
+
20
+ def initialize_sequence_parallel_state(sequence_parallel_size):
21
+ global _SEQUENCE_PARALLEL_STATE
22
+ if sequence_parallel_size > 1:
23
+ _SEQUENCE_PARALLEL_STATE = True
24
+ initialize_sequence_parallel_group(sequence_parallel_size)
25
+ else:
26
+ nccl_info.sp_size = 1
27
+ nccl_info.global_rank = int(os.getenv("RANK", "0"))
28
+ nccl_info.rank_within_group = 0
29
+ nccl_info.group_id = int(os.getenv("RANK", "0"))
30
+
31
+
32
+ def set_sequence_parallel_state(state):
33
+ global _SEQUENCE_PARALLEL_STATE
34
+ _SEQUENCE_PARALLEL_STATE = state
35
+
36
+
37
+ def get_sequence_parallel_state():
38
+ return _SEQUENCE_PARALLEL_STATE
39
+
40
+
41
+ def initialize_sequence_parallel_group(sequence_parallel_size):
42
+ """Initialize the sequence parallel group."""
43
+ rank = int(os.getenv("RANK", "0"))
44
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
45
+ assert (
46
+ world_size % sequence_parallel_size == 0
47
+ ), "world_size must be divisible by sequence_parallel_size, but got world_size: {}, sequence_parallel_size: {}".format(
48
+ world_size, sequence_parallel_size)
49
+ nccl_info.sp_size = sequence_parallel_size
50
+ nccl_info.global_rank = rank
51
+ num_sequence_parallel_groups: int = world_size // sequence_parallel_size
52
+ for i in range(num_sequence_parallel_groups):
53
+ ranks = range(i * sequence_parallel_size,
54
+ (i + 1) * sequence_parallel_size)
55
+ group = dist.new_group(ranks)
56
+ if rank in ranks:
57
+ nccl_info.group = group
58
+ nccl_info.rank_within_group = rank - i * sequence_parallel_size
59
+ nccl_info.group_id = i
60
+
61
+
62
+
63
+ def initialize_sequence_parallel_group_custom(process_group):
64
+ set_sequence_parallel_state(True)
65
+ """Initialize an unsafe sequence parallel group with a pre-formed group."""
66
+ rank = dist.get_rank(group=process_group)
67
+ sequence_parallel_size = dist.get_world_size(group=process_group)
68
+ nccl_info.sp_size = sequence_parallel_size
69
+ nccl_info.global_rank = dist.get_rank() # global rank
70
+ nccl_info.group = process_group
71
+ nccl_info.rank_within_group = rank
72
+ nccl_info.group_id = 0
73
+
74
+
75
+ def destroy_sequence_parallel_group():
76
+ """Destroy the sequence parallel group."""
77
+ dist.destroy_process_group()
ovi/distributed_comms/util.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+ import torch.distributed as dist
5
+
6
+
7
+ def get_global_rank() -> int:
8
+ """
9
+ Get the global rank, the global index of the GPU.
10
+ """
11
+ return int(os.environ.get("RANK", "0"))
12
+
13
+
14
+ def get_local_rank() -> int:
15
+ """
16
+ Get the local rank, the local index of the GPU.
17
+ """
18
+ return int(os.environ.get("LOCAL_RANK", "0"))
19
+
20
+
21
+ def get_world_size() -> int:
22
+ """
23
+ Get the world size, the total amount of GPUs.
24
+ """
25
+ return int(os.environ.get("WORLD_SIZE", "1"))
26
+
27
+
28
+ def get_device() -> torch.device:
29
+ """
30
+ Get current rank device.
31
+ """
32
+ return torch.device("cuda", get_local_rank())
33
+
34
+ def get_sequence_parallel_group():
35
+ """Get the sequence parallel group the caller rank belongs to."""
36
+ return _SEQUENCE_PARALLEL_GROUP
37
+
38
+ def initialize_sequence_parallelism(sequence_parallel_size):
39
+ assert int(get_world_size()) % sequence_parallel_size == 0
40
+ sequence_parallel_num_groups = int(get_world_size()) // sequence_parallel_size
41
+ global _SEQUENCE_PARALLEL_GROUP
42
+ for i in range(sequence_parallel_num_groups):
43
+ ranks = range(i * sequence_parallel_size,
44
+ (i + 1) * sequence_parallel_size)
45
+ group = torch.distributed.new_group(ranks)
46
+ if int(get_global_rank()) in ranks:
47
+ print(f"Rank {get_global_rank()} joined group with ranks {list(ranks)}")
48
+ _SEQUENCE_PARALLEL_GROUP = group
ovi/modules/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .attention import flash_attention
2
+ from .model import WanModel
3
+ from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
4
+ from .tokenizers import HuggingfaceTokenizer
5
+ from .vae import WanVAE
6
+
7
+ __all__ = [
8
+ 'WanVAE',
9
+ 'WanModel',
10
+ 'T5Model',
11
+ 'T5Encoder',
12
+ 'T5Decoder',
13
+ 'T5EncoderModel',
14
+ 'HuggingfaceTokenizer',
15
+ 'flash_attention',
16
+ ]
ovi/modules/attention.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+
4
+ try:
5
+ import flash_attn_interface
6
+ FLASH_ATTN_3_AVAILABLE = True
7
+ except ModuleNotFoundError:
8
+ FLASH_ATTN_3_AVAILABLE = False
9
+
10
+ try:
11
+ import flash_attn
12
+ FLASH_ATTN_2_AVAILABLE = True
13
+ except ModuleNotFoundError:
14
+ FLASH_ATTN_2_AVAILABLE = False
15
+
16
+ import warnings
17
+
18
+ __all__ = [
19
+ 'flash_attention',
20
+ 'attention',
21
+ 'attention_with_weights',
22
+ ]
23
+
24
+
25
+ def flash_attention(
26
+ q,
27
+ k,
28
+ v,
29
+ q_lens=None,
30
+ k_lens=None,
31
+ dropout_p=0.,
32
+ softmax_scale=None,
33
+ q_scale=None,
34
+ causal=False,
35
+ window_size=(-1, -1),
36
+ deterministic=False,
37
+ dtype=torch.bfloat16,
38
+ version=None
39
+ ):
40
+ """
41
+ q: [B, Lq, Nq, C1].
42
+ k: [B, Lk, Nk, C1].
43
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
44
+ q_lens: [B].
45
+ k_lens: [B].
46
+ dropout_p: float. Dropout probability.
47
+ softmax_scale: float. The scaling of QK^T before applying softmax.
48
+ causal: bool. Whether to apply causal attention mask.
49
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
50
+ deterministic: bool. If True, slightly slower and uses more memory.
51
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
52
+ """
53
+ half_dtypes = (torch.float16, torch.bfloat16)
54
+ assert dtype in half_dtypes
55
+ assert q.device.type == 'cuda' and q.size(-1) <= 256
56
+
57
+ # params
58
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
59
+
60
+ def half(x):
61
+ return x if x.dtype in half_dtypes else x.to(dtype)
62
+
63
+ # preprocess query
64
+ if q_lens is None:
65
+ q = half(q.flatten(0, 1))
66
+ q_lens = torch.tensor(
67
+ [lq] * b, dtype=torch.int32).to(
68
+ device=q.device, non_blocking=True)
69
+ else:
70
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
71
+
72
+ # preprocess key, value
73
+ if k_lens is None:
74
+ k = half(k.flatten(0, 1))
75
+ v = half(v.flatten(0, 1))
76
+ k_lens = torch.tensor(
77
+ [lk] * b, dtype=torch.int32).to(
78
+ device=k.device, non_blocking=True)
79
+ else:
80
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
81
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
82
+
83
+ q = q.to(v.dtype)
84
+ k = k.to(v.dtype)
85
+
86
+ if q_scale is not None:
87
+ q = q * q_scale
88
+
89
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
90
+ warnings.warn(
91
+ 'Flash attention 3 is not available, use flash attention 2 instead.'
92
+ )
93
+
94
+ # apply attention
95
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
96
+ # Note: dropout_p, window_size are not supported in FA3 now.
97
+ x = flash_attn_interface.flash_attn_varlen_func(
98
+ q=q,
99
+ k=k,
100
+ v=v,
101
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
102
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
103
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
104
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
105
+ seqused_q=None,
106
+ seqused_k=None,
107
+ max_seqlen_q=lq,
108
+ max_seqlen_k=lk,
109
+ softmax_scale=softmax_scale,
110
+ causal=causal,
111
+ deterministic=deterministic)
112
+
113
+ if isinstance(x, tuple):
114
+ x = x[0]
115
+ x = x.unflatten(0, (b, lq))
116
+
117
+ else:
118
+ assert FLASH_ATTN_2_AVAILABLE
119
+ x = flash_attn.flash_attn_varlen_func(
120
+ q=q,
121
+ k=k,
122
+ v=v,
123
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
124
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
125
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
126
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
127
+ max_seqlen_q=lq,
128
+ max_seqlen_k=lk,
129
+ dropout_p=dropout_p,
130
+ softmax_scale=softmax_scale,
131
+ causal=causal,
132
+ window_size=window_size,
133
+ deterministic=deterministic).unflatten(0, (b, lq))
134
+
135
+ # output
136
+ return x.type(out_dtype)
137
+
138
+
139
+ def attention_with_weights(
140
+ q,
141
+ k,
142
+ v,
143
+ q_lens=None,
144
+ k_lens=None,
145
+ softmax_scale=None,
146
+ q_scale=None,
147
+ causal=False,
148
+ average_for_q=False,
149
+ total_video_latent_frames = 21
150
+ ):
151
+ """
152
+ Compute attention with explicit attention weights for visualization.
153
+ Returns both output and attention weights.
154
+ """
155
+ out_dtype = q.dtype
156
+
157
+ # Handle sequence lengths
158
+ b, lq, lk = q.size(0), q.size(1), k.size(1)
159
+
160
+ if q_lens is None:
161
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32, device=q.device)
162
+ else:
163
+ # Ensure q_lens is on the same device as q
164
+ q_lens = q_lens.to(q.device)
165
+
166
+ if k_lens is None:
167
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32, device=k.device)
168
+ else:
169
+ # Ensure k_lens is on the same device as k
170
+ k_lens = k_lens.to(k.device)
171
+
172
+ # Apply q_scale if provided
173
+ if q_scale is not None:
174
+ q = q * q_scale
175
+
176
+ # Compute attention weights manually
177
+ # q: [B, Lq, Nq, C], k: [B, Lk, Nk, C]
178
+ scale = softmax_scale if softmax_scale is not None else (q.size(-1) ** -0.5)
179
+
180
+ # Compute scores: [B, Nq, Lq, Lk]
181
+ scores = torch.einsum('blhd,bshd->bhls', q, k) * scale
182
+
183
+ # Apply causal mask if needed
184
+ if causal:
185
+ mask = torch.triu(torch.ones(lq, lk, device=q.device, dtype=torch.bool), diagonal=1)
186
+ scores.masked_fill_(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
187
+
188
+ # Mask for k_lens (columns)
189
+ k_mask = torch.arange(lk, device=k.device).unsqueeze(0) >= k_lens.unsqueeze(1) # [B, Lk]
190
+ scores.masked_fill_(k_mask.unsqueeze(1).unsqueeze(2), float('-inf')) # [B, 1, 1, Lk]
191
+
192
+ # Mask for q_lens (rows)
193
+ q_mask = torch.arange(lq, device=q.device).unsqueeze(0) >= q_lens.unsqueeze(1) # [B, Lq]
194
+ scores.masked_fill_(q_mask.unsqueeze(1).unsqueeze(3), float('-inf')) # [B, 1, Lq, 1]
195
+
196
+ # Compute attention weights
197
+ attn_weights = torch.softmax(scores, dim=-1) # [B, Nq, Lq, Lk]
198
+ assert attn_weights.shape[0] == 1, "Batch size > 1 not supported for attention visualization."
199
+
200
+ # Average attention weights to reduce memory usage before returning
201
+ # Average across batch dimension (should be 1) and query heads and query sequence length
202
+ # This gives us attention weight per video token: [Lk]
203
+ if average_for_q:
204
+ #avg_attn_weights = torch.mean(attn_weights, dim=(0, 1, 3)) # [Lq]
205
+ avg_attn_weights = torch.max(attn_weights, dim=3)[0].mean(dim=(0, 1)) # [Lq]
206
+ else:
207
+ if 0:
208
+ avg_attn_weights = torch.mean(attn_weights, dim=(0, 1, 2)) # [Lk]
209
+ elif 1:
210
+ B, H, Lq, Lk = attn_weights.shape # [1, H, Lq, Lk]
211
+ per_frame_seq_len = Lk // total_video_latent_frames
212
+ per_frame_aud_len = Lq // total_video_latent_frames
213
+
214
+ avg_attn_weights = torch.zeros((Lk,), device=attn_weights.device, dtype=attn_weights.dtype)
215
+
216
+ eps = 1e-8 # numerical stability
217
+ for i in range(total_video_latent_frames):
218
+ start_idx_v = i * per_frame_seq_len
219
+ end_idx_v = (i + 1) * per_frame_seq_len
220
+
221
+ start_idx_a = i * per_frame_aud_len
222
+ end_idx_a = (i + 1) * per_frame_aud_len
223
+
224
+ # attn_chunk: [H, La, Lv]
225
+ attn_chunk = attn_weights[0, :, start_idx_a:end_idx_a, start_idx_v:end_idx_v]
226
+
227
+ # ---- Head informativeness via (low) entropy over Lv ----
228
+ # Normalize within the Lv slice per (head, query) to make a proper distribution
229
+ p = attn_chunk / (attn_chunk.sum(dim=-1, keepdim=True) + eps) # [H, La, Lv]
230
+ entropy = -(p * (p + eps).log()).sum(dim=-1).mean(dim=1) # [H]
231
+
232
+ # Convert to positive head weights (lower entropy -> larger weight)
233
+ saliency = 1.0 / (entropy + 1e-6) # [H]
234
+ head_w = saliency / (saliency.sum() + eps) # [H], sum=1
235
+
236
+ # Reduce across audio queries first (pick strong responses), then weight heads
237
+ per_head = torch.amax(attn_chunk, dim=1) # [H, Lv]
238
+ weighted = (per_head * head_w[:, None]).sum(dim=0) # [Lv]
239
+
240
+ avg_attn_weights[start_idx_v:end_idx_v] = weighted
241
+ else:
242
+ avg_attn_weights = torch.mean(attn_weights, dim=(0, 2)).max(dim=(0))[0] # [Lk]
243
+
244
+ # Compute output: [B, Lq, Nq, C]
245
+ out = torch.einsum('bhls,bshd->blhd', attn_weights, v)
246
+
247
+ return out.to(out_dtype), avg_attn_weights.to(out_dtype)
248
+
249
+
250
+ def attention(
251
+ q,
252
+ k,
253
+ v,
254
+ q_lens=None,
255
+ k_lens=None,
256
+ dropout_p=0.,
257
+ softmax_scale=None,
258
+ q_scale=None,
259
+ causal=False,
260
+ window_size=(-1, -1),
261
+ deterministic=False,
262
+ dtype=torch.bfloat16,
263
+ fa_version=None,
264
+ ):
265
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
266
+ return flash_attention(
267
+ q=q,
268
+ k=k,
269
+ v=v,
270
+ q_lens=q_lens,
271
+ k_lens=k_lens,
272
+ dropout_p=dropout_p,
273
+ softmax_scale=softmax_scale,
274
+ q_scale=q_scale,
275
+ causal=causal,
276
+ window_size=window_size,
277
+ deterministic=deterministic,
278
+ dtype=dtype,
279
+ version=fa_version,
280
+ )
281
+ else:
282
+ if q_lens is not None or k_lens is not None:
283
+ warnings.warn(
284
+ 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
285
+ )
286
+ attn_mask = None
287
+
288
+ q = q.transpose(1, 2).to(dtype)
289
+ k = k.transpose(1, 2).to(dtype)
290
+ v = v.transpose(1, 2).to(dtype)
291
+
292
+ out = torch.nn.functional.scaled_dot_product_attention(
293
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
294
+
295
+ out = out.transpose(1, 2).contiguous()
296
+ return out
ovi/modules/clip.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms as T
10
+
11
+ from .attention import flash_attention
12
+ from .tokenizers import HuggingfaceTokenizer
13
+ from .xlm_roberta import XLMRoberta
14
+
15
+ __all__ = [
16
+ 'XLMRobertaCLIP',
17
+ 'clip_xlm_roberta_vit_h_14',
18
+ 'CLIPModel',
19
+ ]
20
+
21
+
22
+ def pos_interpolate(pos, seq_len):
23
+ if pos.size(1) == seq_len:
24
+ return pos
25
+ else:
26
+ src_grid = int(math.sqrt(pos.size(1)))
27
+ tar_grid = int(math.sqrt(seq_len))
28
+ n = pos.size(1) - src_grid * src_grid
29
+ return torch.cat([
30
+ pos[:, :n],
31
+ F.interpolate(
32
+ pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
33
+ 0, 3, 1, 2),
34
+ size=(tar_grid, tar_grid),
35
+ mode='bicubic',
36
+ align_corners=False).flatten(2).transpose(1, 2)
37
+ ],
38
+ dim=1)
39
+
40
+
41
+ class QuickGELU(nn.Module):
42
+
43
+ def forward(self, x):
44
+ return x * torch.sigmoid(1.702 * x)
45
+
46
+
47
+ class LayerNorm(nn.LayerNorm):
48
+
49
+ def forward(self, x):
50
+ return super().forward(x.float()).type_as(x)
51
+
52
+
53
+ class SelfAttention(nn.Module):
54
+
55
+ def __init__(self,
56
+ dim,
57
+ num_heads,
58
+ causal=False,
59
+ attn_dropout=0.0,
60
+ proj_dropout=0.0):
61
+ assert dim % num_heads == 0
62
+ super().__init__()
63
+ self.dim = dim
64
+ self.num_heads = num_heads
65
+ self.head_dim = dim // num_heads
66
+ self.causal = causal
67
+ self.attn_dropout = attn_dropout
68
+ self.proj_dropout = proj_dropout
69
+
70
+ # layers
71
+ self.to_qkv = nn.Linear(dim, dim * 3)
72
+ self.proj = nn.Linear(dim, dim)
73
+
74
+ def forward(self, x):
75
+ """
76
+ x: [B, L, C].
77
+ """
78
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
79
+
80
+ # compute query, key, value
81
+ q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
82
+
83
+ # compute attention
84
+ p = self.attn_dropout if self.training else 0.0
85
+ x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
86
+ # x = flash_attention(q, k, v, dropout_p=p, causal=self.causal)
87
+ x = x.reshape(b, s, c)
88
+
89
+ # output
90
+ x = self.proj(x)
91
+ x = F.dropout(x, self.proj_dropout, self.training)
92
+ return x
93
+
94
+
95
+ class SwiGLU(nn.Module):
96
+
97
+ def __init__(self, dim, mid_dim):
98
+ super().__init__()
99
+ self.dim = dim
100
+ self.mid_dim = mid_dim
101
+
102
+ # layers
103
+ self.fc1 = nn.Linear(dim, mid_dim)
104
+ self.fc2 = nn.Linear(dim, mid_dim)
105
+ self.fc3 = nn.Linear(mid_dim, dim)
106
+
107
+ def forward(self, x):
108
+ x = F.silu(self.fc1(x)) * self.fc2(x)
109
+ x = self.fc3(x)
110
+ return x
111
+
112
+
113
+ class AttentionBlock(nn.Module):
114
+
115
+ def __init__(self,
116
+ dim,
117
+ mlp_ratio,
118
+ num_heads,
119
+ post_norm=False,
120
+ causal=False,
121
+ activation='quick_gelu',
122
+ attn_dropout=0.0,
123
+ proj_dropout=0.0,
124
+ norm_eps=1e-5):
125
+ assert activation in ['quick_gelu', 'gelu', 'swi_glu']
126
+ super().__init__()
127
+ self.dim = dim
128
+ self.mlp_ratio = mlp_ratio
129
+ self.num_heads = num_heads
130
+ self.post_norm = post_norm
131
+ self.causal = causal
132
+ self.norm_eps = norm_eps
133
+
134
+ # layers
135
+ self.norm1 = LayerNorm(dim, eps=norm_eps)
136
+ self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
137
+ proj_dropout)
138
+ self.norm2 = LayerNorm(dim, eps=norm_eps)
139
+ if activation == 'swi_glu':
140
+ self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
141
+ else:
142
+ self.mlp = nn.Sequential(
143
+ nn.Linear(dim, int(dim * mlp_ratio)),
144
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
145
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
146
+
147
+ def forward(self, x):
148
+ if self.post_norm:
149
+ x = x + self.norm1(self.attn(x))
150
+ x = x + self.norm2(self.mlp(x))
151
+ else:
152
+ x = x + self.attn(self.norm1(x))
153
+ x = x + self.mlp(self.norm2(x))
154
+ return x
155
+
156
+
157
+ class AttentionPool(nn.Module):
158
+
159
+ def __init__(self,
160
+ dim,
161
+ mlp_ratio,
162
+ num_heads,
163
+ activation='gelu',
164
+ proj_dropout=0.0,
165
+ norm_eps=1e-5):
166
+ assert dim % num_heads == 0
167
+ super().__init__()
168
+ self.dim = dim
169
+ self.mlp_ratio = mlp_ratio
170
+ self.num_heads = num_heads
171
+ self.head_dim = dim // num_heads
172
+ self.proj_dropout = proj_dropout
173
+ self.norm_eps = norm_eps
174
+
175
+ # layers
176
+ gain = 1.0 / math.sqrt(dim)
177
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
178
+ self.to_q = nn.Linear(dim, dim)
179
+ self.to_kv = nn.Linear(dim, dim * 2)
180
+ self.proj = nn.Linear(dim, dim)
181
+ self.norm = LayerNorm(dim, eps=norm_eps)
182
+ self.mlp = nn.Sequential(
183
+ nn.Linear(dim, int(dim * mlp_ratio)),
184
+ QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
185
+ nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
186
+
187
+ def forward(self, x):
188
+ """
189
+ x: [B, L, C].
190
+ """
191
+ b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
192
+
193
+ # compute query, key, value
194
+ q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
195
+ k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
196
+
197
+ # compute attention
198
+ x = flash_attention(q, k, v, version=2)
199
+ # x = flash_attention(q, k, v)
200
+
201
+ x = x.reshape(b, 1, c)
202
+
203
+ # output
204
+ x = self.proj(x)
205
+ x = F.dropout(x, self.proj_dropout, self.training)
206
+
207
+ # mlp
208
+ x = x + self.mlp(self.norm(x))
209
+ return x[:, 0]
210
+
211
+
212
+ class VisionTransformer(nn.Module):
213
+
214
+ def __init__(self,
215
+ image_size=224,
216
+ patch_size=16,
217
+ dim=768,
218
+ mlp_ratio=4,
219
+ out_dim=512,
220
+ num_heads=12,
221
+ num_layers=12,
222
+ pool_type='token',
223
+ pre_norm=True,
224
+ post_norm=False,
225
+ activation='quick_gelu',
226
+ attn_dropout=0.0,
227
+ proj_dropout=0.0,
228
+ embedding_dropout=0.0,
229
+ norm_eps=1e-5):
230
+ if image_size % patch_size != 0:
231
+ print(
232
+ '[WARNING] image_size is not divisible by patch_size',
233
+ flush=True)
234
+ assert pool_type in ('token', 'token_fc', 'attn_pool')
235
+ out_dim = out_dim or dim
236
+ super().__init__()
237
+ self.image_size = image_size
238
+ self.patch_size = patch_size
239
+ self.num_patches = (image_size // patch_size)**2
240
+ self.dim = dim
241
+ self.mlp_ratio = mlp_ratio
242
+ self.out_dim = out_dim
243
+ self.num_heads = num_heads
244
+ self.num_layers = num_layers
245
+ self.pool_type = pool_type
246
+ self.post_norm = post_norm
247
+ self.norm_eps = norm_eps
248
+
249
+ # embeddings
250
+ gain = 1.0 / math.sqrt(dim)
251
+ self.patch_embedding = nn.Conv2d(
252
+ 3,
253
+ dim,
254
+ kernel_size=patch_size,
255
+ stride=patch_size,
256
+ bias=not pre_norm)
257
+ if pool_type in ('token', 'token_fc'):
258
+ self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
259
+ self.pos_embedding = nn.Parameter(gain * torch.randn(
260
+ 1, self.num_patches +
261
+ (1 if pool_type in ('token', 'token_fc') else 0), dim))
262
+ self.dropout = nn.Dropout(embedding_dropout)
263
+
264
+ # transformer
265
+ self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
266
+ self.transformer = nn.Sequential(*[
267
+ AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
268
+ activation, attn_dropout, proj_dropout, norm_eps)
269
+ for _ in range(num_layers)
270
+ ])
271
+ self.post_norm = LayerNorm(dim, eps=norm_eps)
272
+
273
+ # head
274
+ if pool_type == 'token':
275
+ self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
276
+ elif pool_type == 'token_fc':
277
+ self.head = nn.Linear(dim, out_dim)
278
+ elif pool_type == 'attn_pool':
279
+ self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
280
+ proj_dropout, norm_eps)
281
+
282
+ def forward(self, x, interpolation=False, use_31_block=False):
283
+ b = x.size(0)
284
+
285
+ # embeddings
286
+ x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
287
+ if self.pool_type in ('token', 'token_fc'):
288
+ x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
289
+ if interpolation:
290
+ e = pos_interpolate(self.pos_embedding, x.size(1))
291
+ else:
292
+ e = self.pos_embedding
293
+ x = self.dropout(x + e)
294
+ if self.pre_norm is not None:
295
+ x = self.pre_norm(x)
296
+
297
+ # transformer
298
+ if use_31_block:
299
+ x = self.transformer[:-1](x)
300
+ return x
301
+ else:
302
+ x = self.transformer(x)
303
+ return x
304
+
305
+
306
+ class XLMRobertaWithHead(XLMRoberta):
307
+
308
+ def __init__(self, **kwargs):
309
+ self.out_dim = kwargs.pop('out_dim')
310
+ super().__init__(**kwargs)
311
+
312
+ # head
313
+ mid_dim = (self.dim + self.out_dim) // 2
314
+ self.head = nn.Sequential(
315
+ nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
316
+ nn.Linear(mid_dim, self.out_dim, bias=False))
317
+
318
+ def forward(self, ids):
319
+ # xlm-roberta
320
+ x = super().forward(ids)
321
+
322
+ # average pooling
323
+ mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
324
+ x = (x * mask).sum(dim=1) / mask.sum(dim=1)
325
+
326
+ # head
327
+ x = self.head(x)
328
+ return x
329
+
330
+
331
+ class XLMRobertaCLIP(nn.Module):
332
+
333
+ def __init__(self,
334
+ embed_dim=1024,
335
+ image_size=224,
336
+ patch_size=14,
337
+ vision_dim=1280,
338
+ vision_mlp_ratio=4,
339
+ vision_heads=16,
340
+ vision_layers=32,
341
+ vision_pool='token',
342
+ vision_pre_norm=True,
343
+ vision_post_norm=False,
344
+ activation='gelu',
345
+ vocab_size=250002,
346
+ max_text_len=514,
347
+ type_size=1,
348
+ pad_id=1,
349
+ text_dim=1024,
350
+ text_heads=16,
351
+ text_layers=24,
352
+ text_post_norm=True,
353
+ text_dropout=0.1,
354
+ attn_dropout=0.0,
355
+ proj_dropout=0.0,
356
+ embedding_dropout=0.0,
357
+ norm_eps=1e-5):
358
+ super().__init__()
359
+ self.embed_dim = embed_dim
360
+ self.image_size = image_size
361
+ self.patch_size = patch_size
362
+ self.vision_dim = vision_dim
363
+ self.vision_mlp_ratio = vision_mlp_ratio
364
+ self.vision_heads = vision_heads
365
+ self.vision_layers = vision_layers
366
+ self.vision_pre_norm = vision_pre_norm
367
+ self.vision_post_norm = vision_post_norm
368
+ self.activation = activation
369
+ self.vocab_size = vocab_size
370
+ self.max_text_len = max_text_len
371
+ self.type_size = type_size
372
+ self.pad_id = pad_id
373
+ self.text_dim = text_dim
374
+ self.text_heads = text_heads
375
+ self.text_layers = text_layers
376
+ self.text_post_norm = text_post_norm
377
+ self.norm_eps = norm_eps
378
+
379
+ # models
380
+ self.visual = VisionTransformer(
381
+ image_size=image_size,
382
+ patch_size=patch_size,
383
+ dim=vision_dim,
384
+ mlp_ratio=vision_mlp_ratio,
385
+ out_dim=embed_dim,
386
+ num_heads=vision_heads,
387
+ num_layers=vision_layers,
388
+ pool_type=vision_pool,
389
+ pre_norm=vision_pre_norm,
390
+ post_norm=vision_post_norm,
391
+ activation=activation,
392
+ attn_dropout=attn_dropout,
393
+ proj_dropout=proj_dropout,
394
+ embedding_dropout=embedding_dropout,
395
+ norm_eps=norm_eps)
396
+ self.textual = XLMRobertaWithHead(
397
+ vocab_size=vocab_size,
398
+ max_seq_len=max_text_len,
399
+ type_size=type_size,
400
+ pad_id=pad_id,
401
+ dim=text_dim,
402
+ out_dim=embed_dim,
403
+ num_heads=text_heads,
404
+ num_layers=text_layers,
405
+ post_norm=text_post_norm,
406
+ dropout=text_dropout)
407
+ self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
408
+
409
+ def forward(self, imgs, txt_ids):
410
+ """
411
+ imgs: [B, 3, H, W] of torch.float32.
412
+ - mean: [0.48145466, 0.4578275, 0.40821073]
413
+ - std: [0.26862954, 0.26130258, 0.27577711]
414
+ txt_ids: [B, L] of torch.long.
415
+ Encoded by data.CLIPTokenizer.
416
+ """
417
+ xi = self.visual(imgs)
418
+ xt = self.textual(txt_ids)
419
+ return xi, xt
420
+
421
+ def param_groups(self):
422
+ groups = [{
423
+ 'params': [
424
+ p for n, p in self.named_parameters()
425
+ if 'norm' in n or n.endswith('bias')
426
+ ],
427
+ 'weight_decay': 0.0
428
+ }, {
429
+ 'params': [
430
+ p for n, p in self.named_parameters()
431
+ if not ('norm' in n or n.endswith('bias'))
432
+ ]
433
+ }]
434
+ return groups
435
+
436
+
437
+ def _clip(pretrained=False,
438
+ pretrained_name=None,
439
+ model_cls=XLMRobertaCLIP,
440
+ return_transforms=False,
441
+ return_tokenizer=False,
442
+ tokenizer_padding='eos',
443
+ dtype=torch.float32,
444
+ device='cpu',
445
+ **kwargs):
446
+ # init a model on device
447
+ with torch.device(device):
448
+ model = model_cls(**kwargs)
449
+
450
+ # set device
451
+ model = model.to(dtype=dtype, device=device)
452
+ output = (model,)
453
+
454
+ # init transforms
455
+ if return_transforms:
456
+ # mean and std
457
+ if 'siglip' in pretrained_name.lower():
458
+ mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
459
+ else:
460
+ mean = [0.48145466, 0.4578275, 0.40821073]
461
+ std = [0.26862954, 0.26130258, 0.27577711]
462
+
463
+ # transforms
464
+ transforms = T.Compose([
465
+ T.Resize((model.image_size, model.image_size),
466
+ interpolation=T.InterpolationMode.BICUBIC),
467
+ T.ToTensor(),
468
+ T.Normalize(mean=mean, std=std)
469
+ ])
470
+ output += (transforms,)
471
+ return output[0] if len(output) == 1 else output
472
+
473
+
474
+ def clip_xlm_roberta_vit_h_14(
475
+ pretrained=False,
476
+ pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
477
+ **kwargs):
478
+ cfg = dict(
479
+ embed_dim=1024,
480
+ image_size=224,
481
+ patch_size=14,
482
+ vision_dim=1280,
483
+ vision_mlp_ratio=4,
484
+ vision_heads=16,
485
+ vision_layers=32,
486
+ vision_pool='token',
487
+ activation='gelu',
488
+ vocab_size=250002,
489
+ max_text_len=514,
490
+ type_size=1,
491
+ pad_id=1,
492
+ text_dim=1024,
493
+ text_heads=16,
494
+ text_layers=24,
495
+ text_post_norm=True,
496
+ text_dropout=0.1,
497
+ attn_dropout=0.0,
498
+ proj_dropout=0.0,
499
+ embedding_dropout=0.0)
500
+ cfg.update(**kwargs)
501
+ return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
502
+
503
+
504
+ class CLIPModel:
505
+
506
+ def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
507
+ self.dtype = dtype
508
+ self.device = device
509
+ self.checkpoint_path = checkpoint_path
510
+ self.tokenizer_path = tokenizer_path
511
+
512
+ # init model
513
+ self.model, self.transforms = clip_xlm_roberta_vit_h_14(
514
+ pretrained=False,
515
+ return_transforms=True,
516
+ return_tokenizer=False,
517
+ dtype=dtype,
518
+ device=device)
519
+ self.model = self.model.eval().requires_grad_(False)
520
+ logging.info(f'loading {checkpoint_path}')
521
+ self.model.load_state_dict(
522
+ torch.load(checkpoint_path, map_location='cpu'))
523
+
524
+ # init tokenizer
525
+ self.tokenizer = HuggingfaceTokenizer(
526
+ name=tokenizer_path,
527
+ seq_len=self.model.max_text_len - 2,
528
+ clean='whitespace')
529
+
530
+ def visual(self, videos):
531
+ # preprocess
532
+ size = (self.model.image_size,) * 2
533
+ videos = torch.cat([
534
+ F.interpolate(
535
+ u.transpose(0, 1),
536
+ size=size,
537
+ mode='bicubic',
538
+ align_corners=False) for u in videos
539
+ ])
540
+ videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
541
+
542
+ # forward
543
+ with torch.cuda.amp.autocast(dtype=self.dtype):
544
+ out = self.model.visual(videos, use_31_block=True)
545
+ return out
ovi/modules/fusion.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from ovi.modules.model import WanLayerNorm, WanModel, WanRMSNorm, gradient_checkpointing, rope_apply
5
+ from ovi.modules.attention import flash_attention
6
+ from ovi.distributed_comms.communications import all_gather, all_to_all_4D
7
+ from ovi.distributed_comms.parallel_states import nccl_info, get_sequence_parallel_state
8
+
9
+ class FusionModel(nn.Module):
10
+ def __init__(self, video_config=None, audio_config=None):
11
+ super().__init__()
12
+ has_video = True
13
+ has_audio = True
14
+ if video_config is not None:
15
+ self.video_model = WanModel(**video_config)
16
+ else:
17
+ has_video = False
18
+ self.video_model = None
19
+ print("Warning: No video model is provided!")
20
+
21
+ if audio_config is not None:
22
+ self.audio_model = WanModel(**audio_config)
23
+ else:
24
+ has_audio = False
25
+ self.audio_model = None
26
+ print("Warning: No audio model is provided!")
27
+
28
+ if has_video and has_audio:
29
+ assert len(self.video_model.blocks) == len(self.audio_model.blocks)
30
+ self.num_blocks = len(self.video_model.blocks)
31
+
32
+ self.use_sp = get_sequence_parallel_state()
33
+ if self.use_sp:
34
+ self.sp_size = nccl_info.sp_size
35
+ self.sp_rank = nccl_info.rank_within_group
36
+ self.inject_cross_attention_kv_projections()
37
+
38
+ self.init_weights()
39
+
40
+ def inject_cross_attention_kv_projections(self):
41
+ for vid_block in self.video_model.blocks:
42
+ vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim)
43
+ vid_block.cross_attn.v_fusion = nn.Linear(vid_block.dim, vid_block.dim)
44
+ vid_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(vid_block.dim, elementwise_affine=True)
45
+ vid_block.cross_attn.norm_k_fusion = WanRMSNorm(vid_block.dim, eps=1e-6) if vid_block.qk_norm else nn.Identity()
46
+
47
+
48
+ for audio_block in self.audio_model.blocks:
49
+ audio_block.cross_attn.k_fusion = nn.Linear(audio_block.dim, audio_block.dim)
50
+ audio_block.cross_attn.v_fusion = nn.Linear(audio_block.dim, audio_block.dim)
51
+ audio_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(audio_block.dim, elementwise_affine=True)
52
+ audio_block.cross_attn.norm_k_fusion = WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity()
53
+
54
+
55
+ def merge_kwargs(self, vid_kwargs, audio_kwargs):
56
+ """
57
+ keys in each kwarg:
58
+ e
59
+ seq_lens
60
+ grid_sizes
61
+ freqs
62
+ context
63
+ context_lens
64
+ """
65
+ merged_kwargs = {}
66
+ for key in vid_kwargs:
67
+ merged_kwargs[f"vid_{key}"] = vid_kwargs[key]
68
+ for key in audio_kwargs:
69
+ merged_kwargs[f"audio_{key}"] = audio_kwargs[key]
70
+ return merged_kwargs
71
+
72
+ def single_fusion_cross_attention_forward(self,
73
+ cross_attn_block,
74
+ src_seq,
75
+ src_grid_sizes,
76
+ src_freqs,
77
+ target_seq,
78
+ target_seq_lens,
79
+ target_grid_sizes,
80
+ target_freqs,
81
+ context,
82
+ context_lens
83
+ ):
84
+ b, n, d = src_seq.size(0), cross_attn_block.num_heads, cross_attn_block.head_dim
85
+ if hasattr(cross_attn_block, "k_img"):
86
+ ## means is i2v block
87
+ q, k, v, k_img, v_img = cross_attn_block.qkv_fn(src_seq, context)
88
+ else:
89
+ ## means is t2v block
90
+ q, k, v = cross_attn_block.qkv_fn(src_seq, context)
91
+ k_img = v_img = None
92
+
93
+
94
+ if self.use_sp:
95
+ q = all_to_all_4D(q, scatter_dim=2, gather_dim=1)
96
+ k = torch.chunk(k, self.sp_size, dim=2)[self.sp_rank]
97
+ v = torch.chunk(v, self.sp_size, dim=2)[self.sp_rank]
98
+ if k_img is not None:
99
+ k_img = torch.chunk(k_img, self.sp_size, dim=2)[self.sp_rank]
100
+ if v_img is not None:
101
+ v_img = torch.chunk(v_img, self.sp_size, dim=2)[self.sp_rank]
102
+
103
+ x = flash_attention(q, k, v, k_lens=context_lens)
104
+
105
+ if k_img is not None:
106
+ img_x = flash_attention(q, k_img, v_img, k_lens=None)
107
+ x = x + img_x
108
+
109
+ is_vid = src_grid_sizes.shape[1] > 1
110
+ # compute target attention
111
+ target_seq = cross_attn_block.pre_attn_norm_fusion(target_seq)
112
+ k_target = cross_attn_block.norm_k_fusion(cross_attn_block.k_fusion(target_seq)).view(b, -1, n, d)
113
+ v_target = cross_attn_block.v_fusion(target_seq).view(b, -1, n, d)
114
+ if self.use_sp:
115
+ k_target = all_to_all_4D(k_target, scatter_dim=2, gather_dim=1) # [B, L, H/P, C/H]
116
+ v_target = all_to_all_4D(v_target, scatter_dim=2, gather_dim=1) # [B, L, H/P, C/H]
117
+
118
+ q = rope_apply(q, src_grid_sizes, src_freqs)
119
+ k_target = rope_apply(k_target, target_grid_sizes, target_freqs)
120
+
121
+ target_x = flash_attention(q, k_target, v_target, k_lens=target_seq_lens)
122
+
123
+ x = x + target_x
124
+ if self.use_sp:
125
+ x = all_to_all_4D(x, scatter_dim=1, gather_dim=2) # [B, L/P, H, C/H]
126
+
127
+ x = x.flatten(2) # [B, L/P, C]
128
+
129
+ x = cross_attn_block.o(x)
130
+ return x
131
+
132
+ def single_fusion_cross_attention_ffn_forward(self,
133
+ attn_block,
134
+ src_seq,
135
+ src_grid_sizes,
136
+ src_freqs,
137
+ target_seq,
138
+ target_seq_lens,
139
+ target_grid_sizes,
140
+ target_freqs,
141
+ context,
142
+ context_lens,
143
+ src_e):
144
+
145
+ src_seq = src_seq + self.single_fusion_cross_attention_forward(attn_block.cross_attn,
146
+ attn_block.norm3(src_seq),
147
+ src_grid_sizes=src_grid_sizes,
148
+ src_freqs=src_freqs,
149
+ target_seq=target_seq,
150
+ target_seq_lens=target_seq_lens,
151
+ target_grid_sizes=target_grid_sizes,
152
+ target_freqs=target_freqs,
153
+ context=context,
154
+ context_lens=context_lens
155
+ )
156
+ y = attn_block.ffn(attn_block.norm2(src_seq).bfloat16() * (1 + src_e[4].squeeze(2)) + src_e[3].squeeze(2))
157
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
158
+ src_seq = src_seq + y * src_e[5].squeeze(2)
159
+ return src_seq
160
+
161
+ def single_fusion_block_forward(self,
162
+ vid_block,
163
+ audio_block,
164
+ vid,
165
+ audio,
166
+ vid_e,
167
+ vid_seq_lens,
168
+ vid_grid_sizes,
169
+ vid_freqs,
170
+ vid_context,
171
+ vid_context_lens,
172
+ audio_e,
173
+ audio_seq_lens,
174
+ audio_grid_sizes,
175
+ audio_freqs,
176
+ audio_context,
177
+ audio_context_lens
178
+ ):
179
+ ## audio modulation
180
+ assert audio_e.dtype == torch.bfloat16
181
+ assert len(audio_e.shape) == 4 and audio_e.size(2) == 6 and audio_e.shape[1] == audio.shape[1], f"{audio_e.shape}, {audio.shape}"
182
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
183
+ audio_e = audio_block.modulation(audio_e).chunk(6, dim=2)
184
+ assert audio_e[0].dtype == torch.bfloat16
185
+
186
+ # audio self-attention
187
+ audio_y = audio_block.self_attn(
188
+ audio_block.norm1(audio).bfloat16() * (1 + audio_e[1].squeeze(2)) + audio_e[0].squeeze(2), audio_seq_lens, audio_grid_sizes,
189
+ audio_freqs)
190
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
191
+ audio = audio + audio_y * audio_e[2].squeeze(2)
192
+
193
+ ## video modulation
194
+ assert len(vid_e.shape) == 4 and vid_e.size(2) == 6 and vid_e.shape[1] == vid.shape[1], f"{vid_e.shape}, {vid.shape}"
195
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
196
+ vid_e = vid_block.modulation(vid_e).chunk(6, dim=2)
197
+
198
+ # video self-attention
199
+ vid_y = vid_block.self_attn(
200
+ vid_block.norm1(vid).bfloat16() * (1 + vid_e[1].squeeze(2)) + vid_e[0].squeeze(2), vid_seq_lens, vid_grid_sizes,
201
+ vid_freqs)
202
+
203
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
204
+ vid = vid + vid_y * vid_e[2].squeeze(2)
205
+
206
+ og_audio = audio
207
+
208
+ # audio cross-attention
209
+ audio = self.single_fusion_cross_attention_ffn_forward(
210
+ audio_block,
211
+ audio,
212
+ audio_grid_sizes,
213
+ audio_freqs,
214
+ vid,
215
+ vid_seq_lens,
216
+ vid_grid_sizes,
217
+ vid_freqs,
218
+ audio_context,
219
+ audio_context_lens,
220
+ audio_e
221
+ )
222
+
223
+ assert not torch.equal(og_audio, audio), "Audio should be changed after cross-attention!"
224
+
225
+ # video cross-attention
226
+ vid = self.single_fusion_cross_attention_ffn_forward(
227
+ vid_block,
228
+ vid,
229
+ vid_grid_sizes,
230
+ vid_freqs,
231
+ og_audio,
232
+ audio_seq_lens,
233
+ audio_grid_sizes,
234
+ audio_freqs,
235
+ vid_context,
236
+ vid_context_lens,
237
+ vid_e
238
+ )
239
+
240
+ return vid, audio
241
+
242
+ def forward(
243
+ self,
244
+ vid,
245
+ audio,
246
+ t,
247
+ vid_context,
248
+ audio_context,
249
+ vid_seq_len,
250
+ audio_seq_len,
251
+ clip_fea=None,
252
+ clip_fea_audio=None,
253
+ y=None,
254
+ first_frame_is_clean=False,
255
+ slg_layer=False
256
+ ):
257
+
258
+ assert clip_fea is None
259
+ assert y is None
260
+
261
+ if vid is None or all([x is None for x in vid]):
262
+ assert vid_context is None
263
+ assert vid_seq_len is None
264
+ assert self.audio_model is not None
265
+
266
+ return None, self.audio_model(x=audio, t=t, context=audio_context, seq_len=audio_seq_len, clip_fea=clip_fea_audio, y=None)
267
+
268
+ if audio is None or all([x is None for x in audio]):
269
+ assert clip_fea_audio is None
270
+ assert audio_context is None
271
+ assert audio_seq_len is None
272
+ assert self.video_model is not None
273
+
274
+ return self.video_model(x=vid, t=t, context=vid_context, seq_len=vid_seq_len, clip_fea=clip_fea, y=y, first_frame_is_clean=first_frame_is_clean), None
275
+
276
+ vid, vid_e, vid_kwargs = self.video_model.prepare_transformer_block_kwargs(
277
+ x=vid, t=t, context=vid_context, seq_len=vid_seq_len, clip_fea=clip_fea, y=y, first_frame_is_clean=first_frame_is_clean
278
+ )
279
+
280
+ audio, audio_e, audio_kwargs = self.audio_model.prepare_transformer_block_kwargs(
281
+ x=audio, t=t, context=audio_context, seq_len=audio_seq_len, clip_fea=clip_fea_audio, y=None, first_frame_is_clean=False
282
+ )
283
+
284
+ kwargs = self.merge_kwargs(vid_kwargs, audio_kwargs)
285
+
286
+ for i in range(self.num_blocks):
287
+ """
288
+ 1 fusion block refers to 1 audio block with 1 video block.
289
+ """
290
+ if slg_layer > 0 and i == slg_layer:
291
+ continue
292
+ vid_block = self.video_model.blocks[i]
293
+ audio_block = self.audio_model.blocks[i]
294
+ vid, audio = gradient_checkpointing(
295
+ enabled=(self.training and self.gradient_checkpointing),
296
+ module=self.single_fusion_block_forward,
297
+ vid_block=vid_block,
298
+ audio_block=audio_block,
299
+ vid=vid,
300
+ audio=audio,
301
+ **kwargs
302
+ )
303
+
304
+ vid = self.video_model.post_transformer_block_out(vid, vid_kwargs['grid_sizes'], vid_e)
305
+ audio = self.audio_model.post_transformer_block_out(audio, audio_kwargs['grid_sizes'], audio_e)
306
+
307
+ return vid, audio
308
+
309
+ def init_weights(self):
310
+ if self.audio_model is not None:
311
+ self.audio_model.init_weights()
312
+
313
+ if self.video_model is not None:
314
+ self.video_model.init_weights()
315
+
316
+ for name, mod in self.video_model.named_modules():
317
+ if "fusion" in name and isinstance(mod, nn.Linear):
318
+ with torch.no_grad():
319
+ mod.weight.div_(10.0)
320
+
321
+
322
+ def set_rope_params(self):
323
+ self.video_model.set_rope_params()
324
+ self.audio_model.set_rope_params()
ovi/modules/mmaudio/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # MMAudio package
ovi/modules/mmaudio/ext/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+