nruto commited on
Commit
d0bfdd6
·
verified ·
1 Parent(s): 9874f2a

Upload 31 files

Browse files
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2024 Genmo
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
pyproject.toml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "genmo"
3
+ version = "0.1.0"
4
+ description = "Genmo models"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "addict>=2.4.0",
9
+ "click>=8.1.7",
10
+ "einops>=0.8.0",
11
+ "gradio>=3.36.1",
12
+ "omegaconf>=2.3.0",
13
+ "pillow>=11.0.0",
14
+ "pyyaml>=6.0.2",
15
+ "ray>=2.37.0",
16
+ "sentencepiece>=0.2.0",
17
+ "setuptools>=75.2.0",
18
+ "torch>=2.4.1",
19
+ "transformers>=4.45.2",
20
+ ]
21
+
22
+ [project.optional-dependencies]
23
+ flash = [
24
+ "flash-attn>=2.6.3",
25
+ ]
26
+
27
+ [tool.ruff]
28
+ # Allow lines to be as long as 120.
29
+ line-length = 120
pyrightconfig.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "include": ["src/genmo/mochi_preview/pipelines.py"]
3
+ }
4
+
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ addict>=2.4.0
2
+ click>=8.1.7
3
+ einops>=0.8.0
4
+ gradio>=3.36.1
5
+ omegaconf>=2.3.0
6
+ pillow>=11.0.0
7
+ pyyaml>=6.0.2
8
+ ray>=2.37.0
9
+ sentencepiece>=0.2.0
10
+ setuptools>=75.2.0
11
+ torch>=2.4.1
12
+ transformers>=4.45.2
scripts/format.bash ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #! /bin/bash
2
+ set -euxo pipefail
3
+ ruff format src
4
+ ruff check --fix --select I src
scripts/pytorch_to_safe_tensors.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ from pathlib import Path
3
+
4
+ import click
5
+ import torch
6
+ from safetensors.torch import save_file
7
+
8
+
9
+ @click.command()
10
+ @click.argument("input_path", type=click.Path(exists=True))
11
+ def convert_to_safetensors(input_path):
12
+ model = torch.load(input_path)
13
+ input_path = Path(input_path)
14
+ output_path = input_path.with_suffix(".safetensors")
15
+ save_file(model, str(output_path))
16
+ click.echo(f"Converted {input_path} to {output_path}")
17
+
18
+
19
+ if __name__ == "__main__":
20
+ convert_to_safetensors()
scripts/typecheck.bash ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #! /bin/bash
2
+ npx pyright
scripts/weights_to_fp8.py ADDED
File without changes
src/genmo.egg-info/PKG-INFO ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: genmo
3
+ Version: 0.1.0
4
+ Summary: Genmo models
5
+ Requires-Python: >=3.10
6
+ Description-Content-Type: text/markdown
7
+ License-File: LICENSE
8
+ Requires-Dist: addict>=2.4.0
9
+ Requires-Dist: click>=8.1.7
10
+ Requires-Dist: einops>=0.8.0
11
+ Requires-Dist: gradio>=3.36.1
12
+ Requires-Dist: omegaconf>=2.3.0
13
+ Requires-Dist: pillow>=11.0.0
14
+ Requires-Dist: pyyaml>=6.0.2
15
+ Requires-Dist: ray>=2.37.0
16
+ Requires-Dist: sentencepiece>=0.2.0
17
+ Requires-Dist: setuptools>=75.2.0
18
+ Requires-Dist: torch>=2.4.1
19
+ Requires-Dist: transformers>=4.45.2
20
+ Provides-Extra: flash
21
+ Requires-Dist: flash-attn>=2.6.3; extra == "flash"
22
+
23
+ # Mochi 1
24
+ [Blog](https://www.genmo.ai/blog) | [Hugging Face](https://huggingface.co/genmo/mochi-1-preview) | [Playground](https://www.genmo.ai/play) | [Careers](https://jobs.ashbyhq.com/genmo)
25
+
26
+ A state of the art video generation model by [Genmo](https://genmo.ai).
27
+
28
+ https://github.com/user-attachments/assets/4d268d02-906d-4cb0-87cc-f467f1497108
29
+
30
+ ## Overview
31
+
32
+ Mochi 1 preview is an open state-of-the-art video generation model with high-fidelity motion and strong prompt adherence in preliminary evaluation. This model dramatically closes the gap between closed and open video generation systems. We’re releasing the model under a permissive Apache 2.0 license. Try this model for free on [our playground](https://genmo.ai/play).
33
+
34
+ ## Installation
35
+
36
+ Install using [uv](https://github.com/astral-sh/uv):
37
+
38
+ ```bash
39
+ git clone https://github.com/genmoai/models
40
+ cd models
41
+ pip install uv
42
+ uv venv .venv
43
+ source .venv/bin/activate
44
+ uv pip install -e . --no-build-isolation
45
+ ```
46
+
47
+ If you want to install flash attention, you can use:
48
+ ```
49
+ uv pip install -e .[flash] --no-build-isolation
50
+ ```
51
+
52
+ You will also need to install [FFMPEG](https://www.ffmpeg.org/) to turn your outputs into videos.
53
+
54
+ ## Download Weights
55
+
56
+ Download the weights from [Hugging Face](https://huggingface.co/genmo/mochi-1-preview/tree/main) or via `magnet:?xt=urn:btih:441da1af7a16bcaa4f556964f8028d7113d21cbb&dn=weights&tr=udp://tracker.opentrackr.org:1337/announce` to a folder on your computer.
57
+
58
+ ## Running
59
+
60
+ Start the gradio UI with
61
+
62
+ ```bash
63
+ python3 ./demos/gradio_ui.py --model_dir "<path_to_downloaded_directory>"
64
+ ```
65
+
66
+ Or generate videos directly from the CLI with
67
+
68
+ ```bash
69
+ python3 ./demos/cli.py --model_dir "<path_to_downloaded_directory>"
70
+ ```
71
+
72
+ Replace `<path_to_downloaded_directory>` with the path to your model directory.
73
+
74
+ ## API
75
+
76
+ This repository comes with a simple, composable API, so you can programmatically call the model. You can find a full example [here](demos/api_example.py). But, roughly, it looks like this:
77
+
78
+ ```python
79
+ from genmo.mochi_preview.pipelines import (
80
+ DecoderModelFactory,
81
+ DitModelFactory,
82
+ MochiSingleGPUPipeline,
83
+ T5ModelFactory,
84
+ linear_quadratic_schedule,
85
+ )
86
+
87
+ pipeline = MochiSingleGPUPipeline(
88
+ text_encoder_factory=T5ModelFactory(),
89
+ dit_factory=DitModelFactory(
90
+ model_path=f"{MOCHI_DIR}/dit.safetensors", model_dtype="bf16"
91
+ ),
92
+ decoder_factory=DecoderModelFactory(
93
+ model_path=f"{MOCHI_DIR}/vae.safetensors",
94
+ model_stats_path=f"{MOCHI_DIR}/vae_stats.json",
95
+ ),
96
+ cpu_offload=True,
97
+ decode_type="tiled_full",
98
+ )
99
+
100
+ video = pipeline(
101
+ height=480,
102
+ width=848,
103
+ num_frames=31,
104
+ num_inference_steps=64,
105
+ sigma_schedule=linear_quadratic_schedule(64, 0.025),
106
+ cfg_schedule=[4.5] * 64,
107
+ batch_cfg=False,
108
+ prompt="your favorite prompt here ...",
109
+ negative_prompt="",
110
+ seed=12345,
111
+ )
112
+ ```
113
+
114
+ ## Model Architecture
115
+
116
+ Mochi 1 represents a significant advancement in open-source video generation, featuring a 10 billion parameter diffusion model built on our novel Asymmetric Diffusion Transformer (AsymmDiT) architecture. Trained entirely from scratch, it is the largest video generative model ever openly released. And best of all, it’s a simple, hackable architecture. Additionally, we are releasing an inference harness that includes an efficient context parallel implementation.
117
+
118
+ Alongside Mochi, we are open-sourcing our video AsymmVAE. We use an asymmetric encoder-decoder structure to build an efficient high quality compression model. Our AsymmVAE causally compresses videos to a 128x smaller size, with an 8x8 spatial and a 6x temporal compression to a 12-channel latent space.
119
+
120
+ ### AsymmVAE Model Specs
121
+ |Params <br> Count | Enc Base <br> Channels | Dec Base <br> Channels |Latent <br> Dim | Spatial <br> Compression | Temporal <br> Compression |
122
+ |:--:|:--:|:--:|:--:|:--:|:--:|
123
+ |362M | 64 | 128 | 12 | 8x8 | 6x |
124
+
125
+ An AsymmDiT efficiently processes user prompts alongside compressed video tokens by streamlining text processing and focusing neural network capacity on visual reasoning. AsymmDiT jointly attends to text and visual tokens with multi-modal self-attention and learns separate MLP layers for each modality, similar to Stable Diffusion 3. However, our visual stream has nearly 4 times as many parameters as the text stream via a larger hidden dimension. To unify the modalities in self-attention, we use non-square QKV and output projection layers. This asymmetric design reduces inference memory requirements.
126
+ Many modern diffusion models use multiple pretrained language models to represent user prompts. In contrast, Mochi 1 simply encodes prompts with a single T5-XXL language model.
127
+
128
+ ### AsymmDiT Model Specs
129
+ |Params <br> Count | Num <br> Layers | Num <br> Heads | Visual <br> Dim | Text <br> Dim | Visual <br> Tokens | Text <br> Tokens |
130
+ |:--:|:--:|:--:|:--:|:--:|:--:|:--:|
131
+ |10B | 48 | 24 | 3072 | 1536 | 44520 | 256 |
132
+
133
+ ## Hardware Requirements
134
+
135
+ The model requires at least 4 H100 GPUs to run. We welcome contributions from the community to reduce this requirement.
136
+
137
+ ## Safety
138
+ Genmo video models are general text-to-video diffusion models that inherently reflect the biases and preconceptions found in their training data. While steps have been taken to limit NSFW content, organizations should implement additional safety protocols and careful consideration before deploying these model weights in any commercial services or products.
139
+
140
+ ## Limitations
141
+ Under the research preview, Mochi 1 is a living and evolving checkpoint. There are a few known limitations. The initial release generates videos at 480p today. In some edge cases with extreme motion, minor warping and distortions can also occur. Mochi 1 is also optimized for photorealistic styles so does not perform well with animated content. We also anticipate that the community will fine-tune the model to suit various aesthetic preferences.
142
+
143
+ ## Related Work
144
+ - [ComfyUI-MochiWrapper](https://github.com/kijai/ComfyUI-MochiWrapper) adds ComfyUI support for Mochi. The integration of Pytorch's SDPA attention was taken from their repository.
145
+
146
+
147
+ ## BibTeX
148
+ ```
149
+ @misc{genmo2024mochi,
150
+ title={Mochi},
151
+ author={Genmo Team},
152
+ year={2024}
153
+ }
154
+ ```
src/genmo.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ src/genmo.egg-info/PKG-INFO
5
+ src/genmo.egg-info/SOURCES.txt
6
+ src/genmo.egg-info/dependency_links.txt
7
+ src/genmo.egg-info/requires.txt
8
+ src/genmo.egg-info/top_level.txt
9
+ src/genmo/lib/attn_imports.py
10
+ src/genmo/lib/progress.py
11
+ src/genmo/lib/utils.py
12
+ src/genmo/mochi_preview/__init__.py
13
+ src/genmo/mochi_preview/pipelines.py
14
+ src/genmo/mochi_preview/dit/joint_model/__init__.py
15
+ src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py
16
+ src/genmo/mochi_preview/dit/joint_model/context_parallel.py
17
+ src/genmo/mochi_preview/dit/joint_model/layers.py
18
+ src/genmo/mochi_preview/dit/joint_model/mod_rmsnorm.py
19
+ src/genmo/mochi_preview/dit/joint_model/residual_tanh_gated_rmsnorm.py
20
+ src/genmo/mochi_preview/dit/joint_model/rope_mixed.py
21
+ src/genmo/mochi_preview/dit/joint_model/temporal_rope.py
22
+ src/genmo/mochi_preview/dit/joint_model/utils.py
23
+ src/genmo/mochi_preview/vae/__init__.py
24
+ src/genmo/mochi_preview/vae/cp_conv.py
25
+ src/genmo/mochi_preview/vae/model.py
src/genmo.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
src/genmo.egg-info/requires.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ addict>=2.4.0
2
+ click>=8.1.7
3
+ einops>=0.8.0
4
+ gradio>=3.36.1
5
+ omegaconf>=2.3.0
6
+ pillow>=11.0.0
7
+ pyyaml>=6.0.2
8
+ ray>=2.37.0
9
+ sentencepiece>=0.2.0
10
+ setuptools>=75.2.0
11
+ torch>=2.4.1
12
+ transformers>=4.45.2
13
+
14
+ [flash]
15
+ flash-attn>=2.6.3
src/genmo.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ genmo
src/genmo/lib/attn_imports.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import contextmanager
2
+
3
+ import torch
4
+
5
+ try:
6
+ from flash_attn import flash_attn_varlen_qkvpacked_func as flash_varlen_qkvpacked_attn
7
+ except ImportError:
8
+ flash_varlen_qkvpacked_attn = None
9
+
10
+ try:
11
+ from sageattention import sageattn as sage_attn
12
+ except ImportError:
13
+ sage_attn = None
14
+
15
+ try:
16
+ from comfy.ldm.modules.attention import comfy_optimized_attention as comfy_attn
17
+ except ImportError:
18
+ comfy_attn = None
19
+
20
+
21
+ from torch.nn.attention import SDPBackend, sdpa_kernel
22
+
23
+ backends = []
24
+ if torch.cuda.get_device_properties(0).major < 7:
25
+ backends.append(SDPBackend.MATH)
26
+ if torch.cuda.get_device_properties(0).major >= 9.0:
27
+ backends.append(SDPBackend.CUDNN_ATTENTION)
28
+ else:
29
+ backends.append(SDPBackend.EFFICIENT_ATTENTION)
30
+
31
+
32
+ @contextmanager
33
+ def sdpa_attn_ctx():
34
+ with sdpa_kernel(backends):
35
+ yield
src/genmo/lib/progress.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ from typing import Any, Iterable, Iterator, Optional
3
+
4
+ try:
5
+ from tqdm import tqdm
6
+ except ImportError:
7
+ tqdm = None
8
+
9
+ try:
10
+ from ray.experimental.tqdm_ray import tqdm as ray_tqdm
11
+ except:
12
+ ray_tqdm = None
13
+
14
+ # Global state
15
+ _current_progress_type = "none"
16
+ _is_progress_bar_active = False
17
+
18
+
19
+ class DummyProgressBar:
20
+ """A no-op progress bar that mimics tqdm interface"""
21
+
22
+ def __init__(self, iterable=None, **kwargs):
23
+ self.iterable = iterable
24
+
25
+ def __iter__(self):
26
+ return iter(self.iterable)
27
+
28
+ def update(self, n=1):
29
+ pass
30
+
31
+ def close(self):
32
+ pass
33
+
34
+ def set_description(self, desc):
35
+ pass
36
+
37
+
38
+ def get_new_progress_bar(iterable: Optional[Iterable] = None, **kwargs) -> Any:
39
+ if not _is_progress_bar_active:
40
+ return DummyProgressBar(iterable=iterable, **kwargs)
41
+
42
+ if _current_progress_type == "tqdm":
43
+ if tqdm is None:
44
+ raise ImportError("tqdm is required but not installed. Please install tqdm to use the tqdm progress bar.")
45
+ return tqdm(iterable=iterable, **kwargs)
46
+ elif _current_progress_type == "ray_tqdm":
47
+ if ray_tqdm is None:
48
+ raise ImportError("ray is required but not installed. Please install ray to use the ray_tqdm progress bar.")
49
+ return ray_tqdm(iterable=iterable, **kwargs)
50
+ return DummyProgressBar(iterable=iterable, **kwargs)
51
+
52
+
53
+ @contextlib.contextmanager
54
+ def progress_bar(type: str = "none", enabled=True):
55
+ """
56
+ Context manager for setting progress bar type and options.
57
+
58
+ Args:
59
+ type: Type of progress bar ("none" or "tqdm")
60
+ **options: Options to pass to the progress bar (e.g., total, desc)
61
+
62
+ Raises:
63
+ ValueError: If progress bar type is invalid
64
+ RuntimeError: If progress bars are nested
65
+
66
+ Example:
67
+ with progress_bar(type="tqdm", total=100):
68
+ for i in get_new_progress_bar(range(100)):
69
+ process(i)
70
+ """
71
+ if type not in ("none", "tqdm", "ray_tqdm"):
72
+ raise ValueError("Progress bar type must be 'none' or 'tqdm' or 'ray_tqdm'")
73
+ if not enabled:
74
+ type = "none"
75
+ global _current_progress_type, _is_progress_bar_active
76
+
77
+ if _is_progress_bar_active:
78
+ raise RuntimeError("Nested progress bars are not supported")
79
+
80
+ _is_progress_bar_active = True
81
+ _current_progress_type = type
82
+
83
+ try:
84
+ yield
85
+ finally:
86
+ _is_progress_bar_active = False
87
+ _current_progress_type = "none"
src/genmo/lib/utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import tempfile
4
+ import time
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+
9
+ from genmo.lib.progress import get_new_progress_bar
10
+
11
+ class Timer:
12
+ def __init__(self):
13
+ self.times = {} # Dictionary to store times per stage
14
+
15
+ def __call__(self, name):
16
+ print(f"Timing {name}")
17
+ return self.TimerContextManager(self, name)
18
+
19
+ def print_stats(self):
20
+ total_time = sum(self.times.values())
21
+ # Print table header
22
+ print("{:<20} {:>10} {:>10}".format("Stage", "Time(s)", "Percent"))
23
+ for name, t in self.times.items():
24
+ percent = (t / total_time) * 100 if total_time > 0 else 0
25
+ print("{:<20} {:>10.2f} {:>9.2f}%".format(name, t, percent))
26
+
27
+ class TimerContextManager:
28
+ def __init__(self, outer, name):
29
+ self.outer = outer # Reference to the Timer instance
30
+ self.name = name
31
+ self.start_time = None
32
+
33
+ def __enter__(self):
34
+ self.start_time = time.perf_counter()
35
+ return self
36
+
37
+ def __exit__(self, exc_type, exc_value, traceback):
38
+ end_time = time.perf_counter()
39
+ elapsed = end_time - self.start_time
40
+ self.outer.times[self.name] = self.outer.times.get(self.name, 0) + elapsed
41
+
42
+
43
+ def save_video(final_frames, output_path):
44
+ with tempfile.TemporaryDirectory() as tmpdir:
45
+ frame_paths = []
46
+ for i, frame in enumerate(get_new_progress_bar(final_frames)):
47
+ frame = (frame * 255).astype(np.uint8)
48
+ frame_img = Image.fromarray(frame)
49
+ frame_path = os.path.join(tmpdir, f"frame_{i:04d}.png")
50
+ frame_img.save(frame_path)
51
+ frame_paths.append(frame_path)
52
+
53
+ frame_pattern = os.path.join(tmpdir, "frame_%04d.png")
54
+ ffmpeg_cmd = f"ffmpeg -y -r 30 -i {frame_pattern} -vcodec libx264 -pix_fmt yuv420p {output_path}"
55
+ try:
56
+ subprocess.run(ffmpeg_cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
57
+ except subprocess.CalledProcessError as e:
58
+ print(f"Error occurred while running ffmpeg:\n{e.stderr.decode()}")
src/genmo/mochi_preview/__init__.py ADDED
File without changes
src/genmo/mochi_preview/dit/joint_model/__init__.py ADDED
File without changes
src/genmo/mochi_preview/dit/joint_model/asymm_models_joint.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, List, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.attention import sdpa_kernel
9
+
10
+ import genmo.mochi_preview.dit.joint_model.context_parallel as cp
11
+ from genmo.mochi_preview.dit.joint_model.layers import (
12
+ FeedForward,
13
+ PatchEmbed,
14
+ RMSNorm,
15
+ TimestepEmbedder,
16
+ )
17
+ from genmo.mochi_preview.dit.joint_model.mod_rmsnorm import modulated_rmsnorm
18
+ from genmo.mochi_preview.dit.joint_model.residual_tanh_gated_rmsnorm import (
19
+ residual_tanh_gated_rmsnorm,
20
+ )
21
+ from genmo.mochi_preview.dit.joint_model.rope_mixed import (
22
+ compute_mixed_rotation,
23
+ create_position_matrix,
24
+ )
25
+ from genmo.mochi_preview.dit.joint_model.temporal_rope import apply_rotary_emb_qk_real
26
+ from genmo.mochi_preview.dit.joint_model.utils import (
27
+ AttentionPool,
28
+ modulate,
29
+ pad_and_split_xy,
30
+ unify_streams,
31
+ )
32
+
33
+ COMPILE_FINAL_LAYER = os.environ.get("COMPILE_DIT") == "1"
34
+ COMPILE_MMDIT_BLOCK = os.environ.get("COMPILE_DIT") == "1"
35
+
36
+ from genmo.lib.attn_imports import comfy_attn, flash_varlen_qkvpacked_attn, sage_attn, sdpa_attn_ctx
37
+
38
+
39
+ class AsymmetricAttention(nn.Module):
40
+ def __init__(
41
+ self,
42
+ dim_x: int,
43
+ dim_y: int,
44
+ num_heads: int = 8,
45
+ qkv_bias: bool = True,
46
+ qk_norm: bool = False,
47
+ update_y: bool = True,
48
+ out_bias: bool = True,
49
+ attention_mode: str = "flash",
50
+ softmax_scale: Optional[float] = None,
51
+ device: Optional[torch.device] = None,
52
+ ):
53
+ super().__init__()
54
+ self.attention_mode = attention_mode
55
+ self.dim_x = dim_x
56
+ self.dim_y = dim_y
57
+ self.num_heads = num_heads
58
+ self.head_dim = dim_x // num_heads
59
+ self.update_y = update_y
60
+ self.softmax_scale = softmax_scale
61
+ if dim_x % num_heads != 0:
62
+ raise ValueError(f"dim_x={dim_x} should be divisible by num_heads={num_heads}")
63
+
64
+ # Input layers.
65
+ self.qkv_bias = qkv_bias
66
+ self.qkv_x = nn.Linear(dim_x, 3 * dim_x, bias=qkv_bias, device=device)
67
+ # Project text features to match visual features (dim_y -> dim_x)
68
+ self.qkv_y = nn.Linear(dim_y, 3 * dim_x, bias=qkv_bias, device=device)
69
+
70
+ # Query and key normalization for stability.
71
+ assert qk_norm
72
+ self.q_norm_x = RMSNorm(self.head_dim, device=device)
73
+ self.k_norm_x = RMSNorm(self.head_dim, device=device)
74
+ self.q_norm_y = RMSNorm(self.head_dim, device=device)
75
+ self.k_norm_y = RMSNorm(self.head_dim, device=device)
76
+
77
+ # Output layers. y features go back down from dim_x -> dim_y.
78
+ self.proj_x = nn.Linear(dim_x, dim_x, bias=out_bias, device=device)
79
+ self.proj_y = nn.Linear(dim_x, dim_y, bias=out_bias, device=device) if update_y else nn.Identity()
80
+
81
+ def run_qkv_y(self, y):
82
+ cp_rank, cp_size = cp.get_cp_rank_size()
83
+ local_heads = self.num_heads // cp_size
84
+
85
+ if cp.is_cp_active():
86
+ # Only predict local heads.
87
+ assert not self.qkv_bias
88
+ W_qkv_y = self.qkv_y.weight.view(3, self.num_heads, self.head_dim, self.dim_y)
89
+ W_qkv_y = W_qkv_y.narrow(1, cp_rank * local_heads, local_heads)
90
+ W_qkv_y = W_qkv_y.reshape(3 * local_heads * self.head_dim, self.dim_y)
91
+ qkv_y = F.linear(y, W_qkv_y, None) # (B, L, 3 * local_h * head_dim)
92
+ else:
93
+ qkv_y = self.qkv_y(y) # (B, L, 3 * dim)
94
+
95
+ qkv_y = qkv_y.view(qkv_y.size(0), qkv_y.size(1), 3, local_heads, self.head_dim)
96
+ q_y, k_y, v_y = qkv_y.unbind(2)
97
+ return q_y, k_y, v_y
98
+
99
+ def prepare_qkv(
100
+ self,
101
+ x: torch.Tensor, # (B, N, dim_x)
102
+ y: torch.Tensor, # (B, L, dim_y)
103
+ *,
104
+ scale_x: torch.Tensor,
105
+ scale_y: torch.Tensor,
106
+ rope_cos: torch.Tensor,
107
+ rope_sin: torch.Tensor,
108
+ valid_token_indices: torch.Tensor,
109
+ ):
110
+ # Pre-norm for visual features
111
+ x = modulated_rmsnorm(x, scale_x) # (B, M, dim_x) where M = N / cp_group_size
112
+
113
+ # Process visual features
114
+ qkv_x = self.qkv_x(x) # (B, M, 3 * dim_x)
115
+ assert qkv_x.dtype == torch.bfloat16
116
+ qkv_x = cp.all_to_all_collect_tokens(qkv_x, self.num_heads) # (3, B, N, local_h, head_dim)
117
+
118
+ # Process text features
119
+ y = modulated_rmsnorm(y, scale_y) # (B, L, dim_y)
120
+ q_y, k_y, v_y = self.run_qkv_y(y) # (B, L, local_heads, head_dim)
121
+ q_y = self.q_norm_y(q_y)
122
+ k_y = self.k_norm_y(k_y)
123
+
124
+ # Split qkv_x into q, k, v
125
+ q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim)
126
+ q_x = self.q_norm_x(q_x)
127
+ q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
128
+ k_x = self.k_norm_x(k_x)
129
+ k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
130
+
131
+ # Unite streams
132
+ qkv = unify_streams(
133
+ q_x,
134
+ k_x,
135
+ v_x,
136
+ q_y,
137
+ k_y,
138
+ v_y,
139
+ valid_token_indices,
140
+ )
141
+
142
+ return qkv
143
+
144
+ def flash_attention(self, qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim):
145
+ with torch.autocast("cuda", enabled=False):
146
+ out: torch.Tensor = flash_varlen_qkvpacked_attn(
147
+ qkv,
148
+ cu_seqlens=cu_seqlens,
149
+ max_seqlen=max_seqlen_in_batch,
150
+ dropout_p=0.0,
151
+ softmax_scale=self.softmax_scale,
152
+ ) # (total, local_heads, head_dim)
153
+ return out.view(total, local_dim)
154
+
155
+ def sdpa_attention(self, qkv):
156
+ q, k, v = rearrange(qkv, "(b s) t h d -> t b h s d", b=1)
157
+ with torch.autocast("cuda", enabled=False):
158
+ with sdpa_attn_ctx():
159
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
160
+ return rearrange(out, "b h s d -> s (b h d)")
161
+
162
+ def sage_attention(self, qkv):
163
+ q, k, v = rearrange(qkv, "(b s) t h d -> t b h s d", b=1)
164
+ with torch.autocast("cuda", enabled=False):
165
+ out = sage_attn(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
166
+ return rearrange(out, "b h s d -> s (b h d)")
167
+
168
+ def comfy_attention(self, qkv):
169
+ q, k, v = rearrange(qkv, "(b s) t h d -> t b h s d", b=1)
170
+ with torch.autocast("cuda", enabled=False):
171
+ out = comfy_attn(q, k, v, heads=self.num_heads, skip_reshape=True)
172
+ return out.squeeze(0)
173
+
174
+ @torch.compiler.disable()
175
+ def run_attention(
176
+ self,
177
+ qkv: torch.Tensor, # (total <= B * (N + L), 3, local_heads, head_dim)
178
+ *,
179
+ B: int,
180
+ L: int,
181
+ M: int,
182
+ cu_seqlens: torch.Tensor,
183
+ max_seqlen_in_batch: int,
184
+ valid_token_indices: torch.Tensor,
185
+ ):
186
+ _, cp_size = cp.get_cp_rank_size()
187
+ N = cp_size * M
188
+ assert self.num_heads % cp_size == 0
189
+ local_heads = self.num_heads // cp_size
190
+ local_dim = local_heads * self.head_dim
191
+ total = qkv.size(0)
192
+
193
+ if self.attention_mode != "flash":
194
+ assert B == 1, f"Non-flash attention only supports batch size 1, got {B}"
195
+
196
+ if self.attention_mode == "flash":
197
+ out = self.flash_attention(qkv, cu_seqlens, max_seqlen_in_batch, total, local_dim)
198
+ elif self.attention_mode == "sdpa":
199
+ out = self.sdpa_attention(qkv)
200
+ elif self.attention_mode == "sage":
201
+ out = self.sage_attention(qkv)
202
+ elif self.attention_mode == "comfy":
203
+ out = self.comfy_attention(qkv)
204
+
205
+ x, y = pad_and_split_xy(out, valid_token_indices, B, N, L, qkv.dtype)
206
+ assert x.size() == (B, N, local_dim)
207
+ assert y.size() == (B, L, local_dim)
208
+
209
+ x = x.view(B, N, local_heads, self.head_dim)
210
+ x = cp.all_to_all_collect_heads(x) # (B, M, dim_x = num_heads * head_dim)
211
+ x = self.proj_x(x) # (B, M, dim_x)
212
+
213
+ if cp.is_cp_active():
214
+ y = cp.all_gather(y) # (cp_size * B, L, local_heads * head_dim)
215
+ y = rearrange(y, "(G B) L D -> B L (G D)", G=cp_size, D=local_dim) # (B, L, dim_x)
216
+ y = self.proj_y(y) # (B, L, dim_y)
217
+ return x, y
218
+
219
+ def forward(
220
+ self,
221
+ x: torch.Tensor, # (B, N, dim_x)
222
+ y: torch.Tensor, # (B, L, dim_y)
223
+ *,
224
+ scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
225
+ scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
226
+ packed_indices: Dict[str, torch.Tensor] = None,
227
+ **rope_rotation,
228
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
229
+ """Forward pass of asymmetric multi-modal attention.
230
+
231
+ Args:
232
+ x: (B, N, dim_x) tensor for visual tokens
233
+ y: (B, L, dim_y) tensor of text token features
234
+ packed_indices: Dict with keys for Flash Attention
235
+ num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
236
+
237
+ Returns:
238
+ x: (B, N, dim_x) tensor of visual tokens after multi-modal attention
239
+ y: (B, L, dim_y) tensor of text token features after multi-modal attention
240
+ """
241
+ B, L, _ = y.shape
242
+ _, M, _ = x.shape
243
+
244
+ # Predict a packed QKV tensor from visual and text features.
245
+ # Don't checkpoint the all_to_all.
246
+ qkv = self.prepare_qkv(
247
+ x=x,
248
+ y=y,
249
+ scale_x=scale_x,
250
+ scale_y=scale_y,
251
+ rope_cos=rope_rotation.get("rope_cos"),
252
+ rope_sin=rope_rotation.get("rope_sin"),
253
+ valid_token_indices=packed_indices["valid_token_indices_kv"],
254
+ ) # (total <= B * (N + L), 3, local_heads, head_dim)
255
+
256
+ x, y = self.run_attention(
257
+ qkv,
258
+ B=B,
259
+ L=L,
260
+ M=M,
261
+ cu_seqlens=packed_indices["cu_seqlens_kv"],
262
+ max_seqlen_in_batch=packed_indices["max_seqlen_in_batch_kv"],
263
+ valid_token_indices=packed_indices["valid_token_indices_kv"],
264
+ )
265
+ return x, y
266
+
267
+
268
+ @torch.compile(disable=not COMPILE_MMDIT_BLOCK)
269
+ class AsymmetricJointBlock(nn.Module):
270
+ def __init__(
271
+ self,
272
+ hidden_size_x: int,
273
+ hidden_size_y: int,
274
+ num_heads: int,
275
+ *,
276
+ mlp_ratio_x: float = 8.0, # Ratio of hidden size to d_model for MLP for visual tokens.
277
+ mlp_ratio_y: float = 4.0, # Ratio of hidden size to d_model for MLP for text tokens.
278
+ update_y: bool = True, # Whether to update text tokens in this block.
279
+ device: Optional[torch.device] = None,
280
+ **block_kwargs,
281
+ ):
282
+ super().__init__()
283
+ self.update_y = update_y
284
+ self.hidden_size_x = hidden_size_x
285
+ self.hidden_size_y = hidden_size_y
286
+ self.mod_x = nn.Linear(hidden_size_x, 4 * hidden_size_x, device=device)
287
+ if self.update_y:
288
+ self.mod_y = nn.Linear(hidden_size_x, 4 * hidden_size_y, device=device)
289
+ else:
290
+ self.mod_y = nn.Linear(hidden_size_x, hidden_size_y, device=device)
291
+
292
+ # Self-attention:
293
+ self.attn = AsymmetricAttention(
294
+ hidden_size_x,
295
+ hidden_size_y,
296
+ num_heads=num_heads,
297
+ update_y=update_y,
298
+ device=device,
299
+ **block_kwargs,
300
+ )
301
+
302
+ # MLP.
303
+ mlp_hidden_dim_x = int(hidden_size_x * mlp_ratio_x)
304
+ assert mlp_hidden_dim_x == int(1536 * 8)
305
+ self.mlp_x = FeedForward(
306
+ in_features=hidden_size_x,
307
+ hidden_size=mlp_hidden_dim_x,
308
+ multiple_of=256,
309
+ ffn_dim_multiplier=None,
310
+ device=device,
311
+ )
312
+
313
+ # MLP for text not needed in last block.
314
+ if self.update_y:
315
+ mlp_hidden_dim_y = int(hidden_size_y * mlp_ratio_y)
316
+ self.mlp_y = FeedForward(
317
+ in_features=hidden_size_y,
318
+ hidden_size=mlp_hidden_dim_y,
319
+ multiple_of=256,
320
+ ffn_dim_multiplier=None,
321
+ device=device,
322
+ )
323
+
324
+ def forward(
325
+ self,
326
+ x: torch.Tensor,
327
+ c: torch.Tensor,
328
+ y: torch.Tensor,
329
+ **attn_kwargs,
330
+ ):
331
+ """Forward pass of a block.
332
+
333
+ Args:
334
+ x: (B, N, dim) tensor of visual tokens
335
+ c: (B, dim) tensor of conditioned features
336
+ y: (B, L, dim) tensor of text tokens
337
+ num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
338
+
339
+ Returns:
340
+ x: (B, N, dim) tensor of visual tokens after block
341
+ y: (B, L, dim) tensor of text tokens after block
342
+ """
343
+ N = x.size(1)
344
+
345
+ c = F.silu(c)
346
+ mod_x = self.mod_x(c)
347
+ scale_msa_x, gate_msa_x, scale_mlp_x, gate_mlp_x = mod_x.chunk(4, dim=1)
348
+
349
+ mod_y = self.mod_y(c)
350
+ if self.update_y:
351
+ scale_msa_y, gate_msa_y, scale_mlp_y, gate_mlp_y = mod_y.chunk(4, dim=1)
352
+ else:
353
+ scale_msa_y = mod_y
354
+
355
+ # Self-attention block.
356
+ x_attn, y_attn = self.attn(
357
+ x,
358
+ y,
359
+ scale_x=scale_msa_x,
360
+ scale_y=scale_msa_y,
361
+ **attn_kwargs,
362
+ )
363
+
364
+ assert x_attn.size(1) == N
365
+ x = residual_tanh_gated_rmsnorm(x, x_attn, gate_msa_x)
366
+ if self.update_y:
367
+ y = residual_tanh_gated_rmsnorm(y, y_attn, gate_msa_y)
368
+
369
+ # MLP block.
370
+ x = self.ff_block_x(x, scale_mlp_x, gate_mlp_x)
371
+ if self.update_y:
372
+ y = self.ff_block_y(y, scale_mlp_y, gate_mlp_y)
373
+
374
+ return x, y
375
+
376
+ def ff_block_x(self, x, scale_x, gate_x):
377
+ x_mod = modulated_rmsnorm(x, scale_x)
378
+ x_res = self.mlp_x(x_mod)
379
+ x = residual_tanh_gated_rmsnorm(x, x_res, gate_x) # Sandwich norm
380
+ return x
381
+
382
+ def ff_block_y(self, y, scale_y, gate_y):
383
+ y_mod = modulated_rmsnorm(y, scale_y)
384
+ y_res = self.mlp_y(y_mod)
385
+ y = residual_tanh_gated_rmsnorm(y, y_res, gate_y) # Sandwich norm
386
+ return y
387
+
388
+
389
+ @torch.compile(disable=not COMPILE_FINAL_LAYER)
390
+ class FinalLayer(nn.Module):
391
+ """
392
+ The final layer of DiT.
393
+ """
394
+
395
+ def __init__(
396
+ self,
397
+ hidden_size,
398
+ patch_size,
399
+ out_channels,
400
+ device: Optional[torch.device] = None,
401
+ ):
402
+ super().__init__()
403
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, device=device)
404
+ self.mod = nn.Linear(hidden_size, 2 * hidden_size, device=device)
405
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, device=device)
406
+
407
+ def forward(self, x, c):
408
+ c = F.silu(c)
409
+ shift, scale = self.mod(c).chunk(2, dim=1)
410
+ x = modulate(self.norm_final(x), shift, scale)
411
+ x = self.linear(x)
412
+ return x
413
+
414
+
415
+ class AsymmDiTJoint(nn.Module):
416
+ """
417
+ Diffusion model with a Transformer backbone.
418
+
419
+ Ingests text embeddings instead of a label.
420
+ """
421
+
422
+ def __init__(
423
+ self,
424
+ *,
425
+ patch_size=2,
426
+ in_channels=4,
427
+ hidden_size_x=1152,
428
+ hidden_size_y=1152,
429
+ depth=48,
430
+ num_heads=16,
431
+ mlp_ratio_x=8.0,
432
+ mlp_ratio_y=4.0,
433
+ t5_feat_dim: int = 4096,
434
+ t5_token_length: int = 256,
435
+ patch_embed_bias: bool = True,
436
+ timestep_mlp_bias: bool = True,
437
+ timestep_scale: Optional[float] = None,
438
+ use_extended_posenc: bool = False,
439
+ rope_theta: float = 10000.0,
440
+ device: Optional[torch.device] = None,
441
+ **block_kwargs,
442
+ ):
443
+ super().__init__()
444
+ self.in_channels = in_channels
445
+ self.out_channels = in_channels
446
+ self.patch_size = patch_size
447
+ self.num_heads = num_heads
448
+ self.hidden_size_x = hidden_size_x
449
+ self.hidden_size_y = hidden_size_y
450
+ self.head_dim = hidden_size_x // num_heads # Head dimension and count is determined by visual.
451
+ self.use_extended_posenc = use_extended_posenc
452
+ self.t5_token_length = t5_token_length
453
+ self.t5_feat_dim = t5_feat_dim
454
+ self.rope_theta = rope_theta # Scaling factor for frequency computation for temporal RoPE.
455
+
456
+ self.x_embedder = PatchEmbed(
457
+ patch_size=patch_size,
458
+ in_chans=in_channels,
459
+ embed_dim=hidden_size_x,
460
+ bias=patch_embed_bias,
461
+ device=device,
462
+ )
463
+ # Conditionings
464
+ # Timestep
465
+ self.t_embedder = TimestepEmbedder(hidden_size_x, bias=timestep_mlp_bias, timestep_scale=timestep_scale)
466
+
467
+ # Caption Pooling (T5)
468
+ self.t5_y_embedder = AttentionPool(t5_feat_dim, num_heads=8, output_dim=hidden_size_x, device=device)
469
+
470
+ # Dense Embedding Projection (T5)
471
+ self.t5_yproj = nn.Linear(t5_feat_dim, hidden_size_y, bias=True, device=device)
472
+
473
+ # Initialize pos_frequencies as an empty parameter.
474
+ self.pos_frequencies = nn.Parameter(torch.empty(3, self.num_heads, self.head_dim // 2, device=device))
475
+
476
+ # for depth 48:
477
+ # b = 0: AsymmetricJointBlock, update_y=True
478
+ # b = 1: AsymmetricJointBlock, update_y=True
479
+ # ...
480
+ # b = 46: AsymmetricJointBlock, update_y=True
481
+ # b = 47: AsymmetricJointBlock, update_y=False. No need to update text features.
482
+ blocks = []
483
+ for b in range(depth):
484
+ # Joint multi-modal block
485
+ update_y = b < depth - 1
486
+ block = AsymmetricJointBlock(
487
+ hidden_size_x,
488
+ hidden_size_y,
489
+ num_heads,
490
+ mlp_ratio_x=mlp_ratio_x,
491
+ mlp_ratio_y=mlp_ratio_y,
492
+ update_y=update_y,
493
+ device=device,
494
+ **block_kwargs,
495
+ )
496
+
497
+ blocks.append(block)
498
+ self.blocks = nn.ModuleList(blocks)
499
+
500
+ self.final_layer = FinalLayer(hidden_size_x, patch_size, self.out_channels, device=device)
501
+
502
+ def embed_x(self, x: torch.Tensor) -> torch.Tensor:
503
+ """
504
+ Args:
505
+ x: (B, C=12, T, H, W) tensor of visual tokens
506
+
507
+ Returns:
508
+ x: (B, C=3072, N) tensor of visual tokens with positional embedding.
509
+ """
510
+ return self.x_embedder(x) # Convert BcTHW to BCN
511
+
512
+ @torch.compile(disable=not COMPILE_MMDIT_BLOCK)
513
+ def prepare(
514
+ self,
515
+ x: torch.Tensor,
516
+ sigma: torch.Tensor,
517
+ t5_feat: torch.Tensor,
518
+ t5_mask: torch.Tensor,
519
+ ):
520
+ """Prepare input and conditioning embeddings."""
521
+
522
+ with torch.profiler.record_function("x_emb_pe"):
523
+ # Visual patch embeddings with positional encoding.
524
+ T, H, W = x.shape[-3:]
525
+ pH, pW = H // self.patch_size, W // self.patch_size
526
+ x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
527
+ assert x.ndim == 3
528
+ B = x.size(0)
529
+
530
+ with torch.profiler.record_function("rope_cis"):
531
+ # Construct position array of size [N, 3].
532
+ # pos[:, 0] is the frame index for each location,
533
+ # pos[:, 1] is the row index for each location, and
534
+ # pos[:, 2] is the column index for each location.
535
+ pH, pW = H // self.patch_size, W // self.patch_size
536
+ N = T * pH * pW
537
+ assert x.size(1) == N
538
+ pos = create_position_matrix(T, pH=pH, pW=pW, device=x.device, dtype=torch.float32) # (N, 3)
539
+ rope_cos, rope_sin = compute_mixed_rotation(
540
+ freqs=self.pos_frequencies, pos=pos
541
+ ) # Each are (N, num_heads, dim // 2)
542
+
543
+ with torch.profiler.record_function("t_emb"):
544
+ # Global vector embedding for conditionings.
545
+ c_t = self.t_embedder(1 - sigma) # (B, D)
546
+
547
+ with torch.profiler.record_function("t5_pool"):
548
+ # Pool T5 tokens using attention pooler
549
+ # Note y_feat[1] contains T5 token features.
550
+ assert (
551
+ t5_feat.size(1) == self.t5_token_length
552
+ ), f"Expected L={self.t5_token_length}, got {t5_feat.shape} for y_feat."
553
+ t5_y_pool = self.t5_y_embedder(t5_feat, t5_mask) # (B, D)
554
+ assert t5_y_pool.size(0) == B, f"Expected B={B}, got {t5_y_pool.shape} for t5_y_pool."
555
+
556
+ c = c_t + t5_y_pool
557
+
558
+ y_feat = self.t5_yproj(t5_feat) # (B, L, t5_feat_dim) --> (B, L, D)
559
+
560
+ return x, c, y_feat, rope_cos, rope_sin
561
+
562
+ def forward(
563
+ self,
564
+ x: torch.Tensor,
565
+ sigma: torch.Tensor,
566
+ y_feat: List[torch.Tensor],
567
+ y_mask: List[torch.Tensor],
568
+ packed_indices: Dict[str, torch.Tensor] = None,
569
+ rope_cos: torch.Tensor = None,
570
+ rope_sin: torch.Tensor = None,
571
+ ):
572
+ """Forward pass of DiT.
573
+
574
+ Args:
575
+ x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
576
+ sigma: (B,) tensor of noise standard deviations
577
+ y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
578
+ y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
579
+ packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
580
+ """
581
+ B, _, T, H, W = x.shape
582
+
583
+ # Use EFFICIENT_ATTENTION backend for T5 pooling, since we have a mask.
584
+ # Have to call sdpa_kernel outside of a torch.compile region.
585
+ with sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
586
+ x, c, y_feat, rope_cos, rope_sin = self.prepare(x, sigma, y_feat[0], y_mask[0])
587
+ del y_mask
588
+
589
+ cp_rank, cp_size = cp.get_cp_rank_size()
590
+ N = x.size(1)
591
+ M = N // cp_size
592
+ assert N % cp_size == 0, f"Visual sequence length ({x.shape[1]}) must be divisible by cp_size ({cp_size})."
593
+
594
+ if cp_size > 1:
595
+ x = x.narrow(1, cp_rank * M, M)
596
+
597
+ assert self.num_heads % cp_size == 0
598
+ local_heads = self.num_heads // cp_size
599
+ rope_cos = rope_cos.narrow(1, cp_rank * local_heads, local_heads)
600
+ rope_sin = rope_sin.narrow(1, cp_rank * local_heads, local_heads)
601
+
602
+ for i, block in enumerate(self.blocks):
603
+ x, y_feat = block(
604
+ x,
605
+ c,
606
+ y_feat,
607
+ rope_cos=rope_cos,
608
+ rope_sin=rope_sin,
609
+ packed_indices=packed_indices,
610
+ ) # (B, M, D), (B, L, D)
611
+ del y_feat # Final layers don't use dense text features.
612
+
613
+ x = self.final_layer(x, c) # (B, M, patch_size ** 2 * out_channels)
614
+
615
+ patch = x.size(2)
616
+ x = cp.all_gather(x)
617
+ x = rearrange(x, "(G B) M P -> B (G M) P", G=cp_size, P=patch)
618
+ x = rearrange(
619
+ x,
620
+ "B (T hp wp) (p1 p2 c) -> B c T (hp p1) (wp p2)",
621
+ T=T,
622
+ hp=H // self.patch_size,
623
+ wp=W // self.patch_size,
624
+ p1=self.patch_size,
625
+ p2=self.patch_size,
626
+ c=self.out_channels,
627
+ )
628
+
629
+ return x
src/genmo/mochi_preview/dit/joint_model/context_parallel.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ from einops import rearrange
4
+
5
+ _CONTEXT_PARALLEL_GROUP = None
6
+ _CONTEXT_PARALLEL_RANK = None
7
+ _CONTEXT_PARALLEL_GROUP_SIZE = None
8
+ _CONTEXT_PARALLEL_GROUP_RANKS = None
9
+
10
+
11
+ def get_cp_rank_size():
12
+ if _CONTEXT_PARALLEL_GROUP:
13
+ return _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE
14
+ else:
15
+ return 0, 1
16
+
17
+
18
+ def local_shard(x: torch.Tensor, dim: int = 2) -> torch.Tensor:
19
+ if not _CONTEXT_PARALLEL_GROUP:
20
+ return x
21
+
22
+ cp_rank, cp_size = get_cp_rank_size()
23
+ return x.tensor_split(cp_size, dim=dim)[cp_rank]
24
+
25
+
26
+ def set_cp_group(cp_group, ranks, global_rank):
27
+ global _CONTEXT_PARALLEL_GROUP, _CONTEXT_PARALLEL_RANK, _CONTEXT_PARALLEL_GROUP_SIZE, _CONTEXT_PARALLEL_GROUP_RANKS
28
+ if _CONTEXT_PARALLEL_GROUP is not None:
29
+ raise RuntimeError("CP group already initialized.")
30
+ _CONTEXT_PARALLEL_GROUP = cp_group
31
+ _CONTEXT_PARALLEL_RANK = dist.get_rank(cp_group)
32
+ _CONTEXT_PARALLEL_GROUP_SIZE = dist.get_world_size(cp_group)
33
+ _CONTEXT_PARALLEL_GROUP_RANKS = ranks
34
+
35
+ assert _CONTEXT_PARALLEL_RANK == ranks.index(
36
+ global_rank
37
+ ), f"Rank mismatch: {global_rank} in {ranks} does not have position {_CONTEXT_PARALLEL_RANK} "
38
+ assert _CONTEXT_PARALLEL_GROUP_SIZE == len(
39
+ ranks
40
+ ), f"Group size mismatch: {_CONTEXT_PARALLEL_GROUP_SIZE} != len({ranks})"
41
+
42
+
43
+ def get_cp_group():
44
+ if _CONTEXT_PARALLEL_GROUP is None:
45
+ raise RuntimeError("CP group not initialized")
46
+ return _CONTEXT_PARALLEL_GROUP
47
+
48
+
49
+ def is_cp_active():
50
+ return _CONTEXT_PARALLEL_GROUP is not None
51
+
52
+
53
+ class AllGatherIntoTensorFunction(torch.autograd.Function):
54
+ @staticmethod
55
+ def forward(ctx, x: torch.Tensor, reduce_dtype, group: dist.ProcessGroup):
56
+ ctx.reduce_dtype = reduce_dtype
57
+ ctx.group = group
58
+ ctx.batch_size = x.size(0)
59
+ group_size = dist.get_world_size(group)
60
+
61
+ x = x.contiguous()
62
+ output = torch.empty(group_size * x.size(0), *x.shape[1:], dtype=x.dtype, device=x.device)
63
+ dist.all_gather_into_tensor(output, x, group=group)
64
+ return output
65
+
66
+
67
+ def all_gather(tensor: torch.Tensor) -> torch.Tensor:
68
+ if not _CONTEXT_PARALLEL_GROUP:
69
+ return tensor
70
+
71
+ return AllGatherIntoTensorFunction.apply(tensor, torch.float32, _CONTEXT_PARALLEL_GROUP)
72
+
73
+
74
+ @torch.compiler.disable()
75
+ def _all_to_all_single(output, input, group):
76
+ # Disable compilation since torch compile changes contiguity.
77
+ assert input.is_contiguous(), "Input tensor must be contiguous."
78
+ assert output.is_contiguous(), "Output tensor must be contiguous."
79
+ return dist.all_to_all_single(output, input, group=group)
80
+
81
+
82
+ class CollectTokens(torch.autograd.Function):
83
+ @staticmethod
84
+ def forward(ctx, qkv: torch.Tensor, group: dist.ProcessGroup, num_heads: int):
85
+ """Redistribute heads and receive tokens.
86
+
87
+ Args:
88
+ qkv: query, key or value. Shape: [B, M, 3 * num_heads * head_dim]
89
+
90
+ Returns:
91
+ qkv: shape: [3, B, N, local_heads, head_dim]
92
+
93
+ where M is the number of local tokens,
94
+ N = cp_size * M is the number of global tokens,
95
+ local_heads = num_heads // cp_size is the number of local heads.
96
+ """
97
+ ctx.group = group
98
+ ctx.num_heads = num_heads
99
+ cp_size = dist.get_world_size(group)
100
+ assert num_heads % cp_size == 0
101
+ ctx.local_heads = num_heads // cp_size
102
+
103
+ qkv = rearrange(
104
+ qkv,
105
+ "B M (qkv G h d) -> G M h B (qkv d)",
106
+ qkv=3,
107
+ G=cp_size,
108
+ h=ctx.local_heads,
109
+ ).contiguous()
110
+
111
+ output_chunks = torch.empty_like(qkv)
112
+ _all_to_all_single(output_chunks, qkv, group=group)
113
+
114
+ return rearrange(output_chunks, "G M h B (qkv d) -> qkv B (G M) h d", qkv=3)
115
+
116
+
117
+ def all_to_all_collect_tokens(x: torch.Tensor, num_heads: int) -> torch.Tensor:
118
+ if not _CONTEXT_PARALLEL_GROUP:
119
+ # Move QKV dimension to the front.
120
+ # B M (3 H d) -> 3 B M H d
121
+ B, M, _ = x.size()
122
+ x = x.view(B, M, 3, num_heads, -1)
123
+ return x.permute(2, 0, 1, 3, 4)
124
+
125
+ return CollectTokens.apply(x, _CONTEXT_PARALLEL_GROUP, num_heads)
126
+
127
+
128
+ class CollectHeads(torch.autograd.Function):
129
+ @staticmethod
130
+ def forward(ctx, x: torch.Tensor, group: dist.ProcessGroup):
131
+ """Redistribute tokens and receive heads.
132
+
133
+ Args:
134
+ x: Output of attention. Shape: [B, N, local_heads, head_dim]
135
+
136
+ Returns:
137
+ Shape: [B, M, num_heads * head_dim]
138
+ """
139
+ ctx.group = group
140
+ ctx.local_heads = x.size(2)
141
+ ctx.head_dim = x.size(3)
142
+ group_size = dist.get_world_size(group)
143
+ x = rearrange(x, "B (G M) h D -> G h M B D", G=group_size).contiguous()
144
+ output = torch.empty_like(x)
145
+ _all_to_all_single(output, x, group=group)
146
+ del x
147
+ return rearrange(output, "G h M B D -> B M (G h D)")
148
+
149
+
150
+ def all_to_all_collect_heads(x: torch.Tensor) -> torch.Tensor:
151
+ if not _CONTEXT_PARALLEL_GROUP:
152
+ # Merge heads.
153
+ return x.view(x.size(0), x.size(1), x.size(2) * x.size(3))
154
+
155
+ return CollectHeads.apply(x, _CONTEXT_PARALLEL_GROUP)
src/genmo/mochi_preview/dit/joint_model/layers.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import math
3
+ from itertools import repeat
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+
11
+
12
+ # From PyTorch internals
13
+ def _ntuple(n):
14
+ def parse(x):
15
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
16
+ return tuple(x)
17
+ return tuple(repeat(x, n))
18
+
19
+ return parse
20
+
21
+
22
+ to_2tuple = _ntuple(2)
23
+
24
+
25
+ class TimestepEmbedder(nn.Module):
26
+ def __init__(
27
+ self,
28
+ hidden_size: int,
29
+ frequency_embedding_size: int = 256,
30
+ *,
31
+ bias: bool = True,
32
+ timestep_scale: Optional[float] = None,
33
+ device: Optional[torch.device] = None,
34
+ ):
35
+ super().__init__()
36
+ self.mlp = nn.Sequential(
37
+ nn.Linear(frequency_embedding_size, hidden_size, bias=bias, device=device),
38
+ nn.SiLU(),
39
+ nn.Linear(hidden_size, hidden_size, bias=bias, device=device),
40
+ )
41
+ self.frequency_embedding_size = frequency_embedding_size
42
+ self.timestep_scale = timestep_scale
43
+
44
+ @staticmethod
45
+ def timestep_embedding(t, dim, max_period=10000):
46
+ half = dim // 2
47
+ freqs = torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
48
+ freqs.mul_(-math.log(max_period) / half).exp_()
49
+ args = t[:, None].float() * freqs[None]
50
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
51
+ if dim % 2:
52
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
53
+ return embedding
54
+
55
+ def forward(self, t):
56
+ if self.timestep_scale is not None:
57
+ t = t * self.timestep_scale
58
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
59
+ t_emb = self.mlp(t_freq)
60
+ return t_emb
61
+
62
+
63
+ class PooledCaptionEmbedder(nn.Module):
64
+ def __init__(
65
+ self,
66
+ caption_feature_dim: int,
67
+ hidden_size: int,
68
+ *,
69
+ bias: bool = True,
70
+ device: Optional[torch.device] = None,
71
+ ):
72
+ super().__init__()
73
+ self.caption_feature_dim = caption_feature_dim
74
+ self.hidden_size = hidden_size
75
+ self.mlp = nn.Sequential(
76
+ nn.Linear(caption_feature_dim, hidden_size, bias=bias, device=device),
77
+ nn.SiLU(),
78
+ nn.Linear(hidden_size, hidden_size, bias=bias, device=device),
79
+ )
80
+
81
+ def forward(self, x):
82
+ return self.mlp(x)
83
+
84
+
85
+ class FeedForward(nn.Module):
86
+ def __init__(
87
+ self,
88
+ in_features: int,
89
+ hidden_size: int,
90
+ multiple_of: int,
91
+ ffn_dim_multiplier: Optional[float],
92
+ device: Optional[torch.device] = None,
93
+ ):
94
+ super().__init__()
95
+ # keep parameter count and computation constant compared to standard FFN
96
+ hidden_size = int(2 * hidden_size / 3)
97
+ # custom dim factor multiplier
98
+ if ffn_dim_multiplier is not None:
99
+ hidden_size = int(ffn_dim_multiplier * hidden_size)
100
+ hidden_size = multiple_of * ((hidden_size + multiple_of - 1) // multiple_of)
101
+
102
+ self.hidden_dim = hidden_size
103
+ self.w1 = nn.Linear(in_features, 2 * hidden_size, bias=False, device=device)
104
+ self.w2 = nn.Linear(hidden_size, in_features, bias=False, device=device)
105
+
106
+ def forward(self, x):
107
+ x, gate = self.w1(x).chunk(2, dim=-1)
108
+ x = self.w2(F.silu(x) * gate)
109
+ return x
110
+
111
+
112
+ class PatchEmbed(nn.Module):
113
+ def __init__(
114
+ self,
115
+ patch_size: int = 16,
116
+ in_chans: int = 3,
117
+ embed_dim: int = 768,
118
+ norm_layer: Optional[Callable] = None,
119
+ flatten: bool = True,
120
+ bias: bool = True,
121
+ dynamic_img_pad: bool = False,
122
+ device: Optional[torch.device] = None,
123
+ ):
124
+ super().__init__()
125
+ self.patch_size = to_2tuple(patch_size)
126
+ self.flatten = flatten
127
+ self.dynamic_img_pad = dynamic_img_pad
128
+
129
+ self.proj = nn.Conv2d(
130
+ in_chans,
131
+ embed_dim,
132
+ kernel_size=patch_size,
133
+ stride=patch_size,
134
+ bias=bias,
135
+ device=device,
136
+ )
137
+ assert norm_layer is None
138
+ self.norm = norm_layer(embed_dim, device=device) if norm_layer else nn.Identity()
139
+
140
+ def forward(self, x):
141
+ B, _C, T, H, W = x.shape
142
+ if not self.dynamic_img_pad:
143
+ assert (
144
+ H % self.patch_size[0] == 0
145
+ ), f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
146
+ assert (
147
+ W % self.patch_size[1] == 0
148
+ ), f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
149
+ else:
150
+ pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
151
+ pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
152
+ x = F.pad(x, (0, pad_w, 0, pad_h))
153
+
154
+ x = rearrange(x, "B C T H W -> (B T) C H W", B=B, T=T)
155
+ x = self.proj(x)
156
+
157
+ # Flatten temporal and spatial dimensions.
158
+ if not self.flatten:
159
+ raise NotImplementedError("Must flatten output.")
160
+ x = rearrange(x, "(B T) C H W -> B (T H W) C", B=B, T=T)
161
+
162
+ x = self.norm(x)
163
+ return x
164
+
165
+
166
+ class RMSNorm(torch.nn.Module):
167
+ def __init__(self, hidden_size, eps=1e-5, device=None):
168
+ super().__init__()
169
+ self.eps = eps
170
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device))
171
+ self.register_parameter("bias", None)
172
+
173
+ def forward(self, x):
174
+ x_fp32 = x.float()
175
+ x_normed = x_fp32 * torch.rsqrt(x_fp32.pow(2).mean(-1, keepdim=True) + self.eps)
176
+ return (x_normed * self.weight).type_as(x)
src/genmo/mochi_preview/dit/joint_model/mod_rmsnorm.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class ModulatedRMSNorm(torch.autograd.Function):
5
+ @staticmethod
6
+ def forward(ctx, x, scale, eps=1e-6):
7
+ # Convert to fp32 for precision
8
+ x_fp32 = x.float()
9
+ scale_fp32 = scale.float()
10
+
11
+ # Compute RMS
12
+ mean_square = x_fp32.pow(2).mean(-1, keepdim=True)
13
+ inv_rms = torch.rsqrt(mean_square + eps)
14
+
15
+ # Normalize and modulate
16
+ x_normed = x_fp32 * inv_rms
17
+ x_modulated = x_normed * (1 + scale_fp32.unsqueeze(1))
18
+
19
+ return x_modulated.type_as(x)
20
+
21
+
22
+ def modulated_rmsnorm(x, scale, eps=1e-6):
23
+ return ModulatedRMSNorm.apply(x, scale, eps)
src/genmo/mochi_preview/dit/joint_model/residual_tanh_gated_rmsnorm.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class ResidualTanhGatedRMSNorm(torch.autograd.Function):
5
+ @staticmethod
6
+ def forward(ctx, x, x_res, gate, eps=1e-6):
7
+ # Convert to fp32 for precision
8
+ x_res_fp32 = x_res.float()
9
+
10
+ # Compute RMS
11
+ mean_square = x_res_fp32.pow(2).mean(-1, keepdim=True)
12
+ scale = torch.rsqrt(mean_square + eps)
13
+
14
+ # Apply tanh to gate
15
+ tanh_gate = torch.tanh(gate).unsqueeze(1)
16
+
17
+ # Normalize and apply gated scaling
18
+ x_normed = x_res_fp32 * scale * tanh_gate
19
+
20
+ # Apply residual connection
21
+ output = x + x_normed.type_as(x)
22
+
23
+ return output
24
+
25
+
26
+ def residual_tanh_gated_rmsnorm(x, x_res, gate, eps=1e-6):
27
+ return ResidualTanhGatedRMSNorm.apply(x, x_res, gate, eps)
src/genmo/mochi_preview/dit/joint_model/rope_mixed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import math
3
+
4
+ import torch
5
+
6
+
7
+ def centers(start: float, stop, num, dtype=None, device=None):
8
+ """linspace through bin centers.
9
+
10
+ Args:
11
+ start (float): Start of the range.
12
+ stop (float): End of the range.
13
+ num (int): Number of points.
14
+ dtype (torch.dtype): Data type of the points.
15
+ device (torch.device): Device of the points.
16
+
17
+ Returns:
18
+ centers (Tensor): Centers of the bins. Shape: (num,).
19
+ """
20
+ edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
21
+ return (edges[:-1] + edges[1:]) / 2
22
+
23
+
24
+ @functools.lru_cache(maxsize=1)
25
+ def create_position_matrix(
26
+ T: int,
27
+ pH: int,
28
+ pW: int,
29
+ device: torch.device,
30
+ dtype: torch.dtype,
31
+ *,
32
+ target_area: float = 36864,
33
+ ):
34
+ """
35
+ Args:
36
+ T: int - Temporal dimension
37
+ pH: int - Height dimension after patchify
38
+ pW: int - Width dimension after patchify
39
+
40
+ Returns:
41
+ pos: [T * pH * pW, 3] - position matrix
42
+ """
43
+ with torch.no_grad():
44
+ # Create 1D tensors for each dimension
45
+ t = torch.arange(T, dtype=dtype)
46
+
47
+ # Positionally interpolate to area 36864.
48
+ # (3072x3072 frame with 16x16 patches = 192x192 latents).
49
+ # This automatically scales rope positions when the resolution changes.
50
+ # We use a large target area so the model is more sensitive
51
+ # to changes in the learned pos_frequencies matrix.
52
+ scale = math.sqrt(target_area / (pW * pH))
53
+ w = centers(-pW * scale / 2, pW * scale / 2, pW)
54
+ h = centers(-pH * scale / 2, pH * scale / 2, pH)
55
+
56
+ # Use meshgrid to create 3D grids
57
+ grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
58
+
59
+ # Stack and reshape the grids.
60
+ pos = torch.stack([grid_t, grid_h, grid_w], dim=-1) # [T, pH, pW, 3]
61
+ pos = pos.view(-1, 3) # [T * pH * pW, 3]
62
+ pos = pos.to(dtype=dtype, device=device)
63
+
64
+ return pos
65
+
66
+
67
+ def compute_mixed_rotation(
68
+ freqs: torch.Tensor,
69
+ pos: torch.Tensor,
70
+ ):
71
+ """
72
+ Project each 3-dim position into per-head, per-head-dim 1D frequencies.
73
+
74
+ Args:
75
+ freqs: [3, num_heads, num_freqs] - learned rotation frequency (for t, row, col) for each head position
76
+ pos: [N, 3] - position of each token
77
+ num_heads: int
78
+
79
+ Returns:
80
+ freqs_cos: [N, num_heads, num_freqs] - cosine components
81
+ freqs_sin: [N, num_heads, num_freqs] - sine components
82
+ """
83
+ with torch.autocast("cuda", enabled=False):
84
+ assert freqs.ndim == 3
85
+ freqs_sum = torch.einsum("Nd,dhf->Nhf", pos.to(freqs), freqs)
86
+ freqs_cos = torch.cos(freqs_sum)
87
+ freqs_sin = torch.sin(freqs_sum)
88
+ return freqs_cos, freqs_sin
src/genmo/mochi_preview/dit/joint_model/temporal_rope.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on Llama3 Implementation.
2
+ import torch
3
+
4
+
5
+ def apply_rotary_emb_qk_real(
6
+ xqk: torch.Tensor,
7
+ freqs_cos: torch.Tensor,
8
+ freqs_sin: torch.Tensor,
9
+ ) -> torch.Tensor:
10
+ """
11
+ Apply rotary embeddings to input tensors using the given frequency tensor without complex numbers.
12
+
13
+ Args:
14
+ xqk (torch.Tensor): Query and/or Key tensors to apply rotary embeddings. Shape: (B, S, *, num_heads, D)
15
+ Can be either just query or just key, or both stacked along some batch or * dim.
16
+ freqs_cos (torch.Tensor): Precomputed cosine frequency tensor.
17
+ freqs_sin (torch.Tensor): Precomputed sine frequency tensor.
18
+
19
+ Returns:
20
+ torch.Tensor: The input tensor with rotary embeddings applied.
21
+ """
22
+ assert xqk.dtype == torch.bfloat16
23
+ # Split the last dimension into even and odd parts
24
+ xqk_even = xqk[..., 0::2]
25
+ xqk_odd = xqk[..., 1::2]
26
+
27
+ # Apply rotation
28
+ cos_part = (xqk_even * freqs_cos - xqk_odd * freqs_sin).type_as(xqk)
29
+ sin_part = (xqk_even * freqs_sin + xqk_odd * freqs_cos).type_as(xqk)
30
+
31
+ # Interleave the results back into the original shape
32
+ out = torch.stack([cos_part, sin_part], dim=-1).flatten(-2)
33
+ assert out.dtype == torch.bfloat16
34
+ return out
src/genmo/mochi_preview/dit/joint_model/utils.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def modulate(x, shift, scale):
9
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
10
+
11
+
12
+ def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
13
+ """
14
+ Pool tokens in x using mask.
15
+
16
+ NOTE: We assume x does not require gradients.
17
+
18
+ Args:
19
+ x: (B, L, D) tensor of tokens.
20
+ mask: (B, L) boolean tensor indicating which tokens are not padding.
21
+
22
+ Returns:
23
+ pooled: (B, D) tensor of pooled tokens.
24
+ """
25
+ assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
26
+ assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
27
+ mask = mask[:, :, None].to(dtype=x.dtype)
28
+ mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
29
+ pooled = (x * mask).sum(dim=1, keepdim=keepdim)
30
+ return pooled
31
+
32
+
33
+ class AttentionPool(nn.Module):
34
+ def __init__(
35
+ self,
36
+ embed_dim: int,
37
+ num_heads: int,
38
+ output_dim: int = None,
39
+ device: Optional[torch.device] = None,
40
+ ):
41
+ """
42
+ Args:
43
+ spatial_dim (int): Number of tokens in sequence length.
44
+ embed_dim (int): Dimensionality of input tokens.
45
+ num_heads (int): Number of attention heads.
46
+ output_dim (int): Dimensionality of output tokens. Defaults to embed_dim.
47
+ """
48
+ super().__init__()
49
+ self.num_heads = num_heads
50
+ self.to_kv = nn.Linear(embed_dim, 2 * embed_dim, device=device)
51
+ self.to_q = nn.Linear(embed_dim, embed_dim, device=device)
52
+ self.to_out = nn.Linear(embed_dim, output_dim or embed_dim, device=device)
53
+
54
+ def forward(self, x, mask):
55
+ """
56
+ Args:
57
+ x (torch.Tensor): (B, L, D) tensor of input tokens.
58
+ mask (torch.Tensor): (B, L) boolean tensor indicating which tokens are not padding.
59
+
60
+ NOTE: We assume x does not require gradients.
61
+
62
+ Returns:
63
+ x (torch.Tensor): (B, D) tensor of pooled tokens.
64
+ """
65
+ D = x.size(2)
66
+
67
+ # Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
68
+ attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
69
+ attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
70
+
71
+ # Average non-padding token features. These will be used as the query.
72
+ x_pool = pool_tokens(x, mask, keepdim=True) # (B, 1, D)
73
+
74
+ # Concat pooled features to input sequence.
75
+ x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
76
+
77
+ # Compute queries, keys, values. Only the mean token is used to create a query.
78
+ kv = self.to_kv(x) # (B, L+1, 2 * D)
79
+ q = self.to_q(x[:, 0]) # (B, D)
80
+
81
+ # Extract heads.
82
+ head_dim = D // self.num_heads
83
+ kv = kv.unflatten(2, (2, self.num_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
84
+ kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
85
+ k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
86
+ q = q.unflatten(1, (self.num_heads, head_dim)) # (B, H, head_dim)
87
+ q = q.unsqueeze(2) # (B, H, 1, head_dim)
88
+
89
+ # Compute attention.
90
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim)
91
+
92
+ # Concatenate heads and run output.
93
+ x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
94
+ x = self.to_out(x)
95
+ return x
96
+
97
+
98
+ class PadSplitXY(torch.autograd.Function):
99
+ """
100
+ Merge heads, pad and extract visual and text tokens,
101
+ and split along the sequence length.
102
+ """
103
+
104
+ @staticmethod
105
+ def forward(
106
+ ctx,
107
+ xy: torch.Tensor,
108
+ indices: torch.Tensor,
109
+ B: int,
110
+ N: int,
111
+ L: int,
112
+ dtype: torch.dtype,
113
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
114
+ """
115
+ Args:
116
+ xy: Packed tokens. Shape: (total <= B * (N + L), num_heads * head_dim).
117
+ indices: Valid token indices out of unpacked tensor. Shape: (total,)
118
+
119
+ Returns:
120
+ x: Visual tokens. Shape: (B, N, num_heads * head_dim).
121
+ y: Text tokens. Shape: (B, L, num_heads * head_dim).
122
+ """
123
+ ctx.save_for_backward(indices)
124
+ ctx.B, ctx.N, ctx.L = B, N, L
125
+ D = xy.size(1)
126
+
127
+ # Pad sequences to (B, N + L, dim).
128
+ assert indices.ndim == 1
129
+ output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype)
130
+ indices = indices.unsqueeze(1).expand(-1, D) # (total,) -> (total, num_heads * head_dim)
131
+ output.scatter_(0, indices, xy)
132
+ xy = output.view(B, N + L, D)
133
+
134
+ # Split visual and text tokens along the sequence length.
135
+ return torch.tensor_split(xy, (N,), dim=1)
136
+
137
+
138
+ def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
139
+ return PadSplitXY.apply(xy, indices, B, N, L, dtype)
140
+
141
+
142
+ class UnifyStreams(torch.autograd.Function):
143
+ """Unify visual and text streams."""
144
+
145
+ @staticmethod
146
+ def forward(
147
+ ctx,
148
+ q_x: torch.Tensor,
149
+ k_x: torch.Tensor,
150
+ v_x: torch.Tensor,
151
+ q_y: torch.Tensor,
152
+ k_y: torch.Tensor,
153
+ v_y: torch.Tensor,
154
+ indices: torch.Tensor,
155
+ ):
156
+ """
157
+ Args:
158
+ q_x: (B, N, num_heads, head_dim)
159
+ k_x: (B, N, num_heads, head_dim)
160
+ v_x: (B, N, num_heads, head_dim)
161
+ q_y: (B, L, num_heads, head_dim)
162
+ k_y: (B, L, num_heads, head_dim)
163
+ v_y: (B, L, num_heads, head_dim)
164
+ indices: (total <= B * (N + L))
165
+
166
+ Returns:
167
+ qkv: (total <= B * (N + L), 3, num_heads, head_dim)
168
+ """
169
+ ctx.save_for_backward(indices)
170
+ B, N, num_heads, head_dim = q_x.size()
171
+ ctx.B, ctx.N, ctx.L = B, N, q_y.size(1)
172
+ D = num_heads * head_dim
173
+
174
+ q = torch.cat([q_x, q_y], dim=1)
175
+ k = torch.cat([k_x, k_y], dim=1)
176
+ v = torch.cat([v_x, v_y], dim=1)
177
+ qkv = torch.stack([q, k, v], dim=2).view(B * (N + ctx.L), 3, D)
178
+
179
+ indices = indices[:, None, None].expand(-1, 3, D)
180
+ qkv = torch.gather(qkv, 0, indices) # (total, 3, num_heads * head_dim)
181
+ return qkv.unflatten(2, (num_heads, head_dim))
182
+
183
+
184
+ def unify_streams(q_x, k_x, v_x, q_y, k_y, v_y, indices) -> torch.Tensor:
185
+ return UnifyStreams.apply(q_x, k_x, v_x, q_y, k_y, v_y, indices)
src/genmo/mochi_preview/pipelines.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from abc import ABC, abstractmethod
4
+ from contextlib import contextmanager
5
+ from functools import partial
6
+ from typing import Any, Dict, List, Literal, Optional, Union, cast
7
+
8
+ import numpy as np
9
+ import ray
10
+ import torch
11
+ import torch.distributed as dist
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from einops import rearrange, repeat
15
+ from safetensors.torch import load_file
16
+ from torch import nn
17
+ from torch.distributed.fsdp import (
18
+ BackwardPrefetch,
19
+ MixedPrecision,
20
+ ShardingStrategy,
21
+ )
22
+ from torch.distributed.fsdp import (
23
+ FullyShardedDataParallel as FSDP,
24
+ )
25
+ from torch.distributed.fsdp.wrap import (
26
+ lambda_auto_wrap_policy,
27
+ transformer_auto_wrap_policy,
28
+ )
29
+ from transformers import T5EncoderModel, T5Tokenizer
30
+ from transformers.models.t5.modeling_t5 import T5Block
31
+
32
+ import genmo.mochi_preview.dit.joint_model.context_parallel as cp
33
+ import genmo.mochi_preview.vae.cp_conv as cp_conv
34
+ from genmo.mochi_preview.vae.model import Decoder, apply_tiled
35
+ from genmo.lib.progress import get_new_progress_bar, progress_bar
36
+ from genmo.lib.utils import Timer
37
+
38
+
39
+ def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
40
+ if linear_steps is None:
41
+ linear_steps = num_steps // 2
42
+ linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
43
+ threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
44
+ quadratic_steps = num_steps - linear_steps
45
+ quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
46
+ linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
47
+ const = quadratic_coef * (linear_steps**2)
48
+ quadratic_sigma_schedule = [
49
+ quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
50
+ ]
51
+ sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
52
+ sigma_schedule = [1.0 - x for x in sigma_schedule]
53
+ return sigma_schedule
54
+
55
+
56
+ T5_MODEL = "google/t5-v1_1-xxl"
57
+ MAX_T5_TOKEN_LENGTH = 256
58
+
59
+
60
+ def setup_fsdp_sync(model, device_id, *, param_dtype, auto_wrap_policy) -> FSDP:
61
+ model = FSDP(
62
+ model,
63
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
64
+ mixed_precision=MixedPrecision(
65
+ param_dtype=param_dtype,
66
+ reduce_dtype=torch.float32,
67
+ buffer_dtype=torch.float32,
68
+ ),
69
+ auto_wrap_policy=auto_wrap_policy,
70
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
71
+ limit_all_gathers=True,
72
+ device_id=device_id,
73
+ sync_module_states=True,
74
+ use_orig_params=True,
75
+ )
76
+ torch.cuda.synchronize()
77
+ return model
78
+
79
+
80
+ class ModelFactory(ABC):
81
+ def __init__(self, **kwargs):
82
+ self.kwargs = kwargs
83
+
84
+ @abstractmethod
85
+ def get_model(self, *, local_rank: int, device_id: Union[int, Literal["cpu"]], world_size: int) -> Any:
86
+ if device_id == "cpu":
87
+ assert world_size == 1, "CPU offload only supports single-GPU inference"
88
+
89
+
90
+ class T5ModelFactory(ModelFactory):
91
+ def __init__(self):
92
+ super().__init__()
93
+
94
+ def get_model(self, *, local_rank, device_id, world_size):
95
+ super().get_model(local_rank=local_rank, device_id=device_id, world_size=world_size)
96
+ model = T5EncoderModel.from_pretrained(T5_MODEL)
97
+ if world_size > 1:
98
+ model = setup_fsdp_sync(
99
+ model,
100
+ device_id=device_id,
101
+ param_dtype=torch.float32,
102
+ auto_wrap_policy=partial(
103
+ transformer_auto_wrap_policy,
104
+ transformer_layer_cls={
105
+ T5Block,
106
+ },
107
+ ),
108
+ )
109
+ elif isinstance(device_id, int):
110
+ model = model.to(torch.device(f"cuda:{device_id}")) # type: ignore
111
+ return model.eval()
112
+
113
+
114
+ class DitModelFactory(ModelFactory):
115
+ def __init__(self, *, model_path: str, model_dtype: str, attention_mode: Optional[str] = None):
116
+ if attention_mode is None:
117
+ from genmo.lib.attn_imports import flash_varlen_qkvpacked_attn # type: ignore
118
+
119
+ attention_mode = "sdpa" if flash_varlen_qkvpacked_attn is None else "flash"
120
+ print(f"Attention mode: {attention_mode}")
121
+ super().__init__(model_path=model_path, model_dtype=model_dtype, attention_mode=attention_mode)
122
+
123
+ def get_model(self, *, local_rank, device_id, world_size):
124
+ # TODO(ved): Set flag for torch.compile
125
+ from genmo.mochi_preview.dit.joint_model.asymm_models_joint import (
126
+ AsymmDiTJoint,
127
+ )
128
+
129
+ model: nn.Module = torch.nn.utils.skip_init(
130
+ AsymmDiTJoint,
131
+ depth=48,
132
+ patch_size=2,
133
+ num_heads=24,
134
+ hidden_size_x=3072,
135
+ hidden_size_y=1536,
136
+ mlp_ratio_x=4.0,
137
+ mlp_ratio_y=4.0,
138
+ in_channels=12,
139
+ qk_norm=True,
140
+ qkv_bias=False,
141
+ out_bias=True,
142
+ patch_embed_bias=True,
143
+ timestep_mlp_bias=True,
144
+ timestep_scale=1000.0,
145
+ t5_feat_dim=4096,
146
+ t5_token_length=256,
147
+ rope_theta=10000.0,
148
+ attention_mode=self.kwargs["attention_mode"],
149
+ )
150
+
151
+ if local_rank == 0:
152
+ # FSDP syncs weights from rank 0 to all other ranks
153
+ model.load_state_dict(load_file(self.kwargs["model_path"]))
154
+
155
+ if world_size > 1:
156
+ assert self.kwargs["model_dtype"] == "bf16", "FP8 is not supported for multi-GPU inference"
157
+ model = setup_fsdp_sync(
158
+ model,
159
+ device_id=device_id,
160
+ param_dtype=torch.bfloat16,
161
+ auto_wrap_policy=partial(
162
+ lambda_auto_wrap_policy,
163
+ lambda_fn=lambda m: m in model.blocks,
164
+ ),
165
+ )
166
+ elif isinstance(device_id, int):
167
+ model = model.to(torch.device(f"cuda:{device_id}"))
168
+ return model.eval()
169
+
170
+
171
+ class DecoderModelFactory(ModelFactory):
172
+ def __init__(self, *, model_path: str, model_stats_path: str):
173
+ super().__init__(model_path=model_path, model_stats_path=model_stats_path)
174
+
175
+ def get_model(self, *, local_rank, device_id, world_size):
176
+ # TODO(ved): Set flag for torch.compile
177
+ # TODO(ved): Use skip_init
178
+ import json
179
+
180
+ decoder = Decoder(
181
+ out_channels=3,
182
+ base_channels=128,
183
+ channel_multipliers=[1, 2, 4, 6],
184
+ temporal_expansions=[1, 2, 3],
185
+ spatial_expansions=[2, 2, 2],
186
+ num_res_blocks=[3, 3, 4, 6, 3],
187
+ latent_dim=12,
188
+ has_attention=[False, False, False, False, False],
189
+ padding_mode="replicate",
190
+ output_norm=False,
191
+ nonlinearity="silu",
192
+ output_nonlinearity="silu",
193
+ causal=True,
194
+ )
195
+ # VAE is not FSDP-wrapped
196
+ state_dict = load_file(self.kwargs["model_path"])
197
+ decoder.load_state_dict(state_dict, strict=True)
198
+ device = torch.device(f"cuda:{device_id}") if isinstance(device_id, int) else "cpu"
199
+ decoder.eval().to(device)
200
+ vae_stats = json.load(open(self.kwargs["model_stats_path"]))
201
+ decoder.register_buffer("vae_mean", torch.tensor(vae_stats["mean"], device=device))
202
+ decoder.register_buffer("vae_std", torch.tensor(vae_stats["std"], device=device))
203
+ return decoder
204
+
205
+
206
+ def get_conditioning(tokenizer, encoder, device, batch_inputs, *, prompt: str, negative_prompt: str):
207
+ if batch_inputs:
208
+ return dict(batched=get_conditioning_for_prompts(tokenizer, encoder, device, [prompt, negative_prompt]))
209
+ else:
210
+ cond_input = get_conditioning_for_prompts(tokenizer, encoder, device, [prompt])
211
+ null_input = get_conditioning_for_prompts(tokenizer, encoder, device, [negative_prompt])
212
+ return dict(cond=cond_input, null=null_input)
213
+
214
+
215
+ def get_conditioning_for_prompts(tokenizer, encoder, device, prompts: List[str]):
216
+ assert len(prompts) in [1, 2] # [neg] or [pos] or [pos, neg]
217
+ B = len(prompts)
218
+ t5_toks = tokenizer(
219
+ prompts,
220
+ padding="max_length",
221
+ truncation=True,
222
+ max_length=MAX_T5_TOKEN_LENGTH,
223
+ return_tensors="pt",
224
+ return_attention_mask=True,
225
+ )
226
+ caption_input_ids_t5 = t5_toks["input_ids"]
227
+ caption_attention_mask_t5 = t5_toks["attention_mask"].bool()
228
+ del t5_toks
229
+
230
+ assert caption_input_ids_t5.shape == (B, MAX_T5_TOKEN_LENGTH)
231
+ assert caption_attention_mask_t5.shape == (B, MAX_T5_TOKEN_LENGTH)
232
+
233
+ # Special-case empty negative prompt by zero-ing it
234
+ if prompts[-1] == "":
235
+ caption_input_ids_t5[-1] = 0
236
+ caption_attention_mask_t5[-1] = False
237
+
238
+ caption_input_ids_t5 = caption_input_ids_t5.to(device, non_blocking=True)
239
+ caption_attention_mask_t5 = caption_attention_mask_t5.to(device, non_blocking=True)
240
+
241
+ y_mask = [caption_attention_mask_t5]
242
+ y_feat = [encoder(caption_input_ids_t5, caption_attention_mask_t5).last_hidden_state.detach()]
243
+ # Sometimes returns a tensor, othertimes a tuple, not sure why
244
+ # See: https://huggingface.co/genmo/mochi-1-preview/discussions/3
245
+ assert tuple(y_feat[-1].shape) == (B, MAX_T5_TOKEN_LENGTH, 4096)
246
+ assert y_feat[-1].dtype == torch.float32
247
+
248
+ return dict(y_mask=y_mask, y_feat=y_feat)
249
+
250
+
251
+ def compute_packed_indices(
252
+ device: torch.device, text_mask: torch.Tensor, num_latents: int
253
+ ) -> Dict[str, Union[torch.Tensor, int]]:
254
+ """
255
+ Based on https://github.com/Dao-AILab/flash-attention/blob/765741c1eeb86c96ee71a3291ad6968cfbf4e4a1/flash_attn/bert_padding.py#L60-L80
256
+
257
+ Args:
258
+ num_latents: Number of latent tokens
259
+ text_mask: (B, L) List of boolean tensor indicating which text tokens are not padding.
260
+
261
+ Returns:
262
+ packed_indices: Dict with keys for Flash Attention:
263
+ - valid_token_indices_kv: up to (B * (N + L),) tensor of valid token indices (non-padding)
264
+ in the packed sequence.
265
+ - cu_seqlens_kv: (B + 1,) tensor of cumulative sequence lengths in the packed sequence.
266
+ - max_seqlen_in_batch_kv: int of the maximum sequence length in the batch.
267
+ """
268
+ # Create an expanded token mask saying which tokens are valid across both visual and text tokens.
269
+ PATCH_SIZE = 2
270
+ num_visual_tokens = num_latents // (PATCH_SIZE**2)
271
+ assert num_visual_tokens > 0
272
+
273
+ mask = F.pad(text_mask, (num_visual_tokens, 0), value=True) # (B, N + L)
274
+ seqlens_in_batch = mask.sum(dim=-1, dtype=torch.int32) # (B,)
275
+ valid_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten() # up to (B * (N + L),)
276
+ assert valid_token_indices.size(0) >= text_mask.size(0) * num_visual_tokens # At least (B * N,)
277
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
278
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
279
+
280
+ return {
281
+ "cu_seqlens_kv": cu_seqlens.to(device, non_blocking=True),
282
+ "max_seqlen_in_batch_kv": cast(int, max_seqlen_in_batch),
283
+ "valid_token_indices_kv": valid_token_indices.to(device, non_blocking=True),
284
+ }
285
+
286
+
287
+ def assert_eq(x, y, msg=None):
288
+ assert x == y, f"{msg or 'Assertion failed'}: {x} != {y}"
289
+
290
+
291
+ def sample_model(device, dit, conditioning, **args):
292
+ random.seed(args["seed"])
293
+ np.random.seed(args["seed"])
294
+ torch.manual_seed(args["seed"])
295
+
296
+ generator = torch.Generator(device=device)
297
+ generator.manual_seed(args["seed"])
298
+
299
+ w, h, t = args["width"], args["height"], args["num_frames"]
300
+ sample_steps = args["num_inference_steps"]
301
+ cfg_schedule = args["cfg_schedule"]
302
+ sigma_schedule = args["sigma_schedule"]
303
+
304
+ assert_eq(len(cfg_schedule), sample_steps, "cfg_schedule must have length sample_steps")
305
+ assert_eq((t - 1) % 6, 0, "t - 1 must be divisible by 6")
306
+ assert_eq(
307
+ len(sigma_schedule),
308
+ sample_steps + 1,
309
+ "sigma_schedule must have length sample_steps + 1",
310
+ )
311
+
312
+ B = 1
313
+ SPATIAL_DOWNSAMPLE = 8
314
+ TEMPORAL_DOWNSAMPLE = 6
315
+ IN_CHANNELS = 12
316
+ latent_t = ((t - 1) // TEMPORAL_DOWNSAMPLE) + 1
317
+ latent_w, latent_h = w // SPATIAL_DOWNSAMPLE, h // SPATIAL_DOWNSAMPLE
318
+
319
+ z = torch.randn(
320
+ (B, IN_CHANNELS, latent_t, latent_h, latent_w),
321
+ device=device,
322
+ dtype=torch.float32,
323
+ )
324
+
325
+ num_latents = latent_t * latent_h * latent_w
326
+ cond_batched = cond_text = cond_null = None
327
+ if "cond" in conditioning:
328
+ cond_text = conditioning["cond"]
329
+ cond_null = conditioning["null"]
330
+ cond_text["packed_indices"] = compute_packed_indices(device, cond_text["y_mask"][0], num_latents)
331
+ cond_null["packed_indices"] = compute_packed_indices(device, cond_null["y_mask"][0], num_latents)
332
+ else:
333
+ cond_batched = conditioning["batched"]
334
+ cond_batched["packed_indices"] = compute_packed_indices(device, cond_batched["y_mask"][0], num_latents)
335
+ z = repeat(z, "b ... -> (repeat b) ...", repeat=2)
336
+
337
+ def model_fn(*, z, sigma, cfg_scale):
338
+ if cond_batched:
339
+ with torch.autocast("cuda", dtype=torch.bfloat16):
340
+ out = dit(z, sigma, **cond_batched)
341
+ out_cond, out_uncond = torch.chunk(out, chunks=2, dim=0)
342
+ else:
343
+ nonlocal cond_text, cond_null
344
+ with torch.autocast("cuda", dtype=torch.bfloat16):
345
+ out_cond = dit(z, sigma, **cond_text)
346
+ out_uncond = dit(z, sigma, **cond_null)
347
+ assert out_cond.shape == out_uncond.shape
348
+ return out_uncond + cfg_scale * (out_cond - out_uncond), out_cond
349
+
350
+ for i in get_new_progress_bar(range(0, sample_steps), desc="Sampling"):
351
+ sigma = sigma_schedule[i]
352
+ dsigma = sigma - sigma_schedule[i + 1]
353
+
354
+ # `pred` estimates `z_0 - eps`.
355
+ pred, output_cond = model_fn(
356
+ z=z,
357
+ sigma=torch.full([B] if cond_text else [B * 2], sigma, device=z.device),
358
+ cfg_scale=cfg_schedule[i],
359
+ )
360
+ pred = pred.to(z)
361
+ output_cond = output_cond.to(z)
362
+ z = z + dsigma * pred
363
+
364
+ return z[:B] if cond_batched else z
365
+
366
+
367
+ def decoded_latents_to_frames(samples):
368
+ samples = samples.float()
369
+ samples = (samples + 1.0) / 2.0
370
+ samples.clamp_(0.0, 1.0)
371
+ frames = rearrange(samples, "b c t h w -> b t h w c")
372
+ return frames
373
+
374
+
375
+ def decode_latents(decoder, z):
376
+ cp_rank, cp_size = cp.get_cp_rank_size()
377
+ z = z.tensor_split(cp_size, dim=2)[cp_rank] # split along temporal dim
378
+ with torch.autocast("cuda", dtype=torch.bfloat16):
379
+ samples = decoder(z)
380
+ samples = cp_conv.gather_all_frames(samples)
381
+ return decoded_latents_to_frames(samples)
382
+
383
+
384
+ @torch.inference_mode()
385
+ def decode_latents_tiled_full(
386
+ decoder,
387
+ z,
388
+ *,
389
+ tile_sample_min_height: int = 240,
390
+ tile_sample_min_width: int = 424,
391
+ tile_overlap_factor_height: float = 0.1666,
392
+ tile_overlap_factor_width: float = 0.2,
393
+ auto_tile_size: bool = True,
394
+ frame_batch_size: int = 6,
395
+ ):
396
+ B, C, T, H, W = z.shape
397
+ assert frame_batch_size <= T, f"frame_batch_size must be <= T, got {frame_batch_size} > {T}"
398
+
399
+ tile_sample_min_height = tile_sample_min_height if not auto_tile_size else H // 2 * 8
400
+ tile_sample_min_width = tile_sample_min_width if not auto_tile_size else W // 2 * 8
401
+
402
+ tile_latent_min_height = int(tile_sample_min_height / 8)
403
+ tile_latent_min_width = int(tile_sample_min_width / 8)
404
+
405
+ def blend_v(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
406
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
407
+ for y in range(blend_extent):
408
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
409
+ y / blend_extent
410
+ )
411
+ return b
412
+
413
+ def blend_h(a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
414
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
415
+ for x in range(blend_extent):
416
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
417
+ x / blend_extent
418
+ )
419
+ return b
420
+
421
+ overlap_height = int(tile_latent_min_height * (1 - tile_overlap_factor_height))
422
+ overlap_width = int(tile_latent_min_width * (1 - tile_overlap_factor_width))
423
+ blend_extent_height = int(tile_sample_min_height * tile_overlap_factor_height)
424
+ blend_extent_width = int(tile_sample_min_width * tile_overlap_factor_width)
425
+ row_limit_height = tile_sample_min_height - blend_extent_height
426
+ row_limit_width = tile_sample_min_width - blend_extent_width
427
+
428
+ # Split z into overlapping tiles and decode them separately.
429
+ # The tiles have an overlap to avoid seams between tiles.
430
+ pbar = get_new_progress_bar(
431
+ desc="Decoding latent tiles",
432
+ total=len(range(0, H, overlap_height)) * len(range(0, W, overlap_width)) * len(range(T // frame_batch_size)),
433
+ )
434
+ rows = []
435
+ for i in range(0, H, overlap_height):
436
+ row = []
437
+ for j in range(0, W, overlap_width):
438
+ temporal = []
439
+ for k in range(T // frame_batch_size):
440
+ remaining_frames = T % frame_batch_size
441
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
442
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
443
+ tile = z[
444
+ :,
445
+ :,
446
+ start_frame:end_frame,
447
+ i : i + tile_latent_min_height,
448
+ j : j + tile_latent_min_width,
449
+ ]
450
+ tile = decoder(tile)
451
+ temporal.append(tile)
452
+ pbar.update(1)
453
+ row.append(torch.cat(temporal, dim=2))
454
+ rows.append(row)
455
+
456
+ result_rows = []
457
+ for i, row in enumerate(rows):
458
+ result_row = []
459
+ for j, tile in enumerate(row):
460
+ # blend the above tile and the left tile
461
+ # to the current tile and add the current tile to the result row
462
+ if i > 0:
463
+ tile = blend_v(rows[i - 1][j], tile, blend_extent_height)
464
+ if j > 0:
465
+ tile = blend_h(row[j - 1], tile, blend_extent_width)
466
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
467
+ result_rows.append(torch.cat(result_row, dim=4))
468
+
469
+ return decoded_latents_to_frames(torch.cat(result_rows, dim=3))
470
+
471
+ @torch.inference_mode()
472
+ def decode_latents_tiled_spatial(
473
+ decoder,
474
+ z,
475
+ *,
476
+ num_tiles_w: int,
477
+ num_tiles_h: int,
478
+ overlap: int = 0, # Number of pixel of overlap between adjacent tiles.
479
+ # Use a factor of 2 times the latent downsample factor.
480
+ min_block_size: int = 1, # Minimum number of pixels in each dimension when subdividing.
481
+ ):
482
+ decoded = apply_tiled(decoder, z, num_tiles_w, num_tiles_h, overlap, min_block_size)
483
+ assert decoded is not None, f"Failed to decode latents with tiled spatial method"
484
+ return decoded
485
+
486
+ @contextmanager
487
+ def move_to_device(model: nn.Module, target_device):
488
+ og_device = next(model.parameters()).device
489
+ if og_device == target_device:
490
+ print(f"move_to_device is a no-op model is already on {target_device}")
491
+ else:
492
+ print(f"moving model from {og_device} -> {target_device}")
493
+
494
+ model.to(target_device)
495
+ yield
496
+ if og_device != target_device:
497
+ print(f"moving model from {target_device} -> {og_device}")
498
+ model.to(og_device)
499
+
500
+
501
+ def t5_tokenizer():
502
+ return T5Tokenizer.from_pretrained(T5_MODEL, legacy=False)
503
+
504
+
505
+ class MochiSingleGPUPipeline:
506
+ def __init__(
507
+ self,
508
+ *,
509
+ text_encoder_factory: ModelFactory,
510
+ dit_factory: ModelFactory,
511
+ decoder_factory: ModelFactory,
512
+ cpu_offload: Optional[bool] = False,
513
+ decode_type: str = "full",
514
+ decode_args: Optional[Dict[str, Any]] = None,
515
+ ):
516
+ self.device = torch.device("cuda:0")
517
+ self.tokenizer = t5_tokenizer()
518
+ t = Timer()
519
+ self.cpu_offload = cpu_offload
520
+ self.decode_args = decode_args or {}
521
+ self.decode_type = decode_type
522
+ init_id = "cpu" if cpu_offload else 0
523
+ with t("load_text_encoder"):
524
+ self.text_encoder = text_encoder_factory.get_model(
525
+ local_rank=0,
526
+ device_id=init_id,
527
+ world_size=1,
528
+ )
529
+ with t("load_dit"):
530
+ self.dit = dit_factory.get_model(local_rank=0, device_id=init_id, world_size=1)
531
+ with t("load_vae"):
532
+ self.decoder = decoder_factory.get_model(local_rank=0, device_id=init_id, world_size=1)
533
+ t.print_stats()
534
+
535
+ def __call__(self, batch_cfg, prompt, negative_prompt, **kwargs):
536
+ with progress_bar(type="tqdm"), torch.inference_mode():
537
+ print_max_memory = lambda: print(
538
+ f"Max memory reserved: {torch.cuda.max_memory_reserved() / 1024**3:.2f} GB"
539
+ )
540
+ print_max_memory()
541
+ with move_to_device(self.text_encoder, self.device):
542
+ conditioning = get_conditioning(
543
+ self.tokenizer,
544
+ self.text_encoder,
545
+ self.device,
546
+ batch_cfg,
547
+ prompt=prompt,
548
+ negative_prompt=negative_prompt,
549
+ )
550
+ print_max_memory()
551
+ with move_to_device(self.dit, self.device):
552
+ latents = sample_model(self.device, self.dit, conditioning, **kwargs)
553
+ print_max_memory()
554
+ with move_to_device(self.decoder, self.device):
555
+ frames = (
556
+ decode_latents_tiled_full(self.decoder, latents, **self.decode_args)
557
+ if self.decode_type == "tiled_full"
558
+ else
559
+ decode_latents_tiled_spatial(self.decoder, latents, **self.decode_args)
560
+ if self.decode_type == "tiled_spatial"
561
+ else decode_latents(self.decoder, latents)
562
+ )
563
+ print_max_memory()
564
+ return frames.cpu().numpy()
565
+
566
+
567
+ ### ALL CODE BELOW HERE IS FOR MULTI-GPU MODE ###
568
+
569
+
570
+ # In multi-gpu mode, all models must belong to a device which has a predefined context parallel group
571
+ # So it doesn't make sense to work with models individually
572
+ class MultiGPUContext:
573
+ def __init__(
574
+ self,
575
+ *,
576
+ text_encoder_factory,
577
+ dit_factory,
578
+ decoder_factory,
579
+ device_id,
580
+ local_rank,
581
+ world_size,
582
+ ):
583
+ t = Timer()
584
+ self.device = torch.device(f"cuda:{device_id}")
585
+ print(f"Initializing rank {local_rank+1}/{world_size}")
586
+ assert world_size > 1, f"Multi-GPU mode requires world_size > 1, got {world_size}"
587
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
588
+ os.environ["MASTER_PORT"] = "29500"
589
+ with t("init_process_group"):
590
+ dist.init_process_group(
591
+ "nccl",
592
+ rank=local_rank,
593
+ world_size=world_size,
594
+ device_id=self.device, # force non-lazy init
595
+ )
596
+ pg = dist.group.WORLD
597
+ cp.set_cp_group(pg, list(range(world_size)), local_rank)
598
+ distributed_kwargs = dict(local_rank=local_rank, device_id=device_id, world_size=world_size)
599
+ self.world_size = world_size
600
+ self.tokenizer = t5_tokenizer()
601
+ with t("load_text_encoder"):
602
+ self.text_encoder = text_encoder_factory.get_model(**distributed_kwargs)
603
+ with t("load_dit"):
604
+ self.dit = dit_factory.get_model(**distributed_kwargs)
605
+ with t("load_vae"):
606
+ self.decoder = decoder_factory.get_model(**distributed_kwargs)
607
+ self.local_rank = local_rank
608
+ t.print_stats()
609
+
610
+ def run(self, *, fn, **kwargs):
611
+ return fn(self, **kwargs)
612
+
613
+
614
+ class MochiMultiGPUPipeline:
615
+ def __init__(
616
+ self,
617
+ *,
618
+ text_encoder_factory: ModelFactory,
619
+ dit_factory: ModelFactory,
620
+ decoder_factory: ModelFactory,
621
+ world_size: int,
622
+ ):
623
+ ray.init()
624
+ RemoteClass = ray.remote(MultiGPUContext)
625
+ self.ctxs = [
626
+ RemoteClass.options(num_gpus=1).remote(
627
+ text_encoder_factory=text_encoder_factory,
628
+ dit_factory=dit_factory,
629
+ decoder_factory=decoder_factory,
630
+ world_size=world_size,
631
+ device_id=0,
632
+ local_rank=i,
633
+ )
634
+ for i in range(world_size)
635
+ ]
636
+ for ctx in self.ctxs:
637
+ ray.get(ctx.__ray_ready__.remote())
638
+
639
+ def __call__(self, **kwargs):
640
+ def sample(ctx, *, batch_cfg, prompt, negative_prompt, **kwargs):
641
+ with progress_bar(type="ray_tqdm", enabled=ctx.local_rank == 0), torch.inference_mode():
642
+ conditioning = get_conditioning(
643
+ ctx.tokenizer,
644
+ ctx.text_encoder,
645
+ ctx.device,
646
+ batch_cfg,
647
+ prompt=prompt,
648
+ negative_prompt=negative_prompt,
649
+ )
650
+ latents = sample_model(ctx.device, ctx.dit, conditioning=conditioning, **kwargs)
651
+ if ctx.local_rank == 0:
652
+ torch.save(latents, "latents.pt")
653
+ frames = decode_latents(ctx.decoder, latents)
654
+ return frames.cpu().numpy()
655
+
656
+ return ray.get([ctx.run.remote(fn=sample, **kwargs, show_progress=i == 0) for i, ctx in enumerate(self.ctxs)])[
657
+ 0
658
+ ]
src/genmo/mochi_preview/vae/__init__.py ADDED
File without changes
src/genmo/mochi_preview/vae/cp_conv.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+ import torch.nn.functional as F
6
+
7
+ import genmo.mochi_preview.dit.joint_model.context_parallel as cp
8
+
9
+
10
+ def cast_tuple(t, length=1):
11
+ return t if isinstance(t, tuple) else ((t,) * length)
12
+
13
+
14
+ def cp_pass_frames(x: torch.Tensor, frames_to_send: int) -> torch.Tensor:
15
+ """
16
+ Forward pass that handles communication between ranks for inference.
17
+ Args:
18
+ x: Tensor of shape (B, C, T, H, W)
19
+ frames_to_send: int, number of frames to communicate between ranks
20
+ Returns:
21
+ output: Tensor of shape (B, C, T', H, W)
22
+ """
23
+ cp_rank, cp_world_size = cp.get_cp_rank_size()
24
+ if frames_to_send == 0 or cp_world_size == 1:
25
+ return x
26
+
27
+ group = cp.get_cp_group()
28
+ global_rank = dist.get_rank()
29
+
30
+ # Send to next rank
31
+ if cp_rank < cp_world_size - 1:
32
+ assert x.size(2) >= frames_to_send
33
+ tail = x[:, :, -frames_to_send:].contiguous()
34
+ dist.send(tail, global_rank + 1, group=group)
35
+
36
+ # Receive from previous rank
37
+ if cp_rank > 0:
38
+ B, C, _, H, W = x.shape
39
+ recv_buffer = torch.empty(
40
+ (B, C, frames_to_send, H, W),
41
+ dtype=x.dtype,
42
+ device=x.device,
43
+ )
44
+ dist.recv(recv_buffer, global_rank - 1, group=group)
45
+ x = torch.cat([recv_buffer, x], dim=2)
46
+
47
+ return x
48
+
49
+
50
+ def _pad_to_max(x: torch.Tensor, max_T: int) -> torch.Tensor:
51
+ if max_T > x.size(2):
52
+ pad_T = max_T - x.size(2)
53
+ pad_dims = (0, 0, 0, 0, 0, pad_T)
54
+ return F.pad(x, pad_dims)
55
+ return x
56
+
57
+
58
+ def gather_all_frames(x: torch.Tensor) -> torch.Tensor:
59
+ """
60
+ Gathers all frames from all processes for inference.
61
+ Args:
62
+ x: Tensor of shape (B, C, T, H, W)
63
+ Returns:
64
+ output: Tensor of shape (B, C, T_total, H, W)
65
+ """
66
+ cp_rank, cp_size = cp.get_cp_rank_size()
67
+ cp_group = cp.get_cp_group()
68
+
69
+ # Ensure the tensor is contiguous for collective operations
70
+ x = x.contiguous()
71
+
72
+ # Get the local time dimension size
73
+ local_T = x.size(2)
74
+ local_T_tensor = torch.tensor([local_T], device=x.device, dtype=torch.int64)
75
+
76
+ # Gather all T sizes from all processes
77
+ all_T = [torch.zeros(1, dtype=torch.int64, device=x.device) for _ in range(cp_size)]
78
+ dist.all_gather(all_T, local_T_tensor, group=cp_group)
79
+ all_T = [t.item() for t in all_T]
80
+
81
+ # Pad the tensor at the end of the time dimension to match max_T
82
+ max_T = max(all_T)
83
+ x = _pad_to_max(x, max_T).contiguous()
84
+
85
+ # Prepare a list to hold the gathered tensors
86
+ gathered_x = [torch.zeros_like(x).contiguous() for _ in range(cp_size)]
87
+
88
+ # Perform the all_gather operation
89
+ dist.all_gather(gathered_x, x, group=cp_group)
90
+
91
+ # Slice each gathered tensor back to its original T size
92
+ for idx, t_size in enumerate(all_T):
93
+ gathered_x[idx] = gathered_x[idx][:, :, :t_size]
94
+
95
+ return torch.cat(gathered_x, dim=2)
96
+
97
+
98
+ def excessive_memory_usage(input: torch.Tensor, max_gb: float = 2.0) -> bool:
99
+ """Estimate memory usage based on input tensor size and data type."""
100
+ element_size = input.element_size() # Size in bytes of each element
101
+ memory_bytes = input.numel() * element_size
102
+ memory_gb = memory_bytes / 1024**3
103
+ return memory_gb > max_gb
104
+
105
+
106
+ class ContextParallelCausalConv3d(torch.nn.Conv3d):
107
+ def __init__(
108
+ self,
109
+ in_channels,
110
+ out_channels,
111
+ kernel_size: Union[int, Tuple[int, int, int]],
112
+ stride: Union[int, Tuple[int, int, int]],
113
+ **kwargs,
114
+ ):
115
+ kernel_size = cast_tuple(kernel_size, 3)
116
+ stride = cast_tuple(stride, 3)
117
+ height_pad = (kernel_size[1] - 1) // 2
118
+ width_pad = (kernel_size[2] - 1) // 2
119
+
120
+ super().__init__(
121
+ in_channels=in_channels,
122
+ out_channels=out_channels,
123
+ kernel_size=kernel_size,
124
+ stride=stride,
125
+ dilation=(1, 1, 1),
126
+ padding=(0, height_pad, width_pad),
127
+ **kwargs,
128
+ )
129
+
130
+ def forward(self, x: torch.Tensor):
131
+ cp_rank, cp_world_size = cp.get_cp_rank_size()
132
+
133
+ context_size = self.kernel_size[0] - 1
134
+ if cp_rank == 0:
135
+ mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
136
+ x = F.pad(x, (0, 0, 0, 0, context_size, 0), mode=mode)
137
+
138
+ if cp_world_size == 1:
139
+ return super().forward(x)
140
+
141
+ if all(s == 1 for s in self.stride):
142
+ # Receive some frames from previous rank.
143
+ x = cp_pass_frames(x, context_size)
144
+ return super().forward(x)
145
+
146
+ # Less efficient implementation for strided convs.
147
+ # All gather x, infer and chunk.
148
+ x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W]
149
+ x = super().forward(x)
150
+ x_chunks = x.tensor_split(cp_world_size, dim=2)
151
+ assert len(x_chunks) == cp_world_size
152
+ return x_chunks[cp_rank]
src/genmo/mochi_preview/vae/model.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+ import genmo.mochi_preview.dit.joint_model.context_parallel as cp
9
+ from genmo.mochi_preview.vae.cp_conv import cp_pass_frames, gather_all_frames
10
+
11
+
12
+ def cast_tuple(t, length=1):
13
+ return t if isinstance(t, tuple) else ((t,) * length)
14
+
15
+
16
+ class GroupNormSpatial(nn.GroupNorm):
17
+ """
18
+ GroupNorm applied per-frame.
19
+ """
20
+
21
+ def forward(self, x: torch.Tensor, *, chunk_size: int = 8):
22
+ B, C, T, H, W = x.shape
23
+ x = rearrange(x, "B C T H W -> (B T) C H W")
24
+ # Run group norm in chunks.
25
+ output = torch.empty_like(x)
26
+ for b in range(0, B * T, chunk_size):
27
+ output[b : b + chunk_size] = super().forward(x[b : b + chunk_size])
28
+ return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T)
29
+
30
+
31
+ class SafeConv3d(torch.nn.Conv3d):
32
+ """
33
+ NOTE: No support for padding along time dimension.
34
+ Input must already be padded along time.
35
+ """
36
+
37
+ def forward(self, input):
38
+ memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
39
+ if memory_count > 2:
40
+ part_num = int(memory_count / 2) + 1
41
+ k = self.kernel_size[0]
42
+ input_idx = torch.arange(k - 1, input.size(2))
43
+ input_chunks_idx = torch.chunk(input_idx, part_num, dim=0)
44
+
45
+ # assert self.kernel_size == (3, 3, 3), f"kernel_size {self.kernel_size} != (3, 3, 3)"
46
+ assert self.stride[0] == 1, f"stride {self.stride}"
47
+ assert self.dilation[0] == 1, f"dilation {self.dilation}"
48
+ assert self.padding[0] == 0, f"padding {self.padding}"
49
+
50
+ # Comptue output size
51
+ assert not input.requires_grad
52
+ B, _, T_in, H_in, W_in = input.shape
53
+ output_size = (
54
+ B,
55
+ self.out_channels,
56
+ T_in - k + 1,
57
+ H_in // self.stride[1],
58
+ W_in // self.stride[2],
59
+ )
60
+ output = torch.empty(output_size, dtype=input.dtype, device=input.device)
61
+ for input_chunk_idx in input_chunks_idx:
62
+ input_s = input_chunk_idx[0] - k + 1
63
+ input_e = input_chunk_idx[-1] + 1
64
+ input_chunk = input[:, :, input_s:input_e, :, :]
65
+ output_chunk = super(SafeConv3d, self).forward(input_chunk)
66
+
67
+ output_s = input_s
68
+ output_e = output_s + output_chunk.size(2)
69
+ output[:, :, output_s:output_e, :, :] = output_chunk
70
+
71
+ return output
72
+ else:
73
+ return super(SafeConv3d, self).forward(input)
74
+
75
+
76
+ class StridedSafeConv3d(torch.nn.Conv3d):
77
+ def forward(self, input, local_shard: bool = False):
78
+ assert self.stride[0] == self.kernel_size[0]
79
+ assert self.dilation[0] == 1
80
+ assert self.padding[0] == 0
81
+
82
+ kernel_size = self.kernel_size[0]
83
+ stride = self.stride[0]
84
+ T_in = input.size(2)
85
+ T_out = T_in // kernel_size
86
+
87
+ # Parallel implementation.
88
+ if local_shard:
89
+ idx = torch.arange(T_out)
90
+ idx = cp.local_shard(idx, dim=0)
91
+ start = idx.min() * stride
92
+ end = idx.max() * stride + kernel_size
93
+ local_input = input[:, :, start:end, :, :]
94
+ return torch.nn.Conv3d.forward(self, local_input)
95
+
96
+ raise NotImplementedError
97
+
98
+
99
+ class ContextParallelConv3d(SafeConv3d):
100
+ def __init__(
101
+ self,
102
+ in_channels,
103
+ out_channels,
104
+ kernel_size: Union[int, Tuple[int, int, int]],
105
+ stride: Union[int, Tuple[int, int, int]],
106
+ causal: bool = True,
107
+ context_parallel: bool = True,
108
+ **kwargs,
109
+ ):
110
+ self.causal = causal
111
+ self.context_parallel = context_parallel
112
+ kernel_size = cast_tuple(kernel_size, 3)
113
+ stride = cast_tuple(stride, 3)
114
+ height_pad = (kernel_size[1] - 1) // 2
115
+ width_pad = (kernel_size[2] - 1) // 2
116
+
117
+ super().__init__(
118
+ in_channels=in_channels,
119
+ out_channels=out_channels,
120
+ kernel_size=kernel_size,
121
+ stride=stride,
122
+ dilation=(1, 1, 1),
123
+ padding=(0, height_pad, width_pad),
124
+ **kwargs,
125
+ )
126
+
127
+ def forward(self, x: torch.Tensor):
128
+ cp_rank, cp_world_size = cp.get_cp_rank_size()
129
+
130
+ # Compute padding amounts.
131
+ context_size = self.kernel_size[0] - 1
132
+ if self.causal:
133
+ pad_front = context_size
134
+ pad_back = 0
135
+ else:
136
+ pad_front = context_size // 2
137
+ pad_back = context_size - pad_front
138
+
139
+ # Apply padding.
140
+ assert self.padding_mode == "replicate" # DEBUG
141
+ mode = "constant" if self.padding_mode == "zeros" else self.padding_mode
142
+ if self.context_parallel and cp_world_size == 1:
143
+ x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode)
144
+ else:
145
+ if cp_rank == 0:
146
+ x = F.pad(x, (0, 0, 0, 0, pad_front, 0), mode=mode)
147
+ elif cp_rank == cp_world_size - 1 and pad_back:
148
+ x = F.pad(x, (0, 0, 0, 0, 0, pad_back), mode=mode)
149
+
150
+ if self.context_parallel and cp_world_size == 1:
151
+ return super().forward(x)
152
+
153
+ if self.stride[0] == 1:
154
+ # Receive some frames from previous rank.
155
+ x = cp_pass_frames(x, context_size)
156
+ return super().forward(x)
157
+
158
+ # Less efficient implementation for strided convs.
159
+ # All gather x, infer and chunk.
160
+ assert x.dtype == torch.bfloat16, f"Expected x to be of type torch.bfloat16, got {x.dtype}"
161
+
162
+ x = gather_all_frames(x) # [B, C, k - 1 + global_T, H, W]
163
+ return StridedSafeConv3d.forward(self, x, local_shard=True)
164
+
165
+
166
+ class Conv1x1(nn.Linear):
167
+ """*1x1 Conv implemented with a linear layer."""
168
+
169
+ def __init__(self, in_features: int, out_features: int, *args, **kwargs):
170
+ super().__init__(in_features, out_features, *args, **kwargs)
171
+
172
+ def forward(self, x: torch.Tensor):
173
+ """Forward pass.
174
+
175
+ Args:
176
+ x: Input tensor. Shape: [B, C, *] or [B, *, C].
177
+
178
+ Returns:
179
+ x: Output tensor. Shape: [B, C', *] or [B, *, C'].
180
+ """
181
+ x = x.movedim(1, -1)
182
+ x = super().forward(x)
183
+ x = x.movedim(-1, 1)
184
+ return x
185
+
186
+
187
+ class DepthToSpaceTime(nn.Module):
188
+ def __init__(
189
+ self,
190
+ temporal_expansion: int,
191
+ spatial_expansion: int,
192
+ ):
193
+ super().__init__()
194
+ self.temporal_expansion = temporal_expansion
195
+ self.spatial_expansion = spatial_expansion
196
+
197
+ # When printed, this module should show the temporal and spatial expansion factors.
198
+ def extra_repr(self):
199
+ return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}"
200
+
201
+ def forward(self, x: torch.Tensor):
202
+ """Forward pass.
203
+
204
+ Args:
205
+ x: Input tensor. Shape: [B, C, T, H, W].
206
+
207
+ Returns:
208
+ x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s].
209
+ """
210
+ x = rearrange(
211
+ x,
212
+ "B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)",
213
+ st=self.temporal_expansion,
214
+ sh=self.spatial_expansion,
215
+ sw=self.spatial_expansion,
216
+ )
217
+
218
+ cp_rank, _ = cp.get_cp_rank_size()
219
+ if self.temporal_expansion > 1 and cp_rank == 0:
220
+ # Drop the first self.temporal_expansion - 1 frames.
221
+ # This is because we always want the 3x3x3 conv filter to only apply
222
+ # to the first frame, and the first frame doesn't need to be repeated.
223
+ assert all(x.shape)
224
+ x = x[:, :, self.temporal_expansion - 1 :]
225
+ assert all(x.shape)
226
+
227
+ return x
228
+
229
+
230
+ def norm_fn(
231
+ in_channels: int,
232
+ affine: bool = True,
233
+ ):
234
+ return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels)
235
+
236
+
237
+ class ResBlock(nn.Module):
238
+ """Residual block that preserves the spatial dimensions."""
239
+
240
+ def __init__(
241
+ self,
242
+ channels: int,
243
+ *,
244
+ affine: bool = True,
245
+ attn_block: Optional[nn.Module] = None,
246
+ padding_mode: str = "replicate",
247
+ causal: bool = True,
248
+ ):
249
+ super().__init__()
250
+ self.channels = channels
251
+
252
+ assert causal
253
+ self.stack = nn.Sequential(
254
+ norm_fn(channels, affine=affine),
255
+ nn.SiLU(inplace=True),
256
+ ContextParallelConv3d(
257
+ in_channels=channels,
258
+ out_channels=channels,
259
+ kernel_size=(3, 3, 3),
260
+ stride=(1, 1, 1),
261
+ padding_mode=padding_mode,
262
+ bias=True,
263
+ causal=causal,
264
+ ),
265
+ norm_fn(channels, affine=affine),
266
+ nn.SiLU(inplace=True),
267
+ ContextParallelConv3d(
268
+ in_channels=channels,
269
+ out_channels=channels,
270
+ kernel_size=(3, 3, 3),
271
+ stride=(1, 1, 1),
272
+ padding_mode=padding_mode,
273
+ bias=True,
274
+ causal=causal,
275
+ ),
276
+ )
277
+
278
+ self.attn_block = attn_block if attn_block else nn.Identity()
279
+
280
+ def forward(self, x: torch.Tensor):
281
+ """Forward pass.
282
+
283
+ Args:
284
+ x: Input tensor. Shape: [B, C, T, H, W].
285
+ """
286
+ residual = x
287
+ x = self.stack(x)
288
+ x = x + residual
289
+ del residual
290
+
291
+ return self.attn_block(x)
292
+
293
+
294
+ def prepare_for_attention(qkv: torch.Tensor, head_dim: int, qk_norm: bool = True):
295
+ """Prepare qkv tensor for attention and normalize qk.
296
+
297
+ Args:
298
+ qkv: Input tensor. Shape: [B, L, 3 * num_heads * head_dim].
299
+
300
+ Returns:
301
+ q, k, v: qkv tensor split into q, k, v. Shape: [B, num_heads, L, head_dim].
302
+ """
303
+ assert qkv.ndim == 3 # [B, L, 3 * num_heads * head_dim]
304
+ assert qkv.size(2) % (3 * head_dim) == 0
305
+ num_heads = qkv.size(2) // (3 * head_dim)
306
+ qkv = qkv.unflatten(2, (3, num_heads, head_dim))
307
+
308
+ q, k, v = qkv.unbind(2) # [B, L, num_heads, head_dim]
309
+ q = q.transpose(1, 2) # [B, num_heads, L, head_dim]
310
+ k = k.transpose(1, 2) # [B, num_heads, L, head_dim]
311
+ v = v.transpose(1, 2) # [B, num_heads, L, head_dim]
312
+
313
+ if qk_norm:
314
+ q = F.normalize(q, p=2, dim=-1)
315
+ k = F.normalize(k, p=2, dim=-1)
316
+
317
+ # Mixed precision can change the dtype of normed q/k to float32.
318
+ q = q.to(dtype=qkv.dtype)
319
+ k = k.to(dtype=qkv.dtype)
320
+
321
+ return q, k, v
322
+
323
+
324
+ class Attention(nn.Module):
325
+ def __init__(
326
+ self,
327
+ dim: int,
328
+ head_dim: int = 32,
329
+ qkv_bias: bool = False,
330
+ out_bias: bool = True,
331
+ qk_norm: bool = True,
332
+ ) -> None:
333
+ super().__init__()
334
+ self.head_dim = head_dim
335
+ self.num_heads = dim // head_dim
336
+ self.qk_norm = qk_norm
337
+
338
+ self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
339
+ self.out = nn.Linear(dim, dim, bias=out_bias)
340
+
341
+ def forward(
342
+ self,
343
+ x: torch.Tensor,
344
+ *,
345
+ chunk_size=2**15,
346
+ ) -> torch.Tensor:
347
+ """Compute temporal self-attention.
348
+
349
+ Args:
350
+ x: Input tensor. Shape: [B, C, T, H, W].
351
+ chunk_size: Chunk size for large tensors.
352
+
353
+ Returns:
354
+ x: Output tensor. Shape: [B, C, T, H, W].
355
+ """
356
+ B, _, T, H, W = x.shape
357
+
358
+ if T == 1:
359
+ # No attention for single frame.
360
+ x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C]
361
+ qkv = self.qkv(x)
362
+ _, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys.
363
+ x = self.out(x)
364
+ return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W]
365
+
366
+ # 1D temporal attention.
367
+ x = rearrange(x, "B C t h w -> (B h w) t C")
368
+ qkv = self.qkv(x)
369
+
370
+ # Input: qkv with shape [B, t, 3 * num_heads * head_dim]
371
+ # Output: x with shape [B, num_heads, t, head_dim]
372
+ q, k, v = prepare_for_attention(qkv, self.head_dim, qk_norm=self.qk_norm)
373
+
374
+ attn_kwargs = dict(
375
+ attn_mask=None,
376
+ dropout_p=0.0,
377
+ is_causal=True,
378
+ scale=self.head_dim**-0.5,
379
+ )
380
+
381
+ if q.size(0) <= chunk_size:
382
+ x = F.scaled_dot_product_attention(q, k, v, **attn_kwargs) # [B, num_heads, t, head_dim]
383
+ else:
384
+ # Evaluate in chunks to avoid `RuntimeError: CUDA error: invalid configuration argument.`
385
+ # Chunks of 2**16 and up cause an error.
386
+ x = torch.empty_like(q)
387
+ for i in range(0, q.size(0), chunk_size):
388
+ qc = q[i : i + chunk_size]
389
+ kc = k[i : i + chunk_size]
390
+ vc = v[i : i + chunk_size]
391
+ chunk = F.scaled_dot_product_attention(qc, kc, vc, **attn_kwargs)
392
+ x[i : i + chunk_size].copy_(chunk)
393
+
394
+ assert x.size(0) == q.size(0)
395
+ x = x.transpose(1, 2) # [B, t, num_heads, head_dim]
396
+ x = x.flatten(2) # [B, t, num_heads * head_dim]
397
+
398
+ x = self.out(x)
399
+ x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
400
+ return x
401
+
402
+
403
+ class AttentionBlock(nn.Module):
404
+ def __init__(
405
+ self,
406
+ dim: int,
407
+ **attn_kwargs,
408
+ ) -> None:
409
+ super().__init__()
410
+ self.norm = norm_fn(dim)
411
+ self.attn = Attention(dim, **attn_kwargs)
412
+
413
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
414
+ return x + self.attn(self.norm(x))
415
+
416
+
417
+ class CausalUpsampleBlock(nn.Module):
418
+ def __init__(
419
+ self,
420
+ in_channels: int,
421
+ out_channels: int,
422
+ num_res_blocks: int,
423
+ *,
424
+ temporal_expansion: int = 2,
425
+ spatial_expansion: int = 2,
426
+ **block_kwargs,
427
+ ):
428
+ super().__init__()
429
+
430
+ blocks = []
431
+ for _ in range(num_res_blocks):
432
+ blocks.append(block_fn(in_channels, **block_kwargs))
433
+ self.blocks = nn.Sequential(*blocks)
434
+
435
+ self.temporal_expansion = temporal_expansion
436
+ self.spatial_expansion = spatial_expansion
437
+
438
+ # Change channels in the final convolution layer.
439
+ self.proj = Conv1x1(
440
+ in_channels,
441
+ out_channels * temporal_expansion * (spatial_expansion**2),
442
+ )
443
+
444
+ self.d2st = DepthToSpaceTime(temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion)
445
+
446
+ def forward(self, x):
447
+ x = self.blocks(x)
448
+ x = self.proj(x)
449
+ x = self.d2st(x)
450
+ return x
451
+
452
+
453
+ def block_fn(channels, *, has_attention: bool = False, **block_kwargs):
454
+ attn_block = AttentionBlock(channels) if has_attention else None
455
+
456
+ return ResBlock(channels, affine=True, attn_block=attn_block, **block_kwargs)
457
+
458
+
459
+ class DownsampleBlock(nn.Module):
460
+ def __init__(
461
+ self,
462
+ in_channels: int,
463
+ out_channels: int,
464
+ num_res_blocks,
465
+ *,
466
+ temporal_reduction=2,
467
+ spatial_reduction=2,
468
+ **block_kwargs,
469
+ ):
470
+ """
471
+ Downsample block for the VAE encoder.
472
+
473
+ Args:
474
+ in_channels: Number of input channels.
475
+ out_channels: Number of output channels.
476
+ num_res_blocks: Number of residual blocks.
477
+ temporal_reduction: Temporal reduction factor.
478
+ spatial_reduction: Spatial reduction factor.
479
+ """
480
+ super().__init__()
481
+ layers = []
482
+
483
+ # Change the channel count in the strided convolution.
484
+ # This lets the ResBlock have uniform channel count,
485
+ # as in ConvNeXt.
486
+ assert in_channels != out_channels
487
+ layers.append(
488
+ ContextParallelConv3d(
489
+ in_channels=in_channels,
490
+ out_channels=out_channels,
491
+ kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
492
+ stride=(temporal_reduction, spatial_reduction, spatial_reduction),
493
+ padding_mode="replicate",
494
+ bias=True,
495
+ )
496
+ )
497
+
498
+ for _ in range(num_res_blocks):
499
+ layers.append(block_fn(out_channels, **block_kwargs))
500
+
501
+ self.layers = nn.Sequential(*layers)
502
+
503
+ def forward(self, x):
504
+ return self.layers(x)
505
+
506
+
507
+ def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1):
508
+ num_freqs = (stop - start) // step
509
+ assert inputs.ndim == 5
510
+ C = inputs.size(1)
511
+
512
+ # Create Base 2 Fourier features.
513
+ freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device)
514
+ assert num_freqs == len(freqs)
515
+ w = torch.pow(2.0, freqs) * (2 * torch.pi) # [num_freqs]
516
+ C = inputs.shape[1]
517
+ w = w.repeat(C)[None, :, None, None, None] # [1, C * num_freqs, 1, 1, 1]
518
+
519
+ # Interleaved repeat of input channels to match w.
520
+ h = inputs.repeat_interleave(num_freqs, dim=1) # [B, C * num_freqs, T, H, W]
521
+ # Scale channels by frequency.
522
+ h = w * h
523
+
524
+ return torch.cat(
525
+ [
526
+ inputs,
527
+ torch.sin(h),
528
+ torch.cos(h),
529
+ ],
530
+ dim=1,
531
+ )
532
+
533
+
534
+ class FourierFeatures(nn.Module):
535
+ def __init__(self, start: int = 6, stop: int = 8, step: int = 1):
536
+ super().__init__()
537
+ self.start = start
538
+ self.stop = stop
539
+ self.step = step
540
+
541
+ def forward(self, inputs):
542
+ """Add Fourier features to inputs.
543
+
544
+ Args:
545
+ inputs: Input tensor. Shape: [B, C, T, H, W]
546
+
547
+ Returns:
548
+ h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W]
549
+ """
550
+ return add_fourier_features(inputs, self.start, self.stop, self.step)
551
+
552
+
553
+ class Decoder(nn.Module):
554
+ def __init__(
555
+ self,
556
+ *,
557
+ out_channels: int = 3,
558
+ latent_dim: int,
559
+ base_channels: int,
560
+ channel_multipliers: List[int],
561
+ num_res_blocks: List[int],
562
+ temporal_expansions: Optional[List[int]] = None,
563
+ spatial_expansions: Optional[List[int]] = None,
564
+ has_attention: List[bool],
565
+ output_norm: bool = True,
566
+ nonlinearity: str = "silu",
567
+ output_nonlinearity: str = "silu",
568
+ causal: bool = True,
569
+ **block_kwargs,
570
+ ):
571
+ super().__init__()
572
+ self.input_channels = latent_dim
573
+ self.base_channels = base_channels
574
+ self.channel_multipliers = channel_multipliers
575
+ self.num_res_blocks = num_res_blocks
576
+ self.output_nonlinearity = output_nonlinearity
577
+ assert nonlinearity == "silu"
578
+ assert causal
579
+
580
+ ch = [mult * base_channels for mult in channel_multipliers]
581
+ self.num_up_blocks = len(ch) - 1
582
+ assert len(num_res_blocks) == self.num_up_blocks + 2
583
+
584
+ blocks = []
585
+
586
+ first_block = [nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))] # Input layer.
587
+ # First set of blocks preserve channel count.
588
+ for _ in range(num_res_blocks[-1]):
589
+ first_block.append(
590
+ block_fn(
591
+ ch[-1],
592
+ has_attention=has_attention[-1],
593
+ causal=causal,
594
+ **block_kwargs,
595
+ )
596
+ )
597
+ blocks.append(nn.Sequential(*first_block))
598
+
599
+ assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks
600
+ assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2
601
+
602
+ upsample_block_fn = CausalUpsampleBlock
603
+
604
+ for i in range(self.num_up_blocks):
605
+ block = upsample_block_fn(
606
+ ch[-i - 1],
607
+ ch[-i - 2],
608
+ num_res_blocks=num_res_blocks[-i - 2],
609
+ has_attention=has_attention[-i - 2],
610
+ temporal_expansion=temporal_expansions[-i - 1],
611
+ spatial_expansion=spatial_expansions[-i - 1],
612
+ causal=causal,
613
+ **block_kwargs,
614
+ )
615
+ blocks.append(block)
616
+
617
+ assert not output_norm
618
+
619
+ # Last block. Preserve channel count.
620
+ last_block = []
621
+ for _ in range(num_res_blocks[0]):
622
+ last_block.append(block_fn(ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs))
623
+ blocks.append(nn.Sequential(*last_block))
624
+
625
+ self.blocks = nn.ModuleList(blocks)
626
+ self.output_proj = Conv1x1(ch[0], out_channels)
627
+
628
+ def unnormalize_latents(
629
+ self,
630
+ z: torch.Tensor,
631
+ mean: torch.Tensor,
632
+ std: torch.Tensor,
633
+ ) -> torch.Tensor:
634
+ """Unnormalize latents. Useful for decoding DiT samples.
635
+
636
+ Args:
637
+ z (torch.Tensor): [B, C_z, T_z, H_z, W_z], float
638
+
639
+ Returns:
640
+ torch.Tensor: [B, C_z, T_z, H_z, W_z], float
641
+ """
642
+ mean = mean[:, None, None, None]
643
+ std = std[:, None, None, None]
644
+
645
+ assert z.ndim == 5
646
+ assert z.size(1) == mean.size(0) == std.size(0)
647
+ return z * std.to(z) + mean.to(z)
648
+
649
+ def forward(self, x):
650
+ """Forward pass.
651
+
652
+ Args:
653
+ x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1].
654
+
655
+ Returns:
656
+ x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1].
657
+ T + 1 = (t - 1) * 4.
658
+ H = h * 16, W = w * 16.
659
+ """
660
+ x = self.unnormalize_latents(x, self.vae_mean, self.vae_std)
661
+
662
+ for block in self.blocks:
663
+ x = block(x)
664
+
665
+ if self.output_nonlinearity == "silu":
666
+ x = F.silu(x, inplace=not self.training)
667
+ else:
668
+ assert not self.output_nonlinearity # StyleGAN3 omits the to-RGB nonlinearity.
669
+
670
+ return self.output_proj(x).contiguous()
671
+
672
+
673
+ def make_broadcastable(
674
+ tensor: torch.Tensor,
675
+ axis: int,
676
+ ndim: int,
677
+ ) -> torch.Tensor:
678
+ """
679
+ Reshapes the input tensor to have singleton dimensions in all axes except the specified axis.
680
+
681
+ Args:
682
+ tensor (torch.Tensor): The tensor to reshape. Typically 1D.
683
+ axis (int): The axis along which the tensor should retain its original size.
684
+ ndim (int): The total number of dimensions the reshaped tensor should have.
685
+
686
+ Returns:
687
+ torch.Tensor: The reshaped tensor with shape suitable for broadcasting.
688
+ """
689
+ if tensor.dim() != 1:
690
+ raise ValueError(f"Expected tensor to be 1D, but got {tensor.dim()}D tensor.")
691
+
692
+ axis = (axis + ndim) % ndim # Ensure the axis is within the tensor dimensions
693
+ shape = [1] * ndim # Start with all dimensions as 1
694
+ shape[axis] = tensor.size(0) # Set the specified axis to the size of the tensor
695
+ return tensor.view(*shape)
696
+
697
+
698
+ def blend(a: torch.Tensor, b: torch.Tensor, axis: int) -> torch.Tensor:
699
+ """
700
+ Blends two tensors `a` and `b` along the specified axis using linear interpolation.
701
+
702
+ Args:
703
+ a (torch.Tensor): The first tensor.
704
+ b (torch.Tensor): The second tensor. Must have the same shape as `a`.
705
+ axis (int): The axis along which to perform the blending.
706
+
707
+ Returns:
708
+ torch.Tensor: The blended tensor.
709
+ """
710
+ assert a.shape == b.shape, f"Tensors must have the same shape, got {a.shape} and {b.shape}"
711
+ steps = a.size(axis)
712
+
713
+ # Create a weight tensor that linearly interpolates from 0 to 1
714
+ start = 1 / (steps + 1)
715
+ end = steps / (steps + 1)
716
+ weight = torch.linspace(start, end, steps=steps, device=a.device, dtype=a.dtype)
717
+
718
+ # Make the weight tensor broadcastable across all dimensions
719
+ weight = make_broadcastable(weight, axis, a.dim())
720
+
721
+ # Perform the blending
722
+ return a * (1 - weight) + b * weight
723
+
724
+
725
+ def blend_horizontal(a: torch.Tensor, b: torch.Tensor, overlap: int) -> torch.Tensor:
726
+ if overlap == 0:
727
+ return torch.cat([a, b], dim=-1)
728
+
729
+ assert a.size(-1) >= overlap
730
+ assert b.size(-1) >= overlap
731
+ a_left, a_overlap = a[..., :-overlap], a[..., -overlap:]
732
+ b_overlap, b_right = b[..., :overlap], b[..., overlap:]
733
+ return torch.cat([a_left, blend(a_overlap, b_overlap, -1), b_right], dim=-1)
734
+
735
+
736
+ def blend_vertical(a: torch.Tensor, b: torch.Tensor, overlap: int) -> torch.Tensor:
737
+ if overlap == 0:
738
+ return torch.cat([a, b], dim=-2)
739
+
740
+ assert a.size(-2) >= overlap
741
+ assert b.size(-2) >= overlap
742
+ a_top, a_overlap = a[..., :-overlap, :], a[..., -overlap:, :]
743
+ b_overlap, b_bottom = b[..., :overlap, :], b[..., overlap:, :]
744
+ return torch.cat([a_top, blend(a_overlap, b_overlap, -2), b_bottom], dim=-2)
745
+
746
+
747
+ def nearest_multiple(x: int, multiple: int) -> int:
748
+ return round(x / multiple) * multiple
749
+
750
+
751
+ def apply_tiled(
752
+ fn: Callable[[torch.Tensor], torch.Tensor],
753
+ x: torch.Tensor,
754
+ num_tiles_w: int,
755
+ num_tiles_h: int,
756
+ overlap: int = 0, # Number of pixel of overlap between adjacent tiles.
757
+ # Use a factor of 2 times the latent downsample factor.
758
+ min_block_size: int = 1, # Minimum number of pixels in each dimension when subdividing.
759
+ ) -> Optional[torch.Tensor]:
760
+ if num_tiles_w == 1 and num_tiles_h == 1:
761
+ return fn(x)
762
+
763
+ assert num_tiles_w & (num_tiles_w - 1) == 0, f"num_tiles_w={num_tiles_w} must be a power of 2"
764
+ assert num_tiles_h & (num_tiles_h - 1) == 0, f"num_tiles_h={num_tiles_h} must be a power of 2"
765
+
766
+ H, W = x.shape[-2:]
767
+ assert H % min_block_size == 0
768
+ assert W % min_block_size == 0
769
+ ov = overlap // 2
770
+ assert ov % min_block_size == 0
771
+
772
+ if num_tiles_w >= 2:
773
+ # Subdivide horizontally.
774
+ half_W = nearest_multiple(W // 2, min_block_size)
775
+ left = x[..., :, : half_W + ov]
776
+ right = x[..., :, half_W - ov :]
777
+
778
+ assert num_tiles_w % 2 == 0, f"num_tiles_w={num_tiles_w} must be even"
779
+ left = apply_tiled(fn, left, num_tiles_w // 2, num_tiles_h, overlap, min_block_size)
780
+ right = apply_tiled(fn, right, num_tiles_w // 2, num_tiles_h, overlap, min_block_size)
781
+ if left is None or right is None:
782
+ return None
783
+
784
+ # If `fn` changed the resolution, adjust the overlap.
785
+ resample_factor = left.size(-1) / (half_W + ov)
786
+ out_overlap = int(overlap * resample_factor)
787
+
788
+ return blend_horizontal(left, right, out_overlap)
789
+
790
+ if num_tiles_h >= 2:
791
+ # Subdivide vertically.
792
+ half_H = nearest_multiple(H // 2, min_block_size)
793
+ top = x[..., : half_H + ov, :]
794
+ bottom = x[..., half_H - ov :, :]
795
+
796
+ assert num_tiles_h % 2 == 0, f"num_tiles_h={num_tiles_h} must be even"
797
+ top = apply_tiled(fn, top, num_tiles_w, num_tiles_h // 2, overlap, min_block_size)
798
+ bottom = apply_tiled(fn, bottom, num_tiles_w, num_tiles_h // 2, overlap, min_block_size)
799
+ if top is None or bottom is None:
800
+ return None
801
+
802
+ # If `fn` changed the resolution, adjust the overlap.
803
+ resample_factor = top.size(-2) / (half_H + ov)
804
+ out_overlap = int(overlap * resample_factor)
805
+
806
+ return blend_vertical(top, bottom, out_overlap)
807
+
808
+ raise ValueError(f"Invalid num_tiles_w={num_tiles_w} and num_tiles_h={num_tiles_h}")
uv.lock ADDED
The diff for this file is too large to render. See raw diff