maxin-cn commited on
Commit
94bafa8
1 Parent(s): c6fae7d

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .gitignore +2 -0
  3. LICENSE +201 -0
  4. README.md +167 -12
  5. configs/ffs/ffs_img_train.yaml +45 -0
  6. configs/ffs/ffs_sample.yaml +30 -0
  7. configs/ffs/ffs_train.yaml +42 -0
  8. configs/sky/sky_img_train.yaml +43 -0
  9. configs/sky/sky_sample.yaml +32 -0
  10. configs/sky/sky_train.yaml +42 -0
  11. configs/t2x/t2i_sample.yaml +37 -0
  12. configs/t2x/t2v_sample.yaml +37 -0
  13. configs/taichi/taichi_img_train.yaml +43 -0
  14. configs/taichi/taichi_sample.yaml +30 -0
  15. configs/taichi/taichi_train.yaml +42 -0
  16. configs/ucf101/ucf101_img_train.yaml +44 -0
  17. configs/ucf101/ucf101_sample.yaml +33 -0
  18. configs/ucf101/ucf101_train.yaml +42 -0
  19. datasets/__init__.py +79 -0
  20. datasets/ffs_datasets.py +164 -0
  21. datasets/ffs_image_datasets.py +246 -0
  22. datasets/sky_datasets.py +110 -0
  23. datasets/sky_image_datasets.py +137 -0
  24. datasets/taichi_datasets.py +108 -0
  25. datasets/taichi_image_datasets.py +139 -0
  26. datasets/ucf101_datasets.py +229 -0
  27. datasets/ucf101_image_datasets.py +279 -0
  28. datasets/video_transforms.py +482 -0
  29. demo.py +284 -0
  30. diffusion/__init__.py +47 -0
  31. diffusion/diffusion_utils.py +88 -0
  32. diffusion/gaussian_diffusion.py +881 -0
  33. diffusion/respace.py +130 -0
  34. diffusion/timestep_sampler.py +150 -0
  35. docs/datasets_evaluation.md +53 -0
  36. docs/latte_diffusers.md +106 -0
  37. environment.yml +25 -0
  38. models/__init__.py +52 -0
  39. models/__pycache__/__init__.cpython-312.pyc +0 -0
  40. models/__pycache__/latte.cpython-312.pyc +0 -0
  41. models/__pycache__/latte_img.cpython-312.pyc +0 -0
  42. models/__pycache__/latte_t2v.cpython-312.pyc +0 -0
  43. models/clip.py +126 -0
  44. models/latte.py +526 -0
  45. models/latte_img.py +552 -0
  46. models/latte_t2v.py +945 -0
  47. models/utils.py +215 -0
  48. sample/__pycache__/pipeline_latte.cpython-312.pyc +0 -0
  49. sample/ffs.sh +7 -0
  50. sample/ffs_ddp.sh +7 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ visuals/latte.gif filter=lfs diff=lfs merge=lfs -text
