sironagasuyagi commited on
Commit
910e2ad
·
verified ·
1 Parent(s): 43fbfb0

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +119 -0
  2. .gradio/certificate.pem +31 -0
  3. LICENSE +21 -0
  4. README - コピー.md +295 -0
  5. README.md +295 -12
  6. annotation/image_text.jsonl +20 -0
  7. annotation/video_text.jsonl +17 -0
  8. app.py +356 -0
  9. app_multigpu.py +143 -0
  10. assets/motivation.jpg +0 -0
  11. assets/the_great_wall.jpg +0 -0
  12. assets/user_study.jpg +0 -0
  13. assets/vbench.jpg +0 -0
  14. causal_video_vae_demo.ipynb +221 -0
  15. dataset/__init__.py +12 -0
  16. dataset/bucket_loader.py +148 -0
  17. dataset/dataloaders.py +190 -0
  18. dataset/dataset_cls.py +377 -0
  19. diffusion_schedulers/__init__.py +2 -0
  20. diffusion_schedulers/scheduling_cosine_ddpm.py +137 -0
  21. diffusion_schedulers/scheduling_flow_matching.py +297 -0
  22. docs/DiT.md +54 -0
  23. docs/VAE.md +42 -0
  24. image_generation_demo.ipynb +123 -0
  25. inference_multigpu.py +123 -0
  26. pyramid_dit/__init__.py +3 -0
  27. pyramid_dit/flux_modules/__init__.py +3 -0
  28. pyramid_dit/flux_modules/modeling_embedding.py +201 -0
  29. pyramid_dit/flux_modules/modeling_flux_block.py +1044 -0
  30. pyramid_dit/flux_modules/modeling_normalization.py +249 -0
  31. pyramid_dit/flux_modules/modeling_pyramid_flux.py +543 -0
  32. pyramid_dit/flux_modules/modeling_text_encoder.py +134 -0
  33. pyramid_dit/mmdit_modules/__init__.py +3 -0
  34. pyramid_dit/mmdit_modules/modeling_embedding.py +390 -0
  35. pyramid_dit/mmdit_modules/modeling_mmdit_block.py +671 -0
  36. pyramid_dit/mmdit_modules/modeling_normalization.py +179 -0
  37. pyramid_dit/mmdit_modules/modeling_pyramid_mmdit.py +497 -0
  38. pyramid_dit/mmdit_modules/modeling_text_encoder.py +140 -0
  39. pyramid_dit/pyramid_dit_for_video_gen_pipeline.py +1279 -0
  40. pyramid_flow_model.lnk +0 -0
  41. pyramid_flow_model/.gitattributes +35 -0
  42. pyramid_flow_model/README.md +191 -0
  43. pyramid_flow_model/causal_video_vae/config.json +92 -0
  44. pyramid_flow_model/causal_video_vae/diffusion_pytorch_model.bin +3 -0
  45. pyramid_flow_model/diffusion_transformer_384p/config.json +21 -0
  46. pyramid_flow_model/diffusion_transformer_384p/diffusion_pytorch_model.safetensors +3 -0
  47. pyramid_flow_model/diffusion_transformer_768p/config.json +21 -0
  48. pyramid_flow_model/diffusion_transformer_768p/diffusion_pytorch_model.safetensors +3 -0
  49. pyramid_flow_model/diffusion_transformer_image/config.json +21 -0
  50. pyramid_flow_model/diffusion_transformer_image/diffusion_pytorch_model.safetensors +3 -0
.gitignore ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Xcode
2
+ .DS_Store
3
+ .idea
4
+
5
+ # tyte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+ # C extensions
10
+ *.so
11
+ onnx_model/*.onnx
12
+ onnx_model/antelope/*.onnx
13
+
14
+
15
+ logs/
16
+ prompts/
17
+
18
+ # Distribution / packaging
19
+ .Python
20
+ build/
21
+ develop-eggs/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ wheels/
30
+ share/python-wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ .python-version
87
+
88
+ # celery beat schedule file
89
+ celerybeat-schedule
90
+
91
+ # SageMath parsed files
92
+ *.sage.py
93
+
94
+ # Environments
95
+ .env
96
+ .pt2/
97
+ .venv
98
+ env/
99
+ venv/
100
+ ENV/
101
+ env.bak/
102
+ venv.bak/
103
+
104
+ # Spyder project settings
105
+ .spyderproject
106
+ .spyproject
107
+
108
+ # Rope project settings
109
+ .ropeproject
110
+
111
+ # mkdocs documentation
112
+ /site
113
+
114
+ # mypy
115
+ .mypy_cache/
116
+ .dmypy.json
117
+ dmypy.json
118
+ .bak
119
+
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Yang Jin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README - コピー.md ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Pyramid-Flow
3
+ app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 5.6.0
6
+ ---
7
+ <div align="center">
8
+
9
+ # ⚡️Pyramid Flow⚡️
10
+
11
+ [[Paper]](https://arxiv.org/abs/2410.05954) [[Project Page ✨]](https://pyramid-flow.github.io) [[miniFLUX Model 🚀]](https://huggingface.co/rain1011/pyramid-flow-miniflux) [[SD3 Model ⚡️]](https://huggingface.co/rain1011/pyramid-flow-sd3) [[demo 🤗](https://huggingface.co/spaces/Pyramid-Flow/pyramid-flow)]
12
+
13
+ </div>
14
+
15
+ This is the official repository for Pyramid Flow, a training-efficient **Autoregressive Video Generation** method based on **Flow Matching**. By training only on **open-source datasets**, it can generate high-quality 10-second videos at 768p resolution and 24 FPS, and naturally supports image-to-video generation.
16
+
17
+ <table class="center" border="0" style="width: 100%; text-align: left;">
18
+ <tr>
19
+ <th>10s, 768p, 24fps</th>
20
+ <th>5s, 768p, 24fps</th>
21
+ <th>Image-to-video</th>
22
+ </tr>
23
+ <tr>
24
+ <td><video src="https://github.com/user-attachments/assets/9935da83-ae56-4672-8747-0f46e90f7b2b" autoplay muted loop playsinline></video></td>
25
+ <td><video src="https://github.com/user-attachments/assets/3412848b-64db-4d9e-8dbf-11403f6d02c5" autoplay muted loop playsinline></video></td>
26
+ <td><video src="https://github.com/user-attachments/assets/3bd7251f-7b2c-4bee-951d-656fdb45f427" autoplay muted loop playsinline></video></td>
27
+ </tr>
28
+ </table>
29
+
30
+ ## News
31
+ * `2024.11.13` 🚀🚀🚀 We release the [768p miniFLUX checkpoint](https://huggingface.co/rain1011/pyramid-flow-miniflux) (up to 10s).
32
+
33
+ > We have switched the model structure from SD3 to a mini FLUX to fix human structure issues, please try our 1024p image checkpoint, 384p video checkpoint (up to 5s) and 768p video checkpoint (up to 10s). The new miniflux model shows great improvement on human structure and motion stability
34
+
35
+ * `2024.10.29` ⚡️⚡️⚡️ We release [training code for VAE](#1-training-vae), [finetuning code for DiT](#2-finetuning-dit) and [new model checkpoints](https://huggingface.co/rain1011/pyramid-flow-miniflux) with FLUX structure trained from scratch.
36
+
37
+
38
+ * `2024.10.13` ✨✨✨ [Multi-GPU inference](#3-multi-gpu-inference) and [CPU offloading](#cpu-offloading) are supported. Use it with **less than 8GB** of GPU memory, with great speedup on multiple GPUs.
39
+
40
+ * `2024.10.11` 🤗🤗🤗 [Hugging Face demo](https://huggingface.co/spaces/Pyramid-Flow/pyramid-flow) is available. Thanks [@multimodalart](https://huggingface.co/multimodalart) for the commit!
41
+
42
+ * `2024.10.10` 🚀🚀🚀 We release the [technical report](https://arxiv.org/abs/2410.05954), [project page](https://pyramid-flow.github.io) and [model checkpoint](https://huggingface.co/rain1011/pyramid-flow-sd3) of Pyramid Flow.
43
+
44
+ ## Table of Contents
45
+
46
+ * [Introduction](#introduction)
47
+ * [Installation](#installation)
48
+ * [Inference](#inference)
49
+ 1. [Quick Start with Gradio](#1-quick-start-with-gradio)
50
+ 2. [Inference Code](#2-inference-code)
51
+ 3. [Multi-GPU Inference](#3-multi-gpu-inference)
52
+ 4. [Usage Tips](#4-usage-tips)
53
+ * [Training](#Training)
54
+ 1. [Training VAE](#training-vae)
55
+ 2. [Finetuning DiT](#finetuning-dit)
56
+ * [Gallery](#gallery)
57
+ * [Comparison](#comparison)
58
+ * [Acknowledgement](#acknowledgement)
59
+ * [Citation](#citation)
60
+
61
+ ## Introduction
62
+
63
+ ![motivation](assets/motivation.jpg)
64
+
65
+ Existing video diffusion models operate at full resolution, spending a lot of computation on very noisy latents. By contrast, our method harnesses the flexibility of flow matching ([Lipman et al., 2023](https://openreview.net/forum?id=PqvMRDCJT9t); [Liu et al., 2023](https://openreview.net/forum?id=XVjTT1nw5z); [Albergo & Vanden-Eijnden, 2023](https://openreview.net/forum?id=li7qeBbCR1t)) to interpolate between latents of different resolutions and noise levels, allowing for simultaneous generation and decompression of visual content with better computational efficiency. The entire framework is end-to-end optimized with a single DiT ([Peebles & Xie, 2023](http://openaccess.thecvf.com/content/ICCV2023/html/Peebles_Scalable_Diffusion_Models_with_Transformers_ICCV_2023_paper.html)), generating high-quality 10-second videos at 768p resolution and 24 FPS within 20.7k A100 GPU training hours.
66
+
67
+ ## Installation
68
+
69
+ We recommend setting up the environment with conda. The codebase currently uses Python 3.8.10 and PyTorch 2.1.2 ([guide](https://pytorch.org/get-started/previous-versions/#v212)), and we are actively working to support a wider range of versions.
70
+
71
+ ```bash
72
+ git clone https://github.com/jy0205/Pyramid-Flow
73
+ cd Pyramid-Flow
74
+
75
+ # create env using conda
76
+ conda create -n pyramid python==3.8.10
77
+ conda activate pyramid
78
+ pip install -r requirements.txt
79
+ ```
80
+
81
+ Then, download the model from [Huggingface](https://huggingface.co/rain1011) (there are two variants: [miniFLUX](https://huggingface.co/rain1011/pyramid-flow-miniflux) or [SD3](https://huggingface.co/rain1011/pyramid-flow-sd3)). The miniFLUX models support 1024p image, 384p and 768p video generation, and the SD3-based models support 768p and 384p video generation. The 384p checkpoint generates 5-second video at 24FPS, while the 768p checkpoint generates up to 10-second video at 24FPS.
82
+
83
+ ```python
84
+ from huggingface_hub import snapshot_download
85
+
86
+ model_path = 'PATH' # The local directory to save downloaded checkpoint
87
+ snapshot_download("rain1011/pyramid-flow-miniflux", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
88
+ ```
89
+
90
+ ## Inference
91
+
92
+ ### 1. Quick start with Gradio
93
+
94
+ To get started, first install [Gradio](https://www.gradio.app/guides/quickstart), set your model path at [#L36](https://github.com/jy0205/Pyramid-Flow/blob/3777f8b84bddfa2aa2b497ca919b3f40567712e6/app.py#L36), and then run on your local machine:
95
+
96
+ ```bash
97
+ python app.py
98
+ ```
99
+
100
+ The Gradio demo will be opened in a browser. Thanks to [@tpc2233](https://github.com/tpc2233) the commit, see [#48](https://github.com/jy0205/Pyramid-Flow/pull/48) for details.
101
+
102
+ Or, try it out effortlessly on [Hugging Face Space 🤗](https://huggingface.co/spaces/Pyramid-Flow/pyramid-flow) created by [@multimodalart](https://huggingface.co/multimodalart). Due to GPU limits, this online demo can only generate 25 frames (export at 8FPS or 24FPS). Duplicate the space to generate longer videos.
103
+
104
+ #### Quick Start on Google Colab
105
+
106
+ To quickly try out Pyramid Flow on Google Colab, run the code below:
107
+
108
+ ```
109
+ # Setup
110
+ !git clone https://github.com/jy0205/Pyramid-Flow
111
+ %cd Pyramid-Flow
112
+ !pip install -r requirements.txt
113
+ !pip install gradio
114
+
115
+ # This code downloads miniFLUX
116
+ from huggingface_hub import snapshot_download
117
+
118
+ model_path = '/content/Pyramid-Flow'
119
+ snapshot_download("rain1011/pyramid-flow-miniflux", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
120
+
121
+ # Start
122
+ !python app.py
123
+ ```
124
+
125
+ ### 2. Inference Code
126
+
127
+ To use our model, please follow the inference code in `video_generation_demo.ipynb` at [this link](https://github.com/jy0205/Pyramid-Flow/blob/main/video_generation_demo.ipynb). We strongly recommend you to try the latest published pyramid-miniflux, which shows great improvement on human structure and motion stability. Set the param `model_name` to `pyramid_flux` to use. We further simplify it into the following two-step procedure. First, load the downloaded model:
128
+
129
+ ```python
130
+ import torch
131
+ from PIL import Image
132
+ from pyramid_dit import PyramidDiTForVideoGeneration
133
+ from diffusers.utils import load_image, export_to_video
134
+
135
+ torch.cuda.set_device(0)
136
+ model_dtype, torch_dtype = 'bf16', torch.bfloat16 # Use bf16 (not support fp16 yet)
137
+
138
+ model = PyramidDiTForVideoGeneration(
139
+ 'PATH', # The downloaded checkpoint dir
140
+ model_name="pyramid_flux",
141
+ model_dtype,
142
+ model_variant='diffusion_transformer_768p',
143
+ )
144
+
145
+ model.vae.enable_tiling()
146
+ # model.vae.to("cuda")
147
+ # model.dit.to("cuda")
148
+ # model.text_encoder.to("cuda")
149
+
150
+ # if you're not using sequential offloading bellow uncomment the lines above ^
151
+ model.enable_sequential_cpu_offload()
152
+ ```
153
+
154
+ Then, you can try text-to-video generation on your own prompts. Noting that the 384p version only support 5s now (set temp up to 16)!
155
+
156
+ ```python
157
+ prompt = "A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors"
158
+
159
+ # used for 384p model variant
160
+ # width = 640
161
+ # height = 384
162
+
163
+ # used for 768p model variant
164
+ width = 1280
165
+ height = 768
166
+
167
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
168
+ frames = model.generate(
169
+ prompt=prompt,
170
+ num_inference_steps=[20, 20, 20],
171
+ video_num_inference_steps=[10, 10, 10],
172
+ height=height,
173
+ width=width,
174
+ temp=16, # temp=16: 5s, temp=31: 10s
175
+ guidance_scale=7.0, # The guidance for the first frame, set it to 7 for 384p variant
176
+ video_guidance_scale=5.0, # The guidance for the other video latent
177
+ output_type="pil",
178
+ save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
179
+ )
180
+
181
+ export_to_video(frames, "./text_to_video_sample.mp4", fps=24)
182
+ ```
183
+
184
+ As an autoregressive model, our model also supports (text conditioned) image-to-video generation:
185
+
186
+ ```python
187
+ # used for 384p model variant
188
+ # width = 640
189
+ # height = 384
190
+
191
+ # used for 768p model variant
192
+ width = 1280
193
+ height = 768
194
+
195
+ image = Image.open('assets/the_great_wall.jpg').convert("RGB").resize((width, height))
196
+ prompt = "FPV flying over the Great Wall"
197
+
198
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
199
+ frames = model.generate_i2v(
200
+ prompt=prompt,
201
+ input_image=image,
202
+ num_inference_steps=[10, 10, 10],
203
+ temp=16,
204
+ video_guidance_scale=4.0,
205
+ output_type="pil",
206
+ save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
207
+ )
208
+
209
+ export_to_video(frames, "./image_to_video_sample.mp4", fps=24)
210
+ ```
211
+
212
+ #### CPU offloading
213
+
214
+ We also support two types of CPU offloading to reduce GPU memory requirements. Note that they may sacrifice efficiency.
215
+ * Adding a `cpu_offloading=True` parameter to the generate function allows inference with **less than 12GB** of GPU memory. This feature was contributed by [@Ednaordinary](https://github.com/Ednaordinary), see [#23](https://github.com/jy0205/Pyramid-Flow/pull/23) for details.
216
+ * Calling `model.enable_sequential_cpu_offload()` before the above procedure allows inference with **less than 8GB** of GPU memory. This feature was contributed by [@rodjjo](https://github.com/rodjjo), see [#75](https://github.com/jy0205/Pyramid-Flow/pull/75) for details.
217
+
218
+ #### MPS backend
219
+
220
+ Thanks to [@niw](https://github.com/niw), Apple Silicon users (e.g. MacBook Pro with M2 24GB) can also try our model using the MPS backend! Please see [#113](https://github.com/jy0205/Pyramid-Flow/pull/113) for the details.
221
+
222
+ ### 3. Multi-GPU Inference
223
+
224
+ For users with multiple GPUs, we provide an [inference script](https://github.com/jy0205/Pyramid-Flow/blob/main/scripts/inference_multigpu.sh) that uses sequence parallelism to save memory on each GPU. This also brings a big speedup, taking only 2.5 minutes to generate a 5s, 768p, 24fps video on 4 A100 GPUs (vs. 5.5 minutes on a single A100 GPU). Run it on 2 GPUs with the following command:
225
+
226
+ ```bash
227
+ CUDA_VISIBLE_DEVICES=0,1 sh scripts/inference_multigpu.sh
228
+ ```
229
+
230
+ It currently supports 2 or 4 GPUs (For SD3 Version), with more configurations available in the original script. You can also launch a [multi-GPU Gradio demo](https://github.com/jy0205/Pyramid-Flow/blob/main/scripts/app_multigpu_engine.sh) created by [@tpc2233](https://github.com/tpc2233), see [#59](https://github.com/jy0205/Pyramid-Flow/pull/59) for details.
231
+
232
+ > Spoiler: We didn't even use sequence parallelism in training, thanks to our efficient pyramid flow designs.
233
+
234
+ ### 4. Usage tips
235
+
236
+ * The `guidance_scale` parameter controls the visual quality. We suggest using a guidance within [7, 9] for the 768p checkpoint during text-to-video generation, and 7 for the 384p checkpoint.
237
+ * The `video_guidance_scale` parameter controls the motion. A larger value increases the dynamic degree and mitigates the autoregressive generation degradation, while a smaller value stabilizes the video.
238
+ * For 10-second video generation, we recommend using a guidance scale of 7 and a video guidance scale of 5.
239
+
240
+ ## Training
241
+
242
+ ### 1. Training VAE
243
+
244
+ The hardware requirements for training VAE are at least 8 A100 GPUs. Please refer to [this document](https://github.com/jy0205/Pyramid-Flow/blob/main/docs/VAE.md). This is a [MAGVIT-v2](https://arxiv.org/abs/2310.05737) like continuous 3D VAE, which should be quite flexible. Feel free to build your own video generative model on this part of VAE training code.
245
+
246
+ ### 2. Finetuning DiT
247
+
248
+ The hardware requirements for finetuning DiT are at least 8 A100 GPUs. Please refer to [this document](https://github.com/jy0205/Pyramid-Flow/blob/main/docs/DiT.md). We provide instructions for both autoregressive and non-autoregressive versions of Pyramid Flow. The former is more research oriented and the latter is more stable (but less efficient without temporal pyramid).
249
+
250
+ ## Gallery
251
+
252
+ The following video examples are generated at 5s, 768p, 24fps. For more results, please visit our [project page](https://pyramid-flow.github.io).
253
+
254
+ <table class="center" border="0" style="width: 100%; text-align: left;">
255
+ <tr>
256
+ <td><video src="https://github.com/user-attachments/assets/5b44a57e-fa08-4554-84a2-2c7a99f2b343" autoplay muted loop playsinline></video></td>
257
+ <td><video src="https://github.com/user-attachments/assets/5afd5970-de72-40e2-900d-a20d18308e8e" autoplay muted loop playsinline></video></td>
258
+ </tr>
259
+ <tr>
260
+ <td><video src="https://github.com/user-attachments/assets/1d44daf8-017f-40e9-bf18-1e19c0a8983b" autoplay muted loop playsinline></video></td>
261
+ <td><video src="https://github.com/user-attachments/assets/7f5dd901-b7d7-48cc-b67a-3c5f9e1546d2" autoplay muted loop playsinline></video></td>
262
+ </tr>
263
+ </table>
264
+
265
+ ## Comparison
266
+
267
+ On VBench ([Huang et al., 2024](https://huggingface.co/spaces/Vchitect/VBench_Leaderboard)), our method surpasses all the compared open-source baselines. Even with only public video data, it achieves comparable performance to commercial models like Kling ([Kuaishou, 2024](https://kling.kuaishou.com/en)) and Gen-3 Alpha ([Runway, 2024](https://runwayml.com/research/introducing-gen-3-alpha)), especially in the quality score (84.74 vs. 84.11 of Gen-3) and motion smoothness.
268
+
269
+ ![vbench](assets/vbench.jpg)
270
+
271
+ We conduct an additional user study with 20+ participants. As can be seen, our method is preferred over open-source models such as [Open-Sora](https://github.com/hpcaitech/Open-Sora) and [CogVideoX-2B](https://github.com/THUDM/CogVideo) especially in terms of motion smoothness.
272
+
273
+ ![user_study](assets/user_study.jpg)
274
+
275
+ ## Acknowledgement
276
+
277
+ We are grateful for the following awesome projects when implementing Pyramid Flow:
278
+
279
+ * [SD3 Medium](https://huggingface.co/stabilityai/stable-diffusion-3-medium) and [Flux 1.0](https://huggingface.co/black-forest-labs/FLUX.1-dev): State-of-the-art image generation models based on flow matching.
280
+ * [Diffusion Forcing](https://boyuan.space/diffusion-forcing) and [GameNGen](https://gamengen.github.io): Next-token prediction meets full-sequence diffusion.
281
+ * [WebVid-10M](https://github.com/m-bain/webvid), [OpenVid-1M](https://github.com/NJU-PCALab/OpenVid-1M) and [Open-Sora Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan): Large-scale datasets for text-to-video generation.
282
+ * [CogVideoX](https://github.com/THUDM/CogVideo): An open-source text-to-video generation model that shares many training details.
283
+ * [Video-LLaMA2](https://github.com/DAMO-NLP-SG/VideoLLaMA2): An open-source video LLM for our video recaptioning.
284
+
285
+ ## Citation
286
+
287
+ Consider giving this repository a star and cite Pyramid Flow in your publications if it helps your research.
288
+ ```
289
+ @article{jin2024pyramidal,
290
+ title={Pyramidal Flow Matching for Efficient Video Generative Modeling},
291
+ author={Jin, Yang and Sun, Zhicheng and Li, Ningyuan and Xu, Kun and Xu, Kun and Jiang, Hao and Zhuang, Nan and Huang, Quzhe and Song, Yang and Mu, Yadong and Lin, Zhouchen},
292
+ jounal={arXiv preprint arXiv:2410.05954},
293
+ year={2024}
294
+ }
295
+ ```
README.md CHANGED
@@ -1,12 +1,295 @@
1
- ---
2
- title: Pyramid Flow
3
- emoji: 😻
4
- colorFrom: green
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.6.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Pyramid-Flow
3
+ app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 5.6.0
6
+ ---
7
+ <div align="center">
8
+
9
+ # Pyramid Flow
10
+
11
+ [[Paper]](https://arxiv.org/abs/2410.05954) [[Project Page]](https://pyramid-flow.github.io) [[miniFLUX Model]](https://huggingface.co/rain1011/pyramid-flow-miniflux) [[SD3 Model]](https://huggingface.co/rain1011/pyramid-flow-sd3) [[demo](https://huggingface.co/spaces/Pyramid-Flow/pyramid-flow)]
12
+
13
+ </div>
14
+
15
+ This is the official repository for Pyramid Flow, a training-efficient **Autoregressive Video Generation** method based on **Flow Matching**. By training only on **open-source datasets**, it can generate high-quality 10-second videos at 768p resolution and 24 FPS, and naturally supports image-to-video generation.
16
+
17
+ <table class="center" border="0" style="width: 100%; text-align: left;">
18
+ <tr>
19
+ <th>10s, 768p, 24fps</th>
20
+ <th>5s, 768p, 24fps</th>
21
+ <th>Image-to-video</th>
22
+ </tr>
23
+ <tr>
24
+ <td><video src="https://github.com/user-attachments/assets/9935da83-ae56-4672-8747-0f46e90f7b2b" autoplay muted loop playsinline></video></td>
25
+ <td><video src="https://github.com/user-attachments/assets/3412848b-64db-4d9e-8dbf-11403f6d02c5" autoplay muted loop playsinline></video></td>
26
+ <td><video src="https://github.com/user-attachments/assets/3bd7251f-7b2c-4bee-951d-656fdb45f427" autoplay muted loop playsinline></video></td>
27
+ </tr>
28
+ </table>
29
+
30
+ ## News
31
+ * `2024.11.13` We release the [768p miniFLUX checkpoint](https://huggingface.co/rain1011/pyramid-flow-miniflux) (up to 10s).
32
+
33
+ > We have switched the model structure from SD3 to a mini FLUX to fix human structure issues, please try our 1024p image checkpoint, 384p video checkpoint (up to 5s) and 768p video checkpoint (up to 10s). The new miniflux model shows great improvement on human structure and motion stability
34
+
35
+ * `2024.10.29` We release [training code for VAE](#1-training-vae), [finetuning code for DiT](#2-finetuning-dit) and [new model checkpoints](https://huggingface.co/rain1011/pyramid-flow-miniflux) with FLUX structure trained from scratch.
36
+
37
+
38
+ * `2024.10.13` [Multi-GPU inference](#3-multi-gpu-inference) and [CPU offloading](#cpu-offloading) are supported. Use it with **less than 8GB** of GPU memory, with great speedup on multiple GPUs.
39
+
40
+ * `2024.10.11` [Hugging Face demo](https://huggingface.co/spaces/Pyramid-Flow/pyramid-flow) is available. Thanks [@multimodalart](https://huggingface.co/multimodalart) for the commit!
41
+
42
+ * `2024.10.10` We release the [technical report](https://arxiv.org/abs/2410.05954), [project page](https://pyramid-flow.github.io) and [model checkpoint](https://huggingface.co/rain1011/pyramid-flow-sd3) of Pyramid Flow.
43
+
44
+ ## Table of Contents
45
+
46
+ * [Introduction](#introduction)
47
+ * [Installation](#installation)
48
+ * [Inference](#inference)
49
+ 1. [Quick Start with Gradio](#1-quick-start-with-gradio)
50
+ 2. [Inference Code](#2-inference-code)
51
+ 3. [Multi-GPU Inference](#3-multi-gpu-inference)
52
+ 4. [Usage Tips](#4-usage-tips)
53
+ * [Training](#Training)
54
+ 1. [Training VAE](#training-vae)
55
+ 2. [Finetuning DiT](#finetuning-dit)
56
+ * [Gallery](#gallery)
57
+ * [Comparison](#comparison)
58
+ * [Acknowledgement](#acknowledgement)
59
+ * [Citation](#citation)
60
+
61
+ ## Introduction
62
+
63
+ ![motivation](assets/motivation.jpg)
64
+
65
+ Existing video diffusion models operate at full resolution, spending a lot of computation on very noisy latents. By contrast, our method harnesses the flexibility of flow matching ([Lipman et al., 2023](https://openreview.net/forum?id=PqvMRDCJT9t); [Liu et al., 2023](https://openreview.net/forum?id=XVjTT1nw5z); [Albergo & Vanden-Eijnden, 2023](https://openreview.net/forum?id=li7qeBbCR1t)) to interpolate between latents of different resolutions and noise levels, allowing for simultaneous generation and decompression of visual content with better computational efficiency. The entire framework is end-to-end optimized with a single DiT ([Peebles & Xie, 2023](http://openaccess.thecvf.com/content/ICCV2023/html/Peebles_Scalable_Diffusion_Models_with_Transformers_ICCV_2023_paper.html)), generating high-quality 10-second videos at 768p resolution and 24 FPS within 20.7k A100 GPU training hours.
66
+
67
+ ## Installation
68
+
69
+ We recommend setting up the environment with conda. The codebase currently uses Python 3.8.10 and PyTorch 2.1.2 ([guide](https://pytorch.org/get-started/previous-versions/#v212)), and we are actively working to support a wider range of versions.
70
+
71
+ ```bash
72
+ git clone https://github.com/jy0205/Pyramid-Flow
73
+ cd Pyramid-Flow
74
+
75
+ # create env using conda
76
+ conda create -n pyramid python==3.8.10
77
+ conda activate pyramid
78
+ pip install -r requirements.txt
79
+ ```
80
+
81
+ Then, download the model from [Huggingface](https://huggingface.co/rain1011) (there are two variants: [miniFLUX](https://huggingface.co/rain1011/pyramid-flow-miniflux) or [SD3](https://huggingface.co/rain1011/pyramid-flow-sd3)). The miniFLUX models support 1024p image, 384p and 768p video generation, and the SD3-based models support 768p and 384p video generation. The 384p checkpoint generates 5-second video at 24FPS, while the 768p checkpoint generates up to 10-second video at 24FPS.
82
+
83
+ ```python
84
+ from huggingface_hub import snapshot_download
85
+
86
+ model_path = 'PATH' # The local directory to save downloaded checkpoint
87
+ snapshot_download("rain1011/pyramid-flow-miniflux", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
88
+ ```
89
+
90
+ ## Inference
91
+
92
+ ### 1. Quick start with Gradio
93
+
94
+ To get started, first install [Gradio](https://www.gradio.app/guides/quickstart), set your model path at [#L36](https://github.com/jy0205/Pyramid-Flow/blob/3777f8b84bddfa2aa2b497ca919b3f40567712e6/app.py#L36), and then run on your local machine:
95
+
96
+ ```bash
97
+ python app.py
98
+ ```
99
+
100
+ The Gradio demo will be opened in a browser. Thanks to [@tpc2233](https://github.com/tpc2233) the commit, see [#48](https://github.com/jy0205/Pyramid-Flow/pull/48) for details.
101
+
102
+ Or, try it out effortlessly on [Hugging Face Space](https://huggingface.co/spaces/Pyramid-Flow/pyramid-flow) created by [@multimodalart](https://huggingface.co/multimodalart). Due to GPU limits, this online demo can only generate 25 frames (export at 8FPS or 24FPS). Duplicate the space to generate longer videos.
103
+
104
+ #### Quick Start on Google Colab
105
+
106
+ To quickly try out Pyramid Flow on Google Colab, run the code below:
107
+
108
+ ```
109
+ # Setup
110
+ !git clone https://github.com/jy0205/Pyramid-Flow
111
+ %cd Pyramid-Flow
112
+ !pip install -r requirements.txt
113
+ !pip install gradio
114
+
115
+ # This code downloads miniFLUX
116
+ from huggingface_hub import snapshot_download
117
+
118
+ model_path = '/content/Pyramid-Flow'
119
+ snapshot_download("rain1011/pyramid-flow-miniflux", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
120
+
121
+ # Start
122
+ !python app.py
123
+ ```
124
+
125
+ ### 2. Inference Code
126
+
127
+ To use our model, please follow the inference code in `video_generation_demo.ipynb` at [this link](https://github.com/jy0205/Pyramid-Flow/blob/main/video_generation_demo.ipynb). We strongly recommend you to try the latest published pyramid-miniflux, which shows great improvement on human structure and motion stability. Set the param `model_name` to `pyramid_flux` to use. We further simplify it into the following two-step procedure. First, load the downloaded model:
128
+
129
+ ```python
130
+ import torch
131
+ from PIL import Image
132
+ from pyramid_dit import PyramidDiTForVideoGeneration
133
+ from diffusers.utils import load_image, export_to_video
134
+
135
+ torch.cuda.set_device(0)
136
+ model_dtype, torch_dtype = 'bf16', torch.bfloat16 # Use bf16 (not support fp16 yet)
137
+
138
+ model = PyramidDiTForVideoGeneration(
139
+ 'PATH', # The downloaded checkpoint dir
140
+ model_name="pyramid_flux",
141
+ model_dtype,
142
+ model_variant='diffusion_transformer_768p',
143
+ )
144
+
145
+ model.vae.enable_tiling()
146
+ # model.vae.to("cuda")
147
+ # model.dit.to("cuda")
148
+ # model.text_encoder.to("cuda")
149
+
150
+ # if you're not using sequential offloading bellow uncomment the lines above ^
151
+ model.enable_sequential_cpu_offload()
152
+ ```
153
+
154
+ Then, you can try text-to-video generation on your own prompts. Noting that the 384p version only support 5s now (set temp up to 16)!
155
+
156
+ ```python
157
+ prompt = "A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors"
158
+
159
+ # used for 384p model variant
160
+ # width = 640
161
+ # height = 384
162
+
163
+ # used for 768p model variant
164
+ width = 1280
165
+ height = 768
166
+
167
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
168
+ frames = model.generate(
169
+ prompt=prompt,
170
+ num_inference_steps=[20, 20, 20],
171
+ video_num_inference_steps=[10, 10, 10],
172
+ height=height,
173
+ width=width,
174
+ temp=16, # temp=16: 5s, temp=31: 10s
175
+ guidance_scale=7.0, # The guidance for the first frame, set it to 7 for 384p variant
176
+ video_guidance_scale=5.0, # The guidance for the other video latent
177
+ output_type="pil",
178
+ save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
179
+ )
180
+
181
+ export_to_video(frames, "./text_to_video_sample.mp4", fps=24)
182
+ ```
183
+
184
+ As an autoregressive model, our model also supports (text conditioned) image-to-video generation:
185
+
186
+ ```python
187
+ # used for 384p model variant
188
+ # width = 640
189
+ # height = 384
190
+
191
+ # used for 768p model variant
192
+ width = 1280
193
+ height = 768
194
+
195
+ image = Image.open('assets/the_great_wall.jpg').convert("RGB").resize((width, height))
196
+ prompt = "FPV flying over the Great Wall"
197
+
198
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
199
+ frames = model.generate_i2v(
200
+ prompt=prompt,
201
+ input_image=image,
202
+ num_inference_steps=[10, 10, 10],
203
+ temp=16,
204
+ video_guidance_scale=4.0,
205
+ output_type="pil",
206
+ save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
207
+ )
208
+
209
+ export_to_video(frames, "./image_to_video_sample.mp4", fps=24)
210
+ ```
211
+
212
+ #### CPU offloading
213
+
214
+ We also support two types of CPU offloading to reduce GPU memory requirements. Note that they may sacrifice efficiency.
215
+ * Adding a `cpu_offloading=True` parameter to the generate function allows inference with **less than 12GB** of GPU memory. This feature was contributed by [@Ednaordinary](https://github.com/Ednaordinary), see [#23](https://github.com/jy0205/Pyramid-Flow/pull/23) for details.
216
+ * Calling `model.enable_sequential_cpu_offload()` before the above procedure allows inference with **less than 8GB** of GPU memory. This feature was contributed by [@rodjjo](https://github.com/rodjjo), see [#75](https://github.com/jy0205/Pyramid-Flow/pull/75) for details.
217
+
218
+ #### MPS backend
219
+
220
+ Thanks to [@niw](https://github.com/niw), Apple Silicon users (e.g. MacBook Pro with M2 24GB) can also try our model using the MPS backend! Please see [#113](https://github.com/jy0205/Pyramid-Flow/pull/113) for the details.
221
+
222
+ ### 3. Multi-GPU Inference
223
+
224
+ For users with multiple GPUs, we provide an [inference script](https://github.com/jy0205/Pyramid-Flow/blob/main/scripts/inference_multigpu.sh) that uses sequence parallelism to save memory on each GPU. This also brings a big speedup, taking only 2.5 minutes to generate a 5s, 768p, 24fps video on 4 A100 GPUs (vs. 5.5 minutes on a single A100 GPU). Run it on 2 GPUs with the following command:
225
+
226
+ ```bash
227
+ CUDA_VISIBLE_DEVICES=0,1 sh scripts/inference_multigpu.sh
228
+ ```
229
+
230
+ It currently supports 2 or 4 GPUs (For SD3 Version), with more configurations available in the original script. You can also launch a [multi-GPU Gradio demo](https://github.com/jy0205/Pyramid-Flow/blob/main/scripts/app_multigpu_engine.sh) created by [@tpc2233](https://github.com/tpc2233), see [#59](https://github.com/jy0205/Pyramid-Flow/pull/59) for details.
231
+
232
+ > Spoiler: We didn't even use sequence parallelism in training, thanks to our efficient pyramid flow designs.
233
+
234
+ ### 4. Usage tips
235
+
236
+ * The `guidance_scale` parameter controls the visual quality. We suggest using a guidance within [7, 9] for the 768p checkpoint during text-to-video generation, and 7 for the 384p checkpoint.
237
+ * The `video_guidance_scale` parameter controls the motion. A larger value increases the dynamic degree and mitigates the autoregressive generation degradation, while a smaller value stabilizes the video.
238
+ * For 10-second video generation, we recommend using a guidance scale of 7 and a video guidance scale of 5.
239
+
240
+ ## Training
241
+
242
+ ### 1. Training VAE
243
+
244
+ The hardware requirements for training VAE are at least 8 A100 GPUs. Please refer to [this document](https://github.com/jy0205/Pyramid-Flow/blob/main/docs/VAE.md). This is a [MAGVIT-v2](https://arxiv.org/abs/2310.05737) like continuous 3D VAE, which should be quite flexible. Feel free to build your own video generative model on this part of VAE training code.
245
+
246
+ ### 2. Finetuning DiT
247
+
248
+ The hardware requirements for finetuning DiT are at least 8 A100 GPUs. Please refer to [this document](https://github.com/jy0205/Pyramid-Flow/blob/main/docs/DiT.md). We provide instructions for both autoregressive and non-autoregressive versions of Pyramid Flow. The former is more research oriented and the latter is more stable (but less efficient without temporal pyramid).
249
+
250
+ ## Gallery
251
+
252
+ The following video examples are generated at 5s, 768p, 24fps. For more results, please visit our [project page](https://pyramid-flow.github.io).
253
+
254
+ <table class="center" border="0" style="width: 100%; text-align: left;">
255
+ <tr>
256
+ <td><video src="https://github.com/user-attachments/assets/5b44a57e-fa08-4554-84a2-2c7a99f2b343" autoplay muted loop playsinline></video></td>
257
+ <td><video src="https://github.com/user-attachments/assets/5afd5970-de72-40e2-900d-a20d18308e8e" autoplay muted loop playsinline></video></td>
258
+ </tr>
259
+ <tr>
260
+ <td><video src="https://github.com/user-attachments/assets/1d44daf8-017f-40e9-bf18-1e19c0a8983b" autoplay muted loop playsinline></video></td>
261
+ <td><video src="https://github.com/user-attachments/assets/7f5dd901-b7d7-48cc-b67a-3c5f9e1546d2" autoplay muted loop playsinline></video></td>
262
+ </tr>
263
+ </table>
264
+
265
+ ## Comparison
266
+
267
+ On VBench ([Huang et al., 2024](https://huggingface.co/spaces/Vchitect/VBench_Leaderboard)), our method surpasses all the compared open-source baselines. Even with only public video data, it achieves comparable performance to commercial models like Kling ([Kuaishou, 2024](https://kling.kuaishou.com/en)) and Gen-3 Alpha ([Runway, 2024](https://runwayml.com/research/introducing-gen-3-alpha)), especially in the quality score (84.74 vs. 84.11 of Gen-3) and motion smoothness.
268
+
269
+ ![vbench](assets/vbench.jpg)
270
+
271
+ We conduct an additional user study with 20+ participants. As can be seen, our method is preferred over open-source models such as [Open-Sora](https://github.com/hpcaitech/Open-Sora) and [CogVideoX-2B](https://github.com/THUDM/CogVideo) especially in terms of motion smoothness.
272
+
273
+ ![user_study](assets/user_study.jpg)
274
+
275
+ ## Acknowledgement
276
+
277
+ We are grateful for the following awesome projects when implementing Pyramid Flow:
278
+
279
+ * [SD3 Medium](https://huggingface.co/stabilityai/stable-diffusion-3-medium) and [Flux 1.0](https://huggingface.co/black-forest-labs/FLUX.1-dev): State-of-the-art image generation models based on flow matching.
280
+ * [Diffusion Forcing](https://boyuan.space/diffusion-forcing) and [GameNGen](https://gamengen.github.io): Next-token prediction meets full-sequence diffusion.
281
+ * [WebVid-10M](https://github.com/m-bain/webvid), [OpenVid-1M](https://github.com/NJU-PCALab/OpenVid-1M) and [Open-Sora Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan): Large-scale datasets for text-to-video generation.
282
+ * [CogVideoX](https://github.com/THUDM/CogVideo): An open-source text-to-video generation model that shares many training details.
283
+ * [Video-LLaMA2](https://github.com/DAMO-NLP-SG/VideoLLaMA2): An open-source video LLM for our video recaptioning.
284
+
285
+ ## Citation
286
+
287
+ Consider giving this repository a star and cite Pyramid Flow in your publications if it helps your research.
288
+ ```
289
+ @article{jin2024pyramidal,
290
+ title={Pyramidal Flow Matching for Efficient Video Generative Modeling},
291
+ author={Jin, Yang and Sun, Zhicheng and Li, Ningyuan and Xu, Kun and Xu, Kun and Jiang, Hao and Zhuang, Nan and Huang, Quzhe and Song, Yang and Mu, Yadong and Lin, Zhouchen},
292
+ jounal={arXiv preprint arXiv:2410.05954},
293
+ year={2024}
294
+ }
295
+ ```
annotation/image_text.jsonl ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"image": "SAM_filter/000424/sa_4749867.jpg", "text": "a cityscape with a large body of water, such as a lake or a river, in the foreground"}
2
+ {"image": "SAM_filter/000311/sa_3490721.jpg", "text": "a large, stately building with a white and blue color scheme, which gives it a grand and elegant appearance"}
3
+ {"image": "SAM_filter/000273/sa_3059407.jpg", "text": "a close-up of a green bag containing a package of Japanese soybeans, along with a bottle of sake, a traditional Japanese alcoholic beverage"}
4
+ {"image": "SAM_filter/000745/sa_8344729.jpg", "text": "a large, old-fashioned building with a red and white color scheme"}
5
+ {"image": "SAM_filter/000832/sa_9310794.jpg", "text": "a cityscape with a large tower, likely the Eiffel Tower, as the main focal point"}
6
+ {"image": "SAM_filter/000427/sa_4779422.jpg", "text": "a large cruise ship, specifically a Royal Caribbean cruise ship, docked at a pier in a harbor"}
7
+ {"image": "SAM_filter/000105/sa_1178255.jpg", "text": "a close-up view of a computer screen with a magnifying glass placed over it"}
8
+ {"image": "SAM_filter/000765/sa_8560467.jpg", "text": "a tree with a sign attached to it, which is located in a lush green field"}
9
+ {"image": "SAM_filter/000216/sa_2417372.jpg", "text": "a large airport terminal with a long blue and white rope-style security line"}
10
+ {"image": "SAM_filter/000385/sa_4308806.jpg", "text": "a close-up of a cell phone screen displaying a blue and white logo, which appears to be a bank logo"}
11
+ {"image": "SAM_filter/000931/sa_10425835.jpg", "text": "a large body of water, possibly a lake, with a lush green landscape surrounding it"}
12
+ {"image": "SAM_filter/000364/sa_4079002.jpg", "text": "a large, empty airport terminal with a long row of gray metal chairs arranged in a straight line"}
13
+ {"image": "SAM_filter/000474/sa_5306222.jpg", "text": "a large, modern building with a tall, glass structure, which is likely a museum"}
14
+ {"image": "SAM_filter/000584/sa_6536849.jpg", "text": "a city street scene with a black car parked in a parking lot, a building with a balcony, and a city skyline in the background"}
15
+ {"image": "SAM_filter/000188/sa_2104485.jpg", "text": "a large jet fighter airplane flying through the sky, captured in a high-quality photograph"}
16
+ {"image": "SAM_filter/000219/sa_2458908.jpg", "text": "a stone structure with a tall tower, which is situated in a lush green garden"}
17
+ {"image": "SAM_filter/000440/sa_4929413.jpg", "text": "a large city street with a mix of architectural styles, including a Gothic-style building and a modern building"}
18
+ {"image": "SAM_filter/000739/sa_8279296.jpg", "text": "a vintage blue and white bus parked on the side of a dirt road, with a building in the background"}
19
+ {"image": "SAM_filter/000809/sa_9052304.jpg", "text": "a large, old stone building with a clock tower, which is situated in a small town"}
20
+ {"image": "SAM_filter/000294/sa_3300200.jpg", "text": "a table with various utensils, including a bowl, spoon, and fork, placed on a wooden surface"}
annotation/video_text.jsonl ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"video": "webvid10m/train/010451_010500/23388121.mp4", "text": "the serene beauty of a valley with a river, mountains, and clouds", "latent": "webvid10m/train/010451_010500/23388121-latent-384-2.pt", "text_fea": "text_feature/webvid10m/train/010451_010500/23388121-text.pt"}
2
+ {"video": "pexels/8440980-uhd_3840_2160_25fps.mp4", "text": "A group of people, including two men and two women, are seen sitting at a table, smiling and waving at the camera, and appear to be in a good mood", "latent": "pexels/8440980-uhd_3840_2160_25fps-latent-384-2.pt", "text_fea": "text_feature/pexels/8440980-uhd_3840_2160_25fps-text.pt"}
3
+ {"video": "webvid10m/train/176251_176300/1011015221.mp4", "text": "an aerial view of a large wheat field with a road running through it, and a car driving on the road", "latent": "webvid10m/train/176251_176300/1011015221-latent-384-4.pt", "text_fea": "text_feature/webvid10m/train/176251_176300/1011015221-text.pt"}
4
+ {"video": "webvid10m/train/005801_005850/22143805.mp4", "text": "a close-up of paint mixing in water, creating swirling patterns", "latent": "webvid10m/train/005801_005850/22143805-latent-384-8.pt", "text_fea": "text_feature/webvid10m/train/005801_005850/22143805-text.pt"}
5
+ {"video": "OpenVid-1M/videos/qsXY7FkNFwE_2_0to743.mp4", "text": "A baby girl in a pink shirt and striped pants sits in a high chair, eats a piece of bread, and looks at the camera", "latent": "OpenVid-1M/videos/qsXY7FkNFwE_2_0to743-latent-384-0.pt", "text_fea": "text_feature/OpenVid-1M/videos/qsXY7FkNFwE_2_0to743-text.pt"}
6
+ {"video": "webvid10m/train/134901_134950/1037990273.mp4", "text": "a field of green wheat waving in the wind", "latent": "webvid10m/train/134901_134950/1037990273-latent-384-6.pt", "text_fea": "text_feature/webvid10m/train/134901_134950/1037990273-text.pt"}
7
+ {"video": "pexels/5263258-uhd_2160_4096_30fps.mp4", "text": "A dog sits patiently in front of its bowl, waiting for it to be filled with food", "latent": "pexels/5263258-uhd_2160_4096_30fps-latent-384-6.pt", "text_fea": "text_feature/pexels/5263258-uhd_2160_4096_30fps-text.pt"}
8
+ {"video": "webvid10m/train/117851_117900/6461432.mp4", "text": "A ladybug crawls along a blade of grass in a serene natural setting", "latent": "webvid10m/train/117851_117900/6461432-latent-384-4.pt", "text_fea": "text_feature/webvid10m/train/117851_117900/6461432-text.pt"}
9
+ {"video": "webvid10m/train/053051_053100/1058396656.mp4", "text": "a group of construction workers working on a rooftop, with a supervisor overseeing the work", "latent": "webvid10m/train/053051_053100/1058396656-latent-384-10.pt", "text_fea": "text_feature/webvid10m/train/053051_053100/1058396656-text.pt"}
10
+ {"video": "webvid10m/train/073651_073700/1021916425.mp4", "text": "an aerial view of a beautiful coastline with rocky islands, blue water, and a white cloud in the sky", "latent": "webvid10m/train/073651_073700/1021916425-latent-384-4.pt", "text_fea": "text_feature/webvid10m/train/073651_073700/1021916425-text.pt"}
11
+ {"video": "webvid10m/train/027051_027100/1032549941.mp4", "text": "a young woman waking up in bed, smiling at the camera, and then lying back down on the bed", "latent": "webvid10m/train/027051_027100/1032549941-latent-384-10.pt", "text_fea": "text_feature/webvid10m/train/027051_027100/1032549941-text.pt"}
12
+ {"video": "pexels/5564564-uhd_3840_2160_24fps.mp4", "text": "a person rolling out dough on a table using a rolling pin", "latent": "pexels/5564564-uhd_3840_2160_24fps-latent-384-8.pt", "text_fea": "text_feature/pexels/5564564-uhd_3840_2160_24fps-text.pt"}
13
+ {"video": "webvid10m/train/073701_073750/24008116.mp4", "text": "a cityscape with a moon in the sky, and the camera pans across the city", "latent": "webvid10m/train/073701_073750/24008116-latent-384-2.pt", "text_fea": "text_feature/webvid10m/train/073701_073750/24008116-text.pt"}
14
+ {"video": "webvid10m/train/118351_118400/23370991.mp4", "text": "a group of dolphins swimming in the ocean, with a person on a boat nearby", "latent": "webvid10m/train/118351_118400/23370991-latent-384-2.pt", "text_fea": "text_feature/webvid10m/train/118351_118400/23370991-text.pt"}
15
+ {"video": "webvid10m/train/022001_022050/1023013066.mp4", "text": "a bird's eye view of a beachfront city, highlighting the hotels, pools, and proximity to the ocean", "latent": "webvid10m/train/022001_022050/1023013066-latent-384-10.pt", "text_fea": "text_feature/webvid10m/train/022001_022050/1023013066-text.pt"}
16
+ {"video": "webvid10m/train/004601_004650/1015979020.mp4", "text": "a bridge over a body of water, with a boat passing under it", "latent": "webvid10m/train/004601_004650/1015979020-latent-384-4.pt", "text_fea": "text_feature/webvid10m/train/004601_004650/1015979020-text.pt"}
17
+ {"video": "webvid10m/train/149701_149750/1034525579.mp4", "text": "a group of owls and a moon, with the moon appearing to grow larger as the video progresses", "latent": "webvid10m/train/149701_149750/1034525579-latent-384-2.pt", "text_fea": "text_feature/webvid10m/train/149701_149750/1034525579-text.pt"}
app.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import gradio as gr
4
+ import torch
5
+ import PIL
6
+ from PIL import Image
7
+ from pyramid_dit import PyramidDiTForVideoGeneration
8
+ from diffusers.utils import export_to_video
9
+ from huggingface_hub import snapshot_download
10
+ import threading
11
+ import random
12
+
13
+ # Global model cache
14
+ model_cache = {}
15
+
16
+ # Lock to ensure thread-safe access to the model cache
17
+ model_cache_lock = threading.Lock()
18
+
19
+ # Configuration
20
+ model_name = "pyramid_flux" # or pyramid_mmdit
21
+ model_repo = "rain1011/pyramid-flow-sd3" if model_name == "pyramid_mmdit" else "rain1011/pyramid-flow-miniflux"
22
+
23
+ model_dtype = "bf16" # Support bf16 and fp32
24
+ variants = {
25
+ 'high': 'diffusion_transformer_768p', # For high-resolution version
26
+ 'low': 'diffusion_transformer_384p' # For low-resolution version
27
+ }
28
+ required_file = 'config.json' # Ensure config.json is present
29
+ width_high = 1280
30
+ height_high = 768
31
+ width_low = 640
32
+ height_low = 384
33
+ cpu_offloading = True # enable cpu_offloading by default
34
+
35
+ # Get the current working directory and create a folder to store the model
36
+ current_directory = os.getcwd()
37
+ model_path = os.path.join(current_directory, "pyramid_flow_model") # Directory to store the model
38
+
39
+ # Download the model if not already present
40
+ def download_model_from_hf(model_repo, model_dir, variants, required_file):
41
+ need_download = False
42
+ if not os.path.exists(model_dir):
43
+ print(f"[INFO] Model directory '{model_dir}' does not exist. Initiating download...")
44
+ need_download = True
45
+ else:
46
+ # Check if all required files exist for each variant
47
+ for variant_key, variant_dir in variants.items():
48
+ variant_path = os.path.join(model_dir, variant_dir)
49
+ file_path = os.path.join(variant_path, required_file)
50
+ if not os.path.exists(file_path):
51
+ print(f"[WARNING] Required file '{required_file}' missing in '{variant_path}'.")
52
+ need_download = True
53
+ break
54
+
55
+ if need_download:
56
+ print(f"[INFO] Downloading model from '{model_repo}' to '{model_dir}'...")
57
+ try:
58
+ snapshot_download(
59
+ repo_id=model_repo,
60
+ local_dir=model_dir,
61
+ local_dir_use_symlinks=False,
62
+ repo_type='model'
63
+ )
64
+ print("[INFO] Model download complete.")
65
+ except Exception as e:
66
+ print(f"[ERROR] Failed to download the model: {e}")
67
+ raise
68
+ else:
69
+ print(f"[INFO] All required model files are present in '{model_dir}'. Skipping download.")
70
+
71
+ # Download model from Hugging Face if not present
72
+ download_model_from_hf(model_repo, model_path, variants, required_file)
73
+
74
+ # Function to initialize the model based on user options
75
+ def initialize_model(variant):
76
+ print(f"[INFO] Initializing model with variant='{variant}', using bf16 precision...")
77
+
78
+ # Determine the correct variant directory
79
+ variant_dir = variants['high'] if variant == '768p' else variants['low']
80
+ base_path = model_path # Pass the base model path
81
+
82
+ print(f"[DEBUG] Model base path: {base_path}")
83
+
84
+ # Verify that config.json exists in the variant directory
85
+ config_path = os.path.join(model_path, variant_dir, 'config.json')
86
+ if not os.path.exists(config_path):
87
+ print(f"[ERROR] config.json not found in '{os.path.join(model_path, variant_dir)}'.")
88
+ raise FileNotFoundError(f"config.json not found in '{os.path.join(model_path, variant_dir)}'.")
89
+
90
+ if model_dtype == "bf16":
91
+ torch_dtype_selected = torch.bfloat16
92
+ else:
93
+ torch_dtype_selected = torch.float32
94
+
95
+ # Initialize the model
96
+ try:
97
+
98
+ model = PyramidDiTForVideoGeneration(
99
+ base_path, # Pass the base model path
100
+ model_name=model_name, # set to pyramid_flux or pyramid_mmdit
101
+ model_dtype=model_dtype, # Use bf16
102
+ model_variant=variant_dir, # Pass the variant directory name
103
+ cpu_offloading=cpu_offloading, # Pass the CPU offloading flag
104
+ )
105
+
106
+ # Always enable tiling for the VAE
107
+ model.vae.enable_tiling()
108
+
109
+ # Remove manual device placement when using CPU offloading
110
+ # The components will be moved to the appropriate devices automatically
111
+ if torch.cuda.is_available():
112
+ torch.cuda.set_device(0)
113
+ # Manual device replacement when not using CPU offloading
114
+ if not cpu_offloading:
115
+ model.vae.to("cuda")
116
+ model.dit.to("cuda")
117
+ model.text_encoder.to("cuda")
118
+ else:
119
+ print("[WARNING] CUDA is not available. Proceeding without GPU.")
120
+
121
+ print("[INFO] Model initialized successfully.")
122
+ return model, torch_dtype_selected
123
+ except Exception as e:
124
+ print(f"[ERROR] Error initializing model: {e}")
125
+ raise
126
+
127
+ # Function to get the model from cache or initialize it
128
+ def initialize_model_cached(variant, seed):
129
+ key = variant
130
+
131
+ if seed == 0:
132
+ seed = random.randint(0, 2**8 - 1)
133
+ torch.manual_seed(seed)
134
+ if torch.cuda.is_available():
135
+ torch.cuda.manual_seed(seed)
136
+ torch.cuda.manual_seed_all(seed)
137
+
138
+ # Check if the model is already in the cache
139
+ if key not in model_cache:
140
+ with model_cache_lock:
141
+ # Double-checked locking to prevent race conditions
142
+ if key not in model_cache:
143
+ model, dtype = initialize_model(variant)
144
+ model_cache[key] = (model, dtype)
145
+
146
+ return model_cache[key]
147
+
148
+ def resize_crop_image(img: PIL.Image.Image, tgt_width, tgt_height):
149
+ ori_width, ori_height = img.width, img.height
150
+ scale = max(tgt_width / ori_width, tgt_height / ori_height)
151
+ resized_width = round(ori_width * scale)
152
+ resized_height = round(ori_height * scale)
153
+ img = img.resize((resized_width, resized_height), resample=PIL.Image.LANCZOS)
154
+
155
+ left = (resized_width - tgt_width) / 2
156
+ top = (resized_height - tgt_height) / 2
157
+ right = (resized_width + tgt_width) / 2
158
+ bottom = (resized_height + tgt_height) / 2
159
+
160
+ # Crop the center of the image
161
+ img = img.crop((left, top, right, bottom))
162
+
163
+ return img
164
+
165
+ # Function to generate text-to-video
166
+ def generate_text_to_video(prompt, temp, guidance_scale, video_guidance_scale, resolution, seed, progress=gr.Progress()):
167
+ progress(0, desc="Loading model")
168
+ print("[DEBUG] generate_text_to_video called.")
169
+ variant = '768p' if resolution == "768p" else '384p'
170
+ height = height_high if resolution == "768p" else height_low
171
+ width = width_high if resolution == "768p" else width_low
172
+
173
+ def progress_callback(i, m):
174
+ progress(i/m)
175
+
176
+ # Initialize model based on user options using cached function
177
+ try:
178
+ model, torch_dtype_selected = initialize_model_cached(variant, seed)
179
+ except Exception as e:
180
+ print(f"[ERROR] Model initialization failed: {e}")
181
+ return f"Model initialization failed: {e}"
182
+
183
+ try:
184
+ print("[INFO] Starting text-to-video generation...")
185
+ with torch.no_grad(), torch.autocast('cuda', dtype=torch_dtype_selected):
186
+ frames = model.generate(
187
+ prompt=prompt,
188
+ num_inference_steps=[20, 20, 20],
189
+ video_num_inference_steps=[10, 10, 10],
190
+ height=height,
191
+ width=width,
192
+ temp=temp,
193
+ guidance_scale=guidance_scale,
194
+ video_guidance_scale=video_guidance_scale,
195
+ output_type="pil",
196
+ cpu_offloading=cpu_offloading,
197
+ save_memory=True,
198
+ callback=progress_callback,
199
+ )
200
+ print("[INFO] Text-to-video generation completed.")
201
+ except Exception as e:
202
+ print(f"[ERROR] Error during text-to-video generation: {e}")
203
+ return f"Error during video generation: {e}"
204
+
205
+ video_path = f"{str(uuid.uuid4())}_text_to_video_sample.mp4"
206
+ try:
207
+ export_to_video(frames, video_path, fps=24)
208
+ print(f"[INFO] Video exported to {video_path}.")
209
+ except Exception as e:
210
+ print(f"[ERROR] Error exporting video: {e}")
211
+ return f"Error exporting video: {e}"
212
+ return video_path
213
+
214
+ # Function to generate image-to-video
215
+ def generate_image_to_video(image, prompt, temp, video_guidance_scale, resolution, seed, progress=gr.Progress()):
216
+ progress(0, desc="Loading model")
217
+ print("[DEBUG] generate_image_to_video called.")
218
+ variant = '768p' if resolution == "768p" else '384p'
219
+ height = height_high if resolution == "768p" else height_low
220
+ width = width_high if resolution == "768p" else width_low
221
+
222
+ try:
223
+ image = resize_crop_image(image, width, height)
224
+ print("[INFO] Image resized and cropped successfully.")
225
+ except Exception as e:
226
+ print(f"[ERROR] Error processing image: {e}")
227
+ return f"Error processing image: {e}"
228
+
229
+ def progress_callback(i, m):
230
+ progress(i/m)
231
+
232
+ # Initialize model based on user options using cached function
233
+ try:
234
+ model, torch_dtype_selected = initialize_model_cached(variant, seed)
235
+ except Exception as e:
236
+ print(f"[ERROR] Model initialization failed: {e}")
237
+ return f"Model initialization failed: {e}"
238
+
239
+ try:
240
+ print("[INFO] Starting image-to-video generation...")
241
+ with torch.no_grad(), torch.autocast('cuda', dtype=torch_dtype_selected):
242
+ frames = model.generate_i2v(
243
+ prompt=prompt,
244
+ input_image=image,
245
+ num_inference_steps=[10, 10, 10],
246
+ temp=temp,
247
+ video_guidance_scale=video_guidance_scale,
248
+ output_type="pil",
249
+ cpu_offloading=cpu_offloading,
250
+ save_memory=True,
251
+ callback=progress_callback,
252
+ )
253
+ print("[INFO] Image-to-video generation completed.")
254
+ except Exception as e:
255
+ print(f"[ERROR] Error during image-to-video generation: {e}")
256
+ return f"Error during video generation: {e}"
257
+
258
+ video_path = f"{str(uuid.uuid4())}_image_to_video_sample.mp4"
259
+ try:
260
+ export_to_video(frames, video_path, fps=24)
261
+ print(f"[INFO] Video exported to {video_path}.")
262
+ except Exception as e:
263
+ print(f"[ERROR] Error exporting video: {e}")
264
+ return f"Error exporting video: {e}"
265
+ return video_path
266
+
267
+ def update_slider(resolution):
268
+ if resolution == "768p":
269
+ return [gr.update(maximum=31), gr.update(maximum=31)]
270
+ else:
271
+ return [gr.update(maximum=16), gr.update(maximum=16)]
272
+
273
+ # Gradio interface
274
+ with gr.Blocks() as demo:
275
+ gr.Markdown(
276
+ """
277
+ # Pyramid Flow Video Generation Demo
278
+
279
+ Pyramid Flow is a training-efficient **Autoregressive Video Generation** model based on **Flow Matching**. It is trained only on open-source datasets within 20.7k A100 GPU hours.
280
+
281
+ [[Paper]](https://arxiv.org/abs/2410.05954) [[Project Page]](https://pyramid-flow.github.io) [[Code]](https://github.com/jy0205/Pyramid-Flow) [[Model]](https://huggingface.co/rain1011/pyramid-flow-sd3)
282
+ """
283
+ )
284
+
285
+ # Shared settings
286
+ with gr.Row():
287
+ resolution_dropdown = gr.Dropdown(
288
+ choices=["768p", "384p"],
289
+ value="384p",
290
+ label="Model Resolution"
291
+ )
292
+
293
+ with gr.Tab("Text-to-Video"):
294
+ with gr.Row():
295
+ with gr.Column():
296
+ text_prompt = gr.Textbox(label="Prompt (Less than 128 words)", placeholder="Enter a text prompt for the video", lines=2)
297
+ temp_slider = gr.Slider(1, 16, value=16, step=1, label="Duration")
298
+ guidance_scale_slider = gr.Slider(1.0, 15.0, value=9.0, step=0.1, label="Guidance Scale")
299
+ video_guidance_scale_slider = gr.Slider(1.0, 10.0, value=5.0, step=0.1, label="Video Guidance Scale")
300
+ text_seed = gr.Number(label="Inference Seed (Enter a positive number, 0 for random)", value=0)
301
+ txt_generate = gr.Button("Generate Video")
302
+ with gr.Column():
303
+ txt_output = gr.Video(label="Generated Video")
304
+ gr.Examples(
305
+ examples=[
306
+ ["A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors", 16, 7.0, 5.0, "384p"],
307
+ ["Beautiful, snowy Tokyo city is bustling. The camera moves through the bustling city street, following several people enjoying the beautiful snowy weather and shopping at nearby stalls. Gorgeous sakura petals are flying through the wind along with snowflakes", 16, 7.0, 5.0, "384p"],
308
+ # ["Extreme close-up of chicken and green pepper kebabs grilling on a barbeque with flames. Shallow focus and light smoke. vivid colours", 31, 9.0, 5.0, "768p"],
309
+ ],
310
+ inputs=[text_prompt, temp_slider, guidance_scale_slider, video_guidance_scale_slider, resolution_dropdown, text_seed],
311
+ outputs=[txt_output],
312
+ fn=generate_text_to_video,
313
+ cache_examples='lazy',
314
+ )
315
+
316
+ with gr.Tab("Image-to-Video"):
317
+ with gr.Row():
318
+ with gr.Column():
319
+ image_input = gr.Image(type="pil", label="Input Image")
320
+ image_prompt = gr.Textbox(label="Prompt (Less than 128 words)", placeholder="Enter a text prompt for the video", lines=2)
321
+ image_temp_slider = gr.Slider(2, 16, value=16, step=1, label="Duration")
322
+ image_video_guidance_scale_slider = gr.Slider(1.0, 7.0, value=4.0, step=0.1, label="Video Guidance Scale")
323
+ image_seed = gr.Number(label="Inference Seed (Enter a positive number, 0 for random)", value=0)
324
+ img_generate = gr.Button("Generate Video")
325
+ with gr.Column():
326
+ img_output = gr.Video(label="Generated Video")
327
+ gr.Examples(
328
+ examples=[
329
+ ['assets/the_great_wall.jpg', 'FPV flying over the Great Wall', 16, 4.0, "384p"]
330
+ ],
331
+ inputs=[image_input, image_prompt, image_temp_slider, image_video_guidance_scale_slider, resolution_dropdown, image_seed],
332
+ outputs=[img_output],
333
+ fn=generate_image_to_video,
334
+ cache_examples='lazy',
335
+ )
336
+
337
+ # Update generate functions to include resolution options
338
+ txt_generate.click(
339
+ generate_text_to_video,
340
+ inputs=[text_prompt, temp_slider, guidance_scale_slider, video_guidance_scale_slider, resolution_dropdown, text_seed],
341
+ outputs=txt_output
342
+ )
343
+
344
+ img_generate.click(
345
+ generate_image_to_video,
346
+ inputs=[image_input, image_prompt, image_temp_slider, image_video_guidance_scale_slider, resolution_dropdown, image_seed],
347
+ outputs=img_output
348
+ )
349
+ resolution_dropdown.change(
350
+ fn=update_slider,
351
+ inputs=resolution_dropdown,
352
+ outputs=[temp_slider, image_temp_slider]
353
+ )
354
+
355
+ # Launch Gradio app
356
+ demo.launch(share=True)
app_multigpu.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import gradio as gr
4
+ import subprocess
5
+ import tempfile
6
+ import shutil
7
+
8
+ def run_inference_multigpu(gpus, variant, model_path, temp, guidance_scale, video_guidance_scale, resolution, prompt):
9
+ """
10
+ Runs the external multi-GPU inference script and returns the path to the generated video.
11
+ """
12
+ # Create a temporary directory to store inputs and outputs
13
+ with tempfile.TemporaryDirectory() as tmpdir:
14
+ output_video = os.path.join(tmpdir, f"{uuid.uuid4()}_output.mp4")
15
+
16
+ # Path to the external shell script
17
+ script_path = "./scripts/app_multigpu_engine.sh" # Updated script path
18
+
19
+ # Prepare the command
20
+ cmd = [
21
+ script_path,
22
+ str(gpus),
23
+ variant,
24
+ model_path,
25
+ 't2v', # Task is always 't2v' since 'i2v' is removed
26
+ str(temp),
27
+ str(guidance_scale),
28
+ str(video_guidance_scale),
29
+ resolution,
30
+ output_video,
31
+ prompt # Pass the prompt directly as an argument
32
+ ]
33
+
34
+ try:
35
+ # Run the external script
36
+ subprocess.run(cmd, check=True)
37
+ except subprocess.CalledProcessError as e:
38
+ raise RuntimeError(f"Error during video generation: {e}")
39
+
40
+ # After generation, move the video to a permanent location
41
+ final_output = os.path.join("generated_videos", f"{uuid.uuid4()}_output.mp4")
42
+ os.makedirs("generated_videos", exist_ok=True)
43
+ shutil.move(output_video, final_output)
44
+
45
+ return final_output
46
+
47
+ def generate_text_to_video(prompt, temp, guidance_scale, video_guidance_scale, resolution, gpus):
48
+ model_path = "./pyramid_flow_model" # Use the model path as specified
49
+ # Determine variant based on resolution
50
+ if resolution == "768p":
51
+ variant = "diffusion_transformer_768p"
52
+ else:
53
+ variant = "diffusion_transformer_384p"
54
+ return run_inference_multigpu(gpus, variant, model_path, temp, guidance_scale, video_guidance_scale, resolution, prompt)
55
+
56
+ # Gradio interface
57
+ with gr.Blocks() as demo:
58
+ gr.Markdown(
59
+ """
60
+ # Pyramid Flow Video Generation Demo
61
+
62
+ Pyramid Flow is a training-efficient **Autoregressive Video Generation** model based on **Flow Matching**. It is trained only on open-source datasets within 20.7k A100 GPU hours.
63
+
64
+ [[Paper]](https://arxiv.org/abs/2410.05954) [[Project Page]](https://pyramid-flow.github.io) [[Code]](https://github.com/jy0205/Pyramid-Flow) [[Model]](https://huggingface.co/rain1011/pyramid-flow-sd3)
65
+ """
66
+ )
67
+
68
+ # Shared settings
69
+ with gr.Row():
70
+ gpus_dropdown = gr.Dropdown(
71
+ choices=[2, 4],
72
+ value=4,
73
+ label="Number of GPUs"
74
+ )
75
+ resolution_dropdown = gr.Dropdown(
76
+ choices=["768p", "384p"],
77
+ value="768p",
78
+ label="Model Resolution"
79
+ )
80
+
81
+ with gr.Tab("Text-to-Video"):
82
+ with gr.Row():
83
+ with gr.Column():
84
+ text_prompt = gr.Textbox(
85
+ label="Prompt (Less than 128 words)",
86
+ placeholder="Enter a text prompt for the video",
87
+ lines=2
88
+ )
89
+ temp_slider = gr.Slider(1, 31, value=16, step=1, label="Duration")
90
+ guidance_scale_slider = gr.Slider(1.0, 15.0, value=9.0, step=0.1, label="Guidance Scale")
91
+ video_guidance_scale_slider = gr.Slider(1.0, 10.0, value=5.0, step=0.1, label="Video Guidance Scale")
92
+ txt_generate = gr.Button("Generate Video")
93
+ with gr.Column():
94
+ txt_output = gr.Video(label="Generated Video")
95
+ gr.Examples(
96
+ examples=[
97
+ [
98
+ "A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors",
99
+ 16,
100
+ 9.0,
101
+ 5.0,
102
+ "768p",
103
+ 4
104
+ ],
105
+ [
106
+ "Beautiful, snowy Tokyo city is bustling. The camera moves through the bustling city street, following several people enjoying the beautiful snowy weather and shopping at nearby stalls. Gorgeous sakura petals are flying through the wind along with snowflakes",
107
+ 16,
108
+ 9.0,
109
+ 5.0,
110
+ "768p",
111
+ 4
112
+ ],
113
+ [
114
+ "Extreme close-up of chicken and green pepper kebabs grilling on a barbeque with flames. Shallow focus and light smoke. vivid colours",
115
+ 31,
116
+ 9.0,
117
+ 5.0,
118
+ "768p",
119
+ 4
120
+ ],
121
+ ],
122
+ inputs=[text_prompt, temp_slider, guidance_scale_slider, video_guidance_scale_slider, resolution_dropdown, gpus_dropdown],
123
+ outputs=[txt_output],
124
+ fn=generate_text_to_video,
125
+ cache_examples='lazy',
126
+ )
127
+
128
+ # Update generate function for Text-to-Video
129
+ txt_generate.click(
130
+ generate_text_to_video,
131
+ inputs=[
132
+ text_prompt,
133
+ temp_slider,
134
+ guidance_scale_slider,
135
+ video_guidance_scale_slider,
136
+ resolution_dropdown,
137
+ gpus_dropdown
138
+ ],
139
+ outputs=txt_output
140
+ )
141
+
142
+ # Launch Gradio app
143
+ demo.launch(share=True)
assets/motivation.jpg ADDED
assets/the_great_wall.jpg ADDED
assets/user_study.jpg ADDED
assets/vbench.jpg ADDED
causal_video_vae_demo.ipynb ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import json\n",
11
+ "import cv2\n",
12
+ "import torch\n",
13
+ "import numpy as np\n",
14
+ "import PIL\n",
15
+ "from PIL import Image\n",
16
+ "from einops import rearrange\n",
17
+ "from video_vae import CausalVideoVAELossWrapper\n",
18
+ "from torchvision import transforms as pth_transforms\n",
19
+ "from torchvision.transforms.functional import InterpolationMode\n",
20
+ "from IPython.display import Image as ipython_image\n",
21
+ "from diffusers.utils import load_image, export_to_video, export_to_gif\n",
22
+ "from IPython.display import HTML"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "model_path = \"pyramid-flow-miniflux/causal_video_vae\" # The video-vae checkpoint dir\n",
32
+ "model_dtype = 'bf16'\n",
33
+ "\n",
34
+ "device_id = 3\n",
35
+ "torch.cuda.set_device(device_id)\n",
36
+ "\n",
37
+ "model = CausalVideoVAELossWrapper(\n",
38
+ " model_path,\n",
39
+ " model_dtype,\n",
40
+ " interpolate=False, \n",
41
+ " add_discriminator=False,\n",
42
+ ")\n",
43
+ "model = model.to(\"cuda\")\n",
44
+ "\n",
45
+ "if model_dtype == \"bf16\":\n",
46
+ " torch_dtype = torch.bfloat16 \n",
47
+ "elif model_dtype == \"fp16\":\n",
48
+ " torch_dtype = torch.float16\n",
49
+ "else:\n",
50
+ " torch_dtype = torch.float32\n",
51
+ "\n",
52
+ "def image_transform(images, resize_width, resize_height):\n",
53
+ " transform_list = pth_transforms.Compose([\n",
54
+ " pth_transforms.Resize((resize_height, resize_width), InterpolationMode.BICUBIC, antialias=True),\n",
55
+ " pth_transforms.ToTensor(),\n",
56
+ " pth_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
57
+ " ])\n",
58
+ " return torch.stack([transform_list(image) for image in images])\n",
59
+ "\n",
60
+ "\n",
61
+ "def get_transform(width, height, new_width=None, new_height=None, resize=False,):\n",
62
+ " transform_list = []\n",
63
+ "\n",
64
+ " if resize:\n",
65
+ " if new_width is None:\n",
66
+ " new_width = width // 8 * 8\n",
67
+ " if new_height is None:\n",
68
+ " new_height = height // 8 * 8\n",
69
+ " transform_list.append(pth_transforms.Resize((new_height, new_width), InterpolationMode.BICUBIC, antialias=True))\n",
70
+ " \n",
71
+ " transform_list.extend([\n",
72
+ " pth_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
73
+ " ])\n",
74
+ " transform_list = pth_transforms.Compose(transform_list)\n",
75
+ "\n",
76
+ " return transform_list\n",
77
+ "\n",
78
+ "\n",
79
+ "def load_video_and_transform(video_path, frame_number, new_width=None, new_height=None, max_frames=600, sample_fps=24, resize=False):\n",
80
+ " try:\n",
81
+ " video_capture = cv2.VideoCapture(video_path)\n",
82
+ " fps = video_capture.get(cv2.CAP_PROP_FPS)\n",
83
+ " frames = []\n",
84
+ " pil_frames = []\n",
85
+ " while True:\n",
86
+ " flag, frame = video_capture.read()\n",
87
+ " if not flag:\n",
88
+ " break\n",
89
+ " \n",
90
+ " pil_frames.append(np.ascontiguousarray(frame[:, :, ::-1]))\n",
91
+ " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
92
+ " frame = torch.from_numpy(frame)\n",
93
+ " frame = frame.permute(2, 0, 1)\n",
94
+ " frames.append(frame)\n",
95
+ " if len(frames) >= max_frames:\n",
96
+ " break\n",
97
+ "\n",
98
+ " video_capture.release()\n",
99
+ " interval = max(int(fps / sample_fps), 1)\n",
100
+ " pil_frames = pil_frames[::interval][:frame_number]\n",
101
+ " frames = frames[::interval][:frame_number]\n",
102
+ " frames = torch.stack(frames).float() / 255\n",
103
+ " width = frames.shape[-1]\n",
104
+ " height = frames.shape[-2]\n",
105
+ " video_transform = get_transform(width, height, new_width, new_height, resize=resize)\n",
106
+ " frames = video_transform(frames)\n",
107
+ " pil_frames = [Image.fromarray(frame).convert(\"RGB\") for frame in pil_frames]\n",
108
+ "\n",
109
+ " if resize:\n",
110
+ " if new_width is None:\n",
111
+ " new_width = width // 32 * 32\n",
112
+ " if new_height is None:\n",
113
+ " new_height = height // 32 * 32\n",
114
+ " pil_frames = [frame.resize((new_width or width, new_height or height), PIL.Image.BICUBIC) for frame in pil_frames]\n",
115
+ " return frames, pil_frames\n",
116
+ " except Exception:\n",
117
+ " return None\n",
118
+ "\n",
119
+ "\n",
120
+ "def show_video(ori_path, rec_path, width=\"100%\"):\n",
121
+ " html = ''\n",
122
+ " if ori_path is not None:\n",
123
+ " html += f\"\"\"<video controls=\"\" name=\"media\" data-fullscreen-container=\"true\" width=\"{width}\">\n",
124
+ " <source src=\"{ori_path}\" type=\"video/mp4\">\n",
125
+ " </video>\n",
126
+ " \"\"\"\n",
127
+ " \n",
128
+ " html += f\"\"\"<video controls=\"\" name=\"media\" data-fullscreen-container=\"true\" width=\"{width}\">\n",
129
+ " <source src=\"{rec_path}\" type=\"video/mp4\">\n",
130
+ " </video>\n",
131
+ " \"\"\"\n",
132
+ " return HTML(html)"
133
+ ]
134
+ },
135
+ {
136
+ "attachments": {},
137
+ "cell_type": "markdown",
138
+ "metadata": {},
139
+ "source": [
140
+ "### Image Reconstruction"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "image_path = 'image_path'\n",
150
+ "\n",
151
+ "image = Image.open(image_path).convert(\"RGB\")\n",
152
+ "resize_width = image.width // 8 * 8\n",
153
+ "resize_height = image.height // 8 * 8\n",
154
+ "input_image_tensor = image_transform([image], resize_width, resize_height)\n",
155
+ "input_image_tensor = input_image_tensor.permute(1, 0, 2, 3).unsqueeze(0)\n",
156
+ "\n",
157
+ "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):\n",
158
+ " latent = model.encode_latent(input_image_tensor.to(\"cuda\"), sample=True)\n",
159
+ " rec_images = model.decode_latent(latent)\n",
160
+ "\n",
161
+ "display(image)\n",
162
+ "display(rec_images[0])"
163
+ ]
164
+ },
165
+ {
166
+ "attachments": {},
167
+ "cell_type": "markdown",
168
+ "metadata": {},
169
+ "source": [
170
+ "### Video Reconstruction"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": null,
176
+ "metadata": {},
177
+ "outputs": [],
178
+ "source": [
179
+ "video_path = 'video_path'\n",
180
+ "\n",
181
+ "frame_number = 57 # x*8 + 1\n",
182
+ "width = 640\n",
183
+ "height = 384\n",
184
+ "\n",
185
+ "video_frames_tensor, pil_video_frames = load_video_and_transform(video_path, frame_number, new_width=width, new_height=height, resize=True)\n",
186
+ "video_frames_tensor = video_frames_tensor.permute(1, 0, 2, 3).unsqueeze(0)\n",
187
+ "print(video_frames_tensor.shape)\n",
188
+ "\n",
189
+ "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):\n",
190
+ " latent = model.encode_latent(video_frames_tensor.to(\"cuda\"), sample=False, window_size=8, temporal_chunk=True)\n",
191
+ " rec_frames = model.decode_latent(latent.float(), window_size=2, temporal_chunk=True)\n",
192
+ "\n",
193
+ "export_to_video(pil_video_frames, './ori_video.mp4', fps=24)\n",
194
+ "export_to_video(rec_frames, \"./rec_video.mp4\", fps=24)\n",
195
+ "show_video('./ori_video.mp4', \"./rec_video.mp4\", \"60%\")"
196
+ ]
197
+ }
198
+ ],
199
+ "metadata": {
200
+ "kernelspec": {
201
+ "display_name": "Python 3",
202
+ "language": "python",
203
+ "name": "python3"
204
+ },
205
+ "language_info": {
206
+ "codemirror_mode": {
207
+ "name": "ipython",
208
+ "version": 3
209
+ },
210
+ "file_extension": ".py",
211
+ "mimetype": "text/x-python",
212
+ "name": "python",
213
+ "nbconvert_exporter": "python",
214
+ "pygments_lexer": "ipython3",
215
+ "version": "3.8.10"
216
+ },
217
+ "orig_nbformat": 4
218
+ },
219
+ "nbformat": 4,
220
+ "nbformat_minor": 2
221
+ }
dataset/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .dataset_cls import (
2
+ ImageTextDataset,
3
+ LengthGroupedVideoTextDataset,
4
+ ImageDataset,
5
+ VideoDataset,
6
+ )
7
+
8
+ from .dataloaders import (
9
+ create_image_text_dataloaders,
10
+ create_length_grouped_video_text_dataloader,
11
+ create_mixed_dataloaders,
12
+ )
dataset/bucket_loader.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import numpy as np
4
+ import math
5
+ import random
6
+ import time
7
+
8
+
9
+ class Bucketeer:
10
+ def __init__(
11
+ self, dataloader,
12
+ sizes=[(256, 256), (192, 384), (192, 320), (384, 192), (320, 192)],
13
+ is_infinite=True, epoch=0,
14
+ ):
15
+ # Ratios and Sizes : (w h)
16
+ self.sizes = sizes
17
+ self.batch_size = dataloader.batch_size
18
+ self._dataloader = dataloader
19
+ self.iterator = iter(dataloader)
20
+ self.sampler = dataloader.sampler
21
+ self.buckets = {s: [] for s in self.sizes}
22
+ self.is_infinite = is_infinite
23
+ self._epoch = epoch
24
+
25
+ def get_available_batch(self):
26
+ available_size = []
27
+ for b in self.buckets:
28
+ if len(self.buckets[b]) >= self.batch_size:
29
+ available_size.append(b)
30
+
31
+ if len(available_size) == 0:
32
+ return None
33
+ else:
34
+ b = random.choice(available_size)
35
+ batch = self.buckets[b][:self.batch_size]
36
+ self.buckets[b] = self.buckets[b][self.batch_size:]
37
+ return batch
38
+
39
+ def __next__(self):
40
+ batch = self.get_available_batch()
41
+ while batch is None:
42
+ try:
43
+ elements = next(self.iterator)
44
+ except StopIteration:
45
+ # To make it infinity
46
+ if self.is_infinite:
47
+ self._epoch += 1
48
+ if hasattr(self._dataloader.sampler, "set_epoch"):
49
+ self._dataloader.sampler.set_epoch(self._epoch)
50
+ time.sleep(2) # Prevent possible deadlock during epoch transition
51
+ self.iterator = iter(self._dataloader)
52
+ elements = next(self.iterator)
53
+ else:
54
+ raise StopIteration
55
+
56
+ for dct in elements:
57
+ try:
58
+ img = dct['video']
59
+ size = (img.shape[-1], img.shape[-2])
60
+ self.buckets[size].append({**{'video': img}, **{k:dct[k] for k in dct if k != 'video'}})
61
+ except Exception as e:
62
+ continue
63
+
64
+ batch = self.get_available_batch()
65
+
66
+ out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
67
+ return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
68
+
69
+ def __iter__(self):
70
+ return self
71
+
72
+ def __len__(self):
73
+ return len(self.iterator)
74
+
75
+
76
+ class TemporalLengthBucketeer:
77
+ def __init__(
78
+ self, dataloader, max_frames=16, epoch=0,
79
+ ):
80
+ self.batch_size = dataloader.batch_size
81
+ self._dataloader = dataloader
82
+ self.iterator = iter(dataloader)
83
+ self.buckets = {temp: [] for temp in range(1, max_frames + 1)}
84
+ self._epoch = epoch
85
+
86
+ def get_available_batch(self):
87
+ available_size = []
88
+ for b in self.buckets:
89
+ if len(self.buckets[b]) >= self.batch_size:
90
+ available_size.append(b)
91
+
92
+ if len(available_size) == 0:
93
+ return None
94
+ else:
95
+ b = random.choice(available_size)
96
+ batch = self.buckets[b][:self.batch_size]
97
+ self.buckets[b] = self.buckets[b][self.batch_size:]
98
+ return batch
99
+
100
+ def __next__(self):
101
+ batch = self.get_available_batch()
102
+ while batch is None:
103
+ try:
104
+ elements = next(self.iterator)
105
+ except StopIteration:
106
+ # To make it infinity
107
+ self._epoch += 1
108
+ if hasattr(self._dataloader.sampler, "set_epoch"):
109
+ self._dataloader.sampler.set_epoch(self._epoch)
110
+ time.sleep(2) # Prevent possible deadlock during epoch transition
111
+ self.iterator = iter(self._dataloader)
112
+ elements = next(self.iterator)
113
+
114
+ for dct in elements:
115
+ try:
116
+ video_latent = dct['video']
117
+ temp = video_latent.shape[2]
118
+ self.buckets[temp].append({**{'video': video_latent}, **{k:dct[k] for k in dct if k != 'video'}})
119
+ except Exception as e:
120
+ continue
121
+
122
+ batch = self.get_available_batch()
123
+
124
+ out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]}
125
+ out = {k: torch.cat(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()}
126
+
127
+ if 'prompt_embed' in out:
128
+ # Loading the pre-extrcted textual features
129
+ prompt_embeds = out['prompt_embed'].clone()
130
+ del out['prompt_embed']
131
+ prompt_attention_mask = out['prompt_attention_mask'].clone()
132
+ del out['prompt_attention_mask']
133
+ pooled_prompt_embeds = out['pooled_prompt_embed'].clone()
134
+ del out['pooled_prompt_embed']
135
+
136
+ out['text'] = {
137
+ 'prompt_embeds' : prompt_embeds,
138
+ 'prompt_attention_mask': prompt_attention_mask,
139
+ 'pooled_prompt_embeds': pooled_prompt_embeds,
140
+ }
141
+
142
+ return out
143
+
144
+ def __iter__(self):
145
+ return self
146
+
147
+ def __len__(self):
148
+ return len(self.iterator)
dataset/dataloaders.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import time
5
+ import random
6
+ from typing import Iterable
7
+
8
+ from collections import OrderedDict
9
+ from PIL import Image
10
+ from torch.utils.data import Dataset, DataLoader, ConcatDataset, IterableDataset, DistributedSampler, RandomSampler
11
+ from torch.utils.data.dataloader import default_collate
12
+ from torchvision import transforms
13
+ from torchvision.transforms.functional import InterpolationMode
14
+ from torchvision.transforms import functional as F
15
+ from .bucket_loader import Bucketeer, TemporalLengthBucketeer
16
+
17
+
18
+ class IterLoader:
19
+ """
20
+ A wrapper to convert DataLoader as an infinite iterator.
21
+
22
+ Modified from:
23
+ https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
24
+ """
25
+
26
+ def __init__(self, dataloader: DataLoader, use_distributed: bool = False, epoch: int = 0):
27
+ self._dataloader = dataloader
28
+ self.iter_loader = iter(self._dataloader)
29
+ self._use_distributed = use_distributed
30
+ self._epoch = epoch
31
+
32
+ @property
33
+ def epoch(self) -> int:
34
+ return self._epoch
35
+
36
+ def __next__(self):
37
+ try:
38
+ data = next(self.iter_loader)
39
+ except StopIteration:
40
+ self._epoch += 1
41
+ if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
42
+ self._dataloader.sampler.set_epoch(self._epoch)
43
+ time.sleep(2) # Prevent possible deadlock during epoch transition
44
+ self.iter_loader = iter(self._dataloader)
45
+ data = next(self.iter_loader)
46
+
47
+ return data
48
+
49
+ def __iter__(self):
50
+ return self
51
+
52
+ def __len__(self):
53
+ return len(self._dataloader)
54
+
55
+
56
+ def identity(x):
57
+ return x
58
+
59
+
60
+ def create_image_text_dataloaders(dataset, batch_size, num_workers,
61
+ multi_aspect_ratio=True, epoch=0, sizes=[(512, 512), (384, 640), (640, 384)],
62
+ use_distributed=True, world_size=None, rank=None,
63
+ ):
64
+ """
65
+ The dataset has already been splited by different rank
66
+ """
67
+ if use_distributed:
68
+ assert world_size is not None
69
+ assert rank is not None
70
+ sampler = DistributedSampler(
71
+ dataset,
72
+ shuffle=True,
73
+ num_replicas=world_size,
74
+ rank=rank,
75
+ seed=epoch,
76
+ )
77
+ else:
78
+ sampler = RandomSampler(dataset)
79
+
80
+ dataloader = DataLoader(
81
+ dataset,
82
+ batch_size=batch_size,
83
+ num_workers=num_workers,
84
+ pin_memory=True,
85
+ sampler=sampler,
86
+ collate_fn=identity if multi_aspect_ratio else default_collate,
87
+ drop_last=True,
88
+ )
89
+
90
+ if multi_aspect_ratio:
91
+ dataloader_iterator = Bucketeer(
92
+ dataloader,
93
+ sizes=sizes,
94
+ is_infinite=True, epoch=epoch,
95
+ )
96
+ else:
97
+ dataloader_iterator = iter(dataloader)
98
+
99
+ # To make it infinite
100
+ loader = IterLoader(dataloader_iterator, use_distributed=False, epoch=epoch)
101
+
102
+ return loader
103
+
104
+
105
+ def create_length_grouped_video_text_dataloader(dataset, batch_size, num_workers, max_frames,
106
+ world_size=None, rank=None, epoch=0, use_distributed=False):
107
+ if use_distributed:
108
+ assert world_size is not None
109
+ assert rank is not None
110
+ sampler = DistributedSampler(
111
+ dataset,
112
+ shuffle=True,
113
+ num_replicas=world_size,
114
+ rank=rank,
115
+ seed=epoch,
116
+ )
117
+ else:
118
+ sampler = RandomSampler(dataset)
119
+
120
+ dataloader = DataLoader(
121
+ dataset,
122
+ batch_size=batch_size,
123
+ num_workers=num_workers,
124
+ pin_memory=True,
125
+ sampler=sampler,
126
+ collate_fn=identity,
127
+ drop_last=True,
128
+ )
129
+
130
+ # make it infinite
131
+ dataloader_iterator = TemporalLengthBucketeer(
132
+ dataloader,
133
+ max_frames=max_frames,
134
+ epoch=epoch,
135
+ )
136
+
137
+ return dataloader_iterator
138
+
139
+
140
+ def create_mixed_dataloaders(
141
+ dataset, batch_size, num_workers, world_size=None, rank=None, epoch=0,
142
+ image_mix_ratio=0.1, use_image_video_mixed_training=True,
143
+ ):
144
+ """
145
+ The video & image mixed training dataloader builder
146
+ """
147
+
148
+ assert world_size is not None
149
+ assert rank is not None
150
+
151
+ image_gpus = max(1, int(world_size * image_mix_ratio))
152
+ if use_image_video_mixed_training:
153
+ video_gpus = world_size - image_gpus
154
+ else:
155
+ # only use video data
156
+ video_gpus = world_size
157
+ image_gpus = 0
158
+
159
+ print(f"{image_gpus} gpus for image, {video_gpus} gpus for video")
160
+
161
+ if rank < video_gpus:
162
+ sampler = DistributedSampler(
163
+ dataset,
164
+ shuffle=True,
165
+ num_replicas=video_gpus,
166
+ rank=rank,
167
+ seed=epoch,
168
+ )
169
+ else:
170
+ sampler = DistributedSampler(
171
+ dataset,
172
+ shuffle=True,
173
+ num_replicas=image_gpus,
174
+ rank=rank - video_gpus,
175
+ seed=epoch,
176
+ )
177
+
178
+ loader = DataLoader(
179
+ dataset,
180
+ batch_size=batch_size,
181
+ num_workers=num_workers,
182
+ pin_memory=True,
183
+ sampler=sampler,
184
+ collate_fn=default_collate,
185
+ drop_last=True,
186
+ )
187
+
188
+ # To make it infinite
189
+ loader = IterLoader(loader, use_distributed=True, epoch=epoch)
190
+ return loader
dataset/dataset_cls.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import jsonlines
4
+ import torch
5
+ import math
6
+ import random
7
+ import cv2
8
+
9
+ from tqdm import tqdm
10
+ from collections import OrderedDict
11
+
12
+ from PIL import Image
13
+ from PIL import ImageFile
14
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
15
+
16
+ import numpy as np
17
+ import subprocess
18
+ from torch.utils.data import Dataset, DataLoader
19
+ from torchvision import transforms
20
+ from torchvision.transforms.functional import InterpolationMode
21
+ from torchvision.transforms import functional as F
22
+
23
+
24
+ class ImageTextDataset(Dataset):
25
+ """
26
+ Usage:
27
+ The dataset class for image-text pairs, used for image generation training
28
+ It supports multi-aspect ratio training
29
+ params:
30
+ anno_file: The annotation file list
31
+ add_normalize: whether to normalize the input image pixel to [-1, 1], default: True
32
+ ratios: The aspect ratios during training, format: width / height
33
+ sizes: The resoultion of training images, format: (width, height)
34
+ """
35
+ def __init__(
36
+ self, anno_file, add_normalize=True,
37
+ ratios=[1/1, 3/5, 5/3],
38
+ sizes=[(1024, 1024), (768, 1280), (1280, 768)],
39
+ crop_mode='random', p_random_ratio=0.0,
40
+ ):
41
+ # Ratios and Sizes : (w h)
42
+ super().__init__()
43
+
44
+ self.image_annos = []
45
+ if not isinstance(anno_file, list):
46
+ anno_file = [anno_file]
47
+
48
+ for anno_file_ in anno_file:
49
+ print(f"Load image annotation files from {anno_file_}")
50
+ with jsonlines.open(anno_file_, 'r') as reader:
51
+ for item in reader:
52
+ self.image_annos.append(item)
53
+
54
+ print(f"Totally Remained {len(self.image_annos)} images")
55
+
56
+ transform_list = [
57
+ transforms.ToTensor(),
58
+ ]
59
+
60
+ if add_normalize:
61
+ transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
62
+
63
+ self.transform = transforms.Compose(transform_list)
64
+
65
+ print(f"Transform List is {transform_list}")
66
+
67
+ assert crop_mode in ['center', 'random']
68
+ self.crop_mode = crop_mode
69
+ self.ratios = ratios
70
+ self.sizes = sizes
71
+ self.p_random_ratio = p_random_ratio
72
+
73
+ def get_closest_size(self, x):
74
+ if self.p_random_ratio > 0 and np.random.rand() < self.p_random_ratio:
75
+ best_size_idx = np.random.randint(len(self.ratios))
76
+ else:
77
+ w, h = x.width, x.height
78
+ best_size_idx = np.argmin([abs(w/h-r) for r in self.ratios])
79
+ return self.sizes[best_size_idx]
80
+
81
+ def get_resize_size(self, orig_size, tgt_size):
82
+ if (tgt_size[1]/tgt_size[0] - 1) * (orig_size[1]/orig_size[0] - 1) >= 0:
83
+ alt_min = int(math.ceil(max(tgt_size)*min(orig_size)/max(orig_size)))
84
+ resize_size = max(alt_min, min(tgt_size))
85
+ else:
86
+ alt_max = int(math.ceil(min(tgt_size)*max(orig_size)/min(orig_size)))
87
+ resize_size = max(alt_max, max(tgt_size))
88
+ return resize_size
89
+
90
+ def __len__(self):
91
+ return len(self.image_annos)
92
+
93
+ def __getitem__(self, index):
94
+ image_anno = self.image_annos[index]
95
+
96
+ try:
97
+ img = Image.open(image_anno['image']).convert("RGB")
98
+ text = image_anno['text']
99
+
100
+ assert isinstance(text, str), "Text should be str"
101
+
102
+ size = self.get_closest_size(img)
103
+ resize_size = self.get_resize_size((img.width, img.height), size)
104
+
105
+ img = transforms.functional.resize(img, resize_size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True)
106
+
107
+ if self.crop_mode == 'center':
108
+ img = transforms.functional.center_crop(img, (size[1], size[0]))
109
+ elif self.crop_mode == 'random':
110
+ img = transforms.RandomCrop((size[1], size[0]))(img)
111
+ else:
112
+ img = transforms.functional.center_crop(img, (size[1], size[0]))
113
+
114
+ image_tensor = self.transform(img)
115
+
116
+ return {
117
+ "video": image_tensor, # using keyname `video`, to be compatible with video
118
+ "text" : text,
119
+ "identifier": 'image',
120
+ }
121
+
122
+ except Exception as e:
123
+ print(f'Load Image Error with {e}')
124
+ return self.__getitem__(random.randint(0, self.__len__() - 1))
125
+
126
+
127
+ class LengthGroupedVideoTextDataset(Dataset):
128
+ """
129
+ Usage:
130
+ The dataset class for video-text pairs, used for video generation training
131
+ It groups the video with the same frames together
132
+ Now only supporting fixed resolution during training
133
+ params:
134
+ anno_file: The annotation file list
135
+ max_frames: The maximum temporal lengths (This is the vae latent temporal length) 16 => (16 - 1) * 8 + 1 = 121 frames
136
+ load_vae_latent: Loading the pre-extracted vae latents during training, we recommend to extract the latents in advance
137
+ to reduce the time cost per batch
138
+ load_text_fea: Loading the pre-extracted text features during training, we recommend to extract the prompt textual features
139
+ in advance, since the T5 encoder will cost many GPU memories
140
+ """
141
+
142
+ def __init__(self, anno_file, max_frames=16, resolution='384p', load_vae_latent=True, load_text_fea=True):
143
+ super().__init__()
144
+
145
+ self.video_annos = []
146
+ self.max_frames = max_frames
147
+ self.load_vae_latent = load_vae_latent
148
+ self.load_text_fea = load_text_fea
149
+ self.resolution = resolution
150
+
151
+ assert load_vae_latent, "Now only support loading vae latents, we will support to directly load video frames in the future"
152
+
153
+ if not isinstance(anno_file, list):
154
+ anno_file = [anno_file]
155
+
156
+ for anno_file_ in anno_file:
157
+ with jsonlines.open(anno_file_, 'r') as reader:
158
+ for item in tqdm(reader):
159
+ self.video_annos.append(item)
160
+
161
+ print(f"Totally Remained {len(self.video_annos)} videos")
162
+
163
+ def __len__(self):
164
+ return len(self.video_annos)
165
+
166
+ def __getitem__(self, index):
167
+ try:
168
+ video_anno = self.video_annos[index]
169
+ text = video_anno['text']
170
+ latent_path = video_anno['latent']
171
+ latent = torch.load(latent_path, map_location='cpu') # loading the pre-extracted video latents
172
+
173
+ # TODO: remove the hard code latent shape checking
174
+ if self.resolution == '384p':
175
+ assert latent.shape[-1] == 640 // 8
176
+ assert latent.shape[-2] == 384 // 8
177
+ else:
178
+ assert self.resolution == '768p'
179
+ assert latent.shape[-1] == 1280 // 8
180
+ assert latent.shape[-2] == 768 // 8
181
+
182
+ cur_temp = latent.shape[2]
183
+ cur_temp = min(cur_temp, self.max_frames)
184
+
185
+ video_latent = latent[:,:,:cur_temp].float()
186
+ assert video_latent.shape[1] == 16
187
+
188
+ if self.load_text_fea:
189
+ text_fea_path = video_anno['text_fea']
190
+ text_fea = torch.load(text_fea_path, map_location='cpu')
191
+ return {
192
+ 'video': video_latent,
193
+ 'prompt_embed': text_fea['prompt_embed'],
194
+ 'prompt_attention_mask': text_fea['prompt_attention_mask'],
195
+ 'pooled_prompt_embed': text_fea['pooled_prompt_embed'],
196
+ "identifier": 'video',
197
+ }
198
+
199
+ else:
200
+ return {
201
+ 'video': video_latent,
202
+ 'text': text,
203
+ "identifier": 'video',
204
+ }
205
+
206
+ except Exception as e:
207
+ print(f'Load Video Error with {e}')
208
+ return self.__getitem__(random.randint(0, self.__len__() - 1))
209
+
210
+
211
+ class VideoFrameProcessor:
212
+ # load a video and transform
213
+ def __init__(self, resolution=256, num_frames=24, add_normalize=True, sample_fps=24):
214
+
215
+ image_size = resolution
216
+
217
+ transform_list = [
218
+ transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC, antialias=True),
219
+ transforms.CenterCrop(image_size),
220
+ ]
221
+
222
+ if add_normalize:
223
+ transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
224
+
225
+ print(f"Transform List is {transform_list}")
226
+ self.num_frames = num_frames
227
+ self.transform = transforms.Compose(transform_list)
228
+ self.sample_fps = sample_fps
229
+
230
+ def __call__(self, video_path):
231
+ try:
232
+ video_capture = cv2.VideoCapture(video_path)
233
+ fps = video_capture.get(cv2.CAP_PROP_FPS)
234
+ frames = []
235
+
236
+ while True:
237
+ flag, frame = video_capture.read()
238
+ if not flag:
239
+ break
240
+
241
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
242
+ frame = torch.from_numpy(frame)
243
+ frame = frame.permute(2, 0, 1)
244
+ frames.append(frame)
245
+
246
+ video_capture.release()
247
+ sample_fps = self.sample_fps
248
+ interval = max(int(fps / sample_fps), 1)
249
+ frames = frames[::interval]
250
+
251
+ if len(frames) < self.num_frames:
252
+ num_frame_to_pack = self.num_frames - len(frames)
253
+ recurrent_num = num_frame_to_pack // len(frames)
254
+ frames = frames + recurrent_num * frames + frames[:(num_frame_to_pack % len(frames))]
255
+ assert len(frames) >= self.num_frames, f'{len(frames)}'
256
+
257
+ start_indexs = list(range(0, max(0, len(frames) - self.num_frames + 1)))
258
+ start_index = random.choice(start_indexs)
259
+
260
+ filtered_frames = frames[start_index : start_index+self.num_frames]
261
+ assert len(filtered_frames) == self.num_frames, f"The sampled frames should equals to {self.num_frames}"
262
+
263
+ filtered_frames = torch.stack(filtered_frames).float() / 255
264
+ filtered_frames = self.transform(filtered_frames)
265
+ filtered_frames = filtered_frames.permute(1, 0, 2, 3)
266
+
267
+ return filtered_frames, None
268
+
269
+ except Exception as e:
270
+ print(f"Load video: {video_path} Error, Exception {e}")
271
+ return None, None
272
+
273
+
274
+ class VideoDataset(Dataset):
275
+ def __init__(self, anno_file, resolution=256, max_frames=6, add_normalize=True):
276
+ super().__init__()
277
+
278
+ self.video_annos = []
279
+ self.max_frames = max_frames
280
+
281
+ if not isinstance(anno_file, list):
282
+ anno_file = [anno_file]
283
+
284
+ print(f"The training video clip frame number is {max_frames} ")
285
+
286
+ for anno_file_ in anno_file:
287
+ print(f"Load annotation file from {anno_file_}")
288
+
289
+ with jsonlines.open(anno_file_, 'r') as reader:
290
+ for item in tqdm(reader):
291
+ self.video_annos.append(item)
292
+
293
+ print(f"Totally Remained {len(self.video_annos)} videos")
294
+
295
+ self.video_processor = VideoFrameProcessor(resolution, max_frames, add_normalize)
296
+
297
+ def __len__(self):
298
+ return len(self.video_annos)
299
+
300
+ def __getitem__(self, index):
301
+ video_anno = self.video_annos[index]
302
+ video_path = video_anno['video']
303
+
304
+ try:
305
+ video_tensors, video_frames = self.video_processor(video_path)
306
+
307
+ assert video_tensors.shape[1] == self.max_frames
308
+
309
+ return {
310
+ "video": video_tensors,
311
+ "identifier": 'video',
312
+ }
313
+
314
+ except Exception as e:
315
+ print('Loading Video Error with {e}')
316
+ return self.__getitem__(random.randint(0, self.__len__() - 1))
317
+
318
+
319
+ class ImageDataset(Dataset):
320
+ def __init__(self, anno_file, resolution=256, max_frames=8, add_normalize=True):
321
+ super().__init__()
322
+
323
+ self.image_annos = []
324
+ self.max_frames = max_frames
325
+ image_paths = []
326
+
327
+ if not isinstance(anno_file, list):
328
+ anno_file = [anno_file]
329
+
330
+ for anno_file_ in anno_file:
331
+ print(f"Load annotation file from {anno_file_}")
332
+ with jsonlines.open(anno_file_, 'r') as reader:
333
+ for item in tqdm(reader):
334
+ image_paths.append(item['image'])
335
+
336
+ print(f"Totally Remained {len(image_paths)} images")
337
+
338
+ # pack multiple frames
339
+ for idx in range(0, len(image_paths), self.max_frames):
340
+ image_path_shard = image_paths[idx : idx + self.max_frames]
341
+ if len(image_path_shard) < self.max_frames:
342
+ image_path_shard = image_path_shard + image_paths[:self.max_frames - len(image_path_shard)]
343
+ assert len(image_path_shard) == self.max_frames
344
+ self.image_annos.append(image_path_shard)
345
+
346
+ image_size = resolution
347
+ transform_list = [
348
+ transforms.Resize(image_size, interpolation=InterpolationMode.BICUBIC, antialias=True),
349
+ transforms.CenterCrop(image_size),
350
+ transforms.ToTensor(),
351
+ ]
352
+ if add_normalize:
353
+ transform_list.append(transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))
354
+
355
+ print(f"Transform List is {transform_list}")
356
+ self.transform = transforms.Compose(transform_list)
357
+
358
+ def __len__(self):
359
+ return len(self.image_annos)
360
+
361
+ def __getitem__(self, index):
362
+ image_paths = self.image_annos[index]
363
+
364
+ try:
365
+ packed_pil_frames = [Image.open(image_path).convert("RGB") for image_path in image_paths]
366
+ filtered_frames = [self.transform(frame) for frame in packed_pil_frames]
367
+ filtered_frames = torch.stack(filtered_frames) # [t, c, h, w]
368
+ filtered_frames = filtered_frames.permute(1, 0, 2, 3) # [c, t, h, w]
369
+
370
+ return {
371
+ "video": filtered_frames,
372
+ "identifier": 'image',
373
+ }
374
+
375
+ except Exception as e:
376
+ print(f'Load Images Error with {e}')
377
+ return self.__getitem__(random.randint(0, self.__len__() - 1))
diffusion_schedulers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .scheduling_cosine_ddpm import DDPMCosineScheduler
2
+ from .scheduling_flow_matching import PyramidFlowMatchEulerDiscreteScheduler
diffusion_schedulers/scheduling_cosine_ddpm.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.utils import BaseOutput
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
11
+
12
+
13
+ @dataclass
14
+ class DDPMSchedulerOutput(BaseOutput):
15
+ """
16
+ Output class for the scheduler's step function output.
17
+
18
+ Args:
19
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
20
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
21
+ denoising loop.
22
+ """
23
+
24
+ prev_sample: torch.Tensor
25
+
26
+
27
+ class DDPMCosineScheduler(SchedulerMixin, ConfigMixin):
28
+
29
+ @register_to_config
30
+ def __init__(
31
+ self,
32
+ scaler: float = 1.0,
33
+ s: float = 0.008,
34
+ ):
35
+ self.scaler = scaler
36
+ self.s = torch.tensor([s])
37
+ self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
38
+
39
+ # standard deviation of the initial noise distribution
40
+ self.init_noise_sigma = 1.0
41
+
42
+ def _alpha_cumprod(self, t, device):
43
+ if self.scaler > 1:
44
+ t = 1 - (1 - t) ** self.scaler
45
+ elif self.scaler < 1:
46
+ t = t**self.scaler
47
+ alpha_cumprod = torch.cos(
48
+ (t + self.s.to(device)) / (1 + self.s.to(device)) * torch.pi * 0.5
49
+ ) ** 2 / self._init_alpha_cumprod.to(device)
50
+ return alpha_cumprod.clamp(0.0001, 0.9999)
51
+
52
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
53
+ """
54
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
55
+ current timestep.
56
+
57
+ Args:
58
+ sample (`torch.Tensor`): input sample
59
+ timestep (`int`, optional): current timestep
60
+
61
+ Returns:
62
+ `torch.Tensor`: scaled input sample
63
+ """
64
+ return sample
65
+
66
+ def set_timesteps(
67
+ self,
68
+ num_inference_steps: int = None,
69
+ timesteps: Optional[List[int]] = None,
70
+ device: Union[str, torch.device] = None,
71
+ ):
72
+ """
73
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
74
+
75
+ Args:
76
+ num_inference_steps (`Dict[float, int]`):
77
+ the number of diffusion steps used when generating samples with a pre-trained model. If passed, then
78
+ `timesteps` must be `None`.
79
+ device (`str` or `torch.device`, optional):
80
+ the device to which the timesteps are moved to. {2 / 3: 20, 0.0: 10}
81
+ """
82
+ if timesteps is None:
83
+ timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device)
84
+ if not isinstance(timesteps, torch.Tensor):
85
+ timesteps = torch.Tensor(timesteps).to(device)
86
+ self.timesteps = timesteps
87
+
88
+ def step(
89
+ self,
90
+ model_output: torch.Tensor,
91
+ timestep: int,
92
+ sample: torch.Tensor,
93
+ generator=None,
94
+ return_dict: bool = True,
95
+ ) -> Union[DDPMSchedulerOutput, Tuple]:
96
+ dtype = model_output.dtype
97
+ device = model_output.device
98
+ t = timestep
99
+
100
+ prev_t = self.previous_timestep(t)
101
+
102
+ alpha_cumprod = self._alpha_cumprod(t, device).view(t.size(0), *[1 for _ in sample.shape[1:]])
103
+ alpha_cumprod_prev = self._alpha_cumprod(prev_t, device).view(prev_t.size(0), *[1 for _ in sample.shape[1:]])
104
+ alpha = alpha_cumprod / alpha_cumprod_prev
105
+
106
+ mu = (1.0 / alpha).sqrt() * (sample - (1 - alpha) * model_output / (1 - alpha_cumprod).sqrt())
107
+
108
+ std_noise = randn_tensor(mu.shape, generator=generator, device=model_output.device, dtype=model_output.dtype)
109
+ std = ((1 - alpha) * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod)).sqrt() * std_noise
110
+ pred = mu + std * (prev_t != 0).float().view(prev_t.size(0), *[1 for _ in sample.shape[1:]])
111
+
112
+ if not return_dict:
113
+ return (pred.to(dtype),)
114
+
115
+ return DDPMSchedulerOutput(prev_sample=pred.to(dtype))
116
+
117
+ def add_noise(
118
+ self,
119
+ original_samples: torch.Tensor,
120
+ noise: torch.Tensor,
121
+ timesteps: torch.Tensor,
122
+ ) -> torch.Tensor:
123
+ device = original_samples.device
124
+ dtype = original_samples.dtype
125
+ alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view(
126
+ timesteps.size(0), *[1 for _ in original_samples.shape[1:]]
127
+ )
128
+ noisy_samples = alpha_cumprod.sqrt() * original_samples + (1 - alpha_cumprod).sqrt() * noise
129
+ return noisy_samples.to(dtype=dtype)
130
+
131
+ def __len__(self):
132
+ return self.config.num_train_timesteps
133
+
134
+ def previous_timestep(self, timestep):
135
+ index = (self.timesteps - timestep[0]).abs().argmin().item()
136
+ prev_t = self.timesteps[index + 1][None].expand(timestep.shape[0])
137
+ return prev_t
diffusion_schedulers/scheduling_flow_matching.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union, List
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.utils import BaseOutput, logging
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
11
+
12
+
13
+ @dataclass
14
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
15
+ """
16
+ Output class for the scheduler's `step` function output.
17
+
18
+ Args:
19
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
20
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
21
+ denoising loop.
22
+ """
23
+
24
+ prev_sample: torch.FloatTensor
25
+
26
+
27
+ class PyramidFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
28
+ """
29
+ Euler scheduler.
30
+
31
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
32
+ methods the library implements for all schedulers such as loading and saving.
33
+
34
+ Args:
35
+ num_train_timesteps (`int`, defaults to 1000):
36
+ The number of diffusion steps to train the model.
37
+ timestep_spacing (`str`, defaults to `"linspace"`):
38
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
39
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
40
+ shift (`float`, defaults to 1.0):
41
+ The shift value for the timestep schedule.
42
+ """
43
+
44
+ _compatibles = []
45
+ order = 1
46
+
47
+ @register_to_config
48
+ def __init__(
49
+ self,
50
+ num_train_timesteps: int = 1000,
51
+ shift: float = 1.0, # Following Stable diffusion 3,
52
+ stages: int = 3,
53
+ stage_range: List = [0, 1/3, 2/3, 1],
54
+ gamma: float = 1/3,
55
+ ):
56
+
57
+ self.timestep_ratios = {} # The timestep ratio for each stage
58
+ self.timesteps_per_stage = {} # The detailed timesteps per stage
59
+ self.sigmas_per_stage = {}
60
+ self.start_sigmas = {}
61
+ self.end_sigmas = {}
62
+ self.ori_start_sigmas = {}
63
+
64
+ # self.init_sigmas()
65
+ self.init_sigmas_for_each_stage()
66
+ self.sigma_min = self.sigmas[-1].item()
67
+ self.sigma_max = self.sigmas[0].item()
68
+ self.gamma = gamma
69
+
70
+ def init_sigmas(self):
71
+ """
72
+ initialize the global timesteps and sigmas
73
+ """
74
+ num_train_timesteps = self.config.num_train_timesteps
75
+ shift = self.config.shift
76
+
77
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
78
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
79
+
80
+ sigmas = timesteps / num_train_timesteps
81
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
82
+
83
+ self.timesteps = sigmas * num_train_timesteps
84
+
85
+ self._step_index = None
86
+ self._begin_index = None
87
+
88
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
89
+
90
+ def init_sigmas_for_each_stage(self):
91
+ """
92
+ Init the timesteps for each stage
93
+ """
94
+ self.init_sigmas()
95
+
96
+ stage_distance = []
97
+ stages = self.config.stages
98
+ training_steps = self.config.num_train_timesteps
99
+ stage_range = self.config.stage_range
100
+
101
+ # Init the start and end point of each stage
102
+ for i_s in range(stages):
103
+ # To decide the start and ends point
104
+ start_indice = int(stage_range[i_s] * training_steps)
105
+ start_indice = max(start_indice, 0)
106
+ end_indice = int(stage_range[i_s+1] * training_steps)
107
+ end_indice = min(end_indice, training_steps)
108
+ start_sigma = self.sigmas[start_indice].item()
109
+ end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0
110
+ self.ori_start_sigmas[i_s] = start_sigma
111
+
112
+ if i_s != 0:
113
+ ori_sigma = 1 - start_sigma
114
+ gamma = self.config.gamma
115
+ corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma
116
+ # corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma
117
+ start_sigma = 1 - corrected_sigma
118
+
119
+ stage_distance.append(start_sigma - end_sigma)
120
+ self.start_sigmas[i_s] = start_sigma
121
+ self.end_sigmas[i_s] = end_sigma
122
+
123
+ # Determine the ratio of each stage according to flow length
124
+ tot_distance = sum(stage_distance)
125
+ for i_s in range(stages):
126
+ if i_s == 0:
127
+ start_ratio = 0.0
128
+ else:
129
+ start_ratio = sum(stage_distance[:i_s]) / tot_distance
130
+ if i_s == stages - 1:
131
+ end_ratio = 1.0
132
+ else:
133
+ end_ratio = sum(stage_distance[:i_s+1]) / tot_distance
134
+
135
+ self.timestep_ratios[i_s] = (start_ratio, end_ratio)
136
+
137
+ # Determine the timesteps and sigmas for each stage
138
+ for i_s in range(stages):
139
+ timestep_ratio = self.timestep_ratios[i_s]
140
+ timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)]
141
+ timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)]
142
+ timesteps = np.linspace(
143
+ timestep_max, timestep_min, training_steps + 1,
144
+ )
145
+ self.timesteps_per_stage[i_s] = timesteps[:-1] if isinstance(timesteps, torch.Tensor) else torch.from_numpy(timesteps[:-1])
146
+ stage_sigmas = np.linspace(
147
+ 1, 0, training_steps + 1,
148
+ )
149
+ self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1])
150
+
151
+ @property
152
+ def step_index(self):
153
+ """
154
+ The index counter for current timestep. It will increase 1 after each scheduler step.
155
+ """
156
+ return self._step_index
157
+
158
+ @property
159
+ def begin_index(self):
160
+ """
161
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
162
+ """
163
+ return self._begin_index
164
+
165
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
166
+ def set_begin_index(self, begin_index: int = 0):
167
+ """
168
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
169
+
170
+ Args:
171
+ begin_index (`int`):
172
+ The begin index for the scheduler.
173
+ """
174
+ self._begin_index = begin_index
175
+
176
+ def _sigma_to_t(self, sigma):
177
+ return sigma * self.config.num_train_timesteps
178
+
179
+ def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Union[str, torch.device] = None):
180
+ """
181
+ Setting the timesteps and sigmas for each stage
182
+ """
183
+ self.num_inference_steps = num_inference_steps
184
+ training_steps = self.config.num_train_timesteps
185
+ self.init_sigmas()
186
+
187
+ stage_timesteps = self.timesteps_per_stage[stage_index]
188
+ timestep_max = stage_timesteps[0].item()
189
+ timestep_min = stage_timesteps[-1].item()
190
+
191
+ timesteps = np.linspace(
192
+ timestep_max, timestep_min, num_inference_steps,
193
+ )
194
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
195
+
196
+ stage_sigmas = self.sigmas_per_stage[stage_index]
197
+ sigma_max = stage_sigmas[0].item()
198
+ sigma_min = stage_sigmas[-1].item()
199
+
200
+ ratios = np.linspace(
201
+ sigma_max, sigma_min, num_inference_steps
202
+ )
203
+ sigmas = torch.from_numpy(ratios).to(device=device)
204
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
205
+
206
+ self._step_index = None
207
+
208
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
209
+ if schedule_timesteps is None:
210
+ schedule_timesteps = self.timesteps
211
+
212
+ indices = (schedule_timesteps == timestep).nonzero()
213
+
214
+ # The sigma index that is taken for the **very** first `step`
215
+ # is always the second index (or the last index if there is only 1)
216
+ # This way we can ensure we don't accidentally skip a sigma in
217
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
218
+ pos = 1 if len(indices) > 1 else 0
219
+
220
+ return indices[pos].item()
221
+
222
+ def _init_step_index(self, timestep):
223
+ if self.begin_index is None:
224
+ if isinstance(timestep, torch.Tensor):
225
+ timestep = timestep.to(self.timesteps.device)
226
+ self._step_index = self.index_for_timestep(timestep)
227
+ else:
228
+ self._step_index = self._begin_index
229
+
230
+ def step(
231
+ self,
232
+ model_output: torch.FloatTensor,
233
+ timestep: Union[float, torch.FloatTensor],
234
+ sample: torch.FloatTensor,
235
+ generator: Optional[torch.Generator] = None,
236
+ return_dict: bool = True,
237
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
238
+ """
239
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
240
+ process from the learned model outputs (most often the predicted noise).
241
+
242
+ Args:
243
+ model_output (`torch.FloatTensor`):
244
+ The direct output from learned diffusion model.
245
+ timestep (`float`):
246
+ The current discrete timestep in the diffusion chain.
247
+ sample (`torch.FloatTensor`):
248
+ A current instance of a sample created by the diffusion process.
249
+ generator (`torch.Generator`, *optional*):
250
+ A random number generator.
251
+ return_dict (`bool`):
252
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
253
+ tuple.
254
+
255
+ Returns:
256
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
257
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
258
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
259
+ """
260
+
261
+ if (
262
+ isinstance(timestep, int)
263
+ or isinstance(timestep, torch.IntTensor)
264
+ or isinstance(timestep, torch.LongTensor)
265
+ ):
266
+ raise ValueError(
267
+ (
268
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
269
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
270
+ " one of the `scheduler.timesteps` as a timestep."
271
+ ),
272
+ )
273
+
274
+ if self.step_index is None:
275
+ self._step_index = 0
276
+
277
+ # Upcast to avoid precision issues when computing prev_sample
278
+ sample = sample.to(torch.float32)
279
+
280
+ sigma = self.sigmas[self.step_index]
281
+ sigma_next = self.sigmas[self.step_index + 1]
282
+
283
+ prev_sample = sample + (sigma_next - sigma) * model_output
284
+
285
+ # Cast sample back to model compatible dtype
286
+ prev_sample = prev_sample.to(model_output.dtype)
287
+
288
+ # upon completion increase step index by one
289
+ self._step_index += 1
290
+
291
+ if not return_dict:
292
+ return (prev_sample,)
293
+
294
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
295
+
296
+ def __len__(self):
297
+ return self.config.num_train_timesteps
docs/DiT.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pyramid Flow's DiT Finetuning Guide
2
+
3
+ This is the finetuning guide for the DiT in Pyramid Flow. We provide instructions for both autoregressive and non-autoregressive versions. The former is more research oriented and the latter is more stable (but less efficient without temporal pyramid). Please refer to [another document](https://github.com/jy0205/Pyramid-Flow/blob/main/docs/VAE) for VAE finetuning.
4
+
5
+ ## Hardware Requirements
6
+
7
+ + DiT finetuning: At least 8 A100 GPUs.
8
+
9
+
10
+ ## Prepare the Dataset
11
+
12
+ The training dataset should be arranged into a json file, with `video`, `text` fields. Since the video vae latent extraction is very slow, we strongly recommend you to pre-extract the video vae latents to save the training time. We provide a video vae latent extraction script in folder `tools`. You can run it with the following command:
13
+
14
+ ```bash
15
+ sh scripts/extract_vae_latent.sh
16
+ ```
17
+
18
+ (optional) Since the T5 text encoder will cost a lot of GPU memory, pre-extract the text features will save the training memory. We also provide a text feature extraction script in folder `tools`. You can run it with the following command:
19
+
20
+ ```bash
21
+ sh scripts/extract_text_feature.sh
22
+ ```
23
+
24
+ The final training annotation json file should look like the following format:
25
+
26
+ ```
27
+ {"video": video_path, "text": text prompt, "latent": extracted video vae latent, "text_fea": extracted text feature}
28
+ ```
29
+
30
+ We provide the example json annotation files for [video](https://github.com/jy0205/Pyramid-Flow/blob/main/annotation/video_text.jsonl) and [image](https://github.com/jy0205/Pyramid-Flow/blob/main/annotation/image_text.jsonl)) training in the `annotation` folder. You can refer them to prepare your training dataset.
31
+
32
+
33
+ ## Run Training
34
+ We provide two types of training scripts: (1) autoregressive video generation training with temporal pyramid. (2) Full-sequence diffusion training with pyramid-flow for both text-to-image and text-to-video training. This corresponds to the following two script files. Running these training scripts using at least 8 GPUs:
35
+
36
+ + `scripts/train_pyramid_flow.sh`: The autoregressive video generation training with temporal pyramid.
37
+
38
+ ```bash
39
+ sh scripts/train_pyramid_flow.sh
40
+ ```
41
+
42
+ + `scripts/train_pyramid_flow_without_ar.sh`: Using pyramid-flow for full-sequence diffusion training.
43
+
44
+ ```bash
45
+ sh scripts/train_pyramid_flow_without_ar.sh
46
+ ```
47
+
48
+
49
+ ## Tips
50
+
51
+ + For the 768p version, make sure to add the args: `--gradient_checkpointing`
52
+ + Param `NUM_FRAMES` should be set to a multiple of 8
53
+ + For the param `video_sync_group`, it indicates the number of process that accepts the same input video, used for temporal pyramid AR training. We recommend to set this value to 4, 8 or 16. (16 is better if you have more GPUs)
54
+ + Make sure to set `NUM_FRAMES % VIDEO_SYNC_GROUP == 0`, `GPUS % VIDEO_SYNC_GROUP == 0`, and `BATCH_SIZE % 4 == 0`
docs/VAE.md ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pyramid Flow's VAE Training Guide
2
+
3
+ This is the training guide for a [MAGVIT-v2](https://arxiv.org/abs/2310.05737) like continuous 3D VAE, which should be quite flexible. Feel free to build your own video generative model on this part of VAE training code. Please refer to [another document](https://github.com/jy0205/Pyramid-Flow/blob/main/docs/DiT) for DiT finetuning.
4
+
5
+ ## Hardware Requirements
6
+
7
+ + VAE training: At least 8 A100 GPUs.
8
+
9
+
10
+ ## Prepare the Dataset
11
+
12
+ The training of our causal video vae uses both image and video data. Both of them should be arranged into a json file, with `video` or `image` field. The final training annotation json file should look like the following format:
13
+
14
+ ```
15
+ # For Video
16
+ {"video": video_path}
17
+
18
+ # For Image
19
+ {"image": image_path}
20
+ ```
21
+
22
+ ## Run Training
23
+
24
+ The causal video vae undergoes a two-stage training.
25
+ + Stage-1: image and video mixed training
26
+ + Stage-2: pure video training, using context parallel to load video with more video frames
27
+
28
+ The VAE training script is `scripts/train_causal_video_vae.sh`, run it as follows:
29
+
30
+ ```bash
31
+ sh scripts/train_causal_video_vae.sh
32
+ ```
33
+
34
+ We also provide a VAE demo `causal_video_vae_demo.ipynb` for image and video reconstruction.
35
+
36
+
37
+ ## Tips
38
+
39
+ + For stage-1, we use a mixed image and video training. Add the param `--use_image_video_mixed_training` to support the mixed training. We set the image ratio to 0.1 by default.
40
+ + Set the `resolution` to 256 is enough for VAE training.
41
+ + For stage-1, the `max_frames` is set to 17. It means we use 17 sampled video frames for training.
42
+ + For stage-2, we open the param `use_context_parallel` to distribute long video frames to multiple GPUs. Make sure to set `GPUS % CONTEXT_SIZE == 0` and `NUM_FRAMES=17 * CONTEXT_SIZE + 1`
image_generation_demo.ipynb ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import json\n",
11
+ "import torch\n",
12
+ "import numpy as np\n",
13
+ "import PIL\n",
14
+ "from PIL import Image\n",
15
+ "from IPython.display import HTML\n",
16
+ "from pyramid_dit import PyramidDiTForVideoGeneration\n",
17
+ "from IPython.display import Image as ipython_image\n",
18
+ "from diffusers.utils import load_image, export_to_video, export_to_gif"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "variant='diffusion_transformer_image' # For low resolution\n",
28
+ "model_name = \"pyramid_flux\"\n",
29
+ "\n",
30
+ "model_path = \"/home/jinyang06/models/pyramid-flow-miniflux\" # The downloaded checkpoint dir\n",
31
+ "model_dtype = 'bf16'\n",
32
+ "\n",
33
+ "device_id = 0\n",
34
+ "torch.cuda.set_device(device_id)\n",
35
+ "\n",
36
+ "model = PyramidDiTForVideoGeneration(\n",
37
+ " model_path,\n",
38
+ " model_dtype,\n",
39
+ " model_name=model_name,\n",
40
+ " model_variant=variant,\n",
41
+ ")\n",
42
+ "\n",
43
+ "model.vae.to(\"cuda\")\n",
44
+ "model.dit.to(\"cuda\")\n",
45
+ "model.text_encoder.to(\"cuda\")\n",
46
+ "\n",
47
+ "model.vae.enable_tiling()\n",
48
+ "\n",
49
+ "if model_dtype == \"bf16\":\n",
50
+ " torch_dtype = torch.bfloat16 \n",
51
+ "elif model_dtype == \"fp16\":\n",
52
+ " torch_dtype = torch.float16\n",
53
+ "else:\n",
54
+ " torch_dtype = torch.float32"
55
+ ]
56
+ },
57
+ {
58
+ "attachments": {},
59
+ "cell_type": "markdown",
60
+ "metadata": {},
61
+ "source": [
62
+ "### Text-to-Image"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "prompt = \"shoulder and full head portrait of a beautiful 19 year old girl, brunette, smiling, stunning, highly detailed, glamour lighting, HDR, photorealistic, hyperrealism, octane render, unreal engine\"\n",
72
+ "\n",
73
+ "# now support 3 aspect ratios\n",
74
+ "resolution_dict = {\n",
75
+ " '1:1' : (1024, 1024),\n",
76
+ " '5:3' : (1280, 768),\n",
77
+ " '3:5' : (768, 1280),\n",
78
+ "}\n",
79
+ "\n",
80
+ "ratio = '1:1' # 1:1, 5:3, 3:5\n",
81
+ "\n",
82
+ "width, height = resolution_dict[ratio]\n",
83
+ "\n",
84
+ "\n",
85
+ "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):\n",
86
+ " images = model.generate(\n",
87
+ " prompt=prompt,\n",
88
+ " num_inference_steps=[20, 20, 20],\n",
89
+ " height=height,\n",
90
+ " width=width,\n",
91
+ " temp=1,\n",
92
+ " guidance_scale=9.0, \n",
93
+ " output_type=\"pil\",\n",
94
+ " save_memory=False, \n",
95
+ " )\n",
96
+ "\n",
97
+ "display(images[0])"
98
+ ]
99
+ }
100
+ ],
101
+ "metadata": {
102
+ "kernelspec": {
103
+ "display_name": "Python 3",
104
+ "language": "python",
105
+ "name": "python3"
106
+ },
107
+ "language_info": {
108
+ "codemirror_mode": {
109
+ "name": "ipython",
110
+ "version": 3
111
+ },
112
+ "file_extension": ".py",
113
+ "mimetype": "text/x-python",
114
+ "name": "python",
115
+ "nbconvert_exporter": "python",
116
+ "pygments_lexer": "ipython3",
117
+ "version": "3.8.10"
118
+ },
119
+ "orig_nbformat": 4
120
+ },
121
+ "nbformat": 4,
122
+ "nbformat_minor": 2
123
+ }
inference_multigpu.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import sys
4
+ import argparse
5
+ import random
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from diffusers.utils import export_to_video
10
+ from pyramid_dit import PyramidDiTForVideoGeneration
11
+ from trainer_misc import init_distributed_mode, init_sequence_parallel_group
12
+ import PIL
13
+ from PIL import Image
14
+
15
+
16
+ def get_args():
17
+ parser = argparse.ArgumentParser('Pytorch Multi-process Script', add_help=False)
18
+ parser.add_argument('--model_name', default='pyramid_flux', type=str, help="The model name", choices=["pyramid_flux", "pyramid_mmdit"])
19
+ parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16")
20
+ parser.add_argument('--model_path', default='/home/jinyang06/models/pyramid-flow', type=str, help='Set it to the downloaded checkpoint dir')
21
+ parser.add_argument('--variant', default='diffusion_transformer_768p', type=str,)
22
+ parser.add_argument('--task', default='t2v', type=str, choices=['i2v', 't2v'])
23
+ parser.add_argument('--temp', default=16, type=int, help='The generated latent num, num_frames = temp * 8 + 1')
24
+ parser.add_argument('--sp_group_size', default=2, type=int, help="The number of gpus used for inference, should be 2 or 4")
25
+ parser.add_argument('--sp_proc_num', default=-1, type=int, help="The number of process used for video training, default=-1 means using all process.")
26
+
27
+ return parser.parse_args()
28
+
29
+
30
+ def main():
31
+ args = get_args()
32
+
33
+ # setup DDP
34
+ init_distributed_mode(args)
35
+
36
+ assert args.world_size == args.sp_group_size, "The sequence parallel size should be DDP world size"
37
+
38
+ # Enable sequence parallel
39
+ init_sequence_parallel_group(args)
40
+
41
+ device = torch.device('cuda')
42
+ rank = args.rank
43
+ model_dtype = args.model_dtype
44
+
45
+ model = PyramidDiTForVideoGeneration(
46
+ args.model_path,
47
+ model_dtype,
48
+ model_name=args.model_name,
49
+ model_variant=args.variant,
50
+ )
51
+
52
+ model.vae.to(device)
53
+ model.dit.to(device)
54
+ model.text_encoder.to(device)
55
+ model.vae.enable_tiling()
56
+
57
+ if model_dtype == "bf16":
58
+ torch_dtype = torch.bfloat16
59
+ elif model_dtype == "fp16":
60
+ torch_dtype = torch.float16
61
+ else:
62
+ torch_dtype = torch.float32
63
+
64
+ # The video generation config
65
+ if args.variant == 'diffusion_transformer_768p':
66
+ width = 1280
67
+ height = 768
68
+ else:
69
+ assert args.variant == 'diffusion_transformer_384p'
70
+ width = 640
71
+ height = 384
72
+
73
+ if args.task == 't2v':
74
+ prompt = "A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors"
75
+
76
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):
77
+ frames = model.generate(
78
+ prompt=prompt,
79
+ num_inference_steps=[20, 20, 20],
80
+ video_num_inference_steps=[10, 10, 10],
81
+ height=height,
82
+ width=width,
83
+ temp=args.temp,
84
+ guidance_scale=7.0, # The guidance for the first frame, set it to 7 for 384p variant
85
+ video_guidance_scale=5.0, # The guidance for the other video latent
86
+ output_type="pil",
87
+ save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
88
+ cpu_offloading=False, # If OOM, set it to True to reduce memory usage
89
+ inference_multigpu=True,
90
+ )
91
+ if rank == 0:
92
+ export_to_video(frames, "./text_to_video_sample.mp4", fps=24)
93
+
94
+ else:
95
+ assert args.task == 'i2v'
96
+
97
+ image_path = 'assets/the_great_wall.jpg'
98
+ image = Image.open(image_path).convert("RGB")
99
+ image = image.resize((width, height))
100
+
101
+ prompt = "FPV flying over the Great Wall"
102
+
103
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):
104
+ frames = model.generate_i2v(
105
+ prompt=prompt,
106
+ input_image=image,
107
+ num_inference_steps=[10, 10, 10],
108
+ temp=args.temp,
109
+ video_guidance_scale=4.0,
110
+ output_type="pil",
111
+ save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
112
+ cpu_offloading=False, # If OOM, set it to True to reduce memory usage
113
+ inference_multigpu=True,
114
+ )
115
+
116
+ if rank == 0:
117
+ export_to_video(frames, "./image_to_video_sample.mp4", fps=24)
118
+
119
+ torch.distributed.barrier()
120
+
121
+
122
+ if __name__ == "__main__":
123
+ main()
pyramid_dit/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .pyramid_dit_for_video_gen_pipeline import PyramidDiTForVideoGeneration
2
+ from .flux_modules import FluxSingleTransformerBlock, FluxTransformerBlock, FluxTextEncoderWithMask
3
+ from .mmdit_modules import JointTransformerBlock, SD3TextEncoderWithMask
pyramid_dit/flux_modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .modeling_pyramid_flux import PyramidFluxTransformer
2
+ from .modeling_text_encoder import FluxTextEncoderWithMask
3
+ from .modeling_flux_block import FluxSingleTransformerBlock, FluxTransformerBlock
pyramid_dit/flux_modules/modeling_embedding.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ from diffusers.models.activations import get_activation, FP32SiLU
10
+
11
+ def get_timestep_embedding(
12
+ timesteps: torch.Tensor,
13
+ embedding_dim: int,
14
+ flip_sin_to_cos: bool = False,
15
+ downscale_freq_shift: float = 1,
16
+ scale: float = 1,
17
+ max_period: int = 10000,
18
+ ):
19
+ """
20
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
21
+
22
+ Args
23
+ timesteps (torch.Tensor):
24
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
25
+ embedding_dim (int):
26
+ the dimension of the output.
27
+ flip_sin_to_cos (bool):
28
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
29
+ downscale_freq_shift (float):
30
+ Controls the delta between frequencies between dimensions
31
+ scale (float):
32
+ Scaling factor applied to the embeddings.
33
+ max_period (int):
34
+ Controls the maximum frequency of the embeddings
35
+ Returns
36
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
37
+ """
38
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
39
+
40
+ half_dim = embedding_dim // 2
41
+ exponent = -math.log(max_period) * torch.arange(
42
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
43
+ )
44
+ exponent = exponent / (half_dim - downscale_freq_shift)
45
+
46
+ emb = torch.exp(exponent)
47
+ emb = timesteps[:, None].float() * emb[None, :]
48
+
49
+ # scale embeddings
50
+ emb = scale * emb
51
+
52
+ # concat sine and cosine embeddings
53
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
54
+
55
+ # flip sine and cosine embeddings
56
+ if flip_sin_to_cos:
57
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
58
+
59
+ # zero pad
60
+ if embedding_dim % 2 == 1:
61
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
62
+ return emb
63
+
64
+
65
+ class Timesteps(nn.Module):
66
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
67
+ super().__init__()
68
+ self.num_channels = num_channels
69
+ self.flip_sin_to_cos = flip_sin_to_cos
70
+ self.downscale_freq_shift = downscale_freq_shift
71
+ self.scale = scale
72
+
73
+ def forward(self, timesteps):
74
+ t_emb = get_timestep_embedding(
75
+ timesteps,
76
+ self.num_channels,
77
+ flip_sin_to_cos=self.flip_sin_to_cos,
78
+ downscale_freq_shift=self.downscale_freq_shift,
79
+ scale=self.scale,
80
+ )
81
+ return t_emb
82
+
83
+
84
+ class TimestepEmbedding(nn.Module):
85
+ def __init__(
86
+ self,
87
+ in_channels: int,
88
+ time_embed_dim: int,
89
+ act_fn: str = "silu",
90
+ out_dim: int = None,
91
+ post_act_fn: Optional[str] = None,
92
+ cond_proj_dim=None,
93
+ sample_proj_bias=True,
94
+ ):
95
+ super().__init__()
96
+
97
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
98
+
99
+ if cond_proj_dim is not None:
100
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
101
+ else:
102
+ self.cond_proj = None
103
+
104
+ self.act = get_activation(act_fn)
105
+
106
+ if out_dim is not None:
107
+ time_embed_dim_out = out_dim
108
+ else:
109
+ time_embed_dim_out = time_embed_dim
110
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
111
+
112
+ if post_act_fn is None:
113
+ self.post_act = None
114
+ else:
115
+ self.post_act = get_activation(post_act_fn)
116
+
117
+ def forward(self, sample, condition=None):
118
+ if condition is not None:
119
+ sample = sample + self.cond_proj(condition)
120
+ sample = self.linear_1(sample)
121
+
122
+ if self.act is not None:
123
+ sample = self.act(sample)
124
+
125
+ sample = self.linear_2(sample)
126
+
127
+ if self.post_act is not None:
128
+ sample = self.post_act(sample)
129
+ return sample
130
+
131
+
132
+ class PixArtAlphaTextProjection(nn.Module):
133
+ """
134
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
135
+
136
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
137
+ """
138
+
139
+ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
140
+ super().__init__()
141
+ if out_features is None:
142
+ out_features = hidden_size
143
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
144
+ if act_fn == "gelu_tanh":
145
+ self.act_1 = nn.GELU(approximate="tanh")
146
+ elif act_fn == "silu":
147
+ self.act_1 = nn.SiLU()
148
+ elif act_fn == "silu_fp32":
149
+ self.act_1 = FP32SiLU()
150
+ else:
151
+ raise ValueError(f"Unknown activation function: {act_fn}")
152
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
153
+
154
+ def forward(self, caption):
155
+ hidden_states = self.linear_1(caption)
156
+ hidden_states = self.act_1(hidden_states)
157
+ hidden_states = self.linear_2(hidden_states)
158
+ return hidden_states
159
+
160
+
161
+ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
162
+ def __init__(self, embedding_dim, pooled_projection_dim):
163
+ super().__init__()
164
+
165
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
166
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
167
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
168
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
169
+
170
+ def forward(self, timestep, guidance, pooled_projection):
171
+ timesteps_proj = self.time_proj(timestep)
172
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
173
+
174
+ guidance_proj = self.time_proj(guidance)
175
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
176
+
177
+ time_guidance_emb = timesteps_emb + guidance_emb
178
+
179
+ pooled_projections = self.text_embedder(pooled_projection)
180
+ conditioning = time_guidance_emb + pooled_projections
181
+
182
+ return conditioning
183
+
184
+
185
+ class CombinedTimestepTextProjEmbeddings(nn.Module):
186
+ def __init__(self, embedding_dim, pooled_projection_dim):
187
+ super().__init__()
188
+
189
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
190
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
191
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
192
+
193
+ def forward(self, timestep, pooled_projection):
194
+ timesteps_proj = self.time_proj(timestep)
195
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
196
+
197
+ pooled_projections = self.text_embedder(pooled_projection)
198
+
199
+ conditioning = timesteps_emb + pooled_projections
200
+
201
+ return conditioning
pyramid_dit/flux_modules/modeling_flux_block.py ADDED
@@ -0,0 +1,1044 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import inspect
7
+ from einops import rearrange
8
+
9
+ from diffusers.utils import deprecate
10
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, SwiGLU
11
+
12
+ from .modeling_normalization import (
13
+ AdaLayerNormContinuous, AdaLayerNormZero,
14
+ AdaLayerNormZeroSingle, FP32LayerNorm, RMSNorm
15
+ )
16
+
17
+ from trainer_misc import (
18
+ is_sequence_parallel_initialized,
19
+ get_sequence_parallel_group,
20
+ get_sequence_parallel_world_size,
21
+ all_to_all,
22
+ )
23
+
24
+ try:
25
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
26
+ from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis
27
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
28
+ except:
29
+ flash_attn_func = None
30
+ flash_attn_qkvpacked_func = None
31
+ flash_attn_varlen_func = None
32
+
33
+
34
+ def apply_rope(xq, xk, freqs_cis):
35
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
36
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
37
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
38
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
39
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
40
+
41
+
42
+ class FeedForward(nn.Module):
43
+ r"""
44
+ A feed-forward layer.
45
+
46
+ Parameters:
47
+ dim (`int`): The number of channels in the input.
48
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
49
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
50
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
51
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
52
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
53
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ dim: int,
59
+ dim_out: Optional[int] = None,
60
+ mult: int = 4,
61
+ dropout: float = 0.0,
62
+ activation_fn: str = "geglu",
63
+ final_dropout: bool = False,
64
+ inner_dim=None,
65
+ bias: bool = True,
66
+ ):
67
+ super().__init__()
68
+ if inner_dim is None:
69
+ inner_dim = int(dim * mult)
70
+ dim_out = dim_out if dim_out is not None else dim
71
+
72
+ if activation_fn == "gelu":
73
+ act_fn = GELU(dim, inner_dim, bias=bias)
74
+ if activation_fn == "gelu-approximate":
75
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
76
+ elif activation_fn == "geglu":
77
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
78
+ elif activation_fn == "geglu-approximate":
79
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
80
+ elif activation_fn == "swiglu":
81
+ act_fn = SwiGLU(dim, inner_dim, bias=bias)
82
+
83
+ self.net = nn.ModuleList([])
84
+ # project in
85
+ self.net.append(act_fn)
86
+ # project dropout
87
+ self.net.append(nn.Dropout(dropout))
88
+ # project out
89
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
90
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
91
+ if final_dropout:
92
+ self.net.append(nn.Dropout(dropout))
93
+
94
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
95
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
96
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
97
+ deprecate("scale", "1.0.0", deprecation_message)
98
+ for module in self.net:
99
+ hidden_states = module(hidden_states)
100
+ return hidden_states
101
+
102
+
103
+ class SequenceParallelVarlenFlashSelfAttentionWithT5Mask:
104
+
105
+ def __init__(self):
106
+ pass
107
+
108
+ def __call__(
109
+ self, query, key, value, encoder_query, encoder_key, encoder_value,
110
+ heads, scale, hidden_length=None, image_rotary_emb=None, encoder_attention_mask=None,
111
+ ):
112
+ assert encoder_attention_mask is not None, "The encoder-hidden mask needed to be set"
113
+
114
+ batch_size = query.shape[0]
115
+ qkv_list = []
116
+ num_stages = len(hidden_length)
117
+
118
+ encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
119
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
120
+
121
+ # To sync the encoder query, key and values
122
+ sp_group = get_sequence_parallel_group()
123
+ sp_group_size = get_sequence_parallel_world_size()
124
+ encoder_qkv = all_to_all(encoder_qkv, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
125
+
126
+ output_hidden = torch.zeros_like(qkv[:,:,0])
127
+ output_encoder_hidden = torch.zeros_like(encoder_qkv[:,:,0])
128
+ encoder_length = encoder_qkv.shape[1]
129
+
130
+ i_sum = 0
131
+ for i_p, length in enumerate(hidden_length):
132
+ # get the query, key, value from padding sequence
133
+ encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
134
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
135
+ qkv_tokens = all_to_all(qkv_tokens, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
136
+ concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, pad_seq, 3, nhead, dim]
137
+
138
+ if image_rotary_emb is not None:
139
+ concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
140
+
141
+ indices = encoder_attention_mask[i_p]['indices']
142
+ qkv_list.append(index_first_axis(rearrange(concat_qkv_tokens, "b s ... -> (b s) ..."), indices))
143
+ i_sum += length
144
+
145
+ token_lengths = [x_.shape[0] for x_ in qkv_list]
146
+ qkv = torch.cat(qkv_list, dim=0)
147
+ query, key, value = qkv.unbind(1)
148
+
149
+ cu_seqlens = torch.cat([x_['seqlens_in_batch'] for x_ in encoder_attention_mask], dim=0)
150
+ max_seqlen_q = cu_seqlens.max().item()
151
+ max_seqlen_k = max_seqlen_q
152
+ cu_seqlens_q = F.pad(torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32), (1, 0))
153
+ cu_seqlens_k = cu_seqlens_q.clone()
154
+
155
+ output = flash_attn_varlen_func(
156
+ query,
157
+ key,
158
+ value,
159
+ cu_seqlens_q=cu_seqlens_q,
160
+ cu_seqlens_k=cu_seqlens_k,
161
+ max_seqlen_q=max_seqlen_q,
162
+ max_seqlen_k=max_seqlen_k,
163
+ dropout_p=0.0,
164
+ causal=False,
165
+ softmax_scale=scale,
166
+ )
167
+
168
+ # To merge the tokens
169
+ i_sum = 0;token_sum = 0
170
+ for i_p, length in enumerate(hidden_length):
171
+ tot_token_num = token_lengths[i_p]
172
+ stage_output = output[token_sum : token_sum + tot_token_num]
173
+ stage_output = pad_input(stage_output, encoder_attention_mask[i_p]['indices'], batch_size, encoder_length + length * sp_group_size)
174
+ stage_encoder_hidden_output = stage_output[:, :encoder_length]
175
+ stage_hidden_output = stage_output[:, encoder_length:]
176
+ stage_hidden_output = all_to_all(stage_hidden_output, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
177
+ output_hidden[:, i_sum:i_sum+length] = stage_hidden_output
178
+ output_encoder_hidden[i_p::num_stages] = stage_encoder_hidden_output
179
+ token_sum += tot_token_num
180
+ i_sum += length
181
+
182
+ output_encoder_hidden = all_to_all(output_encoder_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
183
+ output_hidden = output_hidden.flatten(2, 3)
184
+ output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
185
+
186
+ return output_hidden, output_encoder_hidden
187
+
188
+
189
+ class VarlenFlashSelfAttentionWithT5Mask:
190
+
191
+ def __init__(self):
192
+ pass
193
+
194
+ def __call__(
195
+ self, query, key, value, encoder_query, encoder_key, encoder_value,
196
+ heads, scale, hidden_length=None, image_rotary_emb=None, encoder_attention_mask=None,
197
+ ):
198
+ assert encoder_attention_mask is not None, "The encoder-hidden mask needed to be set"
199
+
200
+ batch_size = query.shape[0]
201
+ output_hidden = torch.zeros_like(query)
202
+ output_encoder_hidden = torch.zeros_like(encoder_query)
203
+ encoder_length = encoder_query.shape[1]
204
+
205
+ qkv_list = []
206
+ num_stages = len(hidden_length)
207
+
208
+ encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
209
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
210
+
211
+ i_sum = 0
212
+ for i_p, length in enumerate(hidden_length):
213
+ encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
214
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
215
+ concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, tot_seq, 3, nhead, dim]
216
+
217
+ if image_rotary_emb is not None:
218
+ concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
219
+
220
+ indices = encoder_attention_mask[i_p]['indices']
221
+ qkv_list.append(index_first_axis(rearrange(concat_qkv_tokens, "b s ... -> (b s) ..."), indices))
222
+ i_sum += length
223
+
224
+ token_lengths = [x_.shape[0] for x_ in qkv_list]
225
+ qkv = torch.cat(qkv_list, dim=0)
226
+ query, key, value = qkv.unbind(1)
227
+
228
+ cu_seqlens = torch.cat([x_['seqlens_in_batch'] for x_ in encoder_attention_mask], dim=0)
229
+ max_seqlen_q = cu_seqlens.max().item()
230
+ max_seqlen_k = max_seqlen_q
231
+ cu_seqlens_q = F.pad(torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32), (1, 0))
232
+ cu_seqlens_k = cu_seqlens_q.clone()
233
+
234
+ output = flash_attn_varlen_func(
235
+ query,
236
+ key,
237
+ value,
238
+ cu_seqlens_q=cu_seqlens_q,
239
+ cu_seqlens_k=cu_seqlens_k,
240
+ max_seqlen_q=max_seqlen_q,
241
+ max_seqlen_k=max_seqlen_k,
242
+ dropout_p=0.0,
243
+ causal=False,
244
+ softmax_scale=scale,
245
+ )
246
+
247
+ # To merge the tokens
248
+ i_sum = 0;token_sum = 0
249
+ for i_p, length in enumerate(hidden_length):
250
+ tot_token_num = token_lengths[i_p]
251
+ stage_output = output[token_sum : token_sum + tot_token_num]
252
+ stage_output = pad_input(stage_output, encoder_attention_mask[i_p]['indices'], batch_size, encoder_length + length)
253
+ stage_encoder_hidden_output = stage_output[:, :encoder_length]
254
+ stage_hidden_output = stage_output[:, encoder_length:]
255
+ output_hidden[:, i_sum:i_sum+length] = stage_hidden_output
256
+ output_encoder_hidden[i_p::num_stages] = stage_encoder_hidden_output
257
+ token_sum += tot_token_num
258
+ i_sum += length
259
+
260
+ output_hidden = output_hidden.flatten(2, 3)
261
+ output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
262
+
263
+ return output_hidden, output_encoder_hidden
264
+
265
+
266
+ class SequenceParallelVarlenSelfAttentionWithT5Mask:
267
+
268
+ def __init__(self):
269
+ pass
270
+
271
+ def __call__(
272
+ self, query, key, value, encoder_query, encoder_key, encoder_value,
273
+ heads, scale, hidden_length=None, image_rotary_emb=None, attention_mask=None,
274
+ ):
275
+ assert attention_mask is not None, "The attention mask needed to be set"
276
+
277
+ num_stages = len(hidden_length)
278
+
279
+ encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
280
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
281
+
282
+ # To sync the encoder query, key and values
283
+ sp_group = get_sequence_parallel_group()
284
+ sp_group_size = get_sequence_parallel_world_size()
285
+ encoder_qkv = all_to_all(encoder_qkv, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
286
+ encoder_length = encoder_qkv.shape[1]
287
+
288
+ i_sum = 0
289
+ output_encoder_hidden_list = []
290
+ output_hidden_list = []
291
+
292
+ for i_p, length in enumerate(hidden_length):
293
+ encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
294
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
295
+ qkv_tokens = all_to_all(qkv_tokens, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
296
+ concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, tot_seq, 3, nhead, dim]
297
+
298
+ if image_rotary_emb is not None:
299
+ concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
300
+
301
+ query, key, value = concat_qkv_tokens.unbind(2) # [bs, tot_seq, nhead, dim]
302
+ query = query.transpose(1, 2)
303
+ key = key.transpose(1, 2)
304
+ value = value.transpose(1, 2)
305
+
306
+ stage_hidden_states = F.scaled_dot_product_attention(
307
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
308
+ )
309
+ stage_hidden_states = stage_hidden_states.transpose(1, 2) # [bs, tot_seq, nhead, dim]
310
+
311
+ output_encoder_hidden_list.append(stage_hidden_states[:, :encoder_length])
312
+
313
+ output_hidden = stage_hidden_states[:, encoder_length:]
314
+ output_hidden = all_to_all(output_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
315
+ output_hidden_list.append(output_hidden)
316
+
317
+ i_sum += length
318
+
319
+ output_encoder_hidden = torch.stack(output_encoder_hidden_list, dim=1) # [b n s nhead d]
320
+ output_encoder_hidden = rearrange(output_encoder_hidden, 'b n s h d -> (b n) s h d')
321
+ output_encoder_hidden = all_to_all(output_encoder_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
322
+ output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
323
+ output_hidden = torch.cat(output_hidden_list, dim=1).flatten(2, 3)
324
+
325
+ return output_hidden, output_encoder_hidden
326
+
327
+
328
+ class VarlenSelfAttentionWithT5Mask:
329
+
330
+ def __init__(self):
331
+ pass
332
+
333
+ def __call__(
334
+ self, query, key, value, encoder_query, encoder_key, encoder_value,
335
+ heads, scale, hidden_length=None, image_rotary_emb=None, attention_mask=None,
336
+ ):
337
+ assert attention_mask is not None, "The attention mask needed to be set"
338
+
339
+ encoder_length = encoder_query.shape[1]
340
+ num_stages = len(hidden_length)
341
+
342
+ encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
343
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
344
+
345
+ i_sum = 0
346
+ output_encoder_hidden_list = []
347
+ output_hidden_list = []
348
+
349
+ for i_p, length in enumerate(hidden_length):
350
+ encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
351
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
352
+ concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, tot_seq, 3, nhead, dim]
353
+
354
+ if image_rotary_emb is not None:
355
+ concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
356
+
357
+ query, key, value = concat_qkv_tokens.unbind(2) # [bs, tot_seq, nhead, dim]
358
+ query = query.transpose(1, 2)
359
+ key = key.transpose(1, 2)
360
+ value = value.transpose(1, 2)
361
+
362
+ # with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True):
363
+ stage_hidden_states = F.scaled_dot_product_attention(
364
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
365
+ )
366
+ stage_hidden_states = stage_hidden_states.transpose(1, 2).flatten(2, 3) # [bs, tot_seq, dim]
367
+
368
+ output_encoder_hidden_list.append(stage_hidden_states[:, :encoder_length])
369
+ output_hidden_list.append(stage_hidden_states[:, encoder_length:])
370
+ i_sum += length
371
+
372
+ output_encoder_hidden = torch.stack(output_encoder_hidden_list, dim=1) # [b n s d]
373
+ output_encoder_hidden = rearrange(output_encoder_hidden, 'b n s d -> (b n) s d')
374
+ output_hidden = torch.cat(output_hidden_list, dim=1)
375
+
376
+ return output_hidden, output_encoder_hidden
377
+
378
+
379
+ class SequenceParallelVarlenFlashAttnSingle:
380
+
381
+ def __init__(self):
382
+ pass
383
+
384
+ def __call__(
385
+ self, query, key, value, heads, scale,
386
+ hidden_length=None, image_rotary_emb=None, encoder_attention_mask=None,
387
+ ):
388
+ assert encoder_attention_mask is not None, "The encoder-hidden mask needed to be set"
389
+
390
+ batch_size = query.shape[0]
391
+ qkv_list = []
392
+ num_stages = len(hidden_length)
393
+
394
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
395
+ output_hidden = torch.zeros_like(qkv[:,:,0])
396
+
397
+ sp_group = get_sequence_parallel_group()
398
+ sp_group_size = get_sequence_parallel_world_size()
399
+
400
+ i_sum = 0
401
+ for i_p, length in enumerate(hidden_length):
402
+ # get the query, key, value from padding sequence
403
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
404
+ qkv_tokens = all_to_all(qkv_tokens, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
405
+
406
+ if image_rotary_emb is not None:
407
+ qkv_tokens[:,:,0], qkv_tokens[:,:,1] = apply_rope(qkv_tokens[:,:,0], qkv_tokens[:,:,1], image_rotary_emb[i_p])
408
+
409
+ indices = encoder_attention_mask[i_p]['indices']
410
+ qkv_list.append(index_first_axis(rearrange(qkv_tokens, "b s ... -> (b s) ..."), indices))
411
+ i_sum += length
412
+
413
+ token_lengths = [x_.shape[0] for x_ in qkv_list]
414
+ qkv = torch.cat(qkv_list, dim=0)
415
+ query, key, value = qkv.unbind(1)
416
+
417
+ cu_seqlens = torch.cat([x_['seqlens_in_batch'] for x_ in encoder_attention_mask], dim=0)
418
+ max_seqlen_q = cu_seqlens.max().item()
419
+ max_seqlen_k = max_seqlen_q
420
+ cu_seqlens_q = F.pad(torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32), (1, 0))
421
+ cu_seqlens_k = cu_seqlens_q.clone()
422
+
423
+ output = flash_attn_varlen_func(
424
+ query,
425
+ key,
426
+ value,
427
+ cu_seqlens_q=cu_seqlens_q,
428
+ cu_seqlens_k=cu_seqlens_k,
429
+ max_seqlen_q=max_seqlen_q,
430
+ max_seqlen_k=max_seqlen_k,
431
+ dropout_p=0.0,
432
+ causal=False,
433
+ softmax_scale=scale,
434
+ )
435
+
436
+ # To merge the tokens
437
+ i_sum = 0;token_sum = 0
438
+ for i_p, length in enumerate(hidden_length):
439
+ tot_token_num = token_lengths[i_p]
440
+ stage_output = output[token_sum : token_sum + tot_token_num]
441
+ stage_output = pad_input(stage_output, encoder_attention_mask[i_p]['indices'], batch_size, length * sp_group_size)
442
+ stage_hidden_output = all_to_all(stage_output, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
443
+ output_hidden[:, i_sum:i_sum+length] = stage_hidden_output
444
+ token_sum += tot_token_num
445
+ i_sum += length
446
+
447
+ output_hidden = output_hidden.flatten(2, 3)
448
+
449
+ return output_hidden
450
+
451
+
452
+ class VarlenFlashSelfAttnSingle:
453
+
454
+ def __init__(self):
455
+ pass
456
+
457
+ def __call__(
458
+ self, query, key, value, heads, scale,
459
+ hidden_length=None, image_rotary_emb=None, encoder_attention_mask=None,
460
+ ):
461
+ assert encoder_attention_mask is not None, "The encoder-hidden mask needed to be set"
462
+
463
+ batch_size = query.shape[0]
464
+ output_hidden = torch.zeros_like(query)
465
+
466
+ qkv_list = []
467
+ num_stages = len(hidden_length)
468
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
469
+
470
+ i_sum = 0
471
+ for i_p, length in enumerate(hidden_length):
472
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
473
+
474
+ if image_rotary_emb is not None:
475
+ qkv_tokens[:,:,0], qkv_tokens[:,:,1] = apply_rope(qkv_tokens[:,:,0], qkv_tokens[:,:,1], image_rotary_emb[i_p])
476
+
477
+ indices = encoder_attention_mask[i_p]['indices']
478
+ qkv_list.append(index_first_axis(rearrange(qkv_tokens, "b s ... -> (b s) ..."), indices))
479
+ i_sum += length
480
+
481
+ token_lengths = [x_.shape[0] for x_ in qkv_list]
482
+ qkv = torch.cat(qkv_list, dim=0)
483
+ query, key, value = qkv.unbind(1)
484
+
485
+ cu_seqlens = torch.cat([x_['seqlens_in_batch'] for x_ in encoder_attention_mask], dim=0)
486
+ max_seqlen_q = cu_seqlens.max().item()
487
+ max_seqlen_k = max_seqlen_q
488
+ cu_seqlens_q = F.pad(torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32), (1, 0))
489
+ cu_seqlens_k = cu_seqlens_q.clone()
490
+
491
+ output = flash_attn_varlen_func(
492
+ query,
493
+ key,
494
+ value,
495
+ cu_seqlens_q=cu_seqlens_q,
496
+ cu_seqlens_k=cu_seqlens_k,
497
+ max_seqlen_q=max_seqlen_q,
498
+ max_seqlen_k=max_seqlen_k,
499
+ dropout_p=0.0,
500
+ causal=False,
501
+ softmax_scale=scale,
502
+ )
503
+
504
+ # To merge the tokens
505
+ i_sum = 0;token_sum = 0
506
+ for i_p, length in enumerate(hidden_length):
507
+ tot_token_num = token_lengths[i_p]
508
+ stage_output = output[token_sum : token_sum + tot_token_num]
509
+ stage_output = pad_input(stage_output, encoder_attention_mask[i_p]['indices'], batch_size, length)
510
+ output_hidden[:, i_sum:i_sum+length] = stage_output
511
+ token_sum += tot_token_num
512
+ i_sum += length
513
+
514
+ output_hidden = output_hidden.flatten(2, 3)
515
+
516
+ return output_hidden
517
+
518
+
519
+ class SequenceParallelVarlenAttnSingle:
520
+
521
+ def __init__(self):
522
+ pass
523
+
524
+ def __call__(
525
+ self, query, key, value, heads, scale,
526
+ hidden_length=None, image_rotary_emb=None, attention_mask=None,
527
+ ):
528
+ assert attention_mask is not None, "The attention mask needed to be set"
529
+
530
+ num_stages = len(hidden_length)
531
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
532
+
533
+ # To sync the encoder query, key and values
534
+ sp_group = get_sequence_parallel_group()
535
+ sp_group_size = get_sequence_parallel_world_size()
536
+
537
+ i_sum = 0
538
+ output_hidden_list = []
539
+
540
+ for i_p, length in enumerate(hidden_length):
541
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
542
+ qkv_tokens = all_to_all(qkv_tokens, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
543
+
544
+ if image_rotary_emb is not None:
545
+ qkv_tokens[:,:,0], qkv_tokens[:,:,1] = apply_rope(qkv_tokens[:,:,0], qkv_tokens[:,:,1], image_rotary_emb[i_p])
546
+
547
+ query, key, value = qkv_tokens.unbind(2) # [bs, tot_seq, nhead, dim]
548
+ query = query.transpose(1, 2).contiguous()
549
+ key = key.transpose(1, 2).contiguous()
550
+ value = value.transpose(1, 2).contiguous()
551
+
552
+ stage_hidden_states = F.scaled_dot_product_attention(
553
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
554
+ )
555
+ stage_hidden_states = stage_hidden_states.transpose(1, 2) # [bs, tot_seq, nhead, dim]
556
+
557
+ output_hidden = stage_hidden_states
558
+ output_hidden = all_to_all(output_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
559
+ output_hidden_list.append(output_hidden)
560
+
561
+ i_sum += length
562
+
563
+ output_hidden = torch.cat(output_hidden_list, dim=1).flatten(2, 3)
564
+
565
+ return output_hidden
566
+
567
+
568
+ class VarlenSelfAttnSingle:
569
+
570
+ def __init__(self):
571
+ pass
572
+
573
+ def __call__(
574
+ self, query, key, value, heads, scale,
575
+ hidden_length=None, image_rotary_emb=None, attention_mask=None,
576
+ ):
577
+ assert attention_mask is not None, "The attention mask needed to be set"
578
+
579
+ num_stages = len(hidden_length)
580
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
581
+
582
+ i_sum = 0
583
+ output_hidden_list = []
584
+
585
+ for i_p, length in enumerate(hidden_length):
586
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
587
+
588
+ if image_rotary_emb is not None:
589
+ qkv_tokens[:,:,0], qkv_tokens[:,:,1] = apply_rope(qkv_tokens[:,:,0], qkv_tokens[:,:,1], image_rotary_emb[i_p])
590
+
591
+ query, key, value = qkv_tokens.unbind(2)
592
+ query = query.transpose(1, 2).contiguous()
593
+ key = key.transpose(1, 2).contiguous()
594
+ value = value.transpose(1, 2).contiguous()
595
+
596
+ stage_hidden_states = F.scaled_dot_product_attention(
597
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
598
+ )
599
+ stage_hidden_states = stage_hidden_states.transpose(1, 2).flatten(2, 3) # [bs, tot_seq, dim]
600
+
601
+ output_hidden_list.append(stage_hidden_states)
602
+ i_sum += length
603
+
604
+ output_hidden = torch.cat(output_hidden_list, dim=1)
605
+
606
+ return output_hidden
607
+
608
+
609
+ class Attention(nn.Module):
610
+
611
+ def __init__(
612
+ self,
613
+ query_dim: int,
614
+ cross_attention_dim: Optional[int] = None,
615
+ heads: int = 8,
616
+ dim_head: int = 64,
617
+ dropout: float = 0.0,
618
+ bias: bool = False,
619
+ qk_norm: Optional[str] = None,
620
+ added_kv_proj_dim: Optional[int] = None,
621
+ added_proj_bias: Optional[bool] = True,
622
+ out_bias: bool = True,
623
+ only_cross_attention: bool = False,
624
+ eps: float = 1e-5,
625
+ processor: Optional["AttnProcessor"] = None,
626
+ out_dim: int = None,
627
+ context_pre_only=None,
628
+ pre_only=False,
629
+ ):
630
+ super().__init__()
631
+
632
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
633
+ self.inner_kv_dim = self.inner_dim
634
+ self.query_dim = query_dim
635
+ self.use_bias = bias
636
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
637
+
638
+ self.dropout = dropout
639
+ self.out_dim = out_dim if out_dim is not None else query_dim
640
+ self.context_pre_only = context_pre_only
641
+ self.pre_only = pre_only
642
+
643
+ self.scale = dim_head**-0.5
644
+ self.heads = out_dim // dim_head if out_dim is not None else heads
645
+
646
+
647
+ self.added_kv_proj_dim = added_kv_proj_dim
648
+ self.only_cross_attention = only_cross_attention
649
+
650
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
651
+ raise ValueError(
652
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
653
+ )
654
+
655
+ if qk_norm is None:
656
+ self.norm_q = None
657
+ self.norm_k = None
658
+ elif qk_norm == "rms_norm":
659
+ self.norm_q = RMSNorm(dim_head, eps=eps)
660
+ self.norm_k = RMSNorm(dim_head, eps=eps)
661
+ else:
662
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
663
+
664
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
665
+
666
+ if not self.only_cross_attention:
667
+ # only relevant for the `AddedKVProcessor` classes
668
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
669
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
670
+ else:
671
+ self.to_k = None
672
+ self.to_v = None
673
+
674
+ self.added_proj_bias = added_proj_bias
675
+ if self.added_kv_proj_dim is not None:
676
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
677
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
678
+ if self.context_pre_only is not None:
679
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
680
+
681
+ if not self.pre_only:
682
+ self.to_out = nn.ModuleList([])
683
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
684
+ self.to_out.append(nn.Dropout(dropout))
685
+
686
+ if self.context_pre_only is not None and not self.context_pre_only:
687
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
688
+
689
+ if qk_norm is not None and added_kv_proj_dim is not None:
690
+ if qk_norm == "fp32_layer_norm":
691
+ self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
692
+ self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
693
+ elif qk_norm == "rms_norm":
694
+ self.norm_added_q = RMSNorm(dim_head, eps=eps)
695
+ self.norm_added_k = RMSNorm(dim_head, eps=eps)
696
+ else:
697
+ self.norm_added_q = None
698
+ self.norm_added_k = None
699
+
700
+ # set attention processor
701
+ self.set_processor(processor)
702
+
703
+ def set_processor(self, processor: "AttnProcessor") -> None:
704
+ self.processor = processor
705
+
706
+ def forward(
707
+ self,
708
+ hidden_states: torch.Tensor,
709
+ encoder_hidden_states: Optional[torch.Tensor] = None,
710
+ encoder_attention_mask: Optional[torch.Tensor] = None,
711
+ attention_mask: Optional[torch.Tensor] = None,
712
+ hidden_length: List = None,
713
+ image_rotary_emb: Optional[torch.Tensor] = None,
714
+ ) -> torch.Tensor:
715
+
716
+ return self.processor(
717
+ self,
718
+ hidden_states,
719
+ encoder_hidden_states=encoder_hidden_states,
720
+ encoder_attention_mask=encoder_attention_mask,
721
+ attention_mask=attention_mask,
722
+ hidden_length=hidden_length,
723
+ image_rotary_emb=image_rotary_emb,
724
+ )
725
+
726
+
727
+ class FluxSingleAttnProcessor2_0:
728
+ r"""
729
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
730
+ """
731
+ def __init__(self, use_flash_attn=False):
732
+ self.use_flash_attn = use_flash_attn
733
+
734
+ if self.use_flash_attn:
735
+ if is_sequence_parallel_initialized():
736
+ self.varlen_flash_attn = SequenceParallelVarlenFlashAttnSingle()
737
+ else:
738
+ self.varlen_flash_attn = VarlenFlashSelfAttnSingle()
739
+ else:
740
+ if is_sequence_parallel_initialized():
741
+ self.varlen_attn = SequenceParallelVarlenAttnSingle()
742
+ else:
743
+ self.varlen_attn = VarlenSelfAttnSingle()
744
+
745
+ def __call__(
746
+ self,
747
+ attn: Attention,
748
+ hidden_states: torch.Tensor,
749
+ encoder_hidden_states: Optional[torch.Tensor] = None,
750
+ encoder_attention_mask: Optional[torch.Tensor] = None,
751
+ attention_mask: Optional[torch.FloatTensor] = None,
752
+ hidden_length: List = None,
753
+ image_rotary_emb: Optional[torch.Tensor] = None,
754
+ ) -> torch.Tensor:
755
+
756
+ query = attn.to_q(hidden_states)
757
+ key = attn.to_k(hidden_states)
758
+ value = attn.to_v(hidden_states)
759
+
760
+ inner_dim = key.shape[-1]
761
+ head_dim = inner_dim // attn.heads
762
+
763
+ query = query.view(query.shape[0], -1, attn.heads, head_dim)
764
+ key = key.view(key.shape[0], -1, attn.heads, head_dim)
765
+ value = value.view(value.shape[0], -1, attn.heads, head_dim)
766
+
767
+ if attn.norm_q is not None:
768
+ query = attn.norm_q(query)
769
+ if attn.norm_k is not None:
770
+ key = attn.norm_k(key)
771
+
772
+ if self.use_flash_attn:
773
+ hidden_states = self.varlen_flash_attn(
774
+ query, key, value,
775
+ attn.heads, attn.scale, hidden_length,
776
+ image_rotary_emb, encoder_attention_mask,
777
+ )
778
+ else:
779
+ hidden_states = self.varlen_attn(
780
+ query, key, value,
781
+ attn.heads, attn.scale, hidden_length,
782
+ image_rotary_emb, attention_mask,
783
+ )
784
+
785
+ return hidden_states
786
+
787
+
788
+ class FluxAttnProcessor2_0:
789
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
790
+
791
+ def __init__(self, use_flash_attn=False):
792
+ self.use_flash_attn = use_flash_attn
793
+
794
+ if self.use_flash_attn:
795
+ if is_sequence_parallel_initialized():
796
+ self.varlen_flash_attn = SequenceParallelVarlenFlashSelfAttentionWithT5Mask()
797
+ else:
798
+ self.varlen_flash_attn = VarlenFlashSelfAttentionWithT5Mask()
799
+ else:
800
+ if is_sequence_parallel_initialized():
801
+ self.varlen_attn = SequenceParallelVarlenSelfAttentionWithT5Mask()
802
+ else:
803
+ self.varlen_attn = VarlenSelfAttentionWithT5Mask()
804
+
805
+ def __call__(
806
+ self,
807
+ attn: Attention,
808
+ hidden_states: torch.FloatTensor,
809
+ encoder_hidden_states: torch.FloatTensor = None,
810
+ encoder_attention_mask: Optional[torch.Tensor] = None,
811
+ attention_mask: Optional[torch.FloatTensor] = None,
812
+ hidden_length: List = None,
813
+ image_rotary_emb: Optional[torch.Tensor] = None,
814
+ ) -> torch.FloatTensor:
815
+ # `sample` projections.
816
+ query = attn.to_q(hidden_states)
817
+ key = attn.to_k(hidden_states)
818
+ value = attn.to_v(hidden_states)
819
+
820
+ inner_dim = key.shape[-1]
821
+ head_dim = inner_dim // attn.heads
822
+
823
+ query = query.view(query.shape[0], -1, attn.heads, head_dim)
824
+ key = key.view(key.shape[0], -1, attn.heads, head_dim)
825
+ value = value.view(value.shape[0], -1, attn.heads, head_dim)
826
+
827
+ if attn.norm_q is not None:
828
+ query = attn.norm_q(query)
829
+ if attn.norm_k is not None:
830
+ key = attn.norm_k(key)
831
+
832
+ # `context` projections.
833
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
834
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
835
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
836
+
837
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
838
+ encoder_hidden_states_query_proj.shape[0], -1, attn.heads, head_dim
839
+ )
840
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
841
+ encoder_hidden_states_key_proj.shape[0], -1, attn.heads, head_dim
842
+ )
843
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
844
+ encoder_hidden_states_value_proj.shape[0], -1, attn.heads, head_dim
845
+ )
846
+
847
+ if attn.norm_added_q is not None:
848
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
849
+ if attn.norm_added_k is not None:
850
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
851
+
852
+ if self.use_flash_attn:
853
+ hidden_states, encoder_hidden_states = self.varlen_flash_attn(
854
+ query, key, value,
855
+ encoder_hidden_states_query_proj, encoder_hidden_states_key_proj,
856
+ encoder_hidden_states_value_proj, attn.heads, attn.scale, hidden_length,
857
+ image_rotary_emb, encoder_attention_mask,
858
+ )
859
+ else:
860
+ hidden_states, encoder_hidden_states = self.varlen_attn(
861
+ query, key, value,
862
+ encoder_hidden_states_query_proj, encoder_hidden_states_key_proj,
863
+ encoder_hidden_states_value_proj, attn.heads, attn.scale, hidden_length,
864
+ image_rotary_emb, attention_mask,
865
+ )
866
+
867
+ # linear proj
868
+ hidden_states = attn.to_out[0](hidden_states)
869
+ # dropout
870
+ hidden_states = attn.to_out[1](hidden_states)
871
+
872
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
873
+
874
+ return hidden_states, encoder_hidden_states
875
+
876
+
877
+ class FluxSingleTransformerBlock(nn.Module):
878
+ r"""
879
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
880
+
881
+ Reference: https://arxiv.org/abs/2403.03206
882
+
883
+ Parameters:
884
+ dim (`int`): The number of channels in the input and output.
885
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
886
+ attention_head_dim (`int`): The number of channels in each head.
887
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
888
+ processing of `context` conditions.
889
+ """
890
+
891
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0, use_flash_attn=False):
892
+ super().__init__()
893
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
894
+
895
+ self.norm = AdaLayerNormZeroSingle(dim)
896
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
897
+ self.act_mlp = nn.GELU(approximate="tanh")
898
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
899
+
900
+ processor = FluxSingleAttnProcessor2_0(use_flash_attn)
901
+ self.attn = Attention(
902
+ query_dim=dim,
903
+ cross_attention_dim=None,
904
+ dim_head=attention_head_dim,
905
+ heads=num_attention_heads,
906
+ out_dim=dim,
907
+ bias=True,
908
+ processor=processor,
909
+ qk_norm="rms_norm",
910
+ eps=1e-6,
911
+ pre_only=True,
912
+ )
913
+
914
+ def forward(
915
+ self,
916
+ hidden_states: torch.FloatTensor,
917
+ temb: torch.FloatTensor,
918
+ encoder_attention_mask=None,
919
+ attention_mask=None,
920
+ hidden_length=None,
921
+ image_rotary_emb=None,
922
+ ):
923
+ residual = hidden_states
924
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb, hidden_length=hidden_length)
925
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
926
+
927
+ attn_output = self.attn(
928
+ hidden_states=norm_hidden_states,
929
+ encoder_hidden_states=None,
930
+ encoder_attention_mask=encoder_attention_mask,
931
+ attention_mask=attention_mask,
932
+ hidden_length=hidden_length,
933
+ image_rotary_emb=image_rotary_emb,
934
+ )
935
+
936
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
937
+ hidden_states = gate * self.proj_out(hidden_states)
938
+ hidden_states = residual + hidden_states
939
+ if hidden_states.dtype == torch.float16:
940
+ hidden_states = hidden_states.clip(-65504, 65504)
941
+
942
+ return hidden_states
943
+
944
+
945
+ class FluxTransformerBlock(nn.Module):
946
+ r"""
947
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
948
+
949
+ Reference: https://arxiv.org/abs/2403.03206
950
+
951
+ Parameters:
952
+ dim (`int`): The number of channels in the input and output.
953
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
954
+ attention_head_dim (`int`): The number of channels in each head.
955
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
956
+ processing of `context` conditions.
957
+ """
958
+
959
+ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6, use_flash_attn=False):
960
+ super().__init__()
961
+
962
+ self.norm1 = AdaLayerNormZero(dim)
963
+
964
+ self.norm1_context = AdaLayerNormZero(dim)
965
+
966
+ if hasattr(F, "scaled_dot_product_attention"):
967
+ processor = FluxAttnProcessor2_0(use_flash_attn)
968
+ else:
969
+ raise ValueError(
970
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
971
+ )
972
+ self.attn = Attention(
973
+ query_dim=dim,
974
+ cross_attention_dim=None,
975
+ added_kv_proj_dim=dim,
976
+ dim_head=attention_head_dim,
977
+ heads=num_attention_heads,
978
+ out_dim=dim,
979
+ context_pre_only=False,
980
+ bias=True,
981
+ processor=processor,
982
+ qk_norm=qk_norm,
983
+ eps=eps,
984
+ )
985
+
986
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
987
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
988
+
989
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
990
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
991
+
992
+ def forward(
993
+ self,
994
+ hidden_states: torch.FloatTensor,
995
+ encoder_hidden_states: torch.FloatTensor,
996
+ encoder_attention_mask: torch.FloatTensor,
997
+ temb: torch.FloatTensor,
998
+ attention_mask: torch.FloatTensor = None,
999
+ hidden_length: List = None,
1000
+ image_rotary_emb=None,
1001
+ ):
1002
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb, hidden_length=hidden_length)
1003
+
1004
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
1005
+ encoder_hidden_states, emb=temb
1006
+ )
1007
+
1008
+ # Attention.
1009
+ attn_output, context_attn_output = self.attn(
1010
+ hidden_states=norm_hidden_states,
1011
+ encoder_hidden_states=norm_encoder_hidden_states,
1012
+ encoder_attention_mask=encoder_attention_mask,
1013
+ attention_mask=attention_mask,
1014
+ hidden_length=hidden_length,
1015
+ image_rotary_emb=image_rotary_emb,
1016
+ )
1017
+
1018
+ # Process attention outputs for the `hidden_states`.
1019
+ attn_output = gate_msa * attn_output
1020
+ hidden_states = hidden_states + attn_output
1021
+
1022
+ norm_hidden_states = self.norm2(hidden_states)
1023
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
1024
+
1025
+ ff_output = self.ff(norm_hidden_states)
1026
+ ff_output = gate_mlp * ff_output
1027
+
1028
+ hidden_states = hidden_states + ff_output
1029
+
1030
+ # Process attention outputs for the `encoder_hidden_states`.
1031
+
1032
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
1033
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
1034
+
1035
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
1036
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
1037
+
1038
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
1039
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
1040
+
1041
+ if encoder_hidden_states.dtype == torch.float16:
1042
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
1043
+
1044
+ return encoder_hidden_states, hidden_states
pyramid_dit/flux_modules/modeling_normalization.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+ from typing import Dict, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from diffusers.utils import is_torch_version
9
+
10
+
11
+ if is_torch_version(">=", "2.1.0"):
12
+ LayerNorm = nn.LayerNorm
13
+ else:
14
+ # Has optional bias parameter compared to torch layer norm
15
+ # TODO: replace with torch layernorm once min required torch version >= 2.1
16
+ class LayerNorm(nn.Module):
17
+ def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
18
+ super().__init__()
19
+
20
+ self.eps = eps
21
+
22
+ if isinstance(dim, numbers.Integral):
23
+ dim = (dim,)
24
+
25
+ self.dim = torch.Size(dim)
26
+
27
+ if elementwise_affine:
28
+ self.weight = nn.Parameter(torch.ones(dim))
29
+ self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
30
+ else:
31
+ self.weight = None
32
+ self.bias = None
33
+
34
+ def forward(self, input):
35
+ return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
36
+
37
+
38
+ class FP32LayerNorm(nn.LayerNorm):
39
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
40
+ origin_dtype = inputs.dtype
41
+ return F.layer_norm(
42
+ inputs.float(),
43
+ self.normalized_shape,
44
+ self.weight.float() if self.weight is not None else None,
45
+ self.bias.float() if self.bias is not None else None,
46
+ self.eps,
47
+ ).to(origin_dtype)
48
+
49
+
50
+ class RMSNorm(nn.Module):
51
+ def __init__(self, dim, eps: float, elementwise_affine: bool = True):
52
+ super().__init__()
53
+
54
+ self.eps = eps
55
+
56
+ if isinstance(dim, numbers.Integral):
57
+ dim = (dim,)
58
+
59
+ self.dim = torch.Size(dim)
60
+
61
+ if elementwise_affine:
62
+ self.weight = nn.Parameter(torch.ones(dim))
63
+ else:
64
+ self.weight = None
65
+
66
+ def forward(self, hidden_states):
67
+ input_dtype = hidden_states.dtype
68
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
69
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
70
+
71
+ if self.weight is not None:
72
+ # convert into half-precision if necessary
73
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
74
+ hidden_states = hidden_states.to(self.weight.dtype)
75
+ hidden_states = hidden_states * self.weight
76
+ else:
77
+ hidden_states = hidden_states.to(input_dtype)
78
+
79
+ return hidden_states
80
+
81
+
82
+ class AdaLayerNormContinuous(nn.Module):
83
+ def __init__(
84
+ self,
85
+ embedding_dim: int,
86
+ conditioning_embedding_dim: int,
87
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
88
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
89
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
90
+ # However, this is how it was implemented in the original code, and it's rather likely you should
91
+ # set `elementwise_affine` to False.
92
+ elementwise_affine=True,
93
+ eps=1e-5,
94
+ bias=True,
95
+ norm_type="layer_norm",
96
+ ):
97
+ super().__init__()
98
+ self.silu = nn.SiLU()
99
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
100
+ if norm_type == "layer_norm":
101
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
102
+ elif norm_type == "rms_norm":
103
+ self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
104
+ else:
105
+ raise ValueError(f"unknown norm_type {norm_type}")
106
+
107
+ def forward_with_pad(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor:
108
+ assert hidden_length is not None
109
+
110
+ emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
111
+ batch_emb = torch.zeros_like(x).repeat(1, 1, 2)
112
+
113
+ i_sum = 0
114
+ num_stages = len(hidden_length)
115
+ for i_p, length in enumerate(hidden_length):
116
+ batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None]
117
+ i_sum += length
118
+
119
+ batch_scale, batch_shift = torch.chunk(batch_emb, 2, dim=2)
120
+ x = self.norm(x) * (1 + batch_scale) + batch_shift
121
+ return x
122
+
123
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor:
124
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
125
+ if hidden_length is not None:
126
+ return self.forward_with_pad(x, conditioning_embedding, hidden_length)
127
+ emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
128
+ scale, shift = torch.chunk(emb, 2, dim=1)
129
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
130
+ return x
131
+
132
+
133
+ class AdaLayerNormZero(nn.Module):
134
+ r"""
135
+ Norm layer adaptive layer norm zero (adaLN-Zero).
136
+
137
+ Parameters:
138
+ embedding_dim (`int`): The size of each embedding vector.
139
+ num_embeddings (`int`): The size of the embeddings dictionary.
140
+ """
141
+
142
+ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None):
143
+ super().__init__()
144
+ self.emb = None
145
+
146
+ self.silu = nn.SiLU()
147
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
148
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
149
+
150
+ def forward_with_pad(
151
+ self,
152
+ x: torch.Tensor,
153
+ timestep: Optional[torch.Tensor] = None,
154
+ class_labels: Optional[torch.LongTensor] = None,
155
+ hidden_dtype: Optional[torch.dtype] = None,
156
+ emb: Optional[torch.Tensor] = None,
157
+ hidden_length: Optional[torch.Tensor] = None,
158
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
159
+ # hidden_length: [[20, 30], [30, 40], [50, 60]]
160
+ # x: [bs, seq_len, dim]
161
+ if self.emb is not None:
162
+ emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
163
+
164
+ emb = self.linear(self.silu(emb))
165
+ batch_emb = torch.zeros_like(x).repeat(1, 1, 6)
166
+
167
+ i_sum = 0
168
+ num_stages = len(hidden_length)
169
+ for i_p, length in enumerate(hidden_length):
170
+ batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None]
171
+ i_sum += length
172
+
173
+ batch_shift_msa, batch_scale_msa, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp = batch_emb.chunk(6, dim=2)
174
+ x = self.norm(x) * (1 + batch_scale_msa) + batch_shift_msa
175
+ return x, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp
176
+
177
+ def forward(
178
+ self,
179
+ x: torch.Tensor,
180
+ timestep: Optional[torch.Tensor] = None,
181
+ class_labels: Optional[torch.LongTensor] = None,
182
+ hidden_dtype: Optional[torch.dtype] = None,
183
+ emb: Optional[torch.Tensor] = None,
184
+ hidden_length: Optional[torch.Tensor] = None,
185
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
186
+ if hidden_length is not None:
187
+ return self.forward_with_pad(x, timestep, class_labels, hidden_dtype, emb, hidden_length)
188
+ if self.emb is not None:
189
+ emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
190
+ emb = self.linear(self.silu(emb))
191
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
192
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
193
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
194
+
195
+
196
+ class AdaLayerNormZeroSingle(nn.Module):
197
+ r"""
198
+ Norm layer adaptive layer norm zero (adaLN-Zero).
199
+
200
+ Parameters:
201
+ embedding_dim (`int`): The size of each embedding vector.
202
+ num_embeddings (`int`): The size of the embeddings dictionary.
203
+ """
204
+
205
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
206
+ super().__init__()
207
+
208
+ self.silu = nn.SiLU()
209
+ self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
210
+ if norm_type == "layer_norm":
211
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
212
+ else:
213
+ raise ValueError(
214
+ f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
215
+ )
216
+
217
+ def forward_with_pad(
218
+ self,
219
+ x: torch.Tensor,
220
+ emb: Optional[torch.Tensor] = None,
221
+ hidden_length: Optional[torch.Tensor] = None,
222
+ ):
223
+ emb = self.linear(self.silu(emb))
224
+ batch_emb = torch.zeros_like(x).repeat(1, 1, 3)
225
+
226
+ i_sum = 0
227
+ num_stages = len(hidden_length)
228
+ for i_p, length in enumerate(hidden_length):
229
+ batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None]
230
+ i_sum += length
231
+
232
+ batch_shift_msa, batch_scale_msa, batch_gate_msa = batch_emb.chunk(3, dim=2)
233
+
234
+ x = self.norm(x) * (1 + batch_scale_msa) + batch_shift_msa
235
+
236
+ return x, batch_gate_msa
237
+
238
+ def forward(
239
+ self,
240
+ x: torch.Tensor,
241
+ emb: Optional[torch.Tensor] = None,
242
+ hidden_length: Optional[torch.Tensor] = None,
243
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
244
+ if hidden_length is not None:
245
+ return self.forward_with_pad(x, emb, hidden_length)
246
+ emb = self.linear(self.silu(emb))
247
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
248
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
249
+ return x, gate_msa
pyramid_dit/flux_modules/modeling_pyramid_flux.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Union
2
+
3
+ import torch
4
+ import os
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from tqdm import tqdm
9
+
10
+ from diffusers.utils.torch_utils import randn_tensor
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+ from diffusers.utils import is_torch_version
14
+
15
+ from .modeling_normalization import AdaLayerNormContinuous
16
+ from .modeling_embedding import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
17
+ from .modeling_flux_block import FluxTransformerBlock, FluxSingleTransformerBlock
18
+
19
+ from trainer_misc import (
20
+ is_sequence_parallel_initialized,
21
+ get_sequence_parallel_group,
22
+ get_sequence_parallel_world_size,
23
+ get_sequence_parallel_rank,
24
+ all_to_all,
25
+ )
26
+
27
+
28
+ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
29
+ assert dim % 2 == 0, "The dimension must be even."
30
+
31
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
32
+ omega = 1.0 / (theta**scale)
33
+
34
+ batch_size, seq_length = pos.shape
35
+ out = torch.einsum("...n,d->...nd", pos, omega)
36
+ cos_out = torch.cos(out)
37
+ sin_out = torch.sin(out)
38
+
39
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
40
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
41
+ return out.float()
42
+
43
+
44
+ class EmbedND(nn.Module):
45
+ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
46
+ super().__init__()
47
+ self.dim = dim
48
+ self.theta = theta
49
+ self.axes_dim = axes_dim
50
+
51
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
52
+ n_axes = ids.shape[-1]
53
+ emb = torch.cat(
54
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
55
+ dim=-3,
56
+ )
57
+ return emb.unsqueeze(2)
58
+
59
+
60
+ class PyramidFluxTransformer(ModelMixin, ConfigMixin):
61
+ """
62
+ The Transformer model introduced in Flux.
63
+
64
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
65
+
66
+ Parameters:
67
+ patch_size (`int`): Patch size to turn the input data into small patches.
68
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
69
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
70
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
71
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
72
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
73
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
74
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
75
+ """
76
+
77
+ _supports_gradient_checkpointing = True
78
+
79
+ @register_to_config
80
+ def __init__(
81
+ self,
82
+ patch_size: int = 1,
83
+ in_channels: int = 64,
84
+ num_layers: int = 19,
85
+ num_single_layers: int = 38,
86
+ attention_head_dim: int = 64,
87
+ num_attention_heads: int = 24,
88
+ joint_attention_dim: int = 4096,
89
+ pooled_projection_dim: int = 768,
90
+ axes_dims_rope: List[int] = [16, 24, 24],
91
+ use_flash_attn: bool = False,
92
+ use_temporal_causal: bool = True,
93
+ interp_condition_pos: bool = True,
94
+ use_gradient_checkpointing: bool = False,
95
+ gradient_checkpointing_ratio: float = 0.6,
96
+ ):
97
+ super().__init__()
98
+ self.out_channels = in_channels
99
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
100
+
101
+ self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
102
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
103
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
104
+ )
105
+
106
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
107
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
108
+
109
+ self.transformer_blocks = nn.ModuleList(
110
+ [
111
+ FluxTransformerBlock(
112
+ dim=self.inner_dim,
113
+ num_attention_heads=self.config.num_attention_heads,
114
+ attention_head_dim=self.config.attention_head_dim,
115
+ use_flash_attn=use_flash_attn,
116
+ )
117
+ for i in range(self.config.num_layers)
118
+ ]
119
+ )
120
+
121
+ self.single_transformer_blocks = nn.ModuleList(
122
+ [
123
+ FluxSingleTransformerBlock(
124
+ dim=self.inner_dim,
125
+ num_attention_heads=self.config.num_attention_heads,
126
+ attention_head_dim=self.config.attention_head_dim,
127
+ use_flash_attn=use_flash_attn,
128
+ )
129
+ for i in range(self.config.num_single_layers)
130
+ ]
131
+ )
132
+
133
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
134
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
135
+
136
+ self.gradient_checkpointing = use_gradient_checkpointing
137
+ self.gradient_checkpointing_ratio = gradient_checkpointing_ratio
138
+
139
+ self.use_temporal_causal = use_temporal_causal
140
+ if self.use_temporal_causal:
141
+ print("Using temporal causal attention")
142
+
143
+ self.use_flash_attn = use_flash_attn
144
+ if self.use_flash_attn:
145
+ print("Using Flash attention")
146
+
147
+ self.patch_size = 2 # hard-code for now
148
+
149
+ # init weights
150
+ self.initialize_weights()
151
+
152
+ def initialize_weights(self):
153
+ # Initialize transformer layers:
154
+ def _basic_init(module):
155
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
156
+ torch.nn.init.xavier_uniform_(module.weight)
157
+ if module.bias is not None:
158
+ nn.init.constant_(module.bias, 0)
159
+ self.apply(_basic_init)
160
+
161
+ # Initialize all the conditioning to normal init
162
+ nn.init.normal_(self.time_text_embed.timestep_embedder.linear_1.weight, std=0.02)
163
+ nn.init.normal_(self.time_text_embed.timestep_embedder.linear_2.weight, std=0.02)
164
+ nn.init.normal_(self.time_text_embed.text_embedder.linear_1.weight, std=0.02)
165
+ nn.init.normal_(self.time_text_embed.text_embedder.linear_2.weight, std=0.02)
166
+ nn.init.normal_(self.context_embedder.weight, std=0.02)
167
+
168
+ # Zero-out adaLN modulation layers in DiT blocks:
169
+ for block in self.transformer_blocks:
170
+ nn.init.constant_(block.norm1.linear.weight, 0)
171
+ nn.init.constant_(block.norm1.linear.bias, 0)
172
+ nn.init.constant_(block.norm1_context.linear.weight, 0)
173
+ nn.init.constant_(block.norm1_context.linear.bias, 0)
174
+
175
+ for block in self.single_transformer_blocks:
176
+ nn.init.constant_(block.norm.linear.weight, 0)
177
+ nn.init.constant_(block.norm.linear.bias, 0)
178
+
179
+ # Zero-out output layers:
180
+ nn.init.constant_(self.norm_out.linear.weight, 0)
181
+ nn.init.constant_(self.norm_out.linear.bias, 0)
182
+ nn.init.constant_(self.proj_out.weight, 0)
183
+ nn.init.constant_(self.proj_out.bias, 0)
184
+
185
+ @torch.no_grad()
186
+ def _prepare_image_ids(self, batch_size, temp, height, width, train_height, train_width, device, start_time_stamp=0):
187
+ latent_image_ids = torch.zeros(temp, height, width, 3)
188
+
189
+ # Temporal Rope
190
+ latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(start_time_stamp, start_time_stamp + temp)[:, None, None]
191
+
192
+ # height Rope
193
+ if height != train_height:
194
+ height_pos = F.interpolate(torch.arange(train_height)[None, None, :].float(), height, mode='linear').squeeze(0, 1)
195
+ else:
196
+ height_pos = torch.arange(train_height).float()
197
+
198
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + height_pos[None, :, None]
199
+
200
+ # width rope
201
+ if width != train_width:
202
+ width_pos = F.interpolate(torch.arange(train_width)[None, None, :].float(), width, mode='linear').squeeze(0, 1)
203
+ else:
204
+ width_pos = torch.arange(train_width).float()
205
+
206
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + width_pos[None, None, :]
207
+
208
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1)
209
+ latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c')
210
+
211
+ return latent_image_ids.to(device=device)
212
+
213
+ @torch.no_grad()
214
+ def _prepare_pyramid_image_ids(self, sample, batch_size, device):
215
+ image_ids_list = []
216
+
217
+ for i_b, sample_ in enumerate(sample):
218
+ if not isinstance(sample_, list):
219
+ sample_ = [sample_]
220
+
221
+ cur_image_ids = []
222
+ start_time_stamp = 0
223
+
224
+ train_height = sample_[-1].shape[-2] // self.patch_size
225
+ train_width = sample_[-1].shape[-1] // self.patch_size
226
+
227
+ for clip_ in sample_:
228
+ _, _, temp, height, width = clip_.shape
229
+ height = height // self.patch_size
230
+ width = width // self.patch_size
231
+ cur_image_ids.append(self._prepare_image_ids(batch_size, temp, height, width, train_height, train_width, device, start_time_stamp=start_time_stamp))
232
+ start_time_stamp += temp
233
+
234
+ cur_image_ids = torch.cat(cur_image_ids, dim=1)
235
+ image_ids_list.append(cur_image_ids)
236
+
237
+ return image_ids_list
238
+
239
+ def merge_input(self, sample, encoder_hidden_length, encoder_attention_mask):
240
+ """
241
+ Merge the input video with different resolutions into one sequence
242
+ Sample: From low resolution to high resolution
243
+ """
244
+ if isinstance(sample[0], list):
245
+ device = sample[0][-1].device
246
+ pad_batch_size = sample[0][-1].shape[0]
247
+ else:
248
+ device = sample[0].device
249
+ pad_batch_size = sample[0].shape[0]
250
+
251
+ num_stages = len(sample)
252
+ height_list = [];width_list = [];temp_list = []
253
+ trainable_token_list = []
254
+
255
+ for i_b, sample_ in enumerate(sample):
256
+ if isinstance(sample_, list):
257
+ sample_ = sample_[-1]
258
+ _, _, temp, height, width = sample_.shape
259
+ height = height // self.patch_size
260
+ width = width // self.patch_size
261
+ temp_list.append(temp)
262
+ height_list.append(height)
263
+ width_list.append(width)
264
+ trainable_token_list.append(height * width * temp)
265
+
266
+ # prepare the RoPE IDs,
267
+ image_ids_list = self._prepare_pyramid_image_ids(sample, pad_batch_size, device)
268
+ text_ids = torch.zeros(pad_batch_size, encoder_attention_mask.shape[1], 3).to(device=device)
269
+ input_ids_list = [torch.cat([text_ids, image_ids], dim=1) for image_ids in image_ids_list]
270
+ image_rotary_emb = [self.pos_embed(input_ids) for input_ids in input_ids_list] # [bs, seq_len, 1, head_dim // 2, 2, 2]
271
+
272
+ if is_sequence_parallel_initialized():
273
+ sp_group = get_sequence_parallel_group()
274
+ sp_group_size = get_sequence_parallel_world_size()
275
+ concat_output = True if self.training else False
276
+ image_rotary_emb = [all_to_all(x_.repeat(1, 1, sp_group_size, 1, 1, 1), sp_group, sp_group_size, scatter_dim=2, gather_dim=0, concat_output=concat_output) for x_ in image_rotary_emb]
277
+ input_ids_list = [all_to_all(input_ids.repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0, concat_output=concat_output) for input_ids in input_ids_list]
278
+
279
+ hidden_states, hidden_length = [], []
280
+
281
+ for sample_ in sample:
282
+ video_tokens = []
283
+
284
+ for each_latent in sample_:
285
+ each_latent = rearrange(each_latent, 'b c t h w -> b t h w c')
286
+ each_latent = rearrange(each_latent, 'b t (h p1) (w p2) c -> b (t h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size)
287
+ video_tokens.append(each_latent)
288
+
289
+ video_tokens = torch.cat(video_tokens, dim=1)
290
+ video_tokens = self.x_embedder(video_tokens)
291
+ hidden_states.append(video_tokens)
292
+ hidden_length.append(video_tokens.shape[1])
293
+
294
+ # prepare the attention mask
295
+ if self.use_flash_attn:
296
+ attention_mask = None
297
+ indices_list = []
298
+ for i_p, length in enumerate(hidden_length):
299
+ pad_attention_mask = torch.ones((pad_batch_size, length), dtype=encoder_attention_mask.dtype).to(device)
300
+ pad_attention_mask = torch.cat([encoder_attention_mask[i_p::num_stages], pad_attention_mask], dim=1)
301
+
302
+ if is_sequence_parallel_initialized():
303
+ sp_group = get_sequence_parallel_group()
304
+ sp_group_size = get_sequence_parallel_world_size()
305
+ pad_attention_mask = all_to_all(pad_attention_mask.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0)
306
+ pad_attention_mask = pad_attention_mask.squeeze(2)
307
+
308
+ seqlens_in_batch = pad_attention_mask.sum(dim=-1, dtype=torch.int32)
309
+ indices = torch.nonzero(pad_attention_mask.flatten(), as_tuple=False).flatten()
310
+
311
+ indices_list.append(
312
+ {
313
+ 'indices': indices,
314
+ 'seqlens_in_batch': seqlens_in_batch,
315
+ }
316
+ )
317
+ encoder_attention_mask = indices_list
318
+ else:
319
+ assert encoder_attention_mask.shape[1] == encoder_hidden_length
320
+ real_batch_size = encoder_attention_mask.shape[0]
321
+
322
+ # prepare text ids
323
+ text_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, encoder_hidden_length)
324
+ text_ids = text_ids.to(device)
325
+ text_ids[encoder_attention_mask == 0] = 0
326
+
327
+ # prepare image ids
328
+ image_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, max(hidden_length))
329
+ image_ids = image_ids.to(device)
330
+ image_ids_list = []
331
+ for i_p, length in enumerate(hidden_length):
332
+ image_ids_list.append(image_ids[i_p::num_stages][:, :length])
333
+
334
+ if is_sequence_parallel_initialized():
335
+ sp_group = get_sequence_parallel_group()
336
+ sp_group_size = get_sequence_parallel_world_size()
337
+ concat_output = True if self.training else False
338
+ text_ids = all_to_all(text_ids.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0, concat_output=concat_output).squeeze(2)
339
+ image_ids_list = [all_to_all(image_ids_.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0, concat_output=concat_output).squeeze(2) for image_ids_ in image_ids_list]
340
+
341
+ attention_mask = []
342
+ for i_p in range(len(hidden_length)):
343
+ image_ids = image_ids_list[i_p]
344
+ token_ids = torch.cat([text_ids[i_p::num_stages], image_ids], dim=1)
345
+ stage_attention_mask = rearrange(token_ids, 'b i -> b 1 i 1') == rearrange(token_ids, 'b j -> b 1 1 j') # [bs, 1, q_len, k_len]
346
+ if self.use_temporal_causal:
347
+ input_order_ids = input_ids_list[i_p][:,:,0]
348
+ temporal_causal_mask = rearrange(input_order_ids, 'b i -> b 1 i 1') >= rearrange(input_order_ids, 'b j -> b 1 1 j')
349
+ stage_attention_mask = stage_attention_mask & temporal_causal_mask
350
+ attention_mask.append(stage_attention_mask)
351
+
352
+ return hidden_states, hidden_length, temp_list, height_list, width_list, trainable_token_list, encoder_attention_mask, attention_mask, image_rotary_emb
353
+
354
+ def split_output(self, batch_hidden_states, hidden_length, temps, heights, widths, trainable_token_list):
355
+ # To split the hidden states
356
+ batch_size = batch_hidden_states.shape[0]
357
+ output_hidden_list = []
358
+ batch_hidden_states = torch.split(batch_hidden_states, hidden_length, dim=1)
359
+
360
+ if is_sequence_parallel_initialized():
361
+ sp_group_size = get_sequence_parallel_world_size()
362
+ if self.training:
363
+ batch_size = batch_size // sp_group_size
364
+
365
+ for i_p, length in enumerate(hidden_length):
366
+ width, height, temp = widths[i_p], heights[i_p], temps[i_p]
367
+ trainable_token_num = trainable_token_list[i_p]
368
+ hidden_states = batch_hidden_states[i_p]
369
+
370
+ if is_sequence_parallel_initialized():
371
+ sp_group = get_sequence_parallel_group()
372
+ sp_group_size = get_sequence_parallel_world_size()
373
+
374
+ if not self.training:
375
+ hidden_states = hidden_states.repeat(sp_group_size, 1, 1)
376
+
377
+ hidden_states = all_to_all(hidden_states, sp_group, sp_group_size, scatter_dim=0, gather_dim=1)
378
+
379
+ # only the trainable token are taking part in loss computation
380
+ hidden_states = hidden_states[:, -trainable_token_num:]
381
+
382
+ # unpatchify
383
+ hidden_states = hidden_states.reshape(
384
+ shape=(batch_size, temp, height, width, self.patch_size, self.patch_size, self.out_channels // 4)
385
+ )
386
+ hidden_states = rearrange(hidden_states, "b t h w p1 p2 c -> b t (h p1) (w p2) c")
387
+ hidden_states = rearrange(hidden_states, "b t h w c -> b c t h w")
388
+ output_hidden_list.append(hidden_states)
389
+
390
+ return output_hidden_list
391
+
392
+ def forward(
393
+ self,
394
+ sample: torch.FloatTensor, # [num_stages]
395
+ encoder_hidden_states: torch.Tensor = None,
396
+ encoder_attention_mask: torch.FloatTensor = None,
397
+ pooled_projections: torch.Tensor = None,
398
+ timestep_ratio: torch.LongTensor = None,
399
+ ):
400
+ temb = self.time_text_embed(timestep_ratio, pooled_projections)
401
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
402
+ encoder_hidden_length = encoder_hidden_states.shape[1]
403
+
404
+ # Get the input sequence
405
+ hidden_states, hidden_length, temps, heights, widths, trainable_token_list, encoder_attention_mask, attention_mask, \
406
+ image_rotary_emb = self.merge_input(sample, encoder_hidden_length, encoder_attention_mask)
407
+
408
+ # split the long latents if necessary
409
+ if is_sequence_parallel_initialized():
410
+ sp_group = get_sequence_parallel_group()
411
+ sp_group_size = get_sequence_parallel_world_size()
412
+ concat_output = True if self.training else False
413
+
414
+ # sync the input hidden states
415
+ batch_hidden_states = []
416
+ for i_p, hidden_states_ in enumerate(hidden_states):
417
+ assert hidden_states_.shape[1] % sp_group_size == 0, "The sequence length should be divided by sequence parallel size"
418
+ hidden_states_ = all_to_all(hidden_states_, sp_group, sp_group_size, scatter_dim=1, gather_dim=0, concat_output=concat_output)
419
+ hidden_length[i_p] = hidden_length[i_p] // sp_group_size
420
+ batch_hidden_states.append(hidden_states_)
421
+
422
+ # sync the encoder hidden states
423
+ hidden_states = torch.cat(batch_hidden_states, dim=1)
424
+ encoder_hidden_states = all_to_all(encoder_hidden_states, sp_group, sp_group_size, scatter_dim=1, gather_dim=0, concat_output=concat_output)
425
+ temb = all_to_all(temb.unsqueeze(1).repeat(1, sp_group_size, 1), sp_group, sp_group_size, scatter_dim=1, gather_dim=0, concat_output=concat_output)
426
+ temb = temb.squeeze(1)
427
+ else:
428
+ hidden_states = torch.cat(hidden_states, dim=1)
429
+
430
+ for index_block, block in enumerate(self.transformer_blocks):
431
+ if self.training and self.gradient_checkpointing and (index_block <= int(len(self.transformer_blocks) * self.gradient_checkpointing_ratio)):
432
+
433
+ def create_custom_forward(module):
434
+ def custom_forward(*inputs):
435
+ return module(*inputs)
436
+
437
+ return custom_forward
438
+
439
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
440
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
441
+ create_custom_forward(block),
442
+ hidden_states,
443
+ encoder_hidden_states,
444
+ encoder_attention_mask,
445
+ temb,
446
+ attention_mask,
447
+ hidden_length,
448
+ image_rotary_emb,
449
+ **ckpt_kwargs,
450
+ )
451
+
452
+ else:
453
+ encoder_hidden_states, hidden_states = block(
454
+ hidden_states=hidden_states,
455
+ encoder_hidden_states=encoder_hidden_states,
456
+ encoder_attention_mask=encoder_attention_mask,
457
+ temb=temb,
458
+ attention_mask=attention_mask,
459
+ hidden_length=hidden_length,
460
+ image_rotary_emb=image_rotary_emb,
461
+ )
462
+
463
+ # remerge for single attention block
464
+ num_stages = len(hidden_length)
465
+ batch_hidden_states = list(torch.split(hidden_states, hidden_length, dim=1))
466
+ concat_hidden_length = []
467
+
468
+ if is_sequence_parallel_initialized():
469
+ sp_group = get_sequence_parallel_group()
470
+ sp_group_size = get_sequence_parallel_world_size()
471
+ encoder_hidden_states = all_to_all(encoder_hidden_states, sp_group, sp_group_size, scatter_dim=0, gather_dim=1)
472
+
473
+ for i_p in range(len(hidden_length)):
474
+
475
+ if is_sequence_parallel_initialized():
476
+ sp_group = get_sequence_parallel_group()
477
+ sp_group_size = get_sequence_parallel_world_size()
478
+ batch_hidden_states[i_p] = all_to_all(batch_hidden_states[i_p], sp_group, sp_group_size, scatter_dim=0, gather_dim=1)
479
+
480
+ batch_hidden_states[i_p] = torch.cat([encoder_hidden_states[i_p::num_stages], batch_hidden_states[i_p]], dim=1)
481
+
482
+ if is_sequence_parallel_initialized():
483
+ sp_group = get_sequence_parallel_group()
484
+ sp_group_size = get_sequence_parallel_world_size()
485
+ batch_hidden_states[i_p] = all_to_all(batch_hidden_states[i_p], sp_group, sp_group_size, scatter_dim=1, gather_dim=0)
486
+
487
+ concat_hidden_length.append(batch_hidden_states[i_p].shape[1])
488
+
489
+ hidden_states = torch.cat(batch_hidden_states, dim=1)
490
+
491
+ for index_block, block in enumerate(self.single_transformer_blocks):
492
+ if self.training and self.gradient_checkpointing and (index_block <= int(len(self.single_transformer_blocks) * self.gradient_checkpointing_ratio)):
493
+
494
+ def create_custom_forward(module):
495
+ def custom_forward(*inputs):
496
+ return module(*inputs)
497
+
498
+ return custom_forward
499
+
500
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
501
+ hidden_states = torch.utils.checkpoint.checkpoint(
502
+ create_custom_forward(block),
503
+ hidden_states,
504
+ temb,
505
+ encoder_attention_mask,
506
+ attention_mask,
507
+ concat_hidden_length,
508
+ image_rotary_emb,
509
+ **ckpt_kwargs,
510
+ )
511
+
512
+ else:
513
+ hidden_states = block(
514
+ hidden_states=hidden_states,
515
+ temb=temb,
516
+ encoder_attention_mask=encoder_attention_mask, # used for
517
+ attention_mask=attention_mask,
518
+ hidden_length=concat_hidden_length,
519
+ image_rotary_emb=image_rotary_emb,
520
+ )
521
+
522
+ batch_hidden_states = list(torch.split(hidden_states, concat_hidden_length, dim=1))
523
+
524
+ for i_p in range(len(concat_hidden_length)):
525
+ if is_sequence_parallel_initialized():
526
+ sp_group = get_sequence_parallel_group()
527
+ sp_group_size = get_sequence_parallel_world_size()
528
+ batch_hidden_states[i_p] = all_to_all(batch_hidden_states[i_p], sp_group, sp_group_size, scatter_dim=0, gather_dim=1)
529
+
530
+ batch_hidden_states[i_p] = batch_hidden_states[i_p][:, encoder_hidden_length :, ...]
531
+
532
+ if is_sequence_parallel_initialized():
533
+ sp_group = get_sequence_parallel_group()
534
+ sp_group_size = get_sequence_parallel_world_size()
535
+ batch_hidden_states[i_p] = all_to_all(batch_hidden_states[i_p], sp_group, sp_group_size, scatter_dim=1, gather_dim=0)
536
+
537
+ hidden_states = torch.cat(batch_hidden_states, dim=1)
538
+ hidden_states = self.norm_out(hidden_states, temb, hidden_length=hidden_length)
539
+ hidden_states = self.proj_out(hidden_states)
540
+
541
+ output = self.split_output(hidden_states, hidden_length, temps, heights, widths, trainable_token_list)
542
+
543
+ return output
pyramid_dit/flux_modules/modeling_text_encoder.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+
5
+ from transformers import (
6
+ CLIPTextModel,
7
+ CLIPTokenizer,
8
+ T5EncoderModel,
9
+ T5TokenizerFast,
10
+ )
11
+
12
+ from typing import Any, Callable, Dict, List, Optional, Union
13
+
14
+
15
+ class FluxTextEncoderWithMask(nn.Module):
16
+ def __init__(self, model_path, torch_dtype):
17
+ super().__init__()
18
+ # CLIP-G
19
+ self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'), torch_dtype=torch_dtype)
20
+ self.tokenizer_max_length = (
21
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
22
+ )
23
+ self.text_encoder = CLIPTextModel.from_pretrained(os.path.join(model_path, 'text_encoder'), torch_dtype=torch_dtype)
24
+
25
+ # T5
26
+ self.tokenizer_2 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_2'))
27
+ self.text_encoder_2 = T5EncoderModel.from_pretrained(os.path.join(model_path, 'text_encoder_2'), torch_dtype=torch_dtype)
28
+
29
+ self._freeze()
30
+
31
+ def _freeze(self):
32
+ for param in self.parameters():
33
+ param.requires_grad = False
34
+
35
+ def _get_t5_prompt_embeds(
36
+ self,
37
+ prompt: Union[str, List[str]] = None,
38
+ num_images_per_prompt: int = 1,
39
+ max_sequence_length: int = 128,
40
+ device: Optional[torch.device] = None,
41
+ ):
42
+
43
+ prompt = [prompt] if isinstance(prompt, str) else prompt
44
+ batch_size = len(prompt)
45
+
46
+ text_inputs = self.tokenizer_2(
47
+ prompt,
48
+ padding="max_length",
49
+ max_length=max_sequence_length,
50
+ truncation=True,
51
+ return_length=False,
52
+ return_overflowing_tokens=False,
53
+ return_tensors="pt",
54
+ )
55
+ text_input_ids = text_inputs.input_ids
56
+ prompt_attention_mask = text_inputs.attention_mask
57
+ prompt_attention_mask = prompt_attention_mask.to(device)
58
+
59
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=False)[0]
60
+
61
+ dtype = self.text_encoder_2.dtype
62
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
63
+
64
+ _, seq_len, _ = prompt_embeds.shape
65
+
66
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
67
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
68
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
69
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
70
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
71
+
72
+ return prompt_embeds, prompt_attention_mask
73
+
74
+ def _get_clip_prompt_embeds(
75
+ self,
76
+ prompt: Union[str, List[str]],
77
+ num_images_per_prompt: int = 1,
78
+ device: Optional[torch.device] = None,
79
+ ):
80
+
81
+ prompt = [prompt] if isinstance(prompt, str) else prompt
82
+ batch_size = len(prompt)
83
+
84
+ text_inputs = self.tokenizer(
85
+ prompt,
86
+ padding="max_length",
87
+ max_length=self.tokenizer_max_length,
88
+ truncation=True,
89
+ return_overflowing_tokens=False,
90
+ return_length=False,
91
+ return_tensors="pt",
92
+ )
93
+
94
+ text_input_ids = text_inputs.input_ids
95
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
96
+
97
+ # Use pooled output of CLIPTextModel
98
+ prompt_embeds = prompt_embeds.pooler_output
99
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
100
+
101
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
102
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
103
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
104
+
105
+ return prompt_embeds
106
+
107
+ def encode_prompt(self,
108
+ prompt,
109
+ num_images_per_prompt=1,
110
+ device=None,
111
+ ):
112
+ prompt = [prompt] if isinstance(prompt, str) else prompt
113
+
114
+ batch_size = len(prompt)
115
+
116
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
117
+ prompt=prompt,
118
+ device=device,
119
+ num_images_per_prompt=num_images_per_prompt,
120
+ )
121
+
122
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
123
+ prompt=prompt,
124
+ num_images_per_prompt=num_images_per_prompt,
125
+ device=device,
126
+ )
127
+
128
+ return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
129
+
130
+ def forward(self, input_prompts, device):
131
+ with torch.no_grad():
132
+ prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts, 1, device=device)
133
+
134
+ return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
pyramid_dit/mmdit_modules/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .modeling_text_encoder import SD3TextEncoderWithMask
2
+ from .modeling_pyramid_mmdit import PyramidDiffusionMMDiT
3
+ from .modeling_mmdit_block import JointTransformerBlock
pyramid_dit/mmdit_modules/modeling_embedding.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import math
7
+
8
+ from diffusers.models.activations import get_activation
9
+ from einops import rearrange
10
+
11
+
12
+ def get_1d_sincos_pos_embed(
13
+ embed_dim, num_frames, cls_token=False, extra_tokens=0,
14
+ ):
15
+ t = np.arange(num_frames, dtype=np.float32)
16
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, t) # (T, D)
17
+ if cls_token and extra_tokens > 0:
18
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
19
+ return pos_embed
20
+
21
+
22
+ def get_2d_sincos_pos_embed(
23
+ embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
24
+ ):
25
+ """
26
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
27
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
28
+ """
29
+ if isinstance(grid_size, int):
30
+ grid_size = (grid_size, grid_size)
31
+
32
+ grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
33
+ grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
34
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
35
+ grid = np.stack(grid, axis=0)
36
+
37
+ grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
38
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
39
+ if cls_token and extra_tokens > 0:
40
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
41
+ return pos_embed
42
+
43
+
44
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
45
+ if embed_dim % 2 != 0:
46
+ raise ValueError("embed_dim must be divisible by 2")
47
+
48
+ # use half of dimensions to encode grid_h
49
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
50
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
51
+
52
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
53
+ return emb
54
+
55
+
56
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
57
+ """
58
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
59
+ """
60
+ if embed_dim % 2 != 0:
61
+ raise ValueError("embed_dim must be divisible by 2")
62
+
63
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
64
+ omega /= embed_dim / 2.0
65
+ omega = 1.0 / 10000**omega # (D/2,)
66
+
67
+ pos = pos.reshape(-1) # (M,)
68
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
69
+
70
+ emb_sin = np.sin(out) # (M, D/2)
71
+ emb_cos = np.cos(out) # (M, D/2)
72
+
73
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
74
+ return emb
75
+
76
+
77
+ def get_timestep_embedding(
78
+ timesteps: torch.Tensor,
79
+ embedding_dim: int,
80
+ flip_sin_to_cos: bool = False,
81
+ downscale_freq_shift: float = 1,
82
+ scale: float = 1,
83
+ max_period: int = 10000,
84
+ ):
85
+ """
86
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
87
+ :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional.
88
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
89
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
90
+ """
91
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
92
+
93
+ half_dim = embedding_dim // 2
94
+ exponent = -math.log(max_period) * torch.arange(
95
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
96
+ )
97
+ exponent = exponent / (half_dim - downscale_freq_shift)
98
+
99
+ emb = torch.exp(exponent)
100
+ emb = timesteps[:, None].float() * emb[None, :]
101
+
102
+ # scale embeddings
103
+ emb = scale * emb
104
+
105
+ # concat sine and cosine embeddings
106
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
107
+
108
+ # flip sine and cosine embeddings
109
+ if flip_sin_to_cos:
110
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
111
+
112
+ # zero pad
113
+ if embedding_dim % 2 == 1:
114
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
115
+ return emb
116
+
117
+
118
+ class Timesteps(nn.Module):
119
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
120
+ super().__init__()
121
+ self.num_channels = num_channels
122
+ self.flip_sin_to_cos = flip_sin_to_cos
123
+ self.downscale_freq_shift = downscale_freq_shift
124
+
125
+ def forward(self, timesteps):
126
+ t_emb = get_timestep_embedding(
127
+ timesteps,
128
+ self.num_channels,
129
+ flip_sin_to_cos=self.flip_sin_to_cos,
130
+ downscale_freq_shift=self.downscale_freq_shift,
131
+ )
132
+ return t_emb
133
+
134
+
135
+ class TimestepEmbedding(nn.Module):
136
+ def __init__(
137
+ self,
138
+ in_channels: int,
139
+ time_embed_dim: int,
140
+ act_fn: str = "silu",
141
+ out_dim: int = None,
142
+ post_act_fn: Optional[str] = None,
143
+ sample_proj_bias=True,
144
+ ):
145
+ super().__init__()
146
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
147
+ self.act = get_activation(act_fn)
148
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, sample_proj_bias)
149
+
150
+ def forward(self, sample):
151
+ sample = self.linear_1(sample)
152
+ sample = self.act(sample)
153
+ sample = self.linear_2(sample)
154
+ return sample
155
+
156
+
157
+ class TextProjection(nn.Module):
158
+ def __init__(self, in_features, hidden_size, act_fn="silu"):
159
+ super().__init__()
160
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
161
+ self.act_1 = get_activation(act_fn)
162
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
163
+
164
+ def forward(self, caption):
165
+ hidden_states = self.linear_1(caption)
166
+ hidden_states = self.act_1(hidden_states)
167
+ hidden_states = self.linear_2(hidden_states)
168
+ return hidden_states
169
+
170
+
171
+ class CombinedTimestepConditionEmbeddings(nn.Module):
172
+ def __init__(self, embedding_dim, pooled_projection_dim):
173
+ super().__init__()
174
+
175
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
176
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
177
+ self.text_embedder = TextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
178
+
179
+ def forward(self, timestep, pooled_projection):
180
+ timesteps_proj = self.time_proj(timestep)
181
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
182
+ pooled_projections = self.text_embedder(pooled_projection)
183
+ conditioning = timesteps_emb + pooled_projections
184
+ return conditioning
185
+
186
+
187
+ class CombinedTimestepEmbeddings(nn.Module):
188
+ def __init__(self, embedding_dim):
189
+ super().__init__()
190
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
191
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
192
+
193
+ def forward(self, timestep):
194
+ timesteps_proj = self.time_proj(timestep)
195
+ timesteps_emb = self.timestep_embedder(timesteps_proj) # (N, D)
196
+ return timesteps_emb
197
+
198
+
199
+ class PatchEmbed3D(nn.Module):
200
+ """Support the 3D Tensor input"""
201
+
202
+ def __init__(
203
+ self,
204
+ height=128,
205
+ width=128,
206
+ patch_size=2,
207
+ in_channels=16,
208
+ embed_dim=1536,
209
+ layer_norm=False,
210
+ bias=True,
211
+ interpolation_scale=1,
212
+ pos_embed_type="sincos",
213
+ temp_pos_embed_type='rope',
214
+ pos_embed_max_size=192, # For SD3 cropping
215
+ max_num_frames=64,
216
+ add_temp_pos_embed=False,
217
+ interp_condition_pos=False,
218
+ ):
219
+ super().__init__()
220
+
221
+ num_patches = (height // patch_size) * (width // patch_size)
222
+ self.layer_norm = layer_norm
223
+ self.pos_embed_max_size = pos_embed_max_size
224
+
225
+ self.proj = nn.Conv2d(
226
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
227
+ )
228
+ if layer_norm:
229
+ self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
230
+ else:
231
+ self.norm = None
232
+
233
+ self.patch_size = patch_size
234
+ self.height, self.width = height // patch_size, width // patch_size
235
+ self.base_size = height // patch_size
236
+ self.interpolation_scale = interpolation_scale
237
+ self.add_temp_pos_embed = add_temp_pos_embed
238
+
239
+ # Calculate positional embeddings based on max size or default
240
+ if pos_embed_max_size:
241
+ grid_size = pos_embed_max_size
242
+ else:
243
+ grid_size = int(num_patches**0.5)
244
+
245
+ if pos_embed_type is None:
246
+ self.pos_embed = None
247
+
248
+ elif pos_embed_type == "sincos":
249
+ pos_embed = get_2d_sincos_pos_embed(
250
+ embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
251
+ )
252
+ persistent = True if pos_embed_max_size else False
253
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
254
+
255
+ if add_temp_pos_embed and temp_pos_embed_type == 'sincos':
256
+ time_pos_embed = get_1d_sincos_pos_embed(embed_dim, max_num_frames)
257
+ self.register_buffer("temp_pos_embed", torch.from_numpy(time_pos_embed).float().unsqueeze(0), persistent=True)
258
+
259
+ elif pos_embed_type == "rope":
260
+ print("Using the rotary position embedding")
261
+
262
+ else:
263
+ raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
264
+
265
+ self.pos_embed_type = pos_embed_type
266
+ self.temp_pos_embed_type = temp_pos_embed_type
267
+ self.interp_condition_pos = interp_condition_pos
268
+
269
+ def cropped_pos_embed(self, height, width, ori_height, ori_width):
270
+ """Crops positional embeddings for SD3 compatibility."""
271
+ if self.pos_embed_max_size is None:
272
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
273
+
274
+ height = height // self.patch_size
275
+ width = width // self.patch_size
276
+ ori_height = ori_height // self.patch_size
277
+ ori_width = ori_width // self.patch_size
278
+
279
+ assert ori_height >= height, "The ori_height needs >= height"
280
+ assert ori_width >= width, "The ori_width needs >= width"
281
+
282
+ if height > self.pos_embed_max_size:
283
+ raise ValueError(
284
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
285
+ )
286
+ if width > self.pos_embed_max_size:
287
+ raise ValueError(
288
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
289
+ )
290
+
291
+ if self.interp_condition_pos:
292
+ top = (self.pos_embed_max_size - ori_height) // 2
293
+ left = (self.pos_embed_max_size - ori_width) // 2
294
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
295
+ spatial_pos_embed = spatial_pos_embed[:, top : top + ori_height, left : left + ori_width, :] # [b h w c]
296
+ if ori_height != height or ori_width != width:
297
+ spatial_pos_embed = spatial_pos_embed.permute(0, 3, 1, 2)
298
+ spatial_pos_embed = torch.nn.functional.interpolate(spatial_pos_embed, size=(height, width), mode='bilinear')
299
+ spatial_pos_embed = spatial_pos_embed.permute(0, 2, 3, 1)
300
+ else:
301
+ top = (self.pos_embed_max_size - height) // 2
302
+ left = (self.pos_embed_max_size - width) // 2
303
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
304
+ spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
305
+
306
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
307
+
308
+ return spatial_pos_embed
309
+
310
+ def forward_func(self, latent, time_index=0, ori_height=None, ori_width=None):
311
+ if self.pos_embed_max_size is not None:
312
+ height, width = latent.shape[-2:]
313
+ else:
314
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
315
+
316
+ bs = latent.shape[0]
317
+ temp = latent.shape[2]
318
+
319
+ latent = rearrange(latent, 'b c t h w -> (b t) c h w')
320
+ latent = self.proj(latent)
321
+ latent = latent.flatten(2).transpose(1, 2) # (BT)CHW -> (BT)NC
322
+
323
+ if self.layer_norm:
324
+ latent = self.norm(latent)
325
+
326
+ if self.pos_embed_type == 'sincos':
327
+ # Spatial position embedding, Interpolate or crop positional embeddings as needed
328
+ if self.pos_embed_max_size:
329
+ pos_embed = self.cropped_pos_embed(height, width, ori_height, ori_width)
330
+ else:
331
+ raise NotImplementedError("Not implemented sincos pos embed without sd3 max pos crop")
332
+ if self.height != height or self.width != width:
333
+ pos_embed = get_2d_sincos_pos_embed(
334
+ embed_dim=self.pos_embed.shape[-1],
335
+ grid_size=(height, width),
336
+ base_size=self.base_size,
337
+ interpolation_scale=self.interpolation_scale,
338
+ )
339
+ pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
340
+ else:
341
+ pos_embed = self.pos_embed
342
+
343
+ if self.add_temp_pos_embed and self.temp_pos_embed_type == 'sincos':
344
+ latent_dtype = latent.dtype
345
+ latent = latent + pos_embed
346
+ latent = rearrange(latent, '(b t) n c -> (b n) t c', t=temp)
347
+ latent = latent + self.temp_pos_embed[:, time_index:time_index + temp, :]
348
+ latent = latent.to(latent_dtype)
349
+ latent = rearrange(latent, '(b n) t c -> b t n c', b=bs)
350
+ else:
351
+ latent = (latent + pos_embed).to(latent.dtype)
352
+ latent = rearrange(latent, '(b t) n c -> b t n c', b=bs, t=temp)
353
+
354
+ else:
355
+ assert self.pos_embed_type == "rope", "Only supporting the sincos and rope embedding"
356
+ latent = rearrange(latent, '(b t) n c -> b t n c', b=bs, t=temp)
357
+
358
+ return latent
359
+
360
+ def forward(self, latent):
361
+ """
362
+ Arguments:
363
+ past_condition_latents (Torch.FloatTensor): The past latent during the generation
364
+ flatten_input (bool): True indicate flatten the latent into 1D sequence
365
+ """
366
+
367
+ if isinstance(latent, list):
368
+ output_list = []
369
+
370
+ for latent_ in latent:
371
+ if not isinstance(latent_, list):
372
+ latent_ = [latent_]
373
+
374
+ output_latent = []
375
+ time_index = 0
376
+ ori_height, ori_width = latent_[-1].shape[-2:]
377
+ for each_latent in latent_:
378
+ hidden_state = self.forward_func(each_latent, time_index=time_index, ori_height=ori_height, ori_width=ori_width)
379
+ time_index += each_latent.shape[2]
380
+ hidden_state = rearrange(hidden_state, "b t n c -> b (t n) c")
381
+ output_latent.append(hidden_state)
382
+
383
+ output_latent = torch.cat(output_latent, dim=1)
384
+ output_list.append(output_latent)
385
+
386
+ return output_list
387
+ else:
388
+ hidden_states = self.forward_func(latent)
389
+ hidden_states = rearrange(hidden_states, "b t n c -> b (t n) c")
390
+ return hidden_states
pyramid_dit/mmdit_modules/modeling_mmdit_block.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Tuple, List
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
7
+
8
+ try:
9
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
10
+ from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis
11
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
12
+ except:
13
+ flash_attn_func = None
14
+ flash_attn_qkvpacked_func = None
15
+ flash_attn_varlen_func = None
16
+
17
+ from trainer_misc import (
18
+ is_sequence_parallel_initialized,
19
+ get_sequence_parallel_group,
20
+ get_sequence_parallel_world_size,
21
+ all_to_all,
22
+ )
23
+
24
+ from .modeling_normalization import AdaLayerNormZero, AdaLayerNormContinuous, RMSNorm
25
+
26
+
27
+ class FeedForward(nn.Module):
28
+ r"""
29
+ A feed-forward layer.
30
+
31
+ Parameters:
32
+ dim (`int`): The number of channels in the input.
33
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
34
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
35
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
36
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
37
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
38
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
39
+ """
40
+ def __init__(
41
+ self,
42
+ dim: int,
43
+ dim_out: Optional[int] = None,
44
+ mult: int = 4,
45
+ dropout: float = 0.0,
46
+ activation_fn: str = "geglu",
47
+ final_dropout: bool = False,
48
+ inner_dim=None,
49
+ bias: bool = True,
50
+ ):
51
+ super().__init__()
52
+ if inner_dim is None:
53
+ inner_dim = int(dim * mult)
54
+ dim_out = dim_out if dim_out is not None else dim
55
+
56
+ if activation_fn == "gelu":
57
+ act_fn = GELU(dim, inner_dim, bias=bias)
58
+ if activation_fn == "gelu-approximate":
59
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
60
+ elif activation_fn == "geglu":
61
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
62
+ elif activation_fn == "geglu-approximate":
63
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
64
+
65
+ self.net = nn.ModuleList([])
66
+ # project in
67
+ self.net.append(act_fn)
68
+ # project dropout
69
+ self.net.append(nn.Dropout(dropout))
70
+ # project out
71
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
72
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
73
+ if final_dropout:
74
+ self.net.append(nn.Dropout(dropout))
75
+
76
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
77
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
78
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
79
+ deprecate("scale", "1.0.0", deprecation_message)
80
+ for module in self.net:
81
+ hidden_states = module(hidden_states)
82
+ return hidden_states
83
+
84
+
85
+ class VarlenFlashSelfAttentionWithT5Mask:
86
+
87
+ def __init__(self):
88
+ pass
89
+
90
+ def apply_rope(self, xq, xk, freqs_cis):
91
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
92
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
93
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
94
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
95
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
96
+
97
+ def __call__(
98
+ self, query, key, value, encoder_query, encoder_key, encoder_value,
99
+ heads, scale, hidden_length=None, image_rotary_emb=None, encoder_attention_mask=None,
100
+ ):
101
+ assert encoder_attention_mask is not None, "The encoder-hidden mask needed to be set"
102
+
103
+ batch_size = query.shape[0]
104
+ output_hidden = torch.zeros_like(query)
105
+ output_encoder_hidden = torch.zeros_like(encoder_query)
106
+ encoder_length = encoder_query.shape[1]
107
+
108
+ qkv_list = []
109
+ num_stages = len(hidden_length)
110
+
111
+ encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
112
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
113
+
114
+ i_sum = 0
115
+ for i_p, length in enumerate(hidden_length):
116
+ encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
117
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
118
+ concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, tot_seq, 3, nhead, dim]
119
+
120
+ if image_rotary_emb is not None:
121
+ concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
122
+
123
+ indices = encoder_attention_mask[i_p]['indices']
124
+ qkv_list.append(index_first_axis(rearrange(concat_qkv_tokens, "b s ... -> (b s) ..."), indices))
125
+ i_sum += length
126
+
127
+ token_lengths = [x_.shape[0] for x_ in qkv_list]
128
+ qkv = torch.cat(qkv_list, dim=0)
129
+ query, key, value = qkv.unbind(1)
130
+
131
+ cu_seqlens = torch.cat([x_['seqlens_in_batch'] for x_ in encoder_attention_mask], dim=0)
132
+ max_seqlen_q = cu_seqlens.max().item()
133
+ max_seqlen_k = max_seqlen_q
134
+ cu_seqlens_q = F.pad(torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32), (1, 0))
135
+ cu_seqlens_k = cu_seqlens_q.clone()
136
+
137
+ output = flash_attn_varlen_func(
138
+ query,
139
+ key,
140
+ value,
141
+ cu_seqlens_q=cu_seqlens_q,
142
+ cu_seqlens_k=cu_seqlens_k,
143
+ max_seqlen_q=max_seqlen_q,
144
+ max_seqlen_k=max_seqlen_k,
145
+ dropout_p=0.0,
146
+ causal=False,
147
+ softmax_scale=scale,
148
+ )
149
+
150
+ # To merge the tokens
151
+ i_sum = 0;token_sum = 0
152
+ for i_p, length in enumerate(hidden_length):
153
+ tot_token_num = token_lengths[i_p]
154
+ stage_output = output[token_sum : token_sum + tot_token_num]
155
+ stage_output = pad_input(stage_output, encoder_attention_mask[i_p]['indices'], batch_size, encoder_length + length)
156
+ stage_encoder_hidden_output = stage_output[:, :encoder_length]
157
+ stage_hidden_output = stage_output[:, encoder_length:]
158
+ output_hidden[:, i_sum:i_sum+length] = stage_hidden_output
159
+ output_encoder_hidden[i_p::num_stages] = stage_encoder_hidden_output
160
+ token_sum += tot_token_num
161
+ i_sum += length
162
+
163
+ output_hidden = output_hidden.flatten(2, 3)
164
+ output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
165
+
166
+ return output_hidden, output_encoder_hidden
167
+
168
+
169
+ class SequenceParallelVarlenFlashSelfAttentionWithT5Mask:
170
+
171
+ def __init__(self):
172
+ pass
173
+
174
+ def apply_rope(self, xq, xk, freqs_cis):
175
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
176
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
177
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
178
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
179
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
180
+
181
+ def __call__(
182
+ self, query, key, value, encoder_query, encoder_key, encoder_value,
183
+ heads, scale, hidden_length=None, image_rotary_emb=None, encoder_attention_mask=None,
184
+ ):
185
+ assert encoder_attention_mask is not None, "The encoder-hidden mask needed to be set"
186
+
187
+ batch_size = query.shape[0]
188
+ qkv_list = []
189
+ num_stages = len(hidden_length)
190
+
191
+ encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
192
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
193
+
194
+ # To sync the encoder query, key and values
195
+ sp_group = get_sequence_parallel_group()
196
+ sp_group_size = get_sequence_parallel_world_size()
197
+ encoder_qkv = all_to_all(encoder_qkv, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
198
+
199
+ output_hidden = torch.zeros_like(qkv[:,:,0])
200
+ output_encoder_hidden = torch.zeros_like(encoder_qkv[:,:,0])
201
+ encoder_length = encoder_qkv.shape[1]
202
+
203
+ i_sum = 0
204
+ for i_p, length in enumerate(hidden_length):
205
+ # get the query, key, value from padding sequence
206
+ encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
207
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
208
+ qkv_tokens = all_to_all(qkv_tokens, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
209
+ concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, pad_seq, 3, nhead, dim]
210
+
211
+ if image_rotary_emb is not None:
212
+ concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
213
+
214
+ indices = encoder_attention_mask[i_p]['indices']
215
+ qkv_list.append(index_first_axis(rearrange(concat_qkv_tokens, "b s ... -> (b s) ..."), indices))
216
+ i_sum += length
217
+
218
+ token_lengths = [x_.shape[0] for x_ in qkv_list]
219
+ qkv = torch.cat(qkv_list, dim=0)
220
+ query, key, value = qkv.unbind(1)
221
+
222
+ cu_seqlens = torch.cat([x_['seqlens_in_batch'] for x_ in encoder_attention_mask], dim=0)
223
+ max_seqlen_q = cu_seqlens.max().item()
224
+ max_seqlen_k = max_seqlen_q
225
+ cu_seqlens_q = F.pad(torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32), (1, 0))
226
+ cu_seqlens_k = cu_seqlens_q.clone()
227
+
228
+ output = flash_attn_varlen_func(
229
+ query,
230
+ key,
231
+ value,
232
+ cu_seqlens_q=cu_seqlens_q,
233
+ cu_seqlens_k=cu_seqlens_k,
234
+ max_seqlen_q=max_seqlen_q,
235
+ max_seqlen_k=max_seqlen_k,
236
+ dropout_p=0.0,
237
+ causal=False,
238
+ softmax_scale=scale,
239
+ )
240
+
241
+ # To merge the tokens
242
+ i_sum = 0;token_sum = 0
243
+ for i_p, length in enumerate(hidden_length):
244
+ tot_token_num = token_lengths[i_p]
245
+ stage_output = output[token_sum : token_sum + tot_token_num]
246
+ stage_output = pad_input(stage_output, encoder_attention_mask[i_p]['indices'], batch_size, encoder_length + length * sp_group_size)
247
+ stage_encoder_hidden_output = stage_output[:, :encoder_length]
248
+ stage_hidden_output = stage_output[:, encoder_length:]
249
+ stage_hidden_output = all_to_all(stage_hidden_output, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
250
+ output_hidden[:, i_sum:i_sum+length] = stage_hidden_output
251
+ output_encoder_hidden[i_p::num_stages] = stage_encoder_hidden_output
252
+ token_sum += tot_token_num
253
+ i_sum += length
254
+
255
+ output_encoder_hidden = all_to_all(output_encoder_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
256
+ output_hidden = output_hidden.flatten(2, 3)
257
+ output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
258
+
259
+ return output_hidden, output_encoder_hidden
260
+
261
+
262
+ class VarlenSelfAttentionWithT5Mask:
263
+
264
+ """
265
+ For chunk stage attention without using flash attention
266
+ """
267
+
268
+ def __init__(self):
269
+ pass
270
+
271
+ def apply_rope(self, xq, xk, freqs_cis):
272
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
273
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
274
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
275
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
276
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
277
+
278
+ def __call__(
279
+ self, query, key, value, encoder_query, encoder_key, encoder_value,
280
+ heads, scale, hidden_length=None, image_rotary_emb=None, attention_mask=None,
281
+ ):
282
+ assert attention_mask is not None, "The attention mask needed to be set"
283
+
284
+ encoder_length = encoder_query.shape[1]
285
+ num_stages = len(hidden_length)
286
+
287
+ encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
288
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
289
+
290
+ i_sum = 0
291
+ output_encoder_hidden_list = []
292
+ output_hidden_list = []
293
+
294
+ for i_p, length in enumerate(hidden_length):
295
+ encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
296
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
297
+ concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, tot_seq, 3, nhead, dim]
298
+
299
+ if image_rotary_emb is not None:
300
+ concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
301
+
302
+ query, key, value = concat_qkv_tokens.unbind(2) # [bs, tot_seq, nhead, dim]
303
+ query = query.transpose(1, 2)
304
+ key = key.transpose(1, 2)
305
+ value = value.transpose(1, 2)
306
+
307
+ # with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=False, enable_mem_efficient=True):
308
+ stage_hidden_states = F.scaled_dot_product_attention(
309
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
310
+ )
311
+ stage_hidden_states = stage_hidden_states.transpose(1, 2).flatten(2, 3) # [bs, tot_seq, dim]
312
+
313
+ output_encoder_hidden_list.append(stage_hidden_states[:, :encoder_length])
314
+ output_hidden_list.append(stage_hidden_states[:, encoder_length:])
315
+ i_sum += length
316
+
317
+ output_encoder_hidden = torch.stack(output_encoder_hidden_list, dim=1) # [b n s d]
318
+ output_encoder_hidden = rearrange(output_encoder_hidden, 'b n s d -> (b n) s d')
319
+ output_hidden = torch.cat(output_hidden_list, dim=1)
320
+
321
+ return output_hidden, output_encoder_hidden
322
+
323
+
324
+ class SequenceParallelVarlenSelfAttentionWithT5Mask:
325
+ """
326
+ For chunk stage attention without using flash attention
327
+ """
328
+
329
+ def __init__(self):
330
+ pass
331
+
332
+ def apply_rope(self, xq, xk, freqs_cis):
333
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
334
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
335
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
336
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
337
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
338
+
339
+ def __call__(
340
+ self, query, key, value, encoder_query, encoder_key, encoder_value,
341
+ heads, scale, hidden_length=None, image_rotary_emb=None, attention_mask=None,
342
+ ):
343
+ assert attention_mask is not None, "The attention mask needed to be set"
344
+
345
+ num_stages = len(hidden_length)
346
+
347
+ encoder_qkv = torch.stack([encoder_query, encoder_key, encoder_value], dim=2) # [bs, sub_seq, 3, head, head_dim]
348
+ qkv = torch.stack([query, key, value], dim=2) # [bs, sub_seq, 3, head, head_dim]
349
+
350
+ # To sync the encoder query, key and values
351
+ sp_group = get_sequence_parallel_group()
352
+ sp_group_size = get_sequence_parallel_world_size()
353
+ encoder_qkv = all_to_all(encoder_qkv, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
354
+ encoder_length = encoder_qkv.shape[1]
355
+
356
+ i_sum = 0
357
+ output_encoder_hidden_list = []
358
+ output_hidden_list = []
359
+
360
+ for i_p, length in enumerate(hidden_length):
361
+ encoder_qkv_tokens = encoder_qkv[i_p::num_stages]
362
+ qkv_tokens = qkv[:, i_sum:i_sum+length]
363
+ qkv_tokens = all_to_all(qkv_tokens, sp_group, sp_group_size, scatter_dim=3, gather_dim=1) # [bs, seq, 3, sub_head, head_dim]
364
+ concat_qkv_tokens = torch.cat([encoder_qkv_tokens, qkv_tokens], dim=1) # [bs, tot_seq, 3, nhead, dim]
365
+
366
+ if image_rotary_emb is not None:
367
+ concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1] = self.apply_rope(concat_qkv_tokens[:,:,0], concat_qkv_tokens[:,:,1], image_rotary_emb[i_p])
368
+
369
+ query, key, value = concat_qkv_tokens.unbind(2) # [bs, tot_seq, nhead, dim]
370
+ query = query.transpose(1, 2)
371
+ key = key.transpose(1, 2)
372
+ value = value.transpose(1, 2)
373
+
374
+ stage_hidden_states = F.scaled_dot_product_attention(
375
+ query, key, value, dropout_p=0.0, is_causal=False, attn_mask=attention_mask[i_p],
376
+ )
377
+ stage_hidden_states = stage_hidden_states.transpose(1, 2) # [bs, tot_seq, nhead, dim]
378
+
379
+ output_encoder_hidden_list.append(stage_hidden_states[:, :encoder_length])
380
+
381
+ output_hidden = stage_hidden_states[:, encoder_length:]
382
+ output_hidden = all_to_all(output_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
383
+ output_hidden_list.append(output_hidden)
384
+
385
+ i_sum += length
386
+
387
+ output_encoder_hidden = torch.stack(output_encoder_hidden_list, dim=1) # [b n s nhead d]
388
+ output_encoder_hidden = rearrange(output_encoder_hidden, 'b n s h d -> (b n) s h d')
389
+ output_encoder_hidden = all_to_all(output_encoder_hidden, sp_group, sp_group_size, scatter_dim=1, gather_dim=2)
390
+ output_encoder_hidden = output_encoder_hidden.flatten(2, 3)
391
+ output_hidden = torch.cat(output_hidden_list, dim=1).flatten(2, 3)
392
+
393
+ return output_hidden, output_encoder_hidden
394
+
395
+
396
+ class JointAttention(nn.Module):
397
+
398
+ def __init__(
399
+ self,
400
+ query_dim: int,
401
+ cross_attention_dim: Optional[int] = None,
402
+ heads: int = 8,
403
+ dim_head: int = 64,
404
+ dropout: float = 0.0,
405
+ bias: bool = False,
406
+ qk_norm: Optional[str] = None,
407
+ added_kv_proj_dim: Optional[int] = None,
408
+ out_bias: bool = True,
409
+ eps: float = 1e-5,
410
+ out_dim: int = None,
411
+ context_pre_only=None,
412
+ use_flash_attn=True,
413
+ ):
414
+ """
415
+ Fixing the QKNorm, following the flux, norm the head dimension
416
+ """
417
+ super().__init__()
418
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
419
+ self.query_dim = query_dim
420
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
421
+ self.use_bias = bias
422
+ self.dropout = dropout
423
+
424
+ self.out_dim = out_dim if out_dim is not None else query_dim
425
+ self.context_pre_only = context_pre_only
426
+
427
+ self.scale = dim_head**-0.5
428
+ self.heads = out_dim // dim_head if out_dim is not None else heads
429
+ self.added_kv_proj_dim = added_kv_proj_dim
430
+
431
+ if qk_norm is None:
432
+ self.norm_q = None
433
+ self.norm_k = None
434
+ elif qk_norm == "layer_norm":
435
+ self.norm_q = nn.LayerNorm(dim_head, eps=eps)
436
+ self.norm_k = nn.LayerNorm(dim_head, eps=eps)
437
+ elif qk_norm == 'rms_norm':
438
+ self.norm_q = RMSNorm(dim_head, eps=eps)
439
+ self.norm_k = RMSNorm(dim_head, eps=eps)
440
+ else:
441
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
442
+
443
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
444
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
445
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
446
+
447
+ if self.added_kv_proj_dim is not None:
448
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
449
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
450
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
451
+
452
+ if qk_norm is None:
453
+ self.norm_add_q = None
454
+ self.norm_add_k = None
455
+ elif qk_norm == "layer_norm":
456
+ self.norm_add_q = nn.LayerNorm(dim_head, eps=eps)
457
+ self.norm_add_k = nn.LayerNorm(dim_head, eps=eps)
458
+ elif qk_norm == 'rms_norm':
459
+ self.norm_add_q = RMSNorm(dim_head, eps=eps)
460
+ self.norm_add_k = RMSNorm(dim_head, eps=eps)
461
+ else:
462
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
463
+
464
+ self.to_out = nn.ModuleList([])
465
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
466
+ self.to_out.append(nn.Dropout(dropout))
467
+
468
+ if not self.context_pre_only:
469
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
470
+
471
+ self.use_flash_attn = use_flash_attn
472
+
473
+ if flash_attn_func is None:
474
+ self.use_flash_attn = False
475
+
476
+ # print(f"Using flash-attention: {self.use_flash_attn}")
477
+ if self.use_flash_attn:
478
+ if is_sequence_parallel_initialized():
479
+ self.var_flash_attn = SequenceParallelVarlenFlashSelfAttentionWithT5Mask()
480
+ else:
481
+ self.var_flash_attn = VarlenFlashSelfAttentionWithT5Mask()
482
+ else:
483
+ if is_sequence_parallel_initialized():
484
+ self.var_len_attn = SequenceParallelVarlenSelfAttentionWithT5Mask()
485
+ else:
486
+ self.var_len_attn = VarlenSelfAttentionWithT5Mask()
487
+
488
+
489
+ def forward(
490
+ self,
491
+ hidden_states: torch.FloatTensor,
492
+ encoder_hidden_states: torch.FloatTensor = None,
493
+ encoder_attention_mask: torch.FloatTensor = None,
494
+ attention_mask: torch.FloatTensor = None, # [B, L, S]
495
+ hidden_length: torch.Tensor = None,
496
+ image_rotary_emb: torch.Tensor = None,
497
+ **kwargs,
498
+ ) -> torch.FloatTensor:
499
+ # This function is only used during training
500
+ # `sample` projections.
501
+ query = self.to_q(hidden_states)
502
+ key = self.to_k(hidden_states)
503
+ value = self.to_v(hidden_states)
504
+
505
+ inner_dim = key.shape[-1]
506
+ head_dim = inner_dim // self.heads
507
+
508
+ query = query.view(query.shape[0], -1, self.heads, head_dim)
509
+ key = key.view(key.shape[0], -1, self.heads, head_dim)
510
+ value = value.view(value.shape[0], -1, self.heads, head_dim)
511
+
512
+ if self.norm_q is not None:
513
+ query = self.norm_q(query)
514
+
515
+ if self.norm_k is not None:
516
+ key = self.norm_k(key)
517
+
518
+ # `context` projections.
519
+ encoder_hidden_states_query_proj = self.add_q_proj(encoder_hidden_states)
520
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
521
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
522
+
523
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
524
+ encoder_hidden_states_query_proj.shape[0], -1, self.heads, head_dim
525
+ )
526
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
527
+ encoder_hidden_states_key_proj.shape[0], -1, self.heads, head_dim
528
+ )
529
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
530
+ encoder_hidden_states_value_proj.shape[0], -1, self.heads, head_dim
531
+ )
532
+
533
+ if self.norm_add_q is not None:
534
+ encoder_hidden_states_query_proj = self.norm_add_q(encoder_hidden_states_query_proj)
535
+
536
+ if self.norm_add_k is not None:
537
+ encoder_hidden_states_key_proj = self.norm_add_k(encoder_hidden_states_key_proj)
538
+
539
+ # To cat the hidden and encoder hidden, perform attention compuataion, and then split
540
+ if self.use_flash_attn:
541
+ hidden_states, encoder_hidden_states = self.var_flash_attn(
542
+ query, key, value,
543
+ encoder_hidden_states_query_proj, encoder_hidden_states_key_proj,
544
+ encoder_hidden_states_value_proj, self.heads, self.scale, hidden_length,
545
+ image_rotary_emb, encoder_attention_mask,
546
+ )
547
+ else:
548
+ hidden_states, encoder_hidden_states = self.var_len_attn(
549
+ query, key, value,
550
+ encoder_hidden_states_query_proj, encoder_hidden_states_key_proj,
551
+ encoder_hidden_states_value_proj, self.heads, self.scale, hidden_length,
552
+ image_rotary_emb, attention_mask,
553
+ )
554
+
555
+ # linear proj
556
+ hidden_states = self.to_out[0](hidden_states)
557
+ # dropout
558
+ hidden_states = self.to_out[1](hidden_states)
559
+ if not self.context_pre_only:
560
+ encoder_hidden_states = self.to_add_out(encoder_hidden_states)
561
+
562
+ return hidden_states, encoder_hidden_states
563
+
564
+
565
+ class JointTransformerBlock(nn.Module):
566
+ r"""
567
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
568
+
569
+ Reference: https://arxiv.org/abs/2403.03206
570
+
571
+ Parameters:
572
+ dim (`int`): The number of channels in the input and output.
573
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
574
+ attention_head_dim (`int`): The number of channels in each head.
575
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
576
+ processing of `context` conditions.
577
+ """
578
+
579
+ def __init__(
580
+ self, dim, num_attention_heads, attention_head_dim, qk_norm=None,
581
+ context_pre_only=False, use_flash_attn=True,
582
+ ):
583
+ super().__init__()
584
+
585
+ self.context_pre_only = context_pre_only
586
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
587
+
588
+ self.norm1 = AdaLayerNormZero(dim)
589
+
590
+ if context_norm_type == "ada_norm_continous":
591
+ self.norm1_context = AdaLayerNormContinuous(
592
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
593
+ )
594
+ elif context_norm_type == "ada_norm_zero":
595
+ self.norm1_context = AdaLayerNormZero(dim)
596
+ else:
597
+ raise ValueError(
598
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
599
+ )
600
+
601
+ self.attn = JointAttention(
602
+ query_dim=dim,
603
+ cross_attention_dim=None,
604
+ added_kv_proj_dim=dim,
605
+ dim_head=attention_head_dim // num_attention_heads,
606
+ heads=num_attention_heads,
607
+ out_dim=attention_head_dim,
608
+ qk_norm=qk_norm,
609
+ context_pre_only=context_pre_only,
610
+ bias=True,
611
+ use_flash_attn=use_flash_attn,
612
+ )
613
+
614
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
615
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
616
+
617
+ if not context_pre_only:
618
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
619
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
620
+ else:
621
+ self.norm2_context = None
622
+ self.ff_context = None
623
+
624
+ def forward(
625
+ self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor,
626
+ encoder_attention_mask: torch.FloatTensor, temb: torch.FloatTensor,
627
+ attention_mask: torch.FloatTensor = None, hidden_length: List = None,
628
+ image_rotary_emb: torch.FloatTensor = None,
629
+ ):
630
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb, hidden_length=hidden_length)
631
+
632
+ if self.context_pre_only:
633
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
634
+ else:
635
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
636
+ encoder_hidden_states, emb=temb,
637
+ )
638
+
639
+ # Attention
640
+ attn_output, context_attn_output = self.attn(
641
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
642
+ encoder_attention_mask=encoder_attention_mask, attention_mask=attention_mask,
643
+ hidden_length=hidden_length, image_rotary_emb=image_rotary_emb,
644
+ )
645
+
646
+ # Process attention outputs for the `hidden_states`.
647
+ attn_output = gate_msa * attn_output
648
+ hidden_states = hidden_states + attn_output
649
+
650
+ norm_hidden_states = self.norm2(hidden_states)
651
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
652
+
653
+ ff_output = self.ff(norm_hidden_states)
654
+ ff_output = gate_mlp * ff_output
655
+
656
+ hidden_states = hidden_states + ff_output
657
+
658
+ # Process attention outputs for the `encoder_hidden_states`.
659
+ if self.context_pre_only:
660
+ encoder_hidden_states = None
661
+ else:
662
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
663
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
664
+
665
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
666
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
667
+
668
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
669
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
670
+
671
+ return encoder_hidden_states, hidden_states
pyramid_dit/mmdit_modules/modeling_normalization.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+ from typing import Dict, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from diffusers.utils import is_torch_version
9
+
10
+
11
+ if is_torch_version(">=", "2.1.0"):
12
+ LayerNorm = nn.LayerNorm
13
+ else:
14
+ # Has optional bias parameter compared to torch layer norm
15
+ # TODO: replace with torch layernorm once min required torch version >= 2.1
16
+ class LayerNorm(nn.Module):
17
+ def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
18
+ super().__init__()
19
+
20
+ self.eps = eps
21
+
22
+ if isinstance(dim, numbers.Integral):
23
+ dim = (dim,)
24
+
25
+ self.dim = torch.Size(dim)
26
+
27
+ if elementwise_affine:
28
+ self.weight = nn.Parameter(torch.ones(dim))
29
+ self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
30
+ else:
31
+ self.weight = None
32
+ self.bias = None
33
+
34
+ def forward(self, input):
35
+ return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)
36
+
37
+
38
+ class RMSNorm(nn.Module):
39
+ def __init__(self, dim, eps: float, elementwise_affine: bool = True):
40
+ super().__init__()
41
+
42
+ self.eps = eps
43
+
44
+ if isinstance(dim, numbers.Integral):
45
+ dim = (dim,)
46
+
47
+ self.dim = torch.Size(dim)
48
+
49
+ if elementwise_affine:
50
+ self.weight = nn.Parameter(torch.ones(dim))
51
+ else:
52
+ self.weight = None
53
+
54
+ def forward(self, hidden_states):
55
+ input_dtype = hidden_states.dtype
56
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
57
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
58
+
59
+ if self.weight is not None:
60
+ # convert into half-precision if necessary
61
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
62
+ hidden_states = hidden_states.to(self.weight.dtype)
63
+ hidden_states = hidden_states * self.weight
64
+
65
+ hidden_states = hidden_states.to(input_dtype)
66
+
67
+ return hidden_states
68
+
69
+
70
+ class AdaLayerNormContinuous(nn.Module):
71
+ def __init__(
72
+ self,
73
+ embedding_dim: int,
74
+ conditioning_embedding_dim: int,
75
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
76
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
77
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
78
+ # However, this is how it was implemented in the original code, and it's rather likely you should
79
+ # set `elementwise_affine` to False.
80
+ elementwise_affine=True,
81
+ eps=1e-5,
82
+ bias=True,
83
+ norm_type="layer_norm",
84
+ ):
85
+ super().__init__()
86
+ self.silu = nn.SiLU()
87
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
88
+ if norm_type == "layer_norm":
89
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
90
+ elif norm_type == "rms_norm":
91
+ self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
92
+ else:
93
+ raise ValueError(f"unknown norm_type {norm_type}")
94
+
95
+ def forward_with_pad(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor:
96
+ assert hidden_length is not None
97
+
98
+ emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
99
+ batch_emb = torch.zeros_like(x).repeat(1, 1, 2)
100
+
101
+ i_sum = 0
102
+ num_stages = len(hidden_length)
103
+ for i_p, length in enumerate(hidden_length):
104
+ batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None]
105
+ i_sum += length
106
+
107
+ batch_scale, batch_shift = torch.chunk(batch_emb, 2, dim=2)
108
+ x = self.norm(x) * (1 + batch_scale) + batch_shift
109
+ return x
110
+
111
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor:
112
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
113
+ if hidden_length is not None:
114
+ return self.forward_with_pad(x, conditioning_embedding, hidden_length)
115
+ emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
116
+ scale, shift = torch.chunk(emb, 2, dim=1)
117
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
118
+ return x
119
+
120
+
121
+ class AdaLayerNormZero(nn.Module):
122
+ r"""
123
+ Norm layer adaptive layer norm zero (adaLN-Zero).
124
+
125
+ Parameters:
126
+ embedding_dim (`int`): The size of each embedding vector.
127
+ num_embeddings (`int`): The size of the embeddings dictionary.
128
+ """
129
+
130
+ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None):
131
+ super().__init__()
132
+ self.emb = None
133
+ self.silu = nn.SiLU()
134
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
135
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
136
+
137
+ def forward_with_pad(
138
+ self,
139
+ x: torch.Tensor,
140
+ timestep: Optional[torch.Tensor] = None,
141
+ class_labels: Optional[torch.LongTensor] = None,
142
+ hidden_dtype: Optional[torch.dtype] = None,
143
+ emb: Optional[torch.Tensor] = None,
144
+ hidden_length: Optional[torch.Tensor] = None,
145
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
146
+ # x: [bs, seq_len, dim]
147
+ if self.emb is not None:
148
+ emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
149
+
150
+ emb = self.linear(self.silu(emb))
151
+ batch_emb = torch.zeros_like(x).repeat(1, 1, 6)
152
+
153
+ i_sum = 0
154
+ num_stages = len(hidden_length)
155
+ for i_p, length in enumerate(hidden_length):
156
+ batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None]
157
+ i_sum += length
158
+
159
+ batch_shift_msa, batch_scale_msa, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp = batch_emb.chunk(6, dim=2)
160
+ x = self.norm(x) * (1 + batch_scale_msa) + batch_shift_msa
161
+ return x, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp
162
+
163
+ def forward(
164
+ self,
165
+ x: torch.Tensor,
166
+ timestep: Optional[torch.Tensor] = None,
167
+ class_labels: Optional[torch.LongTensor] = None,
168
+ hidden_dtype: Optional[torch.dtype] = None,
169
+ emb: Optional[torch.Tensor] = None,
170
+ hidden_length: Optional[torch.Tensor] = None,
171
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
172
+ if hidden_length is not None:
173
+ return self.forward_with_pad(x, timestep, class_labels, hidden_dtype, emb, hidden_length)
174
+ if self.emb is not None:
175
+ emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
176
+ emb = self.linear(self.silu(emb))
177
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
178
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
179
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
pyramid_dit/mmdit_modules/modeling_pyramid_mmdit.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+ import torch.nn.functional as F
5
+
6
+ from einops import rearrange
7
+ from diffusers.utils.torch_utils import randn_tensor
8
+ from diffusers.models.modeling_utils import ModelMixin
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers.utils import is_torch_version
11
+ from typing import Any, Callable, Dict, List, Optional, Union
12
+
13
+ from .modeling_embedding import PatchEmbed3D, CombinedTimestepConditionEmbeddings
14
+ from .modeling_normalization import AdaLayerNormContinuous
15
+ from .modeling_mmdit_block import JointTransformerBlock
16
+
17
+ from trainer_misc import (
18
+ is_sequence_parallel_initialized,
19
+ get_sequence_parallel_group,
20
+ get_sequence_parallel_world_size,
21
+ get_sequence_parallel_rank,
22
+ all_to_all,
23
+ )
24
+
25
+ from IPython import embed
26
+
27
+
28
+ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
29
+ assert dim % 2 == 0, "The dimension must be even."
30
+
31
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
32
+ omega = 1.0 / (theta**scale)
33
+
34
+ batch_size, seq_length = pos.shape
35
+ out = torch.einsum("...n,d->...nd", pos, omega)
36
+ cos_out = torch.cos(out)
37
+ sin_out = torch.sin(out)
38
+
39
+ stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
40
+ out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
41
+ return out.float()
42
+
43
+
44
+ class EmbedNDRoPE(nn.Module):
45
+ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
46
+ super().__init__()
47
+ self.dim = dim
48
+ self.theta = theta
49
+ self.axes_dim = axes_dim
50
+
51
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
52
+ n_axes = ids.shape[-1]
53
+ emb = torch.cat(
54
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
55
+ dim=-3,
56
+ )
57
+ return emb.unsqueeze(2)
58
+
59
+
60
+ class PyramidDiffusionMMDiT(ModelMixin, ConfigMixin):
61
+ _supports_gradient_checkpointing = True
62
+
63
+ @register_to_config
64
+ def __init__(
65
+ self,
66
+ sample_size: int = 128,
67
+ patch_size: int = 2,
68
+ in_channels: int = 16,
69
+ num_layers: int = 24,
70
+ attention_head_dim: int = 64,
71
+ num_attention_heads: int = 24,
72
+ caption_projection_dim: int = 1152,
73
+ pooled_projection_dim: int = 2048,
74
+ pos_embed_max_size: int = 192,
75
+ max_num_frames: int = 200,
76
+ qk_norm: str = 'rms_norm',
77
+ pos_embed_type: str = 'rope',
78
+ temp_pos_embed_type: str = 'sincos',
79
+ joint_attention_dim: int = 4096,
80
+ use_gradient_checkpointing: bool = False,
81
+ use_flash_attn: bool = True,
82
+ use_temporal_causal: bool = False,
83
+ use_t5_mask: bool = False,
84
+ add_temp_pos_embed: bool = False,
85
+ interp_condition_pos: bool = False,
86
+ gradient_checkpointing_ratio: float = 0.6,
87
+ ):
88
+ super().__init__()
89
+
90
+ self.out_channels = in_channels
91
+ self.inner_dim = num_attention_heads * attention_head_dim
92
+ assert temp_pos_embed_type in ['rope', 'sincos']
93
+
94
+ # The input latent embeder, using the name pos_embed to remain the same with SD#
95
+ self.pos_embed = PatchEmbed3D(
96
+ height=sample_size,
97
+ width=sample_size,
98
+ patch_size=patch_size,
99
+ in_channels=in_channels,
100
+ embed_dim=self.inner_dim,
101
+ pos_embed_max_size=pos_embed_max_size, # hard-code for now.
102
+ max_num_frames=max_num_frames,
103
+ pos_embed_type=pos_embed_type,
104
+ temp_pos_embed_type=temp_pos_embed_type,
105
+ add_temp_pos_embed=add_temp_pos_embed,
106
+ interp_condition_pos=interp_condition_pos,
107
+ )
108
+
109
+ # The RoPE EMbedding
110
+ if pos_embed_type == 'rope':
111
+ self.rope_embed = EmbedNDRoPE(self.inner_dim, 10000, axes_dim=[16, 24, 24])
112
+ else:
113
+ self.rope_embed = None
114
+
115
+ if temp_pos_embed_type == 'rope':
116
+ self.temp_rope_embed = EmbedNDRoPE(self.inner_dim, 10000, axes_dim=[attention_head_dim])
117
+ else:
118
+ self.temp_rope_embed = None
119
+
120
+ self.time_text_embed = CombinedTimestepConditionEmbeddings(
121
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim,
122
+ )
123
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
124
+
125
+ self.transformer_blocks = nn.ModuleList(
126
+ [
127
+ JointTransformerBlock(
128
+ dim=self.inner_dim,
129
+ num_attention_heads=num_attention_heads,
130
+ attention_head_dim=self.inner_dim,
131
+ qk_norm=qk_norm,
132
+ context_pre_only=i == num_layers - 1,
133
+ use_flash_attn=use_flash_attn,
134
+ )
135
+ for i in range(num_layers)
136
+ ]
137
+ )
138
+
139
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
140
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
141
+ self.gradient_checkpointing = use_gradient_checkpointing
142
+ self.gradient_checkpointing_ratio = gradient_checkpointing_ratio
143
+
144
+ self.patch_size = patch_size
145
+ self.use_flash_attn = use_flash_attn
146
+ self.use_temporal_causal = use_temporal_causal
147
+ self.pos_embed_type = pos_embed_type
148
+ self.temp_pos_embed_type = temp_pos_embed_type
149
+ self.add_temp_pos_embed = add_temp_pos_embed
150
+
151
+ if self.use_temporal_causal:
152
+ print("Using temporal causal attention")
153
+ assert self.use_flash_attn is False, "The flash attention does not support temporal causal"
154
+
155
+ if interp_condition_pos:
156
+ print("We interp the position embedding of condition latents")
157
+
158
+ # init weights
159
+ self.initialize_weights()
160
+
161
+ def initialize_weights(self):
162
+ # Initialize transformer layers:
163
+ def _basic_init(module):
164
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
165
+ torch.nn.init.xavier_uniform_(module.weight)
166
+ if module.bias is not None:
167
+ nn.init.constant_(module.bias, 0)
168
+ self.apply(_basic_init)
169
+
170
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
171
+ w = self.pos_embed.proj.weight.data
172
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
173
+ nn.init.constant_(self.pos_embed.proj.bias, 0)
174
+
175
+ # Initialize all the conditioning to normal init
176
+ nn.init.normal_(self.time_text_embed.timestep_embedder.linear_1.weight, std=0.02)
177
+ nn.init.normal_(self.time_text_embed.timestep_embedder.linear_2.weight, std=0.02)
178
+ nn.init.normal_(self.time_text_embed.text_embedder.linear_1.weight, std=0.02)
179
+ nn.init.normal_(self.time_text_embed.text_embedder.linear_2.weight, std=0.02)
180
+ nn.init.normal_(self.context_embedder.weight, std=0.02)
181
+
182
+ # Zero-out adaLN modulation layers in DiT blocks:
183
+ for block in self.transformer_blocks:
184
+ nn.init.constant_(block.norm1.linear.weight, 0)
185
+ nn.init.constant_(block.norm1.linear.bias, 0)
186
+ nn.init.constant_(block.norm1_context.linear.weight, 0)
187
+ nn.init.constant_(block.norm1_context.linear.bias, 0)
188
+
189
+ # Zero-out output layers:
190
+ nn.init.constant_(self.norm_out.linear.weight, 0)
191
+ nn.init.constant_(self.norm_out.linear.bias, 0)
192
+ nn.init.constant_(self.proj_out.weight, 0)
193
+ nn.init.constant_(self.proj_out.bias, 0)
194
+
195
+ @torch.no_grad()
196
+ def _prepare_latent_image_ids(self, batch_size, temp, height, width, device):
197
+ latent_image_ids = torch.zeros(temp, height, width, 3)
198
+ latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(temp)[:, None, None]
199
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[None, :, None]
200
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, None, :]
201
+
202
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1)
203
+ latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c')
204
+ return latent_image_ids.to(device=device)
205
+
206
+ @torch.no_grad()
207
+ def _prepare_pyramid_latent_image_ids(self, batch_size, temp_list, height_list, width_list, device):
208
+ base_width = width_list[-1]; base_height = height_list[-1]
209
+ assert base_width == max(width_list)
210
+ assert base_height == max(height_list)
211
+
212
+ image_ids_list = []
213
+ for temp, height, width in zip(temp_list, height_list, width_list):
214
+ latent_image_ids = torch.zeros(temp, height, width, 3)
215
+
216
+ if height != base_height:
217
+ height_pos = F.interpolate(torch.arange(base_height)[None, None, :].float(), height, mode='linear').squeeze(0, 1)
218
+ else:
219
+ height_pos = torch.arange(base_height).float()
220
+ if width != base_width:
221
+ width_pos = F.interpolate(torch.arange(base_width)[None, None, :].float(), width, mode='linear').squeeze(0, 1)
222
+ else:
223
+ width_pos = torch.arange(base_width).float()
224
+
225
+ latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(temp)[:, None, None]
226
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + height_pos[None, :, None]
227
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + width_pos[None, None, :]
228
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1)
229
+ latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c').to(device)
230
+ image_ids_list.append(latent_image_ids)
231
+
232
+ return image_ids_list
233
+
234
+ @torch.no_grad()
235
+ def _prepare_temporal_rope_ids(self, batch_size, temp, height, width, device, start_time_stamp=0):
236
+ latent_image_ids = torch.zeros(temp, height, width, 1)
237
+ latent_image_ids[..., 0] = latent_image_ids[..., 0] + torch.arange(start_time_stamp, start_time_stamp + temp)[:, None, None]
238
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1, 1)
239
+ latent_image_ids = rearrange(latent_image_ids, 'b t h w c -> b (t h w) c')
240
+ return latent_image_ids.to(device=device)
241
+
242
+ @torch.no_grad()
243
+ def _prepare_pyramid_temporal_rope_ids(self, sample, batch_size, device):
244
+ image_ids_list = []
245
+
246
+ for i_b, sample_ in enumerate(sample):
247
+ if not isinstance(sample_, list):
248
+ sample_ = [sample_]
249
+
250
+ cur_image_ids = []
251
+ start_time_stamp = 0
252
+
253
+ for clip_ in sample_:
254
+ _, _, temp, height, width = clip_.shape
255
+ height = height // self.patch_size
256
+ width = width // self.patch_size
257
+ cur_image_ids.append(self._prepare_temporal_rope_ids(batch_size, temp, height, width, device, start_time_stamp=start_time_stamp))
258
+ start_time_stamp += temp
259
+
260
+ cur_image_ids = torch.cat(cur_image_ids, dim=1)
261
+ image_ids_list.append(cur_image_ids)
262
+
263
+ return image_ids_list
264
+
265
+ def merge_input(self, sample, encoder_hidden_length, encoder_attention_mask):
266
+ """
267
+ Merge the input video with different resolutions into one sequence
268
+ Sample: From low resolution to high resolution
269
+ """
270
+ if isinstance(sample[0], list):
271
+ device = sample[0][-1].device
272
+ pad_batch_size = sample[0][-1].shape[0]
273
+ else:
274
+ device = sample[0].device
275
+ pad_batch_size = sample[0].shape[0]
276
+
277
+ num_stages = len(sample)
278
+ height_list = [];width_list = [];temp_list = []
279
+ trainable_token_list = []
280
+
281
+ for i_b, sample_ in enumerate(sample):
282
+ if isinstance(sample_, list):
283
+ sample_ = sample_[-1]
284
+ _, _, temp, height, width = sample_.shape
285
+ height = height // self.patch_size
286
+ width = width // self.patch_size
287
+ temp_list.append(temp)
288
+ height_list.append(height)
289
+ width_list.append(width)
290
+ trainable_token_list.append(height * width * temp)
291
+
292
+ # prepare the RoPE embedding if needed
293
+ if self.pos_embed_type == 'rope':
294
+ # TODO: support the 3D Rope for video
295
+ raise NotImplementedError("Not compatible with video generation now")
296
+ text_ids = torch.zeros(pad_batch_size, encoder_hidden_length, 3).to(device=device)
297
+ image_ids_list = self._prepare_pyramid_latent_image_ids(pad_batch_size, temp_list, height_list, width_list, device)
298
+ input_ids_list = [torch.cat([text_ids, image_ids], dim=1) for image_ids in image_ids_list]
299
+ image_rotary_emb = [self.rope_embed(input_ids) for input_ids in input_ids_list] # [bs, seq_len, 1, head_dim // 2, 2, 2]
300
+ else:
301
+ if self.temp_pos_embed_type == 'rope' and self.add_temp_pos_embed:
302
+ image_ids_list = self._prepare_pyramid_temporal_rope_ids(sample, pad_batch_size, device)
303
+ text_ids = torch.zeros(pad_batch_size, encoder_attention_mask.shape[1], 1).to(device=device)
304
+ input_ids_list = [torch.cat([text_ids, image_ids], dim=1) for image_ids in image_ids_list]
305
+ image_rotary_emb = [self.temp_rope_embed(input_ids) for input_ids in input_ids_list] # [bs, seq_len, 1, head_dim // 2, 2, 2]
306
+
307
+ if is_sequence_parallel_initialized():
308
+ sp_group = get_sequence_parallel_group()
309
+ sp_group_size = get_sequence_parallel_world_size()
310
+ concat_output = True if self.training else False
311
+ image_rotary_emb = [all_to_all(x_.repeat(1, 1, sp_group_size, 1, 1, 1), sp_group, sp_group_size, scatter_dim=2, gather_dim=0, concat_output=concat_output) for x_ in image_rotary_emb]
312
+ input_ids_list = [all_to_all(input_ids.repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0, concat_output=concat_output) for input_ids in input_ids_list]
313
+
314
+ else:
315
+ image_rotary_emb = None
316
+
317
+ hidden_states = self.pos_embed(sample) # hidden states is a list of [b c t h w] b = real_b // num_stages
318
+ hidden_length = []
319
+
320
+ for i_b in range(num_stages):
321
+ hidden_length.append(hidden_states[i_b].shape[1])
322
+
323
+ # prepare the attention mask
324
+ if self.use_flash_attn:
325
+ attention_mask = None
326
+ indices_list = []
327
+ for i_p, length in enumerate(hidden_length):
328
+ pad_attention_mask = torch.ones((pad_batch_size, length), dtype=encoder_attention_mask.dtype).to(device)
329
+ pad_attention_mask = torch.cat([encoder_attention_mask[i_p::num_stages], pad_attention_mask], dim=1)
330
+
331
+ if is_sequence_parallel_initialized():
332
+ sp_group = get_sequence_parallel_group()
333
+ sp_group_size = get_sequence_parallel_world_size()
334
+ pad_attention_mask = all_to_all(pad_attention_mask.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0)
335
+ pad_attention_mask = pad_attention_mask.squeeze(2)
336
+
337
+ seqlens_in_batch = pad_attention_mask.sum(dim=-1, dtype=torch.int32)
338
+ indices = torch.nonzero(pad_attention_mask.flatten(), as_tuple=False).flatten()
339
+
340
+ indices_list.append(
341
+ {
342
+ 'indices': indices,
343
+ 'seqlens_in_batch': seqlens_in_batch,
344
+ }
345
+ )
346
+ encoder_attention_mask = indices_list
347
+ else:
348
+ assert encoder_attention_mask.shape[1] == encoder_hidden_length
349
+ real_batch_size = encoder_attention_mask.shape[0]
350
+ # prepare text ids
351
+ text_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, encoder_hidden_length)
352
+ text_ids = text_ids.to(device)
353
+ text_ids[encoder_attention_mask == 0] = 0
354
+
355
+ # prepare image ids
356
+ image_ids = torch.arange(1, real_batch_size + 1, dtype=encoder_attention_mask.dtype).unsqueeze(1).repeat(1, max(hidden_length))
357
+ image_ids = image_ids.to(device)
358
+ image_ids_list = []
359
+ for i_p, length in enumerate(hidden_length):
360
+ image_ids_list.append(image_ids[i_p::num_stages][:, :length])
361
+
362
+ if is_sequence_parallel_initialized():
363
+ sp_group = get_sequence_parallel_group()
364
+ sp_group_size = get_sequence_parallel_world_size()
365
+ concat_output = True if self.training else False
366
+ text_ids = all_to_all(text_ids.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0, concat_output=concat_output).squeeze(2)
367
+ image_ids_list = [all_to_all(image_ids_.unsqueeze(2).repeat(1, 1, sp_group_size), sp_group, sp_group_size, scatter_dim=2, gather_dim=0, concat_output=concat_output).squeeze(2) for image_ids_ in image_ids_list]
368
+
369
+ attention_mask = []
370
+ for i_p in range(len(hidden_length)):
371
+ image_ids = image_ids_list[i_p]
372
+ token_ids = torch.cat([text_ids[i_p::num_stages], image_ids], dim=1)
373
+ stage_attention_mask = rearrange(token_ids, 'b i -> b 1 i 1') == rearrange(token_ids, 'b j -> b 1 1 j') # [bs, 1, q_len, k_len]
374
+ if self.use_temporal_causal:
375
+ input_order_ids = input_ids_list[i_p].squeeze(2)
376
+ temporal_causal_mask = rearrange(input_order_ids, 'b i -> b 1 i 1') >= rearrange(input_order_ids, 'b j -> b 1 1 j')
377
+ stage_attention_mask = stage_attention_mask & temporal_causal_mask
378
+ attention_mask.append(stage_attention_mask)
379
+
380
+ return hidden_states, hidden_length, temp_list, height_list, width_list, trainable_token_list, encoder_attention_mask, attention_mask, image_rotary_emb
381
+
382
+ def split_output(self, batch_hidden_states, hidden_length, temps, heights, widths, trainable_token_list):
383
+ # To split the hidden states
384
+ batch_size = batch_hidden_states.shape[0]
385
+ output_hidden_list = []
386
+ batch_hidden_states = torch.split(batch_hidden_states, hidden_length, dim=1)
387
+
388
+ if is_sequence_parallel_initialized():
389
+ sp_group_size = get_sequence_parallel_world_size()
390
+ if self.training:
391
+ batch_size = batch_size // sp_group_size
392
+
393
+ for i_p, length in enumerate(hidden_length):
394
+ width, height, temp = widths[i_p], heights[i_p], temps[i_p]
395
+ trainable_token_num = trainable_token_list[i_p]
396
+ hidden_states = batch_hidden_states[i_p]
397
+
398
+ if is_sequence_parallel_initialized():
399
+ sp_group = get_sequence_parallel_group()
400
+ sp_group_size = get_sequence_parallel_world_size()
401
+
402
+ if not self.training:
403
+ hidden_states = hidden_states.repeat(sp_group_size, 1, 1)
404
+
405
+ hidden_states = all_to_all(hidden_states, sp_group, sp_group_size, scatter_dim=0, gather_dim=1)
406
+
407
+ # only the trainable token are taking part in loss computation
408
+ hidden_states = hidden_states[:, -trainable_token_num:]
409
+
410
+ # unpatchify
411
+ hidden_states = hidden_states.reshape(
412
+ shape=(batch_size, temp, height, width, self.patch_size, self.patch_size, self.out_channels)
413
+ )
414
+ hidden_states = rearrange(hidden_states, "b t h w p1 p2 c -> b t (h p1) (w p2) c")
415
+ hidden_states = rearrange(hidden_states, "b t h w c -> b c t h w")
416
+ output_hidden_list.append(hidden_states)
417
+
418
+ return output_hidden_list
419
+
420
+ def forward(
421
+ self,
422
+ sample: torch.FloatTensor, # [num_stages]
423
+ encoder_hidden_states: torch.FloatTensor = None,
424
+ encoder_attention_mask: torch.FloatTensor = None,
425
+ pooled_projections: torch.FloatTensor = None,
426
+ timestep_ratio: torch.FloatTensor = None,
427
+ ):
428
+ # Get the timestep embedding
429
+ temb = self.time_text_embed(timestep_ratio, pooled_projections)
430
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
431
+ encoder_hidden_length = encoder_hidden_states.shape[1]
432
+
433
+ # Get the input sequence
434
+ hidden_states, hidden_length, temps, heights, widths, trainable_token_list, encoder_attention_mask, \
435
+ attention_mask, image_rotary_emb = self.merge_input(sample, encoder_hidden_length, encoder_attention_mask)
436
+
437
+ # split the long latents if necessary
438
+ if is_sequence_parallel_initialized():
439
+ sp_group = get_sequence_parallel_group()
440
+ sp_group_size = get_sequence_parallel_world_size()
441
+ concat_output = True if self.training else False
442
+
443
+ # sync the input hidden states
444
+ batch_hidden_states = []
445
+ for i_p, hidden_states_ in enumerate(hidden_states):
446
+ assert hidden_states_.shape[1] % sp_group_size == 0, "The sequence length should be divided by sequence parallel size"
447
+ hidden_states_ = all_to_all(hidden_states_, sp_group, sp_group_size, scatter_dim=1, gather_dim=0, concat_output=concat_output)
448
+ hidden_length[i_p] = hidden_length[i_p] // sp_group_size
449
+ batch_hidden_states.append(hidden_states_)
450
+
451
+ # sync the encoder hidden states
452
+ hidden_states = torch.cat(batch_hidden_states, dim=1)
453
+ encoder_hidden_states = all_to_all(encoder_hidden_states, sp_group, sp_group_size, scatter_dim=1, gather_dim=0, concat_output=concat_output)
454
+ temb = all_to_all(temb.unsqueeze(1).repeat(1, sp_group_size, 1), sp_group, sp_group_size, scatter_dim=1, gather_dim=0, concat_output=concat_output)
455
+ temb = temb.squeeze(1)
456
+ else:
457
+ hidden_states = torch.cat(hidden_states, dim=1)
458
+
459
+ # print(hidden_length)
460
+ for i_b, block in enumerate(self.transformer_blocks):
461
+ if self.training and self.gradient_checkpointing and (i_b >= int(len(self.transformer_blocks) * self.gradient_checkpointing_ratio)):
462
+ def create_custom_forward(module):
463
+ def custom_forward(*inputs):
464
+ return module(*inputs)
465
+
466
+ return custom_forward
467
+
468
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
469
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
470
+ create_custom_forward(block),
471
+ hidden_states,
472
+ encoder_hidden_states,
473
+ encoder_attention_mask,
474
+ temb,
475
+ attention_mask,
476
+ hidden_length,
477
+ image_rotary_emb,
478
+ **ckpt_kwargs,
479
+ )
480
+
481
+ else:
482
+ encoder_hidden_states, hidden_states = block(
483
+ hidden_states=hidden_states,
484
+ encoder_hidden_states=encoder_hidden_states,
485
+ encoder_attention_mask=encoder_attention_mask,
486
+ temb=temb,
487
+ attention_mask=attention_mask,
488
+ hidden_length=hidden_length,
489
+ image_rotary_emb=image_rotary_emb,
490
+ )
491
+
492
+ hidden_states = self.norm_out(hidden_states, temb, hidden_length=hidden_length)
493
+ hidden_states = self.proj_out(hidden_states)
494
+
495
+ output = self.split_output(hidden_states, hidden_length, temps, heights, widths, trainable_token_list)
496
+
497
+ return output
pyramid_dit/mmdit_modules/modeling_text_encoder.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+
5
+ from transformers import (
6
+ CLIPTextModelWithProjection,
7
+ CLIPTokenizer,
8
+ T5EncoderModel,
9
+ T5TokenizerFast,
10
+ )
11
+
12
+ from typing import Any, Callable, Dict, List, Optional, Union
13
+
14
+
15
+ class SD3TextEncoderWithMask(nn.Module):
16
+ def __init__(self, model_path, torch_dtype):
17
+ super().__init__()
18
+ # CLIP-L
19
+ self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer'))
20
+ self.tokenizer_max_length = self.tokenizer.model_max_length
21
+ self.text_encoder = CLIPTextModelWithProjection.from_pretrained(os.path.join(model_path, 'text_encoder'), torch_dtype=torch_dtype)
22
+
23
+ # CLIP-G
24
+ self.tokenizer_2 = CLIPTokenizer.from_pretrained(os.path.join(model_path, 'tokenizer_2'))
25
+ self.text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(os.path.join(model_path, 'text_encoder_2'), torch_dtype=torch_dtype)
26
+
27
+ # T5
28
+ self.tokenizer_3 = T5TokenizerFast.from_pretrained(os.path.join(model_path, 'tokenizer_3'))
29
+ self.text_encoder_3 = T5EncoderModel.from_pretrained(os.path.join(model_path, 'text_encoder_3'), torch_dtype=torch_dtype)
30
+
31
+ self._freeze()
32
+
33
+ def _freeze(self):
34
+ for param in self.parameters():
35
+ param.requires_grad = False
36
+
37
+ def _get_t5_prompt_embeds(
38
+ self,
39
+ prompt: Union[str, List[str]] = None,
40
+ num_images_per_prompt: int = 1,
41
+ device: Optional[torch.device] = None,
42
+ max_sequence_length: int = 128,
43
+ ):
44
+ prompt = [prompt] if isinstance(prompt, str) else prompt
45
+ batch_size = len(prompt)
46
+
47
+ text_inputs = self.tokenizer_3(
48
+ prompt,
49
+ padding="max_length",
50
+ max_length=max_sequence_length,
51
+ truncation=True,
52
+ add_special_tokens=True,
53
+ return_tensors="pt",
54
+ )
55
+ text_input_ids = text_inputs.input_ids
56
+ prompt_attention_mask = text_inputs.attention_mask
57
+ prompt_attention_mask = prompt_attention_mask.to(device)
58
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
59
+ dtype = self.text_encoder_3.dtype
60
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
61
+
62
+ _, seq_len, _ = prompt_embeds.shape
63
+
64
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
65
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
66
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
67
+ prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
68
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
69
+
70
+ return prompt_embeds, prompt_attention_mask
71
+
72
+ def _get_clip_prompt_embeds(
73
+ self,
74
+ prompt: Union[str, List[str]],
75
+ num_images_per_prompt: int = 1,
76
+ device: Optional[torch.device] = None,
77
+ clip_skip: Optional[int] = None,
78
+ clip_model_index: int = 0,
79
+ ):
80
+
81
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
82
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
83
+
84
+ tokenizer = clip_tokenizers[clip_model_index]
85
+ text_encoder = clip_text_encoders[clip_model_index]
86
+
87
+ batch_size = len(prompt)
88
+
89
+ text_inputs = tokenizer(
90
+ prompt,
91
+ padding="max_length",
92
+ max_length=self.tokenizer_max_length,
93
+ truncation=True,
94
+ return_tensors="pt",
95
+ )
96
+
97
+ text_input_ids = text_inputs.input_ids
98
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
99
+ pooled_prompt_embeds = prompt_embeds[0]
100
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
101
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
102
+
103
+ return pooled_prompt_embeds
104
+
105
+ def encode_prompt(self,
106
+ prompt,
107
+ num_images_per_prompt=1,
108
+ clip_skip: Optional[int] = None,
109
+ device=None,
110
+ ):
111
+ prompt = [prompt] if isinstance(prompt, str) else prompt
112
+
113
+ pooled_prompt_embed = self._get_clip_prompt_embeds(
114
+ prompt=prompt,
115
+ device=device,
116
+ num_images_per_prompt=num_images_per_prompt,
117
+ clip_skip=clip_skip,
118
+ clip_model_index=0,
119
+ )
120
+ pooled_prompt_2_embed = self._get_clip_prompt_embeds(
121
+ prompt=prompt,
122
+ device=device,
123
+ num_images_per_prompt=num_images_per_prompt,
124
+ clip_skip=clip_skip,
125
+ clip_model_index=1,
126
+ )
127
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
128
+
129
+ prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
130
+ prompt=prompt,
131
+ num_images_per_prompt=num_images_per_prompt,
132
+ device=device,
133
+ )
134
+ return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
135
+
136
+ def forward(self, input_prompts, device):
137
+ with torch.no_grad():
138
+ prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.encode_prompt(input_prompts, 1, clip_skip=None, device=device)
139
+
140
+ return prompt_embeds, prompt_attention_mask, pooled_prompt_embeds
pyramid_dit/pyramid_dit_for_video_gen_pipeline.py ADDED
@@ -0,0 +1,1279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import gc
4
+ import sys
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from collections import OrderedDict
9
+ from einops import rearrange
10
+ from diffusers.utils.torch_utils import randn_tensor
11
+ import numpy as np
12
+ import math
13
+ import random
14
+ import PIL
15
+ from PIL import Image
16
+ from tqdm import tqdm
17
+ from torchvision import transforms
18
+ from copy import deepcopy
19
+ from typing import Any, Callable, Dict, List, Optional, Union
20
+ from accelerate import Accelerator, cpu_offload
21
+ from diffusion_schedulers import PyramidFlowMatchEulerDiscreteScheduler
22
+ from video_vae.modeling_causal_vae import CausalVideoVAE
23
+
24
+ from trainer_misc import (
25
+ all_to_all,
26
+ is_sequence_parallel_initialized,
27
+ get_sequence_parallel_group,
28
+ get_sequence_parallel_group_rank,
29
+ get_sequence_parallel_rank,
30
+ get_sequence_parallel_world_size,
31
+ get_rank,
32
+ )
33
+
34
+ from .mmdit_modules import (
35
+ PyramidDiffusionMMDiT,
36
+ SD3TextEncoderWithMask,
37
+ )
38
+
39
+ from .flux_modules import (
40
+ PyramidFluxTransformer,
41
+ FluxTextEncoderWithMask,
42
+ )
43
+
44
+
45
+ def compute_density_for_timestep_sampling(
46
+ weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
47
+ ):
48
+ if weighting_scheme == "logit_normal":
49
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
50
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
51
+ u = torch.nn.functional.sigmoid(u)
52
+ elif weighting_scheme == "mode":
53
+ u = torch.rand(size=(batch_size,), device="cpu")
54
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
55
+ else:
56
+ u = torch.rand(size=(batch_size,), device="cpu")
57
+ return u
58
+
59
+
60
+ def build_pyramid_dit(
61
+ model_name : str,
62
+ model_path : str,
63
+ torch_dtype,
64
+ use_flash_attn : bool,
65
+ use_mixed_training: bool,
66
+ interp_condition_pos: bool = True,
67
+ use_gradient_checkpointing: bool = False,
68
+ use_temporal_causal: bool = True,
69
+ gradient_checkpointing_ratio: float = 0.6,
70
+ ):
71
+ model_dtype = torch.float32 if use_mixed_training else torch_dtype
72
+ if model_name == "pyramid_flux":
73
+ dit = PyramidFluxTransformer.from_pretrained(
74
+ model_path, torch_dtype=model_dtype,
75
+ use_gradient_checkpointing=use_gradient_checkpointing,
76
+ gradient_checkpointing_ratio=gradient_checkpointing_ratio,
77
+ use_flash_attn=use_flash_attn, use_temporal_causal=use_temporal_causal,
78
+ interp_condition_pos=interp_condition_pos, axes_dims_rope=[16, 24, 24],
79
+ )
80
+ elif model_name == "pyramid_mmdit":
81
+ dit = PyramidDiffusionMMDiT.from_pretrained(
82
+ model_path, torch_dtype=model_dtype, use_gradient_checkpointing=use_gradient_checkpointing,
83
+ gradient_checkpointing_ratio=gradient_checkpointing_ratio,
84
+ use_flash_attn=use_flash_attn, use_t5_mask=True,
85
+ add_temp_pos_embed=True, temp_pos_embed_type='rope',
86
+ use_temporal_causal=use_temporal_causal, interp_condition_pos=interp_condition_pos,
87
+ )
88
+ else:
89
+ raise NotImplementedError(f"Unsupported DiT architecture, please set the model_name to `pyramid_flux` or `pyramid_mmdit`")
90
+
91
+ return dit
92
+
93
+
94
+ def build_text_encoder(
95
+ model_name : str,
96
+ model_path : str,
97
+ torch_dtype,
98
+ load_text_encoder: bool = True,
99
+ ):
100
+ # The text encoder
101
+ if load_text_encoder:
102
+ if model_name == "pyramid_flux":
103
+ text_encoder = FluxTextEncoderWithMask(model_path, torch_dtype=torch_dtype)
104
+ elif model_name == "pyramid_mmdit":
105
+ text_encoder = SD3TextEncoderWithMask(model_path, torch_dtype=torch_dtype)
106
+ else:
107
+ raise NotImplementedError(f"Unsupported Text Encoder architecture, please set the model_name to `pyramid_flux` or `pyramid_mmdit`")
108
+ else:
109
+ text_encoder = None
110
+
111
+ return text_encoder
112
+
113
+
114
+ class PyramidDiTForVideoGeneration:
115
+ """
116
+ The pyramid dit for both image and video generation, The running class wrapper
117
+ This class is mainly for fixed unit implementation: 1 + n + n + n
118
+ """
119
+ def __init__(self, model_path, model_dtype='bf16', model_name='pyramid_mmdit', use_gradient_checkpointing=False,
120
+ return_log=True, model_variant="diffusion_transformer_768p", timestep_shift=1.0, stage_range=[0, 1/3, 2/3, 1],
121
+ sample_ratios=[1, 1, 1], scheduler_gamma=1/3, use_mixed_training=False, use_flash_attn=False,
122
+ load_text_encoder=True, load_vae=True, max_temporal_length=31, frame_per_unit=1, use_temporal_causal=True,
123
+ corrupt_ratio=1/3, interp_condition_pos=True, stages=[1, 2, 4], video_sync_group=8, gradient_checkpointing_ratio=0.6, **kwargs,
124
+ ):
125
+ super().__init__()
126
+
127
+ if model_dtype == 'bf16':
128
+ torch_dtype = torch.bfloat16
129
+ elif model_dtype == 'fp16':
130
+ torch_dtype = torch.float16
131
+ else:
132
+ torch_dtype = torch.float32
133
+
134
+ self.stages = stages
135
+ self.sample_ratios = sample_ratios
136
+ self.corrupt_ratio = corrupt_ratio
137
+
138
+ dit_path = os.path.join(model_path, model_variant)
139
+
140
+ # The dit
141
+ self.dit = build_pyramid_dit(
142
+ model_name, dit_path, torch_dtype,
143
+ use_flash_attn=use_flash_attn, use_mixed_training=use_mixed_training,
144
+ interp_condition_pos=interp_condition_pos, use_gradient_checkpointing=use_gradient_checkpointing,
145
+ use_temporal_causal=use_temporal_causal, gradient_checkpointing_ratio=gradient_checkpointing_ratio,
146
+ )
147
+
148
+ # The text encoder
149
+ self.text_encoder = build_text_encoder(
150
+ model_name, model_path, torch_dtype, load_text_encoder=load_text_encoder,
151
+ )
152
+ self.load_text_encoder = load_text_encoder
153
+
154
+ # The base video vae decoder
155
+ if load_vae:
156
+ self.vae = CausalVideoVAE.from_pretrained(os.path.join(model_path, 'causal_video_vae'), torch_dtype=torch_dtype, interpolate=False)
157
+ # Freeze vae
158
+ for parameter in self.vae.parameters():
159
+ parameter.requires_grad = False
160
+ else:
161
+ self.vae = None
162
+ self.load_vae = load_vae
163
+
164
+ # For the image latent
165
+ if model_name == "pyramid_flux":
166
+ self.vae_shift_factor = -0.04
167
+ self.vae_scale_factor = 1 / 1.8726
168
+ elif model_name == "pyramid_mmdit":
169
+ self.vae_shift_factor = 0.1490
170
+ self.vae_scale_factor = 1 / 1.8415
171
+ else:
172
+ raise NotImplementedError(f"Unsupported model name : {model_name}")
173
+
174
+ # For the video latent
175
+ self.vae_video_shift_factor = -0.2343
176
+ self.vae_video_scale_factor = 1 / 3.0986
177
+
178
+ self.downsample = 8
179
+
180
+ # Configure the video training hyper-parameters
181
+ # The video sequence: one frame + N * unit
182
+ self.frame_per_unit = frame_per_unit
183
+ self.max_temporal_length = max_temporal_length
184
+ assert (max_temporal_length - 1) % frame_per_unit == 0, "The frame number should be divided by the frame number per unit"
185
+ self.num_units_per_video = 1 + ((max_temporal_length - 1) // frame_per_unit) + int(sum(sample_ratios))
186
+
187
+ self.scheduler = PyramidFlowMatchEulerDiscreteScheduler(
188
+ shift=timestep_shift, stages=len(self.stages),
189
+ stage_range=stage_range, gamma=scheduler_gamma,
190
+ )
191
+ print(f"The start sigmas and end sigmas of each stage is Start: {self.scheduler.start_sigmas}, End: {self.scheduler.end_sigmas}, Ori_start: {self.scheduler.ori_start_sigmas}")
192
+
193
+ self.cfg_rate = 0.1
194
+ self.return_log = return_log
195
+ self.use_flash_attn = use_flash_attn
196
+ self.model_name = model_name
197
+ self.sequential_offload_enabled = False
198
+ self.accumulate_steps = 0
199
+ self.video_sync_group = video_sync_group
200
+
201
+ def _enable_sequential_cpu_offload(self, model):
202
+ self.sequential_offload_enabled = True
203
+ torch_device = torch.device("cuda")
204
+ device_type = torch_device.type
205
+ device = torch.device(f"{device_type}:0")
206
+ offload_buffers = len(model._parameters) > 0
207
+ cpu_offload(model, device, offload_buffers=offload_buffers)
208
+
209
+ def enable_sequential_cpu_offload(self):
210
+ self._enable_sequential_cpu_offload(self.text_encoder)
211
+ self._enable_sequential_cpu_offload(self.dit)
212
+
213
+ def load_checkpoint(self, checkpoint_path, model_key='model', **kwargs):
214
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
215
+ dit_checkpoint = OrderedDict()
216
+ for key in checkpoint:
217
+ if key.startswith('vae') or key.startswith('text_encoder'):
218
+ continue
219
+ if key.startswith('dit'):
220
+ new_key = key.split('.')
221
+ new_key = '.'.join(new_key[1:])
222
+ dit_checkpoint[new_key] = checkpoint[key]
223
+ else:
224
+ dit_checkpoint[key] = checkpoint[key]
225
+
226
+ load_result = self.dit.load_state_dict(dit_checkpoint, strict=True)
227
+ print(f"Load checkpoint from {checkpoint_path}, load result: {load_result}")
228
+
229
+ def load_vae_checkpoint(self, vae_checkpoint_path, model_key='model'):
230
+ checkpoint = torch.load(vae_checkpoint_path, map_location='cpu')
231
+ checkpoint = checkpoint[model_key]
232
+ loaded_checkpoint = OrderedDict()
233
+
234
+ for key in checkpoint.keys():
235
+ if key.startswith('vae.'):
236
+ new_key = key.split('.')
237
+ new_key = '.'.join(new_key[1:])
238
+ loaded_checkpoint[new_key] = checkpoint[key]
239
+
240
+ load_result = self.vae.load_state_dict(loaded_checkpoint)
241
+ print(f"Load the VAE from {vae_checkpoint_path}, load result: {load_result}")
242
+
243
+ @torch.no_grad()
244
+ def add_pyramid_noise(
245
+ self,
246
+ latents_list,
247
+ sample_ratios=[1, 1, 1],
248
+ ):
249
+ """
250
+ add the noise for each pyramidal stage
251
+ noting that, this method is a general strategy for pyramid-flow, it
252
+ can be used for both image and video training.
253
+ You can also use this method to train pyramid-flow with full-sequence
254
+ diffusion in video generation (without using temporal pyramid and autoregressive modeling)
255
+
256
+ Params:
257
+ latent_list: [low_res, mid_res, high_res] The vae latents of all stages
258
+ sample_ratios: The proportion of each stage in the training batch
259
+ """
260
+ noise = torch.randn_like(latents_list[-1])
261
+ device = noise.device
262
+ dtype = latents_list[-1].dtype
263
+ t = noise.shape[2]
264
+
265
+ stages = len(self.stages)
266
+ tot_samples = noise.shape[0]
267
+ assert tot_samples % (int(sum(sample_ratios))) == 0
268
+ assert stages == len(sample_ratios)
269
+
270
+ height, width = noise.shape[-2], noise.shape[-1]
271
+ noise_list = [noise]
272
+ cur_noise = noise
273
+ for i_s in range(stages-1):
274
+ height //= 2;width //= 2
275
+ cur_noise = rearrange(cur_noise, 'b c t h w -> (b t) c h w')
276
+ cur_noise = F.interpolate(cur_noise, size=(height, width), mode='bilinear') * 2
277
+ cur_noise = rearrange(cur_noise, '(b t) c h w -> b c t h w', t=t)
278
+ noise_list.append(cur_noise)
279
+
280
+ noise_list = list(reversed(noise_list)) # make sure from low res to high res
281
+
282
+ # To calculate the padding batchsize and column size
283
+ batch_size = tot_samples // int(sum(sample_ratios))
284
+ column_size = int(sum(sample_ratios))
285
+
286
+ column_to_stage = {}
287
+ i_sum = 0
288
+ for i_s, column_num in enumerate(sample_ratios):
289
+ for index in range(i_sum, i_sum + column_num):
290
+ column_to_stage[index] = i_s
291
+ i_sum += column_num
292
+
293
+ noisy_latents_list = []
294
+ ratios_list = []
295
+ targets_list = []
296
+ timesteps_list = []
297
+ training_steps = self.scheduler.config.num_train_timesteps
298
+
299
+ # from low resolution to high resolution
300
+ for index in range(column_size):
301
+ i_s = column_to_stage[index]
302
+ clean_latent = latents_list[i_s][index::column_size] # [bs, c, t, h, w]
303
+ last_clean_latent = None if i_s == 0 else latents_list[i_s-1][index::column_size]
304
+ start_sigma = self.scheduler.start_sigmas[i_s]
305
+ end_sigma = self.scheduler.end_sigmas[i_s]
306
+
307
+ if i_s == 0:
308
+ start_point = noise_list[i_s][index::column_size]
309
+ else:
310
+ # Get the upsampled latent
311
+ last_clean_latent = rearrange(last_clean_latent, 'b c t h w -> (b t) c h w')
312
+ last_clean_latent = F.interpolate(last_clean_latent, size=(last_clean_latent.shape[-2] * 2, last_clean_latent.shape[-1] * 2), mode='nearest')
313
+ last_clean_latent = rearrange(last_clean_latent, '(b t) c h w -> b c t h w', t=t)
314
+ start_point = start_sigma * noise_list[i_s][index::column_size] + (1 - start_sigma) * last_clean_latent
315
+
316
+ if i_s == stages - 1:
317
+ end_point = clean_latent
318
+ else:
319
+ end_point = end_sigma * noise_list[i_s][index::column_size] + (1 - end_sigma) * clean_latent
320
+
321
+ # To sample a timestep
322
+ u = compute_density_for_timestep_sampling(
323
+ weighting_scheme='random',
324
+ batch_size=batch_size,
325
+ logit_mean=0.0,
326
+ logit_std=1.0,
327
+ mode_scale=1.29,
328
+ )
329
+
330
+ indices = (u * training_steps).long() # Totally 1000 training steps per stage
331
+ indices = indices.clamp(0, training_steps-1)
332
+ timesteps = self.scheduler.timesteps_per_stage[i_s][indices].to(device=device)
333
+ ratios = self.scheduler.sigmas_per_stage[i_s][indices].to(device=device)
334
+
335
+ while len(ratios.shape) < start_point.ndim:
336
+ ratios = ratios.unsqueeze(-1)
337
+
338
+ # interpolate the latent
339
+ noisy_latents = ratios * start_point + (1 - ratios) * end_point
340
+
341
+ last_cond_noisy_sigma = torch.rand(size=(batch_size,), device=device) * self.corrupt_ratio
342
+
343
+ # [stage1_latent, stage2_latent, ..., stagen_latent], which will be concat after patching
344
+ noisy_latents_list.append([noisy_latents.to(dtype)])
345
+ ratios_list.append(ratios.to(dtype))
346
+ timesteps_list.append(timesteps.to(dtype))
347
+ targets_list.append(start_point - end_point) # The standard rectified flow matching objective
348
+
349
+ return noisy_latents_list, ratios_list, timesteps_list, targets_list
350
+
351
+ def sample_stage_length(self, num_stages, max_units=None):
352
+ max_units_in_training = 1 + ((self.max_temporal_length - 1) // self.frame_per_unit)
353
+ cur_rank = get_rank()
354
+
355
+ self.accumulate_steps = self.accumulate_steps + 1
356
+ total_turns = max_units_in_training // self.video_sync_group
357
+ update_turn = self.accumulate_steps % total_turns
358
+
359
+ # # uniformly sampling each position
360
+ cur_highres_unit = max(int((cur_rank % self.video_sync_group + 1) + update_turn * self.video_sync_group), 1)
361
+ cur_mid_res_unit = max(1 + max_units_in_training - cur_highres_unit, 1)
362
+ cur_low_res_unit = cur_mid_res_unit
363
+
364
+ if max_units is not None:
365
+ cur_highres_unit = min(cur_highres_unit, max_units)
366
+ cur_mid_res_unit = min(cur_mid_res_unit, max_units)
367
+ cur_low_res_unit = min(cur_low_res_unit, max_units)
368
+
369
+ length_list = [cur_low_res_unit, cur_mid_res_unit, cur_highres_unit]
370
+
371
+ assert len(length_list) == num_stages
372
+
373
+ return length_list
374
+
375
+ @torch.no_grad()
376
+ def add_pyramid_noise_with_temporal_pyramid(
377
+ self,
378
+ latents_list,
379
+ sample_ratios=[1, 1, 1],
380
+ ):
381
+ """
382
+ add the noise for each pyramidal stage, used for AR video training with temporal pyramid
383
+ Params:
384
+ latent_list: [low_res, mid_res, high_res] The vae latents of all stages
385
+ sample_ratios: The proportion of each stage in the training batch
386
+ """
387
+ stages = len(self.stages)
388
+ tot_samples = latents_list[0].shape[0]
389
+ device = latents_list[0].device
390
+ dtype = latents_list[0].dtype
391
+
392
+ assert tot_samples % (int(sum(sample_ratios))) == 0
393
+ assert stages == len(sample_ratios)
394
+
395
+ noise = torch.randn_like(latents_list[-1])
396
+ t = noise.shape[2]
397
+
398
+ # To allocate the temporal length of each stage, ensuring the sum == constant
399
+ max_units = 1 + (t - 1) // self.frame_per_unit
400
+
401
+ if is_sequence_parallel_initialized():
402
+ max_units_per_sample = torch.LongTensor([max_units]).to(device)
403
+ sp_group = get_sequence_parallel_group()
404
+ sp_group_size = get_sequence_parallel_world_size()
405
+ max_units_per_sample = all_to_all(max_units_per_sample.unsqueeze(1).repeat(1, sp_group_size), sp_group, sp_group_size, scatter_dim=1, gather_dim=0).squeeze(1)
406
+ max_units = min(max_units_per_sample.cpu().tolist())
407
+
408
+ num_units_per_stage = self.sample_stage_length(stages, max_units=max_units) # [The unit number of each stage]
409
+
410
+ # we needs to sync the length alloc of each sequence parallel group
411
+ if is_sequence_parallel_initialized():
412
+ num_units_per_stage = torch.LongTensor(num_units_per_stage).to(device)
413
+ sp_group_rank = get_sequence_parallel_group_rank()
414
+ global_src_rank = sp_group_rank * get_sequence_parallel_world_size()
415
+ torch.distributed.broadcast(num_units_per_stage, global_src_rank, group=get_sequence_parallel_group())
416
+ num_units_per_stage = num_units_per_stage.tolist()
417
+
418
+ height, width = noise.shape[-2], noise.shape[-1]
419
+ noise_list = [noise]
420
+ cur_noise = noise
421
+ for i_s in range(stages-1):
422
+ height //= 2;width //= 2
423
+ cur_noise = rearrange(cur_noise, 'b c t h w -> (b t) c h w')
424
+ cur_noise = F.interpolate(cur_noise, size=(height, width), mode='bilinear') * 2
425
+ cur_noise = rearrange(cur_noise, '(b t) c h w -> b c t h w', t=t)
426
+ noise_list.append(cur_noise)
427
+
428
+ noise_list = list(reversed(noise_list)) # make sure from low res to high res
429
+
430
+ # To calculate the batchsize and column size
431
+ batch_size = tot_samples // int(sum(sample_ratios))
432
+ column_size = int(sum(sample_ratios))
433
+
434
+ column_to_stage = {}
435
+ i_sum = 0
436
+ for i_s, column_num in enumerate(sample_ratios):
437
+ for index in range(i_sum, i_sum + column_num):
438
+ column_to_stage[index] = i_s
439
+ i_sum += column_num
440
+
441
+ noisy_latents_list = []
442
+ ratios_list = []
443
+ targets_list = []
444
+ timesteps_list = []
445
+ training_steps = self.scheduler.config.num_train_timesteps
446
+
447
+ # from low resolution to high resolution
448
+ for index in range(column_size):
449
+ # First prepare the trainable latent construction
450
+ i_s = column_to_stage[index]
451
+ clean_latent = latents_list[i_s][index::column_size] # [bs, c, t, h, w]
452
+ last_clean_latent = None if i_s == 0 else latents_list[i_s-1][index::column_size]
453
+ start_sigma = self.scheduler.start_sigmas[i_s]
454
+ end_sigma = self.scheduler.end_sigmas[i_s]
455
+
456
+ if i_s == 0:
457
+ start_point = noise_list[i_s][index::column_size]
458
+ else:
459
+ # Get the upsampled latent
460
+ last_clean_latent = rearrange(last_clean_latent, 'b c t h w -> (b t) c h w')
461
+ last_clean_latent = F.interpolate(last_clean_latent, size=(last_clean_latent.shape[-2] * 2, last_clean_latent.shape[-1] * 2), mode='nearest')
462
+ last_clean_latent = rearrange(last_clean_latent, '(b t) c h w -> b c t h w', t=t)
463
+ start_point = start_sigma * noise_list[i_s][index::column_size] + (1 - start_sigma) * last_clean_latent
464
+
465
+ if i_s == stages - 1:
466
+ end_point = clean_latent
467
+ else:
468
+ end_point = end_sigma * noise_list[i_s][index::column_size] + (1 - end_sigma) * clean_latent
469
+
470
+ # To sample a timestep
471
+ u = compute_density_for_timestep_sampling(
472
+ weighting_scheme='random',
473
+ batch_size=batch_size,
474
+ logit_mean=0.0,
475
+ logit_std=1.0,
476
+ mode_scale=1.29,
477
+ )
478
+
479
+ indices = (u * training_steps).long() # Totally 1000 training steps per stage
480
+ indices = indices.clamp(0, training_steps-1)
481
+ timesteps = self.scheduler.timesteps_per_stage[i_s][indices].to(device=device)
482
+ ratios = self.scheduler.sigmas_per_stage[i_s][indices].to(device=device)
483
+ noise_ratios = ratios * start_sigma + (1 - ratios) * end_sigma
484
+
485
+ while len(ratios.shape) < start_point.ndim:
486
+ ratios = ratios.unsqueeze(-1)
487
+
488
+ # interpolate the latent
489
+ noisy_latents = ratios * start_point + (1 - ratios) * end_point
490
+
491
+ # The flow matching object
492
+ target_latents = start_point - end_point
493
+
494
+ # pad the noisy previous
495
+ num_units = num_units_per_stage[i_s]
496
+ num_units = min(num_units, 1 + (t - 1) // self.frame_per_unit)
497
+ actual_frames = 1 + (num_units - 1) * self.frame_per_unit
498
+
499
+ noisy_latents = noisy_latents[:, :, :actual_frames]
500
+ target_latents = target_latents[:, :, :actual_frames]
501
+
502
+ clean_latent = clean_latent[:, :, :actual_frames]
503
+ stage_noise = noise_list[i_s][index::column_size][:, :, :actual_frames]
504
+
505
+ # only the last latent takes part in training
506
+ noisy_latents = noisy_latents[:, :, -self.frame_per_unit:]
507
+ target_latents = target_latents[:, :, -self.frame_per_unit:]
508
+
509
+ last_cond_noisy_sigma = torch.rand(size=(batch_size,), device=device) * self.corrupt_ratio
510
+
511
+ if num_units == 1:
512
+ stage_input = [noisy_latents.to(dtype)]
513
+ else:
514
+ # add the random noise for the last cond clip
515
+ last_cond_latent = clean_latent[:, :, -(2*self.frame_per_unit):-self.frame_per_unit]
516
+
517
+ while len(last_cond_noisy_sigma.shape) < last_cond_latent.ndim:
518
+ last_cond_noisy_sigma = last_cond_noisy_sigma.unsqueeze(-1)
519
+
520
+ # We adding some noise to corrupt the clean condition
521
+ last_cond_latent = last_cond_noisy_sigma * torch.randn_like(last_cond_latent) + (1 - last_cond_noisy_sigma) * last_cond_latent
522
+
523
+ # concat the corrupted condition and the input noisy latents
524
+ stage_input = [noisy_latents.to(dtype), last_cond_latent.to(dtype)]
525
+
526
+ cur_unit_num = 2
527
+ cur_stage = i_s
528
+
529
+ while cur_unit_num < num_units:
530
+ cur_stage = max(cur_stage - 1, 0)
531
+ if cur_stage == 0:
532
+ break
533
+ cur_unit_num += 1
534
+ cond_latents = latents_list[cur_stage][index::column_size][:, :, :actual_frames]
535
+ cond_latents = cond_latents[:, :, -(cur_unit_num * self.frame_per_unit) : -((cur_unit_num - 1) * self.frame_per_unit)]
536
+ cond_latents = last_cond_noisy_sigma * torch.randn_like(cond_latents) + (1 - last_cond_noisy_sigma) * cond_latents
537
+ stage_input.append(cond_latents.to(dtype))
538
+
539
+ if cur_stage == 0 and cur_unit_num < num_units:
540
+ cond_latents = latents_list[0][index::column_size][:, :, :actual_frames]
541
+ cond_latents = cond_latents[:, :, :-(cur_unit_num * self.frame_per_unit)]
542
+
543
+ cond_latents = last_cond_noisy_sigma * torch.randn_like(cond_latents) + (1 - last_cond_noisy_sigma) * cond_latents
544
+ stage_input.append(cond_latents.to(dtype))
545
+
546
+ stage_input = list(reversed(stage_input))
547
+ noisy_latents_list.append(stage_input)
548
+ ratios_list.append(ratios.to(dtype))
549
+ timesteps_list.append(timesteps.to(dtype))
550
+ targets_list.append(target_latents) # The standard rectified flow matching objective
551
+
552
+ return noisy_latents_list, ratios_list, timesteps_list, targets_list
553
+
554
+ @torch.no_grad()
555
+ def get_pyramid_latent(self, x, stage_num):
556
+ # x is the origin vae latent
557
+ vae_latent_list = []
558
+ vae_latent_list.append(x)
559
+
560
+ temp, height, width = x.shape[-3], x.shape[-2], x.shape[-1]
561
+ for _ in range(stage_num):
562
+ height //= 2
563
+ width //= 2
564
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
565
+ x = torch.nn.functional.interpolate(x, size=(height, width), mode='bilinear')
566
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=temp)
567
+ vae_latent_list.append(x)
568
+
569
+ vae_latent_list = list(reversed(vae_latent_list))
570
+ return vae_latent_list
571
+
572
+ @torch.no_grad()
573
+ def get_vae_latent(self, video, use_temporal_pyramid=True):
574
+ if self.load_vae:
575
+ assert video.shape[1] == 3, "The vae is loaded, the input should be raw pixels"
576
+ video = self.vae.encode(video).latent_dist.sample() # [b c t h w]
577
+
578
+ if video.shape[2] == 1:
579
+ # is image
580
+ video = (video - self.vae_shift_factor) * self.vae_scale_factor
581
+ else:
582
+ # is video
583
+ video[:, :, :1] = (video[:, :, :1] - self.vae_shift_factor) * self.vae_scale_factor
584
+ video[:, :, 1:] = (video[:, :, 1:] - self.vae_video_shift_factor) * self.vae_video_scale_factor
585
+
586
+ # Get the pyramidal stages
587
+ vae_latent_list = self.get_pyramid_latent(video, len(self.stages) - 1)
588
+
589
+ if use_temporal_pyramid:
590
+ noisy_latents_list, ratios_list, timesteps_list, targets_list = self.add_pyramid_noise_with_temporal_pyramid(vae_latent_list, self.sample_ratios)
591
+ else:
592
+ # Only use the spatial pyramidal (without temporal ar)
593
+ noisy_latents_list, ratios_list, timesteps_list, targets_list = self.add_pyramid_noise(vae_latent_list, self.sample_ratios)
594
+
595
+ return noisy_latents_list, ratios_list, timesteps_list, targets_list
596
+
597
+ @torch.no_grad()
598
+ def get_text_embeddings(self, text, rand_idx, device):
599
+ if self.load_text_encoder:
600
+ batch_size = len(text) # Text is a str list
601
+ for idx in range(batch_size):
602
+ if rand_idx[idx].item():
603
+ text[idx] = ''
604
+ return self.text_encoder(text, device) # [b s c]
605
+ else:
606
+ batch_size = len(text['prompt_embeds'])
607
+
608
+ for idx in range(batch_size):
609
+ if rand_idx[idx].item():
610
+ text['prompt_embeds'][idx] = self.null_text_embeds['prompt_embed'].to(device)
611
+ text['prompt_attention_mask'][idx] = self.null_text_embeds['prompt_attention_mask'].to(device)
612
+ text['pooled_prompt_embeds'][idx] = self.null_text_embeds['pooled_prompt_embed'].to(device)
613
+
614
+ return text['prompt_embeds'], text['prompt_attention_mask'], text['pooled_prompt_embeds']
615
+
616
+ def calculate_loss(self, model_preds_list, targets_list):
617
+ loss_list = []
618
+
619
+ for model_pred, target in zip(model_preds_list, targets_list):
620
+ # Compute the loss.
621
+ loss_weight = torch.ones_like(target)
622
+
623
+ loss = torch.mean(
624
+ (loss_weight.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
625
+ 1,
626
+ )
627
+ loss_list.append(loss)
628
+
629
+ diffusion_loss = torch.cat(loss_list, dim=0).mean()
630
+
631
+ if self.return_log:
632
+ log = {}
633
+ split="train"
634
+ log[f'{split}/loss'] = diffusion_loss.detach()
635
+ return diffusion_loss, log
636
+ else:
637
+ return diffusion_loss, {}
638
+
639
+ def __call__(self, video, text, identifier=['video'], use_temporal_pyramid=True, accelerator: Accelerator=None):
640
+ xdim = video.ndim
641
+ device = video.device
642
+
643
+ if 'video' in identifier:
644
+ assert 'image' not in identifier
645
+ is_image = False
646
+ else:
647
+ assert 'video' not in identifier
648
+ video = video.unsqueeze(2) # 'b c h w -> b c 1 h w'
649
+ is_image = True
650
+
651
+ # TODO: now have 3 stages, firstly get the vae latents
652
+ with torch.no_grad(), accelerator.autocast():
653
+ # 10% prob drop the text
654
+ batch_size = len(video)
655
+ rand_idx = torch.rand((batch_size,)) <= self.cfg_rate
656
+ prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.get_text_embeddings(text, rand_idx, device)
657
+ noisy_latents_list, ratios_list, timesteps_list, targets_list = self.get_vae_latent(video, use_temporal_pyramid=use_temporal_pyramid)
658
+
659
+ timesteps = torch.cat([timestep.unsqueeze(-1) for timestep in timesteps_list], dim=-1)
660
+ timesteps = timesteps.reshape(-1)
661
+
662
+ assert timesteps.shape[0] == prompt_embeds.shape[0]
663
+
664
+ # DiT forward
665
+ model_preds_list = self.dit(
666
+ sample=noisy_latents_list,
667
+ timestep_ratio=timesteps,
668
+ encoder_hidden_states=prompt_embeds,
669
+ encoder_attention_mask=prompt_attention_mask,
670
+ pooled_projections=pooled_prompt_embeds,
671
+ )
672
+
673
+ # calculate the loss
674
+ return self.calculate_loss(model_preds_list, targets_list)
675
+
676
+ def prepare_latents(
677
+ self,
678
+ batch_size,
679
+ num_channels_latents,
680
+ temp,
681
+ height,
682
+ width,
683
+ dtype,
684
+ device,
685
+ generator,
686
+ ):
687
+ shape = (
688
+ batch_size,
689
+ num_channels_latents,
690
+ int(temp),
691
+ int(height) // self.downsample,
692
+ int(width) // self.downsample,
693
+ )
694
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
695
+ return latents
696
+
697
+ def sample_block_noise(self, bs, ch, temp, height, width):
698
+ gamma = self.scheduler.config.gamma
699
+ dist = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(4), torch.eye(4) * (1 + gamma) - torch.ones(4, 4) * gamma)
700
+ block_number = bs * ch * temp * (height // 2) * (width // 2)
701
+ noise = torch.stack([dist.sample() for _ in range(block_number)]) # [block number, 4]
702
+ noise = rearrange(noise, '(b c t h w) (p q) -> b c t (h p) (w q)',b=bs,c=ch,t=temp,h=height//2,w=width//2,p=2,q=2)
703
+ return noise
704
+
705
+ @torch.no_grad()
706
+ def generate_one_unit(
707
+ self,
708
+ latents,
709
+ past_conditions, # List of past conditions, contains the conditions of each stage
710
+ prompt_embeds,
711
+ prompt_attention_mask,
712
+ pooled_prompt_embeds,
713
+ num_inference_steps,
714
+ height,
715
+ width,
716
+ temp,
717
+ device,
718
+ dtype,
719
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
720
+ is_first_frame: bool = False,
721
+ ):
722
+ stages = self.stages
723
+ intermed_latents = []
724
+
725
+ for i_s in range(len(stages)):
726
+ self.scheduler.set_timesteps(num_inference_steps[i_s], i_s, device=device)
727
+ timesteps = self.scheduler.timesteps
728
+
729
+ if i_s > 0:
730
+ height *= 2; width *= 2
731
+ latents = rearrange(latents, 'b c t h w -> (b t) c h w')
732
+ latents = F.interpolate(latents, size=(height, width), mode='nearest')
733
+ latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp)
734
+ # Fix the stage
735
+ ori_sigma = 1 - self.scheduler.ori_start_sigmas[i_s] # the original coeff of signal
736
+ gamma = self.scheduler.config.gamma
737
+ alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)
738
+ beta = alpha * (1 - ori_sigma) / math.sqrt(gamma)
739
+
740
+ bs, ch, temp, height, width = latents.shape
741
+ noise = self.sample_block_noise(bs, ch, temp, height, width)
742
+ noise = noise.to(device=device, dtype=dtype)
743
+ latents = alpha * latents + beta * noise # To fix the block artifact
744
+
745
+ for idx, t in enumerate(timesteps):
746
+ # expand the latents if we are doing classifier free guidance
747
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
748
+
749
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
750
+ timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
751
+
752
+ if is_sequence_parallel_initialized():
753
+ # sync the input latent
754
+ sp_group_rank = get_sequence_parallel_group_rank()
755
+ global_src_rank = sp_group_rank * get_sequence_parallel_world_size()
756
+ torch.distributed.broadcast(latent_model_input, global_src_rank, group=get_sequence_parallel_group())
757
+
758
+ latent_model_input = past_conditions[i_s] + [latent_model_input]
759
+
760
+ noise_pred = self.dit(
761
+ sample=[latent_model_input],
762
+ timestep_ratio=timestep,
763
+ encoder_hidden_states=prompt_embeds,
764
+ encoder_attention_mask=prompt_attention_mask,
765
+ pooled_projections=pooled_prompt_embeds,
766
+ )
767
+
768
+ noise_pred = noise_pred[0]
769
+
770
+ # perform guidance
771
+ if self.do_classifier_free_guidance:
772
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
773
+ if is_first_frame:
774
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
775
+ else:
776
+ noise_pred = noise_pred_uncond + self.video_guidance_scale * (noise_pred_text - noise_pred_uncond)
777
+
778
+ # compute the previous noisy sample x_t -> x_t-1
779
+ latents = self.scheduler.step(
780
+ model_output=noise_pred,
781
+ timestep=timestep,
782
+ sample=latents,
783
+ generator=generator,
784
+ ).prev_sample
785
+
786
+ intermed_latents.append(latents)
787
+
788
+ return intermed_latents
789
+
790
+ @torch.no_grad()
791
+ def generate_i2v(
792
+ self,
793
+ prompt: Union[str, List[str]] = '',
794
+ input_image: PIL.Image = None,
795
+ temp: int = 1,
796
+ num_inference_steps: Optional[Union[int, List[int]]] = 28,
797
+ guidance_scale: float = 7.0,
798
+ video_guidance_scale: float = 4.0,
799
+ min_guidance_scale: float = 2.0,
800
+ use_linear_guidance: bool = False,
801
+ alpha: float = 0.5,
802
+ negative_prompt: Optional[Union[str, List[str]]]="cartoon style, worst quality, low quality, blurry, absolute black, absolute white, low res, extra limbs, extra digits, misplaced objects, mutated anatomy, monochrome, horror",
803
+ num_images_per_prompt: Optional[int] = 1,
804
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
805
+ output_type: Optional[str] = "pil",
806
+ save_memory: bool = True,
807
+ cpu_offloading: bool = False, # If true, reload device will be cuda.
808
+ inference_multigpu: bool = False,
809
+ callback: Optional[Callable[[int, int, Dict], None]] = None,
810
+ ):
811
+ if self.sequential_offload_enabled and not cpu_offloading:
812
+ print("Warning: overriding cpu_offloading set to false, as it's needed for sequential cpu offload")
813
+ cpu_offloading=True
814
+ device = self.device if not cpu_offloading else torch.device("cuda")
815
+ dtype = self.dtype
816
+ if cpu_offloading:
817
+ # skip caring about the text encoder here as its about to be used anyways.
818
+ if not self.sequential_offload_enabled:
819
+ if str(self.dit.device) != "cpu":
820
+ print("(dit) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
821
+ self.dit.to("cpu")
822
+ torch.cuda.empty_cache()
823
+ if str(self.vae.device) != "cpu":
824
+ print("(vae) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
825
+ self.vae.to("cpu")
826
+ torch.cuda.empty_cache()
827
+
828
+ width = input_image.width
829
+ height = input_image.height
830
+
831
+ assert temp % self.frame_per_unit == 0, "The frames should be divided by frame_per unit"
832
+
833
+ if isinstance(prompt, str):
834
+ batch_size = 1
835
+ prompt = prompt + ", hyper quality, Ultra HD, 8K" # adding this prompt to improve aesthetics
836
+ else:
837
+ assert isinstance(prompt, list)
838
+ batch_size = len(prompt)
839
+ prompt = [_ + ", hyper quality, Ultra HD, 8K" for _ in prompt]
840
+
841
+ if isinstance(num_inference_steps, int):
842
+ num_inference_steps = [num_inference_steps] * len(self.stages)
843
+
844
+ negative_prompt = negative_prompt or ""
845
+
846
+ # Get the text embeddings
847
+ if cpu_offloading and not self.sequential_offload_enabled:
848
+ self.text_encoder.to("cuda")
849
+ prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device)
850
+ negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device)
851
+
852
+ if cpu_offloading:
853
+ if not self.sequential_offload_enabled:
854
+ self.text_encoder.to("cpu")
855
+ self.vae.to("cuda")
856
+ torch.cuda.empty_cache()
857
+
858
+ if use_linear_guidance:
859
+ max_guidance_scale = guidance_scale
860
+ guidance_scale_list = [max(max_guidance_scale - alpha * t_, min_guidance_scale) for t_ in range(temp+1)]
861
+ print(guidance_scale_list)
862
+
863
+ self._guidance_scale = guidance_scale
864
+ self._video_guidance_scale = video_guidance_scale
865
+
866
+ if self.do_classifier_free_guidance:
867
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
868
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
869
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
870
+
871
+ if is_sequence_parallel_initialized():
872
+ # sync the prompt embedding across multiple GPUs
873
+ sp_group_rank = get_sequence_parallel_group_rank()
874
+ global_src_rank = sp_group_rank * get_sequence_parallel_world_size()
875
+ torch.distributed.broadcast(prompt_embeds, global_src_rank, group=get_sequence_parallel_group())
876
+ torch.distributed.broadcast(pooled_prompt_embeds, global_src_rank, group=get_sequence_parallel_group())
877
+ torch.distributed.broadcast(prompt_attention_mask, global_src_rank, group=get_sequence_parallel_group())
878
+
879
+ # Create the initial random noise
880
+ num_channels_latents = (self.dit.config.in_channels // 4) if self.model_name == "pyramid_flux" else self.dit.config.in_channels
881
+ latents = self.prepare_latents(
882
+ batch_size * num_images_per_prompt,
883
+ num_channels_latents,
884
+ temp,
885
+ height,
886
+ width,
887
+ prompt_embeds.dtype,
888
+ device,
889
+ generator,
890
+ )
891
+
892
+ temp, height, width = latents.shape[-3], latents.shape[-2], latents.shape[-1]
893
+
894
+ latents = rearrange(latents, 'b c t h w -> (b t) c h w')
895
+ # by defalut, we needs to start from the block noise
896
+ for _ in range(len(self.stages)-1):
897
+ height //= 2;width //= 2
898
+ latents = F.interpolate(latents, size=(height, width), mode='bilinear') * 2
899
+
900
+ latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp)
901
+
902
+ num_units = temp // self.frame_per_unit
903
+ stages = self.stages
904
+
905
+ # encode the image latents
906
+ image_transform = transforms.Compose([
907
+ transforms.ToTensor(),
908
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
909
+ ])
910
+ input_image_tensor = image_transform(input_image).unsqueeze(0).unsqueeze(2) # [b c 1 h w]
911
+ input_image_latent = (self.vae.encode(input_image_tensor.to(self.vae.device, dtype=self.vae.dtype)).latent_dist.sample() - self.vae_shift_factor) * self.vae_scale_factor # [b c 1 h w]
912
+
913
+ if is_sequence_parallel_initialized():
914
+ # sync the image latent across multiple GPUs
915
+ sp_group_rank = get_sequence_parallel_group_rank()
916
+ global_src_rank = sp_group_rank * get_sequence_parallel_world_size()
917
+ torch.distributed.broadcast(input_image_latent, global_src_rank, group=get_sequence_parallel_group())
918
+
919
+ generated_latents_list = [input_image_latent] # The generated results
920
+ last_generated_latents = input_image_latent
921
+
922
+ if cpu_offloading:
923
+ self.vae.to("cpu")
924
+ if not self.sequential_offload_enabled:
925
+ self.dit.to("cuda")
926
+ torch.cuda.empty_cache()
927
+
928
+ for unit_index in tqdm(range(1, num_units)):
929
+ gc.collect()
930
+ torch.cuda.empty_cache()
931
+
932
+ if callback:
933
+ callback(unit_index, num_units)
934
+
935
+ if use_linear_guidance:
936
+ self._guidance_scale = guidance_scale_list[unit_index]
937
+ self._video_guidance_scale = guidance_scale_list[unit_index]
938
+
939
+ # prepare the condition latents
940
+ past_condition_latents = []
941
+ clean_latents_list = self.get_pyramid_latent(torch.cat(generated_latents_list, dim=2), len(stages) - 1)
942
+
943
+ for i_s in range(len(stages)):
944
+ last_cond_latent = clean_latents_list[i_s][:,:,-self.frame_per_unit:]
945
+
946
+ stage_input = [torch.cat([last_cond_latent] * 2) if self.do_classifier_free_guidance else last_cond_latent]
947
+
948
+ # pad the past clean latents
949
+ cur_unit_num = unit_index
950
+ cur_stage = i_s
951
+ cur_unit_ptx = 1
952
+
953
+ while cur_unit_ptx < cur_unit_num:
954
+ cur_stage = max(cur_stage - 1, 0)
955
+ if cur_stage == 0:
956
+ break
957
+ cur_unit_ptx += 1
958
+ cond_latents = clean_latents_list[cur_stage][:, :, -(cur_unit_ptx * self.frame_per_unit) : -((cur_unit_ptx - 1) * self.frame_per_unit)]
959
+ stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
960
+
961
+ if cur_stage == 0 and cur_unit_ptx < cur_unit_num:
962
+ cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)]
963
+ stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
964
+
965
+ stage_input = list(reversed(stage_input))
966
+ past_condition_latents.append(stage_input)
967
+
968
+ intermed_latents = self.generate_one_unit(
969
+ latents[:,:,(unit_index - 1) * self.frame_per_unit:unit_index * self.frame_per_unit],
970
+ past_condition_latents,
971
+ prompt_embeds,
972
+ prompt_attention_mask,
973
+ pooled_prompt_embeds,
974
+ num_inference_steps,
975
+ height,
976
+ width,
977
+ self.frame_per_unit,
978
+ device,
979
+ dtype,
980
+ generator,
981
+ is_first_frame=False,
982
+ )
983
+
984
+ generated_latents_list.append(intermed_latents[-1])
985
+ last_generated_latents = intermed_latents
986
+
987
+ generated_latents = torch.cat(generated_latents_list, dim=2)
988
+
989
+ if output_type == "latent":
990
+ image = generated_latents
991
+ else:
992
+ if cpu_offloading:
993
+ if not self.sequential_offload_enabled:
994
+ self.dit.to("cpu")
995
+ self.vae.to("cuda")
996
+ torch.cuda.empty_cache()
997
+ image = self.decode_latent(generated_latents, save_memory=save_memory, inference_multigpu=inference_multigpu)
998
+ if cpu_offloading:
999
+ self.vae.to("cpu")
1000
+ torch.cuda.empty_cache()
1001
+ # not technically necessary, but returns the pipeline to its original state
1002
+
1003
+ return image
1004
+
1005
+ @torch.no_grad()
1006
+ def generate(
1007
+ self,
1008
+ prompt: Union[str, List[str]] = None,
1009
+ height: Optional[int] = None,
1010
+ width: Optional[int] = None,
1011
+ temp: int = 1,
1012
+ num_inference_steps: Optional[Union[int, List[int]]] = 28,
1013
+ video_num_inference_steps: Optional[Union[int, List[int]]] = 28,
1014
+ guidance_scale: float = 7.0,
1015
+ video_guidance_scale: float = 7.0,
1016
+ min_guidance_scale: float = 2.0,
1017
+ use_linear_guidance: bool = False,
1018
+ alpha: float = 0.5,
1019
+ negative_prompt: Optional[Union[str, List[str]]]="cartoon style, worst quality, low quality, blurry, absolute black, absolute white, low res, extra limbs, extra digits, misplaced objects, mutated anatomy, monochrome, horror",
1020
+ num_images_per_prompt: Optional[int] = 1,
1021
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1022
+ output_type: Optional[str] = "pil",
1023
+ save_memory: bool = True,
1024
+ cpu_offloading: bool = False, # If true, reload device will be cuda.
1025
+ inference_multigpu: bool = False,
1026
+ callback: Optional[Callable[[int, int, Dict], None]] = None,
1027
+ ):
1028
+ if self.sequential_offload_enabled and not cpu_offloading:
1029
+ print("Warning: overriding cpu_offloading set to false, as it's needed for sequential cpu offload")
1030
+ cpu_offloading=True
1031
+ device = self.device if not cpu_offloading else torch.device("cuda")
1032
+ dtype = self.dtype
1033
+ if cpu_offloading:
1034
+ # skip caring about the text encoder here as its about to be used anyways.
1035
+ if not self.sequential_offload_enabled:
1036
+ if str(self.dit.device) != "cpu":
1037
+ print("(dit) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
1038
+ self.dit.to("cpu")
1039
+ torch.cuda.empty_cache()
1040
+ if str(self.vae.device) != "cpu":
1041
+ print("(vae) Warning: Do not preload pipeline components (i.e. to cuda) with cpu offloading enabled! Otherwise, a second transfer will occur needlessly taking up time.")
1042
+ self.vae.to("cpu")
1043
+ torch.cuda.empty_cache()
1044
+
1045
+
1046
+ assert (temp - 1) % self.frame_per_unit == 0, "The frames should be divided by frame_per unit"
1047
+
1048
+ if isinstance(prompt, str):
1049
+ batch_size = 1
1050
+ prompt = prompt + ", hyper quality, Ultra HD, 8K" # adding this prompt to improve aesthetics
1051
+ else:
1052
+ assert isinstance(prompt, list)
1053
+ batch_size = len(prompt)
1054
+ prompt = [_ + ", hyper quality, Ultra HD, 8K" for _ in prompt]
1055
+
1056
+ if isinstance(num_inference_steps, int):
1057
+ num_inference_steps = [num_inference_steps] * len(self.stages)
1058
+
1059
+ if isinstance(video_num_inference_steps, int):
1060
+ video_num_inference_steps = [video_num_inference_steps] * len(self.stages)
1061
+
1062
+ negative_prompt = negative_prompt or ""
1063
+
1064
+ # Get the text embeddings
1065
+ if cpu_offloading and not self.sequential_offload_enabled:
1066
+ self.text_encoder.to("cuda")
1067
+ prompt_embeds, prompt_attention_mask, pooled_prompt_embeds = self.text_encoder(prompt, device)
1068
+ negative_prompt_embeds, negative_prompt_attention_mask, negative_pooled_prompt_embeds = self.text_encoder(negative_prompt, device)
1069
+ if cpu_offloading:
1070
+ if not self.sequential_offload_enabled:
1071
+ self.text_encoder.to("cpu")
1072
+ self.dit.to("cuda")
1073
+ torch.cuda.empty_cache()
1074
+
1075
+ if use_linear_guidance:
1076
+ max_guidance_scale = guidance_scale
1077
+ # guidance_scale_list = torch.linspace(max_guidance_scale, min_guidance_scale, temp).tolist()
1078
+ guidance_scale_list = [max(max_guidance_scale - alpha * t_, min_guidance_scale) for t_ in range(temp)]
1079
+ print(guidance_scale_list)
1080
+
1081
+ self._guidance_scale = guidance_scale
1082
+ self._video_guidance_scale = video_guidance_scale
1083
+
1084
+ if self.do_classifier_free_guidance:
1085
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1086
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1087
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
1088
+
1089
+ if is_sequence_parallel_initialized():
1090
+ # sync the prompt embedding across multiple GPUs
1091
+ sp_group_rank = get_sequence_parallel_group_rank()
1092
+ global_src_rank = sp_group_rank * get_sequence_parallel_world_size()
1093
+ torch.distributed.broadcast(prompt_embeds, global_src_rank, group=get_sequence_parallel_group())
1094
+ torch.distributed.broadcast(pooled_prompt_embeds, global_src_rank, group=get_sequence_parallel_group())
1095
+ torch.distributed.broadcast(prompt_attention_mask, global_src_rank, group=get_sequence_parallel_group())
1096
+
1097
+ # Create the initial random noise
1098
+ num_channels_latents = (self.dit.config.in_channels // 4) if self.model_name == "pyramid_flux" else self.dit.config.in_channels
1099
+ latents = self.prepare_latents(
1100
+ batch_size * num_images_per_prompt,
1101
+ num_channels_latents,
1102
+ temp,
1103
+ height,
1104
+ width,
1105
+ prompt_embeds.dtype,
1106
+ device,
1107
+ generator,
1108
+ )
1109
+
1110
+ temp, height, width = latents.shape[-3], latents.shape[-2], latents.shape[-1]
1111
+
1112
+ latents = rearrange(latents, 'b c t h w -> (b t) c h w')
1113
+ # by default, we needs to start from the block noise
1114
+ for _ in range(len(self.stages)-1):
1115
+ height //= 2;width //= 2
1116
+ latents = F.interpolate(latents, size=(height, width), mode='bilinear') * 2
1117
+
1118
+ latents = rearrange(latents, '(b t) c h w -> b c t h w', t=temp)
1119
+
1120
+ num_units = 1 + (temp - 1) // self.frame_per_unit
1121
+ stages = self.stages
1122
+
1123
+ generated_latents_list = [] # The generated results
1124
+ last_generated_latents = None
1125
+
1126
+ for unit_index in tqdm(range(num_units)):
1127
+ gc.collect()
1128
+ torch.cuda.empty_cache()
1129
+
1130
+ if callback:
1131
+ callback(unit_index, num_units)
1132
+
1133
+ if use_linear_guidance:
1134
+ self._guidance_scale = guidance_scale_list[unit_index]
1135
+ self._video_guidance_scale = guidance_scale_list[unit_index]
1136
+
1137
+ if unit_index == 0:
1138
+ past_condition_latents = [[] for _ in range(len(stages))]
1139
+ intermed_latents = self.generate_one_unit(
1140
+ latents[:,:,:1],
1141
+ past_condition_latents,
1142
+ prompt_embeds,
1143
+ prompt_attention_mask,
1144
+ pooled_prompt_embeds,
1145
+ num_inference_steps,
1146
+ height,
1147
+ width,
1148
+ 1,
1149
+ device,
1150
+ dtype,
1151
+ generator,
1152
+ is_first_frame=True,
1153
+ )
1154
+ else:
1155
+ # prepare the condition latents
1156
+ past_condition_latents = []
1157
+ clean_latents_list = self.get_pyramid_latent(torch.cat(generated_latents_list, dim=2), len(stages) - 1)
1158
+
1159
+ for i_s in range(len(stages)):
1160
+ last_cond_latent = clean_latents_list[i_s][:,:,-(self.frame_per_unit):]
1161
+
1162
+ stage_input = [torch.cat([last_cond_latent] * 2) if self.do_classifier_free_guidance else last_cond_latent]
1163
+
1164
+ # pad the past clean latents
1165
+ cur_unit_num = unit_index
1166
+ cur_stage = i_s
1167
+ cur_unit_ptx = 1
1168
+
1169
+ while cur_unit_ptx < cur_unit_num:
1170
+ cur_stage = max(cur_stage - 1, 0)
1171
+ if cur_stage == 0:
1172
+ break
1173
+ cur_unit_ptx += 1
1174
+ cond_latents = clean_latents_list[cur_stage][:, :, -(cur_unit_ptx * self.frame_per_unit) : -((cur_unit_ptx - 1) * self.frame_per_unit)]
1175
+ stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
1176
+
1177
+ if cur_stage == 0 and cur_unit_ptx < cur_unit_num:
1178
+ cond_latents = clean_latents_list[0][:, :, :-(cur_unit_ptx * self.frame_per_unit)]
1179
+ stage_input.append(torch.cat([cond_latents] * 2) if self.do_classifier_free_guidance else cond_latents)
1180
+
1181
+ stage_input = list(reversed(stage_input))
1182
+ past_condition_latents.append(stage_input)
1183
+
1184
+ intermed_latents = self.generate_one_unit(
1185
+ latents[:,:, 1 + (unit_index - 1) * self.frame_per_unit:1 + unit_index * self.frame_per_unit],
1186
+ past_condition_latents,
1187
+ prompt_embeds,
1188
+ prompt_attention_mask,
1189
+ pooled_prompt_embeds,
1190
+ video_num_inference_steps,
1191
+ height,
1192
+ width,
1193
+ self.frame_per_unit,
1194
+ device,
1195
+ dtype,
1196
+ generator,
1197
+ is_first_frame=False,
1198
+ )
1199
+
1200
+ generated_latents_list.append(intermed_latents[-1])
1201
+ last_generated_latents = intermed_latents
1202
+
1203
+ generated_latents = torch.cat(generated_latents_list, dim=2)
1204
+
1205
+ if output_type == "latent":
1206
+ image = generated_latents
1207
+ else:
1208
+ if cpu_offloading:
1209
+ if not self.sequential_offload_enabled:
1210
+ self.dit.to("cpu")
1211
+ self.vae.to("cuda")
1212
+ torch.cuda.empty_cache()
1213
+ image = self.decode_latent(generated_latents, save_memory=save_memory, inference_multigpu=inference_multigpu)
1214
+ if cpu_offloading:
1215
+ self.vae.to("cpu")
1216
+ torch.cuda.empty_cache()
1217
+ # not technically necessary, but returns the pipeline to its original state
1218
+
1219
+ return image
1220
+
1221
+ def decode_latent(self, latents, save_memory=True, inference_multigpu=False):
1222
+ # only the main process needs vae decoding
1223
+ if inference_multigpu and get_rank() != 0:
1224
+ return None
1225
+
1226
+ if latents.shape[2] == 1:
1227
+ latents = (latents / self.vae_scale_factor) + self.vae_shift_factor
1228
+ else:
1229
+ latents[:, :, :1] = (latents[:, :, :1] / self.vae_scale_factor) + self.vae_shift_factor
1230
+ latents[:, :, 1:] = (latents[:, :, 1:] / self.vae_video_scale_factor) + self.vae_video_shift_factor
1231
+
1232
+ if save_memory:
1233
+ # reducing the tile size and temporal chunk window size
1234
+ image = self.vae.decode(latents, temporal_chunk=True, window_size=1, tile_sample_min_size=256).sample
1235
+ else:
1236
+ image = self.vae.decode(latents, temporal_chunk=True, window_size=2, tile_sample_min_size=512).sample
1237
+
1238
+ image = image.mul(127.5).add(127.5).clamp(0, 255).byte()
1239
+ image = rearrange(image, "B C T H W -> (B T) H W C")
1240
+ image = image.cpu().numpy()
1241
+ image = self.numpy_to_pil(image)
1242
+
1243
+ return image
1244
+
1245
+ @staticmethod
1246
+ def numpy_to_pil(images):
1247
+ """
1248
+ Convert a numpy image or a batch of images to a PIL image.
1249
+ """
1250
+ if images.ndim == 3:
1251
+ images = images[None, ...]
1252
+
1253
+ if images.shape[-1] == 1:
1254
+ # special case for grayscale (single channel) images
1255
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
1256
+ else:
1257
+ pil_images = [Image.fromarray(image) for image in images]
1258
+
1259
+ return pil_images
1260
+
1261
+ @property
1262
+ def device(self):
1263
+ return next(self.dit.parameters()).device
1264
+
1265
+ @property
1266
+ def dtype(self):
1267
+ return next(self.dit.parameters()).dtype
1268
+
1269
+ @property
1270
+ def guidance_scale(self):
1271
+ return self._guidance_scale
1272
+
1273
+ @property
1274
+ def video_guidance_scale(self):
1275
+ return self._video_guidance_scale
1276
+
1277
+ @property
1278
+ def do_classifier_free_guidance(self):
1279
+ return self._guidance_scale > 0
pyramid_flow_model.lnk ADDED
Binary file (982 Bytes). View file
 
pyramid_flow_model/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
pyramid_flow_model/README.md ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ pipeline_tag: text-to-video
4
+ tags:
5
+ - text-to-image
6
+ - image-to-video
7
+ - flux
8
+ ---
9
+
10
+ # ⚡️Pyramid Flow miniFLUX⚡️
11
+
12
+ [[Paper]](https://arxiv.org/abs/2410.05954) [[Project Page ✨]](https://pyramid-flow.github.io) [[Code 🚀]](https://github.com/jy0205/Pyramid-Flow) [[SD3 Model ⚡️]](https://huggingface.co/rain1011/pyramid-flow-sd3) [[demo 🤗](https://huggingface.co/spaces/Pyramid-Flow/pyramid-flow)]
13
+
14
+ This is the model repository for Pyramid Flow, a training-efficient **Autoregressive Video Generation** method based on **Flow Matching**. By training only on open-source datasets, it generates high-quality 10-second videos at 768p resolution and 24 FPS, and naturally supports image-to-video generation.
15
+
16
+ <table class="center" border="0" style="width: 100%; text-align: left;">
17
+ <tr>
18
+ <th>10s, 768p, 24fps</th>
19
+ <th>5s, 768p, 24fps</th>
20
+ <th>Image-to-video</th>
21
+ </tr>
22
+ <tr>
23
+ <td><video src="https://pyramid-flow.github.io/static/videos/t2v_10s/fireworks.mp4" autoplay muted loop playsinline></video></td>
24
+ <td><video src="https://pyramid-flow.github.io/static/videos/t2v/trailer.mp4" autoplay muted loop playsinline></video></td>
25
+ <td><video src="https://pyramid-flow.github.io/static/videos/i2v/sunday.mp4" autoplay muted loop playsinline></video></td>
26
+ </tr>
27
+ </table>
28
+
29
+ ## News
30
+
31
+ * `2024.11.13` 🚀🚀🚀 We release the [768p miniFLUX checkpoint](https://huggingface.co/rain1011/pyramid-flow-miniflux) (up to 10s).
32
+
33
+ > We have switched the model structure from SD3 to a mini FLUX to fix human structure issues, please try our 1024p image checkpoint, 384p video checkpoint (up to 5s) and 768p video checkpoint (up to 10s). The new miniflux model shows great improvement on human structure and motion stability
34
+ * `2024.10.29` ⚡️⚡️⚡️ We release [training code](https://github.com/jy0205/Pyramid-Flow?tab=readme-ov-file#training) and [new model checkpoints](https://huggingface.co/rain1011/pyramid-flow-miniflux) with FLUX structure trained from scratch.
35
+ * `2024.10.11` 🤗🤗🤗 [Hugging Face demo](https://huggingface.co/spaces/Pyramid-Flow/pyramid-flow) is available. Thanks [@multimodalart](https://huggingface.co/multimodalart) for the commit!
36
+ * `2024.10.10` 🚀🚀🚀 We release the [technical report](https://arxiv.org/abs/2410.05954), [project page](https://pyramid-flow.github.io) and [model checkpoint](https://huggingface.co/rain1011/pyramid-flow-sd3) of Pyramid Flow.
37
+
38
+ ## Installation
39
+
40
+ We recommend setting up the environment with conda. The codebase currently uses Python 3.8.10 and PyTorch 2.1.2 ([guide](https://pytorch.org/get-started/previous-versions/#v212)), and we are actively working to support a wider range of versions.
41
+
42
+ ```bash
43
+ git clone https://github.com/jy0205/Pyramid-Flow
44
+ cd Pyramid-Flow
45
+
46
+ # create env using conda
47
+ conda create -n pyramid python==3.8.10
48
+ conda activate pyramid
49
+ pip install -r requirements.txt
50
+ ```
51
+
52
+ Then, download the model from [Huggingface](https://huggingface.co/rain1011) (there are two variants: [miniFLUX](https://huggingface.co/rain1011/pyramid-flow-miniflux) or [SD3](https://huggingface.co/rain1011/pyramid-flow-sd3)). The miniFLUX models support 1024p image, 384p and 768p video generation, and the SD3-based models support 768p and 384p video generation. The 384p checkpoint generates 5-second video at 24FPS, while the 768p checkpoint generates up to 10-second video at 24FPS.
53
+
54
+ ```python
55
+ from huggingface_hub import snapshot_download
56
+
57
+ model_path = 'PATH' # The local directory to save downloaded checkpoint
58
+ snapshot_download("rain1011/pyramid-flow-miniflux", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
59
+ ```
60
+
61
+ ## Usage
62
+
63
+ For inference, we provide Gradio demo, single-GPU, multi-GPU, and Apple Silicon inference code, as well as VRAM-efficient features such as CPU offloading. Please check our [code repository](https://github.com/jy0205/Pyramid-Flow?tab=readme-ov-file#inference) for usage.
64
+
65
+ Below is a simplified two-step usage procedure. First, load the downloaded model:
66
+
67
+ ```python
68
+ import torch
69
+ from PIL import Image
70
+ from pyramid_dit import PyramidDiTForVideoGeneration
71
+ from diffusers.utils import load_image, export_to_video
72
+
73
+ torch.cuda.set_device(0)
74
+ model_dtype, torch_dtype = 'bf16', torch.bfloat16 # Use bf16 (not support fp16 yet)
75
+
76
+ model = PyramidDiTForVideoGeneration(
77
+ 'PATH', # The downloaded checkpoint dir
78
+ model_name="pyramid_flux",
79
+ model_dtype,
80
+ model_variant='diffusion_transformer_768p',
81
+ )
82
+
83
+ model.vae.enable_tiling()
84
+ # model.vae.to("cuda")
85
+ # model.dit.to("cuda")
86
+ # model.text_encoder.to("cuda")
87
+
88
+ # if you're not using sequential offloading bellow uncomment the lines above ^
89
+ model.enable_sequential_cpu_offload()
90
+ ```
91
+
92
+ Then, you can try text-to-video generation on your own prompts:
93
+
94
+ ```python
95
+ prompt = "A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors"
96
+
97
+ # used for 384p model variant
98
+ # width = 640
99
+ # height = 384
100
+
101
+ # used for 768p model variant
102
+ width = 1280
103
+ height = 768
104
+
105
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
106
+ frames = model.generate(
107
+ prompt=prompt,
108
+ num_inference_steps=[20, 20, 20],
109
+ video_num_inference_steps=[10, 10, 10],
110
+ height=height,
111
+ width=width,
112
+ temp=16, # temp=16: 5s, temp=31: 10s
113
+ guidance_scale=7.0, # The guidance for the first frame, set it to 7 for 384p variant
114
+ video_guidance_scale=5.0, # The guidance for the other video latent
115
+ output_type="pil",
116
+ save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
117
+ )
118
+
119
+ export_to_video(frames, "./text_to_video_sample.mp4", fps=24)
120
+ ```
121
+
122
+ As an autoregressive model, our model also supports (text conditioned) image-to-video generation:
123
+
124
+ ```python
125
+ # used for 384p model variant
126
+ # width = 640
127
+ # height = 384
128
+
129
+ # used for 768p model variant
130
+ width = 1280
131
+ height = 768
132
+
133
+ image = Image.open('assets/the_great_wall.jpg').convert("RGB").resize((width, height))
134
+ prompt = "FPV flying over the Great Wall"
135
+
136
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
137
+ frames = model.generate_i2v(
138
+ prompt=prompt,
139
+ input_image=image,
140
+ num_inference_steps=[10, 10, 10],
141
+ temp=16,
142
+ video_guidance_scale=4.0,
143
+ output_type="pil",
144
+ save_memory=True, # If you have enough GPU memory, set it to `False` to improve vae decoding speed
145
+ )
146
+
147
+ export_to_video(frames, "./image_to_video_sample.mp4", fps=24)
148
+ ```
149
+
150
+ ## Usage tips
151
+
152
+ * The `guidance_scale` parameter controls the visual quality. We suggest using a guidance within [7, 9] for the 768p checkpoint during text-to-video generation, and 7 for the 384p checkpoint.
153
+ * The `video_guidance_scale` parameter controls the motion. A larger value increases the dynamic degree and mitigates the autoregressive generation degradation, while a smaller value stabilizes the video.
154
+ * For 10-second video generation, we recommend using a guidance scale of 7 and a video guidance scale of 5.
155
+
156
+ ## Gallery
157
+
158
+ The following video examples are generated at 5s, 768p, 24fps. For more results, please visit our [project page](https://pyramid-flow.github.io).
159
+
160
+ <table class="center" border="0" style="width: 100%; text-align: left;">
161
+ <tr>
162
+ <td><video src="https://pyramid-flow.github.io/static/videos/t2v/tokyo.mp4" autoplay muted loop playsinline></video></td>
163
+ <td><video src="https://pyramid-flow.github.io/static/videos/t2v/eiffel.mp4" autoplay muted loop playsinline></video></td>
164
+ </tr>
165
+ <tr>
166
+ <td><video src="https://pyramid-flow.github.io/static/videos/t2v/waves.mp4" autoplay muted loop playsinline></video></td>
167
+ <td><video src="https://pyramid-flow.github.io/static/videos/t2v/rail.mp4" autoplay muted loop playsinline></video></td>
168
+ </tr>
169
+ </table>
170
+
171
+ ## Acknowledgement
172
+
173
+ We are grateful for the following awesome projects when implementing Pyramid Flow:
174
+
175
+ * [SD3 Medium](https://huggingface.co/stabilityai/stable-diffusion-3-medium) and [Flux 1.0](https://huggingface.co/black-forest-labs/FLUX.1-dev): State-of-the-art image generation models based on flow matching.
176
+ * [Diffusion Forcing](https://boyuan.space/diffusion-forcing) and [GameNGen](https://gamengen.github.io): Next-token prediction meets full-sequence diffusion.
177
+ * [WebVid-10M](https://github.com/m-bain/webvid), [OpenVid-1M](https://github.com/NJU-PCALab/OpenVid-1M) and [Open-Sora Plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan): Large-scale datasets for text-to-video generation.
178
+ * [CogVideoX](https://github.com/THUDM/CogVideo): An open-source text-to-video generation model that shares many training details.
179
+ * [Video-LLaMA2](https://github.com/DAMO-NLP-SG/VideoLLaMA2): An open-source video LLM for our video recaptioning.
180
+
181
+ ## Citation
182
+
183
+ Consider giving this repository a star and cite Pyramid Flow in your publications if it helps your research.
184
+ ```
185
+ @article{jin2024pyramidal,
186
+ title={Pyramidal Flow Matching for Efficient Video Generative Modeling},
187
+ author={Jin, Yang and Sun, Zhicheng and Li, Ningyuan and Xu, Kun and Xu, Kun and Jiang, Hao and Zhuang, Nan and Huang, Quzhe and Song, Yang and Mu, Yadong and Lin, Zhouchen},
188
+ jounal={arXiv preprint arXiv:2410.05954},
189
+ year={2024}
190
+ }
191
+ ```
pyramid_flow_model/causal_video_vae/config.json ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "CausalVideoVAE",
3
+ "_diffusers_version": "0.29.2",
4
+ "add_post_quant_conv": true,
5
+ "decoder_act_fn": "silu",
6
+ "decoder_block_dropout": [
7
+ 0.0,
8
+ 0.0,
9
+ 0.0,
10
+ 0.0
11
+ ],
12
+ "decoder_block_out_channels": [
13
+ 128,
14
+ 256,
15
+ 512,
16
+ 512
17
+ ],
18
+ "decoder_in_channels": 16,
19
+ "decoder_layers_per_block": [
20
+ 3,
21
+ 3,
22
+ 3,
23
+ 3
24
+ ],
25
+ "decoder_norm_num_groups": 32,
26
+ "decoder_out_channels": 3,
27
+ "decoder_spatial_up_sample": [
28
+ true,
29
+ true,
30
+ true,
31
+ false
32
+ ],
33
+ "decoder_temporal_up_sample": [
34
+ true,
35
+ true,
36
+ true,
37
+ false
38
+ ],
39
+ "decoder_type": "causal_vae_conv",
40
+ "decoder_up_block_types": [
41
+ "UpDecoderBlockCausal3D",
42
+ "UpDecoderBlockCausal3D",
43
+ "UpDecoderBlockCausal3D",
44
+ "UpDecoderBlockCausal3D"
45
+ ],
46
+ "downsample_scale": 8,
47
+ "encoder_act_fn": "silu",
48
+ "encoder_block_dropout": [
49
+ 0.0,
50
+ 0.0,
51
+ 0.0,
52
+ 0.0
53
+ ],
54
+ "encoder_block_out_channels": [
55
+ 128,
56
+ 256,
57
+ 512,
58
+ 512
59
+ ],
60
+ "encoder_double_z": true,
61
+ "encoder_down_block_types": [
62
+ "DownEncoderBlockCausal3D",
63
+ "DownEncoderBlockCausal3D",
64
+ "DownEncoderBlockCausal3D",
65
+ "DownEncoderBlockCausal3D"
66
+ ],
67
+ "encoder_in_channels": 3,
68
+ "encoder_layers_per_block": [
69
+ 2,
70
+ 2,
71
+ 2,
72
+ 2
73
+ ],
74
+ "encoder_norm_num_groups": 32,
75
+ "encoder_out_channels": 16,
76
+ "encoder_spatial_down_sample": [
77
+ true,
78
+ true,
79
+ true,
80
+ false
81
+ ],
82
+ "encoder_temporal_down_sample": [
83
+ true,
84
+ true,
85
+ true,
86
+ false
87
+ ],
88
+ "encoder_type": "causal_vae_conv",
89
+ "interpolate": false,
90
+ "sample_size": 256,
91
+ "scaling_factor": 0.13025
92
+ }
pyramid_flow_model/causal_video_vae/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8385177ef6dc62f9e0665213c1509f750a59b496ddf573b4524d7a641b21d260
3
+ size 1341696682
pyramid_flow_model/diffusion_transformer_384p/config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PyramidFluxTransformer",
3
+ "_diffusers_version": "0.30.3",
4
+ "attention_head_dim": 64,
5
+ "axes_dims_rope": [
6
+ 16,
7
+ 24,
8
+ 24
9
+ ],
10
+ "in_channels": 64,
11
+ "interp_condition_pos": true,
12
+ "joint_attention_dim": 4096,
13
+ "num_attention_heads": 30,
14
+ "num_layers": 8,
15
+ "num_single_layers": 16,
16
+ "patch_size": 1,
17
+ "pooled_projection_dim": 768,
18
+ "use_flash_attn": false,
19
+ "use_gradient_checkpointing": false,
20
+ "use_temporal_causal": true
21
+ }
pyramid_flow_model/diffusion_transformer_384p/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76ab1a5f81f5c69285ad8040e8282a6260dae5ca601d7f614bd9de38a46316b5
3
+ size 7888294568
pyramid_flow_model/diffusion_transformer_768p/config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PyramidFluxTransformer",
3
+ "_diffusers_version": "0.30.3",
4
+ "attention_head_dim": 64,
5
+ "axes_dims_rope": [
6
+ 16,
7
+ 24,
8
+ 24
9
+ ],
10
+ "in_channels": 64,
11
+ "interp_condition_pos": true,
12
+ "joint_attention_dim": 4096,
13
+ "num_attention_heads": 30,
14
+ "num_layers": 8,
15
+ "num_single_layers": 16,
16
+ "patch_size": 1,
17
+ "pooled_projection_dim": 768,
18
+ "use_flash_attn": false,
19
+ "use_gradient_checkpointing": false,
20
+ "use_temporal_causal": true
21
+ }
pyramid_flow_model/diffusion_transformer_768p/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:864de0e1afd9dd2c373d957ac2c54346f5006036dc7aa8ec7605db80eea2272c
3
+ size 7888294568
pyramid_flow_model/diffusion_transformer_image/config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "PyramidFluxTransformer",
3
+ "_diffusers_version": "0.30.3",
4
+ "attention_head_dim": 64,
5
+ "axes_dims_rope": [
6
+ 16,
7
+ 24,
8
+ 24
9
+ ],
10
+ "in_channels": 64,
11
+ "interp_condition_pos": true,
12
+ "joint_attention_dim": 4096,
13
+ "num_attention_heads": 30,
14
+ "num_layers": 8,
15
+ "num_single_layers": 16,
16
+ "patch_size": 1,
17
+ "pooled_projection_dim": 768,
18
+ "use_flash_attn": false,
19
+ "use_gradient_checkpointing": false,
20
+ "use_temporal_causal": true
21
+ }
pyramid_flow_model/diffusion_transformer_image/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a16bcc2f50fe52de93d6a7aa13a31dde384dda13a98007e3a5b17e02257697e
3
+ size 7888294568