Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +19 -0
- .gradio/certificate.pem +31 -0
- .python-version +1 -0
- LICENSE +201 -0
- README.md +152 -7
- assets/mochi-factory.webp +0 -0
- contrib/README.md +6 -0
- contrib/modal/lora.yaml +58 -0
- contrib/modal/main.py +285 -0
- contrib/modal/readme.md +55 -0
- demos/api_example.py +53 -0
- demos/cli.py +205 -0
- demos/comfyui_nodes.py +0 -0
- demos/fine_tuner/README.md +103 -0
- demos/fine_tuner/configs/lora.yaml +58 -0
- demos/fine_tuner/dataset.py +45 -0
- demos/fine_tuner/embed_captions.py +66 -0
- demos/fine_tuner/encode_videos.py +142 -0
- demos/fine_tuner/preprocess.bash +87 -0
- demos/fine_tuner/run.bash +92 -0
- demos/fine_tuner/train.py +398 -0
- demos/fine_tuner/trim_and_crop_videos.py +110 -0
- demos/gradio_ui.py +57 -0
- demos/gradio_ui_adapted.py +39 -0
- demos/gradio_ui_fixed.py +48 -0
- demos/gradio_ui_fixed.py~ +48 -0
- demos/test_encoder_decoder.py +79 -0
- pyproject.toml +37 -0
- scripts/download_weights.py +66 -0
- scripts/format.bash +5 -0
- scripts/pytorch_to_safe_tensors.py +24 -0
- scripts/typecheck.bash +2 -0
- scripts/weights_to_fp8.py +0 -0
- src/genmo/lib/attn_imports.py +29 -0
- src/genmo/lib/progress.py +87 -0
- src/genmo/lib/utils.py +67 -0
- src/genmo/mochi_preview/__init__.py +0 -0
- src/genmo/mochi_preview/dit/joint_model/__init__.py +0 -0
- src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py +737 -0
- src/genmo/mochi_preview/dit/joint_model/context_parallel.py +158 -0
- src/genmo/mochi_preview/dit/joint_model/layers.py +179 -0
- src/genmo/mochi_preview/dit/joint_model/lora.py +112 -0
- src/genmo/mochi_preview/dit/joint_model/mod_rmsnorm.py +15 -0
- src/genmo/mochi_preview/dit/joint_model/residual_tanh_gated_rmsnorm.py +20 -0
- src/genmo/mochi_preview/dit/joint_model/rope_mixed.py +88 -0
- src/genmo/mochi_preview/dit/joint_model/temporal_rope.py +34 -0
- src/genmo/mochi_preview/dit/joint_model/utils.py +109 -0
- src/genmo/mochi_preview/pipelines.py +682 -0
- src/genmo/mochi_preview/vae/__init__.py +0 -0
- src/genmo/mochi_preview/vae/cp_conv.py +155 -0
.gitignore
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.venv
|
2 |
+
.venv_test
|
3 |
+
dist
|
4 |
+
__pycache__
|
5 |
+
mochi.egg-info
|
6 |
+
genmo.egg-info
|
7 |
+
outputs
|
8 |
+
build
|
9 |
+
.ruff_cache
|
10 |
+
*.mp4
|
11 |
+
*.txt
|
12 |
+
*.pt
|
13 |
+
*.log
|
14 |
+
*.json
|
15 |
+
*.safetensors
|
16 |
+
wandb/
|
17 |
+
*.err
|
18 |
+
*.out
|
19 |
+
*.MOV
|
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.10
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2024 Genmo
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,157 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: red
|
5 |
-
colorTo: red
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.13.1
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: genmoai
|
3 |
+
app_file: ./demos/gradio_ui.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 5.13.1
|
|
|
|
|
6 |
---
|
7 |
+
# Mochi 1
|
8 |
+
[Blog](https://www.genmo.ai/blog) | [Hugging Face](https://huggingface.co/genmo/mochi-1-preview) | [Playground](https://www.genmo.ai/play) | [Careers](https://jobs.ashbyhq.com/genmo)
|
9 |
|
10 |
+
A state of the art video generation model by [Genmo](https://genmo.ai).
|
11 |
+
|
12 |
+
https://github.com/user-attachments/assets/4d268d02-906d-4cb0-87cc-f467f1497108
|
13 |
+
|
14 |
+
## News
|
15 |
+
|
16 |
+
- ⭐ **November 26, 2024**: Added support for [LoRA fine-tuning](demos/fine_tuner/README.md)
|
17 |
+
- ⭐ **November 5, 2024**: Consumer-GPU support for Mochi [natively in ComfyUI](https://x.com/ComfyUI/status/1853838184012251317)
|
18 |
+
|
19 |
+
## Overview
|
20 |
+
|
21 |
+
Mochi 1 preview is an open state-of-the-art video generation model with high-fidelity motion and strong prompt adherence in preliminary evaluation. This model dramatically closes the gap between closed and open video generation systems. We’re releasing the model under a permissive Apache 2.0 license. Try this model for free on [our playground](https://genmo.ai/play).
|
22 |
+
|
23 |
+
## Installation
|
24 |
+
|
25 |
+
Install using [uv](https://github.com/astral-sh/uv):
|
26 |
+
|
27 |
+
```bash
|
28 |
+
git clone https://github.com/genmoai/models
|
29 |
+
cd models
|
30 |
+
pip install uv
|
31 |
+
uv venv .venv
|
32 |
+
source .venv/bin/activate
|
33 |
+
uv pip install setuptools
|
34 |
+
uv pip install -e . --no-build-isolation
|
35 |
+
```
|
36 |
+
|
37 |
+
If you want to install flash attention, you can use:
|
38 |
+
```
|
39 |
+
uv pip install -e .[flash] --no-build-isolation
|
40 |
+
```
|
41 |
+
|
42 |
+
You will also need to install [FFMPEG](https://www.ffmpeg.org/) to turn your outputs into videos.
|
43 |
+
|
44 |
+
## Download Weights
|
45 |
+
|
46 |
+
Use [download_weights.py](scripts/download_weights.py) to download the model + VAE to a local directory. Use it like this:
|
47 |
+
```bash
|
48 |
+
python3 ./scripts/download_weights.py weights/
|
49 |
+
```
|
50 |
+
|
51 |
+
Or, directly download the weights from [Hugging Face](https://huggingface.co/genmo/mochi-1-preview/tree/main) or via `magnet:?xt=urn:btih:441da1af7a16bcaa4f556964f8028d7113d21cbb&dn=weights&tr=udp://tracker.opentrackr.org:1337/announce` to a folder on your computer.
|
52 |
+
|
53 |
+
## Running
|
54 |
+
|
55 |
+
Start the gradio UI with
|
56 |
+
|
57 |
+
```bash
|
58 |
+
python3 ./demos/gradio_ui.py --model_dir weights/ --cpu_offload
|
59 |
+
```
|
60 |
+
|
61 |
+
Or generate videos directly from the CLI with
|
62 |
+
|
63 |
+
```bash
|
64 |
+
python3 ./demos/cli.py --model_dir weights/ --cpu_offload
|
65 |
+
```
|
66 |
+
|
67 |
+
If you have a fine-tuned LoRA in the safetensors format, you can add `--lora_path <path/to/my_mochi_lora.safetensors>` to either `gradio_ui.py` or `cli.py`.
|
68 |
+
|
69 |
+
## API
|
70 |
+
|
71 |
+
This repository comes with a simple, composable API, so you can programmatically call the model. You can find a full example [here](demos/api_example.py). But, roughly, it looks like this:
|
72 |
+
|
73 |
+
```python
|
74 |
+
from genmo.mochi_preview.pipelines import (
|
75 |
+
DecoderModelFactory,
|
76 |
+
DitModelFactory,
|
77 |
+
MochiSingleGPUPipeline,
|
78 |
+
T5ModelFactory,
|
79 |
+
linear_quadratic_schedule,
|
80 |
+
)
|
81 |
+
|
82 |
+
pipeline = MochiSingleGPUPipeline(
|
83 |
+
text_encoder_factory=T5ModelFactory(),
|
84 |
+
dit_factory=DitModelFactory(
|
85 |
+
model_path=f"weights/dit.safetensors", model_dtype="bf16"
|
86 |
+
),
|
87 |
+
decoder_factory=DecoderModelFactory(
|
88 |
+
model_path=f"weights/decoder.safetensors",
|
89 |
+
),
|
90 |
+
cpu_offload=True,
|
91 |
+
decode_type="tiled_spatial",
|
92 |
+
)
|
93 |
+
|
94 |
+
video = pipeline(
|
95 |
+
height=480,
|
96 |
+
width=848,
|
97 |
+
num_frames=31,
|
98 |
+
num_inference_steps=64,
|
99 |
+
sigma_schedule=linear_quadratic_schedule(64, 0.025),
|
100 |
+
cfg_schedule=[6.0] * 64,
|
101 |
+
batch_cfg=False,
|
102 |
+
prompt="your favorite prompt here ...",
|
103 |
+
negative_prompt="",
|
104 |
+
seed=12345,
|
105 |
+
)
|
106 |
+
```
|
107 |
+
|
108 |
+
## Fine-tuning with LoRA
|
109 |
+
|
110 |
+
We provide [an easy-to-use trainer](demos/fine_tuner/README.md) that allows you to build LoRA fine-tunes of Mochi on your own videos. The model can be fine-tuned on one H100 or A100 80GB GPU.
|
111 |
+
|
112 |
+
## Model Architecture
|
113 |
+
|
114 |
+
Mochi 1 represents a significant advancement in open-source video generation, featuring a 10 billion parameter diffusion model built on our novel Asymmetric Diffusion Transformer (AsymmDiT) architecture. Trained entirely from scratch, it is the largest video generative model ever openly released. And best of all, it’s a simple, hackable architecture. Additionally, we are releasing an inference harness that includes an efficient context parallel implementation.
|
115 |
+
|
116 |
+
Alongside Mochi, we are open-sourcing our video AsymmVAE. We use an asymmetric encoder-decoder structure to build an efficient high quality compression model. Our AsymmVAE causally compresses videos to a 128x smaller size, with an 8x8 spatial and a 6x temporal compression to a 12-channel latent space.
|
117 |
+
|
118 |
+
### AsymmVAE Model Specs
|
119 |
+
|Params <br> Count | Enc Base <br> Channels | Dec Base <br> Channels |Latent <br> Dim | Spatial <br> Compression | Temporal <br> Compression |
|
120 |
+
|:--:|:--:|:--:|:--:|:--:|:--:|
|
121 |
+
|362M | 64 | 128 | 12 | 8x8 | 6x |
|
122 |
+
|
123 |
+
An AsymmDiT efficiently processes user prompts alongside compressed video tokens by streamlining text processing and focusing neural network capacity on visual reasoning. AsymmDiT jointly attends to text and visual tokens with multi-modal self-attention and learns separate MLP layers for each modality, similar to Stable Diffusion 3. However, our visual stream has nearly 4 times as many parameters as the text stream via a larger hidden dimension. To unify the modalities in self-attention, we use non-square QKV and output projection layers. This asymmetric design reduces inference memory requirements.
|
124 |
+
Many modern diffusion models use multiple pretrained language models to represent user prompts. In contrast, Mochi 1 simply encodes prompts with a single T5-XXL language model.
|
125 |
+
|
126 |
+
### AsymmDiT Model Specs
|
127 |
+
|Params <br> Count | Num <br> Layers | Num <br> Heads | Visual <br> Dim | Text <br> Dim | Visual <br> Tokens | Text <br> Tokens |
|
128 |
+
|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
|
129 |
+
|10B | 48 | 24 | 3072 | 1536 | 44520 | 256 |
|
130 |
+
|
131 |
+
## Hardware Requirements
|
132 |
+
The repository supports both multi-GPU operation (splitting the model across multiple graphics cards) and single-GPU operation, though it requires approximately 60GB VRAM when running on a single GPU. While ComfyUI can optimize Mochi to run on less than 20GB VRAM, this implementation prioritizes flexibility over memory efficiency. When using this repository, we recommend using at least 1 H100 GPU.
|
133 |
+
|
134 |
+
## Safety
|
135 |
+
Genmo video models are general text-to-video diffusion models that inherently reflect the biases and preconceptions found in their training data. While steps have been taken to limit NSFW content, organizations should implement additional safety protocols and careful consideration before deploying these model weights in any commercial services or products.
|
136 |
+
|
137 |
+
## Limitations
|
138 |
+
Under the research preview, Mochi 1 is a living and evolving checkpoint. There are a few known limitations. The initial release generates videos at 480p today. In some edge cases with extreme motion, minor warping and distortions can also occur. Mochi 1 is also optimized for photorealistic styles so does not perform well with animated content. We also anticipate that the community will fine-tune the model to suit various aesthetic preferences.
|
139 |
+
|
140 |
+
## Related Work
|
141 |
+
- [ComfyUI-MochiWrapper](https://github.com/kijai/ComfyUI-MochiWrapper) adds ComfyUI support for Mochi. The integration of Pytorch's SDPA attention was based on their repository.
|
142 |
+
- [ComfyUI-MochiEdit](https://github.com/logtd/ComfyUI-MochiEdit) adds ComfyUI nodes for video editing, such as object insertion and restyling.
|
143 |
+
- [mochi-xdit](https://github.com/xdit-project/mochi-xdit) is a fork of this repository and improve the parallel inference speed with [xDiT](https://github.com/xdit-project/xdit).
|
144 |
+
- [Modal script](contrib/modal/readme.md) for fine-tuning Mochi on Modal GPUs.
|
145 |
+
|
146 |
+
|
147 |
+
## BibTeX
|
148 |
+
```
|
149 |
+
@misc{genmo2024mochi,
|
150 |
+
title={Mochi 1},
|
151 |
+
author={Genmo Team},
|
152 |
+
year={2024},
|
153 |
+
publisher = {GitHub},
|
154 |
+
journal = {GitHub repository},
|
155 |
+
howpublished={\url{https://github.com/genmoai/models}}
|
156 |
+
}
|
157 |
+
```
|
assets/mochi-factory.webp
ADDED
contrib/README.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Mochi Community Contributions
|
2 |
+
|
3 |
+
`mochi/contrib` contains community contributed pipelines for running and customizing Mochi.
|
4 |
+
|
5 |
+
## Index:
|
6 |
+
- `mochi/contrib/modal` - [Script](contrib/modal/readme.md) for fine-tuning Mochi on Modal GPUs.
|
contrib/modal/lora.yaml
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
init_checkpoint_path: /weights/dit.safetensors
|
2 |
+
checkpoint_dir: /finetunes/my_mochi_lora
|
3 |
+
train_data_dir: /videos_prepared
|
4 |
+
attention_mode: sdpa
|
5 |
+
single_video_mode: false # Useful for debugging whether your model can learn a single video
|
6 |
+
|
7 |
+
# You only need this if you're using wandb
|
8 |
+
wandb:
|
9 |
+
# project: mochi_1_lora
|
10 |
+
# name: ${checkpoint_dir}
|
11 |
+
# group: null
|
12 |
+
|
13 |
+
optimizer:
|
14 |
+
lr: 2e-4
|
15 |
+
weight_decay: 0.01
|
16 |
+
|
17 |
+
model:
|
18 |
+
type: lora
|
19 |
+
kwargs:
|
20 |
+
# Apply LoRA to the QKV projection and the output projection of the attention block.
|
21 |
+
qkv_proj_lora_rank: 16
|
22 |
+
qkv_proj_lora_alpha: 16
|
23 |
+
qkv_proj_lora_dropout: 0.
|
24 |
+
out_proj_lora_rank: 16
|
25 |
+
out_proj_lora_alpha: 16
|
26 |
+
out_proj_lora_dropout: 0.
|
27 |
+
|
28 |
+
training:
|
29 |
+
model_dtype: bf16
|
30 |
+
warmup_steps: 200
|
31 |
+
num_qkv_checkpoint: 48
|
32 |
+
num_ff_checkpoint: 48
|
33 |
+
num_post_attn_checkpoint: 48
|
34 |
+
num_steps: 2000
|
35 |
+
save_interval: 200
|
36 |
+
caption_dropout: 0.1
|
37 |
+
grad_clip: 0.0
|
38 |
+
save_safetensors: true
|
39 |
+
|
40 |
+
# Used for generating samples during training to monitor progress ...
|
41 |
+
sample:
|
42 |
+
interval: 200
|
43 |
+
output_dir: ${checkpoint_dir}/samples
|
44 |
+
decoder_path: /weights/decoder.safetensors
|
45 |
+
prompts:
|
46 |
+
- A pristine snowglobe featuring a winter scene sits peacefully. The glass begins to crumble into fine powder, as the entire sphere deteriorates into sparkling dust that drifts outward. The fake snow mingles with the crystalline particles, creating a glittering cloud captured in high-speed photography.
|
47 |
+
- A vintage pocket watch ticks quietly on an antique desk. Its brass casing starts to deteriorate, turning to fine metallic powder that lifts into the air. The gears and springs fragment into microscopic particles, each piece breaking down into a shimmering bronze dust that hangs suspended. The scene is richly detailed with warm, brass tones.
|
48 |
+
- A cello is propped up against a wall, a single spotlight illuminating it. The wooden surface begins to decay into fine sawdust, the instrument gradually breaking apart as its form disintegrates into a cloud of earthen particles. The strings unravel into delicate fibers that float amidst the swirling wooden dust. The scene is vibrant and colorful.
|
49 |
+
- A graphics card sits inside an oven, heatwaves around it. The silicon and metal components begin to break down at a molecular level, deteriorating into a dark cloud of fine metallic and mineral dust that hangs suspended in the heated air. The scene is darkly lit, high contrast, with a focus on the suspended particles.
|
50 |
+
- A delicate porcelain teacup sits on a marble countertop. The ceramic structure begins to crumble into a fine, chalk-like powder, breaking down into countless microscopic white particles that drift upward in graceful patterns. The scene is bright and crisp with dramatic lighting illuminating the cloud of porcelain dust.
|
51 |
+
seed: 12345
|
52 |
+
kwargs:
|
53 |
+
height: 480
|
54 |
+
width: 848
|
55 |
+
num_frames: 37
|
56 |
+
num_inference_steps: 64
|
57 |
+
sigma_schedule_python_code: "linear_quadratic_schedule(64, 0.025)"
|
58 |
+
cfg_schedule_python_code: "[6.0] * 64"
|
contrib/modal/main.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import modal
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
# Creating our Modal App
|
5 |
+
app = modal.App("mochi-finetune")
|
6 |
+
|
7 |
+
# Creating volumes for data, intermediate data, and produced weights
|
8 |
+
videos_volume = modal.Volume.from_name("mochi-tune-videos", create_if_missing=True)
|
9 |
+
videos_prepared_volume = modal.Volume.from_name("mochi-tune-videos-prepared", create_if_missing=True)
|
10 |
+
weights_volume = modal.Volume.from_name("mochi-tune-weights", create_if_missing=True)
|
11 |
+
finetunes_volume = modal.Volume.from_name("mochi-tune-finetunes", create_if_missing=True)
|
12 |
+
outputs_volume = modal.Volume.from_name("mochi-tune-outputs", create_if_missing=True)
|
13 |
+
|
14 |
+
USERNAME = "genmoai"
|
15 |
+
REPOSITORY = "mochi"
|
16 |
+
CLONE_CMD = f"git clone https://github.com/{USERNAME}/{REPOSITORY}.git"
|
17 |
+
|
18 |
+
# Building our container image
|
19 |
+
base_img = (
|
20 |
+
modal.Image.debian_slim()
|
21 |
+
.apt_install("git", "ffmpeg", "bc", "zlib1g-dev", "libjpeg-dev", "wget")
|
22 |
+
.run_commands(CLONE_CMD)
|
23 |
+
.workdir(REPOSITORY)
|
24 |
+
.pip_install("gdown", "setuptools", "wheel")
|
25 |
+
.run_commands('pip install -e . --no-build-isolation')
|
26 |
+
)
|
27 |
+
|
28 |
+
MINUTES = 60
|
29 |
+
HOURS = 60 * MINUTES
|
30 |
+
|
31 |
+
# Remote function for downloading a labeled video dataset from Google Drive
|
32 |
+
# Run it with:
|
33 |
+
# modal run main::download_videos
|
34 |
+
@app.function(image=base_img,
|
35 |
+
volumes={
|
36 |
+
"/videos": videos_volume,
|
37 |
+
}
|
38 |
+
)
|
39 |
+
def download_videos():
|
40 |
+
'''Downloads videos from google drive into our volume'''
|
41 |
+
import gdown
|
42 |
+
import zipfile
|
43 |
+
|
44 |
+
name = "dissolve"
|
45 |
+
url = "https://drive.google.com/uc?id=1ldoBppcsv5Ueoikh0zCmNviojRCrGXQN"
|
46 |
+
output = f"{name}.zip"
|
47 |
+
gdown.download(url, output, quiet=False)
|
48 |
+
with zipfile.ZipFile(output, "r") as zip_ref:
|
49 |
+
zip_ref.extractall("/videos")
|
50 |
+
|
51 |
+
# Remote function for downloading the model weights from Hugging Face
|
52 |
+
# Run it with:
|
53 |
+
# modal run main::download_weights
|
54 |
+
@app.function(image=base_img,
|
55 |
+
volumes={
|
56 |
+
"/weights": weights_volume,
|
57 |
+
},
|
58 |
+
timeout=1*HOURS,
|
59 |
+
)
|
60 |
+
def download_weights():
|
61 |
+
# HF-transfer and snapshot download tend to hang on the large model, so we download it manually with wget
|
62 |
+
import subprocess
|
63 |
+
print("🍡 Downloading weights from Hugging Face. This may take 30 minutes.")
|
64 |
+
# ~30 min
|
65 |
+
subprocess.run(["wget", "https://huggingface.co/genmo/mochi-1-preview/resolve/main/dit.safetensors", "-O", "/weights/dit.safetensors"])
|
66 |
+
# ~1 min
|
67 |
+
subprocess.run(["wget", "https://huggingface.co/genmo/mochi-1-preview/resolve/main/decoder.safetensors", "-O", "/weights/decoder.safetensors"])
|
68 |
+
# ~20 sec
|
69 |
+
subprocess.run(["wget", "https://huggingface.co/genmo/mochi-1-preview/resolve/main/encoder.safetensors", "-O", "/weights/encoder.safetensors"])
|
70 |
+
|
71 |
+
# Remote function for preprocessing the video dataset
|
72 |
+
# Run it with:
|
73 |
+
# modal run main::preprocess
|
74 |
+
@app.function(
|
75 |
+
image=base_img,
|
76 |
+
volumes={
|
77 |
+
"/videos": videos_volume,
|
78 |
+
"/videos_prepared": videos_prepared_volume,
|
79 |
+
"/weights": weights_volume,
|
80 |
+
},
|
81 |
+
timeout=30*MINUTES,
|
82 |
+
gpu="H100"
|
83 |
+
)
|
84 |
+
def preprocess():
|
85 |
+
import subprocess
|
86 |
+
print("🍡 Preprocessing videos. This may take 2-3 minutes.")
|
87 |
+
video_dir = "videos_dissolve"
|
88 |
+
subprocess.run([
|
89 |
+
"bash", "demos/fine_tuner/preprocess.bash",
|
90 |
+
"-v", f"/videos/{video_dir}/",
|
91 |
+
"-o", "/videos_prepared/",
|
92 |
+
"-w", "/weights/",
|
93 |
+
"-n", "37"
|
94 |
+
])
|
95 |
+
|
96 |
+
# Remote function for finetuning the model using the prepared dataset
|
97 |
+
# Configure the run in lora.yaml
|
98 |
+
# Run it with:
|
99 |
+
# modal run main::finetune
|
100 |
+
@app.function(
|
101 |
+
image=base_img,
|
102 |
+
volumes={
|
103 |
+
"/videos": videos_volume,
|
104 |
+
"/videos_prepared": videos_prepared_volume,
|
105 |
+
"/weights": weights_volume,
|
106 |
+
"/finetunes": finetunes_volume,
|
107 |
+
},
|
108 |
+
mounts=[modal.Mount.from_local_file("lora.yaml", remote_path=f"{REPOSITORY}/lora.yaml")],
|
109 |
+
timeout=4*HOURS,
|
110 |
+
gpu="H100"
|
111 |
+
)
|
112 |
+
def finetune():
|
113 |
+
import subprocess
|
114 |
+
print("🍡 Finetuning Mochi. This may take 3 hours.")
|
115 |
+
print("🍡 See your mochi-tune-finetunes volume for intermediate checkpoints and samples.")
|
116 |
+
subprocess.run([
|
117 |
+
"bash", "demos/fine_tuner/run.bash",
|
118 |
+
"-c", "lora.yaml", # from our locally mounted yaml file
|
119 |
+
"-n", "1",
|
120 |
+
])
|
121 |
+
|
122 |
+
# Remote function (Modal @cls) for running inference on one or multiple videos
|
123 |
+
# Run it with the @local_entrypoint below
|
124 |
+
@app.cls(
|
125 |
+
image = base_img,
|
126 |
+
volumes={
|
127 |
+
"/weights": weights_volume,
|
128 |
+
"/finetunes": finetunes_volume,
|
129 |
+
"/outputs": outputs_volume,
|
130 |
+
},
|
131 |
+
timeout=30*MINUTES,
|
132 |
+
gpu="H100"
|
133 |
+
)
|
134 |
+
class MochiLora():
|
135 |
+
def __init__(self, model_dir: str = "/weights", lora_path: str = None, cpu_offload: bool = False):
|
136 |
+
self.model_dir = model_dir
|
137 |
+
self.lora_path = lora_path
|
138 |
+
self.cpu_offload = cpu_offload
|
139 |
+
|
140 |
+
@modal.enter()
|
141 |
+
def start(self):
|
142 |
+
from genmo.mochi_preview.pipelines import (
|
143 |
+
DecoderModelFactory,
|
144 |
+
DitModelFactory,
|
145 |
+
MochiMultiGPUPipeline,
|
146 |
+
MochiSingleGPUPipeline,
|
147 |
+
T5ModelFactory,
|
148 |
+
)
|
149 |
+
import torch
|
150 |
+
|
151 |
+
"""Initialize the model - this runs once when the container starts"""
|
152 |
+
print("🍡 Loading Mochi model.")
|
153 |
+
|
154 |
+
self.num_gpus = torch.cuda.device_count()
|
155 |
+
|
156 |
+
# Configure pipeline based on GPU count
|
157 |
+
klass = MochiSingleGPUPipeline if self.num_gpus == 1 else MochiMultiGPUPipeline
|
158 |
+
|
159 |
+
kwargs = dict(
|
160 |
+
text_encoder_factory=T5ModelFactory(),
|
161 |
+
dit_factory=DitModelFactory(
|
162 |
+
model_path=f"{self.model_dir}/dit.safetensors",
|
163 |
+
lora_path=self.lora_path,
|
164 |
+
model_dtype="bf16",
|
165 |
+
),
|
166 |
+
decoder_factory=DecoderModelFactory(
|
167 |
+
model_path=f"{self.model_dir}/decoder.safetensors",
|
168 |
+
),
|
169 |
+
)
|
170 |
+
|
171 |
+
if self.num_gpus > 1:
|
172 |
+
assert not self.lora_path, f"Lora not supported in multi-GPU mode"
|
173 |
+
assert not self.cpu_offload, "CPU offload not supported in multi-GPU mode"
|
174 |
+
kwargs["world_size"] = self.num_gpus
|
175 |
+
else:
|
176 |
+
kwargs["cpu_offload"] = self.cpu_offload
|
177 |
+
kwargs["decode_type"] = "tiled_spatial"
|
178 |
+
kwargs["fast_init"] = not self.lora_path
|
179 |
+
kwargs["strict_load"] = not self.lora_path
|
180 |
+
kwargs["decode_args"] = dict(overlap=8)
|
181 |
+
|
182 |
+
self.pipeline = klass(**kwargs)
|
183 |
+
print(f"🍡 Model loaded successfully with {self.num_gpus} GPUs")
|
184 |
+
|
185 |
+
@modal.method()
|
186 |
+
def generate(self,
|
187 |
+
prompt: str,
|
188 |
+
negative_prompt: str = "",
|
189 |
+
width: int = 848,
|
190 |
+
height: int = 480,
|
191 |
+
num_frames: int = 163,
|
192 |
+
seed: int = 1710977262,
|
193 |
+
cfg_scale: float = 6.0,
|
194 |
+
num_inference_steps: int = 64) -> str:
|
195 |
+
"""Generate video based on the prompt and parameters"""
|
196 |
+
|
197 |
+
print("🍡 Generating video.")
|
198 |
+
|
199 |
+
import json
|
200 |
+
import os
|
201 |
+
import time
|
202 |
+
|
203 |
+
import numpy as np
|
204 |
+
|
205 |
+
from genmo.lib.progress import progress_bar
|
206 |
+
from genmo.lib.utils import save_video
|
207 |
+
from genmo.mochi_preview.pipelines import linear_quadratic_schedule
|
208 |
+
|
209 |
+
|
210 |
+
# Create sigma schedule
|
211 |
+
sigma_schedule = linear_quadratic_schedule(num_inference_steps, 0.025)
|
212 |
+
cfg_schedule = [cfg_scale] * num_inference_steps
|
213 |
+
|
214 |
+
args = {
|
215 |
+
"height": height,
|
216 |
+
"width": width,
|
217 |
+
"num_frames": num_frames,
|
218 |
+
"sigma_schedule": sigma_schedule,
|
219 |
+
"cfg_schedule": cfg_schedule,
|
220 |
+
"num_inference_steps": num_inference_steps,
|
221 |
+
"batch_cfg": False,
|
222 |
+
"prompt": prompt,
|
223 |
+
"negative_prompt": negative_prompt,
|
224 |
+
"seed": seed,
|
225 |
+
}
|
226 |
+
|
227 |
+
with progress_bar(type="tqdm"):
|
228 |
+
final_frames = self.pipeline(**args)
|
229 |
+
final_frames = final_frames[0]
|
230 |
+
|
231 |
+
assert isinstance(final_frames, np.ndarray)
|
232 |
+
assert final_frames.dtype == np.float32
|
233 |
+
|
234 |
+
# Save to mounted volume
|
235 |
+
output_dir = "/outputs" # Assuming this path exists in the mounted volume
|
236 |
+
os.makedirs(output_dir, exist_ok=True)
|
237 |
+
output_path = os.path.join(output_dir, f"output_{int(time.time())}.mp4")
|
238 |
+
|
239 |
+
save_video(final_frames, output_path)
|
240 |
+
|
241 |
+
# Save generation parameters
|
242 |
+
json_path = os.path.splitext(output_path)[0] + ".json"
|
243 |
+
json.dump(args, open(json_path, "w"), indent=4)
|
244 |
+
|
245 |
+
print(f"🍡 Video saved to {output_path}")
|
246 |
+
outputs_volume.commit()
|
247 |
+
return output_path.split("/")[-1]
|
248 |
+
|
249 |
+
# Local entrypoint for using the MochiLora class
|
250 |
+
# Select the lora_path you'd want to use from the finetunes volume
|
251 |
+
# Then it with:
|
252 |
+
# modal run main
|
253 |
+
@app.local_entrypoint()
|
254 |
+
def main(
|
255 |
+
prompt="A pristine snowglobe featuring a winter scene sits peacefully. The glass begins to crumble into fine powder, as the entire sphere deteriorates into sparkling dust that drifts outward. The fake snow mingles with the crystalline particles, creating a glittering cloud captured in high-speed photography.",
|
256 |
+
negative_prompt="blurry, low quality",
|
257 |
+
width=848,
|
258 |
+
height=480,
|
259 |
+
num_frames=49, # (num_frames - 1) must be divisible by 6
|
260 |
+
seed=1710977262,
|
261 |
+
cfg_scale=6.0,
|
262 |
+
num_inference_steps=64,
|
263 |
+
lora_path="/finetunes/my_mochi_lora/model_2000.lora.safetensors",
|
264 |
+
cpu_offload=True,
|
265 |
+
):
|
266 |
+
lora = MochiLora(
|
267 |
+
lora_path=lora_path, # your lora path
|
268 |
+
cpu_offload=cpu_offload,
|
269 |
+
)
|
270 |
+
output_path = lora.generate.remote(
|
271 |
+
prompt=prompt,
|
272 |
+
negative_prompt=negative_prompt,
|
273 |
+
width=width,
|
274 |
+
height=height,
|
275 |
+
num_frames=num_frames,
|
276 |
+
seed=seed,
|
277 |
+
cfg_scale=cfg_scale,
|
278 |
+
num_inference_steps=num_inference_steps,
|
279 |
+
)
|
280 |
+
|
281 |
+
local_dir = Path("/tmp/mochi")
|
282 |
+
local_dir.mkdir(exist_ok=True, parents=True)
|
283 |
+
local_path = local_dir / output_path
|
284 |
+
local_path.write_bytes(b"".join(outputs_volume.read_file(output_path)))
|
285 |
+
print(f"🍡 video saved locally at {local_path}")
|
contrib/modal/readme.md
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Finetuning Mochi with LoRA on Modal
|
2 |
+
|
3 |
+
This example demonstrates how to run the Mochi finetuner on Modal GPUs.
|
4 |
+
|
5 |
+
### Setup
|
6 |
+
Install [Modal](https://modal.com/docs/guide).
|
7 |
+
```bash
|
8 |
+
pip install modal
|
9 |
+
modal setup
|
10 |
+
```
|
11 |
+
|
12 |
+
### Fetch the dataset
|
13 |
+
There is a labeled dataset for a dissolving visual effect available on Google Drive. Download it into the `mochi-tune-videos` modal volume with:
|
14 |
+
```bash
|
15 |
+
modal run main::download_videos
|
16 |
+
```
|
17 |
+
|
18 |
+
### Download the model weights
|
19 |
+
Download the model weights from Hugging Face into the `mochi-tune-weights` modal volume with:
|
20 |
+
```bash
|
21 |
+
modal run -d main::download_weights
|
22 |
+
```
|
23 |
+
Note that this download can take more than 30 minutes. The `-d` flag allows you to exit the terminal session without losing progress.
|
24 |
+
|
25 |
+
### Prepare the dataset
|
26 |
+
We now run the preprocessing script to prepare the dataset for finetuning:
|
27 |
+
```bash
|
28 |
+
modal run main::preprocess
|
29 |
+
```
|
30 |
+
This puts preprocessed training input into the `mochi-tune-videos-prepared` modal volume.
|
31 |
+
|
32 |
+
### Finetuning
|
33 |
+
Finetune the model using the prepared dataset.
|
34 |
+
|
35 |
+
You may configure the finetune run using the `lora.yaml` file, such as number of steps, learning rate, etc.
|
36 |
+
|
37 |
+
Run the finetuning with:
|
38 |
+
```bash
|
39 |
+
modal run -d main::finetune
|
40 |
+
```
|
41 |
+
|
42 |
+
This will produce a series of checkpoints, as well as video samples generated along the training process. You can view these files in the Modal `moshi-tune-finetunes` volume using the Storage tab in the dashboard.
|
43 |
+
|
44 |
+
### Inference
|
45 |
+
You can now use the MochiLora class to generate videos from a prompt. The `main` entrypoint will initialize the model to use the specified LoRA weights from your finetuning run.
|
46 |
+
|
47 |
+
```bash
|
48 |
+
modal run main
|
49 |
+
```
|
50 |
+
or with more parameters:
|
51 |
+
```bash
|
52 |
+
modal run main lora-path="/finetunes/my_mochi_lora/model_1000.lora.safetensors" prompt="A pristine snowglobe featuring a winter scene sits peacefully. The glass begins to crumble into fine powder, as the entire sphere deteriorates into sparkling dust that drifts outward."
|
53 |
+
```
|
54 |
+
|
55 |
+
See modal run main --help for all inference options.
|
demos/api_example.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python
|
2 |
+
import sys
|
3 |
+
from pathlib import Path
|
4 |
+
from textwrap import dedent
|
5 |
+
|
6 |
+
from genmo.lib.progress import progress_bar
|
7 |
+
from genmo.lib.utils import save_video
|
8 |
+
from genmo.mochi_preview.pipelines import (
|
9 |
+
DecoderModelFactory,
|
10 |
+
DitModelFactory,
|
11 |
+
MochiSingleGPUPipeline,
|
12 |
+
T5ModelFactory,
|
13 |
+
linear_quadratic_schedule,
|
14 |
+
)
|
15 |
+
|
16 |
+
MOCHI_DIR = sys.argv[1]
|
17 |
+
assert Path(MOCHI_DIR).exists(), f"Model directory {MOCHI_DIR} does not exist."
|
18 |
+
pipeline = MochiSingleGPUPipeline(
|
19 |
+
text_encoder_factory=T5ModelFactory(),
|
20 |
+
dit_factory=DitModelFactory(model_path=f"{MOCHI_DIR}/dit.safetensors", model_dtype="bf16"),
|
21 |
+
decoder_factory=DecoderModelFactory(
|
22 |
+
model_path=f"{MOCHI_DIR}/vae.safetensors",
|
23 |
+
model_stats_path=f"{MOCHI_DIR}/vae_stats.json",
|
24 |
+
),
|
25 |
+
cpu_offload=True,
|
26 |
+
decode_type="tiled_full",
|
27 |
+
)
|
28 |
+
|
29 |
+
PROMPT = dedent("""
|
30 |
+
A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl
|
31 |
+
filled with lemons and sprigs of mint against a peach-colored background.
|
32 |
+
The hand gently tosses the lemon up and catches it, showcasing its smooth texture.
|
33 |
+
A beige string bag sits beside the bowl, adding a rustic touch to the scene.
|
34 |
+
Additional lemons, one halved, are scattered around the base of the bowl.
|
35 |
+
The even lighting enhances the vibrant colors and creates a fresh,
|
36 |
+
inviting atmosphere.
|
37 |
+
""")
|
38 |
+
|
39 |
+
video = pipeline(
|
40 |
+
height=480,
|
41 |
+
width=848,
|
42 |
+
num_frames=31,
|
43 |
+
num_inference_steps=64,
|
44 |
+
sigma_schedule=linear_quadratic_schedule(64, 0.025),
|
45 |
+
cfg_schedule=[4.5] * 64,
|
46 |
+
batch_cfg=False,
|
47 |
+
prompt=PROMPT,
|
48 |
+
negative_prompt="",
|
49 |
+
seed=12345,
|
50 |
+
)
|
51 |
+
|
52 |
+
with progress_bar(type="tqdm"):
|
53 |
+
save_video(video[0], "video.mp4")
|
demos/cli.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
|
6 |
+
import click
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from genmo.lib.progress import progress_bar
|
11 |
+
from genmo.lib.utils import save_video
|
12 |
+
from genmo.mochi_preview.pipelines import (
|
13 |
+
DecoderModelFactory,
|
14 |
+
DitModelFactory,
|
15 |
+
MochiMultiGPUPipeline,
|
16 |
+
MochiSingleGPUPipeline,
|
17 |
+
T5ModelFactory,
|
18 |
+
linear_quadratic_schedule,
|
19 |
+
)
|
20 |
+
|
21 |
+
pipeline = None
|
22 |
+
model_dir_path = None
|
23 |
+
lora_path = None
|
24 |
+
num_gpus = torch.cuda.device_count()
|
25 |
+
cpu_offload = False
|
26 |
+
|
27 |
+
|
28 |
+
def configure_model(model_dir_path_, lora_path_, cpu_offload_, fast_model_=False):
|
29 |
+
global model_dir_path, lora_path, cpu_offload
|
30 |
+
model_dir_path = model_dir_path_
|
31 |
+
lora_path = lora_path_
|
32 |
+
cpu_offload = cpu_offload_
|
33 |
+
|
34 |
+
|
35 |
+
def load_model():
|
36 |
+
global num_gpus, pipeline, model_dir_path, lora_path
|
37 |
+
if pipeline is None:
|
38 |
+
MOCHI_DIR = model_dir_path
|
39 |
+
print(f"Launching with {num_gpus} GPUs. If you want to force single GPU mode use CUDA_VISIBLE_DEVICES=0.")
|
40 |
+
klass = MochiSingleGPUPipeline if num_gpus == 1 else MochiMultiGPUPipeline
|
41 |
+
kwargs = dict(
|
42 |
+
text_encoder_factory=T5ModelFactory(),
|
43 |
+
dit_factory=DitModelFactory(
|
44 |
+
model_path=f"{MOCHI_DIR}/dit.safetensors",
|
45 |
+
lora_path=lora_path,
|
46 |
+
model_dtype="bf16",
|
47 |
+
),
|
48 |
+
decoder_factory=DecoderModelFactory(
|
49 |
+
model_path=f"{MOCHI_DIR}/decoder.safetensors",
|
50 |
+
),
|
51 |
+
)
|
52 |
+
if num_gpus > 1:
|
53 |
+
assert not lora_path, f"Lora not supported in multi-GPU mode"
|
54 |
+
assert not cpu_offload, "CPU offload not supported in multi-GPU mode"
|
55 |
+
kwargs["world_size"] = num_gpus
|
56 |
+
else:
|
57 |
+
kwargs["cpu_offload"] = cpu_offload
|
58 |
+
kwargs["decode_type"] = "tiled_spatial"
|
59 |
+
kwargs["fast_init"] = not lora_path
|
60 |
+
kwargs["strict_load"] = not lora_path
|
61 |
+
kwargs["decode_args"] = dict(overlap=8)
|
62 |
+
pipeline = klass(**kwargs)
|
63 |
+
|
64 |
+
|
65 |
+
def generate_video(
|
66 |
+
prompt,
|
67 |
+
negative_prompt,
|
68 |
+
width,
|
69 |
+
height,
|
70 |
+
num_frames,
|
71 |
+
seed,
|
72 |
+
cfg_scale,
|
73 |
+
num_inference_steps,
|
74 |
+
threshold_noise=0.025,
|
75 |
+
linear_steps=None,
|
76 |
+
output_dir="outputs",
|
77 |
+
):
|
78 |
+
load_model()
|
79 |
+
|
80 |
+
# Fast mode parameters: threshold_noise=0.1, linear_steps=6, cfg_scale=1.5, num_inference_steps=8
|
81 |
+
sigma_schedule = linear_quadratic_schedule(num_inference_steps, threshold_noise, linear_steps)
|
82 |
+
|
83 |
+
# cfg_schedule should be a list of floats of length num_inference_steps.
|
84 |
+
# For simplicity, we just use the same cfg scale at all timesteps,
|
85 |
+
# but more optimal schedules may use varying cfg, e.g:
|
86 |
+
# [5.0] * (num_inference_steps // 2) + [4.5] * (num_inference_steps // 2)
|
87 |
+
cfg_schedule = [cfg_scale] * num_inference_steps
|
88 |
+
|
89 |
+
args = {
|
90 |
+
"height": height,
|
91 |
+
"width": width,
|
92 |
+
"num_frames": num_frames,
|
93 |
+
"sigma_schedule": sigma_schedule,
|
94 |
+
"cfg_schedule": cfg_schedule,
|
95 |
+
"num_inference_steps": num_inference_steps,
|
96 |
+
# We *need* flash attention to batch cfg
|
97 |
+
# and it's only worth doing in a high-memory regime (assume multiple GPUs)
|
98 |
+
"batch_cfg": False,
|
99 |
+
"prompt": prompt,
|
100 |
+
"negative_prompt": negative_prompt,
|
101 |
+
"seed": seed,
|
102 |
+
}
|
103 |
+
|
104 |
+
with progress_bar(type="tqdm"):
|
105 |
+
final_frames = pipeline(**args)
|
106 |
+
|
107 |
+
final_frames = final_frames[0]
|
108 |
+
|
109 |
+
assert isinstance(final_frames, np.ndarray)
|
110 |
+
assert final_frames.dtype == np.float32
|
111 |
+
|
112 |
+
os.makedirs(output_dir, exist_ok=True)
|
113 |
+
output_path = os.path.join(output_dir, f"output_{int(time.time())}.mp4")
|
114 |
+
|
115 |
+
save_video(final_frames, output_path)
|
116 |
+
json_path = os.path.splitext(output_path)[0] + ".json"
|
117 |
+
json.dump(args, open(json_path, "w"), indent=4)
|
118 |
+
|
119 |
+
return output_path
|
120 |
+
|
121 |
+
|
122 |
+
from textwrap import dedent
|
123 |
+
|
124 |
+
DEFAULT_PROMPT = dedent("""
|
125 |
+
A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl
|
126 |
+
filled with lemons and sprigs of mint against a peach-colored background.
|
127 |
+
The hand gently tosses the lemon up and catches it, showcasing its smooth texture.
|
128 |
+
A beige string bag sits beside the bowl, adding a rustic touch to the scene.
|
129 |
+
Additional lemons, one halved, are scattered around the base of the bowl.
|
130 |
+
The even lighting enhances the vibrant colors and creates a fresh,
|
131 |
+
inviting atmosphere.
|
132 |
+
""")
|
133 |
+
|
134 |
+
|
135 |
+
@click.command()
|
136 |
+
@click.option("--prompt", default=DEFAULT_PROMPT, help="Prompt for video generation.")
|
137 |
+
@click.option("--sweep-file", help="JSONL file containing one config per line.")
|
138 |
+
@click.option("--negative_prompt", default="", help="Negative prompt for video generation.")
|
139 |
+
@click.option("--width", default=848, type=int, help="Width of the video.")
|
140 |
+
@click.option("--height", default=480, type=int, help="Height of the video.")
|
141 |
+
@click.option("--num_frames", default=163, type=int, help="Number of frames.")
|
142 |
+
@click.option("--seed", default=1710977262, type=int, help="Random seed.")
|
143 |
+
@click.option("--cfg_scale", default=6.0, type=float, help="CFG Scale.")
|
144 |
+
@click.option("--num_steps", default=64, type=int, help="Number of inference steps.")
|
145 |
+
@click.option("--model_dir", required=True, help="Path to the model directory.")
|
146 |
+
@click.option("--lora_path", required=False, help="Path to the lora file.")
|
147 |
+
@click.option("--cpu_offload", is_flag=True, help="Whether to offload model to CPU")
|
148 |
+
@click.option("--out_dir", default="outputs", help="Output directory for generated videos")
|
149 |
+
@click.option("--threshold-noise", default=0.025, help="threshold noise")
|
150 |
+
@click.option("--linear-steps", default=None, type=int, help="linear steps")
|
151 |
+
def generate_cli(
|
152 |
+
prompt, sweep_file, negative_prompt, width, height, num_frames, seed, cfg_scale, num_steps,
|
153 |
+
model_dir, lora_path, cpu_offload, out_dir, threshold_noise, linear_steps
|
154 |
+
):
|
155 |
+
configure_model(model_dir, lora_path, cpu_offload)
|
156 |
+
|
157 |
+
if sweep_file:
|
158 |
+
with open(sweep_file, 'r') as f:
|
159 |
+
for i, line in enumerate(f):
|
160 |
+
if not line.strip():
|
161 |
+
continue
|
162 |
+
config = json.loads(line)
|
163 |
+
current_prompt = config.get('prompt', prompt)
|
164 |
+
current_cfg_scale = config.get('cfg_scale', cfg_scale)
|
165 |
+
current_num_steps = config.get('num_steps', num_steps)
|
166 |
+
current_threshold_noise = config.get('threshold_noise', threshold_noise)
|
167 |
+
current_linear_steps = config.get('linear_steps', linear_steps)
|
168 |
+
current_seed = config.get('seed', seed)
|
169 |
+
current_width = config.get('width', width)
|
170 |
+
current_height = config.get('height', height)
|
171 |
+
current_num_frames = config.get('num_frames', num_frames)
|
172 |
+
|
173 |
+
output_path = generate_video(
|
174 |
+
current_prompt,
|
175 |
+
negative_prompt,
|
176 |
+
current_width,
|
177 |
+
current_height,
|
178 |
+
current_num_frames,
|
179 |
+
current_seed,
|
180 |
+
current_cfg_scale,
|
181 |
+
current_num_steps,
|
182 |
+
threshold_noise=current_threshold_noise,
|
183 |
+
linear_steps=current_linear_steps,
|
184 |
+
output_dir=out_dir,
|
185 |
+
)
|
186 |
+
click.echo(f"Video {i+1} generated at: {output_path}")
|
187 |
+
else:
|
188 |
+
output_path = generate_video(
|
189 |
+
prompt,
|
190 |
+
negative_prompt,
|
191 |
+
width,
|
192 |
+
height,
|
193 |
+
num_frames,
|
194 |
+
seed,
|
195 |
+
cfg_scale,
|
196 |
+
num_steps,
|
197 |
+
threshold_noise=threshold_noise,
|
198 |
+
linear_steps=linear_steps,
|
199 |
+
output_dir=out_dir,
|
200 |
+
)
|
201 |
+
click.echo(f"Video generated at: {output_path}")
|
202 |
+
|
203 |
+
|
204 |
+
if __name__ == "__main__":
|
205 |
+
generate_cli()
|
demos/comfyui_nodes.py
ADDED
File without changes
|
demos/fine_tuner/README.md
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Mochi 1 LoRA Fine-tuner
|
2 |
+
|
3 |
+
![Mochi being made](../../assets/mochi-factory.webp)
|
4 |
+
|
5 |
+
|
6 |
+
This folder contains tools for fine-tuning the Mochi 1 model. It supports [LoRA](https://arxiv.org/abs/2106.09685) fine-tuning on a single GPU.
|
7 |
+
|
8 |
+
## Quick Start (Single GPU)
|
9 |
+
This shows you how to prepare your dataset for single GPU.
|
10 |
+
|
11 |
+
First, setup the inference code and download Mochi 1 weights following [README.md](../../README.md).
|
12 |
+
All commands below assume you are in the top-level directory of the Mochi repo.
|
13 |
+
|
14 |
+
### 1. Collect your videos and captions
|
15 |
+
Collect your videos (supported formats: MP4, MOV) into a folder, e.g. `videos/`. Then, write a detailed description of each of the videos in a txt file with the same name. For example,
|
16 |
+
```
|
17 |
+
videos/
|
18 |
+
video_1.mp4
|
19 |
+
video_1.txt -- One-paragraph description of video_1
|
20 |
+
video_2.mp4
|
21 |
+
video_2.txt -- One-paragraph description of video_2
|
22 |
+
...
|
23 |
+
```
|
24 |
+
|
25 |
+
### 2. Process videos and captions (About 2 minutes)
|
26 |
+
Update the paths in the command below to match your dataset. Videos are processed at 30 FPS, so make sure your videos are at least `num_frames / 30` seconds long.
|
27 |
+
```bash
|
28 |
+
bash demos/fine_tuner/preprocess.bash -v videos/ -o videos_prepared/ -w weights/ --num_frames 37
|
29 |
+
```
|
30 |
+
|
31 |
+
### 3. Fine-tune the model
|
32 |
+
Update `./demos/fine_tuner/configs/lora.yaml` to customize the fine-tuning process,
|
33 |
+
including prompts to generate at various points of the fine-tuning process and the path to your prepared videos.
|
34 |
+
|
35 |
+
Launch LoRA fine-tuning on single GPU:
|
36 |
+
```bash
|
37 |
+
bash ./demos/fine_tuner/run.bash -c ./demos/fine_tuner/configs/lora.yaml -n 1
|
38 |
+
```
|
39 |
+
|
40 |
+
Samples will be generated in `finetunes/my_mochi_lora/samples` every 200 steps.
|
41 |
+
|
42 |
+
### 4. Use your fine-tuned weights to generate videos!
|
43 |
+
Update `--lora_path` to the path of your fine-tuned weights and run:
|
44 |
+
```python
|
45 |
+
python3 ./demos/cli.py --model_dir weights/ --lora_path finetunes/my_mochi_lora/model_2000.lora.safetensors --num_frames 37 --cpu_offload --prompt "A delicate porcelain teacup sits on a marble countertop. The teacup suddenly shatters into hundreds of white ceramic shards that scatter through the air. The scene is bright and crisp with dramatic lighting."
|
46 |
+
```
|
47 |
+
|
48 |
+
You can increase the number of frames to generate a longer video. Finally, share your creations with the community by uploading your LoRA and sample videos to Hugging Face.
|
49 |
+
|
50 |
+
## System Requirements
|
51 |
+
|
52 |
+
**Single GPU:**
|
53 |
+
- 1x H100 or A100 (80 GB VRAM is recommended)
|
54 |
+
- Less VRAM is required if training with less than 1 second long videos.
|
55 |
+
|
56 |
+
**Supported video lengths:** Up to 85 frames (~2.8 seconds at 30 FPS)
|
57 |
+
- Choose a frame count in increments of 6: 25, 31, 37, ... 79, 85.
|
58 |
+
- Training on 37 frames uses 50 GB of VRAM. On 1 H100, each training step takes about 1.67 s/it,
|
59 |
+
and you'll start seeing changes to your videos within 200-400 steps. Training for 1,000 steps takes about 30 minutes.
|
60 |
+
|
61 |
+
Settings tested on 1x H100 SXM:
|
62 |
+
|
63 |
+
| Frames | Video Length | VRAM | Time/step | num_qkv_checkpoint | num_ff_checkpoint | num_post_attn_checkpoint |
|
64 |
+
|--------|--------------|------|-----------|-------------------|-------------------|-------------------------|
|
65 |
+
| 37 frames | 1.2 second videos | 50 GB VRAM | 1.67 s/it | 48 | 48† | 48 |
|
66 |
+
| 61 frames | 2.0 second videos | 64 GB VRAM | 3.35 s/it | 48 | 48† | 48 |
|
67 |
+
| 79 frames | 2.6 second videos | 69-78 GB VRAM | 4.92 s/it | 48 | 48† | 48 |
|
68 |
+
| 85 frames | 2.8 second videos | 80 GB VRAM | 5.44 s/it | 48 | 48 | 48 |
|
69 |
+
|
70 |
+
*† As the VRAM is not fully used, you can lower `num_ff_checkpoint` to speed up training.*
|
71 |
+
|
72 |
+
## Technical Details
|
73 |
+
|
74 |
+
- LoRA fine-tuning updates the query, key, and value projection matrices, as well as the output projection matrix.
|
75 |
+
These settings are configurable in `./demos/fine_tuner/configs/lora.yaml`.
|
76 |
+
- We welcome contributions and suggestions for improved settings.
|
77 |
+
|
78 |
+
## Known Limitations
|
79 |
+
|
80 |
+
- No support for training on multiple GPUs
|
81 |
+
- LoRA inference is restricted to 1-GPU (for now)
|
82 |
+
|
83 |
+
## Tips
|
84 |
+
|
85 |
+
- Be as descriptive as possible in your captions.
|
86 |
+
- A learning rate around 1e-4 or 2e-4 seems effective for LoRA fine-tuning.
|
87 |
+
- For larger datasets or to customize the model aggressively, increase `num_steps` in in the YAML.
|
88 |
+
- To monitor training loss, uncomment the `wandb` section in the YAML and run `wandb login` or set the `WANDB_API_KEY` environment variable.
|
89 |
+
- Videos are trimmed to the **first** `num_frames` frames. Make sure your clips contain the content you care about near the beginning.
|
90 |
+
You can check the trimmed versions after running `preprocess.bash` to make sure they look good.
|
91 |
+
- When capturing HDR videos on an iPhone, convert your .mov files to .mp4 using the Handbrake application. Our preprocessing script won't produce the correct colorspace otherwise, and your fine-tuned videos may look overly bright.
|
92 |
+
|
93 |
+
### If you are running out of GPU memory, make sure:
|
94 |
+
- `COMPILE_DIT=1` is set in `demos/fine_tuner/run.bash`.
|
95 |
+
This enables model compilation, which saves memory and speeds up training!
|
96 |
+
- `num_post_attn_checkpoint`, `num_ff_checkpoint`, and `num_qkv_checkpoint` are set to 48 in your YAML.
|
97 |
+
You can checkpoint up to 48 layers, saving memory at the cost of slower training.
|
98 |
+
- If all else fails, reduce `num_frames` when processing your videos and in your YAML.
|
99 |
+
You can fine-tune Mochi on shorter videos, and still generate longer videos at inference time.
|
100 |
+
|
101 |
+
## Diffusers trainer
|
102 |
+
|
103 |
+
The [Diffusers Python library](https://github.com/huggingface/diffusers) supports LoRA fine-tuning of Mochi 1 as well. Check out [this link](https://github.com/a-r-r-o-w/cogvideox-factory/tree/80d1150a0e233a1b2b98dd0367c06276989d049c/training/mochi-1) for more details.
|
demos/fine_tuner/configs/lora.yaml
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
init_checkpoint_path: weights/dit.safetensors
|
2 |
+
checkpoint_dir: finetunes/my_mochi_lora
|
3 |
+
train_data_dir: videos_prepared
|
4 |
+
attention_mode: sdpa
|
5 |
+
single_video_mode: false # Useful for debugging whether your model can learn a single video
|
6 |
+
|
7 |
+
# You only need this if you're using wandb
|
8 |
+
wandb:
|
9 |
+
# project: mochi_1_lora
|
10 |
+
# name: ${checkpoint_dir}
|
11 |
+
# group: null
|
12 |
+
|
13 |
+
optimizer:
|
14 |
+
lr: 2e-4
|
15 |
+
weight_decay: 0.01
|
16 |
+
|
17 |
+
model:
|
18 |
+
type: lora
|
19 |
+
kwargs:
|
20 |
+
# Apply LoRA to the QKV projection and the output projection of the attention block.
|
21 |
+
qkv_proj_lora_rank: 16
|
22 |
+
qkv_proj_lora_alpha: 16
|
23 |
+
qkv_proj_lora_dropout: 0.
|
24 |
+
out_proj_lora_rank: 16
|
25 |
+
out_proj_lora_alpha: 16
|
26 |
+
out_proj_lora_dropout: 0.
|
27 |
+
|
28 |
+
training:
|
29 |
+
model_dtype: bf16
|
30 |
+
warmup_steps: 200
|
31 |
+
num_qkv_checkpoint: 48
|
32 |
+
num_ff_checkpoint: 48
|
33 |
+
num_post_attn_checkpoint: 48
|
34 |
+
num_steps: 2000
|
35 |
+
save_interval: 200
|
36 |
+
caption_dropout: 0.1
|
37 |
+
grad_clip: 0.0
|
38 |
+
save_safetensors: true
|
39 |
+
|
40 |
+
# Used for generating samples during training to monitor progress ...
|
41 |
+
sample:
|
42 |
+
interval: 200
|
43 |
+
output_dir: ${checkpoint_dir}/samples
|
44 |
+
decoder_path: weights/decoder.safetensors
|
45 |
+
prompts:
|
46 |
+
- A pristine snowglobe featuring a winter scene sits peacefully. The globe violently explodes, sending glass, water, and glittering fake snow in all directions. The scene is captured with high-speed photography.
|
47 |
+
- A vintage pocket watch ticks quietly on an antique desk. Suddenly, it explodes into gears, springs and metal fragments that scatter through the air. The scene is richly detailed with warm, brass tones.
|
48 |
+
- A cello is propped up against a wall, a single spotlight illuminating it. The cello explodes into wooden fragments, sending debris everywhere. The scene is vibrant and colorful.
|
49 |
+
- A graphics card sits inside an oven, heatwaves around it. Suddenly, the graphics card explodes into numerous fragments, sending debris everywhere. The scene is darkly lit, high contrast, with a focus on the shattered pieces.
|
50 |
+
- A delicate porcelain teacup sits on a marble countertop. The teacup suddenly shatters into hundreds of white ceramic shards that scatter through the air. The scene is bright and crisp with dramatic lighting.
|
51 |
+
seed: 12345
|
52 |
+
kwargs:
|
53 |
+
height: 480
|
54 |
+
width: 848
|
55 |
+
num_frames: 37
|
56 |
+
num_inference_steps: 64
|
57 |
+
sigma_schedule_python_code: "linear_quadratic_schedule(64, 0.025)"
|
58 |
+
cfg_schedule_python_code: "[6.0] * 64"
|
demos/fine_tuner/dataset.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import click
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import DataLoader, Dataset
|
6 |
+
|
7 |
+
|
8 |
+
def load_to_cpu(x):
|
9 |
+
return torch.load(x, map_location=torch.device("cpu"), weights_only=True)
|
10 |
+
|
11 |
+
|
12 |
+
class LatentEmbedDataset(Dataset):
|
13 |
+
def __init__(self, file_paths, repeat=1):
|
14 |
+
self.items = [
|
15 |
+
(Path(p).with_suffix(".latent.pt"), Path(p).with_suffix(".embed.pt"))
|
16 |
+
for p in file_paths
|
17 |
+
if Path(p).with_suffix(".latent.pt").is_file() and Path(p).with_suffix(".embed.pt").is_file()
|
18 |
+
]
|
19 |
+
self.items = self.items * repeat
|
20 |
+
print(f"Loaded {len(self.items)}/{len(file_paths)} valid file pairs.")
|
21 |
+
|
22 |
+
def __len__(self):
|
23 |
+
return len(self.items)
|
24 |
+
|
25 |
+
def __getitem__(self, idx):
|
26 |
+
latent_path, embed_path = self.items[idx]
|
27 |
+
return load_to_cpu(latent_path), load_to_cpu(embed_path)
|
28 |
+
|
29 |
+
|
30 |
+
@click.command()
|
31 |
+
@click.argument("directory", type=click.Path(exists=True, file_okay=False))
|
32 |
+
def process_videos(directory):
|
33 |
+
dir_path = Path(directory)
|
34 |
+
mp4_files = [str(f) for f in dir_path.glob("**/*.mp4") if not f.name.endswith(".recon.mp4")]
|
35 |
+
assert mp4_files, f"No mp4 files found"
|
36 |
+
|
37 |
+
dataset = LatentEmbedDataset(mp4_files)
|
38 |
+
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
|
39 |
+
|
40 |
+
for latents, embeds in dataloader:
|
41 |
+
print([(k, v.shape) for k, v in latents.items()])
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
process_videos()
|
demos/fine_tuner/embed_captions.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import click
|
5 |
+
import torch
|
6 |
+
from tqdm import tqdm
|
7 |
+
from transformers import T5Tokenizer
|
8 |
+
|
9 |
+
from genmo.mochi_preview.pipelines import T5_MODEL, T5ModelFactory, get_conditioning_for_prompts
|
10 |
+
|
11 |
+
|
12 |
+
@click.command()
|
13 |
+
@click.argument("captions_dir", type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path))
|
14 |
+
@click.option("--device_id", default=0, help="GPU device ID to use")
|
15 |
+
@click.option("--overwrite", "-ow", is_flag=True, help="Overwrite existing embeddings")
|
16 |
+
def process_captions(captions_dir: Path, device_id: int, overwrite=True) -> None:
|
17 |
+
"""Process all text files in a directory using T5 encoder.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
captions_dir: Directory containing input text files
|
21 |
+
device_id: GPU device ID to use
|
22 |
+
"""
|
23 |
+
|
24 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
25 |
+
torch.backends.cudnn.allow_tf32 = True
|
26 |
+
|
27 |
+
# Get all text file paths
|
28 |
+
text_paths = list(captions_dir.glob("**/*.txt"))
|
29 |
+
if not text_paths:
|
30 |
+
print(f"No text files found in {captions_dir}")
|
31 |
+
return
|
32 |
+
|
33 |
+
# Initialize model and tokenizer
|
34 |
+
model_factory = T5ModelFactory()
|
35 |
+
device = f"cuda:{device_id}"
|
36 |
+
model = model_factory.get_model(local_rank=0, device_id=device_id, world_size=1)
|
37 |
+
tokenizer = T5Tokenizer.from_pretrained(T5_MODEL, legacy=False)
|
38 |
+
|
39 |
+
with tqdm(total=len(text_paths)) as pbar:
|
40 |
+
for text_path in text_paths:
|
41 |
+
embed_path = text_path.with_suffix(".embed.pt")
|
42 |
+
if embed_path.exists() and not overwrite:
|
43 |
+
pbar.write(f"Skipping {text_path} - embeddings already exist")
|
44 |
+
continue
|
45 |
+
|
46 |
+
pbar.write(f"Processing {text_path}")
|
47 |
+
try:
|
48 |
+
with open(text_path) as f:
|
49 |
+
text = f.read().strip()
|
50 |
+
|
51 |
+
with torch.inference_mode():
|
52 |
+
conditioning = get_conditioning_for_prompts(tokenizer, model, device, [text])
|
53 |
+
|
54 |
+
torch.save(conditioning, embed_path)
|
55 |
+
|
56 |
+
except Exception as e:
|
57 |
+
import traceback
|
58 |
+
|
59 |
+
traceback.print_exc()
|
60 |
+
pbar.write(f"Error processing {text_path}: {str(e)}")
|
61 |
+
|
62 |
+
pbar.update(1)
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
process_captions()
|
demos/fine_tuner/encode_videos.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
import traceback
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import click
|
8 |
+
import ray
|
9 |
+
import torch
|
10 |
+
import torchvision
|
11 |
+
from einops import rearrange
|
12 |
+
|
13 |
+
import genmo.mochi_preview.dit.joint_model.context_parallel as cp
|
14 |
+
import genmo.mochi_preview.vae.cp_conv as cp_conv
|
15 |
+
from genmo.lib.progress import get_new_progress_bar, progress_bar
|
16 |
+
from genmo.lib.utils import Timer, save_video
|
17 |
+
from genmo.mochi_preview.pipelines import DecoderModelFactory, EncoderModelFactory
|
18 |
+
from genmo.mochi_preview.vae.models import add_fourier_features, decode_latents
|
19 |
+
|
20 |
+
|
21 |
+
class GPUContext:
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
*,
|
25 |
+
encoder_factory: Optional[EncoderModelFactory] = None,
|
26 |
+
decoder_factory: Optional[DecoderModelFactory] = None,
|
27 |
+
):
|
28 |
+
t = Timer()
|
29 |
+
self.device = torch.device(f"cuda")
|
30 |
+
if encoder_factory is not None:
|
31 |
+
with t("load_encoder"):
|
32 |
+
self.encoder = encoder_factory.get_model()
|
33 |
+
if decoder_factory is not None:
|
34 |
+
with t("load_decoder"):
|
35 |
+
self.decoder = decoder_factory.get_model()
|
36 |
+
t.print_stats()
|
37 |
+
|
38 |
+
|
39 |
+
def preprocess(ctx: GPUContext, vid_path: Path, shape: str, reconstruct: bool):
|
40 |
+
T, H, W = [int(s) for s in shape.split("x")]
|
41 |
+
assert (T - 1) % 6 == 0, "Expected T to be 1 mod 6"
|
42 |
+
video, _, metadata = torchvision.io.read_video(
|
43 |
+
str(vid_path), output_format="THWC", pts_unit="secs")
|
44 |
+
fps = metadata["video_fps"]
|
45 |
+
video = rearrange(video, "t h w c -> c t h w")
|
46 |
+
og_shape = video.shape
|
47 |
+
assert video.shape[2] == H, f"Expected {vid_path} to have height {H}, got {video.shape}"
|
48 |
+
assert video.shape[3] == W, f"Expected {vid_path} to have width {W}, got {video.shape}"
|
49 |
+
assert video.shape[1] >= T, f"Expected {vid_path} to have at least {T} frames, got {video.shape}"
|
50 |
+
if video.shape[1] > T:
|
51 |
+
video = video[:, :T]
|
52 |
+
print(f"Trimmed video from {og_shape[1]} to first {T} frames")
|
53 |
+
video = video.unsqueeze(0)
|
54 |
+
video = video.float() / 127.5 - 1.0
|
55 |
+
video = video.to(ctx.device)
|
56 |
+
video = add_fourier_features(video)
|
57 |
+
|
58 |
+
assert video.ndim == 5
|
59 |
+
video = cp.local_shard(video, dim=2) # split along time dimension
|
60 |
+
|
61 |
+
with torch.inference_mode():
|
62 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
63 |
+
ldist = ctx.encoder(video)
|
64 |
+
|
65 |
+
print(f"{og_shape} -> {ldist.mean.shape}")
|
66 |
+
torch.save(
|
67 |
+
dict(mean=ldist.mean, logvar=ldist.logvar),
|
68 |
+
vid_path.with_suffix(".latent.pt"),
|
69 |
+
)
|
70 |
+
|
71 |
+
if reconstruct:
|
72 |
+
latents = ldist.sample()
|
73 |
+
frames = decode_latents(ctx.decoder, latents)
|
74 |
+
frames = frames.cpu().numpy()
|
75 |
+
save_video(frames[0], str(vid_path.with_suffix(".recon.mp4")), fps=fps)
|
76 |
+
|
77 |
+
|
78 |
+
@click.command()
|
79 |
+
@click.argument("videos_dir", type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path))
|
80 |
+
@click.option(
|
81 |
+
"--model_dir",
|
82 |
+
type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path),
|
83 |
+
help="Path to folder containing Mochi's VAE encoder and decoder weights. Download from Hugging Face: https://huggingface.co/genmo/mochi-1-preview/blob/main/encoder.safetensors and https://huggingface.co/genmo/mochi-1-preview/blob/main/decoder.safetensors",
|
84 |
+
default="weights/",
|
85 |
+
)
|
86 |
+
@click.option("--num_gpus", default=1, help="Number of GPUs to split the encoder over")
|
87 |
+
@click.option(
|
88 |
+
"--recon_interval", default=10, help="Reconstruct one out of every N videos (0 to disable reconstruction)"
|
89 |
+
)
|
90 |
+
@click.option("--shape", default="163x480x848", help="Shape of the video to encode")
|
91 |
+
@click.option("--overwrite", "-ow", is_flag=True, help="Overwrite existing latents")
|
92 |
+
def batch_process(
|
93 |
+
videos_dir: Path, model_dir: Path, num_gpus: int, recon_interval: int, shape: str, overwrite: bool
|
94 |
+
) -> None:
|
95 |
+
"""Process all videos in a directory using multiple GPUs.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
videos_dir: Directory containing input videos
|
99 |
+
encoder_path: Path to encoder model weights
|
100 |
+
decoder_path: Path to decoder model weights
|
101 |
+
num_gpus: Number of GPUs to use for parallel processing
|
102 |
+
recon_interval: Frequency of video reconstructions (0 to disable)
|
103 |
+
"""
|
104 |
+
|
105 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
106 |
+
torch.backends.cudnn.allow_tf32 = True
|
107 |
+
|
108 |
+
# Get all video paths
|
109 |
+
video_paths = list(videos_dir.glob("**/*.mp4"))
|
110 |
+
if not video_paths:
|
111 |
+
print(f"No MP4 files found in {videos_dir}")
|
112 |
+
return
|
113 |
+
|
114 |
+
preproc = GPUContext(
|
115 |
+
encoder_factory=EncoderModelFactory(model_path=os.path.join(model_dir, "encoder.safetensors")),
|
116 |
+
decoder_factory=DecoderModelFactory(model_path=os.path.join(model_dir, "decoder.safetensors")),
|
117 |
+
)
|
118 |
+
with progress_bar(type="ray_tqdm"):
|
119 |
+
for idx, video_path in get_new_progress_bar((list(enumerate(sorted(video_paths))))):
|
120 |
+
if str(video_path).endswith(".recon.mp4"):
|
121 |
+
print(f"Skipping {video_path} b/c it is a reconstruction")
|
122 |
+
continue
|
123 |
+
|
124 |
+
print(f"Processing {video_path}")
|
125 |
+
try:
|
126 |
+
if video_path.with_suffix(".latent.pt").exists() and not overwrite:
|
127 |
+
print(f"Skipping {video_path}")
|
128 |
+
continue
|
129 |
+
|
130 |
+
preprocess(
|
131 |
+
ctx=preproc,
|
132 |
+
vid_path=video_path,
|
133 |
+
shape=shape,
|
134 |
+
reconstruct=recon_interval != 0 and idx % recon_interval == 0,
|
135 |
+
)
|
136 |
+
except Exception as e:
|
137 |
+
traceback.print_exc()
|
138 |
+
print(f"Error processing {video_path}: {str(e)}")
|
139 |
+
|
140 |
+
|
141 |
+
if __name__ == "__main__":
|
142 |
+
batch_process()
|
demos/fine_tuner/preprocess.bash
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /bin/bash
|
2 |
+
|
3 |
+
# Enable job control and set process group
|
4 |
+
set -eo pipefail
|
5 |
+
set -x
|
6 |
+
|
7 |
+
# Function to check if a command exists
|
8 |
+
command_exists() {
|
9 |
+
command -v "$1" >/dev/null 2>&1
|
10 |
+
}
|
11 |
+
|
12 |
+
# Function to install bc using the appropriate package manager
|
13 |
+
install_bc() {
|
14 |
+
if command_exists apt-get; then
|
15 |
+
sudo apt-get update && sudo apt-get install -y bc
|
16 |
+
elif command_exists yum; then
|
17 |
+
sudo yum install -y bc
|
18 |
+
else
|
19 |
+
echo "Error: Could not find package manager to install bc"
|
20 |
+
exit 1
|
21 |
+
fi
|
22 |
+
}
|
23 |
+
|
24 |
+
# Check and install bc if necessary
|
25 |
+
if ! command_exists bc; then
|
26 |
+
echo "bc is not installed. Installing bc..."
|
27 |
+
install_bc
|
28 |
+
fi
|
29 |
+
|
30 |
+
# Function to display help
|
31 |
+
usage() {
|
32 |
+
echo "Usage: $0 -v|--videos_dir videos_dir -o|--output_dir output_dir -w|--weights_dir weights_dir -n|--num_frames num_frames"
|
33 |
+
echo " -v, --videos_dir Path to the videos directory"
|
34 |
+
echo " -o, --output_dir Path to the output directory"
|
35 |
+
echo " -w, --weights_dir Path to the weights directory"
|
36 |
+
echo " -n, --num_frames Number of frames"
|
37 |
+
exit 1
|
38 |
+
}
|
39 |
+
|
40 |
+
# Function to check if the next argument is missing
|
41 |
+
check_argument() {
|
42 |
+
if [[ -z "$2" || "$2" == -* ]]; then
|
43 |
+
echo "Error: Argument for $1 is missing"
|
44 |
+
usage
|
45 |
+
fi
|
46 |
+
}
|
47 |
+
|
48 |
+
# Parse command-line arguments
|
49 |
+
while [[ "$#" -gt 0 ]]; do
|
50 |
+
case $1 in
|
51 |
+
-v|--videos_dir) check_argument "$1" "$2"; VIDEOS_DIR="$2"; shift ;;
|
52 |
+
-o|--output_dir) check_argument "$1" "$2"; OUTPUT_DIR="$2"; shift ;;
|
53 |
+
-w|--weights_dir) check_argument "$1" "$2"; WEIGHTS_DIR="$2"; shift ;;
|
54 |
+
-n|--num_frames) check_argument "$1" "$2"; NUM_FRAMES="$2"; shift ;;
|
55 |
+
-h|--help) usage ;;
|
56 |
+
*) echo "Unknown parameter passed: $1"; usage ;;
|
57 |
+
esac
|
58 |
+
shift
|
59 |
+
done
|
60 |
+
|
61 |
+
# Check if all required arguments are provided
|
62 |
+
if [[ -z "$VIDEOS_DIR" || -z "$OUTPUT_DIR" || -z "$WEIGHTS_DIR" || -z "$NUM_FRAMES" ]]; then
|
63 |
+
echo "Error: All arguments are required."
|
64 |
+
usage
|
65 |
+
fi
|
66 |
+
|
67 |
+
# Get the directory where this script is located
|
68 |
+
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
69 |
+
echo "Using script directory: ${SCRIPT_DIR}"
|
70 |
+
|
71 |
+
##### Step 1: Trim and resize videos
|
72 |
+
echo -e "\n\e[1;35m🎬 **Step 1: Trim and resize videos** \e[0m"
|
73 |
+
# Calculate duration to trim videos
|
74 |
+
DURATION=$(printf "%.1f" "$(echo "($NUM_FRAMES / 30) + 0.09" | bc -l)")
|
75 |
+
echo "Trimming videos to duration: ${DURATION} seconds"
|
76 |
+
python3 ${SCRIPT_DIR}/trim_and_crop_videos.py ${VIDEOS_DIR} ${OUTPUT_DIR} -d ${DURATION}
|
77 |
+
|
78 |
+
##### Step 2: Run the VAE encoder on each video.
|
79 |
+
echo -e "\n\e[1;35m🎥 **Step 2: Run the VAE encoder on each video** \e[0m"
|
80 |
+
python3 ${SCRIPT_DIR}/encode_videos.py ${OUTPUT_DIR} \
|
81 |
+
--model_dir ${WEIGHTS_DIR} --num_gpus 1 --shape "${NUM_FRAMES}x480x848" --overwrite
|
82 |
+
|
83 |
+
##### Step 3: Compute T5 embeddings
|
84 |
+
echo -e "\n\e[1;35m🧠 **Step 3: Compute T5 embeddings** \e[0m"
|
85 |
+
python3 ${SCRIPT_DIR}/embed_captions.py --overwrite ${OUTPUT_DIR}
|
86 |
+
|
87 |
+
echo -e "\n\e[1;32m✓ Done!\e[0m"
|
demos/fine_tuner/run.bash
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /bin/bash
|
2 |
+
|
3 |
+
# Enable job control and set process group
|
4 |
+
set -m
|
5 |
+
trap 'kill $(jobs -p)' EXIT INT TERM
|
6 |
+
|
7 |
+
# Get the directory where this script is located
|
8 |
+
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
9 |
+
DEFAULT_CONFIG="${SCRIPT_DIR}/configs/finetune.yaml"
|
10 |
+
|
11 |
+
# Parse command line arguments
|
12 |
+
usage() {
|
13 |
+
echo "Usage: $0 [-c|--config <config_path>] [-n|--num-gpus <num_gpus>]"
|
14 |
+
echo " -c, --config Path to config file (default: ${DEFAULT_CONFIG})"
|
15 |
+
echo " -n, --num-gpus Number of GPUs to use (default: 8)"
|
16 |
+
exit 1
|
17 |
+
}
|
18 |
+
|
19 |
+
# Default values
|
20 |
+
CONFIG_PATH="${DEFAULT_CONFIG}"
|
21 |
+
NUM_GPUS=8
|
22 |
+
|
23 |
+
# Parse arguments
|
24 |
+
while [[ $# -gt 0 ]]; do
|
25 |
+
case $1 in
|
26 |
+
-c|--config)
|
27 |
+
CONFIG_PATH="$2"
|
28 |
+
shift 2
|
29 |
+
;;
|
30 |
+
-n|--num-gpus)
|
31 |
+
NUM_GPUS="$2"
|
32 |
+
shift 2
|
33 |
+
;;
|
34 |
+
-h|--help)
|
35 |
+
usage
|
36 |
+
;;
|
37 |
+
*)
|
38 |
+
echo "Unknown option: $1"
|
39 |
+
usage
|
40 |
+
;;
|
41 |
+
esac
|
42 |
+
done
|
43 |
+
|
44 |
+
# Validate config file exists
|
45 |
+
if [ ! -f "${CONFIG_PATH}" ]; then
|
46 |
+
echo "Config file not found at ${CONFIG_PATH}"
|
47 |
+
exit 1
|
48 |
+
fi
|
49 |
+
|
50 |
+
# Validate num_gpus is a positive integer
|
51 |
+
if ! [[ "$NUM_GPUS" =~ ^[1-9][0-9]*$ ]]; then
|
52 |
+
echo "Number of GPUs must be a positive integer"
|
53 |
+
exit 1
|
54 |
+
fi
|
55 |
+
|
56 |
+
# Set distributed training environment variables
|
57 |
+
export MASTER_PORT=29500
|
58 |
+
export MASTER_ADDR="localhost"
|
59 |
+
export WORLD_SIZE=$NUM_GPUS
|
60 |
+
export TF_CPP_MIN_LOG_LEVEL=3
|
61 |
+
export COMPILE_DIT=1
|
62 |
+
|
63 |
+
# Set IS_DISTRIBUTED based on NUM_GPUS
|
64 |
+
if [ "$NUM_GPUS" -gt 1 ]; then
|
65 |
+
export IS_DISTRIBUTED=true
|
66 |
+
fi
|
67 |
+
|
68 |
+
# Load .env file (if it exists)
|
69 |
+
if [ -f ".env" ]; then
|
70 |
+
export $(grep -v '^#' .env | xargs)
|
71 |
+
fi
|
72 |
+
|
73 |
+
echo "Starting training with ${NUM_GPUS} GPU(s), mode: ${IS_DISTRIBUTED:+distributed}${IS_DISTRIBUTED:-single_gpu}"
|
74 |
+
echo "Using config: ${CONFIG_PATH}"
|
75 |
+
|
76 |
+
# Launch processes
|
77 |
+
if [ "$NUM_GPUS" -gt 1 ]; then
|
78 |
+
for RANK in $(seq 0 $((NUM_GPUS-1))); do
|
79 |
+
env RANK=$RANK CUDA_VISIBLE_DEVICES=$RANK python "${SCRIPT_DIR}/train.py" --config-path "${CONFIG_PATH}" &
|
80 |
+
done
|
81 |
+
else
|
82 |
+
python "${SCRIPT_DIR}/train.py" --config-path "${CONFIG_PATH}" &
|
83 |
+
fi
|
84 |
+
|
85 |
+
# Wait for all background processes to complete
|
86 |
+
wait
|
87 |
+
|
88 |
+
# Check if any process failed
|
89 |
+
if [ $? -ne 0 ]; then
|
90 |
+
echo "One or more training processes failed"
|
91 |
+
exit 1
|
92 |
+
fi
|
demos/fine_tuner/train.py
ADDED
@@ -0,0 +1,398 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import multiprocessing as mp
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import re
|
6 |
+
import sys
|
7 |
+
import time
|
8 |
+
from contextlib import contextmanager
|
9 |
+
from glob import glob
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Any, Dict, Tuple, cast
|
12 |
+
|
13 |
+
import click
|
14 |
+
import numpy as np
|
15 |
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
16 |
+
from safetensors.torch import save_file
|
17 |
+
import torch
|
18 |
+
from torch import Tensor
|
19 |
+
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
torch._dynamo.config.cache_size_limit = 32
|
24 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
25 |
+
torch.backends.cudnn.allow_tf32 = True
|
26 |
+
torch.use_deterministic_algorithms(False)
|
27 |
+
|
28 |
+
import genmo.mochi_preview.dit.joint_model.lora as lora
|
29 |
+
from genmo.lib.progress import progress_bar
|
30 |
+
from genmo.lib.utils import Timer, save_video
|
31 |
+
from genmo.mochi_preview.vae.vae_stats import vae_latents_to_dit_latents
|
32 |
+
from genmo.mochi_preview.pipelines import (
|
33 |
+
DecoderModelFactory,
|
34 |
+
DitModelFactory,
|
35 |
+
ModelFactory,
|
36 |
+
T5ModelFactory,
|
37 |
+
cast_dit,
|
38 |
+
compute_packed_indices,
|
39 |
+
get_conditioning,
|
40 |
+
linear_quadratic_schedule, # used in eval'd Python code in lora.yaml
|
41 |
+
load_to_cpu,
|
42 |
+
move_to_device,
|
43 |
+
sample_model,
|
44 |
+
t5_tokenizer,
|
45 |
+
)
|
46 |
+
from genmo.mochi_preview.vae.latent_dist import LatentDistribution
|
47 |
+
from genmo.mochi_preview.vae.models import decode_latents_tiled_spatial
|
48 |
+
|
49 |
+
sys.path.append("..")
|
50 |
+
|
51 |
+
from dataset import LatentEmbedDataset
|
52 |
+
|
53 |
+
|
54 |
+
class MochiTorchRunEvalPipeline:
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
*,
|
58 |
+
device_id,
|
59 |
+
dit,
|
60 |
+
text_encoder_factory: ModelFactory,
|
61 |
+
decoder_factory: ModelFactory,
|
62 |
+
):
|
63 |
+
self.device = torch.device(f"cuda:{device_id}")
|
64 |
+
self.tokenizer = t5_tokenizer()
|
65 |
+
t = Timer()
|
66 |
+
self.dit = dit
|
67 |
+
with t("load_text_encoder"):
|
68 |
+
self.text_encoder = text_encoder_factory.get_model(
|
69 |
+
local_rank=0,
|
70 |
+
world_size=1,
|
71 |
+
device_id="cpu",
|
72 |
+
)
|
73 |
+
with t("load_vae"):
|
74 |
+
self.decoder = decoder_factory.get_model(local_rank=0, device_id="cpu", world_size=1)
|
75 |
+
t.print_stats() # type: ignore
|
76 |
+
|
77 |
+
def __call__(self, prompt, save_path, **kwargs):
|
78 |
+
with progress_bar(type="tqdm", enabled=True), torch.inference_mode():
|
79 |
+
# Encode prompt with T5 XXL.
|
80 |
+
with move_to_device(self.text_encoder, self.device, enabled=True):
|
81 |
+
conditioning = get_conditioning(
|
82 |
+
self.tokenizer,
|
83 |
+
self.text_encoder,
|
84 |
+
self.device,
|
85 |
+
batch_inputs=False,
|
86 |
+
prompt=prompt,
|
87 |
+
negative_prompt="",
|
88 |
+
)
|
89 |
+
|
90 |
+
# Sample video latents from Mochi.
|
91 |
+
with move_to_device(self.dit, self.device, enabled=True):
|
92 |
+
latents = sample_model(self.device, self.dit, conditioning, **kwargs)
|
93 |
+
|
94 |
+
# Decode video latents to frames.
|
95 |
+
with move_to_device(self.decoder, self.device, enabled=True):
|
96 |
+
frames = decode_latents_tiled_spatial(
|
97 |
+
self.decoder, latents, num_tiles_w=2, num_tiles_h=2, overlap=8)
|
98 |
+
frames = frames.cpu().numpy() # b t h w c
|
99 |
+
assert isinstance(frames, np.ndarray)
|
100 |
+
|
101 |
+
save_video(frames[0], save_path)
|
102 |
+
|
103 |
+
|
104 |
+
def map_to_device(x, device: torch.device):
|
105 |
+
if isinstance(x, dict):
|
106 |
+
return {k: map_to_device(v, device) for k, v in x.items()}
|
107 |
+
elif isinstance(x, list):
|
108 |
+
return [map_to_device(y, device) for y in x]
|
109 |
+
elif isinstance(x, tuple):
|
110 |
+
return tuple(map_to_device(y, device) for y in x)
|
111 |
+
elif isinstance(x, torch.Tensor):
|
112 |
+
return x.to(device, non_blocking=True)
|
113 |
+
else:
|
114 |
+
return x
|
115 |
+
|
116 |
+
|
117 |
+
EPOCH_IDX = 0
|
118 |
+
|
119 |
+
|
120 |
+
def infinite_dl(dl):
|
121 |
+
global EPOCH_IDX
|
122 |
+
while True:
|
123 |
+
EPOCH_IDX += 1
|
124 |
+
for batch in dl:
|
125 |
+
yield batch
|
126 |
+
|
127 |
+
|
128 |
+
@contextmanager
|
129 |
+
def timer(description="Task", enabled=True):
|
130 |
+
if enabled:
|
131 |
+
start = time.perf_counter()
|
132 |
+
try:
|
133 |
+
yield
|
134 |
+
finally:
|
135 |
+
if enabled:
|
136 |
+
elapsed = time.perf_counter() - start # type: ignore
|
137 |
+
print(f"{description} took {elapsed:.4f} seconds")
|
138 |
+
|
139 |
+
|
140 |
+
def get_cosine_annealing_lr_scheduler(
|
141 |
+
optimizer: torch.optim.Optimizer,
|
142 |
+
warmup_steps: int,
|
143 |
+
total_steps: int,
|
144 |
+
):
|
145 |
+
def lr_lambda(step):
|
146 |
+
if step < warmup_steps:
|
147 |
+
return float(step) / float(max(1, warmup_steps))
|
148 |
+
else:
|
149 |
+
return 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / (total_steps - warmup_steps)))
|
150 |
+
|
151 |
+
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
152 |
+
|
153 |
+
|
154 |
+
@click.command()
|
155 |
+
@click.option("--config-path", type=click.Path(exists=True), required=True, help="Path to YAML config file")
|
156 |
+
def main(config_path):
|
157 |
+
mp.set_start_method("spawn", force=True)
|
158 |
+
cfg = cast(DictConfig, OmegaConf.load(config_path))
|
159 |
+
|
160 |
+
device_id = 0
|
161 |
+
device_str = f"cuda:0"
|
162 |
+
device = torch.device(device_str)
|
163 |
+
|
164 |
+
# Verify checkpoint path exists
|
165 |
+
checkpoint_path = Path(cfg.init_checkpoint_path)
|
166 |
+
assert checkpoint_path.exists(), f"Checkpoint file not found: {checkpoint_path}"
|
167 |
+
|
168 |
+
# Create checkpoint directory if it doesn't exist
|
169 |
+
checkpoint_dir = Path(cfg.checkpoint_dir)
|
170 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
171 |
+
|
172 |
+
# Get step number from checkpoint filename
|
173 |
+
pattern = r"model_(\d+)\.(lora|checkpoint)\.(safetensors|pt)"
|
174 |
+
match = re.search(pattern, str(checkpoint_path))
|
175 |
+
if match:
|
176 |
+
start_step_num = int(match.group(1))
|
177 |
+
opt_path = str(checkpoint_path).replace("model_", "optimizer_")
|
178 |
+
else:
|
179 |
+
start_step_num = 0
|
180 |
+
opt_path = ""
|
181 |
+
|
182 |
+
print(
|
183 |
+
f"model={checkpoint_path}, optimizer={opt_path}, start_step_num={start_step_num}"
|
184 |
+
)
|
185 |
+
|
186 |
+
wandb_run = None
|
187 |
+
sample_prompts = cfg.sample.prompts
|
188 |
+
|
189 |
+
train_vids = list(sorted(glob(f"{cfg.train_data_dir}/*.mp4")))
|
190 |
+
train_vids = [v for v in train_vids if not v.endswith(".recon.mp4")]
|
191 |
+
print(f"Found {len(train_vids)} training videos in {cfg.train_data_dir}")
|
192 |
+
assert len(train_vids) > 0, f"No training data found in {cfg.train_data_dir}"
|
193 |
+
if cfg.single_video_mode:
|
194 |
+
train_vids = train_vids[:1]
|
195 |
+
sample_prompts = [Path(train_vids[0]).with_suffix(".txt").read_text()]
|
196 |
+
print(f"Training on video: {train_vids[0]}")
|
197 |
+
|
198 |
+
train_dataset = LatentEmbedDataset(
|
199 |
+
train_vids,
|
200 |
+
repeat=1_000 if cfg.single_video_mode else 1,
|
201 |
+
)
|
202 |
+
train_dl = torch.utils.data.DataLoader(
|
203 |
+
train_dataset,
|
204 |
+
batch_size=None,
|
205 |
+
num_workers=4,
|
206 |
+
shuffle=True,
|
207 |
+
pin_memory=True,
|
208 |
+
)
|
209 |
+
train_dl_iter = infinite_dl(train_dl)
|
210 |
+
|
211 |
+
if cfg.get("wandb"):
|
212 |
+
import wandb
|
213 |
+
|
214 |
+
wandb_run = wandb.init(
|
215 |
+
project=cfg.wandb.project,
|
216 |
+
name=f"{cfg.wandb.name}-{int(time.time())}",
|
217 |
+
config=OmegaConf.to_container(cfg), # type: ignore
|
218 |
+
)
|
219 |
+
print(f"🚀 Weights & Biases run URL: {wandb_run.get_url()}")
|
220 |
+
|
221 |
+
print("Loading model")
|
222 |
+
patch_model_fns = []
|
223 |
+
model_kwargs = {}
|
224 |
+
is_lora = cfg.model.type == "lora"
|
225 |
+
print(f"Training type: {'LoRA' if is_lora else 'Full'}")
|
226 |
+
if is_lora:
|
227 |
+
def mark_lora_params(m):
|
228 |
+
lora.mark_only_lora_as_trainable(m, bias="none")
|
229 |
+
return m
|
230 |
+
|
231 |
+
patch_model_fns.append(mark_lora_params)
|
232 |
+
model_kwargs = dict(**cfg.model.kwargs)
|
233 |
+
# Replace ListConfig with list to allow serialization to JSON.
|
234 |
+
for k, v in model_kwargs.items():
|
235 |
+
if isinstance(v, ListConfig):
|
236 |
+
model_kwargs[k] = list(v)
|
237 |
+
|
238 |
+
if cfg.training.get("model_dtype"):
|
239 |
+
assert cfg.training.model_dtype == "bf16", f"Only bf16 is supported"
|
240 |
+
patch_model_fns.append(lambda m: cast_dit(m, torch.bfloat16))
|
241 |
+
|
242 |
+
model = (
|
243 |
+
DitModelFactory(
|
244 |
+
model_path=str(checkpoint_path),
|
245 |
+
model_dtype="bf16",
|
246 |
+
attention_mode=cfg.attention_mode
|
247 |
+
).get_model(
|
248 |
+
local_rank=0,
|
249 |
+
device_id=device_id,
|
250 |
+
model_kwargs=model_kwargs,
|
251 |
+
patch_model_fns=patch_model_fns,
|
252 |
+
world_size=1,
|
253 |
+
strict_load=not is_lora,
|
254 |
+
fast_init=not is_lora, # fast_init not supported for LoRA (please someone fix this !!!)
|
255 |
+
)
|
256 |
+
.train() # calling train() makes sure LoRA weights are not merged
|
257 |
+
)
|
258 |
+
|
259 |
+
optimizer = torch.optim.AdamW(model.parameters(), **cfg.optimizer)
|
260 |
+
if os.path.exists(opt_path):
|
261 |
+
print("Loading optimizer")
|
262 |
+
optimizer.load_state_dict(load_to_cpu(opt_path))
|
263 |
+
|
264 |
+
scheduler = get_cosine_annealing_lr_scheduler(
|
265 |
+
optimizer,
|
266 |
+
warmup_steps=cfg.training.warmup_steps,
|
267 |
+
total_steps=cfg.training.num_steps
|
268 |
+
)
|
269 |
+
|
270 |
+
print("Loading eval pipeline ...")
|
271 |
+
eval_pipeline = MochiTorchRunEvalPipeline(
|
272 |
+
device_id=device_id,
|
273 |
+
dit=model,
|
274 |
+
text_encoder_factory=T5ModelFactory(),
|
275 |
+
decoder_factory=DecoderModelFactory(model_path=cfg.sample.decoder_path),
|
276 |
+
)
|
277 |
+
|
278 |
+
def get_batch() -> Tuple[Dict[str, Any], Tensor, Tensor, Tensor]:
|
279 |
+
nonlocal train_dl_iter
|
280 |
+
batch = next(train_dl_iter) # type: ignore
|
281 |
+
latent, embed = cast(Tuple[Dict[str, Any], Dict[str, Any]], batch)
|
282 |
+
assert len(embed["y_feat"]) == 1 and len(embed["y_mask"]) == 1, f"Only batch size 1 is supported"
|
283 |
+
|
284 |
+
ldist = LatentDistribution(latent["mean"], latent["logvar"])
|
285 |
+
z = ldist.sample()
|
286 |
+
assert torch.isfinite(z).all()
|
287 |
+
assert z.shape[0] == 1, f"Only batch size 1 is supported"
|
288 |
+
|
289 |
+
eps = torch.randn_like(z)
|
290 |
+
sigma = torch.rand(z.shape[:1], device="cpu", dtype=torch.float32)
|
291 |
+
|
292 |
+
if random.random() < cfg.training.caption_dropout:
|
293 |
+
embed["y_mask"][0].zero_()
|
294 |
+
embed["y_feat"][0].zero_()
|
295 |
+
return embed, z, eps, sigma
|
296 |
+
|
297 |
+
pbar = tqdm(
|
298 |
+
range(start_step_num, cfg.training.num_steps),
|
299 |
+
total=cfg.training.num_steps,
|
300 |
+
initial=start_step_num,
|
301 |
+
)
|
302 |
+
for step in pbar:
|
303 |
+
if cfg.sample.interval and step % cfg.sample.interval == 0 and step > 0:
|
304 |
+
sample_dir = Path(cfg.sample.output_dir)
|
305 |
+
sample_dir.mkdir(exist_ok=True)
|
306 |
+
model.eval()
|
307 |
+
for eval_idx, prompt in enumerate(sample_prompts):
|
308 |
+
save_path = sample_dir / f"{eval_idx}_{step}.mp4"
|
309 |
+
if save_path.exists():
|
310 |
+
print(f"Skipping {save_path} as it already exists")
|
311 |
+
continue
|
312 |
+
|
313 |
+
sample_kwargs = {
|
314 |
+
k.removesuffix("_python_code"): (eval(v) if k.endswith("_python_code") else v)
|
315 |
+
for k, v in cfg.sample.kwargs.items()
|
316 |
+
}
|
317 |
+
eval_pipeline(
|
318 |
+
prompt=prompt,
|
319 |
+
save_path=str(save_path),
|
320 |
+
seed=cfg.sample.seed + eval_idx,
|
321 |
+
**sample_kwargs,
|
322 |
+
)
|
323 |
+
Path(sample_dir / f"{eval_idx}_{step}.txt").write_text(prompt)
|
324 |
+
model.train()
|
325 |
+
|
326 |
+
if cfg.training.save_interval and step > 0 and step % cfg.training.save_interval == 0:
|
327 |
+
with timer("get_state_dict"):
|
328 |
+
if is_lora:
|
329 |
+
model_sd = lora.lora_state_dict(model, bias="none")
|
330 |
+
else:
|
331 |
+
# NOTE: Not saving optimizer state dict to save space.
|
332 |
+
model_sd, _optimizer_sd = get_state_dict(
|
333 |
+
model, [], options=StateDictOptions(cpu_offload=True, full_state_dict=True)
|
334 |
+
)
|
335 |
+
|
336 |
+
checkpoint_filename = f"model_{step}.{'lora' if is_lora else 'checkpoint'}.pt"
|
337 |
+
save_path = checkpoint_dir / checkpoint_filename
|
338 |
+
if cfg.training.get("save_safetensors", True):
|
339 |
+
save_path = save_path.with_suffix(".safetensors")
|
340 |
+
save_file(
|
341 |
+
model_sd, save_path,
|
342 |
+
# `safetensors` only supports string-to-string metadata,
|
343 |
+
# so we serialize the kwargs to a JSON string.
|
344 |
+
metadata=dict(kwargs=json.dumps(model_kwargs)),
|
345 |
+
)
|
346 |
+
else:
|
347 |
+
torch.save(model_sd, save_path)
|
348 |
+
|
349 |
+
with torch.no_grad(), timer("load_batch", enabled=False):
|
350 |
+
batch = get_batch()
|
351 |
+
embed, z, eps, sigma = map_to_device(batch, device)
|
352 |
+
embed = cast(Dict[str, Any], embed)
|
353 |
+
|
354 |
+
num_latent_toks = np.prod(z.shape[-3:])
|
355 |
+
indices = compute_packed_indices(device, cast(Tensor, embed["y_mask"][0]), int(num_latent_toks))
|
356 |
+
|
357 |
+
sigma_bcthw = sigma[:, None, None, None, None] # [B, 1, 1, 1, 1]
|
358 |
+
z_sigma = (1 - sigma_bcthw) * z + sigma_bcthw * eps
|
359 |
+
ut = z - eps
|
360 |
+
|
361 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
362 |
+
preds = model(
|
363 |
+
x=z_sigma,
|
364 |
+
sigma=sigma,
|
365 |
+
packed_indices=indices,
|
366 |
+
**embed,
|
367 |
+
num_ff_checkpoint=cfg.training.num_ff_checkpoint,
|
368 |
+
num_qkv_checkpoint=cfg.training.num_qkv_checkpoint,
|
369 |
+
)
|
370 |
+
assert preds.shape == z.shape
|
371 |
+
|
372 |
+
ut_dit_space = vae_latents_to_dit_latents(ut.float())
|
373 |
+
loss = F.mse_loss(preds.float(), ut_dit_space)
|
374 |
+
loss.backward()
|
375 |
+
|
376 |
+
log_kwargs = {
|
377 |
+
"train/loss": loss.item(),
|
378 |
+
"train/epoch": EPOCH_IDX,
|
379 |
+
"train/lr": scheduler.get_last_lr()[0],
|
380 |
+
}
|
381 |
+
|
382 |
+
if cfg.training.get("grad_clip"):
|
383 |
+
assert not is_lora, "Gradient clipping not supported for LoRA"
|
384 |
+
gnorm_before_clip = torch.nn.utils.clip_grad_norm_(
|
385 |
+
model.parameters(), max_norm=cfg.training.grad_clip)
|
386 |
+
log_kwargs["train/gnorm"] = gnorm_before_clip.item()
|
387 |
+
pbar.set_postfix(**log_kwargs)
|
388 |
+
|
389 |
+
if wandb_run:
|
390 |
+
wandb_run.log(log_kwargs, step=step)
|
391 |
+
|
392 |
+
optimizer.step()
|
393 |
+
scheduler.step()
|
394 |
+
optimizer.zero_grad()
|
395 |
+
|
396 |
+
|
397 |
+
if __name__ == "__main__":
|
398 |
+
main()
|
demos/fine_tuner/trim_and_crop_videos.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
from pathlib import Path
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
import click
|
6 |
+
from moviepy.editor import VideoFileClip
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
@click.command()
|
11 |
+
@click.argument("folder", type=click.Path(exists=True, dir_okay=True))
|
12 |
+
@click.argument("output_folder", type=click.Path(dir_okay=True))
|
13 |
+
@click.option("--duration", "-d", type=float, default=5.4, help="Duration in seconds")
|
14 |
+
@click.option("--resolution", "-r", type=str, default="848x480", help="Video resolution")
|
15 |
+
def truncate_videos(folder, output_folder, duration, resolution):
|
16 |
+
"""Truncate all MP4 and MOV files in FOLDER to specified duration and resolution"""
|
17 |
+
input_path = Path(folder)
|
18 |
+
output_path = Path(output_folder)
|
19 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
20 |
+
|
21 |
+
# Parse target resolution
|
22 |
+
target_width, target_height = map(int, resolution.split("x"))
|
23 |
+
|
24 |
+
# Find all MP4 and MOV files
|
25 |
+
video_files = (
|
26 |
+
list(input_path.rglob("*.mp4"))
|
27 |
+
+ list(input_path.rglob("*.MOV"))
|
28 |
+
+ list(input_path.rglob("*.mov"))
|
29 |
+
+ list(input_path.rglob("*.MP4"))
|
30 |
+
)
|
31 |
+
|
32 |
+
for file_path in tqdm(video_files):
|
33 |
+
try:
|
34 |
+
relative_path = file_path.relative_to(input_path)
|
35 |
+
output_file = output_path / relative_path.with_suffix(".mp4")
|
36 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
37 |
+
|
38 |
+
click.echo(f"Processing: {file_path}")
|
39 |
+
video = VideoFileClip(str(file_path))
|
40 |
+
|
41 |
+
# Skip if video is too short
|
42 |
+
if video.duration < duration:
|
43 |
+
click.echo(f"Skipping {file_path} as it is too short")
|
44 |
+
continue
|
45 |
+
|
46 |
+
# Skip if target resolution is larger than input
|
47 |
+
if target_width > video.w or target_height > video.h:
|
48 |
+
click.echo(
|
49 |
+
f"Skipping {file_path} as target resolution {resolution} is larger than input {video.w}x{video.h}"
|
50 |
+
)
|
51 |
+
continue
|
52 |
+
|
53 |
+
# First truncate duration
|
54 |
+
truncated = video.subclip(0, duration)
|
55 |
+
|
56 |
+
# Calculate crop dimensions to maintain aspect ratio
|
57 |
+
target_ratio = target_width / target_height
|
58 |
+
current_ratio = truncated.w / truncated.h
|
59 |
+
|
60 |
+
if current_ratio > target_ratio:
|
61 |
+
# Video is wider than target ratio - crop width
|
62 |
+
new_width = int(truncated.h * target_ratio)
|
63 |
+
x1 = (truncated.w - new_width) // 2
|
64 |
+
final = truncated.crop(x1=x1, width=new_width).resize((target_width, target_height))
|
65 |
+
else:
|
66 |
+
# Video is taller than target ratio - crop height
|
67 |
+
new_height = int(truncated.w / target_ratio)
|
68 |
+
y1 = (truncated.h - new_height) // 2
|
69 |
+
final = truncated.crop(y1=y1, height=new_height).resize((target_width, target_height))
|
70 |
+
|
71 |
+
# Set output parameters for consistent MP4 encoding
|
72 |
+
output_params = {
|
73 |
+
"codec": "libx264",
|
74 |
+
"audio": False, # Disable audio
|
75 |
+
"preset": "medium", # Balance between speed and quality
|
76 |
+
"bitrate": "5000k", # Adjust as needed
|
77 |
+
}
|
78 |
+
|
79 |
+
# Set FPS to 30
|
80 |
+
final = final.set_fps(30)
|
81 |
+
|
82 |
+
# Check for a corresponding .txt file
|
83 |
+
txt_file_path = file_path.with_suffix('.txt')
|
84 |
+
if txt_file_path.exists():
|
85 |
+
output_txt_file = output_path / relative_path.with_suffix('.txt')
|
86 |
+
output_txt_file.parent.mkdir(parents=True, exist_ok=True)
|
87 |
+
shutil.copy(txt_file_path, output_txt_file)
|
88 |
+
click.echo(f"Copied {txt_file_path} to {output_txt_file}")
|
89 |
+
else:
|
90 |
+
# Print warning in bold yellow with a warning emoji
|
91 |
+
click.echo(f"\033[1;33m⚠️ Warning: No caption found for {file_path}, using an empty caption. This may hurt fine-tuning quality.\033[0m")
|
92 |
+
output_txt_file = output_path / relative_path.with_suffix('.txt')
|
93 |
+
output_txt_file.parent.mkdir(parents=True, exist_ok=True)
|
94 |
+
output_txt_file.touch()
|
95 |
+
|
96 |
+
# Write the output file
|
97 |
+
final.write_videofile(str(output_file), **output_params)
|
98 |
+
|
99 |
+
# Clean up
|
100 |
+
video.close()
|
101 |
+
truncated.close()
|
102 |
+
final.close()
|
103 |
+
|
104 |
+
except Exception as e:
|
105 |
+
click.echo(f"\033[1;31m Error processing {file_path}: {str(e)}\033[0m", err=True)
|
106 |
+
raise
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
truncate_videos()
|
demos/gradio_ui.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python
|
2 |
+
|
3 |
+
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import click
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
sys.path.append("..")
|
10 |
+
from cli import configure_model, generate_video
|
11 |
+
|
12 |
+
with gr.Blocks() as demo:
|
13 |
+
gr.Markdown("Video Generator")
|
14 |
+
with gr.Row():
|
15 |
+
prompt = gr.Textbox(
|
16 |
+
label="Prompt",
|
17 |
+
value="A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl filled with lemons and sprigs of mint against a peach-colored background. The hand gently tosses the lemon up and catches it, showcasing its smooth texture. A beige string bag sits beside the bowl, adding a rustic touch to the scene. Additional lemons, one halved, are scattered around the base of the bowl. The even lighting enhances the vibrant colors and creates a fresh, inviting atmosphere.",
|
18 |
+
)
|
19 |
+
negative_prompt = gr.Textbox(label="Negative Prompt", value="")
|
20 |
+
seed = gr.Number(label="Seed", value=1710977262, precision=0)
|
21 |
+
with gr.Row():
|
22 |
+
width = gr.Number(label="Width", value=848, precision=0)
|
23 |
+
height = gr.Number(label="Height", value=480, precision=0)
|
24 |
+
num_frames = gr.Number(label="Number of Frames", value=163, precision=0)
|
25 |
+
with gr.Row():
|
26 |
+
cfg_scale = gr.Number(label="CFG Scale", value=6.0)
|
27 |
+
num_inference_steps = gr.Number(label="Number of Inference Steps", value=100, precision=0)
|
28 |
+
btn = gr.Button("Generate Video")
|
29 |
+
output = gr.Video()
|
30 |
+
|
31 |
+
btn.click(
|
32 |
+
generate_video,
|
33 |
+
inputs=[
|
34 |
+
prompt,
|
35 |
+
negative_prompt,
|
36 |
+
width,
|
37 |
+
height,
|
38 |
+
num_frames,
|
39 |
+
seed,
|
40 |
+
cfg_scale,
|
41 |
+
num_inference_steps,
|
42 |
+
],
|
43 |
+
outputs=output,
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
@click.command()
|
48 |
+
@click.option("--model_dir", required=True, help="Path to the model directory.")
|
49 |
+
@click.option("--lora_path", required=False, help="Path to the lora file.")
|
50 |
+
@click.option("--cpu_offload", is_flag=True, help="Whether to offload model to CPU")
|
51 |
+
def launch(model_dir, lora_path, cpu_offload):
|
52 |
+
configure_model(model_dir, lora_path, cpu_offload)
|
53 |
+
demo.launch()
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
launch()
|
demos/gradio_ui_adapted.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#! /usr/bin/env python
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import torch # Ensure PyTorch is imported to configure MPS backend
|
6 |
+
import click
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
# Configure PyTorch to use the MPS backend if available (for Apple Silicon)
|
10 |
+
if torch.backends.mps.is_available():
|
11 |
+
device = torch.device("mps")
|
12 |
+
print("Using MPS backend for Apple Silicon")
|
13 |
+
else:
|
14 |
+
device = torch.device("cpu")
|
15 |
+
print("MPS backend not available. Using CPU.")
|
16 |
+
|
17 |
+
sys.path.append("..")
|
18 |
+
from cli import configure_model, generate_video
|
19 |
+
|
20 |
+
# Pass the configured device (MPS or CPU) to your model
|
21 |
+
def generate_with_device(prompt, *args):
|
22 |
+
# Modify the model initialization in 'configure_model' to use the device
|
23 |
+
model = configure_model(device=device)
|
24 |
+
return generate_video(prompt, model, *args)
|
25 |
+
|
26 |
+
with gr.Blocks() as demo:
|
27 |
+
gr.Markdown("Video Generator")
|
28 |
+
with gr.Row():
|
29 |
+
prompt = gr.Textbox(
|
30 |
+
label="Prompt",
|
31 |
+
value="A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl filled with lemons and sprigs of mint against a peach-colored background. The hand gently tosses the lemon up and catches it."
|
32 |
+
)
|
33 |
+
with gr.Row():
|
34 |
+
generate_button = gr.Button("Generate Video")
|
35 |
+
output = gr.Video(label="Generated Video")
|
36 |
+
generate_button.click(generate_with_device, inputs=[prompt], outputs=[output])
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
demo.launch()
|
demos/gradio_ui_fixed.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python
|
2 |
+
|
3 |
+
import sys
|
4 |
+
import torch # Ensure PyTorch is imported to configure MPS backend
|
5 |
+
import click
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
# Configure PyTorch to use the MPS backend if available (for Apple Silicon)
|
9 |
+
if torch.backends.mps.is_available():
|
10 |
+
device = torch.device("mps")
|
11 |
+
print("Using MPS backend for Apple Silicon")
|
12 |
+
else:
|
13 |
+
device = torch.device("cpu")
|
14 |
+
print("MPS backend not available. Using CPU.")
|
15 |
+
|
16 |
+
sys.path.append("..")
|
17 |
+
from cli import configure_model, generate_video
|
18 |
+
|
19 |
+
# Set the required arguments for configure_model
|
20 |
+
model_dir_path_ = "/path/to/model/dir" # Replace with the actual path
|
21 |
+
lora_path_ = "/path/to/lora" # Replace with the actual path
|
22 |
+
cpu_offload_ = False # Set True or False based on your needs
|
23 |
+
|
24 |
+
# Adjust model loading to set the device directly after configuration
|
25 |
+
def generate_with_device(prompt, *args):
|
26 |
+
# Load the model
|
27 |
+
model = configure_model(model_dir_path_, lora_path_, cpu_offload_)
|
28 |
+
|
29 |
+
# Move the model to the specified device
|
30 |
+
model.to(device)
|
31 |
+
|
32 |
+
# Generate video with the model
|
33 |
+
return generate_video(prompt, model, *args)
|
34 |
+
|
35 |
+
with gr.Blocks() as demo:
|
36 |
+
gr.Markdown("Video Generator")
|
37 |
+
with gr.Row():
|
38 |
+
prompt = gr.Textbox(
|
39 |
+
label="Prompt",
|
40 |
+
value="A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl filled with lemons and sprigs of mint against a peach-colored background. The hand gently tosses the lemon up and catches it."
|
41 |
+
)
|
42 |
+
with gr.Row():
|
43 |
+
generate_button = gr.Button("Generate Video")
|
44 |
+
output = gr.Video(label="Generated Video")
|
45 |
+
generate_button.click(generate_with_device, inputs=[prompt], outputs=[output])
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
demo.launch(share=True) # Enable public link with share=True
|
demos/gradio_ui_fixed.py~
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python
|
2 |
+
|
3 |
+
import sys
|
4 |
+
import torch # Ensure PyTorch is imported to configure MPS backend
|
5 |
+
import click
|
6 |
+
import gradio as gr
|
7 |
+
|
8 |
+
# Configure PyTorch to use the MPS backend if available (for Apple Silicon)
|
9 |
+
if torch.backends.mps.is_available():
|
10 |
+
device = torch.device("mps")
|
11 |
+
print("Using MPS backend for Apple Silicon")
|
12 |
+
else:
|
13 |
+
device = torch.device("cpu")
|
14 |
+
print("MPS backend not available. Using CPU.")
|
15 |
+
|
16 |
+
sys.path.append("..")
|
17 |
+
from cli import configure_model, generate_video
|
18 |
+
|
19 |
+
# Set the required arguments for configure_model
|
20 |
+
model_dir_path_ = "/path/to/model/dir" # Replace with the actual path
|
21 |
+
lora_path_ = "/path/to/lora" # Replace with the actual path
|
22 |
+
cpu_offload_ = False # Set True or False based on your needs
|
23 |
+
|
24 |
+
# Adjust model loading to set the device directly after configuration
|
25 |
+
def generate_with_device(prompt, *args):
|
26 |
+
# Load the model
|
27 |
+
model = configure_model(model_dir_path_, lora_path_, cpu_offload_)
|
28 |
+
|
29 |
+
# Move the model to the specified device
|
30 |
+
model.to(device)
|
31 |
+
|
32 |
+
# Generate video with the model
|
33 |
+
return generate_video(prompt, model, *args)
|
34 |
+
|
35 |
+
with gr.Blocks() as demo:
|
36 |
+
gr.Markdown("Video Generator")
|
37 |
+
with gr.Row():
|
38 |
+
prompt = gr.Textbox(
|
39 |
+
label="Prompt",
|
40 |
+
value="A hand with delicate fingers picks up a bright yellow lemon from a wooden bowl filled with lemons and sprigs of mint against a peach-colored background. The hand gently tosses the lemon up and catches it."
|
41 |
+
)
|
42 |
+
with gr.Row():
|
43 |
+
generate_button = gr.Button("Generate Video")
|
44 |
+
output = gr.Video(label="Generated Video")
|
45 |
+
generate_button.click(generate_with_device, inputs=[prompt], outputs=[output])
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
demo.launch(share=True) # Enable public link with share=True
|
demos/test_encoder_decoder.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import click
|
4 |
+
import torch
|
5 |
+
import torchvision
|
6 |
+
from einops import rearrange
|
7 |
+
from safetensors.torch import load_file
|
8 |
+
|
9 |
+
from genmo.lib.utils import save_video
|
10 |
+
from genmo.mochi_preview.pipelines import DecoderModelFactory, decode_latents_tiled_spatial
|
11 |
+
from genmo.mochi_preview.vae.models import Encoder, add_fourier_features
|
12 |
+
|
13 |
+
|
14 |
+
@click.command()
|
15 |
+
@click.argument("mochi_dir", type=str)
|
16 |
+
@click.argument("video_path", type=click.Path(exists=True))
|
17 |
+
def reconstruct(mochi_dir, video_path):
|
18 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
19 |
+
torch.backends.cudnn.allow_tf32 = True
|
20 |
+
|
21 |
+
decoder_factory = DecoderModelFactory(
|
22 |
+
model_path=f"{mochi_dir}/decoder.safetensors",
|
23 |
+
)
|
24 |
+
decoder = decoder_factory.get_model(world_size=1, device_id=0, local_rank=0)
|
25 |
+
|
26 |
+
config = dict(
|
27 |
+
prune_bottlenecks=[False, False, False, False, False],
|
28 |
+
has_attentions=[False, True, True, True, True],
|
29 |
+
affine=True,
|
30 |
+
bias=True,
|
31 |
+
input_is_conv_1x1=True,
|
32 |
+
padding_mode="replicate",
|
33 |
+
)
|
34 |
+
|
35 |
+
# Create VAE encoder
|
36 |
+
encoder = Encoder(
|
37 |
+
in_channels=15,
|
38 |
+
base_channels=64,
|
39 |
+
channel_multipliers=[1, 2, 4, 6],
|
40 |
+
num_res_blocks=[3, 3, 4, 6, 3],
|
41 |
+
latent_dim=12,
|
42 |
+
temporal_reductions=[1, 2, 3],
|
43 |
+
spatial_reductions=[2, 2, 2],
|
44 |
+
**config,
|
45 |
+
)
|
46 |
+
device = torch.device("cuda:0")
|
47 |
+
encoder = encoder.to(device, memory_format=torch.channels_last_3d)
|
48 |
+
encoder.load_state_dict(load_file(f"{mochi_dir}/encoder.safetensors"))
|
49 |
+
encoder.eval()
|
50 |
+
|
51 |
+
video, _, metadata = torchvision.io.read_video(video_path, output_format="THWC")
|
52 |
+
fps = metadata["video_fps"]
|
53 |
+
video = rearrange(video, "t h w c -> c t h w")
|
54 |
+
video = video.unsqueeze(0)
|
55 |
+
assert video.dtype == torch.uint8
|
56 |
+
# Convert to float in [-1, 1] range.
|
57 |
+
video = video.float() / 127.5 - 1.0
|
58 |
+
video = video.to(device)
|
59 |
+
video = add_fourier_features(video)
|
60 |
+
torch.cuda.synchronize()
|
61 |
+
|
62 |
+
# Encode video to latent
|
63 |
+
with torch.inference_mode():
|
64 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
65 |
+
t0 = time.time()
|
66 |
+
ldist = encoder(video)
|
67 |
+
torch.cuda.synchronize()
|
68 |
+
print(f"Time to encode: {time.time() - t0:.2f}s")
|
69 |
+
t0 = time.time()
|
70 |
+
frames = decode_latents_tiled_spatial(decoder, ldist.sample(), num_tiles_w=2, num_tiles_h=2)
|
71 |
+
torch.cuda.synchronize()
|
72 |
+
print(f"Time to decode: {time.time() - t0:.2f}s")
|
73 |
+
t0 = time.time()
|
74 |
+
save_video(frames.cpu().numpy()[0], f"{video_path}.recon.mp4", fps=fps)
|
75 |
+
print(f"Time to save: {time.time() - t0:.2f}s")
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == "__main__":
|
79 |
+
reconstruct()
|
pyproject.toml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "genmo"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "Genmo models"
|
5 |
+
readme = "README.md"
|
6 |
+
requires-python = ">=3.10"
|
7 |
+
dependencies = [
|
8 |
+
"addict>=2.4.0",
|
9 |
+
"av==13.1.0",
|
10 |
+
"click>=8.1.7",
|
11 |
+
"einops>=0.8.0",
|
12 |
+
"gradio>=3.36.1",
|
13 |
+
"moviepy==1.0.3",
|
14 |
+
"omegaconf>=2.3.0",
|
15 |
+
"pillow==9.5.0",
|
16 |
+
"pyyaml>=6.0.2",
|
17 |
+
"ray>=2.37.0",
|
18 |
+
"sentencepiece>=0.2.0",
|
19 |
+
"setuptools>=75.2.0",
|
20 |
+
"torch>=2.4.1",
|
21 |
+
"torchvision>=0.19.1",
|
22 |
+
"transformers>=4.45.2",
|
23 |
+
]
|
24 |
+
|
25 |
+
[project.optional-dependencies]
|
26 |
+
flash = [
|
27 |
+
"flash-attn>=2.6.3"
|
28 |
+
]
|
29 |
+
|
30 |
+
torchvision = [
|
31 |
+
"torchvision>=0.15.0",
|
32 |
+
"pyav>=13.1.0"
|
33 |
+
]
|
34 |
+
|
35 |
+
[tool.ruff]
|
36 |
+
# Allow lines to be as long as 120.
|
37 |
+
line-length = 120
|
scripts/download_weights.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
import os
|
3 |
+
import tempfile
|
4 |
+
|
5 |
+
import click
|
6 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
7 |
+
import shutil
|
8 |
+
|
9 |
+
BASE_MODEL_FILES = [
|
10 |
+
# (repo_id, remote_file_path, local_file_path)
|
11 |
+
("genmo/mochi-1-preview", "decoder.safetensors", "decoder.safetensors"),
|
12 |
+
("genmo/mochi-1-preview", "encoder.safetensors", "encoder.safetensors"),
|
13 |
+
("genmo/mochi-1-preview", "dit.safetensors", "dit.safetensors"),
|
14 |
+
]
|
15 |
+
|
16 |
+
FAST_MODEL_FILE = ("FastVideo/FastMochi", "dit.safetensors", "dit.fast.safetensors")
|
17 |
+
|
18 |
+
|
19 |
+
@click.command()
|
20 |
+
@click.argument('output_dir', required=True)
|
21 |
+
@click.option('--fast_model', is_flag=True, help='Download FastMochi model instead of standard model')
|
22 |
+
@click.option('--hf_transfer', is_flag=True, help='Enable faster downloads using hf_transfer (requires: pip install "huggingface_hub[hf_transfer]")')
|
23 |
+
def download_weights(output_dir, fast_model, hf_transfer):
|
24 |
+
if not os.path.exists(output_dir):
|
25 |
+
print(f"Creating output directory: {output_dir}")
|
26 |
+
os.makedirs(output_dir, exist_ok=True)
|
27 |
+
|
28 |
+
if hf_transfer:
|
29 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
30 |
+
print("Using hf_transfer for faster downloads (requires: pip install 'huggingface_hub[hf_transfer]')")
|
31 |
+
|
32 |
+
model_files = BASE_MODEL_FILES
|
33 |
+
if fast_model:
|
34 |
+
# Replace the standard DIT model with the fast model
|
35 |
+
model_files = [f for f in model_files if not f[2].startswith("dit.")]
|
36 |
+
model_files.append(FAST_MODEL_FILE)
|
37 |
+
|
38 |
+
for repo_id, remote_path, local_path in model_files:
|
39 |
+
local_file_path = os.path.join(output_dir, local_path)
|
40 |
+
if not os.path.exists(local_file_path):
|
41 |
+
if hf_transfer:
|
42 |
+
# I don't know if `hf_transfer` works with `snapshot_download`
|
43 |
+
print(f"Downloading {local_path} from {repo_id} to: {local_file_path}")
|
44 |
+
out_path = hf_hub_download(
|
45 |
+
repo_id=repo_id,
|
46 |
+
filename=remote_path,
|
47 |
+
local_dir=output_dir,
|
48 |
+
)
|
49 |
+
print(f"Copying {out_path} to {local_file_path}")
|
50 |
+
# copy instead of mv to avoid destroying huggingface cache
|
51 |
+
shutil.copy2(out_path, local_file_path)
|
52 |
+
else:
|
53 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
54 |
+
snapshot_download(
|
55 |
+
repo_id=repo_id,
|
56 |
+
allow_patterns=[f"*{remote_path}*"],
|
57 |
+
local_dir=tmp_dir,
|
58 |
+
local_dir_use_symlinks=False,
|
59 |
+
)
|
60 |
+
shutil.move(os.path.join(tmp_dir, remote_path), local_file_path)
|
61 |
+
else:
|
62 |
+
print(f"{local_path} already exists in: {local_file_path}")
|
63 |
+
assert os.path.exists(local_file_path), f"File {local_file_path} does not exist"
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
download_weights()
|
scripts/format.bash
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /bin/bash
|
2 |
+
set -euxo pipefail
|
3 |
+
ruff format src demos
|
4 |
+
ruff check --fix --select I src
|
5 |
+
ruff check --fix --select I demos
|
scripts/pytorch_to_safe_tensors.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import click
|
5 |
+
import torch
|
6 |
+
from safetensors.torch import save_file
|
7 |
+
|
8 |
+
|
9 |
+
@click.command()
|
10 |
+
@click.argument("input_path", type=click.Path(exists=True))
|
11 |
+
def convert_to_safetensors(input_path):
|
12 |
+
model = torch.load(input_path)
|
13 |
+
model = {
|
14 |
+
k: v.contiguous() for k, v in model.items()
|
15 |
+
}
|
16 |
+
assert 'vae_ema' not in model
|
17 |
+
input_path = Path(input_path)
|
18 |
+
output_path = input_path.with_suffix(".safetensors")
|
19 |
+
save_file(model, str(output_path))
|
20 |
+
click.echo(f"Converted {input_path} to {output_path}")
|
21 |
+
|
22 |
+
|
23 |
+
if __name__ == "__main__":
|
24 |
+
convert_to_safetensors()
|
scripts/typecheck.bash
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
#! /bin/bash
|
2 |
+
npx pyright
|
scripts/weights_to_fp8.py
ADDED
File without changes
|
src/genmo/lib/attn_imports.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import contextmanager
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
try:
|
7 |
+
from flash_attn import flash_attn_varlen_func as flash_varlen_attn
|
8 |
+
except ImportError:
|
9 |
+
flash_varlen_attn = None
|
10 |
+
|
11 |
+
try:
|
12 |
+
from sageattention import sageattn as sage_attn
|
13 |
+
except ImportError:
|
14 |
+
sage_attn = None
|
15 |
+
|
16 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
17 |
+
|
18 |
+
training_backends = [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
|
19 |
+
eval_backends = list(training_backends)
|
20 |
+
if torch.cuda.get_device_properties(0).major >= 9.0:
|
21 |
+
# Enable fast CuDNN attention on Hopper.
|
22 |
+
# This gives NaN on the backward pass for some reason,
|
23 |
+
# so only use it for evaluation.
|
24 |
+
eval_backends.append(SDPBackend.CUDNN_ATTENTION)
|
25 |
+
|
26 |
+
@contextmanager
|
27 |
+
def sdpa_attn_ctx(training: bool = False):
|
28 |
+
with sdpa_kernel(training_backends if training else eval_backends):
|
29 |
+
yield
|
src/genmo/lib/progress.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
from typing import Any, Iterable, Iterator, Optional
|
3 |
+
|
4 |
+
try:
|
5 |
+
from tqdm import tqdm
|
6 |
+
except ImportError:
|
7 |
+
tqdm = None
|
8 |
+
|
9 |
+
try:
|
10 |
+
from ray.experimental.tqdm_ray import tqdm as ray_tqdm
|
11 |
+
except:
|
12 |
+
ray_tqdm = None
|
13 |
+
|
14 |
+
# Global state
|
15 |
+
_current_progress_type = "none"
|
16 |
+
_is_progress_bar_active = False
|
17 |
+
|
18 |
+
|
19 |
+
class DummyProgressBar:
|
20 |
+
"""A no-op progress bar that mimics tqdm interface"""
|
21 |
+
|
22 |
+
def __init__(self, iterable=None, **kwargs):
|
23 |
+
self.iterable = iterable
|
24 |
+
|
25 |
+
def __iter__(self):
|
26 |
+
return iter(self.iterable)
|
27 |
+
|
28 |
+
def update(self, n=1):
|
29 |
+
pass
|
30 |
+
|
31 |
+
def close(self):
|
32 |
+
pass
|
33 |
+
|
34 |
+
def set_description(self, desc):
|
35 |
+
pass
|
36 |
+
|
37 |
+
|
38 |
+
def get_new_progress_bar(iterable: Optional[Iterable] = None, **kwargs) -> Any:
|
39 |
+
if not _is_progress_bar_active:
|
40 |
+
return DummyProgressBar(iterable=iterable, **kwargs)
|
41 |
+
|
42 |
+
if _current_progress_type == "tqdm":
|
43 |
+
if tqdm is None:
|
44 |
+
raise ImportError("tqdm is required but not installed. Please install tqdm to use the tqdm progress bar.")
|
45 |
+
return tqdm(iterable=iterable, **kwargs)
|
46 |
+
elif _current_progress_type == "ray_tqdm":
|
47 |
+
if ray_tqdm is None:
|
48 |
+
raise ImportError("ray is required but not installed. Please install ray to use the ray_tqdm progress bar.")
|
49 |
+
return ray_tqdm(iterable=iterable, **kwargs)
|
50 |
+
return DummyProgressBar(iterable=iterable, **kwargs)
|
51 |
+
|
52 |
+
|
53 |
+
@contextlib.contextmanager
|
54 |
+
def progress_bar(type: str = "none", enabled=True):
|
55 |
+
"""
|
56 |
+
Context manager for setting progress bar type and options.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
type: Type of progress bar ("none" or "tqdm")
|
60 |
+
**options: Options to pass to the progress bar (e.g., total, desc)
|
61 |
+
|
62 |
+
Raises:
|
63 |
+
ValueError: If progress bar type is invalid
|
64 |
+
RuntimeError: If progress bars are nested
|
65 |
+
|
66 |
+
Example:
|
67 |
+
with progress_bar(type="tqdm", total=100):
|
68 |
+
for i in get_new_progress_bar(range(100)):
|
69 |
+
process(i)
|
70 |
+
"""
|
71 |
+
if type not in ("none", "tqdm", "ray_tqdm"):
|
72 |
+
raise ValueError("Progress bar type must be 'none' or 'tqdm' or 'ray_tqdm'")
|
73 |
+
if not enabled:
|
74 |
+
type = "none"
|
75 |
+
global _current_progress_type, _is_progress_bar_active
|
76 |
+
|
77 |
+
if _is_progress_bar_active:
|
78 |
+
raise RuntimeError("Nested progress bars are not supported")
|
79 |
+
|
80 |
+
_is_progress_bar_active = True
|
81 |
+
_current_progress_type = type
|
82 |
+
|
83 |
+
try:
|
84 |
+
yield
|
85 |
+
finally:
|
86 |
+
_is_progress_bar_active = False
|
87 |
+
_current_progress_type = "none"
|
src/genmo/lib/utils.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
+
import tempfile
|
4 |
+
import time
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from moviepy.editor import ImageSequenceClip
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from genmo.lib.progress import get_new_progress_bar
|
11 |
+
|
12 |
+
|
13 |
+
class Timer:
|
14 |
+
def __init__(self):
|
15 |
+
self.times = {} # Dictionary to store times per stage
|
16 |
+
|
17 |
+
def __call__(self, name):
|
18 |
+
print(f"Timing {name}")
|
19 |
+
return self.TimerContextManager(self, name)
|
20 |
+
|
21 |
+
def print_stats(self):
|
22 |
+
total_time = sum(self.times.values())
|
23 |
+
# Print table header
|
24 |
+
print("{:<20} {:>10} {:>10}".format("Stage", "Time(s)", "Percent"))
|
25 |
+
for name, t in self.times.items():
|
26 |
+
percent = (t / total_time) * 100 if total_time > 0 else 0
|
27 |
+
print("{:<20} {:>10.2f} {:>9.2f}%".format(name, t, percent))
|
28 |
+
|
29 |
+
class TimerContextManager:
|
30 |
+
def __init__(self, outer, name):
|
31 |
+
self.outer = outer # Reference to the Timer instance
|
32 |
+
self.name = name
|
33 |
+
self.start_time = None
|
34 |
+
|
35 |
+
def __enter__(self):
|
36 |
+
self.start_time = time.perf_counter()
|
37 |
+
return self
|
38 |
+
|
39 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
40 |
+
end_time = time.perf_counter()
|
41 |
+
elapsed = end_time - self.start_time
|
42 |
+
self.outer.times[self.name] = self.outer.times.get(self.name, 0) + elapsed
|
43 |
+
|
44 |
+
|
45 |
+
def save_video(final_frames, output_path, fps=30):
|
46 |
+
assert final_frames.ndim == 4 and final_frames.shape[3] == 3, f"invalid shape: {final_frames} (need t h w c)"
|
47 |
+
if final_frames.dtype != np.uint8:
|
48 |
+
final_frames = (final_frames * 255).astype(np.uint8)
|
49 |
+
ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path)
|
50 |
+
|
51 |
+
|
52 |
+
def create_memory_tracker():
|
53 |
+
import torch
|
54 |
+
|
55 |
+
previous = [None] # Use list for mutable closure state
|
56 |
+
|
57 |
+
def track(label="all2all"):
|
58 |
+
current = torch.cuda.memory_allocated() / 1e9
|
59 |
+
if previous[0] is not None:
|
60 |
+
diff = current - previous[0]
|
61 |
+
sign = "+" if diff >= 0 else ""
|
62 |
+
print(f"GPU memory ({label}): {current:.2f} GB ({sign}{diff:.2f} GB)")
|
63 |
+
else:
|
64 |
+
print(f"GPU memory ({label}): {current:.2f} GB")
|
65 |
+
previous[0] = current # type: ignore
|
66 |
+
|
67 |
+
return track
|
src/genmo/mochi_preview/__init__.py
ADDED
File without changes
|
src/genmo/mochi_preview/dit/joint_model/__init__.py
ADDED
File without changes
|
src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py
ADDED
@@ -0,0 +1,737 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Dict, List, Optional, Tuple
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange
|
9 |
+
from torch.nn.attention import sdpa_kernel
|
10 |
+
|
11 |
+
import genmo.mochi_preview.dit.joint_model.context_parallel as cp
|
12 |
+
from genmo.lib.attn_imports import flash_varlen_attn, sage_attn, sdpa_attn_ctx
|
13 |
+
from genmo.mochi_preview.dit.joint_model.layers import (
|
14 |
+
FeedForward,
|
15 |
+
PatchEmbed,
|
16 |
+
RMSNorm,
|
17 |
+
TimestepEmbedder,
|
18 |
+
)
|
19 |
+
from genmo.mochi_preview.dit.joint_model.lora import LoraLinear
|
20 |
+
from genmo.mochi_preview.dit.joint_model.mod_rmsnorm import modulated_rmsnorm
|
21 |
+
from genmo.mochi_preview.dit.joint_model.residual_tanh_gated_rmsnorm import (
|
22 |
+
residual_tanh_gated_rmsnorm,
|
23 |
+
)
|
24 |
+
from genmo.mochi_preview.dit.joint_model.rope_mixed import (
|
25 |
+
compute_mixed_rotation,
|
26 |
+
create_position_matrix,
|
27 |
+
)
|
28 |
+
from genmo.mochi_preview.dit.joint_model.temporal_rope import apply_rotary_emb_qk_real
|
29 |
+
from genmo.mochi_preview.dit.joint_model.utils import (
|
30 |
+
AttentionPool,
|
31 |
+
modulate,
|
32 |
+
pad_and_split_xy,
|
33 |
+
)
|
34 |
+
|
35 |
+
COMPILE_FINAL_LAYER = os.environ.get("COMPILE_DIT") == "1"
|
36 |
+
COMPILE_MMDIT_BLOCK = os.environ.get("COMPILE_DIT") == "1"
|
37 |
+
|
38 |
+
|
39 |
+
def ck(fn, *args, enabled=True, **kwargs) -> torch.Tensor:
|
40 |
+
if enabled:
|
41 |
+
return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs, use_reentrant=False)
|
42 |
+
|
43 |
+
return fn(*args, **kwargs)
|
44 |
+
|
45 |
+
|
46 |
+
class AsymmetricAttention(nn.Module):
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
dim_x: int,
|
50 |
+
dim_y: int,
|
51 |
+
num_heads: int = 8,
|
52 |
+
qkv_bias: bool = True,
|
53 |
+
qk_norm: bool = False,
|
54 |
+
update_y: bool = True,
|
55 |
+
out_bias: bool = True,
|
56 |
+
attention_mode: str = "flash",
|
57 |
+
softmax_scale: Optional[float] = None,
|
58 |
+
device: Optional[torch.device] = None,
|
59 |
+
# Disable LoRA by default ...
|
60 |
+
qkv_proj_lora_rank: int = 0,
|
61 |
+
qkv_proj_lora_alpha: int = 0,
|
62 |
+
qkv_proj_lora_dropout: float = 0.0,
|
63 |
+
out_proj_lora_rank: int = 0,
|
64 |
+
out_proj_lora_alpha: int = 0,
|
65 |
+
out_proj_lora_dropout: float = 0.0,
|
66 |
+
):
|
67 |
+
super().__init__()
|
68 |
+
self.attention_mode = attention_mode
|
69 |
+
self.dim_x = dim_x
|
70 |
+
self.dim_y = dim_y
|
71 |
+
self.num_heads = num_heads
|
72 |
+
self.head_dim = dim_x // num_heads
|
73 |
+
self.update_y = update_y
|
74 |
+
self.softmax_scale = softmax_scale
|
75 |
+
if dim_x % num_heads != 0:
|
76 |
+
raise ValueError(f"dim_x={dim_x} should be divisible by num_heads={num_heads}")
|
77 |
+
|
78 |
+
# Input layers.
|
79 |
+
self.qkv_bias = qkv_bias
|
80 |
+
qkv_lora_kwargs = dict(
|
81 |
+
bias=qkv_bias,
|
82 |
+
device=device,
|
83 |
+
r=qkv_proj_lora_rank,
|
84 |
+
lora_alpha=qkv_proj_lora_alpha,
|
85 |
+
lora_dropout=qkv_proj_lora_dropout,
|
86 |
+
)
|
87 |
+
self.qkv_x = LoraLinear(dim_x, 3 * dim_x, **qkv_lora_kwargs)
|
88 |
+
# Project text features to match visual features (dim_y -> dim_x)
|
89 |
+
self.qkv_y = LoraLinear(dim_y, 3 * dim_x, **qkv_lora_kwargs)
|
90 |
+
|
91 |
+
# Query and key normalization for stability.
|
92 |
+
assert qk_norm
|
93 |
+
self.q_norm_x = RMSNorm(self.head_dim, device=device)
|
94 |
+
self.k_norm_x = RMSNorm(self.head_dim, device=device)
|
95 |
+
self.q_norm_y = RMSNorm(self.head_dim, device=device)
|
96 |
+
self.k_norm_y = RMSNorm(self.head_dim, device=device)
|
97 |
+
|
98 |
+
# Output layers. y features go back down from dim_x -> dim_y.
|
99 |
+
proj_lora_kwargs = dict(
|
100 |
+
bias=out_bias,
|
101 |
+
device=device,
|
102 |
+
r=out_proj_lora_rank,
|
103 |
+
lora_alpha=out_proj_lora_alpha,
|
104 |
+
lora_dropout=out_proj_lora_dropout,
|
105 |
+
)
|
106 |
+
self.proj_x = LoraLinear(dim_x, dim_x, **proj_lora_kwargs)
|
107 |
+
self.proj_y = LoraLinear(dim_x, dim_y, **proj_lora_kwargs) if update_y else nn.Identity()
|
108 |
+
|
109 |
+
def run_qkv_y(self, y):
|
110 |
+
cp_rank, cp_size = cp.get_cp_rank_size()
|
111 |
+
local_heads = self.num_heads // cp_size
|
112 |
+
|
113 |
+
if cp.is_cp_active():
|
114 |
+
# Only predict local heads.
|
115 |
+
assert not self.qkv_bias
|
116 |
+
W_qkv_y = self.qkv_y.weight.view(3, self.num_heads, self.head_dim, self.dim_y)
|
117 |
+
W_qkv_y = W_qkv_y.narrow(1, cp_rank * local_heads, local_heads)
|
118 |
+
W_qkv_y = W_qkv_y.reshape(3 * local_heads * self.head_dim, self.dim_y)
|
119 |
+
qkv_y = F.linear(y, W_qkv_y, None) # (B, L, 3 * local_h * head_dim)
|
120 |
+
else:
|
121 |
+
qkv_y = self.qkv_y(y) # (B, L, 3 * dim)
|
122 |
+
|
123 |
+
qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim)
|
124 |
+
q_y, k_y, v_y = qkv_y.unbind(2)
|
125 |
+
|
126 |
+
q_y = self.q_norm_y(q_y)
|
127 |
+
k_y = self.k_norm_y(k_y)
|
128 |
+
return q_y, k_y, v_y
|
129 |
+
|
130 |
+
def prepare_qkv(
|
131 |
+
self,
|
132 |
+
x: torch.Tensor, # (B, M, dim_x)
|
133 |
+
y: torch.Tensor, # (B, L, dim_y)
|
134 |
+
*,
|
135 |
+
scale_x: torch.Tensor,
|
136 |
+
scale_y: torch.Tensor,
|
137 |
+
rope_cos: torch.Tensor,
|
138 |
+
rope_sin: torch.Tensor,
|
139 |
+
valid_token_indices: torch.Tensor,
|
140 |
+
max_seqlen_in_batch: int,
|
141 |
+
):
|
142 |
+
# Process visual features
|
143 |
+
x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
|
144 |
+
qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
|
145 |
+
assert qkv_x.dtype == torch.bfloat16
|
146 |
+
|
147 |
+
qkv_x = cp.all_to_all_collect_tokens(qkv_x, self.num_heads) # (3, B, N, local_h, head_dim)
|
148 |
+
|
149 |
+
# Split qkv_x into q, k, v
|
150 |
+
q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim)
|
151 |
+
q_x = self.q_norm_x(q_x)
|
152 |
+
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
|
153 |
+
k_x = self.k_norm_x(k_x)
|
154 |
+
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
|
155 |
+
|
156 |
+
# Concatenate streams
|
157 |
+
B, N, num_heads, head_dim = q_x.size()
|
158 |
+
D = num_heads * head_dim
|
159 |
+
|
160 |
+
# Process text features
|
161 |
+
if B == 1:
|
162 |
+
text_seqlen = max_seqlen_in_batch - N
|
163 |
+
if text_seqlen > 0:
|
164 |
+
y = y[:, :text_seqlen] # Remove padding tokens.
|
165 |
+
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
|
166 |
+
q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim)
|
167 |
+
|
168 |
+
q = torch.cat([q_x, q_y], dim=1)
|
169 |
+
k = torch.cat([k_x, k_y], dim=1)
|
170 |
+
v = torch.cat([v_x, v_y], dim=1)
|
171 |
+
else:
|
172 |
+
q, k, v = q_x, k_x, v_x
|
173 |
+
else:
|
174 |
+
y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
|
175 |
+
q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim)
|
176 |
+
|
177 |
+
indices = valid_token_indices[:, None].expand(-1, D)
|
178 |
+
q = torch.cat([q_x, q_y], dim=1).view(-1, D).gather(0, indices) # (total, D)
|
179 |
+
k = torch.cat([k_x, k_y], dim=1).view(-1, D).gather(0, indices) # (total, D)
|
180 |
+
v = torch.cat([v_x, v_y], dim=1).view(-1, D).gather(0, indices) # (total, D)
|
181 |
+
|
182 |
+
q = q.view(-1, num_heads, head_dim)
|
183 |
+
k = k.view(-1, num_heads, head_dim)
|
184 |
+
v = v.view(-1, num_heads, head_dim)
|
185 |
+
return q, k, v
|
186 |
+
|
187 |
+
@torch.autocast("cuda", enabled=False)
|
188 |
+
def flash_attention(self, q, k, v, cu_seqlens, max_seqlen_in_batch, total, local_dim):
|
189 |
+
out: torch.Tensor = flash_varlen_attn(
|
190 |
+
q, k, v,
|
191 |
+
cu_seqlens_q=cu_seqlens,
|
192 |
+
cu_seqlens_k=cu_seqlens,
|
193 |
+
max_seqlen_q=max_seqlen_in_batch,
|
194 |
+
max_seqlen_k=max_seqlen_in_batch,
|
195 |
+
dropout_p=0.0,
|
196 |
+
softmax_scale=self.softmax_scale,
|
197 |
+
) # (total, local_heads, head_dim)
|
198 |
+
return out.view(total, local_dim)
|
199 |
+
|
200 |
+
def sdpa_attention(self, q, k, v):
|
201 |
+
with sdpa_attn_ctx(training=self.training):
|
202 |
+
out = F.scaled_dot_product_attention(
|
203 |
+
q, k, v,
|
204 |
+
attn_mask=None,
|
205 |
+
dropout_p=0.0,
|
206 |
+
is_causal=False,
|
207 |
+
)
|
208 |
+
return out
|
209 |
+
|
210 |
+
@torch.autocast("cuda", enabled=False)
|
211 |
+
def sage_attention(self, q, k, v):
|
212 |
+
return sage_attn(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
213 |
+
|
214 |
+
def run_attention(
|
215 |
+
self,
|
216 |
+
q: torch.Tensor, # (total <= B * (N + L), num_heads, head_dim)
|
217 |
+
k: torch.Tensor, # (total <= B * (N + L), num_heads, head_dim)
|
218 |
+
v: torch.Tensor, # (total <= B * (N + L), num_heads, head_dim)
|
219 |
+
*,
|
220 |
+
B: int,
|
221 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
222 |
+
max_seqlen_in_batch: Optional[int] = None,
|
223 |
+
):
|
224 |
+
_, cp_size = cp.get_cp_rank_size()
|
225 |
+
assert self.num_heads % cp_size == 0
|
226 |
+
local_heads = self.num_heads // cp_size
|
227 |
+
local_dim = local_heads * self.head_dim
|
228 |
+
|
229 |
+
# Check shapes
|
230 |
+
assert q.ndim == 3 and k.ndim == 3 and v.ndim == 3
|
231 |
+
total = q.size(0)
|
232 |
+
assert k.size(0) == total and v.size(0) == total
|
233 |
+
|
234 |
+
if self.attention_mode == "flash":
|
235 |
+
out = self.flash_attention(
|
236 |
+
q, k, v, cu_seqlens, max_seqlen_in_batch, total, local_dim) # (total, local_dim)
|
237 |
+
else:
|
238 |
+
assert B == 1, \
|
239 |
+
f"Non-flash attention mode {self.attention_mode} only supports batch size 1, got {B}"
|
240 |
+
|
241 |
+
q = rearrange(q, "(b s) h d -> b h s d", b=B)
|
242 |
+
k = rearrange(k, "(b s) h d -> b h s d", b=B)
|
243 |
+
v = rearrange(v, "(b s) h d -> b h s d", b=B)
|
244 |
+
|
245 |
+
if self.attention_mode == "sdpa":
|
246 |
+
out = self.sdpa_attention(q, k, v) # (B, local_heads, seq_len, head_dim)
|
247 |
+
elif self.attention_mode == "sage":
|
248 |
+
out = self.sage_attention(q, k, v) # (B, local_heads, seq_len, head_dim)
|
249 |
+
else:
|
250 |
+
raise ValueError(f"Unknown attention mode: {self.attention_mode}")
|
251 |
+
|
252 |
+
out = rearrange(out, "b h s d -> (b s) (h d)")
|
253 |
+
|
254 |
+
return out
|
255 |
+
|
256 |
+
def post_attention(
|
257 |
+
self,
|
258 |
+
out: torch.Tensor,
|
259 |
+
B: int,
|
260 |
+
M: int,
|
261 |
+
L: int,
|
262 |
+
dtype: torch.dtype,
|
263 |
+
valid_token_indices: torch.Tensor,
|
264 |
+
):
|
265 |
+
"""
|
266 |
+
Args:
|
267 |
+
out: (total <= B * (N + L), local_dim)
|
268 |
+
valid_token_indices: (total <= B * (N + L),)
|
269 |
+
B: Batch size
|
270 |
+
M: Number of visual tokens per context parallel rank
|
271 |
+
L: Number of text tokens
|
272 |
+
dtype: Data type of the input and output tensors
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
x: (B, N, dim_x) tensor of visual tokens where N = M * cp_size
|
276 |
+
y: (B, L, dim_y) tensor of text token features
|
277 |
+
"""
|
278 |
+
_, cp_size = cp.get_cp_rank_size()
|
279 |
+
local_heads = self.num_heads // cp_size
|
280 |
+
local_dim = local_heads * self.head_dim
|
281 |
+
N = M * cp_size
|
282 |
+
|
283 |
+
# Split sequence into visual and text tokens, adding back padding.
|
284 |
+
if B == 1:
|
285 |
+
out = out.view(B, -1, local_dim)
|
286 |
+
if out.size(1) > N:
|
287 |
+
x, y = torch.tensor_split(out, (N,), dim=1) # (B, N, local_dim), (B, <= L, local_dim)
|
288 |
+
y = F.pad(y, (0, 0, 0, L - y.size(1))) # (B, L, local_dim)
|
289 |
+
else:
|
290 |
+
# Empty prompt.
|
291 |
+
x, y = out, out.new_zeros(B, L, local_dim)
|
292 |
+
else:
|
293 |
+
x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, dtype)
|
294 |
+
assert x.size() == (B, N, local_dim)
|
295 |
+
assert y.size() == (B, L, local_dim)
|
296 |
+
|
297 |
+
# Communicate across context parallel ranks.
|
298 |
+
x = x.view(B, N, local_heads, self.head_dim)
|
299 |
+
x = cp.all_to_all_collect_heads(x) # (B, M, dim_x = num_heads * head_dim)
|
300 |
+
if cp.is_cp_active():
|
301 |
+
y = cp.all_gather(y) # (cp_size * B, L, local_heads * head_dim)
|
302 |
+
y = rearrange(y, "(G B) L D -> B L (G D)", G=cp_size, D=local_dim) # (B, L, dim_x)
|
303 |
+
|
304 |
+
x = self.proj_x(x)
|
305 |
+
y = self.proj_y(y)
|
306 |
+
return x, y
|
307 |
+
|
308 |
+
def forward(
|
309 |
+
self,
|
310 |
+
x: torch.Tensor, # (B, M, dim_x)
|
311 |
+
y: torch.Tensor, # (B, L, dim_y)
|
312 |
+
*,
|
313 |
+
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
|
314 |
+
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
|
315 |
+
packed_indices: Dict[str, torch.Tensor] = None,
|
316 |
+
checkpoint_qkv: bool = False,
|
317 |
+
checkpoint_post_attn: bool = False,
|
318 |
+
**rope_rotation,
|
319 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
320 |
+
"""Forward pass of asymmetric multi-modal attention.
|
321 |
+
|
322 |
+
Args:
|
323 |
+
x: (B, M, dim_x) tensor of visual tokens
|
324 |
+
y: (B, L, dim_y) tensor of text token features
|
325 |
+
packed_indices: Dict with keys for Flash Attention
|
326 |
+
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
|
327 |
+
|
328 |
+
Returns:
|
329 |
+
x: (B, M, dim_x) tensor of visual tokens after multi-modal attention
|
330 |
+
y: (B, L, dim_y) tensor of text token features after multi-modal attention
|
331 |
+
"""
|
332 |
+
B, L, _ = y.shape
|
333 |
+
_, M, _ = x.shape
|
334 |
+
|
335 |
+
# Predict a packed QKV tensor from visual and text features.
|
336 |
+
q, k, v = ck(self.prepare_qkv,
|
337 |
+
x=x,
|
338 |
+
y=y,
|
339 |
+
scale_x=scale_x,
|
340 |
+
scale_y=scale_y,
|
341 |
+
rope_cos=rope_rotation.get("rope_cos"),
|
342 |
+
rope_sin=rope_rotation.get("rope_sin"),
|
343 |
+
valid_token_indices=packed_indices["valid_token_indices_kv"],
|
344 |
+
max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"],
|
345 |
+
enabled=checkpoint_qkv,
|
346 |
+
) # (total <= B * (N + L), 3, local_heads, head_dim)
|
347 |
+
|
348 |
+
# Self-attention is expensive, so don't checkpoint it.
|
349 |
+
out = self.run_attention(
|
350 |
+
q, k, v, B=B,
|
351 |
+
cu_seqlens=packed_indices["cu_seqlens_kv"],
|
352 |
+
max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"],
|
353 |
+
)
|
354 |
+
|
355 |
+
x, y = ck(self.post_attention,
|
356 |
+
out,
|
357 |
+
B=B, M=M, L=L,
|
358 |
+
dtype=v.dtype,
|
359 |
+
valid_token_indices=packed_indices["valid_token_indices_kv"],
|
360 |
+
enabled=checkpoint_post_attn,
|
361 |
+
)
|
362 |
+
|
363 |
+
return x, y
|
364 |
+
|
365 |
+
|
366 |
+
@torch.compile(disable=not COMPILE_MMDIT_BLOCK)
|
367 |
+
class AsymmetricJointBlock(nn.Module):
|
368 |
+
def __init__(
|
369 |
+
self,
|
370 |
+
hidden_size_x: int,
|
371 |
+
hidden_size_y: int,
|
372 |
+
num_heads: int,
|
373 |
+
*,
|
374 |
+
mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
|
375 |
+
mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
|
376 |
+
update_y: bool = True, # Whether to update text tokens in this block.
|
377 |
+
device: Optional[torch.device] = None,
|
378 |
+
**block_kwargs,
|
379 |
+
):
|
380 |
+
super().__init__()
|
381 |
+
self.update_y = update_y
|
382 |
+
self.hidden_size_x = hidden_size_x
|
383 |
+
self.hidden_size_y = hidden_size_y
|
384 |
+
self.mod_x = nn.Linear(hidden_size_x, 4 * hidden_size_x, device=device)
|
385 |
+
if self.update_y:
|
386 |
+
self.mod_y = nn.Linear(hidden_size_x, 4 * hidden_size_y, device=device)
|
387 |
+
else:
|
388 |
+
self.mod_y = nn.Linear(hidden_size_x, hidden_size_y, device=device)
|
389 |
+
|
390 |
+
# Self-attention:
|
391 |
+
self.attn = AsymmetricAttention(
|
392 |
+
hidden_size_x,
|
393 |
+
hidden_size_y,
|
394 |
+
num_heads=num_heads,
|
395 |
+
update_y=update_y,
|
396 |
+
device=device,
|
397 |
+
**block_kwargs,
|
398 |
+
)
|
399 |
+
|
400 |
+
# MLP.
|
401 |
+
mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
|
402 |
+
assert mlp_hidden_dim_x == int(1536 * 8)
|
403 |
+
self.mlp_x = FeedForward(
|
404 |
+
in_features=hidden_size_x,
|
405 |
+
hidden_size=mlp_hidden_dim_x,
|
406 |
+
multiple_of=256,
|
407 |
+
ffn_dim_multiplier=None,
|
408 |
+
device=device,
|
409 |
+
)
|
410 |
+
|
411 |
+
# MLP for text not needed in last block.
|
412 |
+
if self.update_y:
|
413 |
+
mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y)
|
414 |
+
self.mlp_y = FeedForward(
|
415 |
+
in_features=hidden_size_y,
|
416 |
+
hidden_size=mlp_hidden_dim_y,
|
417 |
+
multiple_of=256,
|
418 |
+
ffn_dim_multiplier=None,
|
419 |
+
device=device,
|
420 |
+
)
|
421 |
+
|
422 |
+
def forward(
|
423 |
+
self,
|
424 |
+
x: torch.Tensor,
|
425 |
+
c: torch.Tensor,
|
426 |
+
y: torch.Tensor,
|
427 |
+
# TODO: These could probably just go into attn_kwargs
|
428 |
+
checkpoint_ff: bool = False,
|
429 |
+
checkpoint_qkv: bool = False,
|
430 |
+
checkpoint_post_attn: bool = False,
|
431 |
+
**attn_kwargs,
|
432 |
+
):
|
433 |
+
"""Forward pass of a block.
|
434 |
+
|
435 |
+
Args:
|
436 |
+
x: (B, N, dim) tensor of visual tokens
|
437 |
+
c: (B, dim) tensor of conditioned features
|
438 |
+
y: (B, L, dim) tensor of text tokens
|
439 |
+
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
|
440 |
+
|
441 |
+
Returns:
|
442 |
+
x: (B, N, dim) tensor of visual tokens after block
|
443 |
+
y: (B, L, dim) tensor of text tokens after block
|
444 |
+
"""
|
445 |
+
N = x.size(1)
|
446 |
+
|
447 |
+
c = F.silu(c)
|
448 |
+
mod_x = self.mod_x(c)
|
449 |
+
scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1)
|
450 |
+
mod_y = self.mod_y(c)
|
451 |
+
|
452 |
+
if self.update_y:
|
453 |
+
scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1)
|
454 |
+
else:
|
455 |
+
scale_msa_y = mod_y
|
456 |
+
|
457 |
+
# Self-attention block.
|
458 |
+
x_attn, y_attn = self.attn(
|
459 |
+
x,
|
460 |
+
y,
|
461 |
+
scale_x=scale_msa_x,
|
462 |
+
scale_y=scale_msa_y,
|
463 |
+
checkpoint_qkv=checkpoint_qkv,
|
464 |
+
checkpoint_post_attn=checkpoint_post_attn,
|
465 |
+
**attn_kwargs,
|
466 |
+
)
|
467 |
+
|
468 |
+
assert x_attn.size(1) == N
|
469 |
+
x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
|
470 |
+
|
471 |
+
if self.update_y:
|
472 |
+
y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
|
473 |
+
|
474 |
+
# MLP block.
|
475 |
+
x = ck(self.ff_block_x, x, scale_mlp_x, gate_mlp_x, enabled=checkpoint_ff)
|
476 |
+
if self.update_y:
|
477 |
+
y = ck(self.ff_block_y, y, scale_mlp_y, gate_mlp_y, enabled=checkpoint_ff) # type: ignore
|
478 |
+
return x, y
|
479 |
+
|
480 |
+
def ff_block_x(self, x, scale_x, gate_x):
|
481 |
+
x_mod = modulated_rmsnorm(x, scale_x)
|
482 |
+
x_res = self.mlp_x(x_mod)
|
483 |
+
x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
|
484 |
+
return x
|
485 |
+
|
486 |
+
def ff_block_y(self, y, scale_y, gate_y):
|
487 |
+
y_mod = modulated_rmsnorm(y, scale_y)
|
488 |
+
y_res = self.mlp_y(y_mod)
|
489 |
+
y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
|
490 |
+
return y
|
491 |
+
|
492 |
+
|
493 |
+
@torch.compile(disable=not COMPILE_FINAL_LAYER)
|
494 |
+
class FinalLayer(nn.Module):
|
495 |
+
"""
|
496 |
+
The final layer of DiT.
|
497 |
+
"""
|
498 |
+
|
499 |
+
def __init__(
|
500 |
+
self,
|
501 |
+
hidden_size,
|
502 |
+
patch_size,
|
503 |
+
out_channels,
|
504 |
+
device: Optional[torch.device] = None,
|
505 |
+
):
|
506 |
+
super().__init__()
|
507 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, device=device)
|
508 |
+
self.mod = nn.Linear(hidden_size, 2 * hidden_size, device=device)
|
509 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, device=device)
|
510 |
+
|
511 |
+
def forward(self, x, c):
|
512 |
+
c = F.silu(c)
|
513 |
+
shift, scale = self.mod(c).chunk(2, dim=1)
|
514 |
+
x = modulate(self.norm_final(x), shift, scale)
|
515 |
+
x = self.linear(x)
|
516 |
+
return x
|
517 |
+
|
518 |
+
|
519 |
+
class AsymmDiTJoint(nn.Module):
|
520 |
+
"""
|
521 |
+
Diffusion model with a Transformer backbone.
|
522 |
+
|
523 |
+
Ingests text embeddings instead of a label.
|
524 |
+
"""
|
525 |
+
|
526 |
+
def __init__(
|
527 |
+
self,
|
528 |
+
*,
|
529 |
+
patch_size=2,
|
530 |
+
in_channels=4,
|
531 |
+
hidden_size_x=1152,
|
532 |
+
hidden_size_y=1152,
|
533 |
+
depth=48,
|
534 |
+
num_heads=16,
|
535 |
+
mlp_ratio_x=8.0,
|
536 |
+
mlp_ratio_y=4.0,
|
537 |
+
t5_feat_dim: int = 4096,
|
538 |
+
t5_token_length: int = 256,
|
539 |
+
patch_embed_bias: bool = True,
|
540 |
+
timestep_mlp_bias: bool = True,
|
541 |
+
timestep_scale: Optional[float] = None,
|
542 |
+
use_extended_posenc: bool = False,
|
543 |
+
rope_theta: float = 10000.0,
|
544 |
+
device: Optional[torch.device] = None,
|
545 |
+
**block_kwargs,
|
546 |
+
):
|
547 |
+
super().__init__()
|
548 |
+
self.in_channels = in_channels
|
549 |
+
self.out_channels = in_channels
|
550 |
+
self.patch_size = patch_size
|
551 |
+
self.num_heads = num_heads
|
552 |
+
self.hidden_size_x = hidden_size_x
|
553 |
+
self.hidden_size_y = hidden_size_y
|
554 |
+
self.head_dim = hidden_size_x // num_heads # Head dimension and count is determined by visual.
|
555 |
+
self.use_extended_posenc = use_extended_posenc
|
556 |
+
self.t5_token_length = t5_token_length
|
557 |
+
self.t5_feat_dim = t5_feat_dim
|
558 |
+
self.rope_theta = rope_theta # Scaling factor for frequency computation for temporal RoPE.
|
559 |
+
|
560 |
+
self.x_embedder = PatchEmbed(
|
561 |
+
patch_size=patch_size,
|
562 |
+
in_chans=in_channels,
|
563 |
+
embed_dim=hidden_size_x,
|
564 |
+
bias=patch_embed_bias,
|
565 |
+
device=device,
|
566 |
+
)
|
567 |
+
# Conditionings
|
568 |
+
# Timestep
|
569 |
+
self.t_embedder = TimestepEmbedder(hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale)
|
570 |
+
|
571 |
+
# Caption Pooling (T5)
|
572 |
+
self.t5_y_embedder = AttentionPool(t5_feat_dim, num_heads=8, output_dim=hidden_size_x, device=device)
|
573 |
+
|
574 |
+
# Dense Embedding Projection (T5)
|
575 |
+
self.t5_yproj = nn.Linear(t5_feat_dim, hidden_size_y, bias=True, device=device)
|
576 |
+
|
577 |
+
# Initialize pos_frequencies as an empty parameter.
|
578 |
+
self.pos_frequencies = nn.Parameter(torch.empty(3, self.num_heads, self.head_dim // 2, device=device))
|
579 |
+
|
580 |
+
# for depth 48:
|
581 |
+
# b = 0: AsymmetricJointBlock, update_y=True
|
582 |
+
# b = 1: AsymmetricJointBlock, update_y=True
|
583 |
+
# ...
|
584 |
+
# b = 46: AsymmetricJointBlock, update_y=True
|
585 |
+
# b = 47: AsymmetricJointBlock, update_y=False. No need to update text features.
|
586 |
+
blocks = []
|
587 |
+
for b in range(depth):
|
588 |
+
# Joint multi-modal block
|
589 |
+
update_y = b < depth - 1
|
590 |
+
block = AsymmetricJointBlock(
|
591 |
+
hidden_size_x,
|
592 |
+
hidden_size_y,
|
593 |
+
num_heads,
|
594 |
+
mlp_ratio_x=mlp_ratio_x,
|
595 |
+
mlp_ratio_y=mlp_ratio_y,
|
596 |
+
update_y=update_y,
|
597 |
+
device=device,
|
598 |
+
**block_kwargs,
|
599 |
+
)
|
600 |
+
|
601 |
+
blocks.append(block)
|
602 |
+
self.blocks = nn.ModuleList(blocks)
|
603 |
+
|
604 |
+
self.final_layer = FinalLayer(hidden_size_x, patch_size, self.out_channels, device=device)
|
605 |
+
|
606 |
+
def embed_x(self, x: torch.Tensor) -> torch.Tensor:
|
607 |
+
"""
|
608 |
+
Args:
|
609 |
+
x: (B, C=12, T, H, W) tensor of visual tokens
|
610 |
+
|
611 |
+
Returns:
|
612 |
+
x: (B, C=3072, N) tensor of visual tokens with positional embedding.
|
613 |
+
"""
|
614 |
+
return self.x_embedder(x) # Convert BcTHW to BCN
|
615 |
+
|
616 |
+
@torch.compile(disable=not COMPILE_MMDIT_BLOCK)
|
617 |
+
def prepare(
|
618 |
+
self,
|
619 |
+
x: torch.Tensor,
|
620 |
+
sigma: torch.Tensor,
|
621 |
+
t5_feat: torch.Tensor,
|
622 |
+
t5_mask: torch.Tensor,
|
623 |
+
):
|
624 |
+
"""Prepare input and conditioning embeddings."""
|
625 |
+
|
626 |
+
# Visual patch embeddings with positional encoding.
|
627 |
+
T, H, W = x.shape[-3:]
|
628 |
+
pH, pW = H // self.patch_size, W // self.patch_size
|
629 |
+
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
|
630 |
+
assert x.ndim == 3
|
631 |
+
B = x.size(0)
|
632 |
+
|
633 |
+
# Construct position array of size [N, 3].
|
634 |
+
# pos[:, 0] is the frame index for each location,
|
635 |
+
# pos[:, 1] is the row index for each location, and
|
636 |
+
# pos[:, 2] is the column index for each location.
|
637 |
+
N = T * pH * pW
|
638 |
+
assert x.size(1) == N
|
639 |
+
pos = create_position_matrix(T, pH=pH, pW=pW, device=x.device, dtype=torch.float32) # (N, 3)
|
640 |
+
rope_cos, rope_sin = compute_mixed_rotation(
|
641 |
+
freqs=self.pos_frequencies, pos=pos
|
642 |
+
) # Each are (N, num_heads, dim // 2)
|
643 |
+
|
644 |
+
# Global vector embedding for conditionings.
|
645 |
+
c_t = self.t_embedder(1 - sigma) # (B, D)
|
646 |
+
|
647 |
+
# Pool T5 tokens using attention pooler
|
648 |
+
# Note y_feat[1] contains T5 token features.
|
649 |
+
assert (
|
650 |
+
t5_feat.size(1) == self.t5_token_length
|
651 |
+
), f"Expected L={self.t5_token_length}, got {t5_feat.shape} for y_feat."
|
652 |
+
t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D)
|
653 |
+
assert t5_y_pool.size(0) == B, f"Expected B={B}, got {t5_y_pool.shape} for t5_y_pool."
|
654 |
+
|
655 |
+
c = c_t + t5_y_pool
|
656 |
+
|
657 |
+
y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D)
|
658 |
+
|
659 |
+
return x, c, y_feat, rope_cos, rope_sin
|
660 |
+
|
661 |
+
def forward(
|
662 |
+
self,
|
663 |
+
x: torch.Tensor,
|
664 |
+
sigma: torch.Tensor,
|
665 |
+
y_feat: List[torch.Tensor],
|
666 |
+
y_mask: List[torch.Tensor],
|
667 |
+
packed_indices: Dict[str, torch.Tensor] = None,
|
668 |
+
rope_cos: torch.Tensor = None,
|
669 |
+
rope_sin: torch.Tensor = None,
|
670 |
+
num_ff_checkpoint: int = 0,
|
671 |
+
num_qkv_checkpoint: int = 0,
|
672 |
+
num_post_attn_checkpoint: int = 0,
|
673 |
+
):
|
674 |
+
"""Forward pass of DiT.
|
675 |
+
|
676 |
+
Args:
|
677 |
+
x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
|
678 |
+
sigma: (B,) tensor of noise standard deviations
|
679 |
+
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
|
680 |
+
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
|
681 |
+
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
|
682 |
+
"""
|
683 |
+
_, _, T, H, W = x.shape
|
684 |
+
|
685 |
+
if self.pos_frequencies.dtype != torch.float32:
|
686 |
+
warnings.warn(f"pos_frequencies dtype {self.pos_frequencies.dtype} != torch.float32")
|
687 |
+
|
688 |
+
# Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask.
|
689 |
+
# Have to call sdpa_kernel outside of a torch.compile region.
|
690 |
+
with sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
|
691 |
+
x, c, y_feat, rope_cos, rope_sin = self.prepare(x, sigma, y_feat[0], y_mask[0])
|
692 |
+
del y_mask
|
693 |
+
|
694 |
+
cp_rank, cp_size = cp.get_cp_rank_size()
|
695 |
+
N = x.size(1)
|
696 |
+
M = N // cp_size
|
697 |
+
assert N % cp_size == 0, f"Visual sequence length ({x.shape[1]}) must be divisible by cp_size ({cp_size})."
|
698 |
+
|
699 |
+
if cp_size > 1:
|
700 |
+
x = x.narrow(1, cp_rank * M, M)
|
701 |
+
|
702 |
+
assert self.num_heads % cp_size == 0
|
703 |
+
local_heads = self.num_heads // cp_size
|
704 |
+
rope_cos = rope_cos.narrow(1, cp_rank * local_heads, local_heads)
|
705 |
+
rope_sin = rope_sin.narrow(1, cp_rank * local_heads, local_heads)
|
706 |
+
|
707 |
+
for i, block in enumerate(self.blocks):
|
708 |
+
x, y_feat = block(
|
709 |
+
x,
|
710 |
+
c,
|
711 |
+
y_feat,
|
712 |
+
rope_cos=rope_cos,
|
713 |
+
rope_sin=rope_sin,
|
714 |
+
packed_indices=packed_indices,
|
715 |
+
checkpoint_ff=i < num_ff_checkpoint,
|
716 |
+
checkpoint_qkv=i < num_qkv_checkpoint,
|
717 |
+
checkpoint_post_attn=i < num_post_attn_checkpoint,
|
718 |
+
) # (B, M, D), (B, L, D)
|
719 |
+
del y_feat # Final layers don't use dense text features.
|
720 |
+
|
721 |
+
x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
|
722 |
+
|
723 |
+
patch = x.size(2)
|
724 |
+
x = cp.all_gather(x)
|
725 |
+
x = rearrange(x, "(G B) M P -> B (G M) P", G=cp_size, P=patch)
|
726 |
+
x = rearrange(
|
727 |
+
x,
|
728 |
+
"B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
|
729 |
+
T=T,
|
730 |
+
hp=H // self.patch_size,
|
731 |
+
wp=W // self.patch_size,
|
732 |
+
p1=self.patch_size,
|
733 |
+
p2=self.patch_size,
|
734 |
+
c=self.out_channels,
|
735 |
+
)
|
736 |
+
|
737 |
+
return x
|
src/genmo/mochi_preview/dit/joint_model/context_parallel.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
_CONTEXT_PARALLEL_GROUP = None
|
8 |
+
_CONTEXT_PARALLEL_RANK = None
|
9 |
+
_CONTEXT_PARALLEL_GROUP_SIZE = None
|
10 |
+
_CONTEXT_PARALLEL_GROUP_RANKS = None
|
11 |
+
|
12 |
+
|
13 |
+
def get_cp_rank_size() -> Tuple[int, int]:
|
14 |
+
if _CONTEXT_PARALLEL_GROUP:
|
15 |
+
assert isinstance(_CONTEXT_PARALLEL_RANK, int) and isinstance(_CONTEXT_PARALLEL_GROUP_SIZE, int)
|
16 |
+
return _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE
|
17 |
+
else:
|
18 |
+
return 0, 1
|
19 |
+
|
20 |
+
|
21 |
+
def local_shard(x: torch.Tensor, dim: int = 2) -> torch.Tensor:
|
22 |
+
if not _CONTEXT_PARALLEL_GROUP:
|
23 |
+
return x
|
24 |
+
|
25 |
+
cp_rank, cp_size = get_cp_rank_size()
|
26 |
+
return x.tensor_split(cp_size, dim=dim)[cp_rank]
|
27 |
+
|
28 |
+
|
29 |
+
def set_cp_group(cp_group, ranks, global_rank):
|
30 |
+
global _CONTEXT_PARALLEL_GROUP, _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE, _CONTEXT_PARALLEL_GROUP_RANKS
|
31 |
+
if _CONTEXT_PARALLEL_GROUP is not None:
|
32 |
+
raise RuntimeError("CP group already initialized.")
|
33 |
+
_CONTEXT_PARALLEL_GROUP = cp_group
|
34 |
+
_CONTEXT_PARALLEL_RANK = dist.get_rank(cp_group)
|
35 |
+
_CONTEXT_PARALLEL_GROUP_SIZE = dist.get_world_size(cp_group)
|
36 |
+
_CONTEXT_PARALLEL_GROUP_RANKS = ranks
|
37 |
+
|
38 |
+
assert _CONTEXT_PARALLEL_RANK == ranks.index(
|
39 |
+
global_rank
|
40 |
+
), f"Rank mismatch: {global_rank} in {ranks} does not have position {_CONTEXT_PARALLEL_RANK} "
|
41 |
+
assert _CONTEXT_PARALLEL_GROUP_SIZE == len(
|
42 |
+
ranks
|
43 |
+
), f"Group size mismatch: {_CONTEXT_PARALLEL_GROUP_SIZE} != len({ranks})"
|
44 |
+
|
45 |
+
|
46 |
+
def get_cp_group():
|
47 |
+
if _CONTEXT_PARALLEL_GROUP is None:
|
48 |
+
raise RuntimeError("CP group not initialized")
|
49 |
+
return _CONTEXT_PARALLEL_GROUP
|
50 |
+
|
51 |
+
|
52 |
+
def is_cp_active():
|
53 |
+
return _CONTEXT_PARALLEL_GROUP is not None
|
54 |
+
|
55 |
+
|
56 |
+
class AllGatherIntoTensorFunction(torch.autograd.Function):
|
57 |
+
@staticmethod
|
58 |
+
def forward(ctx, x: torch.Tensor, reduce_dtype, group: dist.ProcessGroup):
|
59 |
+
ctx.reduce_dtype = reduce_dtype
|
60 |
+
ctx.group = group
|
61 |
+
ctx.batch_size = x.size(0)
|
62 |
+
group_size = dist.get_world_size(group)
|
63 |
+
|
64 |
+
x = x.contiguous()
|
65 |
+
output = torch.empty(group_size * x.size(0), *x.shape[1:], dtype=x.dtype, device=x.device)
|
66 |
+
dist.all_gather_into_tensor(output, x, group=group)
|
67 |
+
return output
|
68 |
+
|
69 |
+
|
70 |
+
def all_gather(tensor: torch.Tensor) -> torch.Tensor:
|
71 |
+
if not _CONTEXT_PARALLEL_GROUP:
|
72 |
+
return tensor
|
73 |
+
|
74 |
+
return AllGatherIntoTensorFunction.apply(tensor, torch.float32, _CONTEXT_PARALLEL_GROUP)
|
75 |
+
|
76 |
+
|
77 |
+
@torch.compiler.disable()
|
78 |
+
def _all_to_all_single(output, input, group):
|
79 |
+
# Disable compilation since torch compile changes contiguity.
|
80 |
+
assert input.is_contiguous(), "Input tensor must be contiguous."
|
81 |
+
assert output.is_contiguous(), "Output tensor must be contiguous."
|
82 |
+
return dist.all_to_all_single(output, input, group=group)
|
83 |
+
|
84 |
+
|
85 |
+
class CollectTokens(torch.autograd.Function):
|
86 |
+
@staticmethod
|
87 |
+
def forward(ctx, qkv: torch.Tensor, group: dist.ProcessGroup, num_heads: int):
|
88 |
+
"""Redistribute heads and receive tokens.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
qkv: query, key or value. Shape: [B, M, 3 * num_heads * head_dim]
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
qkv: shape: [3, B, N, local_heads, head_dim]
|
95 |
+
|
96 |
+
where M is the number of local tokens,
|
97 |
+
N = cp_size * M is the number of global tokens,
|
98 |
+
local_heads = num_heads // cp_size is the number of local heads.
|
99 |
+
"""
|
100 |
+
ctx.group = group
|
101 |
+
ctx.num_heads = num_heads
|
102 |
+
cp_size = dist.get_world_size(group)
|
103 |
+
assert num_heads % cp_size == 0
|
104 |
+
ctx.local_heads = num_heads // cp_size
|
105 |
+
|
106 |
+
qkv = rearrange(
|
107 |
+
qkv,
|
108 |
+
"B M (qkv G h d) -> G M h B (qkv d)",
|
109 |
+
qkv=3,
|
110 |
+
G=cp_size,
|
111 |
+
h=ctx.local_heads,
|
112 |
+
).contiguous()
|
113 |
+
|
114 |
+
output_chunks = torch.empty_like(qkv)
|
115 |
+
_all_to_all_single(output_chunks, qkv, group=group)
|
116 |
+
|
117 |
+
return rearrange(output_chunks, "G M h B (qkv d) -> qkv B (G M) h d", qkv=3)
|
118 |
+
|
119 |
+
|
120 |
+
def all_to_all_collect_tokens(x: torch.Tensor, num_heads: int) -> torch.Tensor:
|
121 |
+
if not _CONTEXT_PARALLEL_GROUP:
|
122 |
+
# Move QKV dimension to the front.
|
123 |
+
# B M (3 H d) -> 3 B M H d
|
124 |
+
B, M, _ = x.size()
|
125 |
+
x = x.view(B, M, 3, num_heads, -1)
|
126 |
+
return x.permute(2, 0, 1, 3, 4)
|
127 |
+
|
128 |
+
return CollectTokens.apply(x, _CONTEXT_PARALLEL_GROUP, num_heads)
|
129 |
+
|
130 |
+
|
131 |
+
class CollectHeads(torch.autograd.Function):
|
132 |
+
@staticmethod
|
133 |
+
def forward(ctx, x: torch.Tensor, group: dist.ProcessGroup):
|
134 |
+
"""Redistribute tokens and receive heads.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
x: Output of attention. Shape: [B, N, local_heads, head_dim]
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
Shape: [B, M, num_heads * head_dim]
|
141 |
+
"""
|
142 |
+
ctx.group = group
|
143 |
+
ctx.local_heads = x.size(2)
|
144 |
+
ctx.head_dim = x.size(3)
|
145 |
+
group_size = dist.get_world_size(group)
|
146 |
+
x = rearrange(x, "B (G M) h D -> G h M B D", G=group_size).contiguous()
|
147 |
+
output = torch.empty_like(x)
|
148 |
+
_all_to_all_single(output, x, group=group)
|
149 |
+
del x
|
150 |
+
return rearrange(output, "G h M B D -> B M (G h D)")
|
151 |
+
|
152 |
+
|
153 |
+
def all_to_all_collect_heads(x: torch.Tensor) -> torch.Tensor:
|
154 |
+
if not _CONTEXT_PARALLEL_GROUP:
|
155 |
+
# Merge heads.
|
156 |
+
return x.view(x.size(0), x.size(1), x.size(2) * x.size(3))
|
157 |
+
|
158 |
+
return CollectHeads.apply(x, _CONTEXT_PARALLEL_GROUP)
|
src/genmo/mochi_preview/dit/joint_model/layers.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections.abc
|
2 |
+
import math
|
3 |
+
from itertools import repeat
|
4 |
+
from typing import Callable, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from einops import rearrange
|
10 |
+
|
11 |
+
|
12 |
+
# From PyTorch internals
|
13 |
+
def _ntuple(n):
|
14 |
+
def parse(x):
|
15 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
16 |
+
return tuple(x)
|
17 |
+
return tuple(repeat(x, n))
|
18 |
+
|
19 |
+
return parse
|
20 |
+
|
21 |
+
|
22 |
+
to_2tuple = _ntuple(2)
|
23 |
+
|
24 |
+
|
25 |
+
class TimestepEmbedder(nn.Module):
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
hidden_size: int,
|
29 |
+
frequency_embedding_size: int = 256,
|
30 |
+
*,
|
31 |
+
bias: bool = True,
|
32 |
+
timestep_scale: Optional[float] = None,
|
33 |
+
device: Optional[torch.device] = None,
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
self.mlp = nn.Sequential(
|
37 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=bias, device=device),
|
38 |
+
nn.SiLU(),
|
39 |
+
nn.Linear(hidden_size, hidden_size, bias=bias, device=device),
|
40 |
+
)
|
41 |
+
self.frequency_embedding_size = frequency_embedding_size
|
42 |
+
self.timestep_scale = timestep_scale
|
43 |
+
|
44 |
+
@staticmethod
|
45 |
+
def timestep_embedding(t, dim, max_period=10000):
|
46 |
+
half = dim // 2
|
47 |
+
freqs = torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
|
48 |
+
freqs.mul_(-math.log(max_period) / half).exp_()
|
49 |
+
args = t[:, None].float() * freqs[None]
|
50 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
51 |
+
if dim % 2:
|
52 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
53 |
+
return embedding
|
54 |
+
|
55 |
+
def forward(self, t):
|
56 |
+
if self.timestep_scale is not None:
|
57 |
+
t = t * self.timestep_scale
|
58 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
59 |
+
t_emb = self.mlp(t_freq)
|
60 |
+
return t_emb
|
61 |
+
|
62 |
+
|
63 |
+
class PooledCaptionEmbedder(nn.Module):
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
caption_feature_dim: int,
|
67 |
+
hidden_size: int,
|
68 |
+
*,
|
69 |
+
bias: bool = True,
|
70 |
+
device: Optional[torch.device] = None,
|
71 |
+
):
|
72 |
+
super().__init__()
|
73 |
+
self.caption_feature_dim = caption_feature_dim
|
74 |
+
self.hidden_size = hidden_size
|
75 |
+
self.mlp = nn.Sequential(
|
76 |
+
nn.Linear(caption_feature_dim, hidden_size, bias=bias, device=device),
|
77 |
+
nn.SiLU(),
|
78 |
+
nn.Linear(hidden_size, hidden_size, bias=bias, device=device),
|
79 |
+
)
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
return self.mlp(x)
|
83 |
+
|
84 |
+
|
85 |
+
class FeedForward(nn.Module):
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
in_features: int,
|
89 |
+
hidden_size: int,
|
90 |
+
multiple_of: int,
|
91 |
+
ffn_dim_multiplier: Optional[float],
|
92 |
+
device: Optional[torch.device] = None,
|
93 |
+
):
|
94 |
+
super().__init__()
|
95 |
+
# keep parameter count and computation constant compared to standard FFN
|
96 |
+
hidden_size = int(2 * hidden_size / 3)
|
97 |
+
# custom dim factor multiplier
|
98 |
+
if ffn_dim_multiplier is not None:
|
99 |
+
hidden_size = int(ffn_dim_multiplier * hidden_size)
|
100 |
+
hidden_size = multiple_of * ((hidden_size + multiple_of - 1) // multiple_of)
|
101 |
+
|
102 |
+
self.hidden_dim = hidden_size
|
103 |
+
self.w1 = nn.Linear(in_features, 2 * hidden_size, bias=False, device=device)
|
104 |
+
self.w2 = nn.Linear(hidden_size, in_features, bias=False, device=device)
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
# assert self.w1.weight.dtype == torch.bfloat16, f"FFN weight dtype {self.w1.weight.dtype} != bfloat16"
|
108 |
+
x, gate = self.w1(x).chunk(2, dim=-1)
|
109 |
+
x = self.w2(F.silu(x) * gate)
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class PatchEmbed(nn.Module):
|
114 |
+
def __init__(
|
115 |
+
self,
|
116 |
+
patch_size: int = 16,
|
117 |
+
in_chans: int = 3,
|
118 |
+
embed_dim: int = 768,
|
119 |
+
norm_layer: Optional[Callable] = None,
|
120 |
+
flatten: bool = True,
|
121 |
+
bias: bool = True,
|
122 |
+
dynamic_img_pad: bool = False,
|
123 |
+
device: Optional[torch.device] = None,
|
124 |
+
):
|
125 |
+
super().__init__()
|
126 |
+
self.patch_size = to_2tuple(patch_size)
|
127 |
+
self.flatten = flatten
|
128 |
+
self.dynamic_img_pad = dynamic_img_pad
|
129 |
+
|
130 |
+
self.proj = nn.Conv2d(
|
131 |
+
in_chans,
|
132 |
+
embed_dim,
|
133 |
+
kernel_size=patch_size,
|
134 |
+
stride=patch_size,
|
135 |
+
bias=bias,
|
136 |
+
device=device,
|
137 |
+
)
|
138 |
+
assert norm_layer is None
|
139 |
+
self.norm = norm_layer(embed_dim, device=device) if norm_layer else nn.Identity()
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
B, _C, T, H, W = x.shape
|
143 |
+
if not self.dynamic_img_pad:
|
144 |
+
assert (
|
145 |
+
H % self.patch_size[0] == 0
|
146 |
+
), f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
147 |
+
assert (
|
148 |
+
W % self.patch_size[1] == 0
|
149 |
+
), f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
150 |
+
else:
|
151 |
+
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
|
152 |
+
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
|
153 |
+
x = F.pad(x, (0, pad_w, 0, pad_h))
|
154 |
+
|
155 |
+
x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
|
156 |
+
x = self.proj(x)
|
157 |
+
|
158 |
+
# Flatten temporal and spatial dimensions.
|
159 |
+
if not self.flatten:
|
160 |
+
raise NotImplementedError("Must flatten output.")
|
161 |
+
x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T)
|
162 |
+
|
163 |
+
x = self.norm(x)
|
164 |
+
return x
|
165 |
+
|
166 |
+
|
167 |
+
class RMSNorm(torch.nn.Module):
|
168 |
+
def __init__(self, hidden_size, eps=1e-5, device=None):
|
169 |
+
super().__init__()
|
170 |
+
self.eps = eps
|
171 |
+
self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device))
|
172 |
+
self.register_parameter("bias", None)
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
# assert self.weight.dtype == torch.float32, f"RMSNorm weight dtype {self.weight.dtype} != float32"
|
176 |
+
|
177 |
+
x_fp32 = x.float()
|
178 |
+
x_normed = x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps)
|
179 |
+
return (x_normed * self.weight).type_as(x)
|
src/genmo/mochi_preview/dit/joint_model/lora.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
import math
|
3 |
+
from typing import Dict, List, Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
class LoRALayer:
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
r: int,
|
14 |
+
lora_alpha: int,
|
15 |
+
lora_dropout: float,
|
16 |
+
merge_weights: bool,
|
17 |
+
):
|
18 |
+
self.r = r
|
19 |
+
self.lora_alpha = lora_alpha
|
20 |
+
if lora_dropout > 0.0:
|
21 |
+
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
22 |
+
else:
|
23 |
+
self.lora_dropout = lambda x: x
|
24 |
+
self.merged = False
|
25 |
+
self.merge_weights = merge_weights
|
26 |
+
|
27 |
+
|
28 |
+
def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
|
29 |
+
assert bias == "none", f"Only bias='none' is supported"
|
30 |
+
for n, p in model.named_parameters():
|
31 |
+
if "lora_" not in n:
|
32 |
+
p.requires_grad = False
|
33 |
+
|
34 |
+
|
35 |
+
def lora_state_dict(model: nn.Module, bias: str = "none") -> Dict[str, torch.Tensor]:
|
36 |
+
assert bias == "none", f"Only bias='none' is supported"
|
37 |
+
my_state_dict = model.state_dict()
|
38 |
+
return {k: my_state_dict[k] for k in my_state_dict if "lora_" in k}
|
39 |
+
|
40 |
+
|
41 |
+
class LoraLinear(nn.Linear, LoRALayer):
|
42 |
+
# LoRA implemented in a dense layer
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
in_features: int,
|
46 |
+
out_features: int,
|
47 |
+
r: int = 0,
|
48 |
+
lora_alpha: int = 1,
|
49 |
+
lora_dropout: float = 0.0,
|
50 |
+
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
51 |
+
merge_weights: bool = True,
|
52 |
+
**kwargs,
|
53 |
+
):
|
54 |
+
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
55 |
+
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
|
56 |
+
|
57 |
+
self.fan_in_fan_out = fan_in_fan_out
|
58 |
+
# Actual trainable parameters
|
59 |
+
if r > 0:
|
60 |
+
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)).to(torch.float32))
|
61 |
+
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)).to(torch.float32))
|
62 |
+
self.scaling = self.lora_alpha / self.r
|
63 |
+
|
64 |
+
# Freezing the pre-trained weight matrix
|
65 |
+
self.weight.requires_grad = False
|
66 |
+
|
67 |
+
self.reset_parameters()
|
68 |
+
|
69 |
+
if fan_in_fan_out:
|
70 |
+
self.weight.data = self.weight.data.transpose(0, 1)
|
71 |
+
|
72 |
+
def reset_parameters(self):
|
73 |
+
nn.Linear.reset_parameters(self)
|
74 |
+
if hasattr(self, "lora_A"):
|
75 |
+
# initialize B the same way as the default for nn.Linear and A to zero
|
76 |
+
# this is different than what is described in the paper but should not affect performance
|
77 |
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
78 |
+
nn.init.zeros_(self.lora_B)
|
79 |
+
|
80 |
+
def train(self, mode: bool = True):
|
81 |
+
def T(w):
|
82 |
+
return w.transpose(0, 1) if self.fan_in_fan_out else w
|
83 |
+
|
84 |
+
nn.Linear.train(self, mode)
|
85 |
+
if mode:
|
86 |
+
if self.merge_weights and self.merged:
|
87 |
+
# Make sure that the weights are not merged
|
88 |
+
if self.r > 0:
|
89 |
+
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
90 |
+
self.merged = False
|
91 |
+
else:
|
92 |
+
if self.merge_weights and not self.merged:
|
93 |
+
# Merge the weights and mark it
|
94 |
+
if self.r > 0:
|
95 |
+
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
96 |
+
self.merged = True
|
97 |
+
|
98 |
+
def forward(self, x: torch.Tensor):
|
99 |
+
def T(w):
|
100 |
+
return w.transpose(0, 1) if self.fan_in_fan_out else w
|
101 |
+
|
102 |
+
if self.r > 0 and not self.merged:
|
103 |
+
result = F.linear(x, T(self.weight), bias=self.bias)
|
104 |
+
|
105 |
+
x = self.lora_dropout(x)
|
106 |
+
x = x @ self.lora_A.transpose(0, 1)
|
107 |
+
x = x @ self.lora_B.transpose(0, 1)
|
108 |
+
x = x * self.scaling
|
109 |
+
|
110 |
+
return result + x
|
111 |
+
else:
|
112 |
+
return F.linear(x, T(self.weight), bias=self.bias)
|
src/genmo/mochi_preview/dit/joint_model/mod_rmsnorm.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def modulated_rmsnorm(x, scale, eps=1e-6):
|
5 |
+
dtype = x.dtype
|
6 |
+
x = x.float()
|
7 |
+
|
8 |
+
# Compute RMS
|
9 |
+
mean_square = x.pow(2).mean(-1, keepdim=True)
|
10 |
+
inv_rms = torch.rsqrt(mean_square + eps)
|
11 |
+
|
12 |
+
# Normalize and modulate
|
13 |
+
x_normed = x * inv_rms
|
14 |
+
x_modulated = x_normed * (1 + scale.unsqueeze(1).float())
|
15 |
+
return x_modulated.to(dtype)
|
src/genmo/mochi_preview/dit/joint_model/residual_tanh_gated_rmsnorm.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
|
5 |
+
# Convert to fp32 for precision
|
6 |
+
x_res = x_res.float()
|
7 |
+
|
8 |
+
# Compute RMS
|
9 |
+
mean_square = x_res.pow(2).mean(-1, keepdim=True)
|
10 |
+
scale = torch.rsqrt(mean_square + eps)
|
11 |
+
|
12 |
+
# Apply tanh to gate
|
13 |
+
tanh_gate = torch.tanh(gate).unsqueeze(1)
|
14 |
+
|
15 |
+
# Normalize and apply gated scaling
|
16 |
+
x_normed = x_res * scale * tanh_gate
|
17 |
+
|
18 |
+
# Apply residual connection
|
19 |
+
output = x + x_normed.type_as(x)
|
20 |
+
return output
|
src/genmo/mochi_preview/dit/joint_model/rope_mixed.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def centers(start: float, stop, num, dtype=None, device=None):
|
8 |
+
"""linspace through bin centers.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
start (float): Start of the range.
|
12 |
+
stop (float): End of the range.
|
13 |
+
num (int): Number of points.
|
14 |
+
dtype (torch.dtype): Data type of the points.
|
15 |
+
device (torch.device): Device of the points.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
centers (Tensor): Centers of the bins. Shape: (num,).
|
19 |
+
"""
|
20 |
+
edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
|
21 |
+
return (edges[:-1] + edges[1:]) / 2
|
22 |
+
|
23 |
+
|
24 |
+
@functools.lru_cache(maxsize=1)
|
25 |
+
def create_position_matrix(
|
26 |
+
T: int,
|
27 |
+
pH: int,
|
28 |
+
pW: int,
|
29 |
+
device: torch.device,
|
30 |
+
dtype: torch.dtype,
|
31 |
+
*,
|
32 |
+
target_area: float = 36864,
|
33 |
+
):
|
34 |
+
"""
|
35 |
+
Args:
|
36 |
+
T: int - Temporal dimension
|
37 |
+
pH: int - Height dimension after patchify
|
38 |
+
pW: int - Width dimension after patchify
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
pos: [T * pH * pW, 3] - position matrix
|
42 |
+
"""
|
43 |
+
with torch.no_grad():
|
44 |
+
# Create 1D tensors for each dimension
|
45 |
+
t = torch.arange(T, dtype=dtype)
|
46 |
+
|
47 |
+
# Positionally interpolate to area 36864.
|
48 |
+
# (3072x3072 frame with 16x16 patches = 192x192 latents).
|
49 |
+
# This automatically scales rope positions when the resolution changes.
|
50 |
+
# We use a large target area so the model is more sensitive
|
51 |
+
# to changes in the learned pos_frequencies matrix.
|
52 |
+
scale = math.sqrt(target_area / (pW * pH))
|
53 |
+
w = centers(-pW * scale / 2, pW * scale / 2, pW)
|
54 |
+
h = centers(-pH * scale / 2, pH * scale / 2, pH)
|
55 |
+
|
56 |
+
# Use meshgrid to create 3D grids
|
57 |
+
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
|
58 |
+
|
59 |
+
# Stack and reshape the grids.
|
60 |
+
pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
|
61 |
+
pos = pos.view(-1, 3) # [T * pH * pW, 3]
|
62 |
+
pos = pos.to(dtype=dtype, device=device)
|
63 |
+
|
64 |
+
return pos
|
65 |
+
|
66 |
+
|
67 |
+
def compute_mixed_rotation(
|
68 |
+
freqs: torch.Tensor,
|
69 |
+
pos: torch.Tensor,
|
70 |
+
):
|
71 |
+
"""
|
72 |
+
Project each 3-dim position into per-head, per-head-dim 1D frequencies.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position
|
76 |
+
pos: [N, 3] - position of each token
|
77 |
+
num_heads: int
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
freqs_cos: [N, num_heads, num_freqs] - cosine components
|
81 |
+
freqs_sin: [N, num_heads, num_freqs] - sine components
|
82 |
+
"""
|
83 |
+
with torch.autocast("cuda", enabled=False):
|
84 |
+
assert freqs.ndim == 3
|
85 |
+
freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs)
|
86 |
+
freqs_cos = torch.cos(freqs_sum)
|
87 |
+
freqs_sin = torch.sin(freqs_sum)
|
88 |
+
return freqs_cos, freqs_sin
|
src/genmo/mochi_preview/dit/joint_model/temporal_rope.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Based on Llama3 Implementation.
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def apply_rotary_emb_qk_real(
|
6 |
+
xqk: torch.Tensor,
|
7 |
+
freqs_cos: torch.Tensor,
|
8 |
+
freqs_sin: torch.Tensor,
|
9 |
+
) -> torch.Tensor:
|
10 |
+
"""
|
11 |
+
Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D)
|
15 |
+
Can be either just query or just key, or both stacked along some batch or * dim.
|
16 |
+
freqs_cos (torch.Tensor): Precomputed cosine frequency tensor.
|
17 |
+
freqs_sin (torch.Tensor): Precomputed sine frequency tensor.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
torch.Tensor: The input tensor with rotary embeddings applied.
|
21 |
+
"""
|
22 |
+
assert xqk.dtype == torch.bfloat16
|
23 |
+
# Split the last dimension into even and odd parts
|
24 |
+
xqk_even = xqk[..., 0::2]
|
25 |
+
xqk_odd = xqk[..., 1::2]
|
26 |
+
|
27 |
+
# Apply rotation
|
28 |
+
cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk)
|
29 |
+
sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
|
30 |
+
|
31 |
+
# Interleave the results back into the original shape
|
32 |
+
out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
|
33 |
+
assert out.dtype == torch.bfloat16
|
34 |
+
return out
|
src/genmo/mochi_preview/dit/joint_model/utils.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
def modulate(x, shift, scale):
|
9 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
10 |
+
|
11 |
+
|
12 |
+
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
|
13 |
+
"""
|
14 |
+
Pool tokens in x using mask.
|
15 |
+
|
16 |
+
NOTE: We assume x does not require gradients.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
x: (B, L, D) tensor of tokens.
|
20 |
+
mask: (B, L) boolean tensor indicating which tokens are not padding.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
pooled: (B, D) tensor of pooled tokens.
|
24 |
+
"""
|
25 |
+
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
|
26 |
+
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
|
27 |
+
mask = mask[:, :, None].to(dtype=x.dtype)
|
28 |
+
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
|
29 |
+
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
|
30 |
+
return pooled
|
31 |
+
|
32 |
+
|
33 |
+
class AttentionPool(nn.Module):
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
embed_dim: int,
|
37 |
+
num_heads: int,
|
38 |
+
output_dim: int = None,
|
39 |
+
device: Optional[torch.device] = None,
|
40 |
+
):
|
41 |
+
"""
|
42 |
+
Args:
|
43 |
+
spatial_dim (int): Number of tokens in sequence length.
|
44 |
+
embed_dim (int): Dimensionality of input tokens.
|
45 |
+
num_heads (int): Number of attention heads.
|
46 |
+
output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
|
47 |
+
"""
|
48 |
+
super().__init__()
|
49 |
+
self.num_heads = num_heads
|
50 |
+
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim, device=device)
|
51 |
+
self.to_q = nn.Linear(embed_dim, embed_dim, device=device)
|
52 |
+
self.to_out = nn.Linear(embed_dim, output_dim or embed_dim, device=device)
|
53 |
+
|
54 |
+
def forward(self, x, mask):
|
55 |
+
"""
|
56 |
+
Args:
|
57 |
+
x (torch.Tensor): (B, L, D) tensor of input tokens.
|
58 |
+
mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
|
59 |
+
|
60 |
+
NOTE: We assume x does not require gradients.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
x (torch.Tensor): (B, D) tensor of pooled tokens.
|
64 |
+
"""
|
65 |
+
D = x.size(2)
|
66 |
+
|
67 |
+
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
|
68 |
+
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
|
69 |
+
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
|
70 |
+
|
71 |
+
# Average non-padding token features. These will be used as the query.
|
72 |
+
x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
|
73 |
+
|
74 |
+
# Concat pooled features to input sequence.
|
75 |
+
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
|
76 |
+
|
77 |
+
# Compute queries, keys, values. Only the mean token is used to create a query.
|
78 |
+
kv = self.to_kv(x) # (B, L+1, 2 * D)
|
79 |
+
q = self.to_q(x[:, 0]) # (B, D)
|
80 |
+
|
81 |
+
# Extract heads.
|
82 |
+
head_dim = D // self.num_heads
|
83 |
+
kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
|
84 |
+
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
|
85 |
+
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
|
86 |
+
q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
|
87 |
+
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
88 |
+
|
89 |
+
# Compute attention.
|
90 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim)
|
91 |
+
|
92 |
+
# Concatenate heads and run output.
|
93 |
+
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
94 |
+
x = self.to_out(x)
|
95 |
+
return x
|
96 |
+
|
97 |
+
|
98 |
+
def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
|
99 |
+
D = xy.size(1)
|
100 |
+
|
101 |
+
# Pad sequences to (B, N + L, dim).
|
102 |
+
assert indices.ndim == 1
|
103 |
+
indices = indices.unsqueeze(1).expand(-1, D) # (total,) -> (total, num_heads * head_dim)
|
104 |
+
output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype)
|
105 |
+
output = torch.scatter(output, 0, indices, xy)
|
106 |
+
xy = output.view(B, N + L, D)
|
107 |
+
|
108 |
+
# Split visual and text tokens along the sequence length.
|
109 |
+
return torch.tensor_split(xy, (N,), dim=1)
|
src/genmo/mochi_preview/pipelines.py
ADDED
@@ -0,0 +1,682 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
from contextlib import contextmanager
|
6 |
+
from functools import partial
|
7 |
+
from typing import Any, Dict, List, Literal, Optional, Union, cast
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import ray
|
11 |
+
import torch
|
12 |
+
import torch.distributed as dist
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from einops import repeat
|
16 |
+
from safetensors import safe_open
|
17 |
+
from safetensors.torch import load_file
|
18 |
+
from torch import nn
|
19 |
+
from torch.distributed.fsdp import (
|
20 |
+
BackwardPrefetch,
|
21 |
+
MixedPrecision,
|
22 |
+
ShardingStrategy,
|
23 |
+
)
|
24 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
25 |
+
from torch.distributed.fsdp.wrap import (
|
26 |
+
lambda_auto_wrap_policy,
|
27 |
+
transformer_auto_wrap_policy,
|
28 |
+
)
|
29 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
30 |
+
from transformers.models.t5.modeling_t5 import T5Block
|
31 |
+
|
32 |
+
import genmo.mochi_preview.dit.joint_model.context_parallel as cp
|
33 |
+
from genmo.lib.progress import get_new_progress_bar, progress_bar
|
34 |
+
from genmo.lib.utils import Timer
|
35 |
+
from genmo.mochi_preview.vae.models import (
|
36 |
+
Decoder,
|
37 |
+
Encoder,
|
38 |
+
decode_latents,
|
39 |
+
decode_latents_tiled_full,
|
40 |
+
decode_latents_tiled_spatial,
|
41 |
+
)
|
42 |
+
from genmo.mochi_preview.vae.vae_stats import dit_latents_to_vae_latents
|
43 |
+
|
44 |
+
|
45 |
+
def load_to_cpu(p, weights_only=True):
|
46 |
+
if p.endswith(".safetensors"):
|
47 |
+
return load_file(p)
|
48 |
+
else:
|
49 |
+
assert p.endswith(".pt")
|
50 |
+
return torch.load(p, map_location="cpu", weights_only=weights_only)
|
51 |
+
|
52 |
+
|
53 |
+
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
|
54 |
+
if linear_steps is None:
|
55 |
+
linear_steps = num_steps // 2
|
56 |
+
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
|
57 |
+
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
|
58 |
+
quadratic_steps = num_steps - linear_steps
|
59 |
+
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
|
60 |
+
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
|
61 |
+
const = quadratic_coef * (linear_steps**2)
|
62 |
+
quadratic_sigma_schedule = [
|
63 |
+
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
|
64 |
+
]
|
65 |
+
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
|
66 |
+
sigma_schedule = [1.0 - x for x in sigma_schedule]
|
67 |
+
return sigma_schedule
|
68 |
+
|
69 |
+
|
70 |
+
T5_MODEL = "google/t5-v1_1-xxl"
|
71 |
+
MAX_T5_TOKEN_LENGTH = 256
|
72 |
+
|
73 |
+
|
74 |
+
def setup_fsdp_sync(model, device_id, *, param_dtype, auto_wrap_policy) -> FSDP:
|
75 |
+
model = FSDP(
|
76 |
+
model,
|
77 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
78 |
+
mixed_precision=MixedPrecision(
|
79 |
+
param_dtype=param_dtype,
|
80 |
+
reduce_dtype=torch.float32,
|
81 |
+
buffer_dtype=torch.float32,
|
82 |
+
),
|
83 |
+
auto_wrap_policy=auto_wrap_policy,
|
84 |
+
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
85 |
+
limit_all_gathers=True,
|
86 |
+
device_id=device_id,
|
87 |
+
sync_module_states=True,
|
88 |
+
use_orig_params=True,
|
89 |
+
)
|
90 |
+
torch.cuda.synchronize()
|
91 |
+
return model
|
92 |
+
|
93 |
+
|
94 |
+
class ModelFactory(ABC):
|
95 |
+
def __init__(self, **kwargs):
|
96 |
+
self.kwargs = kwargs
|
97 |
+
|
98 |
+
@abstractmethod
|
99 |
+
def get_model(self, *, local_rank: int, device_id: Union[int, Literal["cpu"]], world_size: int) -> Any:
|
100 |
+
assert isinstance(device_id, int) or device_id == "cpu", "device_id must be an integer or 'cpu'"
|
101 |
+
# FSDP does not work when the model is on the CPU
|
102 |
+
if device_id == "cpu":
|
103 |
+
assert world_size == 1, "CPU offload only supports single-GPU inference"
|
104 |
+
|
105 |
+
|
106 |
+
class T5ModelFactory(ModelFactory):
|
107 |
+
def __init__(self, model_dir=None):
|
108 |
+
super().__init__()
|
109 |
+
self.model_dir = model_dir or T5_MODEL
|
110 |
+
|
111 |
+
def get_model(self, *, local_rank, device_id, world_size):
|
112 |
+
super().get_model(local_rank=local_rank, device_id=device_id, world_size=world_size)
|
113 |
+
model = T5EncoderModel.from_pretrained(self.model_dir)
|
114 |
+
if world_size > 1:
|
115 |
+
model = setup_fsdp_sync(
|
116 |
+
model,
|
117 |
+
device_id=device_id,
|
118 |
+
param_dtype=torch.float32,
|
119 |
+
auto_wrap_policy=partial(
|
120 |
+
transformer_auto_wrap_policy,
|
121 |
+
transformer_layer_cls={
|
122 |
+
T5Block,
|
123 |
+
},
|
124 |
+
),
|
125 |
+
)
|
126 |
+
elif isinstance(device_id, int):
|
127 |
+
model = model.to(torch.device(f"cuda:{device_id}")) # type: ignore
|
128 |
+
return model.eval()
|
129 |
+
|
130 |
+
|
131 |
+
class DitModelFactory(ModelFactory):
|
132 |
+
def __init__(
|
133 |
+
self, *,
|
134 |
+
model_path: str,
|
135 |
+
model_dtype: str,
|
136 |
+
lora_path: Optional[str] = None,
|
137 |
+
attention_mode: Optional[str] = None
|
138 |
+
):
|
139 |
+
# Infer attention mode if not specified
|
140 |
+
if attention_mode is None:
|
141 |
+
from genmo.lib.attn_imports import flash_varlen_attn # type: ignore
|
142 |
+
attention_mode = "sdpa" if flash_varlen_attn is None else "flash"
|
143 |
+
print(f"Attention mode: {attention_mode}")
|
144 |
+
|
145 |
+
super().__init__(
|
146 |
+
model_path=model_path,
|
147 |
+
lora_path=lora_path,
|
148 |
+
model_dtype=model_dtype,
|
149 |
+
attention_mode=attention_mode
|
150 |
+
)
|
151 |
+
|
152 |
+
def get_model(
|
153 |
+
self,
|
154 |
+
*,
|
155 |
+
local_rank,
|
156 |
+
device_id,
|
157 |
+
world_size,
|
158 |
+
model_kwargs=None,
|
159 |
+
patch_model_fns=None,
|
160 |
+
strict_load=True,
|
161 |
+
load_checkpoint=True,
|
162 |
+
fast_init=True,
|
163 |
+
):
|
164 |
+
from genmo.mochi_preview.dit.joint_model.asymm_models_joint import AsymmDiTJoint
|
165 |
+
|
166 |
+
if not model_kwargs:
|
167 |
+
model_kwargs = {}
|
168 |
+
|
169 |
+
lora_sd = None
|
170 |
+
lora_path = self.kwargs["lora_path"]
|
171 |
+
if lora_path is not None:
|
172 |
+
if lora_path.endswith(".safetensors"):
|
173 |
+
lora_sd = {}
|
174 |
+
with safe_open(lora_path, framework="pt") as f:
|
175 |
+
for k in f.keys():
|
176 |
+
lora_sd[k] = f.get_tensor(k)
|
177 |
+
lora_kwargs = json.loads(f.metadata()["kwargs"])
|
178 |
+
print(f"Loaded LoRA kwargs: {lora_kwargs}")
|
179 |
+
else:
|
180 |
+
lora = load_to_cpu(lora_path, weights_only=False)
|
181 |
+
lora_sd, lora_kwargs = lora["state_dict"], lora["kwargs"]
|
182 |
+
|
183 |
+
model_kwargs.update(cast(dict, lora_kwargs))
|
184 |
+
|
185 |
+
model_args = dict(
|
186 |
+
depth=48,
|
187 |
+
patch_size=2,
|
188 |
+
num_heads=24,
|
189 |
+
hidden_size_x=3072,
|
190 |
+
hidden_size_y=1536,
|
191 |
+
mlp_ratio_x=4.0,
|
192 |
+
mlp_ratio_y=4.0,
|
193 |
+
in_channels=12,
|
194 |
+
qk_norm=True,
|
195 |
+
qkv_bias=False,
|
196 |
+
out_bias=True,
|
197 |
+
patch_embed_bias=True,
|
198 |
+
timestep_mlp_bias=True,
|
199 |
+
timestep_scale=1000.0,
|
200 |
+
t5_feat_dim=4096,
|
201 |
+
t5_token_length=256,
|
202 |
+
rope_theta=10000.0,
|
203 |
+
attention_mode=self.kwargs["attention_mode"],
|
204 |
+
**model_kwargs,
|
205 |
+
)
|
206 |
+
|
207 |
+
if fast_init:
|
208 |
+
model: nn.Module = torch.nn.utils.skip_init(AsymmDiTJoint, **model_args)
|
209 |
+
else:
|
210 |
+
model: nn.Module = AsymmDiTJoint(**model_args)
|
211 |
+
|
212 |
+
for fn in patch_model_fns or []:
|
213 |
+
model = fn(model)
|
214 |
+
|
215 |
+
# FSDP syncs weights from rank 0 to all other ranks
|
216 |
+
if local_rank == 0 and load_checkpoint:
|
217 |
+
model_path = self.kwargs["model_path"]
|
218 |
+
sd = load_to_cpu(model_path)
|
219 |
+
|
220 |
+
# Load the state dictionary and capture the return value
|
221 |
+
load_result = model.load_state_dict(sd, strict=strict_load)
|
222 |
+
if not strict_load:
|
223 |
+
# Print mismatched keys
|
224 |
+
missing_keys = [k for k in load_result.missing_keys if ".lora_" not in k]
|
225 |
+
if missing_keys:
|
226 |
+
print(f"Missing keys from {model_path}: {missing_keys}")
|
227 |
+
if load_result.unexpected_keys:
|
228 |
+
print(f"Unexpected keys from {model_path}: {load_result.unexpected_keys}")
|
229 |
+
|
230 |
+
if lora_sd:
|
231 |
+
model.load_state_dict(lora_sd, strict=strict_load) # type: ignore
|
232 |
+
|
233 |
+
if world_size > 1:
|
234 |
+
assert self.kwargs["model_dtype"] == "bf16", "FP8 is not supported for multi-GPU inference"
|
235 |
+
|
236 |
+
model = setup_fsdp_sync(
|
237 |
+
model,
|
238 |
+
device_id=device_id,
|
239 |
+
param_dtype=torch.float32,
|
240 |
+
auto_wrap_policy=partial(
|
241 |
+
lambda_auto_wrap_policy,
|
242 |
+
lambda_fn=lambda m: m in model.blocks,
|
243 |
+
),
|
244 |
+
)
|
245 |
+
elif isinstance(device_id, int):
|
246 |
+
model = model.to(torch.device(f"cuda:{device_id}"))
|
247 |
+
return model.eval()
|
248 |
+
|
249 |
+
|
250 |
+
class DecoderModelFactory(ModelFactory):
|
251 |
+
def __init__(self, *, model_path: str):
|
252 |
+
super().__init__(model_path=model_path)
|
253 |
+
|
254 |
+
def get_model(self, *, local_rank=0, device_id=0, world_size=1):
|
255 |
+
# TODO(ved): Set flag for torch.compile
|
256 |
+
# TODO(ved): Use skip_init
|
257 |
+
|
258 |
+
decoder = Decoder(
|
259 |
+
out_channels=3,
|
260 |
+
base_channels=128,
|
261 |
+
channel_multipliers=[1, 2, 4, 6],
|
262 |
+
temporal_expansions=[1, 2, 3],
|
263 |
+
spatial_expansions=[2, 2, 2],
|
264 |
+
num_res_blocks=[3, 3, 4, 6, 3],
|
265 |
+
latent_dim=12,
|
266 |
+
has_attention=[False, False, False, False, False],
|
267 |
+
output_norm=False,
|
268 |
+
nonlinearity="silu",
|
269 |
+
output_nonlinearity="silu",
|
270 |
+
causal=True,
|
271 |
+
)
|
272 |
+
# VAE is not FSDP-wrapped
|
273 |
+
state_dict = load_file(self.kwargs["model_path"])
|
274 |
+
decoder.load_state_dict(state_dict, strict=True)
|
275 |
+
device = torch.device(f"cuda:{device_id}") if isinstance(device_id, int) else "cpu"
|
276 |
+
decoder.eval().to(device)
|
277 |
+
return decoder
|
278 |
+
|
279 |
+
|
280 |
+
class EncoderModelFactory(ModelFactory):
|
281 |
+
def __init__(self, *, model_path: str):
|
282 |
+
super().__init__(model_path=model_path)
|
283 |
+
|
284 |
+
def get_model(self, *, local_rank=0, device_id=0, world_size=1):
|
285 |
+
# TODO(ved): Set flag for torch.compile
|
286 |
+
# TODO(ved): Use skip_init
|
287 |
+
|
288 |
+
# We don't FSDP the encoder b/c it is small
|
289 |
+
encoder = Encoder(
|
290 |
+
in_channels=15,
|
291 |
+
base_channels=64,
|
292 |
+
channel_multipliers=[1, 2, 4, 6],
|
293 |
+
num_res_blocks=[3, 3, 4, 6, 3],
|
294 |
+
latent_dim=12,
|
295 |
+
temporal_reductions=[1, 2, 3],
|
296 |
+
spatial_reductions=[2, 2, 2],
|
297 |
+
prune_bottlenecks=[False, False, False, False, False],
|
298 |
+
has_attentions=[False, True, True, True, True],
|
299 |
+
affine=True,
|
300 |
+
bias=True,
|
301 |
+
input_is_conv_1x1=True,
|
302 |
+
padding_mode="replicate",
|
303 |
+
)
|
304 |
+
state_dict = load_file(self.kwargs["model_path"])
|
305 |
+
encoder.load_state_dict(state_dict, strict=True)
|
306 |
+
device = torch.device(f"cuda:{device_id}") if isinstance(device_id, int) else "cpu"
|
307 |
+
encoder.eval().to(device)
|
308 |
+
return encoder
|
309 |
+
|
310 |
+
|
311 |
+
def get_conditioning(
|
312 |
+
tokenizer: T5Tokenizer,
|
313 |
+
encoder: Encoder,
|
314 |
+
device: torch.device,
|
315 |
+
batch_inputs: bool,
|
316 |
+
*,
|
317 |
+
prompt: str,
|
318 |
+
negative_prompt: str,
|
319 |
+
):
|
320 |
+
if batch_inputs:
|
321 |
+
return dict(
|
322 |
+
batched=get_conditioning_for_prompts(
|
323 |
+
tokenizer, encoder, device, [prompt, negative_prompt]
|
324 |
+
)
|
325 |
+
)
|
326 |
+
else:
|
327 |
+
cond_input = get_conditioning_for_prompts(tokenizer, encoder, device, [prompt])
|
328 |
+
null_input = get_conditioning_for_prompts(tokenizer, encoder, device, [negative_prompt])
|
329 |
+
return dict(cond=cond_input, null=null_input)
|
330 |
+
|
331 |
+
|
332 |
+
def get_conditioning_for_prompts(tokenizer, encoder, device, prompts: List[str]):
|
333 |
+
assert len(prompts) in [1, 2] # [neg] or [pos] or [pos, neg]
|
334 |
+
B = len(prompts)
|
335 |
+
t5_toks = tokenizer(
|
336 |
+
prompts,
|
337 |
+
padding="max_length",
|
338 |
+
truncation=True,
|
339 |
+
max_length=MAX_T5_TOKEN_LENGTH,
|
340 |
+
return_tensors="pt",
|
341 |
+
return_attention_mask=True,
|
342 |
+
)
|
343 |
+
caption_input_ids_t5 = t5_toks["input_ids"]
|
344 |
+
caption_attention_mask_t5 = t5_toks["attention_mask"].bool()
|
345 |
+
del t5_toks
|
346 |
+
|
347 |
+
assert caption_input_ids_t5.shape == (B, MAX_T5_TOKEN_LENGTH)
|
348 |
+
assert caption_attention_mask_t5.shape == (B, MAX_T5_TOKEN_LENGTH)
|
349 |
+
|
350 |
+
# Special-case empty negative prompt by zero-ing it
|
351 |
+
if prompts[-1] == "":
|
352 |
+
caption_input_ids_t5[-1] = 0
|
353 |
+
caption_attention_mask_t5[-1] = False
|
354 |
+
|
355 |
+
caption_input_ids_t5 = caption_input_ids_t5.to(device, non_blocking=True)
|
356 |
+
caption_attention_mask_t5 = caption_attention_mask_t5.to(device, non_blocking=True)
|
357 |
+
|
358 |
+
y_mask = [caption_attention_mask_t5]
|
359 |
+
y_feat = [encoder(caption_input_ids_t5, caption_attention_mask_t5).last_hidden_state.detach()]
|
360 |
+
# Sometimes returns a tensor, othertimes a tuple, not sure why
|
361 |
+
# See: https://huggingface.co/genmo/mochi-1-preview/discussions/3
|
362 |
+
assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096)
|
363 |
+
assert y_feat[-1].dtype == torch.float32
|
364 |
+
|
365 |
+
return dict(y_mask=y_mask, y_feat=y_feat)
|
366 |
+
|
367 |
+
|
368 |
+
def compute_packed_indices(
|
369 |
+
device: torch.device, text_mask: torch.Tensor, num_latents: int
|
370 |
+
) -> Dict[str, Union[torch.Tensor, int]]:
|
371 |
+
"""
|
372 |
+
Based on https://github.com/Dao-AILab/flash-attention/blob/765741c1eeb86c96ee71a3291ad6968cfbf4e4a1/flash_attn/bert_padding.py#L60-L80
|
373 |
+
|
374 |
+
Args:
|
375 |
+
num_latents: Number of latent tokens
|
376 |
+
text_mask: (B, L) List of boolean tensor indicating which text tokens are not padding.
|
377 |
+
|
378 |
+
Returns:
|
379 |
+
packed_indices: Dict with keys for Flash Attention:
|
380 |
+
- valid_token_indices_kv: up to (B * (N + L),) tensor of valid token indices (non-padding)
|
381 |
+
in the packed sequence.
|
382 |
+
- cu_seqlens_kv: (B + 1,) tensor of cumulative sequence lengths in the packed sequence.
|
383 |
+
- max_seqlen_in_batch_kv: int of the maximum sequence length in the batch.
|
384 |
+
"""
|
385 |
+
# Create an expanded token mask saying which tokens are valid across both visual and text tokens.
|
386 |
+
PATCH_SIZE = 2
|
387 |
+
num_visual_tokens = num_latents // (PATCH_SIZE**2)
|
388 |
+
assert num_visual_tokens > 0
|
389 |
+
|
390 |
+
mask = F.pad(text_mask, (num_visual_tokens, 0), value=True) # (B, N + L)
|
391 |
+
seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,)
|
392 |
+
valid_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() # up to (B * (N + L),)
|
393 |
+
assert valid_token_indices.size(0) >= text_mask.size(0) * num_visual_tokens # At least (B * N,)
|
394 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
395 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
396 |
+
|
397 |
+
return {
|
398 |
+
"cu_seqlens_kv": cu_seqlens.to(device, non_blocking=True),
|
399 |
+
"max_seqlen_in_batch_kv": cast(int, max_seqlen_in_batch),
|
400 |
+
"valid_token_indices_kv": valid_token_indices.to(device, non_blocking=True),
|
401 |
+
}
|
402 |
+
|
403 |
+
|
404 |
+
def assert_eq(x, y, msg=None):
|
405 |
+
assert x == y, f"{msg or 'Assertion failed'}: {x} != {y}"
|
406 |
+
|
407 |
+
|
408 |
+
def sample_model(device, dit, conditioning, **args):
|
409 |
+
random.seed(args["seed"])
|
410 |
+
np.random.seed(args["seed"])
|
411 |
+
torch.manual_seed(args["seed"])
|
412 |
+
|
413 |
+
generator = torch.Generator(device=device)
|
414 |
+
generator.manual_seed(args["seed"])
|
415 |
+
|
416 |
+
w, h, t = args["width"], args["height"], args["num_frames"]
|
417 |
+
sample_steps = args["num_inference_steps"]
|
418 |
+
cfg_schedule = args["cfg_schedule"]
|
419 |
+
sigma_schedule = args["sigma_schedule"]
|
420 |
+
|
421 |
+
assert_eq(len(cfg_schedule), sample_steps, "cfg_schedule must have length sample_steps")
|
422 |
+
assert_eq((t - 1) % 6, 0, "t - 1 must be divisible by 6")
|
423 |
+
assert_eq(
|
424 |
+
len(sigma_schedule),
|
425 |
+
sample_steps + 1,
|
426 |
+
"sigma_schedule must have length sample_steps + 1",
|
427 |
+
)
|
428 |
+
|
429 |
+
B = 1
|
430 |
+
SPATIAL_DOWNSAMPLE = 8
|
431 |
+
TEMPORAL_DOWNSAMPLE = 6
|
432 |
+
IN_CHANNELS = 12
|
433 |
+
latent_t = ((t - 1) // TEMPORAL_DOWNSAMPLE) + 1
|
434 |
+
latent_w, latent_h = w // SPATIAL_DOWNSAMPLE, h // SPATIAL_DOWNSAMPLE
|
435 |
+
|
436 |
+
z = torch.randn(
|
437 |
+
(B, IN_CHANNELS, latent_t, latent_h, latent_w),
|
438 |
+
device=device,
|
439 |
+
dtype=torch.float32,
|
440 |
+
)
|
441 |
+
|
442 |
+
num_latents = latent_t * latent_h * latent_w
|
443 |
+
cond_batched = cond_text = cond_null = None
|
444 |
+
if "cond" in conditioning:
|
445 |
+
cond_text = conditioning["cond"]
|
446 |
+
cond_null = conditioning["null"]
|
447 |
+
cond_text["packed_indices"] = compute_packed_indices(device, cond_text["y_mask"][0], num_latents)
|
448 |
+
cond_null["packed_indices"] = compute_packed_indices(device, cond_null["y_mask"][0], num_latents)
|
449 |
+
else:
|
450 |
+
cond_batched = conditioning["batched"]
|
451 |
+
cond_batched["packed_indices"] = compute_packed_indices(device, cond_batched["y_mask"][0], num_latents)
|
452 |
+
z = repeat(z, "b ... -> (repeat b) ...", repeat=2)
|
453 |
+
|
454 |
+
def model_fn(*, z, sigma, cfg_scale):
|
455 |
+
if cond_batched:
|
456 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
457 |
+
out = dit(z, sigma, **cond_batched)
|
458 |
+
out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0)
|
459 |
+
else:
|
460 |
+
nonlocal cond_text, cond_null
|
461 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
462 |
+
out_cond = dit(z, sigma, **cond_text)
|
463 |
+
out_uncond = dit(z, sigma, **cond_null)
|
464 |
+
assert out_cond.shape == out_uncond.shape
|
465 |
+
out_uncond = out_uncond.to(z)
|
466 |
+
out_cond = out_cond.to(z)
|
467 |
+
return out_uncond + cfg_scale * (out_cond - out_uncond)
|
468 |
+
|
469 |
+
# Euler sampler w/ customizable sigma schedule & cfg scale
|
470 |
+
for i in get_new_progress_bar(range(0, sample_steps), desc="Sampling"):
|
471 |
+
sigma = sigma_schedule[i]
|
472 |
+
dsigma = sigma - sigma_schedule[i + 1]
|
473 |
+
|
474 |
+
# `pred` estimates `z_0 - eps`.
|
475 |
+
pred = model_fn(
|
476 |
+
z=z,
|
477 |
+
sigma=torch.full([B] if cond_text else [B * 2], sigma, device=z.device),
|
478 |
+
cfg_scale=cfg_schedule[i],
|
479 |
+
)
|
480 |
+
assert pred.dtype == torch.float32
|
481 |
+
z = z + dsigma * pred
|
482 |
+
|
483 |
+
z = z[:B] if cond_batched else z
|
484 |
+
return dit_latents_to_vae_latents(z)
|
485 |
+
|
486 |
+
|
487 |
+
@contextmanager
|
488 |
+
def move_to_device(model: nn.Module, target_device, *, enabled=True):
|
489 |
+
if not enabled:
|
490 |
+
yield
|
491 |
+
return
|
492 |
+
|
493 |
+
og_device = next(model.parameters()).device
|
494 |
+
if og_device == target_device:
|
495 |
+
print(f"move_to_device is a no-op model is already on {target_device}")
|
496 |
+
else:
|
497 |
+
print(f"moving model from {og_device} -> {target_device}")
|
498 |
+
|
499 |
+
model.to(target_device)
|
500 |
+
yield
|
501 |
+
if og_device != target_device:
|
502 |
+
print(f"moving model from {target_device} -> {og_device}")
|
503 |
+
model.to(og_device)
|
504 |
+
|
505 |
+
|
506 |
+
def t5_tokenizer(model_dir=None):
|
507 |
+
return T5Tokenizer.from_pretrained(model_dir or T5_MODEL, legacy=False)
|
508 |
+
|
509 |
+
|
510 |
+
class MochiSingleGPUPipeline:
|
511 |
+
def __init__(
|
512 |
+
self,
|
513 |
+
*,
|
514 |
+
text_encoder_factory: ModelFactory,
|
515 |
+
dit_factory: ModelFactory,
|
516 |
+
decoder_factory: ModelFactory,
|
517 |
+
cpu_offload: Optional[bool] = False,
|
518 |
+
decode_type: str = "full",
|
519 |
+
decode_args: Optional[Dict[str, Any]] = None,
|
520 |
+
fast_init=True,
|
521 |
+
strict_load=True
|
522 |
+
):
|
523 |
+
self.device = torch.device("cuda:0")
|
524 |
+
self.tokenizer = t5_tokenizer(text_encoder_factory.model_dir)
|
525 |
+
t = Timer()
|
526 |
+
self.cpu_offload = cpu_offload
|
527 |
+
self.decode_args = decode_args or {}
|
528 |
+
self.decode_type = decode_type
|
529 |
+
init_id = "cpu" if cpu_offload else 0
|
530 |
+
with t("load_text_encoder"):
|
531 |
+
self.text_encoder = text_encoder_factory.get_model(
|
532 |
+
local_rank=0,
|
533 |
+
device_id=init_id,
|
534 |
+
world_size=1,
|
535 |
+
)
|
536 |
+
with t("load_dit"):
|
537 |
+
self.dit = dit_factory.get_model(local_rank=0, device_id=init_id, world_size=1, fast_init=fast_init, strict_load=strict_load) # type: ignore
|
538 |
+
with t("load_vae"):
|
539 |
+
self.decoder = decoder_factory.get_model(local_rank=0, device_id=init_id, world_size=1)
|
540 |
+
t.print_stats()
|
541 |
+
|
542 |
+
def __call__(self, batch_cfg, prompt, negative_prompt, **kwargs):
|
543 |
+
with torch.inference_mode():
|
544 |
+
print_max_memory = lambda: print(
|
545 |
+
f"Max memory reserved: {torch.cuda.max_memory_reserved() / 1024**3:.2f} GB"
|
546 |
+
)
|
547 |
+
print_max_memory()
|
548 |
+
|
549 |
+
with move_to_device(self.text_encoder, self.device):
|
550 |
+
conditioning = get_conditioning(
|
551 |
+
tokenizer=self.tokenizer,
|
552 |
+
encoder=self.text_encoder,
|
553 |
+
device=self.device,
|
554 |
+
batch_inputs=batch_cfg,
|
555 |
+
prompt=prompt,
|
556 |
+
negative_prompt=negative_prompt,
|
557 |
+
)
|
558 |
+
print_max_memory()
|
559 |
+
|
560 |
+
with move_to_device(self.dit, self.device):
|
561 |
+
latents = sample_model(self.device, self.dit, conditioning, **kwargs)
|
562 |
+
print_max_memory()
|
563 |
+
|
564 |
+
with move_to_device(self.decoder, self.device):
|
565 |
+
if self.decode_type == "tiled_full":
|
566 |
+
frames = decode_latents_tiled_full(
|
567 |
+
self.decoder, latents, **self.decode_args)
|
568 |
+
elif self.decode_type == "tiled_spatial":
|
569 |
+
frames = decode_latents_tiled_spatial(
|
570 |
+
self.decoder, latents, **self.decode_args,
|
571 |
+
num_tiles_w=4, num_tiles_h=2)
|
572 |
+
else:
|
573 |
+
frames = decode_latents(self.decoder, latents)
|
574 |
+
print_max_memory()
|
575 |
+
return frames.cpu().numpy()
|
576 |
+
|
577 |
+
|
578 |
+
def cast_dit(model, dtype):
|
579 |
+
for name, module in model.named_modules():
|
580 |
+
if isinstance(module, nn.Linear):
|
581 |
+
assert any(
|
582 |
+
n in name for n in ["mlp", "t5", "mod_", "attn.qkv_", "attn.proj_", "final_layer"]
|
583 |
+
), f"Unexpected linear layer: {name}"
|
584 |
+
module.to(dtype=dtype)
|
585 |
+
elif isinstance(module, nn.Conv2d):
|
586 |
+
assert "x_embedder.proj" in name, f"Unexpected conv2d layer: {name}"
|
587 |
+
module.to(dtype=dtype)
|
588 |
+
return model
|
589 |
+
|
590 |
+
|
591 |
+
### ALL CODE BELOW HERE IS FOR MULTI-GPU MODE ###
|
592 |
+
|
593 |
+
|
594 |
+
# In multi-gpu mode, all models must belong to a device which has a predefined context parallel group
|
595 |
+
# So it doesn't make sense to work with models individually
|
596 |
+
class MultiGPUContext:
|
597 |
+
def __init__(
|
598 |
+
self,
|
599 |
+
*,
|
600 |
+
text_encoder_factory,
|
601 |
+
dit_factory,
|
602 |
+
decoder_factory,
|
603 |
+
device_id,
|
604 |
+
local_rank,
|
605 |
+
world_size,
|
606 |
+
):
|
607 |
+
t = Timer()
|
608 |
+
self.device = torch.device(f"cuda:{device_id}")
|
609 |
+
print(f"Initializing rank {local_rank+1}/{world_size}")
|
610 |
+
assert world_size > 1, f"Multi-GPU mode requires world_size > 1, got {world_size}"
|
611 |
+
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
612 |
+
os.environ["MASTER_PORT"] = "29500"
|
613 |
+
with t("init_process_group"):
|
614 |
+
dist.init_process_group(
|
615 |
+
"nccl",
|
616 |
+
rank=local_rank,
|
617 |
+
world_size=world_size,
|
618 |
+
device_id=self.device, # force non-lazy init
|
619 |
+
)
|
620 |
+
pg = dist.group.WORLD
|
621 |
+
cp.set_cp_group(pg, list(range(world_size)), local_rank)
|
622 |
+
distributed_kwargs = dict(local_rank=local_rank, device_id=device_id, world_size=world_size)
|
623 |
+
self.world_size = world_size
|
624 |
+
self.tokenizer = t5_tokenizer(text_encoder_factory.model_dir)
|
625 |
+
with t("load_text_encoder"):
|
626 |
+
self.text_encoder = text_encoder_factory.get_model(**distributed_kwargs)
|
627 |
+
with t("load_dit"):
|
628 |
+
self.dit = dit_factory.get_model(**distributed_kwargs)
|
629 |
+
with t("load_vae"):
|
630 |
+
self.decoder = decoder_factory.get_model(**distributed_kwargs)
|
631 |
+
self.local_rank = local_rank
|
632 |
+
t.print_stats()
|
633 |
+
|
634 |
+
def run(self, *, fn, **kwargs):
|
635 |
+
return fn(self, **kwargs)
|
636 |
+
|
637 |
+
|
638 |
+
class MochiMultiGPUPipeline:
|
639 |
+
def __init__(
|
640 |
+
self,
|
641 |
+
*,
|
642 |
+
text_encoder_factory: ModelFactory,
|
643 |
+
dit_factory: ModelFactory,
|
644 |
+
decoder_factory: ModelFactory,
|
645 |
+
world_size: int,
|
646 |
+
):
|
647 |
+
ray.init()
|
648 |
+
RemoteClass = ray.remote(MultiGPUContext)
|
649 |
+
self.ctxs = [
|
650 |
+
RemoteClass.options(num_gpus=1).remote(
|
651 |
+
text_encoder_factory=text_encoder_factory,
|
652 |
+
dit_factory=dit_factory,
|
653 |
+
decoder_factory=decoder_factory,
|
654 |
+
world_size=world_size,
|
655 |
+
device_id=0,
|
656 |
+
local_rank=i,
|
657 |
+
)
|
658 |
+
for i in range(world_size)
|
659 |
+
]
|
660 |
+
for ctx in self.ctxs:
|
661 |
+
ray.get(ctx.__ray_ready__.remote())
|
662 |
+
|
663 |
+
def __call__(self, **kwargs):
|
664 |
+
def sample(ctx, *, batch_cfg, prompt, negative_prompt, **kwargs):
|
665 |
+
with progress_bar(type="ray_tqdm", enabled=ctx.local_rank == 0), torch.inference_mode():
|
666 |
+
conditioning = get_conditioning(
|
667 |
+
ctx.tokenizer,
|
668 |
+
ctx.text_encoder,
|
669 |
+
ctx.device,
|
670 |
+
batch_cfg,
|
671 |
+
prompt=prompt,
|
672 |
+
negative_prompt=negative_prompt,
|
673 |
+
)
|
674 |
+
latents = sample_model(ctx.device, ctx.dit, conditioning=conditioning, **kwargs)
|
675 |
+
if ctx.local_rank == 0:
|
676 |
+
torch.save(latents, "latents.pt")
|
677 |
+
frames = decode_latents(ctx.decoder, latents)
|
678 |
+
return frames.cpu().numpy()
|
679 |
+
|
680 |
+
return ray.get([ctx.run.remote(fn=sample, **kwargs, show_progress=i == 0) for i, ctx in enumerate(self.ctxs)])[
|
681 |
+
0
|
682 |
+
]
|
src/genmo/mochi_preview/vae/__init__.py
ADDED
File without changes
|
src/genmo/mochi_preview/vae/cp_conv.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
import genmo.mochi_preview.dit.joint_model.context_parallel as cp
|
8 |
+
|
9 |
+
|
10 |
+
def cast_tuple(t, length=1):
|
11 |
+
return t if isinstance(t, tuple) else ((t,) * length)
|
12 |
+
|
13 |
+
|
14 |
+
def cp_pass_frames(x: torch.Tensor, frames_to_send: int) -> torch.Tensor:
|
15 |
+
"""
|
16 |
+
Forward pass that handles communication between ranks for inference.
|
17 |
+
Args:
|
18 |
+
x: Tensor of shape (B, C, T, H, W)
|
19 |
+
frames_to_send: int, number of frames to communicate between ranks
|
20 |
+
Returns:
|
21 |
+
output: Tensor of shape (B, C, T', H, W)
|
22 |
+
"""
|
23 |
+
cp_rank, cp_world_size = cp.get_cp_rank_size()
|
24 |
+
if frames_to_send == 0 or cp_world_size == 1:
|
25 |
+
return x
|
26 |
+
|
27 |
+
group = cp.get_cp_group()
|
28 |
+
global_rank = dist.get_rank()
|
29 |
+
|
30 |
+
# Send to next rank
|
31 |
+
if cp_rank < cp_world_size - 1:
|
32 |
+
assert x.size(2) >= frames_to_send
|
33 |
+
tail = x[:, :, -frames_to_send:].contiguous()
|
34 |
+
dist.send(tail, global_rank + 1, group=group)
|
35 |
+
|
36 |
+
# Receive from previous rank
|
37 |
+
if cp_rank > 0:
|
38 |
+
B, C, _, H, W = x.shape
|
39 |
+
recv_buffer = torch.empty(
|
40 |
+
(B, C, frames_to_send, H, W),
|
41 |
+
dtype=x.dtype,
|
42 |
+
device=x.device,
|
43 |
+
)
|
44 |
+
dist.recv(recv_buffer, global_rank - 1, group=group)
|
45 |
+
x = torch.cat([recv_buffer, x], dim=2)
|
46 |
+
|
47 |
+
return x
|
48 |
+
|
49 |
+
|
50 |
+
def _pad_to_max(x: torch.Tensor, max_T: int) -> torch.Tensor:
|
51 |
+
if max_T > x.size(2):
|
52 |
+
pad_T = max_T - x.size(2)
|
53 |
+
pad_dims = (0, 0, 0, 0, 0, pad_T)
|
54 |
+
return F.pad(x, pad_dims)
|
55 |
+
return x
|
56 |
+
|
57 |
+
|
58 |
+
def gather_all_frames(x: torch.Tensor) -> torch.Tensor:
|
59 |
+
"""
|
60 |
+
Gathers all frames from all processes for inference.
|
61 |
+
Args:
|
62 |
+
x: Tensor of shape (B, C, T, H, W)
|
63 |
+
Returns:
|
64 |
+
output: Tensor of shape (B, C, T_total, H, W)
|
65 |
+
"""
|
66 |
+
cp_rank, cp_size = cp.get_cp_rank_size()
|
67 |
+
if cp_size == 1:
|
68 |
+
return x
|
69 |
+
|
70 |
+
cp_group = cp.get_cp_group()
|
71 |
+
|
72 |
+
# Ensure the tensor is contiguous for collective operations
|
73 |
+
x = x.contiguous()
|
74 |
+
|
75 |
+
# Get the local time dimension size
|
76 |
+
local_T = x.size(2)
|
77 |
+
local_T_tensor = torch.tensor([local_T], device=x.device, dtype=torch.int64)
|
78 |
+
|
79 |
+
# Gather all T sizes from all processes
|
80 |
+
all_T = [torch.zeros(1, dtype=torch.int64, device=x.device) for _ in range(cp_size)]
|
81 |
+
dist.all_gather(all_T, local_T_tensor, group=cp_group)
|
82 |
+
all_T = [t.item() for t in all_T]
|
83 |
+
|
84 |
+
# Pad the tensor at the end of the time dimension to match max_T
|
85 |
+
max_T = max(all_T)
|
86 |
+
x = _pad_to_max(x, max_T).contiguous()
|
87 |
+
|
88 |
+
# Prepare a list to hold the gathered tensors
|
89 |
+
gathered_x = [torch.zeros_like(x).contiguous() for _ in range(cp_size)]
|
90 |
+
|
91 |
+
# Perform the all_gather operation
|
92 |
+
dist.all_gather(gathered_x, x, group=cp_group)
|
93 |
+
|
94 |
+
# Slice each gathered tensor back to its original T size
|
95 |
+
for idx, t_size in enumerate(all_T):
|
96 |
+
gathered_x[idx] = gathered_x[idx][:, :, :t_size]
|
97 |
+
|
98 |
+
return torch.cat(gathered_x, dim=2)
|
99 |
+
|
100 |
+
|
101 |
+
def excessive_memory_usage(input: torch.Tensor, max_gb: float = 2.0) -> bool:
|
102 |
+
"""Estimate memory usage based on input tensor size and data type."""
|
103 |
+
element_size = input.element_size() # Size in bytes of each element
|
104 |
+
memory_bytes = input.numel() * element_size
|
105 |
+
memory_gb = memory_bytes / 1024**3
|
106 |
+
return memory_gb > max_gb
|
107 |
+
|
108 |
+
|
109 |
+
class ContextParallelCausalConv3d(torch.nn.Conv3d):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
in_channels,
|
113 |
+
out_channels,
|
114 |
+
kernel_size: Union[int, Tuple[int, int, int]],
|
115 |
+
stride: Union[int, Tuple[int, int, int]],
|
116 |
+
**kwargs,
|
117 |
+
):
|
118 |
+
kernel_size = cast_tuple(kernel_size, 3)
|
119 |
+
stride = cast_tuple(stride, 3)
|
120 |
+
height_pad = (kernel_size[1] - 1) // 2
|
121 |
+
width_pad = (kernel_size[2] - 1) // 2
|
122 |
+
|
123 |
+
super().__init__(
|
124 |
+
in_channels=in_channels,
|
125 |
+
out_channels=out_channels,
|
126 |
+
kernel_size=kernel_size,
|
127 |
+
stride=stride,
|
128 |
+
dilation=(1, 1, 1),
|
129 |
+
padding=(0, height_pad, width_pad),
|
130 |
+
**kwargs,
|
131 |
+
)
|
132 |
+
|
133 |
+
def forward(self, x: torch.Tensor):
|
134 |
+
cp_rank, cp_world_size = cp.get_cp_rank_size()
|
135 |
+
|
136 |
+
context_size = self.kernel_size[0] - 1
|
137 |
+
if cp_rank == 0:
|
138 |
+
mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
|
139 |
+
x = F.pad(x, (0, 0, 0, 0, context_size, 0), mode=mode)
|
140 |
+
|
141 |
+
if cp_world_size == 1:
|
142 |
+
return super().forward(x)
|
143 |
+
|
144 |
+
if all(s == 1 for s in self.stride):
|
145 |
+
# Receive some frames from previous rank.
|
146 |
+
x = cp_pass_frames(x, context_size)
|
147 |
+
return super().forward(x)
|
148 |
+
|
149 |
+
# Less efficient implementation for strided convs.
|
150 |
+
# All gather x, infer and chunk.
|
151 |
+
x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W]
|
152 |
+
x = super().forward(x)
|
153 |
+
x_chunks = x.tensor_split(cp_world_size, dim=2)
|
154 |
+
assert len(x_chunks) == cp_world_size
|
155 |
+
return x_chunks[cp_rank]
|