37
+ visuals/latteT2V.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .vscode
2
+ preprocess
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,167 @@
1
- ---
2
- title: Latte
3
- emoji: 🏆
4
- colorFrom: blue
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 4.39.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Latte
3
+ app_file: demo.py
4
+ sdk: gradio
5
+ sdk_version: 4.37.2
6
+ ---
7
+ ## Latte: Latent Diffusion Transformer for Video Generation<br><sub>Official PyTorch Implementation</sub>
8
+
9
+ <!-- ### [Paper](https://arxiv.org/abs/2401.03048v1) | [Project Page](https://maxin-cn.github.io/latte_project/) -->
10
+
11
+ <!-- [![arXiv](https://img.shields.io/badge/arXiv-2401.03048-b31b1b.svg)](https://arxiv.org/abs/2401.03048) -->
12
+ [![Arxiv](https://img.shields.io/badge/Arxiv-b31b1b.svg)](https://arxiv.org/abs/2401.03048)
13
+ [![Project Page](https://img.shields.io/badge/Project-Website-blue)](https://maxin-cn.github.io/latte_project/)
14
+ [![HF Demo](https://img.shields.io/static/v1?label=Demo&message=OpenBayes%E8%B4%9D%E5%BC%8F%E8%AE%A1%E7%AE%97&color=green)](https://openbayes.com/console/public/tutorials/UOeU0ywVxl7)
15
+
16
+ [![Static Badge](https://img.shields.io/badge/Latte--1%20checkpoint%20(T2V)-HuggingFace-yellow?logoColor=violet%20Latte-1%20checkpoint)](https://huggingface.co/maxin-cn/Latte-1)
17
+ [![Static Badge](https://img.shields.io/badge/Latte%20checkpoint%20-HuggingFace-yellow?logoColor=violet%20Latte%20checkpoint)](https://huggingface.co/maxin-cn/Latte)
18
+
19
+ This repo contains PyTorch model definitions, pre-trained weights, training/sampling code and evaluation code for our paper exploring
20
+ latent diffusion models with transformers (Latte). You can find more visualizations on our [project page](https://maxin-cn.github.io/latte_project/).
21
+
22
+ > [**Latte: Latent Diffusion Transformer for Video Generation**](https://maxin-cn.github.io/latte_project/)<br>
23
+ > [Xin Ma](https://maxin-cn.github.io/), [Yaohui Wang*](https://wyhsirius.github.io/), [Xinyuan Chen](https://scholar.google.com/citations?user=3fWSC8YAAAAJ), [Gengyun Jia](https://scholar.google.com/citations?user=_04pkGgAAAAJ&hl=zh-CN), [Ziwei Liu](https://liuziwei7.github.io/), [Yuan-Fang Li](https://users.monash.edu/~yli/), [Cunjian Chen](https://cunjian.github.io/), [Yu Qiao](https://scholar.google.com.hk/citations?user=gFtI-8QAAAAJ&hl=zh-CN)
24
+ > (*Corresponding Author & Project Lead)
25
+ <!-- > <br>Monash University, Shanghai Artificial Intelligence Laboratory,<br> NJUPT, S-Lab, Nanyang Technological University
26
+
27
+ We propose a novel Latent Diffusion Transformer, namely Latte, for video generation. Latte first extracts spatio-temporal tokens from input videos and then adopts a series of Transformer blocks to model video distribution in the latent space. In order to model a substantial number of tokens extracted from videos, four efficient variants are introduced from the perspective of decomposing the spatial and temporal dimensions of input videos. To improve the quality of generated videos, we determine the best practices of Latte through rigorous experimental analysis, including video clip patch embedding, model variants, timestep-class information injection, temporal positional embedding, and learning strategies. Our comprehensive evaluation demonstrates that Latte achieves state-of-the-art performance across four standard video generation datasets, i.e., FaceForensics, SkyTimelapse, UCF101, and Taichi-HD. In addition, we extend Latte to text-to-video generation (T2V) task, where Latte achieves comparable results compared to recent T2V models. We strongly believe that Latte provides valuable insights for future research on incorporating Transformers into diffusion models for video generation.
28
+
29
+ ![The architecture of Latte](visuals/architecture.svg){width=20}
30
+ -->
31
+
32
+ <!--
33
+ <div align="center">
34
+ <img src="visuals/architecture.svg" width="650">
35
+ </div>
36
+
37
+ This repository contains:
38
+
39
+ * 🪐 A simple PyTorch [implementation](models/latte.py) of Latte
40
+ * ⚡️ **Pre-trained Latte models** trained on FaceForensics, SkyTimelapse, Taichi-HD and UCF101 (256x256). In addition, we provide a T2V checkpoint (512x512). All checkpoints can be found [here](https://huggingface.co/maxin-cn/Latte/tree/main).
41
+
42
+ * 🛸 A Latte [training script](train.py) using PyTorch DDP.
43
+ -->
44
+
45
+ <video controls loop src="https://github.com/Vchitect/Latte/assets/7929326/a650cd84-2378-4303-822b-56a441e1733b" type="video/mp4"></video>
46
+
47
+ ## News
48
+ - (🔥 New) **Jul 11, 2024** 💥 **Latte-1 is now integrated into [diffusers](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/latte_transformer_3d.py). Thanks to [@yiyixuxu](https://github.com/yiyixuxu), [@sayakpaul](https://github.com/sayakpaul), [@a-r-r-o-w](https://github.com/a-r-r-o-w) and [@DN6](https://github.com/DN6).** You can easily run Latte using the following code. We also support inference with 4/8-bit quantization, which can reduce GPU memory from 17 GB to 9 GB. Please refer to this [tutorial](docs/latte_diffusers.md) for more information.
49
+
50
+ ```
51
+ from diffusers import LattePipeline
52
+ from diffusers.models import AutoencoderKLTemporalDecoder
53
+ from torchvision.utils import save_image
54
+ import torch
55
+ import imageio
56
+
57
+ torch.manual_seed(0)
58
+
59
+ device = "cuda" if torch.cuda.is_available() else "cpu"
60
+ video_length = 16 # 1 (text-to-image) or 16 (text-to-video)
61
+ pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16).to(device)
62
+
63
+ # Using temporal decoder of VAE
64
+ vae = AutoencoderKLTemporalDecoder.from_pretrained("maxin-cn/Latte-1", subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
65
+ pipe.vae = vae
66
+
67
+ prompt = "a cat wearing sunglasses and working as a lifeguard at pool."
68
+ videos = pipe(prompt, video_length=video_length, output_type='pt').frames.cpu()
69
+ ```
70
+
71
+ - (🔥 New) **May 23, 2024** 💥 **Latte-1** is released! Pre-trained model can be downloaded [here](https://huggingface.co/maxin-cn/Latte-1/tree/main/transformer). **We support both T2V and T2I**. Please run `bash sample/t2v.sh` and `bash sample/t2i.sh` respectively.
72
+
73
+ <!--
74
+ <div align="center">
75
+ <img src="visuals/latteT2V.gif" width=88%>
76
+ </div>
77
+ -->
78
+
79
+ - (🔥 New) **Feb 24, 2024** 💥 We are very grateful that researchers and developers like our work. We will continue to update our LatteT2V model, hoping that our efforts can help the community develop. Our Latte discord channel <a href="https://discord.gg/RguYqhVU92" style="text-decoration:none;">
80
+ <img src="https://user-images.githubusercontent.com/25839884/218347213-c080267f-cbb6-443e-8532-8e1ed9a58ea9.png" width="3%" alt="" /></a> is created for discussions. Coders are welcome to contribute.
81
+
82
+ - (🔥 New) **Jan 9, 2024** 💥 An updated LatteT2V model initialized with the [PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha) is released, the checkpoint can be found [here](https://huggingface.co/maxin-cn/Latte-0/tree/main/transformer).
83
+
84
+ - (🔥 New) **Oct 31, 2023** 💥 The training and inference code is released. All checkpoints (including FaceForensics, SkyTimelapse, UCF101, and Taichi-HD) can be found [here](https://huggingface.co/maxin-cn/Latte/tree/main). In addition, the LatteT2V inference code is provided.
85
+
86
+
87
+ ## Setup
88
+
89
+ First, download and set up the repo:
90
+
91
+ ```bash
92
+ git clone https://github.com/Vchitect/Latte
93
+ cd Latte
94
+ ```
95
+
96
+ We provide an [`environment.yml`](environment.yml) file that can be used to create a Conda environment. If you only want
97
+ to run pre-trained models locally on CPU, you can remove the `cudatoolkit` and `pytorch-cuda` requirements from the file.
98
+
99
+ ```bash
100
+ conda env create -f environment.yml
101
+ conda activate latte
102
+ ```
103
+
104
+
105
+ ## Sampling
106
+
107
+ You can sample from our **pre-trained Latte models** with [`sample.py`](sample/sample.py). Weights for our pre-trained Latte model can be found [here](https://huggingface.co/maxin-cn/Latte). The script has various arguments to adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from our model on FaceForensics, you can use:
108
+
109
+ ```bash
110
+ bash sample/ffs.sh
111
+ ```
112
+
113
+ or if you want to sample hundreds of videos, you can use the following script with Pytorch DDP:
114
+
115
+ ```bash
116
+ bash sample/ffs_ddp.sh
117
+ ```
118
+
119
+ If you want to try generating videos from text, just run `bash sample/t2v.sh`. All related checkpoints will download automatically.
120
+
121
+ If you would like to measure the quantitative metrics of your generated results, please refer to [here](docs/datasets_evaluation.md).
122
+
123
+ ## Training
124
+
125
+ We provide a training script for Latte in [`train.py`](train.py). The structure of the datasets can be found [here](docs/datasets_evaluation.md). This script can be used to train class-conditional and unconditional
126
+ Latte models. To launch Latte (256x256) training with `N` GPUs on the FaceForensics dataset
127
+ :
128
+
129
+ ```bash
130
+ torchrun --nnodes=1 --nproc_per_node=N train.py --config ./configs/ffs/ffs_train.yaml
131
+ ```
132
+
133
+ or If you have a cluster that uses slurm, you can also train Latte's model using the following scripts:
134
+
135
+ ```bash
136
+ sbatch slurm_scripts/ffs.slurm
137
+ ```
138
+
139
+ We also provide the video-image joint training scripts [`train_with_img.py`](train_with_img.py). Similar to [`train.py`](train.py) scripts, these scripts can be also used to train class-conditional and unconditional
140
+ Latte models. For example, if you want to train the Latte model on the FaceForensics dataset, you can use:
141
+
142
+ ```bash
143
+ torchrun --nnodes=1 --nproc_per_node=N train_with_img.py --config ./configs/ffs/ffs_img_train.yaml
144
+ ```
145
+
146
+ ## Contact Us
147
+ **Yaohui Wang**: [wangyaohui@pjlab.org.cn](mailto:wangyaohui@pjlab.org.cn)
148
+ **Xin Ma**: [xin.ma1@monash.edu](mailto:xin.ma1@monash.edu)
149
+
150
+ ## Citation
151
+ If you find this work useful for your research, please consider citing it.
152
+ ```bibtex
153
+ @article{ma2024latte,
154
+ title={Latte: Latent Diffusion Transformer for Video Generation},
155
+ author={Ma, Xin and Wang, Yaohui and Jia, Gengyun and Chen, Xinyuan and Liu, Ziwei and Li, Yuan-Fang and Chen, Cunjian and Qiao, Yu},
156
+ journal={arXiv preprint arXiv:2401.03048},
157
+ year={2024}
158
+ }
159
+ ```
160
+
161
+
162
+ ## Acknowledgments
163
+ Latte has been greatly inspired by the following amazing works and teams: [DiT](https://github.com/facebookresearch/DiT) and [PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha), we thank all the contributors for open-sourcing.
164
+
165
+
166
+ ## License
167
+ The code and model weights are licensed under [LICENSE](LICENSE).
configs/ffs/ffs_img_train.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset
2
+ dataset: "ffs_img"
3
+
4
+ data_path: "/path/to/datasets/preprocessed_ffs/train/videos/"
5
+ frame_data_path: "/path/to/datasets/preprocessed_ffs/train/images/"
6
+ frame_data_txt: "/path/to/datasets/preprocessed_ffs/train_list.txt"
7
+ pretrained_model_path: "/path/to/pretrained/Latte/"
8
+
9
+ # save and load
10
+ results_dir: "./results_img"
11
+ pretrained:
12
+
13
+ # model config:
14
+ model: LatteIMG-XL/2
15
+ num_frames: 16
16
+ image_size: 256 # choices=[256, 512]
17
+ num_sampling_steps: 250
18
+ frame_interval: 3
19
+ fixed_spatial: False
20
+ attention_bias: True
21
+ learn_sigma: True # important
22
+ extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
23
+
24
+ # train config:
25
+ save_ceph: True # important
26
+ use_image_num: 8
27
+ learning_rate: 1e-4
28
+ ckpt_every: 10000
29
+ clip_max_norm: 0.1
30
+ start_clip_iter: 500000
31
+ local_batch_size: 4 # important
32
+ max_train_steps: 1000000
33
+ global_seed: 3407
34
+ num_workers: 8
35
+ log_every: 100
36
+ lr_warmup_steps: 0
37
+ resume_from_checkpoint:
38
+ gradient_accumulation_steps: 1 # TODO
39
+ num_classes:
40
+
41
+ # low VRAM and speed up training
42
+ use_compile: False
43
+ mixed_precision: False
44
+ enable_xformers_memory_efficient_attention: False
45
+ gradient_checkpointing: False
configs/ffs/ffs_sample.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path:
2
+ ckpt: # will be overwrite
3
+ save_img_path: "./sample_videos" # will be overwrite
4
+ pretrained_model_path: "/path/to/pretrained/Latte/"
5
+
6
+ # model config:
7
+ model: Latte-XL/2
8
+ num_frames: 16
9
+ image_size: 256 # choices=[256, 512]
10
+ frame_interval: 2
11
+ fixed_spatial: False
12
+ attention_bias: True
13
+ learn_sigma: True
14
+ extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
15
+ num_classes:
16
+
17
+ # model speedup
18
+ use_compile: False
19
+ use_fp16: True
20
+
21
+ # sample config:
22
+ seed:
23
+ sample_method: 'ddpm'
24
+ num_sampling_steps: 250
25
+ cfg_scale: 1.0
26
+ negative_name:
27
+
28
+ # ddp sample config
29
+ per_proc_batch_size: 2
30
+ num_fvd_samples: 2048
configs/ffs/ffs_train.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset
2
+ dataset: "ffs"
3
+
4
+ data_path: "/path/to/datasets/preprocess_ffs/train/videos/" # s
5
+ pretrained_model_path: "/path/to/pretrained/Latte/"
6
+
7
+ # save and load
8
+ results_dir: "./results"
9
+ pretrained:
10
+
11
+ # model config:
12
+ model: Latte-XL/2
13
+ num_frames: 16
14
+ image_size: 256 # choices=[256, 512]
15
+ num_sampling_steps: 250
16
+ frame_interval: 3
17
+ fixed_spatial: False
18
+ attention_bias: True
19
+ learn_sigma: True # important
20
+ extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
21
+
22
+ # train config:
23
+ save_ceph: True # important
24
+ learning_rate: 1e-4
25
+ ckpt_every: 10000
26
+ clip_max_norm: 0.1
27
+ start_clip_iter: 20000
28
+ local_batch_size: 5 # important
29
+ max_train_steps: 1000000
30
+ global_seed: 3407
31
+ num_workers: 8
32
+ log_every: 100
33
+ lr_warmup_steps: 0
34
+ resume_from_checkpoint:
35
+ gradient_accumulation_steps: 1 # TODO
36
+ num_classes:
37
+
38
+ # low VRAM and speed up training
39
+ use_compile: False
40
+ mixed_precision: False
41
+ enable_xformers_memory_efficient_attention: False
42
+ gradient_checkpointing: False
configs/sky/sky_img_train.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset
2
+ dataset: "sky_img"
3
+
4
+ data_path: "/path/to/datasets/sky_timelapse/sky_train/" # s/p
5
+ pretrained_model_path: "/path/to/pretrained/Latte/"
6
+
7
+ # save and load
8
+ results_dir: "./results_img"
9
+ pretrained:
10
+
11
+ # model config:
12
+ model: LatteIMG-XL/2
13
+ num_frames: 16
14
+ image_size: 256 # choices=[256, 512]
15
+ num_sampling_steps: 250
16
+ frame_interval: 3
17
+ fixed_spatial: False
18
+ attention_bias: True
19
+ learn_sigma: True
20
+ extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
21
+
22
+ # train config:
23
+ save_ceph: True # important
24
+ use_image_num: 8 # important
25
+ learning_rate: 1e-4
26
+ ckpt_every: 10000
27
+ clip_max_norm: 0.1
28
+ start_clip_iter: 20000
29
+ local_batch_size: 4 # important
30
+ max_train_steps: 1000000
31
+ global_seed: 3407
32
+ num_workers: 8
33
+ log_every: 50
34
+ lr_warmup_steps: 0
35
+ resume_from_checkpoint:
36
+ gradient_accumulation_steps: 1 # TODO
37
+ num_classes:
38
+
39
+ # low VRAM and speed up training
40
+ use_compile: False
41
+ mixed_precision: False
42
+ enable_xformers_memory_efficient_attention: False
43
+ gradient_checkpointing: False
configs/sky/sky_sample.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path:
2
+ ckpt: # will be overwrite
3
+ save_img_path: "./sample_videos/" # will be overwrite
4
+ pretrained_model_path: "/path/to/pretrained/Latte/"
5
+
6
+ # model config:
7
+ model: Latte-XL/2
8
+ num_frames: 16
9
+ image_size: 256 # choices=[256, 512]
10
+ frame_interval: 2
11
+ fixed_spatial: False
12
+ attention_bias: True
13
+ learn_sigma: True
14
+ extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
15
+ num_classes:
16
+
17
+ # model speedup
18
+ use_compile: False
19
+ use_fp16: True
20
+
21
+ # sample config:
22
+ seed:
23
+ sample_method: 'ddpm'
24
+ num_sampling_steps: 250
25
+ cfg_scale: 1.0
26
+ run_time: 12
27
+ num_sample: 1
28
+ negative_name:
29
+
30
+ # ddp sample config
31
+ per_proc_batch_size: 1
32
+ num_fvd_samples: 2
configs/sky/sky_train.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset
2
+ dataset: "sky"
3
+
4
+ data_path: "/path/to/datasets/sky_timelapse/sky_train/"
5
+ pretrained_model_path: "/path/to/pretrained/Latte/"
6
+
7
+ # save and load
8
+ results_dir: "./results"
9
+ pretrained:
10
+
11
+ # model config:
12
+ model: Latte-XL/2
13
+ num_frames: 16
14
+ image_size: 256 # choices=[256, 512]
15
+ num_sampling_steps: 250
16
+ frame_interval: 3
17
+ fixed_spatial: False
18
+ attention_bias: True
19
+ learn_sigma: True
20
+ extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
21
+
22
+ # train config:
23
+ save_ceph: True # important
24
+ learning_rate: 1e-4
25
+ ckpt_every: 10000
26
+ clip_max_norm: 0.1
27
+ start_clip_iter: 20000
28
+ local_batch_size: 5 # important
29
+ max_train_steps: 1000000
30
+ global_seed: 3407
31
+ num_workers: 8
32
+ log_every: 50
33
+ lr_warmup_steps: 0
34
+ resume_from_checkpoint:
35
+ gradient_accumulation_steps: 1 # TODO
36
+ num_classes:
37
+
38
+ # low VRAM and speed up training
39
+ use_compile: False
40
+ mixed_precision: False
41
+ enable_xformers_memory_efficient_attention: False
42
+ gradient_checkpointing: False
configs/t2x/t2i_sample.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path:
2
+ save_img_path: "./sample_videos/t2i-"
3
+ pretrained_model_path: "maxin-cn/Latte-1"
4
+
5
+ # model config:
6
+ # maxin-cn/Latte-0: the first released version
7
+ # maxin-cn/Latte-1: the second version with better performance (released on May. 23, 2024)
8
+ model: LatteT2V
9
+ video_length: 1
10
+ image_size: [512, 512]
11
+ # # beta schedule
12
+ beta_start: 0.0001
13
+ beta_end: 0.02
14
+ beta_schedule: "linear"
15
+ variance_type: "learned_range"
16
+
17
+ # model speedup
18
+ use_compile: False
19
+ use_fp16: True
20
+
21
+ # sample config:
22
+ seed:
23
+ run_time: 0
24
+ guidance_scale: 7.5
25
+ sample_method: 'DDIM'
26
+ num_sampling_steps: 50
27
+ enable_temporal_attentions: True # LatteT2V-V0: set to False; LatteT2V-V1: set to True
28
+ enable_vae_temporal_decoder: False
29
+
30
+ text_prompt: [
31
+ 'Yellow and black tropical fish dart through the sea.',
32
+ 'An epic tornado attacking above aglowing city at night.',
33
+ 'Slow pan upward of blazing oak fire in an indoor fireplace.',
34
+ 'a cat wearing sunglasses and working as a lifeguard at pool.',
35
+ 'Sunset over the sea.',
36
+ 'A dog in astronaut suit and sunglasses floating in space.',
37
+ ]
configs/t2x/t2v_sample.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path:
2
+ save_img_path: "./sample_videos/t2v-"
3
+ pretrained_model_path: "/data/monash_vidgen/pretrained/Latte-1"
4
+
5
+ # model config:
6
+ # maxin-cn/Latte-0: the first released version
7
+ # maxin-cn/Latte-1: the second version with better performance (released on May. 23, 2024)
8
+ model: LatteT2V
9
+ video_length: 16
10
+ image_size: [512, 512]
11
+ # # beta schedule
12
+ beta_start: 0.0001
13
+ beta_end: 0.02
14
+ beta_schedule: "linear"
15
+ variance_type: "learned_range"
16
+
17
+ # model speedup
18
+ use_compile: False
19
+ use_fp16: True
20
+
21
+ # sample config:
22
+ seed: 0
23
+ run_time: 0
24
+ guidance_scale: 7.5
25
+ sample_method: 'DDIM'
26
+ num_sampling_steps: 50
27
+ enable_temporal_attentions: True
28
+ enable_vae_temporal_decoder: True # use temporal vae decoder from SVD, maybe reduce the video flicker (It's not widely tested)
29
+
30
+ text_prompt: [
31
+ 'Yellow and black tropical fish dart through the sea.',
32
+ 'An epic tornado attacking above aglowing city at night.',
33
+ 'Slow pan upward of blazing oak fire in an indoor fireplace.',
34
+ 'a cat wearing sunglasses and working as a lifeguard at pool.',
35
+ 'Sunset over the sea.',
36
+ 'A dog in astronaut suit and sunglasses floating in space.',
37
+ ]
configs/taichi/taichi_img_train.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset
2
+ dataset: "taichi_img"
3
+
4
+ data_path: "/path/to/datasets/taichi"
5
+ pretrained_model_path: "/path/to/pretrained/Latte/"
6
+
7
+ # save and load
8
+ results_dir: "./results_img"
9
+ pretrained:
10
+
11
+ # model config:
12
+ model: LatteIMG-XL/2
13
+ num_frames: 16
14
+ image_size: 256 # choices=[256, 512]
15
+ num_sampling_steps: 250
16
+ frame_interval: 3
17
+ fixed_spatial: False
18
+ attention_bias: True
19
+ learn_sigma: True
20
+ extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
21
+
22
+ # train config:
23
+ load_from_ceph: False # important
24
+ use_image_num: 8
25
+ learning_rate: 1e-4
26
+ ckpt_every: 10000
27
+ clip_max_norm: 0.1
28
+ start_clip_iter: 500000
29
+ local_batch_size: 4 # important
30
+ max_train_steps: 1000000
31
+ global_seed: 3407
32
+ num_workers: 8
33
+ log_every: 50
34
+ lr_warmup_steps: 0
35
+ resume_from_checkpoint:
36
+ gradient_accumulation_steps: 1 # TODO
37
+ num_classes:
38
+
39
+ # low VRAM and speed up training TODO
40
+ use_compile: False
41
+ mixed_precision: False
42
+ enable_xformers_memory_efficient_attention: False
43
+ gradient_checkpointing: False
configs/taichi/taichi_sample.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path:
2
+ ckpt: # will be overwrite
3
+ save_img_path: "./sample_videos/" # will be overwrite
4
+ pretrained_model_path: "/path/to/pretrained/Latte/"
5
+
6
+ # model config:
7
+ model: Latte-XL/2
8
+ num_frames: 16
9
+ image_size: 256 # choices=[256, 512]
10
+ frame_interval: 2
11
+ fixed_spatial: False
12
+ attention_bias: True
13
+ learn_sigma: True
14
+ extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
15
+ num_classes:
16
+
17
+ # model speedup
18
+ use_compile: False
19
+ use_fp16: True
20
+
21
+ # sample config:
22
+ seed:
23
+ sample_method: 'ddpm'
24
+ num_sampling_steps: 250
25
+ cfg_scale: 1.0
26
+ negative_name:
27
+
28
+ # ddp sample config
29
+ per_proc_batch_size: 1
30
+ num_fvd_samples: 2
configs/taichi/taichi_train.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset
2
+ dataset: "taichi"
3
+
4
+ data_path: "/path/to/datasets/taichi"
5
+ pretrained_model_path: "/path/to/pretrained/Latte/"
6
+
7
+ # save and load
8
+ results_dir: "./results"
9
+ pretrained:
10
+
11
+ # model config:
12
+ model: Latte-XL/2
13
+ num_frames: 16
14
+ image_size: 256 # choices=[256, 512]
15
+ num_sampling_steps: 250
16
+ frame_interval: 3
17
+ fixed_spatial: False
18
+ attention_bias: True
19
+ learn_sigma: True
20
+ extras: 1 # [1, 2] 1 unconditional generation, 2 class-conditional generation
21
+
22
+ # train config:
23
+ load_from_ceph: False # important
24
+ learning_rate: 1e-4
25
+ ckpt_every: 10000
26
+ clip_max_norm: 0.1
27
+ start_clip_iter: 500000
28
+ local_batch_size: 5 # important
29
+ max_train_steps: 1000000
30
+ global_seed: 3407
31
+ num_workers: 8
32
+ log_every: 50
33
+ lr_warmup_steps: 0
34
+ resume_from_checkpoint:
35
+ gradient_accumulation_steps: 1 # TODO
36
+ num_classes:
37
+
38
+ # low VRAM and speed up training TODO
39
+ use_compile: False
40
+ mixed_precision: False
41
+ enable_xformers_memory_efficient_attention: False
42
+ gradient_checkpointing: False
configs/ucf101/ucf101_img_train.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset
2
+ dataset: "ucf101_img"
3
+
4
+ data_path: "/path/to/datasets/UCF101/videos/"
5
+ frame_data_txt: "/path/to/datasets/UCF101/train_256_list.txt"
6
+ pretrained_model_path: "/path/to/pretrained/Latte/"
7
+
8
+ # save and load
9
+ results_dir: "./results_img"
10
+ pretrained:
11
+
12
+ # model config:
13
+ model: LatteIMG-XL/2
14
+ num_frames: 16
15
+ image_size: 256 # choices=[256, 512]
16
+ num_sampling_steps: 250
17
+ frame_interval: 3
18
+ fixed_spatial: False
19
+ attention_bias: True
20
+ learn_sigma: True
21
+ extras: 2 # [1, 2] 1 unconditional generation, 2 class-conditional generation
22
+
23
+ # train config:
24
+ save_ceph: True # important
25
+ use_image_num: 8 # important
26
+ learning_rate: 1e-4
27
+ ckpt_every: 10000
28
+ clip_max_norm: 0.1
29
+ start_clip_iter: 100000
30
+ local_batch_size: 4 # important
31
+ max_train_steps: 1000000
32
+ global_seed: 3407
33
+ num_workers: 8
34
+ log_every: 50
35
+ lr_warmup_steps: 0
36
+ resume_from_checkpoint:
37
+ gradient_accumulation_steps: 1 # TODO
38
+ num_classes: 101
39
+
40
+ # low VRAM and speed up training
41
+ use_compile: False
42
+ mixed_precision: False
43
+ enable_xformers_memory_efficient_attention: False
44
+ gradient_checkpointing: False
configs/ucf101/ucf101_sample.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path:
2
+ ckpt:
3
+ save_img_path: "./sample_videos/"
4
+ pretrained_model_path: "/path/to/pretrained/Latte/"
5
+
6
+ # model config:
7
+ model: Latte-XL/2
8
+ num_frames: 16
9
+ image_size: 256 # choices=[256, 512]
10
+ frame_interval: 3
11
+ fixed_spatial: False
12
+ attention_bias: True
13
+ learn_sigma: True
14
+ extras: 2 # [1, 2] 1 unconditional generation, 2 class-conditional generation
15
+ num_classes: 101
16
+
17
+ # model speedup
18
+ use_compile: False
19
+ use_fp16: True
20
+
21
+ # sample config:
22
+ seed:
23
+ sample_method: 'ddpm'
24
+ num_sampling_steps: 250
25
+ cfg_scale: 7.0
26
+ run_time: 12
27
+ num_sample: 1
28
+ sample_names: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
29
+ negative_name: 101
30
+
31
+ # ddp sample config
32
+ per_proc_batch_size: 2
33
+ num_fvd_samples: 2
configs/ucf101/ucf101_train.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset
2
+ dataset: "ucf101"
3
+
4
+ data_path: "/path/to/datasets/UCF101/videos/"
5
+ pretrained_model_path: "/path/to/pretrained/Latte/"
6
+
7
+ # save and load
8
+ results_dir: "./results"
9
+ pretrained:
10
+
11
+ # model config:
12
+ model: Latte-XL/2
13
+ num_frames: 16
14
+ image_size: 256 # choices=[256, 512]
15
+ num_sampling_steps: 250
16
+ frame_interval: 3
17
+ fixed_spatial: False
18
+ attention_bias: True
19
+ learn_sigma: True
20
+ extras: 2 # [1, 2] 1 unconditional generation, 2 class-conditional generation
21
+
22
+ # train config:
23
+ save_ceph: True # important
24
+ learning_rate: 1e-4
25
+ ckpt_every: 10000
26
+ clip_max_norm: 0.1
27
+ start_clip_iter: 100000
28
+ local_batch_size: 5 # important
29
+ max_train_steps: 1000000
30
+ global_seed: 3407
31
+ num_workers: 8
32
+ log_every: 50
33
+ lr_warmup_steps: 0
34
+ resume_from_checkpoint:
35
+ gradient_accumulation_steps: 1 # TODO
36
+ num_classes: 101
37
+
38
+ # low VRAM and speed up training
39
+ use_compile: False
40
+ mixed_precision: False
41
+ enable_xformers_memory_efficient_attention: False
42
+ gradient_checkpointing: False
datasets/__init__.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .sky_datasets import Sky
2
+ from torchvision import transforms
3
+ from .taichi_datasets import Taichi
4
+ from datasets import video_transforms
5
+ from .ucf101_datasets import UCF101
6
+ from .ffs_datasets import FaceForensics
7
+ from .ffs_image_datasets import FaceForensicsImages
8
+ from .sky_image_datasets import SkyImages
9
+ from .ucf101_image_datasets import UCF101Images
10
+ from .taichi_image_datasets import TaichiImages
11
+
12
+
13
+ def get_dataset(args):
14
+ temporal_sample = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval) # 16 1
15
+
16
+ if args.dataset == 'ffs':
17
+ transform_ffs = transforms.Compose([
18
+ video_transforms.ToTensorVideo(), # TCHW
19
+ video_transforms.RandomHorizontalFlipVideo(),
20
+ video_transforms.UCFCenterCropVideo(args.image_size),
21
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
22
+ ])
23
+ return FaceForensics(args, transform=transform_ffs, temporal_sample=temporal_sample)
24
+ elif args.dataset == 'ffs_img':
25
+ transform_ffs = transforms.Compose([
26
+ video_transforms.ToTensorVideo(), # TCHW
27
+ video_transforms.RandomHorizontalFlipVideo(),
28
+ video_transforms.UCFCenterCropVideo(args.image_size),
29
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
30
+ ])
31
+ return FaceForensicsImages(args, transform=transform_ffs, temporal_sample=temporal_sample)
32
+ elif args.dataset == 'ucf101':
33
+ transform_ucf101 = transforms.Compose([
34
+ video_transforms.ToTensorVideo(), # TCHW
35
+ video_transforms.RandomHorizontalFlipVideo(),
36
+ video_transforms.UCFCenterCropVideo(args.image_size),
37
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
38
+ ])
39
+ return UCF101(args, transform=transform_ucf101, temporal_sample=temporal_sample)
40
+ elif args.dataset == 'ucf101_img':
41
+ transform_ucf101 = transforms.Compose([
42
+ video_transforms.ToTensorVideo(), # TCHW
43
+ video_transforms.RandomHorizontalFlipVideo(),
44
+ video_transforms.UCFCenterCropVideo(args.image_size),
45
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
46
+ ])
47
+ return UCF101Images(args, transform=transform_ucf101, temporal_sample=temporal_sample)
48
+ elif args.dataset == 'taichi':
49
+ transform_taichi = transforms.Compose([
50
+ video_transforms.ToTensorVideo(), # TCHW
51
+ video_transforms.RandomHorizontalFlipVideo(),
52
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
53
+ ])
54
+ return Taichi(args, transform=transform_taichi, temporal_sample=temporal_sample)
55
+ elif args.dataset == 'taichi_img':
56
+ transform_taichi = transforms.Compose([
57
+ video_transforms.ToTensorVideo(), # TCHW
58
+ video_transforms.RandomHorizontalFlipVideo(),
59
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
60
+ ])
61
+ return TaichiImages(args, transform=transform_taichi, temporal_sample=temporal_sample)
62
+ elif args.dataset == 'sky':
63
+ transform_sky = transforms.Compose([
64
+ video_transforms.ToTensorVideo(),
65
+ video_transforms.CenterCropResizeVideo(args.image_size),
66
+ # video_transforms.RandomHorizontalFlipVideo(),
67
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
68
+ ])
69
+ return Sky(args, transform=transform_sky, temporal_sample=temporal_sample)
70
+ elif args.dataset == 'sky_img':
71
+ transform_sky = transforms.Compose([
72
+ video_transforms.ToTensorVideo(),
73
+ video_transforms.CenterCropResizeVideo(args.image_size),
74
+ # video_transforms.RandomHorizontalFlipVideo(),
75
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
76
+ ])
77
+ return SkyImages(args, transform=transform_sky, temporal_sample=temporal_sample)
78
+ else:
79
+ raise NotImplementedError(args.dataset)
datasets/ffs_datasets.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import decord
5
+ import torchvision
6
+
7
+ import numpy as np
8
+
9
+
10
+ from PIL import Image
11
+ from einops import rearrange
12
+ from typing import Dict, List, Tuple
13
+
14
+ class_labels_map = None
15
+ cls_sample_cnt = None
16
+
17
+ def temporal_sampling(frames, start_idx, end_idx, num_samples):
18
+ """
19
+ Given the start and end frame index, sample num_samples frames between
20
+ the start and end with equal interval.
21
+ Args:
22
+ frames (tensor): a tensor of video frames, dimension is
23
+ `num video frames` x `channel` x `height` x `width`.
24
+ start_idx (int): the index of the start frame.
25
+ end_idx (int): the index of the end frame.
26
+ num_samples (int): number of frames to sample.
27
+ Returns:
28
+ frames (tersor): a tensor of temporal sampled video frames, dimension is
29
+ `num clip frames` x `channel` x `height` x `width`.
30
+ """
31
+ index = torch.linspace(start_idx, end_idx, num_samples)
32
+ index = torch.clamp(index, 0, frames.shape[0] - 1).long()
33
+ frames = torch.index_select(frames, 0, index)
34
+ return frames
35
+
36
+
37
+ def numpy2tensor(x):
38
+ return torch.from_numpy(x)
39
+
40
+
41
+ def get_filelist(file_path):
42
+ Filelist = []
43
+ for home, dirs, files in os.walk(file_path):
44
+ for filename in files:
45
+ Filelist.append(os.path.join(home, filename))
46
+ # Filelist.append( filename)
47
+ return Filelist
48
+
49
+
50
+ def load_annotation_data(data_file_path):
51
+ with open(data_file_path, 'r') as data_file:
52
+ return json.load(data_file)
53
+
54
+
55
+ def get_class_labels(num_class, anno_pth='./k400_classmap.json'):
56
+ global class_labels_map, cls_sample_cnt
57
+
58
+ if class_labels_map is not None:
59
+ return class_labels_map, cls_sample_cnt
60
+ else:
61
+ cls_sample_cnt = {}
62
+ class_labels_map = load_annotation_data(anno_pth)
63
+ for cls in class_labels_map:
64
+ cls_sample_cnt[cls] = 0
65
+ return class_labels_map, cls_sample_cnt
66
+
67
+
68
+ def load_annotations(ann_file, num_class, num_samples_per_cls):
69
+ dataset = []
70
+ class_to_idx, cls_sample_cnt = get_class_labels(num_class)
71
+ with open(ann_file, 'r') as fin:
72
+ for line in fin:
73
+ line_split = line.strip().split('\t')
74
+ sample = {}
75
+ idx = 0
76
+ # idx for frame_dir
77
+ frame_dir = line_split[idx]
78
+ sample['video'] = frame_dir
79
+ idx += 1
80
+
81
+ # idx for label[s]
82
+ label = [x for x in line_split[idx:]]
83
+ assert label, f'missing label in line: {line}'
84
+ assert len(label) == 1
85
+ class_name = label[0]
86
+ class_index = int(class_to_idx[class_name])
87
+
88
+ # choose a class subset of whole dataset
89
+ if class_index < num_class:
90
+ sample['label'] = class_index
91
+ if cls_sample_cnt[class_name] < num_samples_per_cls:
92
+ dataset.append(sample)
93
+ cls_sample_cnt[class_name]+=1
94
+
95
+ return dataset
96
+
97
+
98
+ class DecordInit(object):
99
+ """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
100
+
101
+ def __init__(self, num_threads=1, **kwargs):
102
+ self.num_threads = num_threads
103
+ self.ctx = decord.cpu(0)
104
+ self.kwargs = kwargs
105
+
106
+ def __call__(self, filename):
107
+ """Perform the Decord initialization.
108
+ Args:
109
+ results (dict): The resulting dict to be modified and passed
110
+ to the next transform in pipeline.
111
+ """
112
+ reader = decord.VideoReader(filename,
113
+ ctx=self.ctx,
114
+ num_threads=self.num_threads)
115
+ return reader
116
+
117
+ def __repr__(self):
118
+ repr_str = (f'{self.__class__.__name__}('
119
+ f'sr={self.sr},'
120
+ f'num_threads={self.num_threads})')
121
+ return repr_str
122
+
123
+
124
+ class FaceForensics(torch.utils.data.Dataset):
125
+ """Load the FaceForensics video files
126
+
127
+ Args:
128
+ target_video_len (int): the number of video frames will be load.
129
+ align_transform (callable): Align different videos in a specified size.
130
+ temporal_sample (callable): Sample the target length of a video.
131
+ """
132
+
133
+ def __init__(self,
134
+ configs,
135
+ transform=None,
136
+ temporal_sample=None):
137
+ self.configs = configs
138
+ self.data_path = configs.data_path
139
+ self.video_lists = get_filelist(configs.data_path)
140
+ self.transform = transform
141
+ self.temporal_sample = temporal_sample
142
+ self.target_video_len = self.configs.num_frames
143
+ self.v_decoder = DecordInit()
144
+
145
+ def __getitem__(self, index):
146
+ path = self.video_lists[index]
147
+ vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
148
+ total_frames = len(vframes)
149
+
150
+ # Sampling video frames
151
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
152
+ assert end_frame_ind - start_frame_ind >= self.target_video_len
153
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
154
+ video = vframes[frame_indice]
155
+ # videotransformer data proprecess
156
+ video = self.transform(video) # T C H W
157
+ return {'video': video, 'video_name': 1}
158
+
159
+ def __len__(self):
160
+ return len(self.video_lists)
161
+
162
+
163
+ if __name__ == '__main__':
164
+ pass
datasets/ffs_image_datasets.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import decord
5
+ import torchvision
6
+
7
+ import numpy as np
8
+
9
+ import random
10
+ from PIL import Image
11
+ from einops import rearrange
12
+ from typing import Dict, List, Tuple
13
+ from torchvision import transforms
14
+ import traceback
15
+
16
+ class_labels_map = None
17
+ cls_sample_cnt = None
18
+
19
+ def temporal_sampling(frames, start_idx, end_idx, num_samples):
20
+ """
21
+ Given the start and end frame index, sample num_samples frames between
22
+ the start and end with equal interval.
23
+ Args:
24
+ frames (tensor): a tensor of video frames, dimension is
25
+ `num video frames` x `channel` x `height` x `width`.
26
+ start_idx (int): the index of the start frame.
27
+ end_idx (int): the index of the end frame.
28
+ num_samples (int): number of frames to sample.
29
+ Returns:
30
+ frames (tersor): a tensor of temporal sampled video frames, dimension is
31
+ `num clip frames` x `channel` x `height` x `width`.
32
+ """
33
+ index = torch.linspace(start_idx, end_idx, num_samples)
34
+ index = torch.clamp(index, 0, frames.shape[0] - 1).long()
35
+ frames = torch.index_select(frames, 0, index)
36
+ return frames
37
+
38
+
39
+ def numpy2tensor(x):
40
+ return torch.from_numpy(x)
41
+
42
+
43
+ def get_filelist(file_path):
44
+ Filelist = []
45
+ for home, dirs, files in os.walk(file_path):
46
+ for filename in files:
47
+ # 文件名列表,包含完整路径
48
+ Filelist.append(os.path.join(home, filename))
49
+ # # 文件名列表,只包含文件名
50
+ # Filelist.append( filename)
51
+ return Filelist
52
+
53
+
54
+ def load_annotation_data(data_file_path):
55
+ with open(data_file_path, 'r') as data_file:
56
+ return json.load(data_file)
57
+
58
+
59
+ def get_class_labels(num_class, anno_pth='./k400_classmap.json'):
60
+ global class_labels_map, cls_sample_cnt
61
+
62
+ if class_labels_map is not None:
63
+ return class_labels_map, cls_sample_cnt
64
+ else:
65
+ cls_sample_cnt = {}
66
+ class_labels_map = load_annotation_data(anno_pth)
67
+ for cls in class_labels_map:
68
+ cls_sample_cnt[cls] = 0
69
+ return class_labels_map, cls_sample_cnt
70
+
71
+
72
+ def load_annotations(ann_file, num_class, num_samples_per_cls):
73
+ dataset = []
74
+ class_to_idx, cls_sample_cnt = get_class_labels(num_class)
75
+ with open(ann_file, 'r') as fin:
76
+ for line in fin:
77
+ line_split = line.strip().split('\t')
78
+ sample = {}
79
+ idx = 0
80
+ # idx for frame_dir
81
+ frame_dir = line_split[idx]
82
+ sample['video'] = frame_dir
83
+ idx += 1
84
+
85
+ # idx for label[s]
86
+ label = [x for x in line_split[idx:]]
87
+ assert label, f'missing label in line: {line}'
88
+ assert len(label) == 1
89
+ class_name = label[0]
90
+ class_index = int(class_to_idx[class_name])
91
+
92
+ # choose a class subset of whole dataset
93
+ if class_index < num_class:
94
+ sample['label'] = class_index
95
+ if cls_sample_cnt[class_name] < num_samples_per_cls:
96
+ dataset.append(sample)
97
+ cls_sample_cnt[class_name]+=1
98
+
99
+ return dataset
100
+
101
+
102
+ class DecordInit(object):
103
+ """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
104
+
105
+ def __init__(self, num_threads=1, **kwargs):
106
+ self.num_threads = num_threads
107
+ self.ctx = decord.cpu(0)
108
+ self.kwargs = kwargs
109
+
110
+ def __call__(self, filename):
111
+ """Perform the Decord initialization.
112
+ Args:
113
+ results (dict): The resulting dict to be modified and passed
114
+ to the next transform in pipeline.
115
+ """
116
+ reader = decord.VideoReader(filename,
117
+ ctx=self.ctx,
118
+ num_threads=self.num_threads)
119
+ return reader
120
+
121
+ def __repr__(self):
122
+ repr_str = (f'{self.__class__.__name__}('
123
+ f'sr={self.sr},'
124
+ f'num_threads={self.num_threads})')
125
+ return repr_str
126
+
127
+
128
+ class FaceForensicsImages(torch.utils.data.Dataset):
129
+ """Load the FaceForensics video files
130
+
131
+ Args:
132
+ target_video_len (int): the number of video frames will be load.
133
+ align_transform (callable): Align different videos in a specified size.
134
+ temporal_sample (callable): Sample the target length of a video.
135
+ """
136
+
137
+ def __init__(self,
138
+ configs,
139
+ transform=None,
140
+ temporal_sample=None):
141
+ self.configs = configs
142
+ self.data_path = configs.data_path
143
+ self.video_lists = get_filelist(configs.data_path)
144
+ self.transform = transform
145
+ self.temporal_sample = temporal_sample
146
+ self.target_video_len = self.configs.num_frames
147
+ self.v_decoder = DecordInit()
148
+ self.video_length = len(self.video_lists)
149
+
150
+ # ffs video frames
151
+ self.video_frame_path = configs.frame_data_path
152
+ self.video_frame_txt = configs.frame_data_txt
153
+ self.video_frame_files = [frame_file.strip() for frame_file in open(self.video_frame_txt)]
154
+ random.shuffle(self.video_frame_files)
155
+ self.use_image_num = configs.use_image_num
156
+ self.image_tranform = transforms.Compose([
157
+ transforms.ToTensor(),
158
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
159
+ ])
160
+
161
+ def __getitem__(self, index):
162
+ video_index = index % self.video_length
163
+ path = self.video_lists[video_index]
164
+ vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
165
+ total_frames = len(vframes)
166
+
167
+ # Sampling video frames
168
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
169
+ assert end_frame_ind - start_frame_ind >= self.target_video_len
170
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
171
+ video = vframes[frame_indice]
172
+ # videotransformer data proprecess
173
+ video = self.transform(video) # T C H W
174
+
175
+ # get video frames
176
+ images = []
177
+ for i in range(self.use_image_num):
178
+ while True:
179
+ try:
180
+ image = Image.open(os.path.join(self.video_frame_path, self.video_frame_files[index+i])).convert("RGB")
181
+ image = self.image_tranform(image).unsqueeze(0)
182
+ images.append(image)
183
+ break
184
+ except Exception as e:
185
+ traceback.print_exc()
186
+ index = random.randint(0, len(self.video_frame_files) - self.use_image_num)
187
+ images = torch.cat(images, dim=0)
188
+
189
+ assert len(images) == self.use_image_num
190
+
191
+ video_cat = torch.cat([video, images], dim=0)
192
+
193
+ return {'video': video_cat, 'video_name': 1}
194
+
195
+ def __len__(self):
196
+ return len(self.video_frame_files)
197
+
198
+
199
+ if __name__ == '__main__':
200
+ import argparse
201
+ import torchvision
202
+ import video_transforms
203
+
204
+ import torch.utils.data as Data
205
+ import torchvision.transforms as transform
206
+
207
+ from PIL import Image
208
+
209
+
210
+ parser = argparse.ArgumentParser()
211
+ parser.add_argument("--num_frames", type=int, default=16)
212
+ parser.add_argument("--use-image-num", type=int, default=5)
213
+ parser.add_argument("--frame_interval", type=int, default=3)
214
+ parser.add_argument("--dataset", type=str, default='webvideo10m')
215
+ parser.add_argument("--test-run", type=bool, default='')
216
+ parser.add_argument("--data-path", type=str, default="/path/to/datasets/preprocessed_ffs/train/videos/")
217
+ parser.add_argument("--frame-data-path", type=str, default="/path/to/datasets/preprocessed_ffs/train/images/")
218
+ parser.add_argument("--frame-data-txt", type=str, default="/path/to/datasets/faceForensics_v1/train_list.txt")
219
+ config = parser.parse_args()
220
+
221
+ temporal_sample = video_transforms.TemporalRandomCrop(config.num_frames * config.frame_interval)
222
+
223
+ transform_webvideo = transform.Compose([
224
+ video_transforms.ToTensorVideo(),
225
+ transform.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
226
+ ])
227
+
228
+ dataset = FaceForensicsImages(config, transform=transform_webvideo, temporal_sample=temporal_sample)
229
+ dataloader = Data.DataLoader(dataset=dataset, batch_size=1, shuffle=True, num_workers=4)
230
+
231
+ for i, video_data in enumerate(dataloader):
232
+ video, video_label = video_data['video'], video_data['video_name']
233
+ # print(video_label)
234
+ # print(image_label)
235
+ print(video.shape)
236
+ print(video_label)
237
+ # video_ = ((video[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
238
+ # print(video_.shape)
239
+ # try:
240
+ # torchvision.io.write_video(f'./test/{i:03d}_{video_label}.mp4', video_[:16], fps=8)
241
+ # except:
242
+ # pass
243
+
244
+ # if i % 100 == 0 and i != 0:
245
+ # break
246
+ print('Done!')
datasets/sky_datasets.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import torch.utils.data as data
5
+
6
+ import numpy as np
7
+
8
+ from PIL import Image
9
+
10
+ IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
11
+
12
+ def is_image_file(filename):
13
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
14
+
15
+ class Sky(data.Dataset):
16
+ def __init__(self, configs, transform, temporal_sample=None, train=True):
17
+
18
+ self.configs = configs
19
+ self.data_path = configs.data_path
20
+ self.transform = transform
21
+ self.temporal_sample = temporal_sample
22
+ self.target_video_len = self.configs.num_frames
23
+ self.frame_interval = self.configs.frame_interval
24
+ self.data_all = self.load_video_frames(self.data_path)
25
+
26
+ def __getitem__(self, index):
27
+
28
+ vframes = self.data_all[index]
29
+ total_frames = len(vframes)
30
+
31
+ # Sampling video frames
32
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
33
+ assert end_frame_ind - start_frame_ind >= self.target_video_len
34
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, num=self.target_video_len, dtype=int) # start, stop, num=50
35
+
36
+ select_video_frames = vframes[frame_indice[0]: frame_indice[-1]+1: self.frame_interval]
37
+
38
+ video_frames = []
39
+ for path in select_video_frames:
40
+ video_frame = torch.as_tensor(np.array(Image.open(path), dtype=np.uint8, copy=True)).unsqueeze(0)
41
+ video_frames.append(video_frame)
42
+ video_clip = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2)
43
+ video_clip = self.transform(video_clip)
44
+
45
+ return {'video': video_clip, 'video_name': 1}
46
+
47
+ def __len__(self):
48
+ return self.video_num
49
+
50
+ def load_video_frames(self, dataroot):
51
+ data_all = []
52
+ frame_list = os.walk(dataroot)
53
+ for _, meta in enumerate(frame_list):
54
+ root = meta[0]
55
+ try:
56
+ frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
57
+ except:
58
+ print(meta[0]) # root
59
+ print(meta[2]) # files
60
+ frames = [os.path.join(root, item) for item in frames if is_image_file(item)]
61
+ if len(frames) > max(0, self.target_video_len * self.frame_interval): # need all > (16 * frame-interval) videos
62
+ # if len(frames) >= max(0, self.target_video_len): # need all > 16 frames videos
63
+ data_all.append(frames)
64
+ self.video_num = len(data_all)
65
+ return data_all
66
+
67
+
68
+ if __name__ == '__main__':
69
+
70
+ import argparse
71
+ import torchvision
72
+ import video_transforms
73
+ import torch.utils.data as data
74
+
75
+ from torchvision import transforms
76
+ from torchvision.utils import save_image
77
+
78
+
79
+ parser = argparse.ArgumentParser()
80
+ parser.add_argument("--num_frames", type=int, default=16)
81
+ parser.add_argument("--frame_interval", type=int, default=4)
82
+ parser.add_argument("--data-path", type=str, default="/path/to/datasets/sky_timelapse/sky_train/")
83
+ config = parser.parse_args()
84
+
85
+
86
+ target_video_len = config.num_frames
87
+
88
+ temporal_sample = video_transforms.TemporalRandomCrop(target_video_len * config.frame_interval)
89
+ trans = transforms.Compose([
90
+ video_transforms.ToTensorVideo(),
91
+ # video_transforms.CenterCropVideo(256),
92
+ video_transforms.CenterCropResizeVideo(256),
93
+ # video_transforms.RandomHorizontalFlipVideo(),
94
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
95
+ ])
96
+
97
+ taichi_dataset = Sky(config, transform=trans, temporal_sample=temporal_sample)
98
+ print(len(taichi_dataset))
99
+ taichi_dataloader = data.DataLoader(dataset=taichi_dataset, batch_size=1, shuffle=False, num_workers=1)
100
+
101
+ for i, video_data in enumerate(taichi_dataloader):
102
+ print(video_data['video'].shape)
103
+
104
+ # print(video_data.dtype)
105
+ # for i in range(target_video_len):
106
+ # save_image(video_data[0][i], os.path.join('./test_data', '%04d.png' % i), normalize=True, value_range=(-1, 1))
107
+
108
+ # video_ = ((video_data[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
109
+ # torchvision.io.write_video('./test_data' + 'test.mp4', video_, fps=8)
110
+ # exit()
datasets/sky_image_datasets.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import torch.utils.data as data
5
+ import numpy as np
6
+ import copy
7
+ from PIL import Image
8
+
9
+ IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
10
+
11
+ def is_image_file(filename):
12
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
13
+
14
+ class SkyImages(data.Dataset):
15
+ def __init__(self, configs, transform, temporal_sample=None, train=True):
16
+
17
+ self.configs = configs
18
+ self.data_path = configs.data_path
19
+ self.transform = transform
20
+ self.temporal_sample = temporal_sample
21
+ self.target_video_len = self.configs.num_frames
22
+ self.frame_interval = self.configs.frame_interval
23
+ self.data_all, self.video_frame_all = self.load_video_frames(self.data_path)
24
+
25
+ # sky video frames
26
+ random.shuffle(self.video_frame_all)
27
+ self.use_image_num = configs.use_image_num
28
+
29
+ def __getitem__(self, index):
30
+
31
+ video_index = index % self.video_num
32
+ vframes = self.data_all[video_index]
33
+ total_frames = len(vframes)
34
+
35
+ # Sampling video frames
36
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
37
+ assert end_frame_ind - start_frame_ind >= self.target_video_len
38
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, num=self.target_video_len, dtype=int) # start, stop, num=50
39
+
40
+ select_video_frames = vframes[frame_indice[0]: frame_indice[-1]+1: self.frame_interval]
41
+
42
+ video_frames = []
43
+ for path in select_video_frames:
44
+ video_frame = torch.as_tensor(np.array(Image.open(path), dtype=np.uint8, copy=True)).unsqueeze(0)
45
+ video_frames.append(video_frame)
46
+ video_clip = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2)
47
+ video_clip = self.transform(video_clip)
48
+
49
+ # get video frames
50
+ images = []
51
+
52
+ for i in range(self.use_image_num):
53
+ while True:
54
+ try:
55
+ video_frame_path = self.video_frame_all[index+i]
56
+ image = torch.as_tensor(np.array(Image.open(video_frame_path), dtype=np.uint8, copy=True)).unsqueeze(0)
57
+ images.append(image)
58
+ break
59
+ except Exception as e:
60
+ index = random.randint(0, self.video_frame_num - self.use_image_num)
61
+
62
+ images = torch.cat(images, dim=0).permute(0, 3, 1, 2)
63
+ images = self.transform(images)
64
+ assert len(images) == self.use_image_num
65
+
66
+ video_cat = torch.cat([video_clip, images], dim=0)
67
+
68
+ return {'video': video_cat, 'video_name': 1}
69
+
70
+ def __len__(self):
71
+ return self.video_frame_num
72
+
73
+ def load_video_frames(self, dataroot):
74
+ data_all = []
75
+ frames_all = []
76
+ frame_list = os.walk(dataroot)
77
+ for _, meta in enumerate(frame_list):
78
+ root = meta[0]
79
+ try:
80
+ frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
81
+ except:
82
+ print(meta[0]) # root
83
+ print(meta[2]) # files
84
+ frames = [os.path.join(root, item) for item in frames if is_image_file(item)]
85
+ if len(frames) > max(0, self.target_video_len * self.frame_interval): # need all > (16 * frame-interval) videos
86
+ # if len(frames) >= max(0, self.target_video_len): # need all > 16 frames videos
87
+ data_all.append(frames)
88
+ for frame in frames:
89
+ frames_all.append(frame)
90
+ self.video_num = len(data_all)
91
+ self.video_frame_num = len(frames_all)
92
+ return data_all, frames_all
93
+
94
+
95
+ if __name__ == '__main__':
96
+
97
+ import argparse
98
+ import torchvision
99
+ import video_transforms
100
+ import torch.utils.data as data
101
+
102
+ from torchvision import transforms
103
+ from torchvision.utils import save_image
104
+
105
+
106
+ parser = argparse.ArgumentParser()
107
+ parser.add_argument("--num_frames", type=int, default=16)
108
+ parser.add_argument("--frame_interval", type=int, default=3)
109
+ parser.add_argument("--data-path", type=str, default="/path/to/datasets/sky_timelapse/sky_train/")
110
+ parser.add_argument("--use-image-num", type=int, default=5)
111
+ config = parser.parse_args()
112
+
113
+ target_video_len = config.num_frames
114
+
115
+ temporal_sample = video_transforms.TemporalRandomCrop(target_video_len * config.frame_interval)
116
+ trans = transforms.Compose([
117
+ video_transforms.ToTensorVideo(),
118
+ # video_transforms.CenterCropVideo(256),
119
+ video_transforms.CenterCropResizeVideo(256),
120
+ # video_transforms.RandomHorizontalFlipVideo(),
121
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
122
+ ])
123
+
124
+ taichi_dataset = SkyImages(config, transform=trans, temporal_sample=temporal_sample)
125
+ print(len(taichi_dataset))
126
+ taichi_dataloader = data.DataLoader(dataset=taichi_dataset, batch_size=1, shuffle=False, num_workers=1)
127
+
128
+ for i, video_data in enumerate(taichi_dataloader):
129
+ print(video_data['video'].shape)
130
+
131
+ # print(video_data.dtype)
132
+ # for i in range(target_video_len):
133
+ # save_image(video_data[0][i], os.path.join('./test_data', '%04d.png' % i), normalize=True, value_range=(-1, 1))
134
+
135
+ # video_ = ((video_data[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
136
+ # torchvision.io.write_video('./test_data' + 'test.mp4', video_, fps=8)
137
+ # exit()
datasets/taichi_datasets.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import torch.utils.data as data
5
+
6
+ import numpy as np
7
+ import io
8
+ import json
9
+ from PIL import Image
10
+
11
+ IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
12
+
13
+ def is_image_file(filename):
14
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
15
+
16
+ class Taichi(data.Dataset):
17
+ def __init__(self, configs, transform, temporal_sample=None, train=True):
18
+
19
+ self.configs = configs
20
+ self.data_path = configs.data_path
21
+ self.transform = transform
22
+ self.temporal_sample = temporal_sample
23
+ self.target_video_len = self.configs.num_frames
24
+ self.frame_interval = self.configs.frame_interval
25
+ self.data_all = self.load_video_frames(self.data_path)
26
+ self.video_num = len(self.data_all)
27
+
28
+ def __getitem__(self, index):
29
+
30
+ vframes = self.data_all[index]
31
+ total_frames = len(vframes)
32
+
33
+ # Sampling video frames
34
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
35
+ assert end_frame_ind - start_frame_ind >= self.target_video_len
36
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
37
+ select_video_frames = vframes[frame_indice[0]: frame_indice[-1]+1: self.frame_interval]
38
+
39
+ video_frames = []
40
+ for path in select_video_frames:
41
+ image = Image.open(path).convert('RGB')
42
+ video_frame = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0)
43
+ video_frames.append(video_frame)
44
+ video_clip = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2)
45
+ video_clip = self.transform(video_clip)
46
+
47
+ # return video_clip, 1
48
+ return {'video': video_clip, 'video_name': 1}
49
+
50
+ def __len__(self):
51
+ return self.video_num
52
+
53
+ def load_video_frames(self, dataroot):
54
+ data_all = []
55
+ frame_list = os.walk(dataroot)
56
+ for _, meta in enumerate(frame_list):
57
+ root = meta[0]
58
+ try:
59
+ frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
60
+ except:
61
+ print(meta[0], meta[2])
62
+ frames = [os.path.join(root, item) for item in frames if is_image_file(item)]
63
+ # if len(frames) > max(0, self.sequence_length * self.sample_every_n_frames):
64
+ if len(frames) != 0:
65
+ data_all.append(frames)
66
+ # self.video_num = len(data_all)
67
+ return data_all
68
+
69
+
70
+ if __name__ == '__main__':
71
+
72
+ import argparse
73
+ import torchvision
74
+ import video_transforms
75
+ import torch.utils.data as data
76
+
77
+ from torchvision import transforms
78
+ from torchvision.utils import save_image
79
+
80
+ parser = argparse.ArgumentParser()
81
+ parser.add_argument("--num_frames", type=int, default=16)
82
+ parser.add_argument("--frame_interval", type=int, default=4)
83
+ parser.add_argument("--load_fron_ceph", type=bool, default=True)
84
+ parser.add_argument("--data-path", type=str, default="/path/to/datasets/taichi/taichi-256/frames/train")
85
+ config = parser.parse_args()
86
+
87
+
88
+ target_video_len = config.num_frames
89
+
90
+ temporal_sample = video_transforms.TemporalRandomCrop(target_video_len * config.frame_interval)
91
+ trans = transforms.Compose([
92
+ video_transforms.ToTensorVideo(),
93
+ video_transforms.RandomHorizontalFlipVideo(),
94
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
95
+ ])
96
+
97
+ taichi_dataset = Taichi(config, transform=trans, temporal_sample=temporal_sample)
98
+ taichi_dataloader = data.DataLoader(dataset=taichi_dataset, batch_size=1, shuffle=False, num_workers=1)
99
+
100
+ for i, video_data in enumerate(taichi_dataloader):
101
+ print(video_data['video'].shape)
102
+ # print(video_data.dtype)
103
+ # for i in range(target_video_len):
104
+ # save_image(video_data[0][i], os.path.join('./test_data', '%04d.png' % i), normalize=True, value_range=(-1, 1))
105
+
106
+ # video_ = ((video_data[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
107
+ # torchvision.io.write_video('./test_data' + 'test.mp4', video_, fps=8)
108
+ # exit()
datasets/taichi_image_datasets.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ import torch.utils.data as data
5
+
6
+ import numpy as np
7
+ import io
8
+ import json
9
+ from PIL import Image
10
+
11
+ IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
12
+
13
+ def is_image_file(filename):
14
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
15
+
16
+ class TaichiImages(data.Dataset):
17
+ def __init__(self, configs, transform, temporal_sample=None, train=True):
18
+
19
+ self.configs = configs
20
+ self.data_path = configs.data_path
21
+ self.transform = transform
22
+ self.temporal_sample = temporal_sample
23
+ self.target_video_len = self.configs.num_frames
24
+ self.frame_interval = self.configs.frame_interval
25
+ self.data_all, self.video_frame_all = self.load_video_frames(self.data_path)
26
+ self.video_num = len(self.data_all)
27
+ self.video_frame_num = len(self.video_frame_all)
28
+
29
+ # sky video frames
30
+ random.shuffle(self.video_frame_all)
31
+ self.use_image_num = configs.use_image_num
32
+
33
+ def __getitem__(self, index):
34
+
35
+ video_index = index % self.video_num
36
+ vframes = self.data_all[video_index]
37
+ total_frames = len(vframes)
38
+
39
+ # Sampling video frames
40
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
41
+ assert end_frame_ind - start_frame_ind >= self.target_video_len
42
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
43
+ # print(frame_indice)
44
+ select_video_frames = vframes[frame_indice[0]: frame_indice[-1]+1: self.frame_interval]
45
+
46
+ video_frames = []
47
+ for path in select_video_frames:
48
+ image = Image.open(path).convert('RGB')
49
+ video_frame = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0)
50
+ video_frames.append(video_frame)
51
+ video_clip = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2)
52
+ video_clip = self.transform(video_clip)
53
+
54
+ # get video frames
55
+ images = []
56
+ for i in range(self.use_image_num):
57
+ while True:
58
+ try:
59
+ video_frame_path = self.video_frame_all[index+i]
60
+ image_path = os.path.join(self.data_path, video_frame_path)
61
+ image = Image.open(image_path).convert('RGB')
62
+ image = torch.as_tensor(np.array(image, dtype=np.uint8, copy=True)).unsqueeze(0)
63
+ images.append(image)
64
+ break
65
+ except Exception as e:
66
+ index = random.randint(0, self.video_frame_num - self.use_image_num)
67
+
68
+ images = torch.cat(images, dim=0).permute(0, 3, 1, 2)
69
+ images = self.transform(images)
70
+ assert len(images) == self.use_image_num
71
+
72
+ video_cat = torch.cat([video_clip, images], dim=0)
73
+
74
+ return {'video': video_cat, 'video_name': 1}
75
+
76
+ def __len__(self):
77
+ return self.video_frame_num
78
+
79
+ def load_video_frames(self, dataroot):
80
+ data_all = []
81
+ frames_all = []
82
+ frame_list = os.walk(dataroot)
83
+ for _, meta in enumerate(frame_list):
84
+ root = meta[0]
85
+ try:
86
+ frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
87
+ except:
88
+ print(meta[0], meta[2])
89
+ frames = [os.path.join(root, item) for item in frames if is_image_file(item)]
90
+ # if len(frames) > max(0, self.sequence_length * self.sample_every_n_frames):
91
+ if len(frames) != 0:
92
+ data_all.append(frames)
93
+ for frame in frames:
94
+ frames_all.append(frame)
95
+ # self.video_num = len(data_all)
96
+ return data_all, frames_all
97
+
98
+
99
+ if __name__ == '__main__':
100
+
101
+ import argparse
102
+ import torchvision
103
+ import video_transforms
104
+ import torch.utils.data as data
105
+
106
+ from torchvision import transforms
107
+ from torchvision.utils import save_image
108
+
109
+ parser = argparse.ArgumentParser()
110
+ parser.add_argument("--num_frames", type=int, default=16)
111
+ parser.add_argument("--frame_interval", type=int, default=4)
112
+ parser.add_argument("--load_from_ceph", type=bool, default=True)
113
+ parser.add_argument("--data-path", type=str, default="/path/to/datasets/taichi/taichi-256/frames/train")
114
+ parser.add_argument("--use-image-num", type=int, default=5)
115
+ config = parser.parse_args()
116
+
117
+
118
+ target_video_len = config.num_frames
119
+
120
+ temporal_sample = video_transforms.TemporalRandomCrop(target_video_len * config.frame_interval)
121
+ trans = transforms.Compose([
122
+ video_transforms.ToTensorVideo(),
123
+ video_transforms.RandomHorizontalFlipVideo(),
124
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
125
+ ])
126
+
127
+ taichi_dataset = TaichiImages(config, transform=trans, temporal_sample=temporal_sample)
128
+ print(len(taichi_dataset))
129
+ taichi_dataloader = data.DataLoader(dataset=taichi_dataset, batch_size=1, shuffle=False, num_workers=1)
130
+
131
+ for i, video_data in enumerate(taichi_dataloader):
132
+ print(video_data['video'].shape)
133
+ # print(video_data.dtype)
134
+ # for i in range(target_video_len):
135
+ # save_image(video_data[0][i], os.path.join('./test_data', '%04d.png' % i), normalize=True, value_range=(-1, 1))
136
+
137
+ video_ = ((video_data[0] * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
138
+ torchvision.io.write_video('./test_data' + 'test.mp4', video_, fps=8)
139
+ exit()
datasets/ucf101_datasets.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import torch
5
+ import decord
6
+ import torchvision
7
+ import numpy as np
8
+
9
+
10
+ from PIL import Image
11
+ from einops import rearrange
12
+ from typing import Dict, List, Tuple
13
+
14
+ class_labels_map = None
15
+ cls_sample_cnt = None
16
+
17
+ class_labels_map = None
18
+ cls_sample_cnt = None
19
+
20
+
21
+ def temporal_sampling(frames, start_idx, end_idx, num_samples):
22
+ """
23
+ Given the start and end frame index, sample num_samples frames between
24
+ the start and end with equal interval.
25
+ Args:
26
+ frames (tensor): a tensor of video frames, dimension is
27
+ `num video frames` x `channel` x `height` x `width`.
28
+ start_idx (int): the index of the start frame.
29
+ end_idx (int): the index of the end frame.
30
+ num_samples (int): number of frames to sample.
31
+ Returns:
32
+ frames (tersor): a tensor of temporal sampled video frames, dimension is
33
+ `num clip frames` x `channel` x `height` x `width`.
34
+ """
35
+ index = torch.linspace(start_idx, end_idx, num_samples)
36
+ index = torch.clamp(index, 0, frames.shape[0] - 1).long()
37
+ frames = torch.index_select(frames, 0, index)
38
+ return frames
39
+
40
+
41
+ def get_filelist(file_path):
42
+ Filelist = []
43
+ for home, dirs, files in os.walk(file_path):
44
+ for filename in files:
45
+ # 文件名列表,包含完整路径
46
+ Filelist.append(os.path.join(home, filename))
47
+ # # 文件名列表,只包含文件名
48
+ # Filelist.append( filename)
49
+ return Filelist
50
+
51
+
52
+ def load_annotation_data(data_file_path):
53
+ with open(data_file_path, 'r') as data_file:
54
+ return json.load(data_file)
55
+
56
+
57
+ def get_class_labels(num_class, anno_pth='./k400_classmap.json'):
58
+ global class_labels_map, cls_sample_cnt
59
+
60
+ if class_labels_map is not None:
61
+ return class_labels_map, cls_sample_cnt
62
+ else:
63
+ cls_sample_cnt = {}
64
+ class_labels_map = load_annotation_data(anno_pth)
65
+ for cls in class_labels_map:
66
+ cls_sample_cnt[cls] = 0
67
+ return class_labels_map, cls_sample_cnt
68
+
69
+
70
+ def load_annotations(ann_file, num_class, num_samples_per_cls):
71
+ dataset = []
72
+ class_to_idx, cls_sample_cnt = get_class_labels(num_class)
73
+ with open(ann_file, 'r') as fin:
74
+ for line in fin:
75
+ line_split = line.strip().split('\t')
76
+ sample = {}
77
+ idx = 0
78
+ # idx for frame_dir
79
+ frame_dir = line_split[idx]
80
+ sample['video'] = frame_dir
81
+ idx += 1
82
+
83
+ # idx for label[s]
84
+ label = [x for x in line_split[idx:]]
85
+ assert label, f'missing label in line: {line}'
86
+ assert len(label) == 1
87
+ class_name = label[0]
88
+ class_index = int(class_to_idx[class_name])
89
+
90
+ # choose a class subset of whole dataset
91
+ if class_index < num_class:
92
+ sample['label'] = class_index
93
+ if cls_sample_cnt[class_name] < num_samples_per_cls:
94
+ dataset.append(sample)
95
+ cls_sample_cnt[class_name]+=1
96
+
97
+ return dataset
98
+
99
+
100
+ def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
101
+ """Finds the class folders in a dataset.
102
+
103
+ See :class:`DatasetFolder` for details.
104
+ """
105
+ classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
106
+ if not classes:
107
+ raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
108
+
109
+ class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
110
+ return classes, class_to_idx
111
+
112
+
113
+ class DecordInit(object):
114
+ """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
115
+
116
+ def __init__(self, num_threads=1):
117
+ self.num_threads = num_threads
118
+ self.ctx = decord.cpu(0)
119
+
120
+ def __call__(self, filename):
121
+ """Perform the Decord initialization.
122
+ Args:
123
+ results (dict): The resulting dict to be modified and passed
124
+ to the next transform in pipeline.
125
+ """
126
+ reader = decord.VideoReader(filename,
127
+ ctx=self.ctx,
128
+ num_threads=self.num_threads)
129
+ return reader
130
+
131
+ def __repr__(self):
132
+ repr_str = (f'{self.__class__.__name__}('
133
+ f'sr={self.sr},'
134
+ f'num_threads={self.num_threads})')
135
+ return repr_str
136
+
137
+
138
+ class UCF101(torch.utils.data.Dataset):
139
+ """Load the UCF101 video files
140
+
141
+ Args:
142
+ target_video_len (int): the number of video frames will be load.
143
+ align_transform (callable): Align different videos in a specified size.
144
+ temporal_sample (callable): Sample the target length of a video.
145
+ """
146
+
147
+ def __init__(self,
148
+ configs,
149
+ transform=None,
150
+ temporal_sample=None):
151
+ self.configs = configs
152
+ self.data_path = configs.data_path
153
+ self.video_lists = get_filelist(configs.data_path)
154
+ self.transform = transform
155
+ self.temporal_sample = temporal_sample
156
+ self.target_video_len = self.configs.num_frames
157
+ self.v_decoder = DecordInit()
158
+ self.classes, self.class_to_idx = find_classes(self.data_path)
159
+ # print(self.class_to_idx)
160
+ # exit()
161
+
162
+ def __getitem__(self, index):
163
+ path = self.video_lists[index]
164
+ class_name = path.split('/')[-2]
165
+ class_index = self.class_to_idx[class_name]
166
+
167
+ vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
168
+ total_frames = len(vframes)
169
+
170
+ # Sampling video frames
171
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
172
+ assert end_frame_ind - start_frame_ind >= self.target_video_len
173
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
174
+ # print(frame_indice)
175
+ video = vframes[frame_indice] #
176
+ video = self.transform(video) # T C H W
177
+
178
+ return {'video': video, 'video_name': class_index}
179
+
180
+ def __len__(self):
181
+ return len(self.video_lists)
182
+
183
+
184
+ if __name__ == '__main__':
185
+
186
+ import argparse
187
+ import video_transforms
188
+ import torch.utils.data as Data
189
+ import torchvision.transforms as transforms
190
+
191
+ from PIL import Image
192
+
193
+ parser = argparse.ArgumentParser()
194
+ parser.add_argument("--num_frames", type=int, default=16)
195
+ parser.add_argument("--frame_interval", type=int, default=1)
196
+ # parser.add_argument("--data-path", type=str, default="/nvme/share_data/datasets/UCF101/videos")
197
+ parser.add_argument("--data-path", type=str, default="/path/to/datasets/UCF101/videos/")
198
+ config = parser.parse_args()
199
+
200
+
201
+ temporal_sample = video_transforms.TemporalRandomCrop(config.num_frames * config.frame_interval)
202
+
203
+ transform_ucf101 = transforms.Compose([
204
+ video_transforms.ToTensorVideo(), # TCHW
205
+ video_transforms.RandomHorizontalFlipVideo(),
206
+ video_transforms.UCFCenterCropVideo(256),
207
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
208
+ ])
209
+
210
+
211
+ ffs_dataset = UCF101(config, transform=transform_ucf101, temporal_sample=temporal_sample)
212
+ ffs_dataloader = Data.DataLoader(dataset=ffs_dataset, batch_size=6, shuffle=False, num_workers=1)
213
+
214
+ # for i, video_data in enumerate(ffs_dataloader):
215
+ for video_data in ffs_dataloader:
216
+ print(type(video_data))
217
+ video = video_data['video']
218
+ video_name = video_data['video_name']
219
+ print(video.shape)
220
+ print(video_name)
221
+ # print(video_data[2])
222
+
223
+ # for i in range(16):
224
+ # img0 = rearrange(video_data[0][0][i], 'c h w -> h w c')
225
+ # print('Label: {}'.format(video_data[1]))
226
+ # print(img0.shape)
227
+ # img0 = Image.fromarray(np.uint8(img0 * 255))
228
+ # img0.save('./img{}.jpg'.format(i))
229
+ exit()
datasets/ucf101_image_datasets.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io
2
+ import re
3
+ import json
4
+ import torch
5
+ import decord
6
+ import torchvision
7
+ import numpy as np
8
+
9
+
10
+ from PIL import Image
11
+ from einops import rearrange
12
+ from typing import Dict, List, Tuple
13
+ from torchvision import transforms
14
+ import random
15
+
16
+
17
+ class_labels_map = None
18
+ cls_sample_cnt = None
19
+
20
+ class_labels_map = None
21
+ cls_sample_cnt = None
22
+
23
+
24
+ def temporal_sampling(frames, start_idx, end_idx, num_samples):
25
+ """
26
+ Given the start and end frame index, sample num_samples frames between
27
+ the start and end with equal interval.
28
+ Args:
29
+ frames (tensor): a tensor of video frames, dimension is
30
+ `num video frames` x `channel` x `height` x `width`.
31
+ start_idx (int): the index of the start frame.
32
+ end_idx (int): the index of the end frame.
33
+ num_samples (int): number of frames to sample.
34
+ Returns:
35
+ frames (tersor): a tensor of temporal sampled video frames, dimension is
36
+ `num clip frames` x `channel` x `height` x `width`.
37
+ """
38
+ index = torch.linspace(start_idx, end_idx, num_samples)
39
+ index = torch.clamp(index, 0, frames.shape[0] - 1).long()
40
+ frames = torch.index_select(frames, 0, index)
41
+ return frames
42
+
43
+
44
+ def get_filelist(file_path):
45
+ Filelist = []
46
+ for home, dirs, files in os.walk(file_path):
47
+ for filename in files:
48
+ Filelist.append(os.path.join(home, filename))
49
+ # Filelist.append( filename)
50
+ return Filelist
51
+
52
+
53
+ def load_annotation_data(data_file_path):
54
+ with open(data_file_path, 'r') as data_file:
55
+ return json.load(data_file)
56
+
57
+
58
+ def get_class_labels(num_class, anno_pth='./k400_classmap.json'):
59
+ global class_labels_map, cls_sample_cnt
60
+
61
+ if class_labels_map is not None:
62
+ return class_labels_map, cls_sample_cnt
63
+ else:
64
+ cls_sample_cnt = {}
65
+ class_labels_map = load_annotation_data(anno_pth)
66
+ for cls in class_labels_map:
67
+ cls_sample_cnt[cls] = 0
68
+ return class_labels_map, cls_sample_cnt
69
+
70
+
71
+ def load_annotations(ann_file, num_class, num_samples_per_cls):
72
+ dataset = []
73
+ class_to_idx, cls_sample_cnt = get_class_labels(num_class)
74
+ with open(ann_file, 'r') as fin:
75
+ for line in fin:
76
+ line_split = line.strip().split('\t')
77
+ sample = {}
78
+ idx = 0
79
+ # idx for frame_dir
80
+ frame_dir = line_split[idx]
81
+ sample['video'] = frame_dir
82
+ idx += 1
83
+
84
+ # idx for label[s]
85
+ label = [x for x in line_split[idx:]]
86
+ assert label, f'missing label in line: {line}'
87
+ assert len(label) == 1
88
+ class_name = label[0]
89
+ class_index = int(class_to_idx[class_name])
90
+
91
+ # choose a class subset of whole dataset
92
+ if class_index < num_class:
93
+ sample['label'] = class_index
94
+ if cls_sample_cnt[class_name] < num_samples_per_cls:
95
+ dataset.append(sample)
96
+ cls_sample_cnt[class_name]+=1
97
+
98
+ return dataset
99
+
100
+
101
+ def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
102
+ """Finds the class folders in a dataset.
103
+
104
+ See :class:`DatasetFolder` for details.
105
+ """
106
+ classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
107
+ if not classes:
108
+ raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
109
+
110
+ class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
111
+ return classes, class_to_idx
112
+
113
+
114
+ class DecordInit(object):
115
+ """Using Decord(https://github.com/dmlc/decord) to initialize the video_reader."""
116
+
117
+ def __init__(self, num_threads=1):
118
+ self.num_threads = num_threads
119
+ self.ctx = decord.cpu(0)
120
+
121
+ def __call__(self, filename):
122
+ """Perform the Decord initialization.
123
+ Args:
124
+ results (dict): The resulting dict to be modified and passed
125
+ to the next transform in pipeline.
126
+ """
127
+ reader = decord.VideoReader(filename,
128
+ ctx=self.ctx,
129
+ num_threads=self.num_threads)
130
+ return reader
131
+
132
+ def __repr__(self):
133
+ repr_str = (f'{self.__class__.__name__}('
134
+ f'sr={self.sr},'
135
+ f'num_threads={self.num_threads})')
136
+ return repr_str
137
+
138
+
139
+ class UCF101Images(torch.utils.data.Dataset):
140
+ """Load the UCF101 video files
141
+
142
+ Args:
143
+ target_video_len (int): the number of video frames will be load.
144
+ align_transform (callable): Align different videos in a specified size.
145
+ temporal_sample (callable): Sample the target length of a video.
146
+ """
147
+
148
+ def __init__(self,
149
+ configs,
150
+ transform=None,
151
+ temporal_sample=None):
152
+ self.configs = configs
153
+ self.data_path = configs.data_path
154
+ self.video_lists = get_filelist(configs.data_path)
155
+ self.transform = transform
156
+ self.temporal_sample = temporal_sample
157
+ self.target_video_len = self.configs.num_frames
158
+ self.v_decoder = DecordInit()
159
+ self.classes, self.class_to_idx = find_classes(self.data_path)
160
+ self.video_num = len(self.video_lists)
161
+
162
+ # ucf101 video frames
163
+ self.frame_data_path = configs.frame_data_path # important
164
+
165
+ self.video_frame_txt = configs.frame_data_txt
166
+ self.video_frame_files = [frame_file.strip() for frame_file in open(self.video_frame_txt)]
167
+ random.shuffle(self.video_frame_files)
168
+ self.use_image_num = configs.use_image_num
169
+ self.image_tranform = transforms.Compose([
170
+ transforms.ToTensor(),
171
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
172
+ ])
173
+ self.video_frame_num = len(self.video_frame_files)
174
+
175
+
176
+ def __getitem__(self, index):
177
+
178
+ video_index = index % self.video_num
179
+ path = self.video_lists[video_index]
180
+ class_name = path.split('/')[-2]
181
+ class_index = self.class_to_idx[class_name]
182
+
183
+ vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
184
+ total_frames = len(vframes)
185
+
186
+ # Sampling video frames
187
+ start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
188
+ assert end_frame_ind - start_frame_ind >= self.target_video_len
189
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.target_video_len, dtype=int)
190
+ video = vframes[frame_indice]
191
+
192
+ # videotransformer data proprecess
193
+ video = self.transform(video) # T C H W
194
+ images = []
195
+ image_names = []
196
+ for i in range(self.use_image_num):
197
+ while True:
198
+ try:
199
+ video_frame_path = self.video_frame_files[index+i]
200
+ image_class_name = video_frame_path.split('_')[1]
201
+ image_class_index = self.class_to_idx[image_class_name]
202
+ video_frame_path = os.path.join(self.frame_data_path, video_frame_path)
203
+ image = Image.open(video_frame_path).convert('RGB')
204
+ image = self.image_tranform(image).unsqueeze(0)
205
+ images.append(image)
206
+ image_names.append(str(image_class_index))
207
+ break
208
+ except Exception as e:
209
+ index = random.randint(0, self.video_frame_num - self.use_image_num)
210
+ images = torch.cat(images, dim=0)
211
+ assert len(images) == self.use_image_num
212
+ assert len(image_names) == self.use_image_num
213
+
214
+ image_names = '====='.join(image_names)
215
+
216
+ video_cat = torch.cat([video, images], dim=0)
217
+
218
+ return {'video': video_cat,
219
+ 'video_name': class_index,
220
+ 'image_name': image_names}
221
+
222
+ def __len__(self):
223
+ return self.video_frame_num
224
+
225
+
226
+ if __name__ == '__main__':
227
+
228
+ import argparse
229
+ import video_transforms
230
+ import torch.utils.data as Data
231
+ import torchvision.transforms as transforms
232
+
233
+ from PIL import Image
234
+
235
+ parser = argparse.ArgumentParser()
236
+ parser.add_argument("--num_frames", type=int, default=16)
237
+ parser.add_argument("--frame_interval", type=int, default=3)
238
+ parser.add_argument("--use-image-num", type=int, default=5)
239
+ parser.add_argument("--data-path", type=str, default="/path/to/datasets/UCF101/videos/")
240
+ parser.add_argument("--frame-data-path", type=str, default="/path/to/datasets/preprocessed_ffs/train/images/")
241
+ parser.add_argument("--frame-data-txt", type=str, default="/path/to/datasets/UCF101/train_256_list.txt")
242
+ config = parser.parse_args()
243
+
244
+
245
+ temporal_sample = video_transforms.TemporalRandomCrop(config.num_frames * config.frame_interval)
246
+
247
+ transform_ucf101 = transforms.Compose([
248
+ video_transforms.ToTensorVideo(), # TCHW
249
+ video_transforms.RandomHorizontalFlipVideo(),
250
+ video_transforms.UCFCenterCropVideo(256),
251
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
252
+ ])
253
+
254
+
255
+ ffs_dataset = UCF101Images(config, transform=transform_ucf101, temporal_sample=temporal_sample)
256
+ ffs_dataloader = Data.DataLoader(dataset=ffs_dataset, batch_size=6, shuffle=False, num_workers=1)
257
+
258
+ # for i, video_data in enumerate(ffs_dataloader):
259
+ for video_data in ffs_dataloader:
260
+ # print(type(video_data))
261
+ video = video_data['video']
262
+ # video_name = video_data['video_name']
263
+ print(video.shape)
264
+ print(video_data['image_name'])
265
+ image_name = video_data['image_name']
266
+ image_names = []
267
+ for caption in image_name:
268
+ single_caption = [int(item) for item in caption.split('=====')]
269
+ image_names.append(torch.as_tensor(single_caption))
270
+ print(image_names)
271
+ # print(video_name)
272
+ # print(video_data[2])
273
+
274
+ # for i in range(16):
275
+ # img0 = rearrange(video_data[0][0][i], 'c h w -> h w c')
276
+ # print('Label: {}'.format(video_data[1]))
277
+ # print(img0.shape)
278
+ # img0 = Image.fromarray(np.uint8(img0 * 255))
279
+ # img0.save('./img{}.jpg'.format(i))
datasets/video_transforms.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numbers
4
+ from torchvision.transforms import RandomCrop, RandomResizedCrop
5
+
6
+ def _is_tensor_video_clip(clip):
7
+ if not torch.is_tensor(clip):
8
+ raise TypeError("clip should be Tensor. Got %s" % type(clip))
9
+
10
+ if not clip.ndimension() == 4:
11
+ raise ValueError("clip should be 4D. Got %dD" % clip.dim())
12
+
13
+ return True
14
+
15
+
16
+ def center_crop_arr(pil_image, image_size):
17
+ """
18
+ Center cropping implementation from ADM.
19
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
20
+ """
21
+ while min(*pil_image.size) >= 2 * image_size:
22
+ pil_image = pil_image.resize(
23
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
24
+ )
25
+
26
+ scale = image_size / min(*pil_image.size)
27
+ pil_image = pil_image.resize(
28
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
29
+ )
30
+
31
+ arr = np.array(pil_image)
32
+ crop_y = (arr.shape[0] - image_size) // 2
33
+ crop_x = (arr.shape[1] - image_size) // 2
34
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
35
+
36
+
37
+ def crop(clip, i, j, h, w):
38
+ """
39
+ Args:
40
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
41
+ """
42
+ if len(clip.size()) != 4:
43
+ raise ValueError("clip should be a 4D tensor")
44
+ return clip[..., i : i + h, j : j + w]
45
+
46
+
47
+ def resize(clip, target_size, interpolation_mode):
48
+ if len(target_size) != 2:
49
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
50
+ return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
51
+
52
+ def resize_scale(clip, target_size, interpolation_mode):
53
+ if len(target_size) != 2:
54
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
55
+ H, W = clip.size(-2), clip.size(-1)
56
+ scale_ = target_size[0] / min(H, W)
57
+ return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
58
+
59
+
60
+ def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
61
+ """
62
+ Do spatial cropping and resizing to the video clip
63
+ Args:
64
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
65
+ i (int): i in (i,j) i.e coordinates of the upper left corner.
66
+ j (int): j in (i,j) i.e coordinates of the upper left corner.
67
+ h (int): Height of the cropped region.
68
+ w (int): Width of the cropped region.
69
+ size (tuple(int, int)): height and width of resized clip
70
+ Returns:
71
+ clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
72
+ """
73
+ if not _is_tensor_video_clip(clip):
74
+ raise ValueError("clip should be a 4D torch.tensor")
75
+ clip = crop(clip, i, j, h, w)
76
+ clip = resize(clip, size, interpolation_mode)
77
+ return clip
78
+
79
+
80
+ def center_crop(clip, crop_size):
81
+ if not _is_tensor_video_clip(clip):
82
+ raise ValueError("clip should be a 4D torch.tensor")
83
+ h, w = clip.size(-2), clip.size(-1)
84
+ th, tw = crop_size
85
+ if h < th or w < tw:
86
+ raise ValueError("height and width must be no smaller than crop_size")
87
+
88
+ i = int(round((h - th) / 2.0))
89
+ j = int(round((w - tw) / 2.0))
90
+ return crop(clip, i, j, th, tw)
91
+
92
+
93
+ def center_crop_using_short_edge(clip):
94
+ if not _is_tensor_video_clip(clip):
95
+ raise ValueError("clip should be a 4D torch.tensor")
96
+ h, w = clip.size(-2), clip.size(-1)
97
+ if h < w:
98
+ th, tw = h, h
99
+ i = 0
100
+ j = int(round((w - tw) / 2.0))
101
+ else:
102
+ th, tw = w, w
103
+ i = int(round((h - th) / 2.0))
104
+ j = 0
105
+ return crop(clip, i, j, th, tw)
106
+
107
+
108
+ def random_shift_crop(clip):
109
+ '''
110
+ Slide along the long edge, with the short edge as crop size
111
+ '''
112
+ if not _is_tensor_video_clip(clip):
113
+ raise ValueError("clip should be a 4D torch.tensor")
114
+ h, w = clip.size(-2), clip.size(-1)
115
+
116
+ if h <= w:
117
+ long_edge = w
118
+ short_edge = h
119
+ else:
120
+ long_edge = h
121
+ short_edge =w
122
+
123
+ th, tw = short_edge, short_edge
124
+
125
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
126
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
127
+ return crop(clip, i, j, th, tw)
128
+
129
+
130
+ def to_tensor(clip):
131
+ """
132
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
133
+ permute the dimensions of clip tensor
134
+ Args:
135
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
136
+ Return:
137
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
138
+ """
139
+ _is_tensor_video_clip(clip)
140
+ if not clip.dtype == torch.uint8:
141
+ raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
142
+ # return clip.float().permute(3, 0, 1, 2) / 255.0
143
+ return clip.float() / 255.0
144
+
145
+
146
+ def normalize(clip, mean, std, inplace=False):
147
+ """
148
+ Args:
149
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
150
+ mean (tuple): pixel RGB mean. Size is (3)
151
+ std (tuple): pixel standard deviation. Size is (3)
152
+ Returns:
153
+ normalized clip (torch.tensor): Size is (T, C, H, W)
154
+ """
155
+ if not _is_tensor_video_clip(clip):
156
+ raise ValueError("clip should be a 4D torch.tensor")
157
+ if not inplace:
158
+ clip = clip.clone()
159
+ mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
160
+ # print(mean)
161
+ std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
162
+ clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
163
+ return clip
164
+
165
+
166
+ def hflip(clip):
167
+ """
168
+ Args:
169
+ clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
170
+ Returns:
171
+ flipped clip (torch.tensor): Size is (T, C, H, W)
172
+ """
173
+ if not _is_tensor_video_clip(clip):
174
+ raise ValueError("clip should be a 4D torch.tensor")
175
+ return clip.flip(-1)
176
+
177
+
178
+ class RandomCropVideo:
179
+ def __init__(self, size):
180
+ if isinstance(size, numbers.Number):
181
+ self.size = (int(size), int(size))
182
+ else:
183
+ self.size = size
184
+
185
+ def __call__(self, clip):
186
+ """
187
+ Args:
188
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
189
+ Returns:
190
+ torch.tensor: randomly cropped video clip.
191
+ size is (T, C, OH, OW)
192
+ """
193
+ i, j, h, w = self.get_params(clip)
194
+ return crop(clip, i, j, h, w)
195
+
196
+ def get_params(self, clip):
197
+ h, w = clip.shape[-2:]
198
+ th, tw = self.size
199
+
200
+ if h < th or w < tw:
201
+ raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
202
+
203
+ if w == tw and h == th:
204
+ return 0, 0, h, w
205
+
206
+ i = torch.randint(0, h - th + 1, size=(1,)).item()
207
+ j = torch.randint(0, w - tw + 1, size=(1,)).item()
208
+
209
+ return i, j, th, tw
210
+
211
+ def __repr__(self) -> str:
212
+ return f"{self.__class__.__name__}(size={self.size})"
213
+
214
+ class CenterCropResizeVideo:
215
+ '''
216
+ First use the short side for cropping length,
217
+ center crop video, then resize to the specified size
218
+ '''
219
+ def __init__(
220
+ self,
221
+ size,
222
+ interpolation_mode="bilinear",
223
+ ):
224
+ if isinstance(size, tuple):
225
+ if len(size) != 2:
226
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
227
+ self.size = size
228
+ else:
229
+ self.size = (size, size)
230
+
231
+ self.interpolation_mode = interpolation_mode
232
+
233
+
234
+ def __call__(self, clip):
235
+ """
236
+ Args:
237
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
238
+ Returns:
239
+ torch.tensor: scale resized / center cropped video clip.
240
+ size is (T, C, crop_size, crop_size)
241
+ """
242
+ clip_center_crop = center_crop_using_short_edge(clip)
243
+ clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode)
244
+ return clip_center_crop_resize
245
+
246
+ def __repr__(self) -> str:
247
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
248
+
249
+ class UCFCenterCropVideo:
250
+ '''
251
+ First scale to the specified size in equal proportion to the short edge,
252
+ then center cropping
253
+ '''
254
+ def __init__(
255
+ self,
256
+ size,
257
+ interpolation_mode="bilinear",
258
+ ):
259
+ if isinstance(size, tuple):
260
+ if len(size) != 2:
261
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
262
+ self.size = size
263
+ else:
264
+ self.size = (size, size)
265
+
266
+ self.interpolation_mode = interpolation_mode
267
+
268
+
269
+ def __call__(self, clip):
270
+ """
271
+ Args:
272
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
273
+ Returns:
274
+ torch.tensor: scale resized / center cropped video clip.
275
+ size is (T, C, crop_size, crop_size)
276
+ """
277
+ clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
278
+ clip_center_crop = center_crop(clip_resize, self.size)
279
+ return clip_center_crop
280
+
281
+ def __repr__(self) -> str:
282
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
283
+
284
+ class KineticsRandomCropResizeVideo:
285
+ '''
286
+ Slide along the long edge, with the short edge as crop size. And resie to the desired size.
287
+ '''
288
+ def __init__(
289
+ self,
290
+ size,
291
+ interpolation_mode="bilinear",
292
+ ):
293
+ if isinstance(size, tuple):
294
+ if len(size) != 2:
295
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
296
+ self.size = size
297
+ else:
298
+ self.size = (size, size)
299
+
300
+ self.interpolation_mode = interpolation_mode
301
+
302
+ def __call__(self, clip):
303
+ clip_random_crop = random_shift_crop(clip)
304
+ clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
305
+ return clip_resize
306
+
307
+
308
+ class CenterCropVideo:
309
+ def __init__(
310
+ self,
311
+ size,
312
+ interpolation_mode="bilinear",
313
+ ):
314
+ if isinstance(size, tuple):
315
+ if len(size) != 2:
316
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
317
+ self.size = size
318
+ else:
319
+ self.size = (size, size)
320
+
321
+ self.interpolation_mode = interpolation_mode
322
+
323
+
324
+ def __call__(self, clip):
325
+ """
326
+ Args:
327
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
328
+ Returns:
329
+ torch.tensor: center cropped video clip.
330
+ size is (T, C, crop_size, crop_size)
331
+ """
332
+ clip_center_crop = center_crop(clip, self.size)
333
+ return clip_center_crop
334
+
335
+ def __repr__(self) -> str:
336
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
337
+
338
+
339
+ class NormalizeVideo:
340
+ """
341
+ Normalize the video clip by mean subtraction and division by standard deviation
342
+ Args:
343
+ mean (3-tuple): pixel RGB mean
344
+ std (3-tuple): pixel RGB standard deviation
345
+ inplace (boolean): whether do in-place normalization
346
+ """
347
+
348
+ def __init__(self, mean, std, inplace=False):
349
+ self.mean = mean
350
+ self.std = std
351
+ self.inplace = inplace
352
+
353
+ def __call__(self, clip):
354
+ """
355
+ Args:
356
+ clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
357
+ """
358
+ return normalize(clip, self.mean, self.std, self.inplace)
359
+
360
+ def __repr__(self) -> str:
361
+ return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
362
+
363
+
364
+ class ToTensorVideo:
365
+ """
366
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
367
+ permute the dimensions of clip tensor
368
+ """
369
+
370
+ def __init__(self):
371
+ pass
372
+
373
+ def __call__(self, clip):
374
+ """
375
+ Args:
376
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
377
+ Return:
378
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
379
+ """
380
+ return to_tensor(clip)
381
+
382
+ def __repr__(self) -> str:
383
+ return self.__class__.__name__
384
+
385
+
386
+ class RandomHorizontalFlipVideo:
387
+ """
388
+ Flip the video clip along the horizontal direction with a given probability
389
+ Args:
390
+ p (float): probability of the clip being flipped. Default value is 0.5
391
+ """
392
+
393
+ def __init__(self, p=0.5):
394
+ self.p = p
395
+
396
+ def __call__(self, clip):
397
+ """
398
+ Args:
399
+ clip (torch.tensor): Size is (T, C, H, W)
400
+ Return:
401
+ clip (torch.tensor): Size is (T, C, H, W)
402
+ """
403
+ if random.random() < self.p:
404
+ clip = hflip(clip)
405
+ return clip
406
+
407
+ def __repr__(self) -> str:
408
+ return f"{self.__class__.__name__}(p={self.p})"
409
+
410
+ # ------------------------------------------------------------
411
+ # --------------------- Sampling ---------------------------
412
+ # ------------------------------------------------------------
413
+ class TemporalRandomCrop(object):
414
+ """Temporally crop the given frame indices at a random location.
415
+
416
+ Args:
417
+ size (int): Desired length of frames will be seen in the model.
418
+ """
419
+
420
+ def __init__(self, size):
421
+ self.size = size
422
+
423
+ def __call__(self, total_frames):
424
+ rand_end = max(0, total_frames - self.size - 1)
425
+ begin_index = random.randint(0, rand_end)
426
+ end_index = min(begin_index + self.size, total_frames)
427
+ return begin_index, end_index
428
+
429
+
430
+ if __name__ == '__main__':
431
+ from torchvision import transforms
432
+ import torchvision.io as io
433
+ import numpy as np
434
+ from torchvision.utils import save_image
435
+ import os
436
+
437
+ vframes, aframes, info = io.read_video(
438
+ filename='./v_Archery_g01_c03.avi',
439
+ pts_unit='sec',
440
+ output_format='TCHW'
441
+ )
442
+
443
+ trans = transforms.Compose([
444
+ ToTensorVideo(),
445
+ RandomHorizontalFlipVideo(),
446
+ UCFCenterCropVideo(512),
447
+ # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
448
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
449
+ ])
450
+
451
+ target_video_len = 32
452
+ frame_interval = 1
453
+ total_frames = len(vframes)
454
+ print(total_frames)
455
+
456
+ temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
457
+
458
+
459
+ # Sampling video frames
460
+ start_frame_ind, end_frame_ind = temporal_sample(total_frames)
461
+ # print(start_frame_ind)
462
+ # print(end_frame_ind)
463
+ assert end_frame_ind - start_frame_ind >= target_video_len
464
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
465
+ print(frame_indice)
466
+
467
+ select_vframes = vframes[frame_indice]
468
+ print(select_vframes.shape)
469
+ print(select_vframes.dtype)
470
+
471
+ select_vframes_trans = trans(select_vframes)
472
+ print(select_vframes_trans.shape)
473
+ print(select_vframes_trans.dtype)
474
+
475
+ select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
476
+ print(select_vframes_trans_int.dtype)
477
+ print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
478
+
479
+ io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
480
+
481
+ for i in range(target_video_len):
482
+ save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True, value_range=(-1, 1))
demo.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import argparse
5
+ import torchvision
6
+
7
+
8
+ from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler,
9
+ EulerDiscreteScheduler, DPMSolverMultistepScheduler,
10
+ HeunDiscreteScheduler, EulerAncestralDiscreteScheduler,
11
+ DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler)
12
+ from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
13
+ from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
14
+ from omegaconf import OmegaConf
15
+ from transformers import T5EncoderModel, T5Tokenizer
16
+
17
+ import os, sys
18
+ sys.path.append(os.path.split(sys.path[0])[0])
19
+ from sample.pipeline_latte import LattePipeline
20
+ from models import get_models
21
+ # import imageio
22
+ from torchvision.utils import save_image
23
+ import spaces
24
+
25
+
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument("--config", type=str, default="./configs/t2x/t2v_sample.yaml")
28
+ args = parser.parse_args()
29
+ args = OmegaConf.load(args.config)
30
+
31
+ torch.set_grad_enabled(False)
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
+
34
+ transformer_model = get_models(args).to(device, dtype=torch.float16)
35
+ # state_dict = find_model(args.ckpt)
36
+ # msg, unexp = transformer_model.load_state_dict(state_dict, strict=False)
37
+
38
+ if args.enable_vae_temporal_decoder:
39
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
40
+ else:
41
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae", torch_dtype=torch.float16).to(device)
42
+ tokenizer = T5Tokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
43
+ text_encoder = T5EncoderModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device)
44
+
45
+ # set eval mode
46
+ transformer_model.eval()
47
+ vae.eval()
48
+ text_encoder.eval()
49
+
50
+ @spaces.GPU
51
+ def gen_video(text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step):
52
+ torch.manual_seed(seed)
53
+ if sample_method == 'DDIM':
54
+ scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
55
+ subfolder="scheduler",
56
+ beta_start=args.beta_start,
57
+ beta_end=args.beta_end,
58
+ beta_schedule=args.beta_schedule,
59
+ variance_type=args.variance_type,
60
+ clip_sample=False)
61
+ elif sample_method == 'EulerDiscrete':
62
+ scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_path,
63
+ subfolder="scheduler",
64
+ beta_start=args.beta_start,
65
+ beta_end=args.beta_end,
66
+ beta_schedule=args.beta_schedule,
67
+ variance_type=args.variance_type)
68
+ elif sample_method == 'DDPM':
69
+ scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_path,
70
+ subfolder="scheduler",
71
+ beta_start=args.beta_start,
72
+ beta_end=args.beta_end,
73
+ beta_schedule=args.beta_schedule,
74
+ variance_type=args.variance_type,
75
+ clip_sample=False)
76
+ elif sample_method == 'DPMSolverMultistep':
77
+ scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_path,
78
+ subfolder="scheduler",
79
+ beta_start=args.beta_start,
80
+ beta_end=args.beta_end,
81
+ beta_schedule=args.beta_schedule,
82
+ variance_type=args.variance_type)
83
+ elif sample_method == 'DPMSolverSinglestep':
84
+ scheduler = DPMSolverSinglestepScheduler.from_pretrained(args.pretrained_model_path,
85
+ subfolder="scheduler",
86
+ beta_start=args.beta_start,
87
+ beta_end=args.beta_end,
88
+ beta_schedule=args.beta_schedule,
89
+ variance_type=args.variance_type)
90
+ elif sample_method == 'PNDM':
91
+ scheduler = PNDMScheduler.from_pretrained(args.pretrained_model_path,
92
+ subfolder="scheduler",
93
+ beta_start=args.beta_start,
94
+ beta_end=args.beta_end,
95
+ beta_schedule=args.beta_schedule,
96
+ variance_type=args.variance_type)
97
+ elif sample_method == 'HeunDiscrete':
98
+ scheduler = HeunDiscreteScheduler.from_pretrained(args.pretrained_model_path,
99
+ subfolder="scheduler",
100
+ beta_start=args.beta_start,
101
+ beta_end=args.beta_end,
102
+ beta_schedule=args.beta_schedule,
103
+ variance_type=args.variance_type)
104
+ elif sample_method == 'EulerAncestralDiscrete':
105
+ scheduler = EulerAncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path,
106
+ subfolder="scheduler",
107
+ beta_start=args.beta_start,
108
+ beta_end=args.beta_end,
109
+ beta_schedule=args.beta_schedule,
110
+ variance_type=args.variance_type)
111
+ elif sample_method == 'DEISMultistep':
112
+ scheduler = DEISMultistepScheduler.from_pretrained(args.pretrained_model_path,
113
+ subfolder="scheduler",
114
+ beta_start=args.beta_start,
115
+ beta_end=args.beta_end,
116
+ beta_schedule=args.beta_schedule,
117
+ variance_type=args.variance_type)
118
+ elif sample_method == 'KDPM2AncestralDiscrete':
119
+ scheduler = KDPM2AncestralDiscreteScheduler.from_pretrained(args.pretrained_model_path,
120
+ subfolder="scheduler",
121
+ beta_start=args.beta_start,
122
+ beta_end=args.beta_end,
123
+ beta_schedule=args.beta_schedule,
124
+ variance_type=args.variance_type)
125
+
126
+
127
+ videogen_pipeline = LattePipeline(vae=vae,
128
+ text_encoder=text_encoder,
129
+ tokenizer=tokenizer,
130
+ scheduler=scheduler,
131
+ transformer=transformer_model).to(device)
132
+ # videogen_pipeline.enable_xformers_memory_efficient_attention()
133
+
134
+ videos = videogen_pipeline(text_input,
135
+ video_length=video_length,
136
+ height=height,
137
+ width=width,
138
+ num_inference_steps=diffusion_step,
139
+ guidance_scale=scfg_scale,
140
+ enable_temporal_attentions=args.enable_temporal_attentions,
141
+ num_images_per_prompt=1,
142
+ mask_feature=True,
143
+ enable_vae_temporal_decoder=args.enable_vae_temporal_decoder
144
+ ).video
145
+
146
+ save_path = args.save_img_path + 'temp' + '.mp4'
147
+ torchvision.io.write_video(save_path, videos[0], fps=8)
148
+ return save_path
149
+
150
+
151
+ if not os.path.exists(args.save_img_path):
152
+ os.makedirs(args.save_img_path)
153
+
154
+ intro = """
155
+ <div style="display: flex;align-items: center;justify-content: center">
156
+ <h1 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Latte: Latent Diffusion Transformer for Video Generation</h1>
157
+ </div>
158
+ """
159
+
160
+ with gr.Blocks() as demo:
161
+ # gr.HTML(intro)
162
+ # with gr.Accordion("README", open=False):
163
+ # gr.HTML(
164
+ # """
165
+ # <p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
166
+ # <a href="https://maxin-cn.github.io/latte_project/" target="_blank">project page</a> | <a href="https://arxiv.org/abs/2401.03048" target="_blank">paper</a>
167
+ # </p>
168
+
169
+ # We will continue update Latte.
170
+ # """
171
+ # )
172
+ gr.Markdown("<font color=red size=10><center>Latte: Latent Diffusion Transformer for Video Generation</center></font>")
173
+ gr.Markdown(
174
+ """<div style="display: flex;align-items: center;justify-content: center">
175
+ <h2 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Latte supports both T2I and T2V, and will be continuously updated, so stay tuned!</h2></div>
176
+ """
177
+ )
178
+ gr.Markdown(
179
+ """<div style="display: flex;align-items: center;justify-content: center">
180
+ [<a href="https://arxiv.org/abs/2401.03048">Arxiv Report</a>] | [<a href="https://maxin-cn.github.io/latte_project/">Project Page</a>] | [<a href="https://github.com/Vchitect/Latte">Github</a>]</div>
181
+ """
182
+ )
183
+
184
+
185
+ with gr.Row():
186
+ with gr.Column(visible=True) as input_raws:
187
+ with gr.Row():
188
+ with gr.Column(scale=1.0):
189
+ # text_input = gr.Textbox(show_label=True, interactive=True, label="Text prompt").style(container=False)
190
+ text_input = gr.Textbox(show_label=True, interactive=True, label="Prompt")
191
+ # with gr.Row():
192
+ # with gr.Column(scale=0.5):
193
+ # image_input = gr.Image(show_label=True, interactive=True, label="Reference image").style(container=False)
194
+ # with gr.Column(scale=0.5):
195
+ # preframe_input = gr.Image(show_label=True, interactive=True, label="First frame").style(container=False)
196
+ with gr.Row():
197
+ with gr.Column(scale=0.5):
198
+ sample_method = gr.Dropdown(choices=["DDIM", "EulerDiscrete", "PNDM"], label="Sample Method", value="DDIM")
199
+ # with gr.Row():
200
+ # with gr.Column(scale=1.0):
201
+ # video_length = gr.Slider(
202
+ # minimum=1,
203
+ # maximum=24,
204
+ # value=1,
205
+ # step=1,
206
+ # interactive=True,
207
+ # label="Video Length (1 for T2I and 16 for T2V)",
208
+ # )
209
+ with gr.Column(scale=0.5):
210
+ video_length = gr.Dropdown(choices=[1, 16], label="Video Length (1 for T2I and 16 for T2V)", value=16)
211
+ with gr.Row():
212
+ with gr.Column(scale=1.0):
213
+ scfg_scale = gr.Slider(
214
+ minimum=1,
215
+ maximum=50,
216
+ value=7.5,
217
+ step=0.1,
218
+ interactive=True,
219
+ label="Guidence Scale",
220
+ )
221
+ with gr.Row():
222
+ with gr.Column(scale=1.0):
223
+ seed = gr.Slider(
224
+ minimum=1,
225
+ maximum=2147483647,
226
+ value=100,
227
+ step=1,
228
+ interactive=True,
229
+ label="Seed",
230
+ )
231
+ with gr.Row():
232
+ with gr.Column(scale=0.5):
233
+ height = gr.Slider(
234
+ minimum=256,
235
+ maximum=768,
236
+ value=512,
237
+ step=16,
238
+ interactive=False,
239
+ label="Height",
240
+ )
241
+ # with gr.Row():
242
+ with gr.Column(scale=0.5):
243
+ width = gr.Slider(
244
+ minimum=256,
245
+ maximum=768,
246
+ value=512,
247
+ step=16,
248
+ interactive=False,
249
+ label="Width",
250
+ )
251
+ with gr.Row():
252
+ with gr.Column(scale=1.0):
253
+ diffusion_step = gr.Slider(
254
+ minimum=20,
255
+ maximum=250,
256
+ value=50,
257
+ step=1,
258
+ interactive=True,
259
+ label="Sampling Step",
260
+ )
261
+
262
+
263
+ with gr.Column(scale=0.6, visible=True) as video_upload:
264
+ # with gr.Column(visible=True) as video_upload:
265
+ output = gr.Video(interactive=False, include_audio=True, elem_id="输出的视频") #.style(height=360)
266
+ # with gr.Column(elem_id="image", scale=0.5) as img_part:
267
+ # with gr.Tab("Video", elem_id='video_tab'):
268
+
269
+ # with gr.Tab("Image", elem_id='image_tab'):
270
+ # up_image = gr.Image(type="pil", interactive=True, elem_id="image_upload").style(height=360)
271
+ # upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
272
+ # clear = gr.Button("Restart")
273
+
274
+ with gr.Row():
275
+ with gr.Column(scale=1.0, min_width=0):
276
+ run = gr.Button("💭Run")
277
+ # with gr.Column(scale=0.5, min_width=0):
278
+ # clear = gr.Button("🔄Clear️")
279
+
280
+ run.click(gen_video, [text_input, sample_method, scfg_scale, seed, height, width, video_length, diffusion_step], [output])
281
+
282
+ demo.launch(debug=False, share=True)
283
+
284
+ # demo.launch(server_name="0.0.0.0", server_port=10034, enable_queue=True)
diffusion/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from . import gaussian_diffusion as gd
7
+ from .respace import SpacedDiffusion, space_timesteps
8
+
9
+
10
+ def create_diffusion(
11
+ timestep_respacing,
12
+ noise_schedule="linear",
13
+ use_kl=False,
14
+ sigma_small=False,
15
+ predict_xstart=False,
16
+ learn_sigma=True,
17
+ # learn_sigma=False,
18
+ rescale_learned_sigmas=False,
19
+ diffusion_steps=1000
20
+ ):
21
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
22
+ if use_kl:
23
+ loss_type = gd.LossType.RESCALED_KL
24
+ elif rescale_learned_sigmas:
25
+ loss_type = gd.LossType.RESCALED_MSE
26
+ else:
27
+ loss_type = gd.LossType.MSE
28
+ if timestep_respacing is None or timestep_respacing == "":
29
+ timestep_respacing = [diffusion_steps]
30
+ return SpacedDiffusion(
31
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
32
+ betas=betas,
33
+ model_mean_type=(
34
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
35
+ ),
36
+ model_var_type=(
37
+ (
38
+ gd.ModelVarType.FIXED_LARGE
39
+ if not sigma_small
40
+ else gd.ModelVarType.FIXED_SMALL
41
+ )
42
+ if not learn_sigma
43
+ else gd.ModelVarType.LEARNED_RANGE
44
+ ),
45
+ loss_type=loss_type
46
+ # rescale_timesteps=rescale_timesteps,
47
+ )
diffusion/diffusion_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import torch as th
7
+ import numpy as np
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = None
17
+ for obj in (mean1, logvar1, mean2, logvar2):
18
+ if isinstance(obj, th.Tensor):
19
+ tensor = obj
20
+ break
21
+ assert tensor is not None, "at least one argument must be a Tensor"
22
+
23
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
24
+ # Tensors, but it does not work for th.exp().
25
+ logvar1, logvar2 = [
26
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27
+ for x in (logvar1, logvar2)
28
+ ]
29
+
30
+ return 0.5 * (
31
+ -1.0
32
+ + logvar2
33
+ - logvar1
34
+ + th.exp(logvar1 - logvar2)
35
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36
+ )
37
+
38
+
39
+ def approx_standard_normal_cdf(x):
40
+ """
41
+ A fast approximation of the cumulative distribution function of the
42
+ standard normal.
43
+ """
44
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45
+
46
+
47
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48
+ """
49
+ Compute the log-likelihood of a continuous Gaussian distribution.
50
+ :param x: the targets
51
+ :param means: the Gaussian mean Tensor.
52
+ :param log_scales: the Gaussian log stddev Tensor.
53
+ :return: a tensor like x of log probabilities (in nats).
54
+ """
55
+ centered_x = x - means
56
+ inv_stdv = th.exp(-log_scales)
57
+ normalized_x = centered_x * inv_stdv
58
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59
+ return log_probs
60
+
61
+
62
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63
+ """
64
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
65
+ given image.
66
+ :param x: the target images. It is assumed that this was uint8 values,
67
+ rescaled to the range [-1, 1].
68
+ :param means: the Gaussian mean Tensor.
69
+ :param log_scales: the Gaussian log stddev Tensor.
70
+ :return: a tensor like x of log probabilities (in nats).
71
+ """
72
+ assert x.shape == means.shape == log_scales.shape
73
+ centered_x = x - means
74
+ inv_stdv = th.exp(-log_scales)
75
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76
+ cdf_plus = approx_standard_normal_cdf(plus_in)
77
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78
+ cdf_min = approx_standard_normal_cdf(min_in)
79
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81
+ cdf_delta = cdf_plus - cdf_min
82
+ log_probs = th.where(
83
+ x < -0.999,
84
+ log_cdf_plus,
85
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86
+ )
87
+ assert log_probs.shape == x.shape
88
+ return log_probs
diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch as th
11
+ import enum
12
+
13
+ from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14
+
15
+
16
+ def mean_flat(tensor):
17
+ """
18
+ Take the mean over all non-batch dimensions.
19
+ """
20
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
21
+
22
+
23
+ class ModelMeanType(enum.Enum):
24
+ """
25
+ Which type of output the model predicts.
26
+ """
27
+
28
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
29
+ START_X = enum.auto() # the model predicts x_0
30
+ EPSILON = enum.auto() # the model predicts epsilon
31
+
32
+
33
+ class ModelVarType(enum.Enum):
34
+ """
35
+ What is used as the model's output variance.
36
+ The LEARNED_RANGE option has been added to allow the model to predict
37
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
38
+ """
39
+
40
+ LEARNED = enum.auto()
41
+ FIXED_SMALL = enum.auto()
42
+ FIXED_LARGE = enum.auto()
43
+ LEARNED_RANGE = enum.auto()
44
+
45
+
46
+ class LossType(enum.Enum):
47
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
48
+ RESCALED_MSE = (
49
+ enum.auto()
50
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
51
+ KL = enum.auto() # use the variational lower-bound
52
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
53
+
54
+ def is_vb(self):
55
+ return self == LossType.KL or self == LossType.RESCALED_KL
56
+
57
+
58
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
59
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
60
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
61
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
62
+ return betas
63
+
64
+
65
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
66
+ """
67
+ This is the deprecated API for creating beta schedules.
68
+ See get_named_beta_schedule() for the new library of schedules.
69
+ """
70
+ if beta_schedule == "quad":
71
+ betas = (
72
+ np.linspace(
73
+ beta_start ** 0.5,
74
+ beta_end ** 0.5,
75
+ num_diffusion_timesteps,
76
+ dtype=np.float64,
77
+ )
78
+ ** 2
79
+ )
80
+ elif beta_schedule == "linear":
81
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
82
+ elif beta_schedule == "warmup10":
83
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
84
+ elif beta_schedule == "warmup50":
85
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
86
+ elif beta_schedule == "const":
87
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
88
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
89
+ betas = 1.0 / np.linspace(
90
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
91
+ )
92
+ else:
93
+ raise NotImplementedError(beta_schedule)
94
+ assert betas.shape == (num_diffusion_timesteps,)
95
+ return betas
96
+
97
+
98
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
99
+ """
100
+ Get a pre-defined beta schedule for the given name.
101
+ The beta schedule library consists of beta schedules which remain similar
102
+ in the limit of num_diffusion_timesteps.
103
+ Beta schedules may be added, but should not be removed or changed once
104
+ they are committed to maintain backwards compatibility.
105
+ """
106
+ if schedule_name == "linear":
107
+ # Linear schedule from Ho et al, extended to work for any number of
108
+ # diffusion steps.
109
+ scale = 1000 / num_diffusion_timesteps
110
+ return get_beta_schedule(
111
+ "linear",
112
+ beta_start=scale * 0.0001,
113
+ beta_end=scale * 0.02,
114
+ num_diffusion_timesteps=num_diffusion_timesteps,
115
+ )
116
+ elif schedule_name == "squaredcos_cap_v2":
117
+ return betas_for_alpha_bar(
118
+ num_diffusion_timesteps,
119
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
120
+ )
121
+ else:
122
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
123
+
124
+
125
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
126
+ """
127
+ Create a beta schedule that discretizes the given alpha_t_bar function,
128
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
129
+ :param num_diffusion_timesteps: the number of betas to produce.
130
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
131
+ produces the cumulative product of (1-beta) up to that
132
+ part of the diffusion process.
133
+ :param max_beta: the maximum beta to use; use values lower than 1 to
134
+ prevent singularities.
135
+ """
136
+ betas = []
137
+ for i in range(num_diffusion_timesteps):
138
+ t1 = i / num_diffusion_timesteps
139
+ t2 = (i + 1) / num_diffusion_timesteps
140
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
141
+ return np.array(betas)
142
+
143
+
144
+ class GaussianDiffusion:
145
+ """
146
+ Utilities for training and sampling diffusion models.
147
+ Original ported from this codebase:
148
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
149
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
150
+ starting at T and going to 1.
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ *,
156
+ betas,
157
+ model_mean_type,
158
+ model_var_type,
159
+ loss_type
160
+ ):
161
+
162
+ self.model_mean_type = model_mean_type
163
+ self.model_var_type = model_var_type
164
+ self.loss_type = loss_type
165
+
166
+ # Use float64 for accuracy.
167
+ betas = np.array(betas, dtype=np.float64)
168
+ self.betas = betas
169
+ assert len(betas.shape) == 1, "betas must be 1-D"
170
+ assert (betas > 0).all() and (betas <= 1).all()
171
+
172
+ self.num_timesteps = int(betas.shape[0])
173
+
174
+ alphas = 1.0 - betas
175
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
176
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
177
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
178
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
179
+
180
+ # calculations for diffusion q(x_t | x_{t-1}) and others
181
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
182
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
183
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
184
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
185
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
186
+
187
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
188
+ self.posterior_variance = (
189
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
190
+ )
191
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
192
+ self.posterior_log_variance_clipped = np.log(
193
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
194
+ ) if len(self.posterior_variance) > 1 else np.array([])
195
+
196
+ self.posterior_mean_coef1 = (
197
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
198
+ )
199
+ self.posterior_mean_coef2 = (
200
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
201
+ )
202
+
203
+ def q_mean_variance(self, x_start, t):
204
+ """
205
+ Get the distribution q(x_t | x_0).
206
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
207
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
208
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
209
+ """
210
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
211
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
212
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
213
+ return mean, variance, log_variance
214
+
215
+ def q_sample(self, x_start, t, noise=None):
216
+ """
217
+ Diffuse the data for a given number of diffusion steps.
218
+ In other words, sample from q(x_t | x_0).
219
+ :param x_start: the initial data batch.
220
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
221
+ :param noise: if specified, the split-out normal noise.
222
+ :return: A noisy version of x_start.
223
+ """
224
+ if noise is None:
225
+ noise = th.randn_like(x_start)
226
+ assert noise.shape == x_start.shape
227
+ return (
228
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
229
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
230
+ )
231
+
232
+ def q_posterior_mean_variance(self, x_start, x_t, t):
233
+ """
234
+ Compute the mean and variance of the diffusion posterior:
235
+ q(x_{t-1} | x_t, x_0)
236
+ """
237
+ assert x_start.shape == x_t.shape
238
+ posterior_mean = (
239
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
240
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
241
+ )
242
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
243
+ posterior_log_variance_clipped = _extract_into_tensor(
244
+ self.posterior_log_variance_clipped, t, x_t.shape
245
+ )
246
+ assert (
247
+ posterior_mean.shape[0]
248
+ == posterior_variance.shape[0]
249
+ == posterior_log_variance_clipped.shape[0]
250
+ == x_start.shape[0]
251
+ )
252
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
253
+
254
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
255
+ """
256
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
257
+ the initial x, x_0.
258
+ :param model: the model, which takes a signal and a batch of timesteps
259
+ as input.
260
+ :param x: the [N x C x ...] tensor at time t.
261
+ :param t: a 1-D Tensor of timesteps.
262
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
263
+ :param denoised_fn: if not None, a function which applies to the
264
+ x_start prediction before it is used to sample. Applies before
265
+ clip_denoised.
266
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
267
+ pass to the model. This can be used for conditioning.
268
+ :return: a dict with the following keys:
269
+ - 'mean': the model mean output.
270
+ - 'variance': the model variance output.
271
+ - 'log_variance': the log of 'variance'.
272
+ - 'pred_xstart': the prediction for x_0.
273
+ """
274
+ if model_kwargs is None:
275
+ model_kwargs = {}
276
+
277
+ B, F, C = x.shape[:3]
278
+ assert t.shape == (B,)
279
+ model_output = model(x, t, **model_kwargs)
280
+ # try:
281
+ # model_output = model_output.sample # for tav unet
282
+ # except:
283
+ # model_output = model(x, t, **model_kwargs)
284
+ if isinstance(model_output, tuple):
285
+ model_output, extra = model_output
286
+ else:
287
+ extra = None
288
+
289
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
290
+ assert model_output.shape == (B, F, C * 2, *x.shape[3:])
291
+ model_output, model_var_values = th.split(model_output, C, dim=2)
292
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
293
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
294
+ # The model_var_values is [-1, 1] for [min_var, max_var].
295
+ frac = (model_var_values + 1) / 2
296
+ model_log_variance = frac * max_log + (1 - frac) * min_log
297
+ model_variance = th.exp(model_log_variance)
298
+ else:
299
+ model_variance, model_log_variance = {
300
+ # for fixedlarge, we set the initial (log-)variance like so
301
+ # to get a better decoder log likelihood.
302
+ ModelVarType.FIXED_LARGE: (
303
+ np.append(self.posterior_variance[1], self.betas[1:]),
304
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
305
+ ),
306
+ ModelVarType.FIXED_SMALL: (
307
+ self.posterior_variance,
308
+ self.posterior_log_variance_clipped,
309
+ ),
310
+ }[self.model_var_type]
311
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
312
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
313
+
314
+ def process_xstart(x):
315
+ if denoised_fn is not None:
316
+ x = denoised_fn(x)
317
+ if clip_denoised:
318
+ return x.clamp(-1, 1)
319
+ return x
320
+
321
+ if self.model_mean_type == ModelMeanType.START_X:
322
+ pred_xstart = process_xstart(model_output)
323
+ else:
324
+ pred_xstart = process_xstart(
325
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
326
+ )
327
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
328
+
329
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
330
+ return {
331
+ "mean": model_mean,
332
+ "variance": model_variance,
333
+ "log_variance": model_log_variance,
334
+ "pred_xstart": pred_xstart,
335
+ "extra": extra,
336
+ }
337
+
338
+ def _predict_xstart_from_eps(self, x_t, t, eps):
339
+ assert x_t.shape == eps.shape
340
+ return (
341
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
342
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
343
+ )
344
+
345
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
346
+ return (
347
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
348
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
349
+
350
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
351
+ """
352
+ Compute the mean for the previous step, given a function cond_fn that
353
+ computes the gradient of a conditional log probability with respect to
354
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
355
+ condition on y.
356
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
357
+ """
358
+ gradient = cond_fn(x, t, **model_kwargs)
359
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
360
+ return new_mean
361
+
362
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
363
+ """
364
+ Compute what the p_mean_variance output would have been, should the
365
+ model's score function be conditioned by cond_fn.
366
+ See condition_mean() for details on cond_fn.
367
+ Unlike condition_mean(), this instead uses the conditioning strategy
368
+ from Song et al (2020).
369
+ """
370
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
371
+
372
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
373
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
374
+
375
+ out = p_mean_var.copy()
376
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
377
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
378
+ return out
379
+
380
+ def p_sample(
381
+ self,
382
+ model,
383
+ x,
384
+ t,
385
+ clip_denoised=True,
386
+ denoised_fn=None,
387
+ cond_fn=None,
388
+ model_kwargs=None,
389
+ ):
390
+ """
391
+ Sample x_{t-1} from the model at the given timestep.
392
+ :param model: the model to sample from.
393
+ :param x: the current tensor at x_{t-1}.
394
+ :param t: the value of t, starting at 0 for the first diffusion step.
395
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
396
+ :param denoised_fn: if not None, a function which applies to the
397
+ x_start prediction before it is used to sample.
398
+ :param cond_fn: if not None, this is a gradient function that acts
399
+ similarly to the model.
400
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
401
+ pass to the model. This can be used for conditioning.
402
+ :return: a dict containing the following keys:
403
+ - 'sample': a random sample from the model.
404
+ - 'pred_xstart': a prediction of x_0.
405
+ """
406
+ out = self.p_mean_variance(
407
+ model,
408
+ x,
409
+ t,
410
+ clip_denoised=clip_denoised,
411
+ denoised_fn=denoised_fn,
412
+ model_kwargs=model_kwargs,
413
+ )
414
+ noise = th.randn_like(x)
415
+ nonzero_mask = (
416
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
417
+ ) # no noise when t == 0
418
+ if cond_fn is not None:
419
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
420
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
421
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
422
+
423
+ def p_sample_loop(
424
+ self,
425
+ model,
426
+ shape,
427
+ noise=None,
428
+ clip_denoised=True,
429
+ denoised_fn=None,
430
+ cond_fn=None,
431
+ model_kwargs=None,
432
+ device=None,
433
+ progress=False,
434
+ ):
435
+ """
436
+ Generate samples from the model.
437
+ :param model: the model module.
438
+ :param shape: the shape of the samples, (N, C, H, W).
439
+ :param noise: if specified, the noise from the encoder to sample.
440
+ Should be of the same shape as `shape`.
441
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
442
+ :param denoised_fn: if not None, a function which applies to the
443
+ x_start prediction before it is used to sample.
444
+ :param cond_fn: if not None, this is a gradient function that acts
445
+ similarly to the model.
446
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
447
+ pass to the model. This can be used for conditioning.
448
+ :param device: if specified, the device to create the samples on.
449
+ If not specified, use a model parameter's device.
450
+ :param progress: if True, show a tqdm progress bar.
451
+ :return: a non-differentiable batch of samples.
452
+ """
453
+ final = None
454
+ for sample in self.p_sample_loop_progressive(
455
+ model,
456
+ shape,
457
+ noise=noise,
458
+ clip_denoised=clip_denoised,
459
+ denoised_fn=denoised_fn,
460
+ cond_fn=cond_fn,
461
+ model_kwargs=model_kwargs,
462
+ device=device,
463
+ progress=progress,
464
+ ):
465
+ final = sample
466
+ return final["sample"]
467
+
468
+ def p_sample_loop_progressive(
469
+ self,
470
+ model,
471
+ shape,
472
+ noise=None,
473
+ clip_denoised=True,
474
+ denoised_fn=None,
475
+ cond_fn=None,
476
+ model_kwargs=None,
477
+ device=None,
478
+ progress=False,
479
+ ):
480
+ """
481
+ Generate samples from the model and yield intermediate samples from
482
+ each timestep of diffusion.
483
+ Arguments are the same as p_sample_loop().
484
+ Returns a generator over dicts, where each dict is the return value of
485
+ p_sample().
486
+ """
487
+ if device is None:
488
+ device = next(model.parameters()).device
489
+ assert isinstance(shape, (tuple, list))
490
+ if noise is not None:
491
+ img = noise
492
+ else:
493
+ img = th.randn(*shape, device=device)
494
+ indices = list(range(self.num_timesteps))[::-1]
495
+
496
+ if progress:
497
+ # Lazy import so that we don't depend on tqdm.
498
+ from tqdm.auto import tqdm
499
+
500
+ indices = tqdm(indices)
501
+
502
+ for i in indices:
503
+ t = th.tensor([i] * shape[0], device=device)
504
+ with th.no_grad():
505
+ out = self.p_sample(
506
+ model,
507
+ img,
508
+ t,
509
+ clip_denoised=clip_denoised,
510
+ denoised_fn=denoised_fn,
511
+ cond_fn=cond_fn,
512
+ model_kwargs=model_kwargs,
513
+ )
514
+ yield out
515
+ img = out["sample"]
516
+
517
+ def ddim_sample(
518
+ self,
519
+ model,
520
+ x,
521
+ t,
522
+ clip_denoised=True,
523
+ denoised_fn=None,
524
+ cond_fn=None,
525
+ model_kwargs=None,
526
+ eta=0.0,
527
+ ):
528
+ """
529
+ Sample x_{t-1} from the model using DDIM.
530
+ Same usage as p_sample().
531
+ """
532
+ out = self.p_mean_variance(
533
+ model,
534
+ x,
535
+ t,
536
+ clip_denoised=clip_denoised,
537
+ denoised_fn=denoised_fn,
538
+ model_kwargs=model_kwargs,
539
+ )
540
+ if cond_fn is not None:
541
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
542
+
543
+ # Usually our model outputs epsilon, but we re-derive it
544
+ # in case we used x_start or x_prev prediction.
545
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
546
+
547
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
548
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
549
+ sigma = (
550
+ eta
551
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
552
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
553
+ )
554
+ # Equation 12.
555
+ noise = th.randn_like(x)
556
+ mean_pred = (
557
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
558
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
559
+ )
560
+ nonzero_mask = (
561
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
562
+ ) # no noise when t == 0
563
+ sample = mean_pred + nonzero_mask * sigma * noise
564
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
565
+
566
+ def ddim_reverse_sample(
567
+ self,
568
+ model,
569
+ x,
570
+ t,
571
+ clip_denoised=True,
572
+ denoised_fn=None,
573
+ cond_fn=None,
574
+ model_kwargs=None,
575
+ eta=0.0,
576
+ ):
577
+ """
578
+ Sample x_{t+1} from the model using DDIM reverse ODE.
579
+ """
580
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
581
+ out = self.p_mean_variance(
582
+ model,
583
+ x,
584
+ t,
585
+ clip_denoised=clip_denoised,
586
+ denoised_fn=denoised_fn,
587
+ model_kwargs=model_kwargs,
588
+ )
589
+ if cond_fn is not None:
590
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
591
+ # Usually our model outputs epsilon, but we re-derive it
592
+ # in case we used x_start or x_prev prediction.
593
+ eps = (
594
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
595
+ - out["pred_xstart"]
596
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
597
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
598
+
599
+ # Equation 12. reversed
600
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
601
+
602
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
603
+
604
+ def ddim_sample_loop(
605
+ self,
606
+ model,
607
+ shape,
608
+ noise=None,
609
+ clip_denoised=True,
610
+ denoised_fn=None,
611
+ cond_fn=None,
612
+ model_kwargs=None,
613
+ device=None,
614
+ progress=False,
615
+ eta=0.0,
616
+ ):
617
+ """
618
+ Generate samples from the model using DDIM.
619
+ Same usage as p_sample_loop().
620
+ """
621
+ final = None
622
+ for sample in self.ddim_sample_loop_progressive(
623
+ model,
624
+ shape,
625
+ noise=noise,
626
+ clip_denoised=clip_denoised,
627
+ denoised_fn=denoised_fn,
628
+ cond_fn=cond_fn,
629
+ model_kwargs=model_kwargs,
630
+ device=device,
631
+ progress=progress,
632
+ eta=eta,
633
+ ):
634
+ final = sample
635
+ return final["sample"]
636
+
637
+ def ddim_sample_loop_progressive(
638
+ self,
639
+ model,
640
+ shape,
641
+ noise=None,
642
+ clip_denoised=True,
643
+ denoised_fn=None,
644
+ cond_fn=None,
645
+ model_kwargs=None,
646
+ device=None,
647
+ progress=False,
648
+ eta=0.0,
649
+ ):
650
+ """
651
+ Use DDIM to sample from the model and yield intermediate samples from
652
+ each timestep of DDIM.
653
+ Same usage as p_sample_loop_progressive().
654
+ """
655
+ if device is None:
656
+ device = next(model.parameters()).device
657
+ assert isinstance(shape, (tuple, list))
658
+ if noise is not None:
659
+ img = noise
660
+ else:
661
+ img = th.randn(*shape, device=device)
662
+ indices = list(range(self.num_timesteps))[::-1]
663
+
664
+ if progress:
665
+ # Lazy import so that we don't depend on tqdm.
666
+ from tqdm.auto import tqdm
667
+
668
+ indices = tqdm(indices)
669
+
670
+ for i in indices:
671
+ t = th.tensor([i] * shape[0], device=device)
672
+ with th.no_grad():
673
+ out = self.ddim_sample(
674
+ model,
675
+ img,
676
+ t,
677
+ clip_denoised=clip_denoised,
678
+ denoised_fn=denoised_fn,
679
+ cond_fn=cond_fn,
680
+ model_kwargs=model_kwargs,
681
+ eta=eta,
682
+ )
683
+ yield out
684
+ img = out["sample"]
685
+
686
+ def _vb_terms_bpd(
687
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
688
+ ):
689
+ """
690
+ Get a term for the variational lower-bound.
691
+ The resulting units are bits (rather than nats, as one might expect).
692
+ This allows for comparison to other papers.
693
+ :return: a dict with the following keys:
694
+ - 'output': a shape [N] tensor of NLLs or KLs.
695
+ - 'pred_xstart': the x_0 predictions.
696
+ """
697
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
698
+ x_start=x_start, x_t=x_t, t=t
699
+ )
700
+ out = self.p_mean_variance(
701
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
702
+ )
703
+ kl = normal_kl(
704
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
705
+ )
706
+ kl = mean_flat(kl) / np.log(2.0)
707
+
708
+ decoder_nll = -discretized_gaussian_log_likelihood(
709
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
710
+ )
711
+ assert decoder_nll.shape == x_start.shape
712
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
713
+
714
+ # At the first timestep return the decoder NLL,
715
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
716
+ output = th.where((t == 0), decoder_nll, kl)
717
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
718
+
719
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
720
+ """
721
+ Compute training losses for a single timestep.
722
+ :param model: the model to evaluate loss on.
723
+ :param x_start: the [N x C x ...] tensor of inputs.
724
+ :param t: a batch of timestep indices.
725
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
726
+ pass to the model. This can be used for conditioning.
727
+ :param noise: if specified, the specific Gaussian noise to try to remove.
728
+ :return: a dict with the key "loss" containing a tensor of shape [N].
729
+ Some mean or variance settings may also have other keys.
730
+ """
731
+ if model_kwargs is None:
732
+ model_kwargs = {}
733
+ if noise is None:
734
+ noise = th.randn_like(x_start)
735
+ x_t = self.q_sample(x_start, t, noise=noise)
736
+
737
+ terms = {}
738
+
739
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
740
+ terms["loss"] = self._vb_terms_bpd(
741
+ model=model,
742
+ x_start=x_start,
743
+ x_t=x_t,
744
+ t=t,
745
+ clip_denoised=False,
746
+ model_kwargs=model_kwargs,
747
+ )["output"]
748
+ if self.loss_type == LossType.RESCALED_KL:
749
+ terms["loss"] *= self.num_timesteps
750
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
751
+ model_output = model(x_t, t, **model_kwargs)
752
+ # try:
753
+ # model_output = model(x_t, t, **model_kwargs).sample # for tav unet
754
+ # except:
755
+ # model_output = model(x_t, t, **model_kwargs)
756
+
757
+ if self.model_var_type in [
758
+ ModelVarType.LEARNED,
759
+ ModelVarType.LEARNED_RANGE,
760
+ ]:
761
+ B, F, C = x_t.shape[:3]
762
+ assert model_output.shape == (B, F, C * 2, *x_t.shape[3:])
763
+ model_output, model_var_values = th.split(model_output, C, dim=2)
764
+ # Learn the variance using the variational bound, but don't let
765
+ # it affect our mean prediction.
766
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=2)
767
+ terms["vb"] = self._vb_terms_bpd(
768
+ model=lambda *args, r=frozen_out: r,
769
+ x_start=x_start,
770
+ x_t=x_t,
771
+ t=t,
772
+ clip_denoised=False,
773
+ )["output"]
774
+ if self.loss_type == LossType.RESCALED_MSE:
775
+ # Divide by 1000 for equivalence with initial implementation.
776
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
777
+ terms["vb"] *= self.num_timesteps / 1000.0
778
+
779
+ target = {
780
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
781
+ x_start=x_start, x_t=x_t, t=t
782
+ )[0],
783
+ ModelMeanType.START_X: x_start,
784
+ ModelMeanType.EPSILON: noise,
785
+ }[self.model_mean_type]
786
+ assert model_output.shape == target.shape == x_start.shape
787
+ terms["mse"] = mean_flat((target - model_output) ** 2)
788
+ if "vb" in terms:
789
+ terms["loss"] = terms["mse"] + terms["vb"]
790
+ else:
791
+ terms["loss"] = terms["mse"]
792
+ else:
793
+ raise NotImplementedError(self.loss_type)
794
+
795
+ return terms
796
+
797
+ def _prior_bpd(self, x_start):
798
+ """
799
+ Get the prior KL term for the variational lower-bound, measured in
800
+ bits-per-dim.
801
+ This term can't be optimized, as it only depends on the encoder.
802
+ :param x_start: the [N x C x ...] tensor of inputs.
803
+ :return: a batch of [N] KL values (in bits), one per batch element.
804
+ """
805
+ batch_size = x_start.shape[0]
806
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
807
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
808
+ kl_prior = normal_kl(
809
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
810
+ )
811
+ return mean_flat(kl_prior) / np.log(2.0)
812
+
813
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
814
+ """
815
+ Compute the entire variational lower-bound, measured in bits-per-dim,
816
+ as well as other related quantities.
817
+ :param model: the model to evaluate loss on.
818
+ :param x_start: the [N x C x ...] tensor of inputs.
819
+ :param clip_denoised: if True, clip denoised samples.
820
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
821
+ pass to the model. This can be used for conditioning.
822
+ :return: a dict containing the following keys:
823
+ - total_bpd: the total variational lower-bound, per batch element.
824
+ - prior_bpd: the prior term in the lower-bound.
825
+ - vb: an [N x T] tensor of terms in the lower-bound.
826
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
827
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
828
+ """
829
+ device = x_start.device
830
+ batch_size = x_start.shape[0]
831
+
832
+ vb = []
833
+ xstart_mse = []
834
+ mse = []
835
+ for t in list(range(self.num_timesteps))[::-1]:
836
+ t_batch = th.tensor([t] * batch_size, device=device)
837
+ noise = th.randn_like(x_start)
838
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
839
+ # Calculate VLB term at the current timestep
840
+ with th.no_grad():
841
+ out = self._vb_terms_bpd(
842
+ model,
843
+ x_start=x_start,
844
+ x_t=x_t,
845
+ t=t_batch,
846
+ clip_denoised=clip_denoised,
847
+ model_kwargs=model_kwargs,
848
+ )
849
+ vb.append(out["output"])
850
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
851
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
852
+ mse.append(mean_flat((eps - noise) ** 2))
853
+
854
+ vb = th.stack(vb, dim=1)
855
+ xstart_mse = th.stack(xstart_mse, dim=1)
856
+ mse = th.stack(mse, dim=1)
857
+
858
+ prior_bpd = self._prior_bpd(x_start)
859
+ total_bpd = vb.sum(dim=1) + prior_bpd
860
+ return {
861
+ "total_bpd": total_bpd,
862
+ "prior_bpd": prior_bpd,
863
+ "vb": vb,
864
+ "xstart_mse": xstart_mse,
865
+ "mse": mse,
866
+ }
867
+
868
+
869
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
870
+ """
871
+ Extract values from a 1-D numpy array for a batch of indices.
872
+ :param arr: the 1-D numpy array.
873
+ :param timesteps: a tensor of indices into the array to extract.
874
+ :param broadcast_shape: a larger shape of K dimensions with the batch
875
+ dimension equal to the length of timesteps.
876
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
877
+ """
878
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
879
+ while len(res.shape) < len(broadcast_shape):
880
+ res = res[..., None]
881
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
diffusion/respace.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+ import torch
6
+ import numpy as np
7
+ import torch as th
8
+
9
+ from .gaussian_diffusion import GaussianDiffusion
10
+
11
+
12
+ def space_timesteps(num_timesteps, section_counts):
13
+ """
14
+ Create a list of timesteps to use from an original diffusion process,
15
+ given the number of timesteps we want to take from equally-sized portions
16
+ of the original process.
17
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
18
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
19
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
20
+ If the stride is a string starting with "ddim", then the fixed striding
21
+ from the DDIM paper is used, and only one section is allowed.
22
+ :param num_timesteps: the number of diffusion steps in the original
23
+ process to divide up.
24
+ :param section_counts: either a list of numbers, or a string containing
25
+ comma-separated numbers, indicating the step count
26
+ per section. As a special case, use "ddimN" where N
27
+ is a number of steps to use the striding from the
28
+ DDIM paper.
29
+ :return: a set of diffusion steps from the original process to use.
30
+ """
31
+ if isinstance(section_counts, str):
32
+ if section_counts.startswith("ddim"):
33
+ desired_count = int(section_counts[len("ddim") :])
34
+ for i in range(1, num_timesteps):
35
+ if len(range(0, num_timesteps, i)) == desired_count:
36
+ return set(range(0, num_timesteps, i))
37
+ raise ValueError(
38
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
39
+ )
40
+ section_counts = [int(x) for x in section_counts.split(",")]
41
+ size_per = num_timesteps // len(section_counts)
42
+ extra = num_timesteps % len(section_counts)
43
+ start_idx = 0
44
+ all_steps = []
45
+ for i, section_count in enumerate(section_counts):
46
+ size = size_per + (1 if i < extra else 0)
47
+ if size < section_count:
48
+ raise ValueError(
49
+ f"cannot divide section of {size} steps into {section_count}"
50
+ )
51
+ if section_count <= 1:
52
+ frac_stride = 1
53
+ else:
54
+ frac_stride = (size - 1) / (section_count - 1)
55
+ cur_idx = 0.0
56
+ taken_steps = []
57
+ for _ in range(section_count):
58
+ taken_steps.append(start_idx + round(cur_idx))
59
+ cur_idx += frac_stride
60
+ all_steps += taken_steps
61
+ start_idx += size
62
+ return set(all_steps)
63
+
64
+
65
+ class SpacedDiffusion(GaussianDiffusion):
66
+ """
67
+ A diffusion process which can skip steps in a base diffusion process.
68
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
69
+ original diffusion process to retain.
70
+ :param kwargs: the kwargs to create the base diffusion process.
71
+ """
72
+
73
+ def __init__(self, use_timesteps, **kwargs):
74
+ self.use_timesteps = set(use_timesteps)
75
+ self.timestep_map = []
76
+ self.original_num_steps = len(kwargs["betas"])
77
+
78
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79
+ last_alpha_cumprod = 1.0
80
+ new_betas = []
81
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82
+ if i in self.use_timesteps:
83
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84
+ last_alpha_cumprod = alpha_cumprod
85
+ self.timestep_map.append(i)
86
+ kwargs["betas"] = np.array(new_betas)
87
+ super().__init__(**kwargs)
88
+
89
+ def p_mean_variance(
90
+ self, model, *args, **kwargs
91
+ ): # pylint: disable=signature-differs
92
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93
+
94
+ # @torch.compile
95
+ def training_losses(
96
+ self, model, *args, **kwargs
97
+ ): # pylint: disable=signature-differs
98
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
99
+
100
+ def condition_mean(self, cond_fn, *args, **kwargs):
101
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
102
+
103
+ def condition_score(self, cond_fn, *args, **kwargs):
104
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
105
+
106
+ def _wrap_model(self, model):
107
+ if isinstance(model, _WrappedModel):
108
+ return model
109
+ return _WrappedModel(
110
+ model, self.timestep_map, self.original_num_steps
111
+ )
112
+
113
+ def _scale_timesteps(self, t):
114
+ # Scaling is done by the wrapped model.
115
+ return t
116
+
117
+
118
+ class _WrappedModel:
119
+ def __init__(self, model, timestep_map, original_num_steps):
120
+ self.model = model
121
+ self.timestep_map = timestep_map
122
+ # self.rescale_timesteps = rescale_timesteps
123
+ self.original_num_steps = original_num_steps
124
+
125
+ def __call__(self, x, ts, **kwargs):
126
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
127
+ new_ts = map_tensor[ts]
128
+ # if self.rescale_timesteps:
129
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
130
+ return self.model(x, new_ts, **kwargs)
diffusion/timestep_sampler.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import numpy as np
9
+ import torch as th
10
+ import torch.distributed as dist
11
+
12
+
13
+ def create_named_schedule_sampler(name, diffusion):
14
+ """
15
+ Create a ScheduleSampler from a library of pre-defined samplers.
16
+ :param name: the name of the sampler.
17
+ :param diffusion: the diffusion object to sample for.
18
+ """
19
+ if name == "uniform":
20
+ return UniformSampler(diffusion)
21
+ elif name == "loss-second-moment":
22
+ return LossSecondMomentResampler(diffusion)
23
+ else:
24
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
25
+
26
+
27
+ class ScheduleSampler(ABC):
28
+ """
29
+ A distribution over timesteps in the diffusion process, intended to reduce
30
+ variance of the objective.
31
+ By default, samplers perform unbiased importance sampling, in which the
32
+ objective's mean is unchanged.
33
+ However, subclasses may override sample() to change how the resampled
34
+ terms are reweighted, allowing for actual changes in the objective.
35
+ """
36
+
37
+ @abstractmethod
38
+ def weights(self):
39
+ """
40
+ Get a numpy array of weights, one per diffusion step.
41
+ The weights needn't be normalized, but must be positive.
42
+ """
43
+
44
+ def sample(self, batch_size, device):
45
+ """
46
+ Importance-sample timesteps for a batch.
47
+ :param batch_size: the number of timesteps.
48
+ :param device: the torch device to save to.
49
+ :return: a tuple (timesteps, weights):
50
+ - timesteps: a tensor of timestep indices.
51
+ - weights: a tensor of weights to scale the resulting losses.
52
+ """
53
+ w = self.weights()
54
+ p = w / np.sum(w)
55
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56
+ indices = th.from_numpy(indices_np).long().to(device)
57
+ weights_np = 1 / (len(p) * p[indices_np])
58
+ weights = th.from_numpy(weights_np).float().to(device)
59
+ return indices, weights
60
+
61
+
62
+ class UniformSampler(ScheduleSampler):
63
+ def __init__(self, diffusion):
64
+ self.diffusion = diffusion
65
+ self._weights = np.ones([diffusion.num_timesteps])
66
+
67
+ def weights(self):
68
+ return self._weights
69
+
70
+
71
+ class LossAwareSampler(ScheduleSampler):
72
+ def update_with_local_losses(self, local_ts, local_losses):
73
+ """
74
+ Update the reweighting using losses from a model.
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+ :param local_ts: an integer Tensor of timesteps.
80
+ :param local_losses: a 1D Tensor of losses.
81
+ """
82
+ batch_sizes = [
83
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
84
+ for _ in range(dist.get_world_size())
85
+ ]
86
+ dist.all_gather(
87
+ batch_sizes,
88
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89
+ )
90
+
91
+ # Pad all_gather batches to be the maximum batch size.
92
+ batch_sizes = [x.item() for x in batch_sizes]
93
+ max_bs = max(batch_sizes)
94
+
95
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97
+ dist.all_gather(timestep_batches, local_ts)
98
+ dist.all_gather(loss_batches, local_losses)
99
+ timesteps = [
100
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101
+ ]
102
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103
+ self.update_with_all_losses(timesteps, losses)
104
+
105
+ @abstractmethod
106
+ def update_with_all_losses(self, ts, losses):
107
+ """
108
+ Update the reweighting using losses from a model.
109
+ Sub-classes should override this method to update the reweighting
110
+ using losses from the model.
111
+ This method directly updates the reweighting without synchronizing
112
+ between workers. It is called by update_with_local_losses from all
113
+ ranks with identical arguments. Thus, it should have deterministic
114
+ behavior to maintain state across workers.
115
+ :param ts: a list of int timesteps.
116
+ :param losses: a list of float losses, one per timestep.
117
+ """
118
+
119
+
120
+ class LossSecondMomentResampler(LossAwareSampler):
121
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122
+ self.diffusion = diffusion
123
+ self.history_per_term = history_per_term
124
+ self.uniform_prob = uniform_prob
125
+ self._loss_history = np.zeros(
126
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
127
+ )
128
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129
+
130
+ def weights(self):
131
+ if not self._warmed_up():
132
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134
+ weights /= np.sum(weights)
135
+ weights *= 1 - self.uniform_prob
136
+ weights += self.uniform_prob / len(weights)
137
+ return weights
138
+
139
+ def update_with_all_losses(self, ts, losses):
140
+ for t, loss in zip(ts, losses):
141
+ if self._loss_counts[t] == self.history_per_term:
142
+ # Shift out the oldest loss term.
143
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
144
+ self._loss_history[t, -1] = loss
145
+ else:
146
+ self._loss_history[t, self._loss_counts[t]] = loss
147
+ self._loss_counts[t] += 1
148
+
149
+ def _warmed_up(self):
150
+ return (self._loss_counts == self.history_per_term).all()
docs/datasets_evaluation.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Download datasets
2
+
3
+ Here are the links to download the datasets [FaceForensics](https://huggingface.co/datasets/maxin-cn/FaceForensics), [SkyTimelapse](https://huggingface.co/datasets/maxin-cn/SkyTimelapse/tree/main), [UCF101](https://www.crcv.ucf.edu/data/UCF101/UCF101.rar), and [Taichi-HD](https://huggingface.co/datasets/maxin-cn/Taichi-HD).
4
+
5
+
6
+ ## Dataset structure
7
+
8
+ All datasets follow their original dataset structure. As for video-image joint training, there is a `train_list.txt` file, whose format is `video_name/frame.jpg`. Here, we show an example of the FaceForensics datsset.
9
+
10
+ All datasets retain their original structure. For video-image joint training, there is a `train_list.txt` file formatted as `video_name/frame.jpg`. Below is an example from the FaceForensics dataset.
11
+
12
+ ```bash
13
+ aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000306.jpg
14
+ aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000111.jpg
15
+ aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000007.jpg
16
+ aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000057.jpg
17
+ aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000084.jpg
18
+ aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000268.jpg
19
+ aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000270.jpg
20
+ aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000259.jpg
21
+ aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000127.jpg
22
+ aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000099.jpg
23
+ aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000189.jpg
24
+ aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000228.jpg
25
+ aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000026.jpg
26
+ aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000081.jpg
27
+ aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000094.jpg
28
+ aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000223.jpg
29
+ aS62n5PdTIU_1_8WGsQ0Y7uyU_1/000055.jpg
30
+ qEnKi82wWgE_2_rJPM8EdWShs_1/000486.jpg
31
+ qEnKi82wWgE_2_rJPM8EdWShs_1/000396.jpg
32
+ qEnKi82wWgE_2_rJPM8EdWShs_1/000475.jpg
33
+ qEnKi82wWgE_2_rJPM8EdWShs_1/000028.jpg
34
+ qEnKi82wWgE_2_rJPM8EdWShs_1/000261.jpg
35
+ qEnKi82wWgE_2_rJPM8EdWShs_1/000294.jpg
36
+ qEnKi82wWgE_2_rJPM8EdWShs_1/000257.jpg
37
+ qEnKi82wWgE_2_rJPM8EdWShs_1/000490.jpg
38
+ qEnKi82wWgE_2_rJPM8EdWShs_1/000143.jpg
39
+ qEnKi82wWgE_2_rJPM8EdWShs_1/000190.jpg
40
+ qEnKi82wWgE_2_rJPM8EdWShs_1/000476.jpg
41
+ qEnKi82wWgE_2_rJPM8EdWShs_1/000397.jpg
42
+ qEnKi82wWgE_2_rJPM8EdWShs_1/000437.jpg
43
+ qEnKi82wWgE_2_rJPM8EdWShs_1/000071.jpg
44
+ ```
45
+
46
+ ## Evaluation
47
+
48
+ We follow [StyleGAN-V](https://github.com/universome/stylegan-v) to measure the quality of the generated video. The code for calculating the relevant metrics is located in [tools](../tools/) folder. To measure the quantitative metrics of your generated results, you need to put all the videos from real data into a folder and turn them into video frames (the same goes for fake data). Then you can run the following command on one GPU:
49
+
50
+ ```bash
51
+ # cd Latte
52
+ bash tools/eval_metrics.sh
53
+ ```
docs/latte_diffusers.md ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Requirements
2
+
3
+ Please follow [README](../README.md) to install the environment. After installation, update the version of `diffusers` at leaset to 0.30.0.
4
+
5
+ ## Inference
6
+
7
+ ```bash
8
+ from diffusers import LattePipeline
9
+ from diffusers.models import AutoencoderKLTemporalDecoder
10
+
11
+ from torchvision.utils import save_image
12
+
13
+ import torch
14
+ import imageio
15
+
16
+ torch.manual_seed(0)
17
+
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ video_length = 1 # 1 or 16
20
+ pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16).to(device)
21
+
22
+ # if you want to use the temporal decoder of VAE, please uncomment the following codes
23
+ # vae = AutoencoderKLTemporalDecoder.from_pretrained("maxin-cn/Latte-1", subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
24
+ # pipe.vae = vae
25
+
26
+ prompt = "a cat wearing sunglasses and working as a lifeguard at pool."
27
+ videos = pipe(prompt, video_length=video_length, output_type='pt').frames.cpu()
28
+
29
+ if video_length > 1:
30
+ videos = (videos.clamp(0, 1) * 255).to(dtype=torch.uint8) # convert to uint8
31
+ imageio.mimwrite('./latte_output.mp4', videos[0].permute(0, 2, 3, 1), fps=8, quality=5) # highest quality is 10, lowest is 0
32
+ else:
33
+ save_image(videos[0], './latte_output.png')
34
+ ```
35
+
36
+ ## Inference with 4/8-bit quantization
37
+ [@Aryan](https://github.com/a-r-r-o-w) provides a quantization solution for inference, which can reduce GPU memory from 17 GB to 9 GB. Note that please install `bitsandbytes` (`pip install bitsandbytes`).
38
+
39
+ ```bash
40
+ import gc
41
+
42
+ import torch
43
+ from diffusers import LattePipeline
44
+ from transformers import T5EncoderModel, BitsAndBytesConfig
45
+ import imageio
46
+ from torchvision.utils import save_image
47
+
48
+ torch.manual_seed(0)
49
+
50
+ def flush():
51
+ gc.collect()
52
+ torch.cuda.empty_cache()
53
+
54
+ def bytes_to_giga_bytes(bytes):
55
+ return bytes / 1024 / 1024 / 1024
56
+
57
+ video_length = 16
58
+ model_id = "maxin-cn/Latte-1/"
59
+
60
+ text_encoder = T5EncoderModel.from_pretrained(
61
+ model_id,
62
+ subfolder="text_encoder",
63
+ quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16),
64
+ device_map="auto",
65
+ )
66
+ pipe = LattePipeline.from_pretrained(
67
+ model_id,
68
+ text_encoder=text_encoder,
69
+ transformer=None,
70
+ device_map="balanced",
71
+ )
72
+
73
+ with torch.no_grad():
74
+ prompt = "a cat wearing sunglasses and working as a lifeguard at pool."
75
+ negative_prompt = ""
76
+ prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt, negative_prompt=negative_prompt)
77
+
78
+ del text_encoder
79
+ del pipe
80
+ flush()
81
+
82
+ pipe = LattePipeline.from_pretrained(
83
+ model_id,
84
+ text_encoder=None,
85
+ torch_dtype=torch.float16,
86
+ ).to("cuda")
87
+ # pipe.enable_vae_tiling()
88
+ # pipe.enable_vae_slicing()
89
+
90
+ videos = pipe(
91
+ video_length=video_length,
92
+ num_inference_steps=50,
93
+ negative_prompt=None,
94
+ prompt_embeds=prompt_embeds,
95
+ negative_prompt_embeds=negative_prompt_embeds,
96
+ output_type="pt",
97
+ ).frames.cpu()
98
+
99
+ print(f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB")
100
+
101
+ if video_length > 1:
102
+ videos = (videos.clamp(0, 1) * 255).to(dtype=torch.uint8) # convert to uint8
103
+ imageio.mimwrite('./latte_output.mp4', videos[0].permute(0, 2, 3, 1), fps=8, quality=5) # highest quality is 10, lowest is 0
104
+ else:
105
+ save_image(videos[0], './latte_output.png')
106
+ ```
environment.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: latte
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python >= 3.10
7
+ - pytorch > 2.0.0
8
+ - torchvision
9
+ - pytorch-cuda >= 11.7
10
+ - pip:
11
+ - timm
12
+ - diffusers[torch]==0.24.0
13
+ - accelerate
14
+ - tensorboard
15
+ - einops
16
+ - transformers
17
+ - av
18
+ - scikit-image
19
+ - decord
20
+ - pandas
21
+ - imageio-ffmpeg
22
+ - sentencepiece
23
+ - beautifulsoup4
24
+ - ftfy
25
+ - omegaconf
models/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.split(sys.path[0])[0])
4
+
5
+ from .latte import Latte_models
6
+ from .latte_img import LatteIMG_models
7
+ from .latte_t2v import LatteT2V
8
+
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+
11
+
12
+ def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit
13
+ from torch.optim.lr_scheduler import LambdaLR
14
+ def fn(step):
15
+ if warmup_steps > 0:
16
+ return min(step / warmup_steps, 1)
17
+ else:
18
+ return 1
19
+ return LambdaLR(optimizer, fn)
20
+
21
+
22
+ def get_lr_scheduler(optimizer, name, **kwargs):
23
+ if name == 'warmup':
24
+ return customized_lr_scheduler(optimizer, **kwargs)
25
+ elif name == 'cosine':
26
+ from torch.optim.lr_scheduler import CosineAnnealingLR
27
+ return CosineAnnealingLR(optimizer, **kwargs)
28
+ else:
29
+ raise NotImplementedError(name)
30
+
31
+ def get_models(args):
32
+ if 'LatteIMG' in args.model:
33
+ return LatteIMG_models[args.model](
34
+ input_size=args.latent_size,
35
+ num_classes=args.num_classes,
36
+ num_frames=args.num_frames,
37
+ learn_sigma=args.learn_sigma,
38
+ extras=args.extras
39
+ )
40
+ elif 'LatteT2V' in args.model:
41
+ return LatteT2V.from_pretrained(args.pretrained_model_path, subfolder="transformer", video_length=args.video_length)
42
+ elif 'Latte' in args.model:
43
+ return Latte_models[args.model](
44
+ input_size=args.latent_size,
45
+ num_classes=args.num_classes,
46
+ num_frames=args.num_frames,
47
+ learn_sigma=args.learn_sigma,
48
+ extras=args.extras
49
+ )
50
+ else:
51
+ raise '{} Model Not Supported!'.format(args.model)
52
+
models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (2.54 kB). View file
 
models/__pycache__/latte.cpython-312.pyc ADDED
Binary file (28.8 kB). View file
 
models/__pycache__/latte_img.cpython-312.pyc ADDED
Binary file (30.3 kB). View file
 
models/__pycache__/latte_t2v.cpython-312.pyc ADDED
Binary file (39 kB). View file
 
models/clip.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import torch.nn as nn
3
+ from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
4
+
5
+ import transformers
6
+ transformers.logging.set_verbosity_error()
7
+
8
+ """
9
+ Will encounter following warning:
10
+ - This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task
11
+ or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
12
+ - This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model
13
+ that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
14
+
15
+ https://github.com/CompVis/stable-diffusion/issues/97
16
+ according to this issue, this warning is safe.
17
+
18
+ This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion.
19
+ You can safely ignore the warning, it is not an error.
20
+
21
+ This clip usage is from U-ViT and same with Stable Diffusion.
22
+ """
23
+
24
+ class AbstractEncoder(nn.Module):
25
+ def __init__(self):
26
+ super().__init__()
27
+
28
+ def encode(self, *args, **kwargs):
29
+ raise NotImplementedError
30
+
31
+
32
+ class FrozenCLIPEmbedder(AbstractEncoder):
33
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
34
+ # def __init__(self, version="openai/clip-vit-huge-patch14", device="cuda", max_length=77):
35
+ def __init__(self, path, device="cuda", max_length=77):
36
+ super().__init__()
37
+ self.tokenizer = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer")
38
+ self.transformer = CLIPTextModel.from_pretrained(path, subfolder='text_encoder')
39
+ self.device = device
40
+ self.max_length = max_length
41
+ self.freeze()
42
+
43
+ def freeze(self):
44
+ self.transformer = self.transformer.eval()
45
+ for param in self.parameters():
46
+ param.requires_grad = False
47
+
48
+ def forward(self, text):
49
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
50
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
51
+ tokens = batch_encoding["input_ids"].to(self.device)
52
+ outputs = self.transformer(input_ids=tokens)
53
+
54
+ z = outputs.last_hidden_state
55
+ pooled_z = outputs.pooler_output
56
+ return z, pooled_z
57
+
58
+ def encode(self, text):
59
+ return self(text)
60
+
61
+
62
+ class TextEmbedder(nn.Module):
63
+ """
64
+ Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
65
+ """
66
+ def __init__(self, path, dropout_prob=0.1):
67
+ super().__init__()
68
+ self.text_encodder = FrozenCLIPEmbedder(path=path)
69
+ self.dropout_prob = dropout_prob
70
+
71
+ def token_drop(self, text_prompts, force_drop_ids=None):
72
+ """
73
+ Drops text to enable classifier-free guidance.
74
+ """
75
+ if force_drop_ids is None:
76
+ drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob
77
+ else:
78
+ # TODO
79
+ drop_ids = force_drop_ids == 1
80
+ labels = list(numpy.where(drop_ids, "", text_prompts))
81
+ # print(labels)
82
+ return labels
83
+
84
+ def forward(self, text_prompts, train, force_drop_ids=None):
85
+ use_dropout = self.dropout_prob > 0
86
+ if (train and use_dropout) or (force_drop_ids is not None):
87
+ text_prompts = self.token_drop(text_prompts, force_drop_ids)
88
+ embeddings, pooled_embeddings = self.text_encodder(text_prompts)
89
+ # return embeddings, pooled_embeddings
90
+ return pooled_embeddings
91
+
92
+
93
+ if __name__ == '__main__':
94
+
95
+ r"""
96
+ Returns:
97
+
98
+ Examples from CLIPTextModel:
99
+
100
+ ```python
101
+ >>> from transformers import AutoTokenizer, CLIPTextModel
102
+
103
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
104
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
105
+
106
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
107
+
108
+ >>> outputs = model(**inputs)
109
+ >>> last_hidden_state = outputs.last_hidden_state
110
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
111
+ ```"""
112
+
113
+ import torch
114
+
115
+ device = "cuda" if torch.cuda.is_available() else "cpu"
116
+
117
+ text_encoder = TextEmbedder(path='/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base',
118
+ dropout_prob=0.00001).to(device)
119
+
120
+ text_prompt = [["a photo of a cat", "a photo of a cat"], ["a photo of a dog", "a photo of a cat"], ['a photo of a dog human', "a photo of a cat"]]
121
+ # text_prompt = ('None', 'None', 'None')
122
+ output, pooled_output = text_encoder(text_prompts=text_prompt, train=False)
123
+ # print(output)
124
+ print(output.shape)
125
+ print(pooled_output.shape)
126
+ # print(output.shape)
models/latte.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # All rights reserved.
2
+
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # --------------------------------------------------------
6
+ # References:
7
+ # GLIDE: https://github.com/openai/glide-text2im
8
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
9
+ # --------------------------------------------------------
10
+ import math
11
+ import torch
12
+ import torch.nn as nn
13
+ import numpy as np
14
+
15
+ from einops import rearrange, repeat
16
+ from timm.models.vision_transformer import Mlp, PatchEmbed
17
+
18
+ # the xformers lib allows less memory, faster training and inference
19
+ try:
20
+ import xformers
21
+ import xformers.ops
22
+ except:
23
+ XFORMERS_IS_AVAILBLE = False
24
+
25
+ # from timm.models.layers.helpers import to_2tuple
26
+ # from timm.models.layers.trace_utils import _assert
27
+
28
+ def modulate(x, shift, scale):
29
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
30
+
31
+ #################################################################################
32
+ # Attention Layers from TIMM #
33
+ #################################################################################
34
+
35
+ class Attention(nn.Module):
36
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math'):
37
+ super().__init__()
38
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
39
+ self.num_heads = num_heads
40
+ head_dim = dim // num_heads
41
+ self.scale = head_dim ** -0.5
42
+ self.attention_mode = attention_mode
43
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
44
+ self.attn_drop = nn.Dropout(attn_drop)
45
+ self.proj = nn.Linear(dim, dim)
46
+ self.proj_drop = nn.Dropout(proj_drop)
47
+
48
+ def forward(self, x):
49
+ B, N, C = x.shape
50
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
51
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
52
+
53
+ if self.attention_mode == 'xformers': # cause loss nan while using with amp
54
+ # https://github.com/facebookresearch/xformers/blob/e8bd8f932c2f48e3a3171d06749eecbbf1de420c/xformers/ops/fmha/__init__.py#L135
55
+ q_xf = q.transpose(1,2).contiguous()
56
+ k_xf = k.transpose(1,2).contiguous()
57
+ v_xf = v.transpose(1,2).contiguous()
58
+ x = xformers.ops.memory_efficient_attention(q_xf, k_xf, v_xf).reshape(B, N, C)
59
+
60
+ elif self.attention_mode == 'flash':
61
+ # cause loss nan while using with amp
62
+ # Optionally use the context manager to ensure one of the fused kerenels is run
63
+ with torch.backends.cuda.sdp_kernel(enable_math=False):
64
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v).reshape(B, N, C) # require pytorch 2.0
65
+
66
+ elif self.attention_mode == 'math':
67
+ attn = (q @ k.transpose(-2, -1)) * self.scale
68
+ attn = attn.softmax(dim=-1)
69
+ attn = self.attn_drop(attn)
70
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
71
+
72
+ else:
73
+ raise NotImplemented
74
+
75
+ x = self.proj(x)
76
+ x = self.proj_drop(x)
77
+ return x
78
+
79
+
80
+ #################################################################################
81
+ # Embedding Layers for Timesteps and Class Labels #
82
+ #################################################################################
83
+
84
+ class TimestepEmbedder(nn.Module):
85
+ """
86
+ Embeds scalar timesteps into vector representations.
87
+ """
88
+ def __init__(self, hidden_size, frequency_embedding_size=256):
89
+ super().__init__()
90
+ self.mlp = nn.Sequential(
91
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
92
+ nn.SiLU(),
93
+ nn.Linear(hidden_size, hidden_size, bias=True),
94
+ )
95
+ self.frequency_embedding_size = frequency_embedding_size
96
+
97
+ @staticmethod
98
+ def timestep_embedding(t, dim, max_period=10000):
99
+ """
100
+ Create sinusoidal timestep embeddings.
101
+ :param t: a 1-D Tensor of N indices, one per batch element.
102
+ These may be fractional.
103
+ :param dim: the dimension of the output.
104
+ :param max_period: controls the minimum frequency of the embeddings.
105
+ :return: an (N, D) Tensor of positional embeddings.
106
+ """
107
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
108
+ half = dim // 2
109
+ freqs = torch.exp(
110
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
111
+ ).to(device=t.device)
112
+ args = t[:, None].float() * freqs[None]
113
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
114
+ if dim % 2:
115
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
116
+ return embedding
117
+
118
+ def forward(self, t, use_fp16=False):
119
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
120
+ if use_fp16:
121
+ t_freq = t_freq.to(dtype=torch.float16)
122
+ t_emb = self.mlp(t_freq)
123
+ return t_emb
124
+
125
+
126
+ class LabelEmbedder(nn.Module):
127
+ """
128
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
129
+ """
130
+ def __init__(self, num_classes, hidden_size, dropout_prob):
131
+ super().__init__()
132
+ use_cfg_embedding = dropout_prob > 0
133
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
134
+ self.num_classes = num_classes
135
+ self.dropout_prob = dropout_prob
136
+
137
+ def token_drop(self, labels, force_drop_ids=None):
138
+ """
139
+ Drops labels to enable classifier-free guidance.
140
+ """
141
+ if force_drop_ids is None:
142
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
143
+ else:
144
+ drop_ids = force_drop_ids == 1
145
+ labels = torch.where(drop_ids, self.num_classes, labels)
146
+ return labels
147
+
148
+ def forward(self, labels, train, force_drop_ids=None):
149
+ use_dropout = self.dropout_prob > 0
150
+ if (train and use_dropout) or (force_drop_ids is not None):
151
+ labels = self.token_drop(labels, force_drop_ids)
152
+ embeddings = self.embedding_table(labels)
153
+ return embeddings
154
+
155
+
156
+ #################################################################################
157
+ # Core Latte Model #
158
+ #################################################################################
159
+
160
+ class TransformerBlock(nn.Module):
161
+ """
162
+ A Latte tansformer block with adaptive layer norm zero (adaLN-Zero) conditioning.
163
+ """
164
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
165
+ super().__init__()
166
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
167
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
168
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
169
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
170
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
171
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
172
+ self.adaLN_modulation = nn.Sequential(
173
+ nn.SiLU(),
174
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
175
+ )
176
+
177
+ def forward(self, x, c):
178
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
179
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
180
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
181
+ return x
182
+
183
+
184
+ class FinalLayer(nn.Module):
185
+ """
186
+ The final layer of Latte.
187
+ """
188
+ def __init__(self, hidden_size, patch_size, out_channels):
189
+ super().__init__()
190
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
191
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
192
+ self.adaLN_modulation = nn.Sequential(
193
+ nn.SiLU(),
194
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
195
+ )
196
+
197
+ def forward(self, x, c):
198
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
199
+ x = modulate(self.norm_final(x), shift, scale)
200
+ x = self.linear(x)
201
+ return x
202
+
203
+
204
+ class Latte(nn.Module):
205
+ """
206
+ Diffusion model with a Transformer backbone.
207
+ """
208
+ def __init__(
209
+ self,
210
+ input_size=32,
211
+ patch_size=2,
212
+ in_channels=4,
213
+ hidden_size=1152,
214
+ depth=28,
215
+ num_heads=16,
216
+ mlp_ratio=4.0,
217
+ num_frames=16,
218
+ class_dropout_prob=0.1,
219
+ num_classes=1000,
220
+ learn_sigma=True,
221
+ extras=1,
222
+ attention_mode='math',
223
+ ):
224
+ super().__init__()
225
+ self.learn_sigma = learn_sigma
226
+ self.in_channels = in_channels
227
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
228
+ self.patch_size = patch_size
229
+ self.num_heads = num_heads
230
+ self.extras = extras
231
+ self.num_frames = num_frames
232
+
233
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
234
+ self.t_embedder = TimestepEmbedder(hidden_size)
235
+
236
+ if self.extras == 2:
237
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
238
+ if self.extras == 78: # timestep + text_embedding
239
+ self.text_embedding_projection = nn.Sequential(
240
+ nn.SiLU(),
241
+ nn.Linear(77 * 768, hidden_size, bias=True)
242
+ )
243
+
244
+ num_patches = self.x_embedder.num_patches
245
+ # Will use fixed sin-cos embedding:
246
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
247
+ self.temp_embed = nn.Parameter(torch.zeros(1, num_frames, hidden_size), requires_grad=False)
248
+ self.hidden_size = hidden_size
249
+
250
+ self.blocks = nn.ModuleList([
251
+ TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attention_mode=attention_mode) for _ in range(depth)
252
+ ])
253
+
254
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
255
+ self.initialize_weights()
256
+
257
+ def initialize_weights(self):
258
+ # Initialize transformer layers:
259
+ def _basic_init(module):
260
+ if isinstance(module, nn.Linear):
261
+ torch.nn.init.xavier_uniform_(module.weight)
262
+ if module.bias is not None:
263
+ nn.init.constant_(module.bias, 0)
264
+ self.apply(_basic_init)
265
+
266
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
267
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
268
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
269
+
270
+ temp_embed = get_1d_sincos_temp_embed(self.temp_embed.shape[-1], self.temp_embed.shape[-2])
271
+ self.temp_embed.data.copy_(torch.from_numpy(temp_embed).float().unsqueeze(0))
272
+
273
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
274
+ w = self.x_embedder.proj.weight.data
275
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
276
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
277
+
278
+ if self.extras == 2:
279
+ # Initialize label embedding table:
280
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
281
+
282
+ # Initialize timestep embedding MLP:
283
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
284
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
285
+
286
+ # Zero-out adaLN modulation layers in Latte blocks:
287
+ for block in self.blocks:
288
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
289
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
290
+
291
+ # Zero-out output layers:
292
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
293
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
294
+ nn.init.constant_(self.final_layer.linear.weight, 0)
295
+ nn.init.constant_(self.final_layer.linear.bias, 0)
296
+
297
+ def unpatchify(self, x):
298
+ """
299
+ x: (N, T, patch_size**2 * C)
300
+ imgs: (N, H, W, C)
301
+ """
302
+ c = self.out_channels
303
+ p = self.x_embedder.patch_size[0]
304
+ h = w = int(x.shape[1] ** 0.5)
305
+ assert h * w == x.shape[1]
306
+
307
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
308
+ x = torch.einsum('nhwpqc->nchpwq', x)
309
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
310
+ return imgs
311
+
312
+ # @torch.cuda.amp.autocast()
313
+ # @torch.compile
314
+ def forward(self,
315
+ x,
316
+ t,
317
+ y=None,
318
+ text_embedding=None,
319
+ use_fp16=False):
320
+ """
321
+ Forward pass of Latte.
322
+ x: (N, F, C, H, W) tensor of video inputs
323
+ t: (N,) tensor of diffusion timesteps
324
+ y: (N,) tensor of class labels
325
+ """
326
+ if use_fp16:
327
+ x = x.to(dtype=torch.float16)
328
+
329
+ batches, frames, channels, high, weight = x.shape
330
+ x = rearrange(x, 'b f c h w -> (b f) c h w')
331
+ x = self.x_embedder(x) + self.pos_embed
332
+ t = self.t_embedder(t, use_fp16=use_fp16)
333
+ timestep_spatial = repeat(t, 'n d -> (n c) d', c=self.temp_embed.shape[1])
334
+ timestep_temp = repeat(t, 'n d -> (n c) d', c=self.pos_embed.shape[1])
335
+
336
+ if self.extras == 2:
337
+ y = self.y_embedder(y, self.training)
338
+ y_spatial = repeat(y, 'n d -> (n c) d', c=self.temp_embed.shape[1])
339
+ y_temp = repeat(y, 'n d -> (n c) d', c=self.pos_embed.shape[1])
340
+ elif self.extras == 78:
341
+ text_embedding = self.text_embedding_projection(text_embedding.reshape(batches, -1))
342
+ text_embedding_spatial = repeat(text_embedding, 'n d -> (n c) d', c=self.temp_embed.shape[1])
343
+ text_embedding_temp = repeat(text_embedding, 'n d -> (n c) d', c=self.pos_embed.shape[1])
344
+
345
+ for i in range(0, len(self.blocks), 2):
346
+ spatial_block, temp_block = self.blocks[i:i+2]
347
+ if self.extras == 2:
348
+ c = timestep_spatial + y_spatial
349
+ elif self.extras == 78:
350
+ c = timestep_spatial + text_embedding_spatial
351
+ else:
352
+ c = timestep_spatial
353
+ x = spatial_block(x, c)
354
+
355
+ x = rearrange(x, '(b f) t d -> (b t) f d', b=batches)
356
+ # Add Time Embedding
357
+ if i == 0:
358
+ x = x + self.temp_embed
359
+
360
+ if self.extras == 2:
361
+ c = timestep_temp + y_temp
362
+ elif self.extras == 78:
363
+ c = timestep_temp + text_embedding_temp
364
+ else:
365
+ c = timestep_temp
366
+
367
+ x = temp_block(x, c)
368
+ x = rearrange(x, '(b t) f d -> (b f) t d', b=batches)
369
+
370
+ if self.extras == 2:
371
+ c = timestep_spatial + y_spatial
372
+ else:
373
+ c = timestep_spatial
374
+ x = self.final_layer(x, c)
375
+ x = self.unpatchify(x)
376
+ x = rearrange(x, '(b f) c h w -> b f c h w', b=batches)
377
+ return x
378
+
379
+ def forward_with_cfg(self, x, t, y=None, cfg_scale=7.0, use_fp16=False, text_embedding=None):
380
+ """
381
+ Forward pass of Latte, but also batches the unconditional forward pass for classifier-free guidance.
382
+ """
383
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
384
+ half = x[: len(x) // 2]
385
+ combined = torch.cat([half, half], dim=0)
386
+ if use_fp16:
387
+ combined = combined.to(dtype=torch.float16)
388
+ model_out = self.forward(combined, t, y=y, use_fp16=use_fp16, text_embedding=text_embedding)
389
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
390
+ # three channels by default. The standard approach to cfg applies it to all channels.
391
+ # This can be done by uncommenting the following line and commenting-out the line following that.
392
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
393
+ # eps, rest = model_out[:, :3], model_out[:, 3:]
394
+ eps, rest = model_out[:, :, :4, ...], model_out[:, :, 4:, ...]
395
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
396
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
397
+ eps = torch.cat([half_eps, half_eps], dim=0)
398
+ return torch.cat([eps, rest], dim=2)
399
+
400
+
401
+ #################################################################################
402
+ # Sine/Cosine Positional Embedding Functions #
403
+ #################################################################################
404
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
405
+
406
+ def get_1d_sincos_temp_embed(embed_dim, length):
407
+ pos = torch.arange(0, length).unsqueeze(1)
408
+ return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
409
+
410
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
411
+ """
412
+ grid_size: int of the grid height and width
413
+ return:
414
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
415
+ """
416
+ grid_h = np.arange(grid_size, dtype=np.float32)
417
+ grid_w = np.arange(grid_size, dtype=np.float32)
418
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
419
+ grid = np.stack(grid, axis=0)
420
+
421
+ grid = grid.reshape([2, 1, grid_size, grid_size])
422
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
423
+ if cls_token and extra_tokens > 0:
424
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
425
+ return pos_embed
426
+
427
+
428
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
429
+ assert embed_dim % 2 == 0
430
+
431
+ # use half of dimensions to encode grid_h
432
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
433
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
434
+
435
+ emb = np.concatenate([emb_h, emb_w], axis=1)
436
+ return emb
437
+
438
+
439
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
440
+ """
441
+ embed_dim: output dimension for each position
442
+ pos: a list of positions to be encoded: size (M,)
443
+ out: (M, D)
444
+ """
445
+ assert embed_dim % 2 == 0
446
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
447
+ omega /= embed_dim / 2.
448
+ omega = 1. / 10000**omega
449
+
450
+ pos = pos.reshape(-1)
451
+ out = np.einsum('m,d->md', pos, omega)
452
+
453
+ emb_sin = np.sin(out)
454
+ emb_cos = np.cos(out)
455
+
456
+ emb = np.concatenate([emb_sin, emb_cos], axis=1)
457
+ return emb
458
+
459
+
460
+ #################################################################################
461
+ # Latte Configs #
462
+ #################################################################################
463
+
464
+ def Latte_XL_2(**kwargs):
465
+ return Latte(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
466
+
467
+ def Latte_XL_4(**kwargs):
468
+ return Latte(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
469
+
470
+ def Latte_XL_8(**kwargs):
471
+ return Latte(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
472
+
473
+ def Latte_L_2(**kwargs):
474
+ return Latte(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
475
+
476
+ def Latte_L_4(**kwargs):
477
+ return Latte(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
478
+
479
+ def Latte_L_8(**kwargs):
480
+ return Latte(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
481
+
482
+ def Latte_B_2(**kwargs):
483
+ return Latte(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
484
+
485
+ def Latte_B_4(**kwargs):
486
+ return Latte(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
487
+
488
+ def Latte_B_8(**kwargs):
489
+ return Latte(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
490
+
491
+ def Latte_S_2(**kwargs):
492
+ return Latte(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
493
+
494
+ def Latte_S_4(**kwargs):
495
+ return Latte(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
496
+
497
+ def Latte_S_8(**kwargs):
498
+ return Latte(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
499
+
500
+
501
+ Latte_models = {
502
+ 'Latte-XL/2': Latte_XL_2, 'Latte-XL/4': Latte_XL_4, 'Latte-XL/8': Latte_XL_8,
503
+ 'Latte-L/2': Latte_L_2, 'Latte-L/4': Latte_L_4, 'Latte-L/8': Latte_L_8,
504
+ 'Latte-B/2': Latte_B_2, 'Latte-B/4': Latte_B_4, 'Latte-B/8': Latte_B_8,
505
+ 'Latte-S/2': Latte_S_2, 'Latte-S/4': Latte_S_4, 'Latte-S/8': Latte_S_8,
506
+ }
507
+
508
+ if __name__ == '__main__':
509
+
510
+ import torch
511
+
512
+ device = "cuda" if torch.cuda.is_available() else "cpu"
513
+
514
+ img = torch.randn(3, 16, 4, 32, 32).to(device)
515
+ t = torch.tensor([1, 2, 3]).to(device)
516
+ y = torch.tensor([1, 2, 3]).to(device)
517
+ network = Latte_XL_2().to(device)
518
+ from thop import profile
519
+ flops, params = profile(network, inputs=(img, t))
520
+ print('FLOPs = ' + str(flops/1000**3) + 'G')
521
+ print('Params = ' + str(params/1000**2) + 'M')
522
+ # y_embeder = LabelEmbedder(num_classes=101, hidden_size=768, dropout_prob=0.5).to(device)
523
+ # lora.mark_only_lora_as_trainable(network)
524
+ # out = y_embeder(y, True)
525
+ # out = network(img, t, y)
526
+ # print(out.shape)
models/latte_img.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # All rights reserved.
2
+
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ # --------------------------------------------------------
6
+ # References:
7
+ # GLIDE: https://github.com/openai/glide-text2im
8
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
9
+ # --------------------------------------------------------
10
+ import math
11
+ import torch
12
+ import torch.nn as nn
13
+ import numpy as np
14
+
15
+ from einops import rearrange, repeat
16
+ from timm.models.vision_transformer import Mlp, PatchEmbed
17
+
18
+ import os
19
+ import sys
20
+ sys.path.append(os.path.split(sys.path[0])[0])
21
+
22
+ # the xformers lib allows less memory, faster training and inference
23
+ try:
24
+ import xformers
25
+ import xformers.ops
26
+ except:
27
+ XFORMERS_IS_AVAILBLE = False
28
+
29
+ # from timm.models.layers.helpers import to_2tuple
30
+ # from timm.models.layers.trace_utils import _assert
31
+
32
+ def modulate(x, shift, scale):
33
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
34
+
35
+ #################################################################################
36
+ # Attention Layers from TIMM #
37
+ #################################################################################
38
+
39
+ class Attention(nn.Module):
40
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math'):
41
+ super().__init__()
42
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
43
+ self.num_heads = num_heads
44
+ head_dim = dim // num_heads
45
+ self.scale = head_dim ** -0.5
46
+ self.attention_mode = attention_mode
47
+
48
+
49
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
50
+
51
+ self.attn_drop = nn.Dropout(attn_drop)
52
+ self.proj = nn.Linear(dim, dim)
53
+ self.proj_drop = nn.Dropout(proj_drop)
54
+
55
+ def forward(self, x):
56
+ B, N, C = x.shape
57
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
58
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
59
+
60
+ if self.attention_mode == 'xformers': # cause loss nan while using with amp
61
+ x = xformers.ops.memory_efficient_attention(q, k, v).reshape(B, N, C)
62
+
63
+ elif self.attention_mode == 'flash':
64
+ # cause loss nan while using with amp
65
+ # Optionally use the context manager to ensure one of the fused kerenels is run
66
+ with torch.backends.cuda.sdp_kernel(enable_math=False):
67
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v).reshape(B, N, C) # require pytorch 2.0
68
+
69
+ elif self.attention_mode == 'math':
70
+ attn = (q @ k.transpose(-2, -1)) * self.scale
71
+ attn = attn.softmax(dim=-1)
72
+ attn = self.attn_drop(attn)
73
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
74
+
75
+ else:
76
+ raise NotImplemented
77
+
78
+ x = self.proj(x)
79
+ x = self.proj_drop(x)
80
+ return x
81
+
82
+
83
+ #################################################################################
84
+ # Embedding Layers for Timesteps and Class Labels #
85
+ #################################################################################
86
+
87
+ class TimestepEmbedder(nn.Module):
88
+ """
89
+ Embeds scalar timesteps into vector representations.
90
+ """
91
+ def __init__(self, hidden_size, frequency_embedding_size=256):
92
+ super().__init__()
93
+ self.mlp = nn.Sequential(
94
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
95
+ nn.SiLU(),
96
+ nn.Linear(hidden_size, hidden_size, bias=True),
97
+ )
98
+ self.frequency_embedding_size = frequency_embedding_size
99
+
100
+ @staticmethod
101
+ def timestep_embedding(t, dim, max_period=10000):
102
+ """
103
+ Create sinusoidal timestep embeddings.
104
+ :param t: a 1-D Tensor of N indices, one per batch element.
105
+ These be fractional.
106
+ :param dim: the dimension of the output.
107
+ :param max_period: controls the minimum frequency of the embeddings.
108
+ :return: an (N, D) Tensor of positional embeddings.
109
+ """
110
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
111
+ half = dim // 2
112
+ freqs = torch.exp(
113
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
114
+ ).to(device=t.device)
115
+ args = t[:, None].float() * freqs[None]
116
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
117
+ if dim % 2:
118
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
119
+ return embedding
120
+
121
+ def forward(self, t, use_fp16=False):
122
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
123
+ if use_fp16:
124
+ t_freq = t_freq.to(dtype=torch.float16)
125
+ t_emb = self.mlp(t_freq)
126
+ return t_emb
127
+
128
+
129
+ class LabelEmbedder(nn.Module):
130
+ """
131
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
132
+ """
133
+ def __init__(self, num_classes, hidden_size, dropout_prob):
134
+ super().__init__()
135
+ use_cfg_embedding = dropout_prob > 0
136
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
137
+ self.num_classes = num_classes
138
+ self.dropout_prob = dropout_prob
139
+
140
+ def token_drop(self, labels, force_drop_ids=None):
141
+ """
142
+ Drops labels to enable classifier-free guidance.
143
+ """
144
+ if force_drop_ids is None:
145
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
146
+ else:
147
+ drop_ids = force_drop_ids == 1
148
+ labels = torch.where(drop_ids, self.num_classes, labels)
149
+ return labels
150
+
151
+ def forward(self, labels, train, force_drop_ids=None):
152
+ use_dropout = self.dropout_prob > 0
153
+ if (train and use_dropout) or (force_drop_ids is not None):
154
+ labels = self.token_drop(labels, force_drop_ids)
155
+ embeddings = self.embedding_table(labels)
156
+ return embeddings
157
+
158
+
159
+ #################################################################################
160
+ # Core Latte Model #
161
+ #################################################################################
162
+
163
+ class TransformerBlock(nn.Module):
164
+ """
165
+ A Latte block with adaptive layer norm zero (adaLN-Zero) conditioning.
166
+ """
167
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
168
+ super().__init__()
169
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
170
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
171
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
172
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
173
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
174
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
175
+ self.adaLN_modulation = nn.Sequential(
176
+ nn.SiLU(),
177
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
178
+ )
179
+
180
+ def forward(self, x, c):
181
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
182
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
183
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
184
+ return x
185
+
186
+
187
+ class FinalLayer(nn.Module):
188
+ """
189
+ The final layer of Latte.
190
+ """
191
+ def __init__(self, hidden_size, patch_size, out_channels):
192
+ super().__init__()
193
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
194
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
195
+ self.adaLN_modulation = nn.Sequential(
196
+ nn.SiLU(),
197
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
198
+ )
199
+
200
+ def forward(self, x, c):
201
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
202
+ x = modulate(self.norm_final(x), shift, scale)
203
+ x = self.linear(x)
204
+ return x
205
+
206
+
207
+ class Latte(nn.Module):
208
+ """
209
+ Diffusion model with a Transformer backbone.
210
+ """
211
+ def __init__(
212
+ self,
213
+ input_size=32,
214
+ patch_size=2,
215
+ in_channels=4,
216
+ hidden_size=1152,
217
+ depth=28,
218
+ num_heads=16,
219
+ mlp_ratio=4.0,
220
+ num_frames=16,
221
+ class_dropout_prob=0.1,
222
+ num_classes=1000,
223
+ learn_sigma=True,
224
+ extras=2,
225
+ attention_mode='math',
226
+ ):
227
+ super().__init__()
228
+ self.learn_sigma = learn_sigma
229
+ self.in_channels = in_channels
230
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
231
+ self.patch_size = patch_size
232
+ self.num_heads = num_heads
233
+ self.extras = extras
234
+ self.num_frames = num_frames
235
+
236
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
237
+ self.t_embedder = TimestepEmbedder(hidden_size)
238
+
239
+ if self.extras == 2:
240
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
241
+ if self.extras == 78: # timestep + text_embedding
242
+ self.text_embedding_projection = nn.Sequential(
243
+ nn.SiLU(),
244
+ nn.Linear(1024, hidden_size, bias=True)
245
+ )
246
+
247
+ num_patches = self.x_embedder.num_patches
248
+ # Will use fixed sin-cos embedding:
249
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
250
+ self.temp_embed = nn.Parameter(torch.zeros(1, num_frames, hidden_size), requires_grad=False)
251
+
252
+ self.blocks = nn.ModuleList([
253
+ TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attention_mode=attention_mode) for _ in range(depth)
254
+ ])
255
+
256
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
257
+ self.initialize_weights()
258
+
259
+ def initialize_weights(self):
260
+ # Initialize transformer layers:
261
+ def _basic_init(module):
262
+ if isinstance(module, nn.Linear):
263
+ torch.nn.init.xavier_uniform_(module.weight)
264
+ if module.bias is not None:
265
+ nn.init.constant_(module.bias, 0)
266
+ self.apply(_basic_init)
267
+
268
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
269
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
270
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
271
+
272
+ temp_embed = get_1d_sincos_temp_embed(self.temp_embed.shape[-1], self.temp_embed.shape[-2])
273
+ self.temp_embed.data.copy_(torch.from_numpy(temp_embed).float().unsqueeze(0))
274
+
275
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
276
+ w = self.x_embedder.proj.weight.data
277
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
278
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
279
+
280
+ if self.extras == 2:
281
+ # Initialize label embedding table:
282
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
283
+
284
+ # Initialize timestep embedding MLP:
285
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
286
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
287
+
288
+ # Zero-out adaLN modulation layers in Latte blocks:
289
+ for block in self.blocks:
290
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
291
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
292
+
293
+ # Zero-out output layers:
294
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
295
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
296
+ nn.init.constant_(self.final_layer.linear.weight, 0)
297
+ nn.init.constant_(self.final_layer.linear.bias, 0)
298
+
299
+ def unpatchify(self, x):
300
+ """
301
+ x: (N, T, patch_size**2 * C)
302
+ imgs: (N, H, W, C)
303
+ """
304
+ c = self.out_channels
305
+ p = self.x_embedder.patch_size[0]
306
+ h = w = int(x.shape[1] ** 0.5)
307
+ assert h * w == x.shape[1]
308
+
309
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
310
+ x = torch.einsum('nhwpqc->nchpwq', x)
311
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
312
+ return imgs
313
+
314
+ # @torch.cuda.amp.autocast()
315
+ # @torch.compile
316
+ def forward(self, x, t, y=None, use_fp16=False, y_image=None, use_image_num=0):
317
+ """
318
+ Forward pass of Latte.
319
+ x: (N, F, C, H, W) tensor of video inputs
320
+ t: (N,) tensor of diffusion timesteps
321
+ y: (N,) tensor of class labels
322
+ y_image: tensor of video frames
323
+ use_image_num: how many video frames are used
324
+ """
325
+ if use_fp16:
326
+ x = x.to(dtype=torch.float16)
327
+ batches, frames, channels, high, weight = x.shape
328
+ x = rearrange(x, 'b f c h w -> (b f) c h w')
329
+ x = self.x_embedder(x) + self.pos_embed
330
+ t = self.t_embedder(t, use_fp16=use_fp16)
331
+ timestep_spatial = repeat(t, 'n d -> (n c) d', c=self.temp_embed.shape[1] + use_image_num)
332
+ timestep_temp = repeat(t, 'n d -> (n c) d', c=self.pos_embed.shape[1])
333
+
334
+ if self.extras == 2:
335
+ y = self.y_embedder(y, self.training)
336
+ if self.training:
337
+ y_image_emb = []
338
+ # print(y_image)
339
+ for y_image_single in y_image:
340
+ # print(y_image_single)
341
+ y_image_single = y_image_single.reshape(1, -1)
342
+ y_image_emb.append(self.y_embedder(y_image_single, self.training))
343
+ y_image_emb = torch.cat(y_image_emb, dim=0)
344
+ y_spatial = repeat(y, 'n d -> n c d', c=self.temp_embed.shape[1])
345
+ y_spatial = torch.cat([y_spatial, y_image_emb], dim=1)
346
+ y_spatial = rearrange(y_spatial, 'n c d -> (n c) d')
347
+ else:
348
+ y_spatial = repeat(y, 'n d -> (n c) d', c=self.temp_embed.shape[1])
349
+
350
+ y_temp = repeat(y, 'n d -> (n c) d', c=self.pos_embed.shape[1])
351
+ elif self.extras == 78:
352
+ text_embedding = self.text_embedding_projection(text_embedding)
353
+ text_embedding_video = text_embedding[:, :1, :]
354
+ text_embedding_image = text_embedding[:, 1:, :]
355
+ text_embedding_video = repeat(text_embedding, 'n t d -> n (t c) d', c=self.temp_embed.shape[1])
356
+ text_embedding_spatial = torch.cat([text_embedding_video, text_embedding_image], dim=1)
357
+ text_embedding_spatial = rearrange(text_embedding_spatial, 'n t d -> (n t) d')
358
+ text_embedding_temp = repeat(text_embedding_video, 'n t d -> n (t c) d', c=self.pos_embed.shape[1])
359
+ text_embedding_temp = rearrange(text_embedding_temp, 'n t d -> (n t) d')
360
+
361
+ for i in range(0, len(self.blocks), 2):
362
+ spatial_block, temp_block = self.blocks[i:i+2]
363
+
364
+ if self.extras == 2:
365
+ c = timestep_spatial + y_spatial
366
+ elif self.extras == 78:
367
+ c = timestep_spatial + text_embedding_spatial
368
+ else:
369
+ c = timestep_spatial
370
+ x = spatial_block(x, c)
371
+
372
+ x = rearrange(x, '(b f) t d -> (b t) f d', b=batches)
373
+ x_video = x[:, :(frames-use_image_num), :]
374
+ x_image = x[:, (frames-use_image_num):, :]
375
+
376
+ # Add Time Embedding
377
+ if i == 0:
378
+ x_video = x_video + self.temp_embed
379
+
380
+ if self.extras == 2:
381
+ c = timestep_temp + y_temp
382
+ elif self.extras == 78:
383
+ c = timestep_temp + text_embedding_temp
384
+ else:
385
+ c = timestep_temp
386
+
387
+ x_video = temp_block(x_video, c)
388
+ x = torch.cat([x_video, x_image], dim=1)
389
+ x = rearrange(x, '(b t) f d -> (b f) t d', b=batches)
390
+
391
+ if self.extras == 2:
392
+ c = timestep_spatial + y_spatial
393
+ else:
394
+ c = timestep_spatial
395
+ x = self.final_layer(x, c)
396
+ x = self.unpatchify(x)
397
+ x = rearrange(x, '(b f) c h w -> b f c h w', b=batches)
398
+ # print(x.shape)
399
+ return x
400
+
401
+
402
+ def forward_with_cfg(self, x, t, y, cfg_scale, use_fp16=False):
403
+ """
404
+ Forward pass of Latte, but also batches the unconditional forward pass for classifier-free guidance.
405
+ """
406
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
407
+ half = x[: len(x) // 2]
408
+ combined = torch.cat([half, half], dim=0)
409
+ if use_fp16:
410
+ combined = combined.to(dtype=torch.float16)
411
+ model_out = self.forward(combined, t, y, use_fp16=use_fp16)
412
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
413
+ # three channels by default. The standard approach to cfg applies it to all channels.
414
+ # This can be done by uncommenting the following line and commenting-out the line following that.
415
+ # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
416
+ # eps, rest = model_out[:, :3], model_out[:, 3:]
417
+ eps, rest = model_out[:, :, :4, ...], model_out[:, :, 4:, ...] # 2 16 4 32 32
418
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
419
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
420
+ eps = torch.cat([half_eps, half_eps], dim=0)
421
+ return torch.cat([eps, rest], dim=2)
422
+
423
+
424
+ #################################################################################
425
+ # Sine/Cosine Positional Embedding Functions #
426
+ #################################################################################
427
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
428
+
429
+ def get_1d_sincos_temp_embed(embed_dim, length):
430
+ pos = torch.arange(0, length).unsqueeze(1)
431
+ return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
432
+
433
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
434
+ """
435
+ grid_size: int of the grid height and width
436
+ return:
437
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
438
+ """
439
+ grid_h = np.arange(grid_size, dtype=np.float32)
440
+ grid_w = np.arange(grid_size, dtype=np.float32)
441
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
442
+ grid = np.stack(grid, axis=0)
443
+
444
+ grid = grid.reshape([2, 1, grid_size, grid_size])
445
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
446
+ if cls_token and extra_tokens > 0:
447
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
448
+ return pos_embed
449
+
450
+
451
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
452
+ assert embed_dim % 2 == 0
453
+
454
+ # use half of dimensions to encode grid_h
455
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
456
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
457
+
458
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
459
+ return emb
460
+
461
+
462
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
463
+ """
464
+ embed_dim: output dimension for each position
465
+ pos: a list of positions to be encoded: size (M,)
466
+ out: (M, D)
467
+ """
468
+ assert embed_dim % 2 == 0
469
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
470
+ omega /= embed_dim / 2.
471
+ omega = 1. / 10000**omega # (D/2,)
472
+
473
+ pos = pos.reshape(-1) # (M,)
474
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
475
+
476
+ emb_sin = np.sin(out) # (M, D/2)
477
+ emb_cos = np.cos(out) # (M, D/2)
478
+
479
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
480
+ return emb
481
+
482
+
483
+ #################################################################################
484
+ # Latte Configs #
485
+ #################################################################################
486
+
487
+ def Latte_XL_2(**kwargs):
488
+ return Latte(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
489
+
490
+ def Latte_XL_4(**kwargs):
491
+ return Latte(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
492
+
493
+ def Latte_XL_8(**kwargs):
494
+ return Latte(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
495
+
496
+ def Latte_L_2(**kwargs):
497
+ return Latte(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
498
+
499
+ def Latte_L_4(**kwargs):
500
+ return Latte(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
501
+
502
+ def Latte_L_8(**kwargs):
503
+ return Latte(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
504
+
505
+ def Latte_B_2(**kwargs):
506
+ return Latte(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
507
+
508
+ def Latte_B_4(**kwargs):
509
+ return Latte(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
510
+
511
+ def Latte_B_8(**kwargs):
512
+ return Latte(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
513
+
514
+ def Latte_S_2(**kwargs):
515
+ return Latte(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
516
+
517
+ def Latte_S_4(**kwargs):
518
+ return Latte(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
519
+
520
+ def Latte_S_8(**kwargs):
521
+ return Latte(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
522
+
523
+
524
+ LatteIMG_models = {
525
+ 'LatteIMG-XL/2': Latte_XL_2, 'LatteIMG-XL/4': Latte_XL_4, 'LatteIMG-XL/8': Latte_XL_8,
526
+ 'LatteIMG-L/2': Latte_L_2, 'LatteIMG-L/4': Latte_L_4, 'LatteIMG-L/8': Latte_L_8,
527
+ 'LatteIMG-B/2': Latte_B_2, 'LatteIMG-B/4': Latte_B_4, 'LatteIMG-B/8': Latte_B_8,
528
+ 'LatteIMG-S/2': Latte_S_2, 'LatteIMG-S/4': Latte_S_4, 'LatteIMG-S/8': Latte_S_8,
529
+ }
530
+
531
+ if __name__ == '__main__':
532
+ import torch
533
+
534
+ device = "cuda" if torch.cuda.is_available() else "cpu"
535
+
536
+ use_image_num = 8
537
+
538
+ img = torch.randn(3, 16+use_image_num, 4, 32, 32).to(device)
539
+
540
+ t = torch.tensor([1, 2, 3]).to(device)
541
+ y = torch.tensor([1, 2, 3]).to(device)
542
+ y_image = [torch.tensor([48, 37, 72, 63, 74, 6, 7, 8]).to(device),
543
+ torch.tensor([37, 72, 63, 74, 70, 1, 2, 3]).to(device),
544
+ torch.tensor([72, 63, 74, 70, 71, 5, 8, 7]).to(device),
545
+ ]
546
+
547
+
548
+ network = Latte_XL_2().to(device)
549
+ network.train()
550
+
551
+ out = network(img, t, y=y, y_image=y_image, use_image_num=use_image_num)
552
+ print(out.shape)
models/latte_t2v.py ADDED
@@ -0,0 +1,945 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import os
4
+ import json
5
+
6
+ from dataclasses import dataclass
7
+ from einops import rearrange, repeat
8
+ from typing import Any, Dict, Optional, Tuple
9
+ from diffusers.models import Transformer2DModel
10
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate
11
+ from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid, ImagePositionalEmbeddings, CaptionProjection, PatchEmbed, CombinedTimestepSizeEmbeddings
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.models.modeling_utils import ModelMixin
14
+ from diffusers.models.attention import BasicTransformerBlock
15
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
16
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
17
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
18
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
19
+ from diffusers.models.attention_processor import Attention
20
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
21
+
22
+ from dataclasses import dataclass
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ from torch import nn
27
+
28
+ @maybe_allow_in_graph
29
+ class GatedSelfAttentionDense(nn.Module):
30
+ r"""
31
+ A gated self-attention dense layer that combines visual features and object features.
32
+
33
+ Parameters:
34
+ query_dim (`int`): The number of channels in the query.
35
+ context_dim (`int`): The number of channels in the context.
36
+ n_heads (`int`): The number of heads to use for attention.
37
+ d_head (`int`): The number of channels in each head.
38
+ """
39
+
40
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
41
+ super().__init__()
42
+
43
+ # we need a linear projection since we need cat visual feature and obj feature
44
+ self.linear = nn.Linear(context_dim, query_dim)
45
+
46
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
47
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
48
+
49
+ self.norm1 = nn.LayerNorm(query_dim)
50
+ self.norm2 = nn.LayerNorm(query_dim)
51
+
52
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
53
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
54
+
55
+ self.enabled = True
56
+
57
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
58
+ if not self.enabled:
59
+ return x
60
+
61
+ n_visual = x.shape[1]
62
+ objs = self.linear(objs)
63
+
64
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
65
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
66
+
67
+ return x
68
+
69
+ class FeedForward(nn.Module):
70
+ r"""
71
+ A feed-forward layer.
72
+
73
+ Parameters:
74
+ dim (`int`): The number of channels in the input.
75
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
76
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
77
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
78
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
79
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ dim: int,
85
+ dim_out: Optional[int] = None,
86
+ mult: int = 4,
87
+ dropout: float = 0.0,
88
+ activation_fn: str = "geglu",
89
+ final_dropout: bool = False,
90
+ ):
91
+ super().__init__()
92
+ inner_dim = int(dim * mult)
93
+ dim_out = dim_out if dim_out is not None else dim
94
+ linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
95
+
96
+ if activation_fn == "gelu":
97
+ act_fn = GELU(dim, inner_dim)
98
+ if activation_fn == "gelu-approximate":
99
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
100
+ elif activation_fn == "geglu":
101
+ act_fn = GEGLU(dim, inner_dim)
102
+ elif activation_fn == "geglu-approximate":
103
+ act_fn = ApproximateGELU(dim, inner_dim)
104
+
105
+ self.net = nn.ModuleList([])
106
+ # project in
107
+ self.net.append(act_fn)
108
+ # project dropout
109
+ self.net.append(nn.Dropout(dropout))
110
+ # project out
111
+ self.net.append(linear_cls(inner_dim, dim_out))
112
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
113
+ if final_dropout:
114
+ self.net.append(nn.Dropout(dropout))
115
+
116
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
117
+ compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
118
+ for module in self.net:
119
+ if isinstance(module, compatible_cls):
120
+ hidden_states = module(hidden_states, scale)
121
+ else:
122
+ hidden_states = module(hidden_states)
123
+ return hidden_states
124
+
125
+ @maybe_allow_in_graph
126
+ class BasicTransformerBlock_(nn.Module):
127
+ r"""
128
+ A basic Transformer block.
129
+
130
+ Parameters:
131
+ dim (`int`): The number of channels in the input and output.
132
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
133
+ attention_head_dim (`int`): The number of channels in each head.
134
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
135
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
136
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
137
+ num_embeds_ada_norm (:
138
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
139
+ attention_bias (:
140
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
141
+ only_cross_attention (`bool`, *optional*):
142
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
143
+ double_self_attention (`bool`, *optional*):
144
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
145
+ upcast_attention (`bool`, *optional*):
146
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
147
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
148
+ Whether to use learnable elementwise affine parameters for normalization.
149
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
150
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
151
+ final_dropout (`bool` *optional*, defaults to False):
152
+ Whether to apply a final dropout after the last feed-forward layer.
153
+ attention_type (`str`, *optional*, defaults to `"default"`):
154
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
155
+ positional_embeddings (`str`, *optional*, defaults to `None`):
156
+ The type of positional embeddings to apply to.
157
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
158
+ The maximum number of positional embeddings to apply.
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ dim: int,
164
+ num_attention_heads: int,
165
+ attention_head_dim: int,
166
+ dropout=0.0,
167
+ cross_attention_dim: Optional[int] = None,
168
+ activation_fn: str = "geglu",
169
+ num_embeds_ada_norm: Optional[int] = None,
170
+ attention_bias: bool = False,
171
+ only_cross_attention: bool = False,
172
+ double_self_attention: bool = False,
173
+ upcast_attention: bool = False,
174
+ norm_elementwise_affine: bool = True,
175
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
176
+ norm_eps: float = 1e-5,
177
+ final_dropout: bool = False,
178
+ attention_type: str = "default",
179
+ positional_embeddings: Optional[str] = None,
180
+ num_positional_embeddings: Optional[int] = None,
181
+ ):
182
+ super().__init__()
183
+ self.only_cross_attention = only_cross_attention
184
+
185
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
186
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
187
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
188
+ self.use_layer_norm = norm_type == "layer_norm"
189
+
190
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
191
+ raise ValueError(
192
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
193
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
194
+ )
195
+
196
+ if positional_embeddings and (num_positional_embeddings is None):
197
+ raise ValueError(
198
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
199
+ )
200
+
201
+ if positional_embeddings == "sinusoidal":
202
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
203
+ else:
204
+ self.pos_embed = None
205
+
206
+ # Define 3 blocks. Each block has its own normalization layer.
207
+ # 1. Self-Attn
208
+ if self.use_ada_layer_norm:
209
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
210
+ elif self.use_ada_layer_norm_zero:
211
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
212
+ else:
213
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) # go here
214
+
215
+ self.attn1 = Attention(
216
+ query_dim=dim,
217
+ heads=num_attention_heads,
218
+ dim_head=attention_head_dim,
219
+ dropout=dropout,
220
+ bias=attention_bias,
221
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
222
+ upcast_attention=upcast_attention,
223
+ )
224
+
225
+ # # 2. Cross-Attn
226
+ # if cross_attention_dim is not None or double_self_attention:
227
+ # # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
228
+ # # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
229
+ # # the second cross attention block.
230
+ # self.norm2 = (
231
+ # AdaLayerNorm(dim, num_embeds_ada_norm)
232
+ # if self.use_ada_layer_norm
233
+ # else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
234
+ # )
235
+ # self.attn2 = Attention(
236
+ # query_dim=dim,
237
+ # cross_attention_dim=cross_attention_dim if not double_self_attention else None,
238
+ # heads=num_attention_heads,
239
+ # dim_head=attention_head_dim,
240
+ # dropout=dropout,
241
+ # bias=attention_bias,
242
+ # upcast_attention=upcast_attention,
243
+ # ) # is self-attn if encoder_hidden_states is none
244
+ # else:
245
+ # self.norm2 = None
246
+ # self.attn2 = None
247
+
248
+ # 3. Feed-forward
249
+ # if not self.use_ada_layer_norm_single:
250
+ # self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
251
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
252
+
253
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
254
+
255
+ # 4. Fuser
256
+ if attention_type == "gated" or attention_type == "gated-text-image":
257
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
258
+
259
+ # 5. Scale-shift for PixArt-Alpha.
260
+ if self.use_ada_layer_norm_single:
261
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
262
+
263
+ # let chunk size default to None
264
+ self._chunk_size = None
265
+ self._chunk_dim = 0
266
+
267
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
268
+ # Sets chunk feed-forward
269
+ self._chunk_size = chunk_size
270
+ self._chunk_dim = dim
271
+
272
+ def forward(
273
+ self,
274
+ hidden_states: torch.FloatTensor,
275
+ attention_mask: Optional[torch.FloatTensor] = None,
276
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
277
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
278
+ timestep: Optional[torch.LongTensor] = None,
279
+ cross_attention_kwargs: Dict[str, Any] = None,
280
+ class_labels: Optional[torch.LongTensor] = None,
281
+ ) -> torch.FloatTensor:
282
+ # Notice that normalization is always applied before the real computation in the following blocks.
283
+ # 0. Self-Attention
284
+ batch_size = hidden_states.shape[0]
285
+
286
+ if self.use_ada_layer_norm:
287
+ norm_hidden_states = self.norm1(hidden_states, timestep)
288
+ elif self.use_ada_layer_norm_zero:
289
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
290
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
291
+ )
292
+ elif self.use_layer_norm:
293
+ norm_hidden_states = self.norm1(hidden_states)
294
+ elif self.use_ada_layer_norm_single: # go here
295
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
296
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
297
+ ).chunk(6, dim=1)
298
+ norm_hidden_states = self.norm1(hidden_states)
299
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
300
+ # norm_hidden_states = norm_hidden_states.squeeze(1)
301
+ else:
302
+ raise ValueError("Incorrect norm used")
303
+
304
+ if self.pos_embed is not None:
305
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
306
+
307
+ # 1. Retrieve lora scale.
308
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
309
+
310
+ # 2. Prepare GLIGEN inputs
311
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
312
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
313
+
314
+ attn_output = self.attn1(
315
+ norm_hidden_states,
316
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
317
+ attention_mask=attention_mask,
318
+ **cross_attention_kwargs,
319
+ )
320
+ if self.use_ada_layer_norm_zero:
321
+ attn_output = gate_msa.unsqueeze(1) * attn_output
322
+ elif self.use_ada_layer_norm_single:
323
+ attn_output = gate_msa * attn_output
324
+
325
+ hidden_states = attn_output + hidden_states
326
+ if hidden_states.ndim == 4:
327
+ hidden_states = hidden_states.squeeze(1)
328
+
329
+ # 2.5 GLIGEN Control
330
+ if gligen_kwargs is not None:
331
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
332
+
333
+ # # 3. Cross-Attention
334
+ # if self.attn2 is not None:
335
+ # if self.use_ada_layer_norm:
336
+ # norm_hidden_states = self.norm2(hidden_states, timestep)
337
+ # elif self.use_ada_layer_norm_zero or self.use_layer_norm:
338
+ # norm_hidden_states = self.norm2(hidden_states)
339
+ # elif self.use_ada_layer_norm_single:
340
+ # # For PixArt norm2 isn't applied here:
341
+ # # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
342
+ # norm_hidden_states = hidden_states
343
+ # else:
344
+ # raise ValueError("Incorrect norm")
345
+
346
+ # if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
347
+ # norm_hidden_states = self.pos_embed(norm_hidden_states)
348
+
349
+ # attn_output = self.attn2(
350
+ # norm_hidden_states,
351
+ # encoder_hidden_states=encoder_hidden_states,
352
+ # attention_mask=encoder_attention_mask,
353
+ # **cross_attention_kwargs,
354
+ # )
355
+ # hidden_states = attn_output + hidden_states
356
+
357
+ # 4. Feed-forward
358
+ # if not self.use_ada_layer_norm_single:
359
+ # norm_hidden_states = self.norm3(hidden_states)
360
+
361
+ if self.use_ada_layer_norm_zero:
362
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
363
+
364
+ if self.use_ada_layer_norm_single:
365
+ # norm_hidden_states = self.norm2(hidden_states)
366
+ norm_hidden_states = self.norm3(hidden_states)
367
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
368
+
369
+ if self._chunk_size is not None:
370
+ # "feed_forward_chunk_size" can be used to save memory
371
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
372
+ raise ValueError(
373
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
374
+ )
375
+
376
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
377
+ ff_output = torch.cat(
378
+ [
379
+ self.ff(hid_slice, scale=lora_scale)
380
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
381
+ ],
382
+ dim=self._chunk_dim,
383
+ )
384
+ else:
385
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
386
+
387
+ if self.use_ada_layer_norm_zero:
388
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
389
+ elif self.use_ada_layer_norm_single:
390
+ ff_output = gate_mlp * ff_output
391
+
392
+ hidden_states = ff_output + hidden_states
393
+ if hidden_states.ndim == 4:
394
+ hidden_states = hidden_states.squeeze(1)
395
+
396
+ return hidden_states
397
+
398
+ class AdaLayerNormSingle(nn.Module):
399
+ r"""
400
+ Norm layer adaptive layer norm single (adaLN-single).
401
+
402
+ As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
403
+
404
+ Parameters:
405
+ embedding_dim (`int`): The size of each embedding vector.
406
+ use_additional_conditions (`bool`): To use additional conditions for normalization or not.
407
+ """
408
+
409
+ def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
410
+ super().__init__()
411
+
412
+ self.emb = CombinedTimestepSizeEmbeddings(
413
+ embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
414
+ )
415
+
416
+ self.silu = nn.SiLU()
417
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
418
+
419
+ def forward(
420
+ self,
421
+ timestep: torch.Tensor,
422
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
423
+ batch_size: int = None,
424
+ hidden_dtype: Optional[torch.dtype] = None,
425
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
426
+ # No modulation happening here.
427
+ embedded_timestep = self.emb(timestep, batch_size=batch_size, hidden_dtype=hidden_dtype, resolution=None, aspect_ratio=None)
428
+ return self.linear(self.silu(embedded_timestep)), embedded_timestep
429
+
430
+ @dataclass
431
+ class Transformer3DModelOutput(BaseOutput):
432
+ """
433
+ The output of [`Transformer2DModel`].
434
+
435
+ Args:
436
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
437
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
438
+ distributions for the unnoised latent pixels.
439
+ """
440
+
441
+ sample: torch.FloatTensor
442
+
443
+
444
+ class LatteT2V(ModelMixin, ConfigMixin):
445
+ _supports_gradient_checkpointing = True
446
+
447
+ """
448
+ A 2D Transformer model for image-like data.
449
+
450
+ Parameters:
451
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
452
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
453
+ in_channels (`int`, *optional*):
454
+ The number of channels in the input and output (specify if the input is **continuous**).
455
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
456
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
457
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
458
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
459
+ This is fixed during training since it is used to learn a number of position embeddings.
460
+ num_vector_embeds (`int`, *optional*):
461
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
462
+ Includes the class for the masked latent pixel.
463
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
464
+ num_embeds_ada_norm ( `int`, *optional*):
465
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
466
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
467
+ added to the hidden states.
468
+
469
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
470
+ attention_bias (`bool`, *optional*):
471
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
472
+ """
473
+
474
+ @register_to_config
475
+ def __init__(
476
+ self,
477
+ num_attention_heads: int = 16,
478
+ attention_head_dim: int = 88,
479
+ in_channels: Optional[int] = None,
480
+ out_channels: Optional[int] = None,
481
+ num_layers: int = 1,
482
+ dropout: float = 0.0,
483
+ norm_num_groups: int = 32,
484
+ cross_attention_dim: Optional[int] = None,
485
+ attention_bias: bool = False,
486
+ sample_size: Optional[int] = None,
487
+ num_vector_embeds: Optional[int] = None,
488
+ patch_size: Optional[int] = None,
489
+ activation_fn: str = "geglu",
490
+ num_embeds_ada_norm: Optional[int] = None,
491
+ use_linear_projection: bool = False,
492
+ only_cross_attention: bool = False,
493
+ double_self_attention: bool = False,
494
+ upcast_attention: bool = False,
495
+ norm_type: str = "layer_norm",
496
+ norm_elementwise_affine: bool = True,
497
+ norm_eps: float = 1e-5,
498
+ attention_type: str = "default",
499
+ caption_channels: int = None,
500
+ video_length: int = 16,
501
+ ):
502
+ super().__init__()
503
+ self.use_linear_projection = use_linear_projection
504
+ self.num_attention_heads = num_attention_heads
505
+ self.attention_head_dim = attention_head_dim
506
+ inner_dim = num_attention_heads * attention_head_dim
507
+ self.video_length = video_length
508
+
509
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
510
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
511
+
512
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
513
+ # Define whether input is continuous or discrete depending on configuration
514
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
515
+ self.is_input_vectorized = num_vector_embeds is not None
516
+ self.is_input_patches = in_channels is not None and patch_size is not None
517
+
518
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
519
+ deprecation_message = (
520
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
521
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
522
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
523
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
524
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
525
+ )
526
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
527
+ norm_type = "ada_norm"
528
+
529
+ if self.is_input_continuous and self.is_input_vectorized:
530
+ raise ValueError(
531
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
532
+ " sure that either `in_channels` or `num_vector_embeds` is None."
533
+ )
534
+ elif self.is_input_vectorized and self.is_input_patches:
535
+ raise ValueError(
536
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
537
+ " sure that either `num_vector_embeds` or `num_patches` is None."
538
+ )
539
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
540
+ raise ValueError(
541
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
542
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
543
+ )
544
+
545
+ # 2. Define input layers
546
+ if self.is_input_continuous:
547
+ self.in_channels = in_channels
548
+
549
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
550
+ if use_linear_projection:
551
+ self.proj_in = linear_cls(in_channels, inner_dim)
552
+ else:
553
+ self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
554
+ elif self.is_input_vectorized:
555
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
556
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
557
+
558
+ self.height = sample_size
559
+ self.width = sample_size
560
+ self.num_vector_embeds = num_vector_embeds
561
+ self.num_latent_pixels = self.height * self.width
562
+
563
+ self.latent_image_embedding = ImagePositionalEmbeddings(
564
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
565
+ )
566
+ elif self.is_input_patches:
567
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
568
+
569
+ self.height = sample_size
570
+ self.width = sample_size
571
+
572
+ self.patch_size = patch_size
573
+ interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
574
+ interpolation_scale = max(interpolation_scale, 1)
575
+ self.pos_embed = PatchEmbed(
576
+ height=sample_size,
577
+ width=sample_size,
578
+ patch_size=patch_size,
579
+ in_channels=in_channels,
580
+ embed_dim=inner_dim,
581
+ interpolation_scale=interpolation_scale,
582
+ )
583
+
584
+ # 3. Define transformers blocks
585
+ self.transformer_blocks = nn.ModuleList(
586
+ [
587
+ BasicTransformerBlock(
588
+ inner_dim,
589
+ num_attention_heads,
590
+ attention_head_dim,
591
+ dropout=dropout,
592
+ cross_attention_dim=cross_attention_dim,
593
+ activation_fn=activation_fn,
594
+ num_embeds_ada_norm=num_embeds_ada_norm,
595
+ attention_bias=attention_bias,
596
+ only_cross_attention=only_cross_attention,
597
+ double_self_attention=double_self_attention,
598
+ upcast_attention=upcast_attention,
599
+ norm_type=norm_type,
600
+ norm_elementwise_affine=norm_elementwise_affine,
601
+ norm_eps=norm_eps,
602
+ attention_type=attention_type,
603
+ )
604
+ for d in range(num_layers)
605
+ ]
606
+ )
607
+
608
+ # Define temporal transformers blocks
609
+ self.temporal_transformer_blocks = nn.ModuleList(
610
+ [
611
+ BasicTransformerBlock_( # one attention
612
+ inner_dim,
613
+ num_attention_heads, # num_attention_heads
614
+ attention_head_dim, # attention_head_dim 72
615
+ dropout=dropout,
616
+ cross_attention_dim=None,
617
+ activation_fn=activation_fn,
618
+ num_embeds_ada_norm=num_embeds_ada_norm,
619
+ attention_bias=attention_bias,
620
+ only_cross_attention=only_cross_attention,
621
+ double_self_attention=False,
622
+ upcast_attention=upcast_attention,
623
+ norm_type=norm_type,
624
+ norm_elementwise_affine=norm_elementwise_affine,
625
+ norm_eps=norm_eps,
626
+ attention_type=attention_type,
627
+ )
628
+ for d in range(num_layers)
629
+ ]
630
+ )
631
+
632
+
633
+ # 4. Define output layers
634
+ self.out_channels = in_channels if out_channels is None else out_channels
635
+ if self.is_input_continuous:
636
+ # TODO: should use out_channels for continuous projections
637
+ if use_linear_projection:
638
+ self.proj_out = linear_cls(inner_dim, in_channels)
639
+ else:
640
+ self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
641
+ elif self.is_input_vectorized:
642
+ self.norm_out = nn.LayerNorm(inner_dim)
643
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
644
+ elif self.is_input_patches and norm_type != "ada_norm_single":
645
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
646
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
647
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
648
+ elif self.is_input_patches and norm_type == "ada_norm_single":
649
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
650
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
651
+ self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
652
+
653
+ # 5. PixArt-Alpha blocks.
654
+ self.adaln_single = None
655
+ self.use_additional_conditions = False
656
+ if norm_type == "ada_norm_single":
657
+ self.use_additional_conditions = self.config.sample_size == 128 # False, 128 -> 1024
658
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
659
+ # additional conditions until we find better name
660
+ self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
661
+
662
+ self.caption_projection = None
663
+ if caption_channels is not None:
664
+ self.caption_projection = CaptionProjection(in_features=caption_channels, hidden_size=inner_dim)
665
+
666
+ self.gradient_checkpointing = False
667
+
668
+ # define temporal positional embedding
669
+ temp_pos_embed = self.get_1d_sincos_temp_embed(inner_dim, video_length) # 1152 hidden size
670
+ self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
671
+
672
+
673
+ def _set_gradient_checkpointing(self, module, value=False):
674
+ self.gradient_checkpointing = value
675
+
676
+
677
+ def forward(
678
+ self,
679
+ hidden_states: torch.Tensor,
680
+ timestep: Optional[torch.LongTensor] = None,
681
+ encoder_hidden_states: Optional[torch.Tensor] = None,
682
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
683
+ class_labels: Optional[torch.LongTensor] = None,
684
+ cross_attention_kwargs: Dict[str, Any] = None,
685
+ attention_mask: Optional[torch.Tensor] = None,
686
+ encoder_attention_mask: Optional[torch.Tensor] = None,
687
+ use_image_num: int = 0,
688
+ enable_temporal_attentions: bool = True,
689
+ return_dict: bool = True,
690
+ ):
691
+ """
692
+ The [`Transformer2DModel`] forward method.
693
+
694
+ Args:
695
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, frame, channel, height, width)` if continuous):
696
+ Input `hidden_states`.
697
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
698
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
699
+ self-attention.
700
+ timestep ( `torch.LongTensor`, *optional*):
701
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
702
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
703
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
704
+ `AdaLayerZeroNorm`.
705
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
706
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
707
+ `self.processor` in
708
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
709
+ attention_mask ( `torch.Tensor`, *optional*):
710
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
711
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
712
+ negative values to the attention scores corresponding to "discard" tokens.
713
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
714
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
715
+
716
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
717
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
718
+
719
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
720
+ above. This bias will be added to the cross-attention scores.
721
+ return_dict (`bool`, *optional*, defaults to `True`):
722
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
723
+ tuple.
724
+
725
+ Returns:
726
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
727
+ `tuple` where the first element is the sample tensor.
728
+ """
729
+ input_batch_size, c, frame, h, w = hidden_states.shape
730
+ frame = frame - use_image_num
731
+ hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) c h w').contiguous()
732
+
733
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
734
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
735
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
736
+ # expects mask of shape:
737
+ # [batch, key_tokens]
738
+ # adds singleton query_tokens dimension:
739
+ # [batch, 1, key_tokens]
740
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
741
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
742
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
743
+ if attention_mask is not None and attention_mask.ndim == 2:
744
+ # assume that mask is expressed as:
745
+ # (1 = keep, 0 = discard)
746
+ # convert mask into a bias that can be added to attention scores:
747
+ # (keep = +0, discard = -10000.0)
748
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
749
+ attention_mask = attention_mask.unsqueeze(1)
750
+
751
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
752
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: # ndim == 2 means no image joint
753
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
754
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
755
+ encoder_attention_mask = repeat(encoder_attention_mask, 'b 1 l -> (b f) 1 l', f=frame).contiguous()
756
+ elif encoder_attention_mask is not None and encoder_attention_mask.ndim == 3: # ndim == 3 means image joint
757
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
758
+ encoder_attention_mask_video = encoder_attention_mask[:, :1, ...]
759
+ encoder_attention_mask_video = repeat(encoder_attention_mask_video, 'b 1 l -> b (1 f) l', f=frame).contiguous()
760
+ encoder_attention_mask_image = encoder_attention_mask[:, 1:, ...]
761
+ encoder_attention_mask = torch.cat([encoder_attention_mask_video, encoder_attention_mask_image], dim=1)
762
+ encoder_attention_mask = rearrange(encoder_attention_mask, 'b n l -> (b n) l').contiguous().unsqueeze(1)
763
+
764
+
765
+ # Retrieve lora scale.
766
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
767
+
768
+ # 1. Input
769
+ if self.is_input_patches: # here
770
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
771
+ num_patches = height * width
772
+
773
+ hidden_states = self.pos_embed(hidden_states) # alrady add positional embeddings
774
+
775
+ if self.adaln_single is not None:
776
+ if self.use_additional_conditions and added_cond_kwargs is None:
777
+ raise ValueError(
778
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
779
+ )
780
+ # batch_size = hidden_states.shape[0]
781
+ batch_size = input_batch_size
782
+ timestep, embedded_timestep = self.adaln_single(
783
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
784
+ )
785
+
786
+ # 2. Blocks
787
+ if self.caption_projection is not None:
788
+ batch_size = hidden_states.shape[0]
789
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
790
+
791
+ if use_image_num != 0 and self.training:
792
+ encoder_hidden_states_video = encoder_hidden_states[:, :1, ...]
793
+ encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b 1 t d -> b (1 f) t d', f=frame).contiguous()
794
+ encoder_hidden_states_image = encoder_hidden_states[:, 1:, ...]
795
+ encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
796
+ encoder_hidden_states_spatial = rearrange(encoder_hidden_states, 'b f t d -> (b f) t d').contiguous()
797
+ else:
798
+ encoder_hidden_states_spatial = repeat(encoder_hidden_states, 'b t d -> (b f) t d', f=frame).contiguous()
799
+
800
+ # prepare timesteps for spatial and temporal block
801
+ timestep_spatial = repeat(timestep, 'b d -> (b f) d', f=frame + use_image_num).contiguous()
802
+ timestep_temp = repeat(timestep, 'b d -> (b p) d', p=num_patches).contiguous()
803
+
804
+ for i, (spatial_block, temp_block) in enumerate(zip(self.transformer_blocks, self.temporal_transformer_blocks)):
805
+
806
+ if self.training and self.gradient_checkpointing:
807
+ hidden_states = torch.utils.checkpoint.checkpoint(
808
+ spatial_block,
809
+ hidden_states,
810
+ attention_mask,
811
+ encoder_hidden_states_spatial,
812
+ encoder_attention_mask,
813
+ timestep_spatial,
814
+ cross_attention_kwargs,
815
+ class_labels,
816
+ use_reentrant=False,
817
+ )
818
+
819
+ if enable_temporal_attentions:
820
+ hidden_states = rearrange(hidden_states, '(b f) t d -> (b t) f d', b=input_batch_size).contiguous()
821
+
822
+ if use_image_num != 0: # image-video joitn training
823
+ hidden_states_video = hidden_states[:, :frame, ...]
824
+ hidden_states_image = hidden_states[:, frame:, ...]
825
+
826
+ if i == 0:
827
+ hidden_states_video = hidden_states_video + self.temp_pos_embed
828
+
829
+ hidden_states_video = torch.utils.checkpoint.checkpoint(
830
+ temp_block,
831
+ hidden_states_video,
832
+ None, # attention_mask
833
+ None, # encoder_hidden_states
834
+ None, # encoder_attention_mask
835
+ timestep_temp,
836
+ cross_attention_kwargs,
837
+ class_labels,
838
+ use_reentrant=False,
839
+ )
840
+
841
+ hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
842
+ hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', b=input_batch_size).contiguous()
843
+
844
+ else:
845
+ if i == 0:
846
+ hidden_states = hidden_states + self.temp_pos_embed
847
+
848
+ hidden_states = torch.utils.checkpoint.checkpoint(
849
+ temp_block,
850
+ hidden_states,
851
+ None, # attention_mask
852
+ None, # encoder_hidden_states
853
+ None, # encoder_attention_mask
854
+ timestep_temp,
855
+ cross_attention_kwargs,
856
+ class_labels,
857
+ use_reentrant=False,
858
+ )
859
+
860
+ hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', b=input_batch_size).contiguous()
861
+ else:
862
+ hidden_states = spatial_block(
863
+ hidden_states,
864
+ attention_mask,
865
+ encoder_hidden_states_spatial,
866
+ encoder_attention_mask,
867
+ timestep_spatial,
868
+ cross_attention_kwargs,
869
+ class_labels,
870
+ )
871
+
872
+ if enable_temporal_attentions:
873
+
874
+ hidden_states = rearrange(hidden_states, '(b f) t d -> (b t) f d', b=input_batch_size).contiguous()
875
+
876
+ if use_image_num != 0 and self.training:
877
+ hidden_states_video = hidden_states[:, :frame, ...]
878
+ hidden_states_image = hidden_states[:, frame:, ...]
879
+
880
+ hidden_states_video = temp_block(
881
+ hidden_states_video,
882
+ None, # attention_mask
883
+ None, # encoder_hidden_states
884
+ None, # encoder_attention_mask
885
+ timestep_temp,
886
+ cross_attention_kwargs,
887
+ class_labels,
888
+ )
889
+
890
+ hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
891
+ hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', b=input_batch_size).contiguous()
892
+
893
+ else:
894
+ if i == 0 and frame > 1:
895
+ hidden_states = hidden_states + self.temp_pos_embed
896
+
897
+ hidden_states = temp_block(
898
+ hidden_states,
899
+ None, # attention_mask
900
+ None, # encoder_hidden_states
901
+ None, # encoder_attention_mask
902
+ timestep_temp,
903
+ cross_attention_kwargs,
904
+ class_labels,
905
+ )
906
+
907
+ hidden_states = rearrange(hidden_states, '(b t) f d -> (b f) t d', b=input_batch_size).contiguous()
908
+
909
+
910
+ if self.is_input_patches:
911
+ if self.config.norm_type != "ada_norm_single":
912
+ conditioning = self.transformer_blocks[0].norm1.emb(
913
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
914
+ )
915
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
916
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
917
+ hidden_states = self.proj_out_2(hidden_states)
918
+ elif self.config.norm_type == "ada_norm_single":
919
+ embedded_timestep = repeat(embedded_timestep, 'b d -> (b f) d', f=frame + use_image_num).contiguous()
920
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
921
+ hidden_states = self.norm_out(hidden_states)
922
+ # Modulation
923
+ hidden_states = hidden_states * (1 + scale) + shift
924
+ hidden_states = self.proj_out(hidden_states)
925
+
926
+ # unpatchify
927
+ if self.adaln_single is None:
928
+ height = width = int(hidden_states.shape[1] ** 0.5)
929
+ hidden_states = hidden_states.reshape(
930
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
931
+ )
932
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
933
+ output = hidden_states.reshape(
934
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
935
+ )
936
+ output = rearrange(output, '(b f) c h w -> b c f h w', b=input_batch_size).contiguous()
937
+
938
+ if not return_dict:
939
+ return (output,)
940
+
941
+ return Transformer3DModelOutput(sample=output)
942
+
943
+ def get_1d_sincos_temp_embed(self, embed_dim, length):
944
+ pos = torch.arange(0, length).unsqueeze(1)
945
+ return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
models/utils.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+
15
+ import numpy as np
16
+ import torch.nn as nn
17
+
18
+ from einops import repeat
19
+
20
+
21
+ #################################################################################
22
+ # Unet Utils #
23
+ #################################################################################
24
+
25
+ def checkpoint(func, inputs, params, flag):
26
+ """
27
+ Evaluate a function without caching intermediate activations, allowing for
28
+ reduced memory at the expense of extra compute in the backward pass.
29
+ :param func: the function to evaluate.
30
+ :param inputs: the argument sequence to pass to `func`.
31
+ :param params: a sequence of parameters `func` depends on but does not
32
+ explicitly take as arguments.
33
+ :param flag: if False, disable gradient checkpointing.
34
+ """
35
+ if flag:
36
+ args = tuple(inputs) + tuple(params)
37
+ return CheckpointFunction.apply(func, len(inputs), *args)
38
+ else:
39
+ return func(*inputs)
40
+
41
+
42
+ class CheckpointFunction(torch.autograd.Function):
43
+ @staticmethod
44
+ def forward(ctx, run_function, length, *args):
45
+ ctx.run_function = run_function
46
+ ctx.input_tensors = list(args[:length])
47
+ ctx.input_params = list(args[length:])
48
+
49
+ with torch.no_grad():
50
+ output_tensors = ctx.run_function(*ctx.input_tensors)
51
+ return output_tensors
52
+
53
+ @staticmethod
54
+ def backward(ctx, *output_grads):
55
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
56
+ with torch.enable_grad():
57
+ # Fixes a bug where the first op in run_function modifies the
58
+ # Tensor storage in place, which is not allowed for detach()'d
59
+ # Tensors.
60
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
61
+ output_tensors = ctx.run_function(*shallow_copies)
62
+ input_grads = torch.autograd.grad(
63
+ output_tensors,
64
+ ctx.input_tensors + ctx.input_params,
65
+ output_grads,
66
+ allow_unused=True,
67
+ )
68
+ del ctx.input_tensors
69
+ del ctx.input_params
70
+ del output_tensors
71
+ return (None, None) + input_grads
72
+
73
+
74
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
75
+ """
76
+ Create sinusoidal timestep embeddings.
77
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
78
+ These may be fractional.
79
+ :param dim: the dimension of the output.
80
+ :param max_period: controls the minimum frequency of the embeddings.
81
+ :return: an [N x dim] Tensor of positional embeddings.
82
+ """
83
+ if not repeat_only:
84
+ half = dim // 2
85
+ freqs = torch.exp(
86
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
87
+ ).to(device=timesteps.device)
88
+ args = timesteps[:, None].float() * freqs[None]
89
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
90
+ if dim % 2:
91
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
92
+ else:
93
+ embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous()
94
+ return embedding
95
+
96
+
97
+ def zero_module(module):
98
+ """
99
+ Zero out the parameters of a module and return it.
100
+ """
101
+ for p in module.parameters():
102
+ p.detach().zero_()
103
+ return module
104
+
105
+
106
+ def scale_module(module, scale):
107
+ """
108
+ Scale the parameters of a module and return it.
109
+ """
110
+ for p in module.parameters():
111
+ p.detach().mul_(scale)
112
+ return module
113
+
114
+
115
+ def mean_flat(tensor):
116
+ """
117
+ Take the mean over all non-batch dimensions.
118
+ """
119
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
120
+
121
+
122
+ def normalization(channels):
123
+ """
124
+ Make a standard normalization layer.
125
+ :param channels: number of input channels.
126
+ :return: an nn.Module for normalization.
127
+ """
128
+ return GroupNorm32(32, channels)
129
+
130
+
131
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
132
+ class SiLU(nn.Module):
133
+ def forward(self, x):
134
+ return x * torch.sigmoid(x)
135
+
136
+
137
+ class GroupNorm32(nn.GroupNorm):
138
+ def forward(self, x):
139
+ return super().forward(x.float()).type(x.dtype)
140
+
141
+ def conv_nd(dims, *args, **kwargs):
142
+ """
143
+ Create a 1D, 2D, or 3D convolution module.
144
+ """
145
+ if dims == 1:
146
+ return nn.Conv1d(*args, **kwargs)
147
+ elif dims == 2:
148
+ return nn.Conv2d(*args, **kwargs)
149
+ elif dims == 3:
150
+ return nn.Conv3d(*args, **kwargs)
151
+ raise ValueError(f"unsupported dimensions: {dims}")
152
+
153
+
154
+ def linear(*args, **kwargs):
155
+ """
156
+ Create a linear module.
157
+ """
158
+ return nn.Linear(*args, **kwargs)
159
+
160
+
161
+ def avg_pool_nd(dims, *args, **kwargs):
162
+ """
163
+ Create a 1D, 2D, or 3D average pooling module.
164
+ """
165
+ if dims == 1:
166
+ return nn.AvgPool1d(*args, **kwargs)
167
+ elif dims == 2:
168
+ return nn.AvgPool2d(*args, **kwargs)
169
+ elif dims == 3:
170
+ return nn.AvgPool3d(*args, **kwargs)
171
+ raise ValueError(f"unsupported dimensions: {dims}")
172
+
173
+
174
+ # class HybridConditioner(nn.Module):
175
+
176
+ # def __init__(self, c_concat_config, c_crossattn_config):
177
+ # super().__init__()
178
+ # self.concat_conditioner = instantiate_from_config(c_concat_config)
179
+ # self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
180
+
181
+ # def forward(self, c_concat, c_crossattn):
182
+ # c_concat = self.concat_conditioner(c_concat)
183
+ # c_crossattn = self.crossattn_conditioner(c_crossattn)
184
+ # return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
185
+
186
+
187
+ def noise_like(shape, device, repeat=False):
188
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
189
+ noise = lambda: torch.randn(shape, device=device)
190
+ return repeat_noise() if repeat else noise()
191
+
192
+ def count_flops_attn(model, _x, y):
193
+ """
194
+ A counter for the `thop` package to count the operations in an
195
+ attention operation.
196
+ Meant to be used like:
197
+ macs, params = thop.profile(
198
+ model,
199
+ inputs=(inputs, timestamps),
200
+ custom_ops={QKVAttention: QKVAttention.count_flops},
201
+ )
202
+ """
203
+ b, c, *spatial = y[0].shape
204
+ num_spatial = int(np.prod(spatial))
205
+ # We perform two matmuls with the same number of ops.
206
+ # The first computes the weight matrix, the second computes
207
+ # the combination of the value vectors.
208
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
209
+ model.total_ops += torch.DoubleTensor([matmul_ops])
210
+
211
+ def count_params(model, verbose=False):
212
+ total_params = sum(p.numel() for p in model.parameters())
213
+ if verbose:
214
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
215
+ return total_params
sample/__pycache__/pipeline_latte.cpython-312.pyc ADDED
Binary file (35.4 kB). View file
 
sample/ffs.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ export CUDA_VISIBLE_DEVICES=7
3
+
4
+ python sample/sample.py \
5
+ --config ./configs/ffs/ffs_sample.yaml \
6
+ --ckpt ./share_ckpts/ffs.pt \
7
+ --save_video_path ./test
sample/ffs_ddp.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ export CUDA_VISIBLE_DEVICES=6,7
3
+
4
+ torchrun --nnodes=1 --nproc_per_node=2 sample/sample_ddp.py \
5
+ --config ./configs/ffs/ffs_sample.yaml \
6
+ --ckpt ./share_ckpts/ffs.pt \
7
+ --save_video_path ./test