Zhouyan248 commited on
Commit
24d19d7
1 Parent(s): 44037ea

add gradio

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +201 -0
  2. README.md +86 -12
  3. base/__pycache__/download.cpython-311.pyc +0 -0
  4. base/app.py +116 -0
  5. base/app.sh +1 -0
  6. base/configs/sample.yaml +28 -0
  7. base/download.py +18 -0
  8. base/huggingface-t2v/.DS_Store +0 -0
  9. base/huggingface-t2v/__init__.py +0 -0
  10. base/huggingface-t2v/requirements.txt +0 -0
  11. base/models/__init__.py +33 -0
  12. base/models/__pycache__/__init__.cpython-311.pyc +0 -0
  13. base/models/__pycache__/attention.cpython-311.pyc +0 -0
  14. base/models/__pycache__/resnet.cpython-311.pyc +0 -0
  15. base/models/__pycache__/unet.cpython-311.pyc +0 -0
  16. base/models/__pycache__/unet_blocks.cpython-311.pyc +0 -0
  17. base/models/attention.py +707 -0
  18. base/models/clip.py +120 -0
  19. base/models/resnet.py +212 -0
  20. base/models/temporal_attention.py +388 -0
  21. base/models/transformer_3d.py +367 -0
  22. base/models/unet.py +617 -0
  23. base/models/unet_blocks.py +648 -0
  24. base/models/utils.py +215 -0
  25. base/pipelines/__pycache__/pipeline_videogen.cpython-311.pyc +0 -0
  26. base/pipelines/pipeline_videogen.py +677 -0
  27. base/pipelines/sample.py +88 -0
  28. base/pipelines/sample.sh +2 -0
  29. base/text_to_video/__init__.py +44 -0
  30. base/text_to_video/__pycache__/__init__.cpython-311.pyc +0 -0
  31. base/try.py +5 -0
  32. environment.yml +27 -0
  33. interpolation/configs/sample.yaml +36 -0
  34. interpolation/datasets/__init__.py +1 -0
  35. interpolation/datasets/video_transforms.py +109 -0
  36. interpolation/diffusion/__init__.py +47 -0
  37. interpolation/diffusion/diffusion_utils.py +88 -0
  38. interpolation/diffusion/gaussian_diffusion.py +1000 -0
  39. interpolation/diffusion/respace.py +130 -0
  40. interpolation/diffusion/timestep_sampler.py +150 -0
  41. interpolation/download.py +22 -0
  42. interpolation/models/__init__.py +33 -0
  43. interpolation/models/attention.py +665 -0
  44. interpolation/models/clip.py +124 -0
  45. interpolation/models/resnet.py +212 -0
  46. interpolation/models/unet.py +576 -0
  47. interpolation/models/unet_blocks.py +619 -0
  48. interpolation/models/utils.py +215 -0
  49. interpolation/sample.py +312 -0
  50. interpolation/utils.py +371 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,86 @@
1
- ---
2
- title: Lavie Gradio
3
- emoji: 📊
4
- colorFrom: gray
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 4.7.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LaVie: High-Quality Video Generation with Cascaded Latent Diffusion Models
2
+
3
+ This repository is the official PyTorch implementation of [LaVie](https://arxiv.org/abs/2309.15103).
4
+
5
+ **LaVie** is a Text-to-Video (T2V) generation framework, and main part of video generation system [Vchitect](http://vchitect.intern-ai.org.cn/).
6
+
7
+ [![arXiv](https://img.shields.io/badge/arXiv-2307.04725-b31b1b.svg)](https://arxiv.org/abs/2309.15103)
8
+ [![Project Page](https://img.shields.io/badge/Project-Website-green)](https://vchitect.github.io/LaVie-project/)
9
+ <!--
10
+ [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)]()
11
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow)]()
12
+ -->
13
+
14
+ <img src="lavie.gif" width="800">
15
+
16
+ ## Installation
17
+ ```
18
+ conda env create -f environment.yml
19
+ conda activate lavie
20
+ ```
21
+
22
+ ## Download Pre-Trained models
23
+ Download [pre-trained models](https://huggingface.co/YaohuiW/LaVie/tree/main), [stable diffusion 1.4](https://huggingface.co/CompVis/stable-diffusion-v1-4/tree/main), [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/tree/main) to `./pretrained_models`. You should be able to see the following:
24
+ ```
25
+ ├── pretrained_models
26
+ │ ├── lavie_base.pt
27
+ │ ├── lavie_interpolation.pt
28
+ │ ├── lavie_vsr.pt
29
+ │ ├── stable-diffusion-v1-4
30
+ │ │ ├── ...
31
+ └── └── stable-diffusion-x4-upscaler
32
+ ├── ...
33
+ ```
34
+
35
+ ## Inference
36
+ The inference contains **Base T2V**, **Video Interpolation** and **Video Super-Resolution** three steps. We provide several options to generate videos:
37
+ * **Step1**: 320 x 512 resolution, 16 frames
38
+ * **Step1+Step2**: 320 x 512 resolution, 61 frames
39
+ * **Step1+Step3**: 1280 x 2048 resolution, 16 frames
40
+ * **Step1+Step2+Step3**: 1280 x 2048 resolution, 61 frames
41
+
42
+ Feel free to try different options:)
43
+
44
+
45
+ ### Step1. Base T2V
46
+ Run following command to generate videos from base T2V model.
47
+ ```
48
+ cd base
49
+ python pipelines/sample.py --config configs/sample.yaml
50
+ ```
51
+ Edit `text_prompt` in `configs/sample.yaml` to change prompt, results will be saved under `./res/base`.
52
+
53
+ ### Step2 (optional). Video Interpolation
54
+ Run following command to conduct video interpolation.
55
+ ```
56
+ cd interpolation
57
+ python sample.py --config configs/sample.yaml
58
+ ```
59
+ The default input video path is `./res/base`, results will be saved under `./res/interpolation`. In `configs/sample.yaml`, you could modify default `input_folder` with `YOUR_INPUT_FOLDER` in `configs/sample.yaml`. Input videos should be named as `prompt1.mp4`, `prompt2.mp4`, ... and put under `YOUR_INPUT_FOLDER`. Launching the code will process all the input videos in `input_folder`.
60
+
61
+
62
+ ### Step3 (optional). Video Super-Resolution
63
+ Run following command to conduct video super-resolution.
64
+ ```
65
+ cd vsr
66
+ python sample.py --config configs/sample.yaml
67
+ ```
68
+ The default input video path is `./res/base` and results will be saved under `./res/vsr`. You could modify default `input_path` with `YOUR_INPUT_FOLDER` in `configs/sample.yaml`. Smiliar to Step2, input videos should be named as `prompt1.mp4`, `prompt2.mp4`, ... and put under `YOUR_INPUT_FOLDER`. Launching the code will process all the input videos in `input_folder`.
69
+
70
+
71
+ ## BibTex
72
+ ```bibtex
73
+ @article{wang2023lavie,
74
+ title={LAVIE: High-Quality Video Generation with Cascaded Latent Diffusion Models},
75
+ author={Wang, Yaohui and Chen, Xinyuan and Ma, Xin and Zhou, Shangchen and Huang, Ziqi and Wang, Yi and Yang, Ceyuan and He, Yinan and Yu, Jiashuo and Yang, Peiqing and others},
76
+ journal={arXiv preprint arXiv:2309.15103},
77
+ year={2023}
78
+ }
79
+ ```
80
+
81
+ ## Acknowledgements
82
+ The code is buit upon [diffusers](https://github.com/huggingface/diffusers) and [Stable Diffusion](https://github.com/CompVis/stable-diffusion), we thank all the contributors for open-sourcing.
83
+
84
+
85
+ ## License
86
+ The code is licensed under Apache-2.0, model weights are fully open for academic research and also allow **free** commercial usage. To apply for a commercial license, please fill in the [application form]().
base/__pycache__/download.cpython-311.pyc ADDED
Binary file (815 Bytes). View file
 
base/app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from text_to_video import model_t2v_fun,setup_seed
3
+ from omegaconf import OmegaConf
4
+ import torch
5
+ import imageio
6
+ import os
7
+ import cv2
8
+ import torchvision
9
+ config_path = "/mnt/petrelfs/zhouyan/project/lavie-release/base/configs/sample.yaml"
10
+ args = OmegaConf.load("/mnt/petrelfs/zhouyan/project/lavie-release/base/configs/sample.yaml")
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ # ------- get model ---------------
13
+ model_t2V = model_t2v_fun(args)
14
+ model_t2V.to(device)
15
+ if device == "cuda":
16
+ model_t2V.enable_xformers_memory_efficient_attention()
17
+
18
+ # model_t2V.enable_xformers_memory_efficient_attention()
19
+ css = """
20
+ h1 {
21
+ text-align: center;
22
+ }
23
+ #component-0 {
24
+ max-width: 730px;
25
+ margin: auto;
26
+ }
27
+ """
28
+
29
+ def infer(prompt, seed_inp, ddim_steps):
30
+
31
+
32
+ setup_seed(seed_inp)
33
+ videos = model_t2V(prompt, video_length=16, height = 320, width= 512, num_inference_steps=ddim_steps, guidance_scale=7).video
34
+ print(videos[0].shape)
35
+ if not os.path.exists(args.output_folder):
36
+ os.mkdir(args.output_folder)
37
+ torchvision.io.write_video(args.output_folder + prompt.replace(' ', '_') + '-.mp4', videos[0], fps=8)
38
+ # imageio.mimwrite(args.output_folder + prompt.replace(' ', '_') + '.mp4', videos[0], fps=8)
39
+ # video = cv2.VideoCapture(args.output_folder + prompt.replace(' ', '_') + '.mp4')
40
+ # video = imageio.get_reader(args.output_folder + prompt.replace(' ', '_') + '.mp4', 'ffmpeg')
41
+
42
+
43
+ # video = model_t2V(prompt, seed_inp, ddim_steps)
44
+
45
+ return args.output_folder + prompt.replace(' ', '_') + '-.mp4'
46
+
47
+ print(1)
48
+
49
+ # def clean():
50
+ # return gr.Image.update(value=None, visible=False), gr.Video.update(value=None)
51
+ def clean():
52
+ return gr.Video.update(value=None)
53
+
54
+ title = """
55
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
56
+ <div
57
+ style="
58
+ display: inline-flex;
59
+ align-items: center;
60
+ gap: 0.8rem;
61
+ font-size: 1.75rem;
62
+ "
63
+ >
64
+ <h1 style="font-weight: 900; margin-bottom: 7px; margin-top: 5px;">
65
+ Intern·Vchitect (Text-to-Video)
66
+ </h1>
67
+ </div>
68
+ <p style="margin-bottom: 10px; font-size: 94%">
69
+ Apply Intern·Vchitect to generate a video
70
+ </p>
71
+ </div>
72
+ """
73
+
74
+ # print(1)
75
+ with gr.Blocks(css='style.css') as demo:
76
+ gr.Markdown("<font color=red size=10><center>LaVie</center></font>")
77
+ with gr.Row(elem_id="col-container"):
78
+
79
+ with gr.Column():
80
+
81
+ prompt = gr.Textbox(value="a teddy bear walking on the street", label="Prompt", placeholder="enter prompt", show_label=True, elem_id="prompt-in", min_width=200, lines=2)
82
+
83
+ ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=50, step=1)
84
+ seed_inp = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, value=400, elem_id="seed-in")
85
+ # with gr.Row():
86
+ # # control_task = gr.Dropdown(label="Task", choices=["Text-2-video", "Image-2-video"], value="Text-2-video", multiselect=False, elem_id="controltask-in")
87
+ # ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=250, step=1)
88
+ # seed_inp = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, value=123456, elem_id="seed-in")
89
+
90
+ # ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=250, step=1)
91
+ with gr.Column():
92
+ submit_btn = gr.Button("Generate video")
93
+ clean_btn = gr.Button("Clean video")
94
+ # submit_btn = gr.Button("Generate video", size='sm')
95
+ # video_out = gr.Video(label="Video result", elem_id="video-output", height=320, width=512)
96
+ video_out = gr.Video(label="Video result", elem_id="video-output")
97
+ # with gr.Row():
98
+ # video_out = gr.Video(label="Video result", elem_id="video-output", height=320, width=512)
99
+ # submit_btn = gr.Button("Generate video", size='sm')
100
+
101
+
102
+ # video_out = gr.Video(label="Video result", elem_id="video-output", height=320, width=512)
103
+ inputs = [prompt, seed_inp, ddim_steps]
104
+ outputs = [video_out]
105
+
106
+
107
+ # control_task.change(change_task_options, inputs=[control_task], outputs=[canny_opt, hough_opt, normal_opt], queue=False)
108
+ # submit_btn.click(clean, inputs=[], outputs=[video_out], queue=False)
109
+ clean_btn.click(clean, inputs=[], outputs=[video_out], queue=False)
110
+ submit_btn.click(infer, inputs, outputs)
111
+ # share_button.click(None, [], [], _js=share_js)
112
+
113
+ print(2)
114
+ demo.queue(max_size=12).launch(server_name="0.0.0.0", server_port=7860)
115
+
116
+
base/app.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ srun -p aigc-video --gres=gpu:1 -n1 -N1 python app.py
base/configs/sample.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # path:
2
+ output_folder: "/mnt/petrelfs/share_data/zhouyan/gradio/lavie"
3
+ pretrained_path: "/mnt/petrelfs/zhouyan/models"
4
+
5
+ # model config:
6
+ model: UNet
7
+ video_length: 16
8
+ image_size: [320, 512]
9
+
10
+ # beta schedule
11
+ beta_start: 0.0001
12
+ beta_end: 0.02
13
+ beta_schedule: "linear"
14
+
15
+ # model speedup
16
+ use_compile: False
17
+ use_fp16: True
18
+
19
+ # sample config:
20
+ seed: 3
21
+ run_time: 0
22
+ guidance_scale: 7.0
23
+ sample_method: 'ddpm'
24
+ num_sampling_steps: 250
25
+ text_prompt: [
26
+ 'a teddy bear walking on the street, high quality, 2k',
27
+
28
+ ]
base/download.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # All rights reserved.
2
+
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import os
8
+
9
+
10
+ def find_model(model_name):
11
+ """
12
+ Finds a pre-trained model, downloading it if necessary. Alternatively, loads a model from a local path.
13
+ """
14
+ checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
15
+ if "ema" in checkpoint: # supports checkpoints from train.py
16
+ print('Ema existing!')
17
+ checkpoint = checkpoint["ema"]
18
+ return checkpoint
base/huggingface-t2v/.DS_Store ADDED
Binary file (6.15 kB). View file
 
base/huggingface-t2v/__init__.py ADDED
File without changes
base/huggingface-t2v/requirements.txt ADDED
File without changes
base/models/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.split(sys.path[0])[0])
4
+
5
+ from .unet import UNet3DConditionModel
6
+ from torch.optim.lr_scheduler import LambdaLR
7
+
8
+ def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+ def fn(step):
11
+ if warmup_steps > 0:
12
+ return min(step / warmup_steps, 1)
13
+ else:
14
+ return 1
15
+ return LambdaLR(optimizer, fn)
16
+
17
+
18
+ def get_lr_scheduler(optimizer, name, **kwargs):
19
+ if name == 'warmup':
20
+ return customized_lr_scheduler(optimizer, **kwargs)
21
+ elif name == 'cosine':
22
+ from torch.optim.lr_scheduler import CosineAnnealingLR
23
+ return CosineAnnealingLR(optimizer, **kwargs)
24
+ else:
25
+ raise NotImplementedError(name)
26
+
27
+ def get_models(args, sd_path):
28
+
29
+ if 'UNet' in args.model:
30
+ return UNet3DConditionModel.from_pretrained_2d(sd_path, subfolder="unet")
31
+ else:
32
+ raise '{} Model Not Supported!'.format(args.model)
33
+
base/models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.9 kB). View file
 
base/models/__pycache__/attention.cpython-311.pyc ADDED
Binary file (33.7 kB). View file
 
base/models/__pycache__/resnet.cpython-311.pyc ADDED
Binary file (9.76 kB). View file
 
base/models/__pycache__/unet.cpython-311.pyc ADDED
Binary file (27.3 kB). View file
 
base/models/__pycache__/unet_blocks.cpython-311.pyc ADDED
Binary file (20.3 kB). View file
 
base/models/attention.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+ import os
3
+ import sys
4
+ sys.path.append(os.path.split(sys.path[0])[0])
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional
8
+
9
+ import math
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.utils import BaseOutput
16
+ from diffusers.utils.import_utils import is_xformers_available
17
+ from diffusers.models.attention import FeedForward, AdaLayerNorm
18
+ from rotary_embedding_torch import RotaryEmbedding
19
+ from typing import Callable, Optional
20
+ from einops import rearrange, repeat
21
+
22
+ try:
23
+ from diffusers.models.modeling_utils import ModelMixin
24
+ except:
25
+ from diffusers.modeling_utils import ModelMixin # 0.11.1
26
+
27
+
28
+ @dataclass
29
+ class Transformer3DModelOutput(BaseOutput):
30
+ sample: torch.FloatTensor
31
+
32
+
33
+ if is_xformers_available():
34
+ import xformers
35
+ import xformers.ops
36
+ else:
37
+ xformers = None
38
+
39
+ def exists(x):
40
+ return x is not None
41
+
42
+
43
+ class CrossAttention(nn.Module):
44
+ r"""
45
+ copy from diffuser 0.11.1
46
+ A cross attention layer.
47
+ Parameters:
48
+ query_dim (`int`): The number of channels in the query.
49
+ cross_attention_dim (`int`, *optional*):
50
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
51
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
52
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
53
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
54
+ bias (`bool`, *optional*, defaults to False):
55
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ query_dim: int,
61
+ cross_attention_dim: Optional[int] = None,
62
+ heads: int = 8,
63
+ dim_head: int = 64,
64
+ dropout: float = 0.0,
65
+ bias=False,
66
+ upcast_attention: bool = False,
67
+ upcast_softmax: bool = False,
68
+ added_kv_proj_dim: Optional[int] = None,
69
+ norm_num_groups: Optional[int] = None,
70
+ use_relative_position: bool = False,
71
+ ):
72
+ super().__init__()
73
+ inner_dim = dim_head * heads
74
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
75
+ self.upcast_attention = upcast_attention
76
+ self.upcast_softmax = upcast_softmax
77
+
78
+ self.scale = dim_head**-0.5
79
+
80
+ self.heads = heads
81
+ self.dim_head = dim_head
82
+ # for slice_size > 0 the attention score computation
83
+ # is split across the batch axis to save memory
84
+ # You can set slice_size with `set_attention_slice`
85
+ self.sliceable_head_dim = heads
86
+ self._slice_size = None
87
+ self._use_memory_efficient_attention_xformers = False
88
+ self.added_kv_proj_dim = added_kv_proj_dim
89
+
90
+ if norm_num_groups is not None:
91
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
92
+ else:
93
+ self.group_norm = None
94
+
95
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
96
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
97
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
98
+
99
+ if self.added_kv_proj_dim is not None:
100
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
101
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
102
+
103
+ self.to_out = nn.ModuleList([])
104
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
105
+ self.to_out.append(nn.Dropout(dropout))
106
+
107
+ self.use_relative_position = use_relative_position
108
+ if self.use_relative_position:
109
+ self.rotary_emb = RotaryEmbedding(min(32, dim_head))
110
+
111
+
112
+ def reshape_heads_to_batch_dim(self, tensor):
113
+ batch_size, seq_len, dim = tensor.shape
114
+ head_size = self.heads
115
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
116
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
117
+ return tensor
118
+
119
+ def reshape_batch_dim_to_heads(self, tensor):
120
+ batch_size, seq_len, dim = tensor.shape
121
+ head_size = self.heads
122
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
123
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
124
+ return tensor
125
+
126
+ def reshape_for_scores(self, tensor):
127
+ # split heads and dims
128
+ # tensor should be [b (h w)] f (d nd)
129
+ batch_size, seq_len, dim = tensor.shape
130
+ head_size = self.heads
131
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
132
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
133
+ return tensor
134
+
135
+ def same_batch_dim_to_heads(self, tensor):
136
+ batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d
137
+ tensor = tensor.reshape(batch_size, seq_len, dim * head_size)
138
+ return tensor
139
+
140
+ def set_attention_slice(self, slice_size):
141
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
142
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
143
+
144
+ self._slice_size = slice_size
145
+
146
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
147
+ batch_size, sequence_length, _ = hidden_states.shape
148
+
149
+ encoder_hidden_states = encoder_hidden_states
150
+
151
+ if self.group_norm is not None:
152
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
153
+
154
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
155
+
156
+ # print('before reshpape query shape', query.shape)
157
+ dim = query.shape[-1]
158
+ if not self.use_relative_position:
159
+ query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d
160
+ # print('after reshape query shape', query.shape)
161
+
162
+ if self.added_kv_proj_dim is not None:
163
+ key = self.to_k(hidden_states)
164
+ value = self.to_v(hidden_states)
165
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
166
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
167
+
168
+ key = self.reshape_heads_to_batch_dim(key)
169
+ value = self.reshape_heads_to_batch_dim(value)
170
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
171
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
172
+
173
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
174
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
175
+ else:
176
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
177
+ key = self.to_k(encoder_hidden_states)
178
+ value = self.to_v(encoder_hidden_states)
179
+
180
+ if not self.use_relative_position:
181
+ key = self.reshape_heads_to_batch_dim(key)
182
+ value = self.reshape_heads_to_batch_dim(value)
183
+
184
+ if attention_mask is not None:
185
+ if attention_mask.shape[-1] != query.shape[1]:
186
+ target_length = query.shape[1]
187
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
188
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
189
+
190
+ # attention, what we cannot get enough of
191
+ if self._use_memory_efficient_attention_xformers:
192
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
193
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
194
+ hidden_states = hidden_states.to(query.dtype)
195
+ else:
196
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
197
+ hidden_states = self._attention(query, key, value, attention_mask)
198
+ else:
199
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
200
+
201
+ # linear proj
202
+ hidden_states = self.to_out[0](hidden_states)
203
+
204
+ # dropout
205
+ hidden_states = self.to_out[1](hidden_states)
206
+ return hidden_states
207
+
208
+
209
+ def _attention(self, query, key, value, attention_mask=None):
210
+ if self.upcast_attention:
211
+ query = query.float()
212
+ key = key.float()
213
+
214
+ attention_scores = torch.baddbmm(
215
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
216
+ query,
217
+ key.transpose(-1, -2),
218
+ beta=0,
219
+ alpha=self.scale,
220
+ )
221
+
222
+ if attention_mask is not None:
223
+ attention_scores = attention_scores + attention_mask
224
+
225
+ if self.upcast_softmax:
226
+ attention_scores = attention_scores.float()
227
+
228
+ attention_probs = attention_scores.softmax(dim=-1)
229
+
230
+ # cast back to the original dtype
231
+ attention_probs = attention_probs.to(value.dtype)
232
+
233
+ # compute attention output
234
+ hidden_states = torch.bmm(attention_probs, value)
235
+
236
+ # reshape hidden_states
237
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
238
+
239
+ return hidden_states
240
+
241
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
242
+ batch_size_attention = query.shape[0]
243
+ hidden_states = torch.zeros(
244
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
245
+ )
246
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
247
+ for i in range(hidden_states.shape[0] // slice_size):
248
+ start_idx = i * slice_size
249
+ end_idx = (i + 1) * slice_size
250
+
251
+ query_slice = query[start_idx:end_idx]
252
+ key_slice = key[start_idx:end_idx]
253
+
254
+ if self.upcast_attention:
255
+ query_slice = query_slice.float()
256
+ key_slice = key_slice.float()
257
+
258
+ attn_slice = torch.baddbmm(
259
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
260
+ query_slice,
261
+ key_slice.transpose(-1, -2),
262
+ beta=0,
263
+ alpha=self.scale,
264
+ )
265
+
266
+ if attention_mask is not None:
267
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
268
+
269
+ if self.upcast_softmax:
270
+ attn_slice = attn_slice.float()
271
+
272
+ attn_slice = attn_slice.softmax(dim=-1)
273
+
274
+ # cast back to the original dtype
275
+ attn_slice = attn_slice.to(value.dtype)
276
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
277
+
278
+ hidden_states[start_idx:end_idx] = attn_slice
279
+
280
+ # reshape hidden_states
281
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
282
+ return hidden_states
283
+
284
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
285
+ # TODO attention_mask
286
+ query = query.contiguous()
287
+ key = key.contiguous()
288
+ value = value.contiguous()
289
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
290
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
291
+ return hidden_states
292
+
293
+
294
+ class Transformer3DModel(ModelMixin, ConfigMixin):
295
+ @register_to_config
296
+ def __init__(
297
+ self,
298
+ num_attention_heads: int = 16,
299
+ attention_head_dim: int = 88,
300
+ in_channels: Optional[int] = None,
301
+ num_layers: int = 1,
302
+ dropout: float = 0.0,
303
+ norm_num_groups: int = 32,
304
+ cross_attention_dim: Optional[int] = None,
305
+ attention_bias: bool = False,
306
+ activation_fn: str = "geglu",
307
+ num_embeds_ada_norm: Optional[int] = None,
308
+ use_linear_projection: bool = False,
309
+ only_cross_attention: bool = False,
310
+ upcast_attention: bool = False,
311
+ use_first_frame: bool = False,
312
+ use_relative_position: bool = False,
313
+ rotary_emb: bool = None,
314
+ ):
315
+ super().__init__()
316
+ self.use_linear_projection = use_linear_projection
317
+ self.num_attention_heads = num_attention_heads
318
+ self.attention_head_dim = attention_head_dim
319
+ inner_dim = num_attention_heads * attention_head_dim
320
+
321
+ # Define input layers
322
+ self.in_channels = in_channels
323
+
324
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
325
+ if use_linear_projection:
326
+ self.proj_in = nn.Linear(in_channels, inner_dim)
327
+ else:
328
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
329
+
330
+ # Define transformers blocks
331
+ self.transformer_blocks = nn.ModuleList(
332
+ [
333
+ BasicTransformerBlock(
334
+ inner_dim,
335
+ num_attention_heads,
336
+ attention_head_dim,
337
+ dropout=dropout,
338
+ cross_attention_dim=cross_attention_dim,
339
+ activation_fn=activation_fn,
340
+ num_embeds_ada_norm=num_embeds_ada_norm,
341
+ attention_bias=attention_bias,
342
+ only_cross_attention=only_cross_attention,
343
+ upcast_attention=upcast_attention,
344
+ use_first_frame=use_first_frame,
345
+ use_relative_position=use_relative_position,
346
+ rotary_emb=rotary_emb,
347
+ )
348
+ for d in range(num_layers)
349
+ ]
350
+ )
351
+
352
+ # 4. Define output layers
353
+ if use_linear_projection:
354
+ self.proj_out = nn.Linear(in_channels, inner_dim)
355
+ else:
356
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
357
+
358
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, use_image_num=None, return_dict: bool = True):
359
+ # Input
360
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
361
+
362
+ video_length = hidden_states.shape[2]
363
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
364
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous()
365
+
366
+ batch, channel, height, weight = hidden_states.shape
367
+ residual = hidden_states
368
+
369
+ hidden_states = self.norm(hidden_states)
370
+ if not self.use_linear_projection:
371
+ hidden_states = self.proj_in(hidden_states)
372
+ inner_dim = hidden_states.shape[1]
373
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
374
+ else:
375
+ inner_dim = hidden_states.shape[1]
376
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
377
+ hidden_states = self.proj_in(hidden_states)
378
+
379
+ # Blocks
380
+ for block in self.transformer_blocks:
381
+ hidden_states = block(
382
+ hidden_states,
383
+ encoder_hidden_states=encoder_hidden_states,
384
+ timestep=timestep,
385
+ video_length=video_length,
386
+ use_image_num=use_image_num,
387
+ )
388
+
389
+ # Output
390
+ if not self.use_linear_projection:
391
+ hidden_states = (
392
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
393
+ )
394
+ hidden_states = self.proj_out(hidden_states)
395
+ else:
396
+ hidden_states = self.proj_out(hidden_states)
397
+ hidden_states = (
398
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
399
+ )
400
+
401
+ output = hidden_states + residual
402
+
403
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length + use_image_num).contiguous()
404
+ if not return_dict:
405
+ return (output,)
406
+
407
+ return Transformer3DModelOutput(sample=output)
408
+
409
+
410
+ class BasicTransformerBlock(nn.Module):
411
+ def __init__(
412
+ self,
413
+ dim: int,
414
+ num_attention_heads: int,
415
+ attention_head_dim: int,
416
+ dropout=0.0,
417
+ cross_attention_dim: Optional[int] = None,
418
+ activation_fn: str = "geglu",
419
+ num_embeds_ada_norm: Optional[int] = None,
420
+ attention_bias: bool = False,
421
+ only_cross_attention: bool = False,
422
+ upcast_attention: bool = False,
423
+ use_first_frame: bool = False,
424
+ use_relative_position: bool = False,
425
+ rotary_emb: bool = False,
426
+ ):
427
+ super().__init__()
428
+ self.only_cross_attention = only_cross_attention
429
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
430
+ self.use_first_frame = use_first_frame
431
+
432
+ # Spatial-Attn
433
+ self.attn1 = CrossAttention(
434
+ query_dim=dim,
435
+ heads=num_attention_heads,
436
+ dim_head=attention_head_dim,
437
+ dropout=dropout,
438
+ bias=attention_bias,
439
+ cross_attention_dim=None,
440
+ upcast_attention=upcast_attention,
441
+ )
442
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
443
+
444
+ # Text Cross-Attn
445
+ if cross_attention_dim is not None:
446
+ self.attn2 = CrossAttention(
447
+ query_dim=dim,
448
+ cross_attention_dim=cross_attention_dim,
449
+ heads=num_attention_heads,
450
+ dim_head=attention_head_dim,
451
+ dropout=dropout,
452
+ bias=attention_bias,
453
+ upcast_attention=upcast_attention,
454
+ )
455
+ else:
456
+ self.attn2 = None
457
+
458
+ if cross_attention_dim is not None:
459
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
460
+ else:
461
+ self.norm2 = None
462
+
463
+ # Temp
464
+ self.attn_temp = TemporalAttention(
465
+ query_dim=dim,
466
+ heads=num_attention_heads,
467
+ dim_head=attention_head_dim,
468
+ dropout=dropout,
469
+ bias=attention_bias,
470
+ cross_attention_dim=None,
471
+ upcast_attention=upcast_attention,
472
+ rotary_emb=rotary_emb,
473
+ )
474
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
475
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
476
+
477
+
478
+ # Feed-forward
479
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
480
+ self.norm3 = nn.LayerNorm(dim)
481
+
482
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, op=None):
483
+
484
+ if not is_xformers_available():
485
+ print("Here is how to install it")
486
+ raise ModuleNotFoundError(
487
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
488
+ " xformers",
489
+ name="xformers",
490
+ )
491
+ elif not torch.cuda.is_available():
492
+ raise ValueError(
493
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
494
+ " available for GPU "
495
+ )
496
+ else:
497
+ try:
498
+ # Make sure we can run the memory efficient attention
499
+ _ = xformers.ops.memory_efficient_attention(
500
+ torch.randn((1, 2, 40), device="cuda"),
501
+ torch.randn((1, 2, 40), device="cuda"),
502
+ torch.randn((1, 2, 40), device="cuda"),
503
+ )
504
+ except Exception as e:
505
+ raise e
506
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
507
+ if self.attn2 is not None:
508
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
509
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
510
+
511
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None, use_image_num=None):
512
+ # SparseCausal-Attention
513
+ norm_hidden_states = (
514
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
515
+ )
516
+
517
+ if self.only_cross_attention:
518
+ hidden_states = (
519
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
520
+ )
521
+ else:
522
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num) + hidden_states
523
+
524
+ if self.attn2 is not None:
525
+ # Cross-Attention
526
+ norm_hidden_states = (
527
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
528
+ )
529
+ hidden_states = (
530
+ self.attn2(
531
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
532
+ )
533
+ + hidden_states
534
+ )
535
+
536
+ # Temporal Attention
537
+ if self.training:
538
+ d = hidden_states.shape[1]
539
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous()
540
+ hidden_states_video = hidden_states[:, :video_length, :]
541
+ hidden_states_image = hidden_states[:, video_length:, :]
542
+ norm_hidden_states_video = (
543
+ self.norm_temp(hidden_states_video, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states_video)
544
+ )
545
+ hidden_states_video = self.attn_temp(norm_hidden_states_video) + hidden_states_video
546
+ hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
547
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous()
548
+ else:
549
+ d = hidden_states.shape[1]
550
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous()
551
+ norm_hidden_states = (
552
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
553
+ )
554
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
555
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous()
556
+
557
+ # Feed-forward
558
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
559
+
560
+ return hidden_states
561
+
562
+ class TemporalAttention(CrossAttention):
563
+ def __init__(self,
564
+ query_dim: int,
565
+ cross_attention_dim: Optional[int] = None,
566
+ heads: int = 8,
567
+ dim_head: int = 64,
568
+ dropout: float = 0.0,
569
+ bias=False,
570
+ upcast_attention: bool = False,
571
+ upcast_softmax: bool = False,
572
+ added_kv_proj_dim: Optional[int] = None,
573
+ norm_num_groups: Optional[int] = None,
574
+ rotary_emb=None):
575
+ super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups)
576
+ # relative time positional embeddings
577
+ self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet
578
+ self.rotary_emb = rotary_emb
579
+
580
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
581
+ time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device)
582
+ batch_size, sequence_length, _ = hidden_states.shape
583
+
584
+ encoder_hidden_states = encoder_hidden_states
585
+
586
+ if self.group_norm is not None:
587
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
588
+
589
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
590
+ dim = query.shape[-1]
591
+
592
+ if self.added_kv_proj_dim is not None:
593
+ key = self.to_k(hidden_states)
594
+ value = self.to_v(hidden_states)
595
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
596
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
597
+
598
+ key = self.reshape_heads_to_batch_dim(key)
599
+ value = self.reshape_heads_to_batch_dim(value)
600
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
601
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
602
+
603
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
604
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
605
+ else:
606
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
607
+ key = self.to_k(encoder_hidden_states)
608
+ value = self.to_v(encoder_hidden_states)
609
+
610
+ if attention_mask is not None:
611
+ if attention_mask.shape[-1] != query.shape[1]:
612
+ target_length = query.shape[1]
613
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
614
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
615
+
616
+ # attention, what we cannot get enough of
617
+ if self._use_memory_efficient_attention_xformers:
618
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
619
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
620
+ hidden_states = hidden_states.to(query.dtype)
621
+ else:
622
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
623
+ hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias)
624
+ else:
625
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
626
+
627
+ # linear proj
628
+ hidden_states = self.to_out[0](hidden_states)
629
+
630
+ # dropout
631
+ hidden_states = self.to_out[1](hidden_states)
632
+ return hidden_states
633
+
634
+ def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None):
635
+ if self.upcast_attention:
636
+ query = query.float()
637
+ key = key.float()
638
+
639
+ # reshape for adding time positional bais
640
+ query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
641
+ key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
642
+ value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
643
+
644
+ if exists(self.rotary_emb):
645
+ query = self.rotary_emb.rotate_queries_or_keys(query)
646
+ key = self.rotary_emb.rotate_queries_or_keys(key)
647
+
648
+ attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key)
649
+
650
+ attention_scores = attention_scores + time_rel_pos_bias
651
+
652
+ if attention_mask is not None:
653
+ # add attention mask
654
+ attention_scores = attention_scores + attention_mask
655
+
656
+ attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach()
657
+
658
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
659
+ # print(attention_probs[0][0])
660
+
661
+ # cast back to the original dtype
662
+ attention_probs = attention_probs.to(value.dtype)
663
+
664
+ # compute attention output
665
+ hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value)
666
+ hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)')
667
+ return hidden_states
668
+
669
+ class RelativePositionBias(nn.Module):
670
+ def __init__(
671
+ self,
672
+ heads=8,
673
+ num_buckets=32,
674
+ max_distance=128,
675
+ ):
676
+ super().__init__()
677
+ self.num_buckets = num_buckets
678
+ self.max_distance = max_distance
679
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
680
+
681
+ @staticmethod
682
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
683
+ ret = 0
684
+ n = -relative_position
685
+
686
+ num_buckets //= 2
687
+ ret += (n < 0).long() * num_buckets
688
+ n = torch.abs(n)
689
+
690
+ max_exact = num_buckets // 2
691
+ is_small = n < max_exact
692
+
693
+ val_if_large = max_exact + (
694
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
695
+ ).long()
696
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
697
+
698
+ ret += torch.where(is_small, n, val_if_large)
699
+ return ret
700
+
701
+ def forward(self, n, device):
702
+ q_pos = torch.arange(n, dtype = torch.long, device = device)
703
+ k_pos = torch.arange(n, dtype = torch.long, device = device)
704
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
705
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
706
+ values = self.relative_attention_bias(rp_bucket)
707
+ return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
base/models/clip.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import torch.nn as nn
3
+ from transformers import CLIPTokenizer, CLIPTextModel
4
+
5
+ import transformers
6
+ transformers.logging.set_verbosity_error()
7
+
8
+ """
9
+ Will encounter following warning:
10
+ - This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task
11
+ or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
12
+ - This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model
13
+ that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
14
+
15
+ https://github.com/CompVis/stable-diffusion/issues/97
16
+ according to this issue, this warning is safe.
17
+
18
+ This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion.
19
+ You can safely ignore the warning, it is not an error.
20
+
21
+ This clip usage is from U-ViT and same with Stable Diffusion.
22
+ """
23
+
24
+ class AbstractEncoder(nn.Module):
25
+ def __init__(self):
26
+ super().__init__()
27
+
28
+ def encode(self, *args, **kwargs):
29
+ raise NotImplementedError
30
+
31
+
32
+ class FrozenCLIPEmbedder(AbstractEncoder):
33
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
34
+ # def __init__(self, version="openai/clip-vit-huge-patch14", device="cuda", max_length=77):
35
+ def __init__(self, path, device="cuda", max_length=77):
36
+ super().__init__()
37
+ self.tokenizer = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer")
38
+ self.transformer = CLIPTextModel.from_pretrained(path, subfolder='text_encoder')
39
+ self.device = device
40
+ self.max_length = max_length
41
+ self.freeze()
42
+
43
+ def freeze(self):
44
+ self.transformer = self.transformer.eval()
45
+ for param in self.parameters():
46
+ param.requires_grad = False
47
+
48
+ def forward(self, text):
49
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
50
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
51
+ tokens = batch_encoding["input_ids"].to(self.device)
52
+ outputs = self.transformer(input_ids=tokens)
53
+
54
+ z = outputs.last_hidden_state
55
+ return z
56
+
57
+ def encode(self, text):
58
+ return self(text)
59
+
60
+
61
+ class TextEmbedder(nn.Module):
62
+ """
63
+ Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
64
+ """
65
+ def __init__(self, path, dropout_prob=0.1):
66
+ super().__init__()
67
+ self.text_encodder = FrozenCLIPEmbedder(path=path)
68
+ self.dropout_prob = dropout_prob
69
+
70
+ def token_drop(self, text_prompts, force_drop_ids=None):
71
+ """
72
+ Drops text to enable classifier-free guidance.
73
+ """
74
+ if force_drop_ids is None:
75
+ drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob
76
+ else:
77
+ # TODO
78
+ drop_ids = force_drop_ids == 1
79
+ labels = list(numpy.where(drop_ids, "", text_prompts))
80
+ # print(labels)
81
+ return labels
82
+
83
+ def forward(self, text_prompts, train, force_drop_ids=None):
84
+ use_dropout = self.dropout_prob > 0
85
+ if (train and use_dropout) or (force_drop_ids is not None):
86
+ text_prompts = self.token_drop(text_prompts, force_drop_ids)
87
+ embeddings = self.text_encodder(text_prompts)
88
+ return embeddings
89
+
90
+
91
+ if __name__ == '__main__':
92
+
93
+ r"""
94
+ Returns:
95
+
96
+ Examples from CLIPTextModel:
97
+
98
+ ```python
99
+ >>> from transformers import AutoTokenizer, CLIPTextModel
100
+
101
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
102
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
103
+
104
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
105
+
106
+ >>> outputs = model(**inputs)
107
+ >>> last_hidden_state = outputs.last_hidden_state
108
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
109
+ ```"""
110
+
111
+ import torch
112
+
113
+ device = "cuda" if torch.cuda.is_available() else "cpu"
114
+
115
+ text_encoder = TextEmbedder(path='/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base',
116
+ dropout_prob=0.00001).to(device)
117
+
118
+ text_prompt = [["a photo of a cat", "a photo of a cat"], ["a photo of a dog", "a photo of a cat"], ['a photo of a dog human', "a photo of a cat"]]
119
+ output = text_encoder(text_prompts=text_prompt, train=False)
120
+ print(output.shape)
base/models/resnet.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+ import os
3
+ import sys
4
+ sys.path.append(os.path.split(sys.path[0])[0])
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange
11
+
12
+
13
+ class InflatedConv3d(nn.Conv2d):
14
+ def forward(self, x):
15
+ video_length = x.shape[2]
16
+
17
+ x = rearrange(x, "b c f h w -> (b f) c h w")
18
+ x = super().forward(x)
19
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
20
+
21
+ return x
22
+
23
+
24
+ class Upsample3D(nn.Module):
25
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
26
+ super().__init__()
27
+ self.channels = channels
28
+ self.out_channels = out_channels or channels
29
+ self.use_conv = use_conv
30
+ self.use_conv_transpose = use_conv_transpose
31
+ self.name = name
32
+
33
+ conv = None
34
+ if use_conv_transpose:
35
+ raise NotImplementedError
36
+ elif use_conv:
37
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
38
+
39
+ if name == "conv":
40
+ self.conv = conv
41
+ else:
42
+ self.Conv2d_0 = conv
43
+
44
+ def forward(self, hidden_states, output_size=None):
45
+ assert hidden_states.shape[1] == self.channels
46
+
47
+ if self.use_conv_transpose:
48
+ raise NotImplementedError
49
+
50
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
51
+ dtype = hidden_states.dtype
52
+ if dtype == torch.bfloat16:
53
+ hidden_states = hidden_states.to(torch.float32)
54
+
55
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
56
+ if hidden_states.shape[0] >= 64:
57
+ hidden_states = hidden_states.contiguous()
58
+
59
+ # if `output_size` is passed we force the interpolation output
60
+ # size and do not make use of `scale_factor=2`
61
+ if output_size is None:
62
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
63
+ else:
64
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
65
+
66
+ # If the input is bfloat16, we cast back to bfloat16
67
+ if dtype == torch.bfloat16:
68
+ hidden_states = hidden_states.to(dtype)
69
+
70
+ if self.use_conv:
71
+ if self.name == "conv":
72
+ hidden_states = self.conv(hidden_states)
73
+ else:
74
+ hidden_states = self.Conv2d_0(hidden_states)
75
+
76
+ return hidden_states
77
+
78
+
79
+ class Downsample3D(nn.Module):
80
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
81
+ super().__init__()
82
+ self.channels = channels
83
+ self.out_channels = out_channels or channels
84
+ self.use_conv = use_conv
85
+ self.padding = padding
86
+ stride = 2
87
+ self.name = name
88
+
89
+ if use_conv:
90
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
91
+ else:
92
+ raise NotImplementedError
93
+
94
+ if name == "conv":
95
+ self.Conv2d_0 = conv
96
+ self.conv = conv
97
+ elif name == "Conv2d_0":
98
+ self.conv = conv
99
+ else:
100
+ self.conv = conv
101
+
102
+ def forward(self, hidden_states):
103
+ assert hidden_states.shape[1] == self.channels
104
+ if self.use_conv and self.padding == 0:
105
+ raise NotImplementedError
106
+
107
+ assert hidden_states.shape[1] == self.channels
108
+ hidden_states = self.conv(hidden_states)
109
+
110
+ return hidden_states
111
+
112
+
113
+ class ResnetBlock3D(nn.Module):
114
+ def __init__(
115
+ self,
116
+ *,
117
+ in_channels,
118
+ out_channels=None,
119
+ conv_shortcut=False,
120
+ dropout=0.0,
121
+ temb_channels=512,
122
+ groups=32,
123
+ groups_out=None,
124
+ pre_norm=True,
125
+ eps=1e-6,
126
+ non_linearity="swish",
127
+ time_embedding_norm="default",
128
+ output_scale_factor=1.0,
129
+ use_in_shortcut=None,
130
+ ):
131
+ super().__init__()
132
+ self.pre_norm = pre_norm
133
+ self.pre_norm = True
134
+ self.in_channels = in_channels
135
+ out_channels = in_channels if out_channels is None else out_channels
136
+ self.out_channels = out_channels
137
+ self.use_conv_shortcut = conv_shortcut
138
+ self.time_embedding_norm = time_embedding_norm
139
+ self.output_scale_factor = output_scale_factor
140
+
141
+ if groups_out is None:
142
+ groups_out = groups
143
+
144
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
145
+
146
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
147
+
148
+ if temb_channels is not None:
149
+ if self.time_embedding_norm == "default":
150
+ time_emb_proj_out_channels = out_channels
151
+ elif self.time_embedding_norm == "scale_shift":
152
+ time_emb_proj_out_channels = out_channels * 2
153
+ else:
154
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
155
+
156
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
157
+ else:
158
+ self.time_emb_proj = None
159
+
160
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
161
+ self.dropout = torch.nn.Dropout(dropout)
162
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
163
+
164
+ if non_linearity == "swish":
165
+ self.nonlinearity = lambda x: F.silu(x)
166
+ elif non_linearity == "mish":
167
+ self.nonlinearity = Mish()
168
+ elif non_linearity == "silu":
169
+ self.nonlinearity = nn.SiLU()
170
+
171
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
172
+
173
+ self.conv_shortcut = None
174
+ if self.use_in_shortcut:
175
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
176
+
177
+ def forward(self, input_tensor, temb):
178
+ hidden_states = input_tensor
179
+
180
+ hidden_states = self.norm1(hidden_states)
181
+ hidden_states = self.nonlinearity(hidden_states)
182
+
183
+ hidden_states = self.conv1(hidden_states)
184
+
185
+ if temb is not None:
186
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
187
+
188
+ if temb is not None and self.time_embedding_norm == "default":
189
+ hidden_states = hidden_states + temb
190
+
191
+ hidden_states = self.norm2(hidden_states)
192
+
193
+ if temb is not None and self.time_embedding_norm == "scale_shift":
194
+ scale, shift = torch.chunk(temb, 2, dim=1)
195
+ hidden_states = hidden_states * (1 + scale) + shift
196
+
197
+ hidden_states = self.nonlinearity(hidden_states)
198
+
199
+ hidden_states = self.dropout(hidden_states)
200
+ hidden_states = self.conv2(hidden_states)
201
+
202
+ if self.conv_shortcut is not None:
203
+ input_tensor = self.conv_shortcut(input_tensor)
204
+
205
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
206
+
207
+ return output_tensor
208
+
209
+
210
+ class Mish(torch.nn.Module):
211
+ def forward(self, hidden_states):
212
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
base/models/temporal_attention.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from typing import Optional
4
+ from rotary_embedding_torch import RotaryEmbedding
5
+ from dataclasses import dataclass
6
+ from diffusers.utils import BaseOutput
7
+ from diffusers.utils.import_utils import is_xformers_available
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+ import math
11
+
12
+ @dataclass
13
+ class Transformer3DModelOutput(BaseOutput):
14
+ sample: torch.FloatTensor
15
+
16
+
17
+ if is_xformers_available():
18
+ import xformers
19
+ import xformers.ops
20
+ else:
21
+ xformers = None
22
+
23
+ def exists(x):
24
+ return x is not None
25
+
26
+ class CrossAttention(nn.Module):
27
+ r"""
28
+ copy from diffuser 0.11.1
29
+ A cross attention layer.
30
+ Parameters:
31
+ query_dim (`int`): The number of channels in the query.
32
+ cross_attention_dim (`int`, *optional*):
33
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
34
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
35
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
36
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
37
+ bias (`bool`, *optional*, defaults to False):
38
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ query_dim: int,
44
+ cross_attention_dim: Optional[int] = None,
45
+ heads: int = 8,
46
+ dim_head: int = 64,
47
+ dropout: float = 0.0,
48
+ bias=False,
49
+ upcast_attention: bool = False,
50
+ upcast_softmax: bool = False,
51
+ added_kv_proj_dim: Optional[int] = None,
52
+ norm_num_groups: Optional[int] = None,
53
+ use_relative_position: bool = False,
54
+ ):
55
+ super().__init__()
56
+ # print('num head', heads)
57
+ inner_dim = dim_head * heads
58
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
59
+ self.upcast_attention = upcast_attention
60
+ self.upcast_softmax = upcast_softmax
61
+
62
+ self.scale = dim_head**-0.5
63
+
64
+ self.heads = heads
65
+ self.dim_head = dim_head
66
+ # for slice_size > 0 the attention score computation
67
+ # is split across the batch axis to save memory
68
+ # You can set slice_size with `set_attention_slice`
69
+ self.sliceable_head_dim = heads
70
+ self._slice_size = None
71
+ self._use_memory_efficient_attention_xformers = False # No use xformers for temporal attention
72
+ self.added_kv_proj_dim = added_kv_proj_dim
73
+
74
+ if norm_num_groups is not None:
75
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
76
+ else:
77
+ self.group_norm = None
78
+
79
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
80
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
81
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
82
+
83
+ if self.added_kv_proj_dim is not None:
84
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
85
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
86
+
87
+ self.to_out = nn.ModuleList([])
88
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
89
+ self.to_out.append(nn.Dropout(dropout))
90
+
91
+ def reshape_heads_to_batch_dim(self, tensor):
92
+ batch_size, seq_len, dim = tensor.shape
93
+ head_size = self.heads
94
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
95
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
96
+ return tensor
97
+
98
+ def reshape_batch_dim_to_heads(self, tensor):
99
+ batch_size, seq_len, dim = tensor.shape
100
+ head_size = self.heads
101
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
102
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
103
+ return tensor
104
+
105
+ def reshape_for_scores(self, tensor):
106
+ # split heads and dims
107
+ # tensor should be [b (h w)] f (d nd)
108
+ batch_size, seq_len, dim = tensor.shape
109
+ head_size = self.heads
110
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
111
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
112
+ return tensor
113
+
114
+ def same_batch_dim_to_heads(self, tensor):
115
+ batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d
116
+ tensor = tensor.reshape(batch_size, seq_len, dim * head_size)
117
+ return tensor
118
+
119
+ def set_attention_slice(self, slice_size):
120
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
121
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
122
+
123
+ self._slice_size = slice_size
124
+
125
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
126
+ batch_size, sequence_length, _ = hidden_states.shape
127
+
128
+ encoder_hidden_states = encoder_hidden_states
129
+
130
+ if self.group_norm is not None:
131
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
132
+
133
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
134
+
135
+ # print('before reshpape query shape', query.shape)
136
+ dim = query.shape[-1]
137
+ query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d
138
+ # print('after reshape query shape', query.shape)
139
+
140
+ if self.added_kv_proj_dim is not None:
141
+ key = self.to_k(hidden_states)
142
+ value = self.to_v(hidden_states)
143
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
144
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
145
+
146
+ key = self.reshape_heads_to_batch_dim(key)
147
+ value = self.reshape_heads_to_batch_dim(value)
148
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
149
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
150
+
151
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
152
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
153
+ else:
154
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
155
+ key = self.to_k(encoder_hidden_states)
156
+ value = self.to_v(encoder_hidden_states)
157
+
158
+ key = self.reshape_heads_to_batch_dim(key)
159
+ value = self.reshape_heads_to_batch_dim(value)
160
+
161
+ if attention_mask is not None:
162
+ if attention_mask.shape[-1] != query.shape[1]:
163
+ target_length = query.shape[1]
164
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
165
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
166
+
167
+ hidden_states = self._attention(query, key, value, attention_mask)
168
+
169
+ # linear proj
170
+ hidden_states = self.to_out[0](hidden_states)
171
+
172
+ # dropout
173
+ hidden_states = self.to_out[1](hidden_states)
174
+ return hidden_states
175
+
176
+
177
+ def _attention(self, query, key, value, attention_mask=None):
178
+ if self.upcast_attention:
179
+ query = query.float()
180
+ key = key.float()
181
+
182
+ attention_scores = torch.baddbmm(
183
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
184
+ query,
185
+ key.transpose(-1, -2),
186
+ beta=0,
187
+ alpha=self.scale,
188
+ )
189
+
190
+ if attention_mask is not None:
191
+ attention_scores = attention_scores + attention_mask
192
+
193
+ if self.upcast_softmax:
194
+ attention_scores = attention_scores.float()
195
+
196
+ attention_probs = attention_scores.softmax(dim=-1)
197
+ attention_probs = attention_probs.to(value.dtype)
198
+ # compute attention output
199
+ hidden_states = torch.bmm(attention_probs, value)
200
+ # reshape hidden_states
201
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
202
+ return hidden_states
203
+
204
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
205
+ batch_size_attention = query.shape[0]
206
+ hidden_states = torch.zeros(
207
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
208
+ )
209
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
210
+ for i in range(hidden_states.shape[0] // slice_size):
211
+ start_idx = i * slice_size
212
+ end_idx = (i + 1) * slice_size
213
+
214
+ query_slice = query[start_idx:end_idx]
215
+ key_slice = key[start_idx:end_idx]
216
+
217
+ if self.upcast_attention:
218
+ query_slice = query_slice.float()
219
+ key_slice = key_slice.float()
220
+
221
+ attn_slice = torch.baddbmm(
222
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
223
+ query_slice,
224
+ key_slice.transpose(-1, -2),
225
+ beta=0,
226
+ alpha=self.scale,
227
+ )
228
+
229
+ if attention_mask is not None:
230
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
231
+
232
+ if self.upcast_softmax:
233
+ attn_slice = attn_slice.float()
234
+
235
+ attn_slice = attn_slice.softmax(dim=-1)
236
+
237
+ # cast back to the original dtype
238
+ attn_slice = attn_slice.to(value.dtype)
239
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
240
+
241
+ hidden_states[start_idx:end_idx] = attn_slice
242
+
243
+ # reshape hidden_states
244
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
245
+ return hidden_states
246
+
247
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
248
+ # TODO attention_mask
249
+ query = query.contiguous()
250
+ key = key.contiguous()
251
+ value = value.contiguous()
252
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
253
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
254
+ return hidden_states
255
+
256
+ class TemporalAttention(CrossAttention):
257
+ def __init__(self,
258
+ query_dim: int,
259
+ cross_attention_dim: Optional[int] = None,
260
+ heads: int = 8,
261
+ dim_head: int = 64,
262
+ dropout: float = 0.0,
263
+ bias=False,
264
+ upcast_attention: bool = False,
265
+ upcast_softmax: bool = False,
266
+ added_kv_proj_dim: Optional[int] = None,
267
+ norm_num_groups: Optional[int] = None,
268
+ rotary_emb=None):
269
+ super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups)
270
+ # relative time positional embeddings
271
+ self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet
272
+ self.rotary_emb = rotary_emb
273
+
274
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
275
+ time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device)
276
+ batch_size, sequence_length, _ = hidden_states.shape
277
+
278
+ encoder_hidden_states = encoder_hidden_states
279
+
280
+ if self.group_norm is not None:
281
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
282
+
283
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
284
+ dim = query.shape[-1]
285
+
286
+ if self.added_kv_proj_dim is not None:
287
+ key = self.to_k(hidden_states)
288
+ value = self.to_v(hidden_states)
289
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
290
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
291
+
292
+ key = self.reshape_heads_to_batch_dim(key)
293
+ value = self.reshape_heads_to_batch_dim(value)
294
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
295
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
296
+
297
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
298
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
299
+ else:
300
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
301
+ key = self.to_k(encoder_hidden_states)
302
+ value = self.to_v(encoder_hidden_states)
303
+
304
+ if attention_mask is not None:
305
+ if attention_mask.shape[-1] != query.shape[1]:
306
+ target_length = query.shape[1]
307
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
308
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
309
+
310
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
311
+ hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias)
312
+ else:
313
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
314
+
315
+ # linear proj
316
+ hidden_states = self.to_out[0](hidden_states)
317
+
318
+ # dropout
319
+ hidden_states = self.to_out[1](hidden_states)
320
+ return hidden_states
321
+
322
+
323
+ def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None):
324
+ if self.upcast_attention:
325
+ query = query.float()
326
+ key = key.float()
327
+
328
+ query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
329
+ key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
330
+ value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
331
+ if exists(self.rotary_emb):
332
+ query = self.rotary_emb.rotate_queries_or_keys(query)
333
+ key = self.rotary_emb.rotate_queries_or_keys(key)
334
+
335
+ attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key)
336
+ attention_scores = attention_scores + time_rel_pos_bias
337
+
338
+ if attention_mask is not None:
339
+ # add attention mask
340
+ attention_scores = attention_scores + attention_mask
341
+
342
+ attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach()
343
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
344
+
345
+ attention_probs = attention_probs.to(value.dtype)
346
+ hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value)
347
+ hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)')
348
+ return hidden_states
349
+
350
+ class RelativePositionBias(nn.Module):
351
+ def __init__(
352
+ self,
353
+ heads=8,
354
+ num_buckets=32,
355
+ max_distance=128,
356
+ ):
357
+ super().__init__()
358
+ self.num_buckets = num_buckets
359
+ self.max_distance = max_distance
360
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
361
+
362
+ @staticmethod
363
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
364
+ ret = 0
365
+ n = -relative_position
366
+
367
+ num_buckets //= 2
368
+ ret += (n < 0).long() * num_buckets
369
+ n = torch.abs(n)
370
+
371
+ max_exact = num_buckets // 2
372
+ is_small = n < max_exact
373
+
374
+ val_if_large = max_exact + (
375
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
376
+ ).long()
377
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
378
+
379
+ ret += torch.where(is_small, n, val_if_large)
380
+ return ret
381
+
382
+ def forward(self, n, device):
383
+ q_pos = torch.arange(n, dtype = torch.long, device = device)
384
+ k_pos = torch.arange(n, dtype = torch.long, device = device)
385
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
386
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
387
+ values = self.relative_attention_bias(rp_bucket)
388
+ return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
base/models/transformer_3d.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ from dataclasses import dataclass
14
+ from typing import Any, Dict, Optional
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
22
+ from diffusers.utils import BaseOutput, deprecate
23
+ from diffusers.models.embeddings import PatchEmbed
24
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from einops import rearrange, repeat
27
+
28
+ try:
29
+ from attention import BasicTransformerBlock
30
+ except:
31
+ from .attention import BasicTransformerBlock
32
+
33
+ @dataclass
34
+ class Transformer3DModelOutput(BaseOutput):
35
+ """
36
+ The output of [`Transformer2DModel`].
37
+
38
+ Args:
39
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
40
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
41
+ distributions for the unnoised latent pixels.
42
+ """
43
+
44
+ sample: torch.FloatTensor
45
+
46
+
47
+ class Transformer3DModel(ModelMixin, ConfigMixin):
48
+ """
49
+ A 2D Transformer model for image-like data.
50
+
51
+ Parameters:
52
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
53
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
54
+ in_channels (`int`, *optional*):
55
+ The number of channels in the input and output (specify if the input is **continuous**).
56
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
57
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
58
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
59
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
60
+ This is fixed during training since it is used to learn a number of position embeddings.
61
+ num_vector_embeds (`int`, *optional*):
62
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
63
+ Includes the class for the masked latent pixel.
64
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
65
+ num_embeds_ada_norm ( `int`, *optional*):
66
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
67
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
68
+ added to the hidden states.
69
+
70
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
71
+ attention_bias (`bool`, *optional*):
72
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
73
+ """
74
+
75
+ @register_to_config
76
+ def __init__(
77
+ self,
78
+ num_attention_heads: int = 16,
79
+ attention_head_dim: int = 88,
80
+ in_channels: Optional[int] = None,
81
+ out_channels: Optional[int] = None,
82
+ num_layers: int = 1,
83
+ dropout: float = 0.0,
84
+ norm_num_groups: int = 32,
85
+ cross_attention_dim: Optional[int] = None,
86
+ attention_bias: bool = False,
87
+ sample_size: Optional[int] = None,
88
+ num_vector_embeds: Optional[int] = None,
89
+ patch_size: Optional[int] = None,
90
+ activation_fn: str = "geglu",
91
+ num_embeds_ada_norm: Optional[int] = None,
92
+ use_linear_projection: bool = False,
93
+ only_cross_attention: bool = False,
94
+ upcast_attention: bool = False,
95
+ norm_type: str = "layer_norm",
96
+ norm_elementwise_affine: bool = True,
97
+ rotary_emb=None,
98
+ ):
99
+ super().__init__()
100
+ self.use_linear_projection = use_linear_projection
101
+ self.num_attention_heads = num_attention_heads
102
+ self.attention_head_dim = attention_head_dim
103
+ inner_dim = num_attention_heads * attention_head_dim
104
+
105
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
106
+ # Define whether input is continuous or discrete depending on configuration
107
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
108
+ self.is_input_vectorized = num_vector_embeds is not None
109
+ self.is_input_patches = in_channels is not None and patch_size is not None
110
+
111
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
112
+ deprecation_message = (
113
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
114
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
115
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
116
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
117
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
118
+ )
119
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
120
+ norm_type = "ada_norm"
121
+
122
+ if self.is_input_continuous and self.is_input_vectorized:
123
+ raise ValueError(
124
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
125
+ " sure that either `in_channels` or `num_vector_embeds` is None."
126
+ )
127
+ elif self.is_input_vectorized and self.is_input_patches:
128
+ raise ValueError(
129
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
130
+ " sure that either `num_vector_embeds` or `num_patches` is None."
131
+ )
132
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
133
+ raise ValueError(
134
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
135
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
136
+ )
137
+
138
+ # 2. Define input layers
139
+ if self.is_input_continuous:
140
+ self.in_channels = in_channels
141
+
142
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
143
+ if use_linear_projection:
144
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
145
+ else:
146
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
147
+ elif self.is_input_vectorized:
148
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
149
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
150
+
151
+ self.height = sample_size
152
+ self.width = sample_size
153
+ self.num_vector_embeds = num_vector_embeds
154
+ self.num_latent_pixels = self.height * self.width
155
+
156
+ self.latent_image_embedding = ImagePositionalEmbeddings(
157
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
158
+ )
159
+ elif self.is_input_patches:
160
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
161
+
162
+ self.height = sample_size
163
+ self.width = sample_size
164
+
165
+ self.patch_size = patch_size
166
+ self.pos_embed = PatchEmbed(
167
+ height=sample_size,
168
+ width=sample_size,
169
+ patch_size=patch_size,
170
+ in_channels=in_channels,
171
+ embed_dim=inner_dim,
172
+ )
173
+
174
+ # 3. Define transformers blocks
175
+ self.transformer_blocks = nn.ModuleList(
176
+ [
177
+ BasicTransformerBlock(
178
+ inner_dim,
179
+ num_attention_heads,
180
+ attention_head_dim,
181
+ dropout=dropout,
182
+ cross_attention_dim=cross_attention_dim,
183
+ activation_fn=activation_fn,
184
+ num_embeds_ada_norm=num_embeds_ada_norm,
185
+ attention_bias=attention_bias,
186
+ only_cross_attention=only_cross_attention,
187
+ upcast_attention=upcast_attention,
188
+ norm_type=norm_type,
189
+ norm_elementwise_affine=norm_elementwise_affine,
190
+ rotary_emb=rotary_emb,
191
+ )
192
+ for d in range(num_layers)
193
+ ]
194
+ )
195
+
196
+ # 4. Define output layers
197
+ self.out_channels = in_channels if out_channels is None else out_channels
198
+ if self.is_input_continuous:
199
+ # TODO: should use out_channels for continuous projections
200
+ if use_linear_projection:
201
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
202
+ else:
203
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
204
+ elif self.is_input_vectorized:
205
+ self.norm_out = nn.LayerNorm(inner_dim)
206
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
207
+ elif self.is_input_patches:
208
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
209
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
210
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
211
+
212
+ def forward(
213
+ self,
214
+ hidden_states: torch.Tensor,
215
+ encoder_hidden_states: Optional[torch.Tensor] = None,
216
+ timestep: Optional[torch.LongTensor] = None,
217
+ class_labels: Optional[torch.LongTensor] = None,
218
+ cross_attention_kwargs: Dict[str, Any] = None,
219
+ attention_mask: Optional[torch.Tensor] = None,
220
+ encoder_attention_mask: Optional[torch.Tensor] = None,
221
+ return_dict: bool = True,
222
+ use_image_num=None,
223
+ ):
224
+ """
225
+ The [`Transformer2DModel`] forward method.
226
+
227
+ Args:
228
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
229
+ Input `hidden_states`.
230
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
231
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
232
+ self-attention.
233
+ timestep ( `torch.LongTensor`, *optional*):
234
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
235
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
236
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
237
+ `AdaLayerZeroNorm`.
238
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
239
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
240
+
241
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
242
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
243
+
244
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
245
+ above. This bias will be added to the cross-attention scores.
246
+ return_dict (`bool`, *optional*, defaults to `True`):
247
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
248
+ tuple.
249
+
250
+ Returns:
251
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
252
+ `tuple` where the first element is the sample tensor.
253
+ """
254
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
255
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
256
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
257
+ # expects mask of shape:
258
+ # [batch, key_tokens]
259
+ # adds singleton query_tokens dimension:
260
+ # [batch, 1, key_tokens]
261
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
262
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
263
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
264
+ if attention_mask is not None and attention_mask.ndim == 2:
265
+ # assume that mask is expressed as:
266
+ # (1 = keep, 0 = discard)
267
+ # convert mask into a bias that can be added to attention scores:
268
+ # (keep = +0, discard = -10000.0)
269
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
270
+ attention_mask = attention_mask.unsqueeze(1)
271
+
272
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
273
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
274
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
275
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
276
+
277
+ # 1. Input
278
+ if self.is_input_continuous: # True
279
+
280
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
281
+ if self.training:
282
+ video_length = hidden_states.shape[2] - use_image_num
283
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
284
+ encoder_hidden_states_length = encoder_hidden_states.shape[1]
285
+ encoder_hidden_states_video = encoder_hidden_states[:, :encoder_hidden_states_length - use_image_num, ...]
286
+ encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b m n c -> b (m f) n c', f=video_length).contiguous()
287
+ encoder_hidden_states_image = encoder_hidden_states[:, encoder_hidden_states_length - use_image_num:, ...]
288
+ encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
289
+ encoder_hidden_states = rearrange(encoder_hidden_states, 'b m n c -> (b m) n c').contiguous()
290
+ else:
291
+ video_length = hidden_states.shape[2]
292
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
293
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous()
294
+
295
+ batch, _, height, width = hidden_states.shape
296
+ residual = hidden_states
297
+
298
+ hidden_states = self.norm(hidden_states)
299
+ if not self.use_linear_projection:
300
+ hidden_states = self.proj_in(hidden_states)
301
+ inner_dim = hidden_states.shape[1]
302
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
303
+ else:
304
+ inner_dim = hidden_states.shape[1]
305
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
306
+ hidden_states = self.proj_in(hidden_states)
307
+ elif self.is_input_vectorized:
308
+ hidden_states = self.latent_image_embedding(hidden_states)
309
+ elif self.is_input_patches:
310
+ hidden_states = self.pos_embed(hidden_states)
311
+
312
+ # 2. Blocks
313
+ for block in self.transformer_blocks:
314
+ hidden_states = block(
315
+ hidden_states,
316
+ attention_mask=attention_mask,
317
+ encoder_hidden_states=encoder_hidden_states,
318
+ encoder_attention_mask=encoder_attention_mask,
319
+ timestep=timestep,
320
+ cross_attention_kwargs=cross_attention_kwargs,
321
+ class_labels=class_labels,
322
+ video_length=video_length,
323
+ use_image_num=use_image_num,
324
+ )
325
+
326
+ # 3. Output
327
+ if self.is_input_continuous:
328
+ if not self.use_linear_projection:
329
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
330
+ hidden_states = self.proj_out(hidden_states)
331
+ else:
332
+ hidden_states = self.proj_out(hidden_states)
333
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
334
+
335
+ output = hidden_states + residual
336
+ elif self.is_input_vectorized:
337
+ hidden_states = self.norm_out(hidden_states)
338
+ logits = self.out(hidden_states)
339
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
340
+ logits = logits.permute(0, 2, 1)
341
+
342
+ # log(p(x_0))
343
+ output = F.log_softmax(logits.double(), dim=1).float()
344
+ elif self.is_input_patches:
345
+ # TODO: cleanup!
346
+ conditioning = self.transformer_blocks[0].norm1.emb(
347
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
348
+ )
349
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
350
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
351
+ hidden_states = self.proj_out_2(hidden_states)
352
+
353
+ # unpatchify
354
+ height = width = int(hidden_states.shape[1] ** 0.5)
355
+ hidden_states = hidden_states.reshape(
356
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
357
+ )
358
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
359
+ output = hidden_states.reshape(
360
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
361
+ )
362
+
363
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length + use_image_num).contiguous()
364
+ if not return_dict:
365
+ return (output,)
366
+
367
+ return Transformer3DModelOutput(sample=output)
base/models/unet.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import os
7
+ import sys
8
+ sys.path.append(os.path.split(sys.path[0])[0])
9
+
10
+ import math
11
+ import json
12
+ import torch
13
+ import einops
14
+ import torch.nn as nn
15
+ import torch.utils.checkpoint
16
+
17
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
18
+ from diffusers.utils import BaseOutput, logging
19
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
20
+
21
+ try:
22
+ from diffusers.models.modeling_utils import ModelMixin
23
+ except:
24
+ from diffusers.modeling_utils import ModelMixin # 0.11.1
25
+
26
+ try:
27
+ from .unet_blocks import (
28
+ CrossAttnDownBlock3D,
29
+ CrossAttnUpBlock3D,
30
+ DownBlock3D,
31
+ UNetMidBlock3DCrossAttn,
32
+ UpBlock3D,
33
+ get_down_block,
34
+ get_up_block,
35
+ )
36
+ from .resnet import InflatedConv3d
37
+ except:
38
+ from unet_blocks import (
39
+ CrossAttnDownBlock3D,
40
+ CrossAttnUpBlock3D,
41
+ DownBlock3D,
42
+ UNetMidBlock3DCrossAttn,
43
+ UpBlock3D,
44
+ get_down_block,
45
+ get_up_block,
46
+ )
47
+ from resnet import InflatedConv3d
48
+
49
+ from rotary_embedding_torch import RotaryEmbedding
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+ class RelativePositionBias(nn.Module):
54
+ def __init__(
55
+ self,
56
+ heads=8,
57
+ num_buckets=32,
58
+ max_distance=128,
59
+ ):
60
+ super().__init__()
61
+ self.num_buckets = num_buckets
62
+ self.max_distance = max_distance
63
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
64
+
65
+ @staticmethod
66
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
67
+ ret = 0
68
+ n = -relative_position
69
+
70
+ num_buckets //= 2
71
+ ret += (n < 0).long() * num_buckets
72
+ n = torch.abs(n)
73
+
74
+ max_exact = num_buckets // 2
75
+ is_small = n < max_exact
76
+
77
+ val_if_large = max_exact + (
78
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
79
+ ).long()
80
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
81
+
82
+ ret += torch.where(is_small, n, val_if_large)
83
+ return ret
84
+
85
+ def forward(self, n, device):
86
+ q_pos = torch.arange(n, dtype = torch.long, device = device)
87
+ k_pos = torch.arange(n, dtype = torch.long, device = device)
88
+ rel_pos = einops.rearrange(k_pos, 'j -> 1 j') - einops.rearrange(q_pos, 'i -> i 1')
89
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
90
+ values = self.relative_attention_bias(rp_bucket)
91
+ return einops.rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
92
+
93
+ @dataclass
94
+ class UNet3DConditionOutput(BaseOutput):
95
+ sample: torch.FloatTensor
96
+
97
+
98
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
99
+ _supports_gradient_checkpointing = True
100
+
101
+ @register_to_config
102
+ def __init__(
103
+ self,
104
+ sample_size: Optional[int] = None, # 64
105
+ in_channels: int = 4,
106
+ out_channels: int = 4,
107
+ center_input_sample: bool = False,
108
+ flip_sin_to_cos: bool = True,
109
+ freq_shift: int = 0,
110
+ down_block_types: Tuple[str] = (
111
+ "CrossAttnDownBlock3D",
112
+ "CrossAttnDownBlock3D",
113
+ "CrossAttnDownBlock3D",
114
+ "DownBlock3D",
115
+ ),
116
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
117
+ up_block_types: Tuple[str] = (
118
+ "UpBlock3D",
119
+ "CrossAttnUpBlock3D",
120
+ "CrossAttnUpBlock3D",
121
+ "CrossAttnUpBlock3D"
122
+ ),
123
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
124
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
125
+ layers_per_block: int = 2,
126
+ downsample_padding: int = 1,
127
+ mid_block_scale_factor: float = 1,
128
+ act_fn: str = "silu",
129
+ norm_num_groups: int = 32,
130
+ norm_eps: float = 1e-5,
131
+ cross_attention_dim: int = 1280,
132
+ attention_head_dim: Union[int, Tuple[int]] = 8,
133
+ dual_cross_attention: bool = False,
134
+ use_linear_projection: bool = False,
135
+ class_embed_type: Optional[str] = None,
136
+ num_class_embeds: Optional[int] = None,
137
+ upcast_attention: bool = False,
138
+ resnet_time_scale_shift: str = "default",
139
+ use_first_frame: bool = False,
140
+ use_relative_position: bool = False,
141
+ ):
142
+ super().__init__()
143
+
144
+ # print(use_first_frame)
145
+
146
+ self.sample_size = sample_size
147
+ time_embed_dim = block_out_channels[0] * 4
148
+
149
+ # input
150
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
151
+
152
+ # time
153
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
154
+ timestep_input_dim = block_out_channels[0]
155
+
156
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
157
+
158
+ # class embedding
159
+ if class_embed_type is None and num_class_embeds is not None:
160
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
161
+ elif class_embed_type == "timestep":
162
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
163
+ elif class_embed_type == "identity":
164
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
165
+ else:
166
+ self.class_embedding = None
167
+
168
+ self.down_blocks = nn.ModuleList([])
169
+ self.mid_block = None
170
+ self.up_blocks = nn.ModuleList([])
171
+
172
+ # print(only_cross_attention)
173
+ # print(type(only_cross_attention))
174
+ # exit()
175
+ if isinstance(only_cross_attention, bool):
176
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
177
+ # print(only_cross_attention)
178
+ # exit()
179
+
180
+ if isinstance(attention_head_dim, int):
181
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
182
+ # print(attention_head_dim)
183
+ # exit()
184
+
185
+ rotary_emb = RotaryEmbedding(32)
186
+
187
+ # down
188
+ output_channel = block_out_channels[0]
189
+ for i, down_block_type in enumerate(down_block_types):
190
+ input_channel = output_channel
191
+ output_channel = block_out_channels[i]
192
+ is_final_block = i == len(block_out_channels) - 1
193
+
194
+ down_block = get_down_block(
195
+ down_block_type,
196
+ num_layers=layers_per_block,
197
+ in_channels=input_channel,
198
+ out_channels=output_channel,
199
+ temb_channels=time_embed_dim,
200
+ add_downsample=not is_final_block,
201
+ resnet_eps=norm_eps,
202
+ resnet_act_fn=act_fn,
203
+ resnet_groups=norm_num_groups,
204
+ cross_attention_dim=cross_attention_dim,
205
+ attn_num_head_channels=attention_head_dim[i],
206
+ downsample_padding=downsample_padding,
207
+ dual_cross_attention=dual_cross_attention,
208
+ use_linear_projection=use_linear_projection,
209
+ only_cross_attention=only_cross_attention[i],
210
+ upcast_attention=upcast_attention,
211
+ resnet_time_scale_shift=resnet_time_scale_shift,
212
+ use_first_frame=use_first_frame,
213
+ use_relative_position=use_relative_position,
214
+ rotary_emb=rotary_emb,
215
+ )
216
+ self.down_blocks.append(down_block)
217
+
218
+ # mid
219
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
220
+ self.mid_block = UNetMidBlock3DCrossAttn(
221
+ in_channels=block_out_channels[-1],
222
+ temb_channels=time_embed_dim,
223
+ resnet_eps=norm_eps,
224
+ resnet_act_fn=act_fn,
225
+ output_scale_factor=mid_block_scale_factor,
226
+ resnet_time_scale_shift=resnet_time_scale_shift,
227
+ cross_attention_dim=cross_attention_dim,
228
+ attn_num_head_channels=attention_head_dim[-1],
229
+ resnet_groups=norm_num_groups,
230
+ dual_cross_attention=dual_cross_attention,
231
+ use_linear_projection=use_linear_projection,
232
+ upcast_attention=upcast_attention,
233
+ use_first_frame=use_first_frame,
234
+ use_relative_position=use_relative_position,
235
+ rotary_emb=rotary_emb,
236
+ )
237
+ else:
238
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
239
+
240
+ # count how many layers upsample the videos
241
+ self.num_upsamplers = 0
242
+
243
+ # up
244
+ reversed_block_out_channels = list(reversed(block_out_channels))
245
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
246
+ only_cross_attention = list(reversed(only_cross_attention))
247
+ output_channel = reversed_block_out_channels[0]
248
+ for i, up_block_type in enumerate(up_block_types):
249
+ is_final_block = i == len(block_out_channels) - 1
250
+
251
+ prev_output_channel = output_channel
252
+ output_channel = reversed_block_out_channels[i]
253
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
254
+
255
+ # add upsample block for all BUT final layer
256
+ if not is_final_block:
257
+ add_upsample = True
258
+ self.num_upsamplers += 1
259
+ else:
260
+ add_upsample = False
261
+
262
+ up_block = get_up_block(
263
+ up_block_type,
264
+ num_layers=layers_per_block + 1,
265
+ in_channels=input_channel,
266
+ out_channels=output_channel,
267
+ prev_output_channel=prev_output_channel,
268
+ temb_channels=time_embed_dim,
269
+ add_upsample=add_upsample,
270
+ resnet_eps=norm_eps,
271
+ resnet_act_fn=act_fn,
272
+ resnet_groups=norm_num_groups,
273
+ cross_attention_dim=cross_attention_dim,
274
+ attn_num_head_channels=reversed_attention_head_dim[i],
275
+ dual_cross_attention=dual_cross_attention,
276
+ use_linear_projection=use_linear_projection,
277
+ only_cross_attention=only_cross_attention[i],
278
+ upcast_attention=upcast_attention,
279
+ resnet_time_scale_shift=resnet_time_scale_shift,
280
+ use_first_frame=use_first_frame,
281
+ use_relative_position=use_relative_position,
282
+ rotary_emb=rotary_emb,
283
+ )
284
+ self.up_blocks.append(up_block)
285
+ prev_output_channel = output_channel
286
+
287
+ # out
288
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
289
+ self.conv_act = nn.SiLU()
290
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
291
+
292
+ # relative time positional embeddings
293
+ self.use_relative_position = use_relative_position
294
+ if self.use_relative_position:
295
+ self.time_rel_pos_bias = RelativePositionBias(heads=8, max_distance=32) # realistically will not be able to generate that many frames of video... yet
296
+
297
+ def set_attention_slice(self, slice_size):
298
+ r"""
299
+ Enable sliced attention computation.
300
+
301
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
302
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
303
+
304
+ Args:
305
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
306
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
307
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
308
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
309
+ must be a multiple of `slice_size`.
310
+ """
311
+ sliceable_head_dims = []
312
+
313
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
314
+ if hasattr(module, "set_attention_slice"):
315
+ sliceable_head_dims.append(module.sliceable_head_dim)
316
+
317
+ for child in module.children():
318
+ fn_recursive_retrieve_slicable_dims(child)
319
+
320
+ # retrieve number of attention layers
321
+ for module in self.children():
322
+ fn_recursive_retrieve_slicable_dims(module)
323
+
324
+ num_slicable_layers = len(sliceable_head_dims)
325
+
326
+ if slice_size == "auto":
327
+ # half the attention head size is usually a good trade-off between
328
+ # speed and memory
329
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
330
+ elif slice_size == "max":
331
+ # make smallest slice possible
332
+ slice_size = num_slicable_layers * [1]
333
+
334
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
335
+
336
+ if len(slice_size) != len(sliceable_head_dims):
337
+ raise ValueError(
338
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
339
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
340
+ )
341
+
342
+ for i in range(len(slice_size)):
343
+ size = slice_size[i]
344
+ dim = sliceable_head_dims[i]
345
+ if size is not None and size > dim:
346
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
347
+
348
+ # Recursively walk through all the children.
349
+ # Any children which exposes the set_attention_slice method
350
+ # gets the message
351
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
352
+ if hasattr(module, "set_attention_slice"):
353
+ module.set_attention_slice(slice_size.pop())
354
+
355
+ for child in module.children():
356
+ fn_recursive_set_attention_slice(child, slice_size)
357
+
358
+ reversed_slice_size = list(reversed(slice_size))
359
+ for module in self.children():
360
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
361
+
362
+ def _set_gradient_checkpointing(self, module, value=False):
363
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
364
+ module.gradient_checkpointing = value
365
+
366
+ def forward(
367
+ self,
368
+ sample: torch.FloatTensor,
369
+ timestep: Union[torch.Tensor, float, int],
370
+ encoder_hidden_states: torch.Tensor = None,
371
+ class_labels: Optional[torch.Tensor] = None,
372
+ attention_mask: Optional[torch.Tensor] = None,
373
+ use_image_num: int = 0,
374
+ return_dict: bool = True,
375
+ ) -> Union[UNet3DConditionOutput, Tuple]:
376
+ r"""
377
+ Args:
378
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
379
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
380
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
381
+ return_dict (`bool`, *optional*, defaults to `True`):
382
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
383
+
384
+ Returns:
385
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
386
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
387
+ returning a tuple, the first element is the sample tensor.
388
+ """
389
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
390
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
391
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
392
+ # on the fly if necessary.
393
+ default_overall_up_factor = 2**self.num_upsamplers
394
+
395
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
396
+ forward_upsample_size = False
397
+ upsample_size = None
398
+
399
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
400
+ logger.info("Forward upsample size to force interpolation output size.")
401
+ forward_upsample_size = True
402
+
403
+ # prepare attention_mask
404
+ if attention_mask is not None:
405
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
406
+ attention_mask = attention_mask.unsqueeze(1)
407
+
408
+ # center input if necessary
409
+ if self.config.center_input_sample:
410
+ sample = 2 * sample - 1.0
411
+
412
+ # time
413
+ timesteps = timestep
414
+ if not torch.is_tensor(timesteps):
415
+ # This would be a good case for the `match` statement (Python 3.10+)
416
+ is_mps = sample.device.type == "mps"
417
+ if isinstance(timestep, float):
418
+ dtype = torch.float32 if is_mps else torch.float64
419
+ else:
420
+ dtype = torch.int32 if is_mps else torch.int64
421
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
422
+ elif len(timesteps.shape) == 0:
423
+ timesteps = timesteps[None].to(sample.device)
424
+
425
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
426
+ timesteps = timesteps.expand(sample.shape[0])
427
+
428
+ t_emb = self.time_proj(timesteps)
429
+
430
+ # timesteps does not contain any weights and will always return f32 tensors
431
+ # but time_embedding might actually be running in fp16. so we need to cast here.
432
+ # there might be better ways to encapsulate this.
433
+ t_emb = t_emb.to(dtype=self.dtype)
434
+ emb = self.time_embedding(t_emb)
435
+
436
+ if self.class_embedding is not None:
437
+ if class_labels is None:
438
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
439
+
440
+ if self.config.class_embed_type == "timestep":
441
+ class_labels = self.time_proj(class_labels)
442
+
443
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
444
+ # print(emb.shape) # torch.Size([3, 1280])
445
+ # print(class_emb.shape) # torch.Size([3, 1280])
446
+ emb = emb + class_emb
447
+
448
+ if self.use_relative_position:
449
+ frame_rel_pos_bias = self.time_rel_pos_bias(sample.shape[2], device=sample.device)
450
+ else:
451
+ frame_rel_pos_bias = None
452
+
453
+ # pre-process
454
+ sample = self.conv_in(sample)
455
+
456
+ # down
457
+ down_block_res_samples = (sample,)
458
+ for downsample_block in self.down_blocks:
459
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
460
+ sample, res_samples = downsample_block(
461
+ hidden_states=sample,
462
+ temb=emb,
463
+ encoder_hidden_states=encoder_hidden_states,
464
+ attention_mask=attention_mask,
465
+ use_image_num=use_image_num,
466
+ )
467
+ else:
468
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
469
+
470
+ down_block_res_samples += res_samples
471
+
472
+ # mid
473
+ sample = self.mid_block(
474
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num,
475
+ )
476
+
477
+ # up
478
+ for i, upsample_block in enumerate(self.up_blocks):
479
+ is_final_block = i == len(self.up_blocks) - 1
480
+
481
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
482
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
483
+
484
+ # if we have not reached the final block and need to forward the
485
+ # upsample size, we do it here
486
+ if not is_final_block and forward_upsample_size:
487
+ upsample_size = down_block_res_samples[-1].shape[2:]
488
+
489
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
490
+ sample = upsample_block(
491
+ hidden_states=sample,
492
+ temb=emb,
493
+ res_hidden_states_tuple=res_samples,
494
+ encoder_hidden_states=encoder_hidden_states,
495
+ upsample_size=upsample_size,
496
+ attention_mask=attention_mask,
497
+ use_image_num=use_image_num,
498
+ )
499
+ else:
500
+ sample = upsample_block(
501
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
502
+ )
503
+ # post-process
504
+ sample = self.conv_norm_out(sample)
505
+ sample = self.conv_act(sample)
506
+ sample = self.conv_out(sample)
507
+ # print(sample.shape)
508
+
509
+ if not return_dict:
510
+ return (sample,)
511
+ sample = UNet3DConditionOutput(sample=sample)
512
+ return sample
513
+
514
+ def forward_with_cfg(self,
515
+ x,
516
+ t,
517
+ encoder_hidden_states = None,
518
+ class_labels: Optional[torch.Tensor] = None,
519
+ cfg_scale=4.0,
520
+ use_fp16=False):
521
+ """
522
+ Forward, but also batches the unconditional forward pass for classifier-free guidance.
523
+ """
524
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
525
+ half = x[: len(x) // 2]
526
+ combined = torch.cat([half, half], dim=0)
527
+ if use_fp16:
528
+ combined = combined.to(dtype=torch.float16)
529
+ model_out = self.forward(combined, t, encoder_hidden_states, class_labels).sample
530
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
531
+ # three channels by default. The standard approach to cfg applies it to all channels.
532
+ # This can be done by uncommenting the following line and commenting-out the line following that.
533
+ eps, rest = model_out[:, :4], model_out[:, 4:]
534
+ # eps, rest = model_out[:, :3], model_out[:, 3:] # b c f h w
535
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
536
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
537
+ eps = torch.cat([half_eps, half_eps], dim=0)
538
+ return torch.cat([eps, rest], dim=1)
539
+
540
+ @classmethod
541
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None):
542
+ if subfolder is not None:
543
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
544
+
545
+
546
+ config_file = os.path.join(pretrained_model_path, 'config.json')
547
+ if not os.path.isfile(config_file):
548
+ raise RuntimeError(f"{config_file} does not exist")
549
+ with open(config_file, "r") as f:
550
+ config = json.load(f)
551
+ config["_class_name"] = cls.__name__
552
+ config["down_block_types"] = [
553
+ "CrossAttnDownBlock3D",
554
+ "CrossAttnDownBlock3D",
555
+ "CrossAttnDownBlock3D",
556
+ "DownBlock3D"
557
+ ]
558
+ config["up_block_types"] = [
559
+ "UpBlock3D",
560
+ "CrossAttnUpBlock3D",
561
+ "CrossAttnUpBlock3D",
562
+ "CrossAttnUpBlock3D"
563
+ ]
564
+
565
+ config["use_first_frame"] = False
566
+
567
+ from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin
568
+
569
+
570
+ model = cls.from_config(config)
571
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
572
+ if not os.path.isfile(model_file):
573
+ raise RuntimeError(f"{model_file} does not exist")
574
+ state_dict = torch.load(model_file, map_location="cpu")
575
+ for k, v in model.state_dict().items():
576
+ # print(k)
577
+ if '_temp' in k:
578
+ state_dict.update({k: v})
579
+ if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
580
+ k = k.replace('attn_fcross', 'attn1')
581
+ state_dict.update({k: state_dict[k]})
582
+ if 'norm_fcross' in k:
583
+ k = k.replace('norm_fcross', 'norm1')
584
+ state_dict.update({k: state_dict[k]})
585
+
586
+ model.load_state_dict(state_dict)
587
+
588
+ return model
589
+
590
+ if __name__ == '__main__':
591
+ import torch
592
+ # from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
593
+
594
+ device = "cuda" if torch.cuda.is_available() else "cpu"
595
+
596
+ pretrained_model_path = "/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-v1-4/" # p cluster
597
+ unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet").to(device)
598
+ # unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
599
+ unet.enable_xformers_memory_efficient_attention()
600
+ unet.enable_gradient_checkpointing()
601
+
602
+ unet.train()
603
+
604
+ use_image_num = 5
605
+ noisy_latents = torch.randn((2, 4, 16 + use_image_num, 32, 32)).to(device)
606
+ bsz = noisy_latents.shape[0]
607
+ timesteps = torch.randint(0, 1000, (bsz,)).to(device)
608
+ timesteps = timesteps.long()
609
+ encoder_hidden_states = torch.randn((bsz, 1 + use_image_num, 77, 768)).to(device)
610
+ # class_labels = torch.randn((bsz, )).to(device)
611
+
612
+
613
+ model_pred = unet(sample=noisy_latents, timestep=timesteps,
614
+ encoder_hidden_states=encoder_hidden_states,
615
+ class_labels=None,
616
+ use_image_num=use_image_num).sample
617
+ print(model_pred.shape)
base/models/unet_blocks.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+ import os
3
+ import sys
4
+ sys.path.append(os.path.split(sys.path[0])[0])
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ try:
10
+ from .attention import Transformer3DModel
11
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
12
+ except:
13
+ from attention import Transformer3DModel
14
+ from resnet import Downsample3D, ResnetBlock3D, Upsample3D
15
+
16
+
17
+ def get_down_block(
18
+ down_block_type,
19
+ num_layers,
20
+ in_channels,
21
+ out_channels,
22
+ temb_channels,
23
+ add_downsample,
24
+ resnet_eps,
25
+ resnet_act_fn,
26
+ attn_num_head_channels,
27
+ resnet_groups=None,
28
+ cross_attention_dim=None,
29
+ downsample_padding=None,
30
+ dual_cross_attention=False,
31
+ use_linear_projection=False,
32
+ only_cross_attention=False,
33
+ upcast_attention=False,
34
+ resnet_time_scale_shift="default",
35
+ use_first_frame=False,
36
+ use_relative_position=False,
37
+ rotary_emb=False,
38
+ ):
39
+ # print(down_block_type)
40
+ # print(use_first_frame)
41
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
42
+ if down_block_type == "DownBlock3D":
43
+ return DownBlock3D(
44
+ num_layers=num_layers,
45
+ in_channels=in_channels,
46
+ out_channels=out_channels,
47
+ temb_channels=temb_channels,
48
+ add_downsample=add_downsample,
49
+ resnet_eps=resnet_eps,
50
+ resnet_act_fn=resnet_act_fn,
51
+ resnet_groups=resnet_groups,
52
+ downsample_padding=downsample_padding,
53
+ resnet_time_scale_shift=resnet_time_scale_shift,
54
+ )
55
+ elif down_block_type == "CrossAttnDownBlock3D":
56
+ if cross_attention_dim is None:
57
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
58
+ return CrossAttnDownBlock3D(
59
+ num_layers=num_layers,
60
+ in_channels=in_channels,
61
+ out_channels=out_channels,
62
+ temb_channels=temb_channels,
63
+ add_downsample=add_downsample,
64
+ resnet_eps=resnet_eps,
65
+ resnet_act_fn=resnet_act_fn,
66
+ resnet_groups=resnet_groups,
67
+ downsample_padding=downsample_padding,
68
+ cross_attention_dim=cross_attention_dim,
69
+ attn_num_head_channels=attn_num_head_channels,
70
+ dual_cross_attention=dual_cross_attention,
71
+ use_linear_projection=use_linear_projection,
72
+ only_cross_attention=only_cross_attention,
73
+ upcast_attention=upcast_attention,
74
+ resnet_time_scale_shift=resnet_time_scale_shift,
75
+ use_first_frame=use_first_frame,
76
+ use_relative_position=use_relative_position,
77
+ rotary_emb=rotary_emb,
78
+ )
79
+ raise ValueError(f"{down_block_type} does not exist.")
80
+
81
+
82
+ def get_up_block(
83
+ up_block_type,
84
+ num_layers,
85
+ in_channels,
86
+ out_channels,
87
+ prev_output_channel,
88
+ temb_channels,
89
+ add_upsample,
90
+ resnet_eps,
91
+ resnet_act_fn,
92
+ attn_num_head_channels,
93
+ resnet_groups=None,
94
+ cross_attention_dim=None,
95
+ dual_cross_attention=False,
96
+ use_linear_projection=False,
97
+ only_cross_attention=False,
98
+ upcast_attention=False,
99
+ resnet_time_scale_shift="default",
100
+ use_first_frame=False,
101
+ use_relative_position=False,
102
+ rotary_emb=False,
103
+ ):
104
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
105
+ if up_block_type == "UpBlock3D":
106
+ return UpBlock3D(
107
+ num_layers=num_layers,
108
+ in_channels=in_channels,
109
+ out_channels=out_channels,
110
+ prev_output_channel=prev_output_channel,
111
+ temb_channels=temb_channels,
112
+ add_upsample=add_upsample,
113
+ resnet_eps=resnet_eps,
114
+ resnet_act_fn=resnet_act_fn,
115
+ resnet_groups=resnet_groups,
116
+ resnet_time_scale_shift=resnet_time_scale_shift,
117
+ )
118
+ elif up_block_type == "CrossAttnUpBlock3D":
119
+ if cross_attention_dim is None:
120
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
121
+ return CrossAttnUpBlock3D(
122
+ num_layers=num_layers,
123
+ in_channels=in_channels,
124
+ out_channels=out_channels,
125
+ prev_output_channel=prev_output_channel,
126
+ temb_channels=temb_channels,
127
+ add_upsample=add_upsample,
128
+ resnet_eps=resnet_eps,
129
+ resnet_act_fn=resnet_act_fn,
130
+ resnet_groups=resnet_groups,
131
+ cross_attention_dim=cross_attention_dim,
132
+ attn_num_head_channels=attn_num_head_channels,
133
+ dual_cross_attention=dual_cross_attention,
134
+ use_linear_projection=use_linear_projection,
135
+ only_cross_attention=only_cross_attention,
136
+ upcast_attention=upcast_attention,
137
+ resnet_time_scale_shift=resnet_time_scale_shift,
138
+ use_first_frame=use_first_frame,
139
+ use_relative_position=use_relative_position,
140
+ rotary_emb=rotary_emb,
141
+ )
142
+ raise ValueError(f"{up_block_type} does not exist.")
143
+
144
+
145
+ class UNetMidBlock3DCrossAttn(nn.Module):
146
+ def __init__(
147
+ self,
148
+ in_channels: int,
149
+ temb_channels: int,
150
+ dropout: float = 0.0,
151
+ num_layers: int = 1,
152
+ resnet_eps: float = 1e-6,
153
+ resnet_time_scale_shift: str = "default",
154
+ resnet_act_fn: str = "swish",
155
+ resnet_groups: int = 32,
156
+ resnet_pre_norm: bool = True,
157
+ attn_num_head_channels=1,
158
+ output_scale_factor=1.0,
159
+ cross_attention_dim=1280,
160
+ dual_cross_attention=False,
161
+ use_linear_projection=False,
162
+ upcast_attention=False,
163
+ use_first_frame=False,
164
+ use_relative_position=False,
165
+ rotary_emb=False,
166
+ ):
167
+ super().__init__()
168
+
169
+ self.has_cross_attention = True
170
+ self.attn_num_head_channels = attn_num_head_channels
171
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
172
+
173
+ # there is always at least one resnet
174
+ resnets = [
175
+ ResnetBlock3D(
176
+ in_channels=in_channels,
177
+ out_channels=in_channels,
178
+ temb_channels=temb_channels,
179
+ eps=resnet_eps,
180
+ groups=resnet_groups,
181
+ dropout=dropout,
182
+ time_embedding_norm=resnet_time_scale_shift,
183
+ non_linearity=resnet_act_fn,
184
+ output_scale_factor=output_scale_factor,
185
+ pre_norm=resnet_pre_norm,
186
+ )
187
+ ]
188
+ attentions = []
189
+
190
+ for _ in range(num_layers):
191
+ if dual_cross_attention:
192
+ raise NotImplementedError
193
+ attentions.append(
194
+ Transformer3DModel(
195
+ attn_num_head_channels,
196
+ in_channels // attn_num_head_channels,
197
+ in_channels=in_channels,
198
+ num_layers=1,
199
+ cross_attention_dim=cross_attention_dim,
200
+ norm_num_groups=resnet_groups,
201
+ use_linear_projection=use_linear_projection,
202
+ upcast_attention=upcast_attention,
203
+ use_first_frame=use_first_frame,
204
+ use_relative_position=use_relative_position,
205
+ rotary_emb=rotary_emb,
206
+ )
207
+ )
208
+ resnets.append(
209
+ ResnetBlock3D(
210
+ in_channels=in_channels,
211
+ out_channels=in_channels,
212
+ temb_channels=temb_channels,
213
+ eps=resnet_eps,
214
+ groups=resnet_groups,
215
+ dropout=dropout,
216
+ time_embedding_norm=resnet_time_scale_shift,
217
+ non_linearity=resnet_act_fn,
218
+ output_scale_factor=output_scale_factor,
219
+ pre_norm=resnet_pre_norm,
220
+ )
221
+ )
222
+
223
+ self.attentions = nn.ModuleList(attentions)
224
+ self.resnets = nn.ModuleList(resnets)
225
+
226
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
227
+ hidden_states = self.resnets[0](hidden_states, temb)
228
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
229
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample
230
+ hidden_states = resnet(hidden_states, temb)
231
+
232
+ return hidden_states
233
+
234
+
235
+ class CrossAttnDownBlock3D(nn.Module):
236
+ def __init__(
237
+ self,
238
+ in_channels: int,
239
+ out_channels: int,
240
+ temb_channels: int,
241
+ dropout: float = 0.0,
242
+ num_layers: int = 1,
243
+ resnet_eps: float = 1e-6,
244
+ resnet_time_scale_shift: str = "default",
245
+ resnet_act_fn: str = "swish",
246
+ resnet_groups: int = 32,
247
+ resnet_pre_norm: bool = True,
248
+ attn_num_head_channels=1,
249
+ cross_attention_dim=1280,
250
+ output_scale_factor=1.0,
251
+ downsample_padding=1,
252
+ add_downsample=True,
253
+ dual_cross_attention=False,
254
+ use_linear_projection=False,
255
+ only_cross_attention=False,
256
+ upcast_attention=False,
257
+ use_first_frame=False,
258
+ use_relative_position=False,
259
+ rotary_emb=False,
260
+ ):
261
+ super().__init__()
262
+ resnets = []
263
+ attentions = []
264
+
265
+ # print(use_first_frame)
266
+
267
+ self.has_cross_attention = True
268
+ self.attn_num_head_channels = attn_num_head_channels
269
+
270
+ for i in range(num_layers):
271
+ in_channels = in_channels if i == 0 else out_channels
272
+ resnets.append(
273
+ ResnetBlock3D(
274
+ in_channels=in_channels,
275
+ out_channels=out_channels,
276
+ temb_channels=temb_channels,
277
+ eps=resnet_eps,
278
+ groups=resnet_groups,
279
+ dropout=dropout,
280
+ time_embedding_norm=resnet_time_scale_shift,
281
+ non_linearity=resnet_act_fn,
282
+ output_scale_factor=output_scale_factor,
283
+ pre_norm=resnet_pre_norm,
284
+ )
285
+ )
286
+ if dual_cross_attention:
287
+ raise NotImplementedError
288
+ attentions.append(
289
+ Transformer3DModel(
290
+ attn_num_head_channels,
291
+ out_channels // attn_num_head_channels,
292
+ in_channels=out_channels,
293
+ num_layers=1,
294
+ cross_attention_dim=cross_attention_dim,
295
+ norm_num_groups=resnet_groups,
296
+ use_linear_projection=use_linear_projection,
297
+ only_cross_attention=only_cross_attention,
298
+ upcast_attention=upcast_attention,
299
+ use_first_frame=use_first_frame,
300
+ use_relative_position=use_relative_position,
301
+ rotary_emb=rotary_emb,
302
+ )
303
+ )
304
+ self.attentions = nn.ModuleList(attentions)
305
+ self.resnets = nn.ModuleList(resnets)
306
+
307
+ if add_downsample:
308
+ self.downsamplers = nn.ModuleList(
309
+ [
310
+ Downsample3D(
311
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
312
+ )
313
+ ]
314
+ )
315
+ else:
316
+ self.downsamplers = None
317
+
318
+ self.gradient_checkpointing = False
319
+
320
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
321
+ output_states = ()
322
+
323
+ for resnet, attn in zip(self.resnets, self.attentions):
324
+ if self.training and self.gradient_checkpointing:
325
+
326
+ def create_custom_forward(module, return_dict=None):
327
+ def custom_forward(*inputs):
328
+ if return_dict is not None:
329
+ return module(*inputs, return_dict=return_dict)
330
+ else:
331
+ return module(*inputs)
332
+
333
+ return custom_forward
334
+
335
+ def create_custom_forward_attn(module, return_dict=None, use_image_num=None):
336
+ def custom_forward(*inputs):
337
+ if return_dict is not None:
338
+ return module(*inputs, return_dict=return_dict, use_image_num=use_image_num)
339
+ else:
340
+ return module(*inputs, use_image_num=use_image_num)
341
+
342
+ return custom_forward
343
+
344
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
345
+ hidden_states = torch.utils.checkpoint.checkpoint(
346
+ create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num),
347
+ hidden_states,
348
+ encoder_hidden_states,
349
+ )[0]
350
+ else:
351
+ hidden_states = resnet(hidden_states, temb)
352
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample
353
+
354
+ output_states += (hidden_states,)
355
+
356
+ if self.downsamplers is not None:
357
+ for downsampler in self.downsamplers:
358
+ hidden_states = downsampler(hidden_states)
359
+
360
+ output_states += (hidden_states,)
361
+
362
+ return hidden_states, output_states
363
+
364
+
365
+ class DownBlock3D(nn.Module):
366
+ def __init__(
367
+ self,
368
+ in_channels: int,
369
+ out_channels: int,
370
+ temb_channels: int,
371
+ dropout: float = 0.0,
372
+ num_layers: int = 1,
373
+ resnet_eps: float = 1e-6,
374
+ resnet_time_scale_shift: str = "default",
375
+ resnet_act_fn: str = "swish",
376
+ resnet_groups: int = 32,
377
+ resnet_pre_norm: bool = True,
378
+ output_scale_factor=1.0,
379
+ add_downsample=True,
380
+ downsample_padding=1,
381
+ ):
382
+ super().__init__()
383
+ resnets = []
384
+
385
+ for i in range(num_layers):
386
+ in_channels = in_channels if i == 0 else out_channels
387
+ resnets.append(
388
+ ResnetBlock3D(
389
+ in_channels=in_channels,
390
+ out_channels=out_channels,
391
+ temb_channels=temb_channels,
392
+ eps=resnet_eps,
393
+ groups=resnet_groups,
394
+ dropout=dropout,
395
+ time_embedding_norm=resnet_time_scale_shift,
396
+ non_linearity=resnet_act_fn,
397
+ output_scale_factor=output_scale_factor,
398
+ pre_norm=resnet_pre_norm,
399
+ )
400
+ )
401
+
402
+ self.resnets = nn.ModuleList(resnets)
403
+
404
+ if add_downsample:
405
+ self.downsamplers = nn.ModuleList(
406
+ [
407
+ Downsample3D(
408
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
409
+ )
410
+ ]
411
+ )
412
+ else:
413
+ self.downsamplers = None
414
+
415
+ self.gradient_checkpointing = False
416
+
417
+ def forward(self, hidden_states, temb=None):
418
+ output_states = ()
419
+
420
+ for resnet in self.resnets:
421
+ if self.training and self.gradient_checkpointing:
422
+
423
+ def create_custom_forward(module):
424
+ def custom_forward(*inputs):
425
+ return module(*inputs)
426
+
427
+ return custom_forward
428
+
429
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
430
+ else:
431
+ hidden_states = resnet(hidden_states, temb)
432
+
433
+ output_states += (hidden_states,)
434
+
435
+ if self.downsamplers is not None:
436
+ for downsampler in self.downsamplers:
437
+ hidden_states = downsampler(hidden_states)
438
+
439
+ output_states += (hidden_states,)
440
+
441
+ return hidden_states, output_states
442
+
443
+
444
+ class CrossAttnUpBlock3D(nn.Module):
445
+ def __init__(
446
+ self,
447
+ in_channels: int,
448
+ out_channels: int,
449
+ prev_output_channel: int,
450
+ temb_channels: int,
451
+ dropout: float = 0.0,
452
+ num_layers: int = 1,
453
+ resnet_eps: float = 1e-6,
454
+ resnet_time_scale_shift: str = "default",
455
+ resnet_act_fn: str = "swish",
456
+ resnet_groups: int = 32,
457
+ resnet_pre_norm: bool = True,
458
+ attn_num_head_channels=1,
459
+ cross_attention_dim=1280,
460
+ output_scale_factor=1.0,
461
+ add_upsample=True,
462
+ dual_cross_attention=False,
463
+ use_linear_projection=False,
464
+ only_cross_attention=False,
465
+ upcast_attention=False,
466
+ use_first_frame=False,
467
+ use_relative_position=False,
468
+ rotary_emb=False
469
+ ):
470
+ super().__init__()
471
+ resnets = []
472
+ attentions = []
473
+
474
+ self.has_cross_attention = True
475
+ self.attn_num_head_channels = attn_num_head_channels
476
+
477
+ for i in range(num_layers):
478
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
479
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
480
+
481
+ resnets.append(
482
+ ResnetBlock3D(
483
+ in_channels=resnet_in_channels + res_skip_channels,
484
+ out_channels=out_channels,
485
+ temb_channels=temb_channels,
486
+ eps=resnet_eps,
487
+ groups=resnet_groups,
488
+ dropout=dropout,
489
+ time_embedding_norm=resnet_time_scale_shift,
490
+ non_linearity=resnet_act_fn,
491
+ output_scale_factor=output_scale_factor,
492
+ pre_norm=resnet_pre_norm,
493
+ )
494
+ )
495
+ if dual_cross_attention:
496
+ raise NotImplementedError
497
+ attentions.append(
498
+ Transformer3DModel(
499
+ attn_num_head_channels,
500
+ out_channels // attn_num_head_channels,
501
+ in_channels=out_channels,
502
+ num_layers=1,
503
+ cross_attention_dim=cross_attention_dim,
504
+ norm_num_groups=resnet_groups,
505
+ use_linear_projection=use_linear_projection,
506
+ only_cross_attention=only_cross_attention,
507
+ upcast_attention=upcast_attention,
508
+ use_first_frame=use_first_frame,
509
+ use_relative_position=use_relative_position,
510
+ rotary_emb=rotary_emb,
511
+ )
512
+ )
513
+
514
+ self.attentions = nn.ModuleList(attentions)
515
+ self.resnets = nn.ModuleList(resnets)
516
+
517
+ if add_upsample:
518
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
519
+ else:
520
+ self.upsamplers = None
521
+
522
+ self.gradient_checkpointing = False
523
+
524
+ def forward(
525
+ self,
526
+ hidden_states,
527
+ res_hidden_states_tuple,
528
+ temb=None,
529
+ encoder_hidden_states=None,
530
+ upsample_size=None,
531
+ attention_mask=None,
532
+ use_image_num=None,
533
+ ):
534
+ for resnet, attn in zip(self.resnets, self.attentions):
535
+ # pop res hidden states
536
+ res_hidden_states = res_hidden_states_tuple[-1]
537
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
538
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
539
+
540
+ if self.training and self.gradient_checkpointing:
541
+
542
+ def create_custom_forward(module, return_dict=None):
543
+ def custom_forward(*inputs):
544
+ if return_dict is not None:
545
+ return module(*inputs, return_dict=return_dict)
546
+ else:
547
+ return module(*inputs)
548
+
549
+ return custom_forward
550
+
551
+ def create_custom_forward_attn(module, return_dict=None, use_image_num=None):
552
+ def custom_forward(*inputs):
553
+ if return_dict is not None:
554
+ return module(*inputs, return_dict=return_dict, use_image_num=use_image_num)
555
+ else:
556
+ return module(*inputs, use_image_num=use_image_num)
557
+
558
+ return custom_forward
559
+
560
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
561
+ hidden_states = torch.utils.checkpoint.checkpoint(
562
+ create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num),
563
+ hidden_states,
564
+ encoder_hidden_states,
565
+ )[0]
566
+ else:
567
+ hidden_states = resnet(hidden_states, temb)
568
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample
569
+
570
+ if self.upsamplers is not None:
571
+ for upsampler in self.upsamplers:
572
+ hidden_states = upsampler(hidden_states, upsample_size)
573
+
574
+ return hidden_states
575
+
576
+
577
+ class UpBlock3D(nn.Module):
578
+ def __init__(
579
+ self,
580
+ in_channels: int,
581
+ prev_output_channel: int,
582
+ out_channels: int,
583
+ temb_channels: int,
584
+ dropout: float = 0.0,
585
+ num_layers: int = 1,
586
+ resnet_eps: float = 1e-6,
587
+ resnet_time_scale_shift: str = "default",
588
+ resnet_act_fn: str = "swish",
589
+ resnet_groups: int = 32,
590
+ resnet_pre_norm: bool = True,
591
+ output_scale_factor=1.0,
592
+ add_upsample=True,
593
+ ):
594
+ super().__init__()
595
+ resnets = []
596
+
597
+ for i in range(num_layers):
598
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
599
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
600
+
601
+ resnets.append(
602
+ ResnetBlock3D(
603
+ in_channels=resnet_in_channels + res_skip_channels,
604
+ out_channels=out_channels,
605
+ temb_channels=temb_channels,
606
+ eps=resnet_eps,
607
+ groups=resnet_groups,
608
+ dropout=dropout,
609
+ time_embedding_norm=resnet_time_scale_shift,
610
+ non_linearity=resnet_act_fn,
611
+ output_scale_factor=output_scale_factor,
612
+ pre_norm=resnet_pre_norm,
613
+ )
614
+ )
615
+
616
+ self.resnets = nn.ModuleList(resnets)
617
+
618
+ if add_upsample:
619
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
620
+ else:
621
+ self.upsamplers = None
622
+
623
+ self.gradient_checkpointing = False
624
+
625
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
626
+ for resnet in self.resnets:
627
+ # pop res hidden states
628
+ res_hidden_states = res_hidden_states_tuple[-1]
629
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
630
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
631
+
632
+ if self.training and self.gradient_checkpointing:
633
+
634
+ def create_custom_forward(module):
635
+ def custom_forward(*inputs):
636
+ return module(*inputs)
637
+
638
+ return custom_forward
639
+
640
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
641
+ else:
642
+ hidden_states = resnet(hidden_states, temb)
643
+
644
+ if self.upsamplers is not None:
645
+ for upsampler in self.upsamplers:
646
+ hidden_states = upsampler(hidden_states, upsample_size)
647
+
648
+ return hidden_states
base/models/utils.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+
15
+ import numpy as np
16
+ import torch.nn as nn
17
+
18
+ from einops import repeat
19
+
20
+
21
+ #################################################################################
22
+ # Unet Utils #
23
+ #################################################################################
24
+
25
+ def checkpoint(func, inputs, params, flag):
26
+ """
27
+ Evaluate a function without caching intermediate activations, allowing for
28
+ reduced memory at the expense of extra compute in the backward pass.
29
+ :param func: the function to evaluate.
30
+ :param inputs: the argument sequence to pass to `func`.
31
+ :param params: a sequence of parameters `func` depends on but does not
32
+ explicitly take as arguments.
33
+ :param flag: if False, disable gradient checkpointing.
34
+ """
35
+ if flag:
36
+ args = tuple(inputs) + tuple(params)
37
+ return CheckpointFunction.apply(func, len(inputs), *args)
38
+ else:
39
+ return func(*inputs)
40
+
41
+
42
+ class CheckpointFunction(torch.autograd.Function):
43
+ @staticmethod
44
+ def forward(ctx, run_function, length, *args):
45
+ ctx.run_function = run_function
46
+ ctx.input_tensors = list(args[:length])
47
+ ctx.input_params = list(args[length:])
48
+
49
+ with torch.no_grad():
50
+ output_tensors = ctx.run_function(*ctx.input_tensors)
51
+ return output_tensors
52
+
53
+ @staticmethod
54
+ def backward(ctx, *output_grads):
55
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
56
+ with torch.enable_grad():
57
+ # Fixes a bug where the first op in run_function modifies the
58
+ # Tensor storage in place, which is not allowed for detach()'d
59
+ # Tensors.
60
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
61
+ output_tensors = ctx.run_function(*shallow_copies)
62
+ input_grads = torch.autograd.grad(
63
+ output_tensors,
64
+ ctx.input_tensors + ctx.input_params,
65
+ output_grads,
66
+ allow_unused=True,
67
+ )
68
+ del ctx.input_tensors
69
+ del ctx.input_params
70
+ del output_tensors
71
+ return (None, None) + input_grads
72
+
73
+
74
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
75
+ """
76
+ Create sinusoidal timestep embeddings.
77
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
78
+ These may be fractional.
79
+ :param dim: the dimension of the output.
80
+ :param max_period: controls the minimum frequency of the embeddings.
81
+ :return: an [N x dim] Tensor of positional embeddings.
82
+ """
83
+ if not repeat_only:
84
+ half = dim // 2
85
+ freqs = torch.exp(
86
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
87
+ ).to(device=timesteps.device)
88
+ args = timesteps[:, None].float() * freqs[None]
89
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
90
+ if dim % 2:
91
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
92
+ else:
93
+ embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous()
94
+ return embedding
95
+
96
+
97
+ def zero_module(module):
98
+ """
99
+ Zero out the parameters of a module and return it.
100
+ """
101
+ for p in module.parameters():
102
+ p.detach().zero_()
103
+ return module
104
+
105
+
106
+ def scale_module(module, scale):
107
+ """
108
+ Scale the parameters of a module and return it.
109
+ """
110
+ for p in module.parameters():
111
+ p.detach().mul_(scale)
112
+ return module
113
+
114
+
115
+ def mean_flat(tensor):
116
+ """
117
+ Take the mean over all non-batch dimensions.
118
+ """
119
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
120
+
121
+
122
+ def normalization(channels):
123
+ """
124
+ Make a standard normalization layer.
125
+ :param channels: number of input channels.
126
+ :return: an nn.Module for normalization.
127
+ """
128
+ return GroupNorm32(32, channels)
129
+
130
+
131
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
132
+ class SiLU(nn.Module):
133
+ def forward(self, x):
134
+ return x * torch.sigmoid(x)
135
+
136
+
137
+ class GroupNorm32(nn.GroupNorm):
138
+ def forward(self, x):
139
+ return super().forward(x.float()).type(x.dtype)
140
+
141
+ def conv_nd(dims, *args, **kwargs):
142
+ """
143
+ Create a 1D, 2D, or 3D convolution module.
144
+ """
145
+ if dims == 1:
146
+ return nn.Conv1d(*args, **kwargs)
147
+ elif dims == 2:
148
+ return nn.Conv2d(*args, **kwargs)
149
+ elif dims == 3:
150
+ return nn.Conv3d(*args, **kwargs)
151
+ raise ValueError(f"unsupported dimensions: {dims}")
152
+
153
+
154
+ def linear(*args, **kwargs):
155
+ """
156
+ Create a linear module.
157
+ """
158
+ return nn.Linear(*args, **kwargs)
159
+
160
+
161
+ def avg_pool_nd(dims, *args, **kwargs):
162
+ """
163
+ Create a 1D, 2D, or 3D average pooling module.
164
+ """
165
+ if dims == 1:
166
+ return nn.AvgPool1d(*args, **kwargs)
167
+ elif dims == 2:
168
+ return nn.AvgPool2d(*args, **kwargs)
169
+ elif dims == 3:
170
+ return nn.AvgPool3d(*args, **kwargs)
171
+ raise ValueError(f"unsupported dimensions: {dims}")
172
+
173
+
174
+ # class HybridConditioner(nn.Module):
175
+
176
+ # def __init__(self, c_concat_config, c_crossattn_config):
177
+ # super().__init__()
178
+ # self.concat_conditioner = instantiate_from_config(c_concat_config)
179
+ # self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
180
+
181
+ # def forward(self, c_concat, c_crossattn):
182
+ # c_concat = self.concat_conditioner(c_concat)
183
+ # c_crossattn = self.crossattn_conditioner(c_crossattn)
184
+ # return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
185
+
186
+
187
+ def noise_like(shape, device, repeat=False):
188
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
189
+ noise = lambda: torch.randn(shape, device=device)
190
+ return repeat_noise() if repeat else noise()
191
+
192
+ def count_flops_attn(model, _x, y):
193
+ """
194
+ A counter for the `thop` package to count the operations in an
195
+ attention operation.
196
+ Meant to be used like:
197
+ macs, params = thop.profile(
198
+ model,
199
+ inputs=(inputs, timestamps),
200
+ custom_ops={QKVAttention: QKVAttention.count_flops},
201
+ )
202
+ """
203
+ b, c, *spatial = y[0].shape
204
+ num_spatial = int(np.prod(spatial))
205
+ # We perform two matmuls with the same number of ops.
206
+ # The first computes the weight matrix, the second computes
207
+ # the combination of the value vectors.
208
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
209
+ model.total_ops += torch.DoubleTensor([matmul_ops])
210
+
211
+ def count_params(model, verbose=False):
212
+ total_params = sum(p.numel() for p in model.parameters())
213
+ if verbose:
214
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
215
+ return total_params
base/pipelines/__pycache__/pipeline_videogen.cpython-311.pyc ADDED
Binary file (34.9 kB). View file
 
base/pipelines/pipeline_videogen.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+
14
+ import inspect
15
+ from typing import Any, Callable, Dict, List, Optional, Union
16
+ import einops
17
+ import torch
18
+ from packaging import version
19
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
20
+
21
+ from diffusers.configuration_utils import FrozenDict
22
+ from diffusers.models import AutoencoderKL
23
+ from diffusers.schedulers import KarrasDiffusionSchedulers
24
+ from diffusers.utils import (
25
+ deprecate,
26
+ is_accelerate_available,
27
+ is_accelerate_version,
28
+ logging,
29
+ #randn_tensor,
30
+ replace_example_docstring,
31
+ BaseOutput,
32
+ )
33
+
34
+ try:
35
+ from diffusers.utils import randn_tensor
36
+ except:
37
+ from diffusers.utils.torch_utils import randn_tensor
38
+
39
+
40
+ from diffusers.pipeline_utils import DiffusionPipeline
41
+ from dataclasses import dataclass
42
+
43
+ import os, sys
44
+ sys.path.append(os.path.split(sys.path[0])[0])
45
+ from models.unet import UNet3DConditionModel
46
+
47
+ import numpy as np
48
+
49
+ @dataclass
50
+ class StableDiffusionPipelineOutput(BaseOutput):
51
+ video: torch.Tensor
52
+
53
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
54
+
55
+ EXAMPLE_DOC_STRING = """
56
+ Examples:
57
+ ```py
58
+ >>> import torch
59
+ >>> from diffusers import StableDiffusionPipeline
60
+
61
+ >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
62
+ >>> pipe = pipe.to("cuda")
63
+
64
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
65
+ >>> image = pipe(prompt).images[0]
66
+ ```
67
+ """
68
+
69
+
70
+ class VideoGenPipeline(DiffusionPipeline):
71
+ r"""
72
+ Pipeline for text-to-image generation using Stable Diffusion.
73
+
74
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
75
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
76
+
77
+ Args:
78
+ vae ([`AutoencoderKL`]):
79
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
80
+ text_encoder ([`CLIPTextModel`]):
81
+ Frozen text-encoder. Stable Diffusion uses the text portion of
82
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
83
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
84
+ tokenizer (`CLIPTokenizer`):
85
+ Tokenizer of class
86
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
87
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
88
+ scheduler ([`SchedulerMixin`]):
89
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
90
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
91
+ safety_checker ([`StableDiffusionSafetyChecker`]):
92
+ Classification module that estimates whether generated images could be considered offensive or harmful.
93
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
94
+ feature_extractor ([`CLIPFeatureExtractor`]):
95
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
96
+ """
97
+ _optional_components = ["safety_checker", "feature_extractor"]
98
+
99
+ def __init__(
100
+ self,
101
+ vae: AutoencoderKL,
102
+ text_encoder: CLIPTextModel,
103
+ tokenizer: CLIPTokenizer,
104
+ unet: UNet3DConditionModel,
105
+ scheduler: KarrasDiffusionSchedulers,
106
+ ):
107
+ super().__init__()
108
+
109
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
110
+ deprecation_message = (
111
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
112
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
113
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
114
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
115
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
116
+ " file"
117
+ )
118
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
119
+ new_config = dict(scheduler.config)
120
+ new_config["steps_offset"] = 1
121
+ scheduler._internal_dict = FrozenDict(new_config)
122
+
123
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
124
+ deprecation_message = (
125
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
126
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
127
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
128
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
129
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
130
+ )
131
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
132
+ new_config = dict(scheduler.config)
133
+ new_config["clip_sample"] = False
134
+ scheduler._internal_dict = FrozenDict(new_config)
135
+
136
+
137
+
138
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
139
+ version.parse(unet.config._diffusers_version).base_version
140
+ ) < version.parse("0.9.0.dev0")
141
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
142
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
143
+ deprecation_message = (
144
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
145
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
146
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
147
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
148
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
149
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
150
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
151
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
152
+ " the `unet/config.json` file"
153
+ )
154
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
155
+ new_config = dict(unet.config)
156
+ new_config["sample_size"] = 64
157
+ unet._internal_dict = FrozenDict(new_config)
158
+
159
+ self.register_modules(
160
+ vae=vae,
161
+ text_encoder=text_encoder,
162
+ tokenizer=tokenizer,
163
+ unet=unet,
164
+ scheduler=scheduler,
165
+ )
166
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
167
+ # self.register_to_config(requires_safety_checker=requires_safety_checker)
168
+
169
+ def enable_vae_slicing(self):
170
+ r"""
171
+ Enable sliced VAE decoding.
172
+
173
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
174
+ steps. This is useful to save some memory and allow larger batch sizes.
175
+ """
176
+ self.vae.enable_slicing()
177
+
178
+ def disable_vae_slicing(self):
179
+ r"""
180
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
181
+ computing decoding in one step.
182
+ """
183
+ self.vae.disable_slicing()
184
+
185
+ def enable_vae_tiling(self):
186
+ r"""
187
+ Enable tiled VAE decoding.
188
+
189
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
190
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
191
+ """
192
+ self.vae.enable_tiling()
193
+
194
+ def disable_vae_tiling(self):
195
+ r"""
196
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
197
+ computing decoding in one step.
198
+ """
199
+ self.vae.disable_tiling()
200
+
201
+ def enable_sequential_cpu_offload(self, gpu_id=0):
202
+ r"""
203
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
204
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
205
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
206
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
207
+ `enable_model_cpu_offload`, but performance is lower.
208
+ """
209
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
210
+ from accelerate import cpu_offload
211
+ else:
212
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
213
+
214
+ device = torch.device(f"cuda:{gpu_id}")
215
+
216
+ if self.device.type != "cpu":
217
+ self.to("cpu", silence_dtype_warnings=True)
218
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
219
+
220
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
221
+ cpu_offload(cpu_offloaded_model, device)
222
+
223
+ # if self.safety_checker is not None:
224
+ # cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
225
+
226
+ def enable_model_cpu_offload(self, gpu_id=0):
227
+ r"""
228
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
229
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
230
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
231
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
232
+ """
233
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
234
+ from accelerate import cpu_offload_with_hook
235
+ else:
236
+ raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
237
+
238
+ device = torch.device(f"cuda:{gpu_id}")
239
+
240
+ if self.device.type != "cpu":
241
+ self.to("cpu", silence_dtype_warnings=True)
242
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
243
+
244
+ hook = None
245
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
246
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
247
+
248
+ self.final_offload_hook = hook
249
+
250
+ @property
251
+ def _execution_device(self):
252
+ r"""
253
+ Returns the device on which the pipeline's models will be executed. After calling
254
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
255
+ hooks.
256
+ """
257
+ if not hasattr(self.unet, "_hf_hook"):
258
+ return self.device
259
+ for module in self.unet.modules():
260
+ if (
261
+ hasattr(module, "_hf_hook")
262
+ and hasattr(module._hf_hook, "execution_device")
263
+ and module._hf_hook.execution_device is not None
264
+ ):
265
+ return torch.device(module._hf_hook.execution_device)
266
+ return self.device
267
+
268
+ def _encode_prompt(
269
+ self,
270
+ prompt,
271
+ device,
272
+ num_images_per_prompt,
273
+ do_classifier_free_guidance,
274
+ negative_prompt=None,
275
+ prompt_embeds: Optional[torch.FloatTensor] = None,
276
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
277
+ ):
278
+ r"""
279
+ Encodes the prompt into text encoder hidden states.
280
+
281
+ Args:
282
+ prompt (`str` or `List[str]`, *optional*):
283
+ prompt to be encoded
284
+ device: (`torch.device`):
285
+ torch device
286
+ num_images_per_prompt (`int`):
287
+ number of images that should be generated per prompt
288
+ do_classifier_free_guidance (`bool`):
289
+ whether to use classifier free guidance or not
290
+ negative_prompt (`str` or `List[str]`, *optional*):
291
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
292
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
293
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
294
+ prompt_embeds (`torch.FloatTensor`, *optional*):
295
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
296
+ provided, text embeddings will be generated from `prompt` input argument.
297
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
298
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
299
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
300
+ argument.
301
+ """
302
+ if prompt is not None and isinstance(prompt, str):
303
+ batch_size = 1
304
+ elif prompt is not None and isinstance(prompt, list):
305
+ batch_size = len(prompt)
306
+ else:
307
+ batch_size = prompt_embeds.shape[0]
308
+
309
+ if prompt_embeds is None:
310
+ text_inputs = self.tokenizer(
311
+ prompt,
312
+ padding="max_length",
313
+ max_length=self.tokenizer.model_max_length,
314
+ truncation=True,
315
+ return_tensors="pt",
316
+ )
317
+ text_input_ids = text_inputs.input_ids
318
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
319
+
320
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
321
+ text_input_ids, untruncated_ids
322
+ ):
323
+ removed_text = self.tokenizer.batch_decode(
324
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
325
+ )
326
+ logger.warning(
327
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
328
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
329
+ )
330
+
331
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
332
+ attention_mask = text_inputs.attention_mask.to(device)
333
+ else:
334
+ attention_mask = None
335
+
336
+ prompt_embeds = self.text_encoder(
337
+ text_input_ids.to(device),
338
+ attention_mask=attention_mask,
339
+ )
340
+ prompt_embeds = prompt_embeds[0]
341
+
342
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
343
+
344
+ bs_embed, seq_len, _ = prompt_embeds.shape
345
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
346
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
347
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
348
+
349
+ # get unconditional embeddings for classifier free guidance
350
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
351
+ uncond_tokens: List[str]
352
+ if negative_prompt is None:
353
+ uncond_tokens = [""] * batch_size
354
+ elif type(prompt) is not type(negative_prompt):
355
+ raise TypeError(
356
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
357
+ f" {type(prompt)}."
358
+ )
359
+ elif isinstance(negative_prompt, str):
360
+ uncond_tokens = [negative_prompt]
361
+ elif batch_size != len(negative_prompt):
362
+ raise ValueError(
363
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
364
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
365
+ " the batch size of `prompt`."
366
+ )
367
+ else:
368
+ uncond_tokens = negative_prompt
369
+
370
+ max_length = prompt_embeds.shape[1]
371
+ uncond_input = self.tokenizer(
372
+ uncond_tokens,
373
+ padding="max_length",
374
+ max_length=max_length,
375
+ truncation=True,
376
+ return_tensors="pt",
377
+ )
378
+
379
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
380
+ attention_mask = uncond_input.attention_mask.to(device)
381
+ else:
382
+ attention_mask = None
383
+
384
+ negative_prompt_embeds = self.text_encoder(
385
+ uncond_input.input_ids.to(device),
386
+ attention_mask=attention_mask,
387
+ )
388
+ negative_prompt_embeds = negative_prompt_embeds[0]
389
+
390
+ if do_classifier_free_guidance:
391
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
392
+ seq_len = negative_prompt_embeds.shape[1]
393
+
394
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
395
+
396
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
397
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
398
+
399
+ # For classifier free guidance, we need to do two forward passes.
400
+ # Here we concatenate the unconditional and text embeddings into a single batch
401
+ # to avoid doing two forward passes
402
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
403
+
404
+ return prompt_embeds
405
+
406
+ def decode_latents(self, latents):
407
+ video_length = latents.shape[2]
408
+ latents = 1 / 0.18215 * latents
409
+ latents = einops.rearrange(latents, "b c f h w -> (b f) c h w")
410
+ video = self.vae.decode(latents).sample
411
+ video = einops.rearrange(video, "(b f) c h w -> b f h w c", f=video_length)
412
+ video = ((video / 2 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().contiguous()
413
+ return video
414
+
415
+ def prepare_extra_step_kwargs(self, generator, eta):
416
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
417
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
418
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
419
+ # and should be between [0, 1]
420
+
421
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
422
+ extra_step_kwargs = {}
423
+ if accepts_eta:
424
+ extra_step_kwargs["eta"] = eta
425
+
426
+ # check if the scheduler accepts generator
427
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
428
+ if accepts_generator:
429
+ extra_step_kwargs["generator"] = generator
430
+ return extra_step_kwargs
431
+
432
+ def check_inputs(
433
+ self,
434
+ prompt,
435
+ height,
436
+ width,
437
+ callback_steps,
438
+ negative_prompt=None,
439
+ prompt_embeds=None,
440
+ negative_prompt_embeds=None,
441
+ ):
442
+ if height % 8 != 0 or width % 8 != 0:
443
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
444
+
445
+ if (callback_steps is None) or (
446
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
447
+ ):
448
+ raise ValueError(
449
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
450
+ f" {type(callback_steps)}."
451
+ )
452
+
453
+ if prompt is not None and prompt_embeds is not None:
454
+ raise ValueError(
455
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
456
+ " only forward one of the two."
457
+ )
458
+ elif prompt is None and prompt_embeds is None:
459
+ raise ValueError(
460
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
461
+ )
462
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
463
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
464
+
465
+ if negative_prompt is not None and negative_prompt_embeds is not None:
466
+ raise ValueError(
467
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
468
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
469
+ )
470
+
471
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
472
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
473
+ raise ValueError(
474
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
475
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
476
+ f" {negative_prompt_embeds.shape}."
477
+ )
478
+
479
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
480
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
481
+ if isinstance(generator, list) and len(generator) != batch_size:
482
+ raise ValueError(
483
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
484
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
485
+ )
486
+
487
+ if latents is None:
488
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
489
+ else:
490
+ latents = latents.to(device)
491
+
492
+ # scale the initial noise by the standard deviation required by the scheduler
493
+ latents = latents * self.scheduler.init_noise_sigma
494
+ return latents
495
+
496
+ @torch.no_grad()
497
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
498
+ def __call__(
499
+ self,
500
+ prompt: Union[str, List[str]] = None,
501
+ height: Optional[int] = None,
502
+ width: Optional[int] = None,
503
+ video_length: int = 16,
504
+ num_inference_steps: int = 50,
505
+ guidance_scale: float = 7.5,
506
+ negative_prompt: Optional[Union[str, List[str]]] = None,
507
+ num_images_per_prompt: Optional[int] = 1,
508
+ eta: float = 0.0,
509
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
510
+ latents: Optional[torch.FloatTensor] = None,
511
+ prompt_embeds: Optional[torch.FloatTensor] = None,
512
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
513
+ output_type: Optional[str] = "pil",
514
+ return_dict: bool = True,
515
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
516
+ callback_steps: int = 1,
517
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
518
+ ):
519
+ r"""
520
+ Function invoked when calling the pipeline for generation.
521
+
522
+ Args:
523
+ prompt (`str` or `List[str]`, *optional*):
524
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
525
+ instead.
526
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
527
+ The height in pixels of the generated image.
528
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
529
+ The width in pixels of the generated image.
530
+ num_inference_steps (`int`, *optional*, defaults to 50):
531
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
532
+ expense of slower inference.
533
+ guidance_scale (`float`, *optional*, defaults to 7.5):
534
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
535
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
536
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
537
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
538
+ usually at the expense of lower image quality.
539
+ negative_prompt (`str` or `List[str]`, *optional*):
540
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
541
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
542
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
543
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
544
+ The number of images to generate per prompt.
545
+ eta (`float`, *optional*, defaults to 0.0):
546
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
547
+ [`schedulers.DDIMScheduler`], will be ignored for others.
548
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
549
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
550
+ to make generation deterministic.
551
+ latents (`torch.FloatTensor`, *optional*):
552
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
553
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
554
+ tensor will ge generated by sampling using the supplied random `generator`.
555
+ prompt_embeds (`torch.FloatTensor`, *optional*):
556
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
557
+ provided, text embeddings will be generated from `prompt` input argument.
558
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
559
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
560
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
561
+ argument.
562
+ output_type (`str`, *optional*, defaults to `"pil"`):
563
+ The output format of the generate image. Choose between
564
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
565
+ return_dict (`bool`, *optional*, defaults to `True`):
566
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
567
+ plain tuple.
568
+ callback (`Callable`, *optional*):
569
+ A function that will be called every `callback_steps` steps during inference. The function will be
570
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
571
+ callback_steps (`int`, *optional*, defaults to 1):
572
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
573
+ called at every step.
574
+ cross_attention_kwargs (`dict`, *optional*):
575
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
576
+ `self.processor` in
577
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
578
+
579
+ Examples:
580
+
581
+ Returns:
582
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
583
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
584
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
585
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
586
+ (nsfw) content, according to the `safety_checker`.
587
+ """
588
+ # 0. Default height and width to unet
589
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
590
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
591
+
592
+ # 1. Check inputs. Raise error if not correct
593
+ self.check_inputs(
594
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
595
+ )
596
+
597
+ # 2. Define call parameters
598
+ if prompt is not None and isinstance(prompt, str):
599
+ batch_size = 1
600
+ elif prompt is not None and isinstance(prompt, list):
601
+ batch_size = len(prompt)
602
+ else:
603
+ batch_size = prompt_embeds.shape[0]
604
+
605
+ device = self._execution_device
606
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
607
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
608
+ # corresponds to doing no classifier free guidance.
609
+ do_classifier_free_guidance = guidance_scale > 1.0
610
+
611
+ # 3. Encode input prompt
612
+ prompt_embeds = self._encode_prompt(
613
+ prompt,
614
+ device,
615
+ num_images_per_prompt,
616
+ do_classifier_free_guidance,
617
+ negative_prompt,
618
+ prompt_embeds=prompt_embeds,
619
+ negative_prompt_embeds=negative_prompt_embeds,
620
+ )
621
+
622
+ # 4. Prepare timesteps
623
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
624
+ timesteps = self.scheduler.timesteps
625
+
626
+ # 5. Prepare latent variables
627
+ num_channels_latents = self.unet.config.in_channels
628
+ latents = self.prepare_latents(
629
+ batch_size * num_images_per_prompt,
630
+ num_channels_latents,
631
+ video_length,
632
+ height,
633
+ width,
634
+ prompt_embeds.dtype,
635
+ device,
636
+ generator,
637
+ latents,
638
+ )
639
+
640
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
641
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
642
+
643
+ # 7. Denoising loop
644
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
645
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
646
+ for i, t in enumerate(timesteps):
647
+ # expand the latents if we are doing classifier free guidance
648
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
649
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
650
+
651
+ # predict the noise residual
652
+ noise_pred = self.unet(
653
+ latent_model_input,
654
+ t,
655
+ encoder_hidden_states=prompt_embeds,
656
+ # cross_attention_kwargs=cross_attention_kwargs,
657
+ ).sample
658
+
659
+ # perform guidance
660
+ if do_classifier_free_guidance:
661
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
662
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
663
+
664
+ # compute the previous noisy sample x_t -> x_t-1
665
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
666
+
667
+ # call the callback, if provided
668
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
669
+ progress_bar.update()
670
+ if callback is not None and i % callback_steps == 0:
671
+ callback(i, t, latents)
672
+
673
+
674
+ # 8. Post-processing
675
+ video = self.decode_latents(latents)
676
+
677
+ return StableDiffusionPipelineOutput(video=video)
base/pipelines/sample.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import torchvision
5
+
6
+ from pipeline_videogen import VideoGenPipeline
7
+
8
+ from download import find_model
9
+ from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler
10
+ from diffusers.models import AutoencoderKL
11
+ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
12
+ from omegaconf import OmegaConf
13
+
14
+ import os, sys
15
+ sys.path.append(os.path.split(sys.path[0])[0])
16
+ from models import get_models
17
+ import imageio
18
+
19
+ def main(args):
20
+ #torch.manual_seed(args.seed)
21
+ torch.set_grad_enabled(False)
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+
24
+ sd_path = args.pretrained_path + "/stable-diffusion-v1-4"
25
+ unet = get_models(args, sd_path).to(device, dtype=torch.float16)
26
+ state_dict = find_model(args.pretrained_path + "/lavie_base.pt")
27
+ unet.load_state_dict(state_dict)
28
+
29
+ vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae", torch_dtype=torch.float16).to(device)
30
+ tokenizer_one = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
31
+ text_encoder_one = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # huge
32
+
33
+ # set eval mode
34
+ unet.eval()
35
+ vae.eval()
36
+ text_encoder_one.eval()
37
+
38
+ if args.sample_method == 'ddim':
39
+ scheduler = DDIMScheduler.from_pretrained(sd_path,
40
+ subfolder="scheduler",
41
+ beta_start=args.beta_start,
42
+ beta_end=args.beta_end,
43
+ beta_schedule=args.beta_schedule)
44
+ elif args.sample_method == 'eulerdiscrete':
45
+ scheduler = EulerDiscreteScheduler.from_pretrained(sd_path,
46
+ subfolder="scheduler",
47
+ beta_start=args.beta_start,
48
+ beta_end=args.beta_end,
49
+ beta_schedule=args.beta_schedule)
50
+ elif args.sample_method == 'ddpm':
51
+ scheduler = DDPMScheduler.from_pretrained(sd_path,
52
+ subfolder="scheduler",
53
+ beta_start=args.beta_start,
54
+ beta_end=args.beta_end,
55
+ beta_schedule=args.beta_schedule)
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ videogen_pipeline = VideoGenPipeline(vae=vae,
60
+ text_encoder=text_encoder_one,
61
+ tokenizer=tokenizer_one,
62
+ scheduler=scheduler,
63
+ unet=unet).to(device)
64
+ videogen_pipeline.enable_xformers_memory_efficient_attention()
65
+
66
+ if not os.path.exists(args.output_folder):
67
+ os.makedirs(args.output_folder)
68
+
69
+ video_grids = []
70
+ for prompt in args.text_prompt:
71
+ print('Processing the ({}) prompt'.format(prompt))
72
+ videos = videogen_pipeline(prompt,
73
+ video_length=args.video_length,
74
+ height=args.image_size[0],
75
+ width=args.image_size[1],
76
+ num_inference_steps=args.num_sampling_steps,
77
+ guidance_scale=args.guidance_scale).video
78
+ imageio.mimwrite(args.output_folder + prompt.replace(' ', '_') + '.mp4', videos[0], fps=8, quality=9) # highest quality is 10, lowest is 0
79
+
80
+ print('save path {}'.format(args.output_folder))
81
+
82
+ if __name__ == "__main__":
83
+ parser = argparse.ArgumentParser()
84
+ parser.add_argument("--config", type=str, default="")
85
+ args = parser.parse_args()
86
+
87
+ main(OmegaConf.load(args.config))
88
+
base/pipelines/sample.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ export CUDA_VISIBLE_DEVICES=6
2
+ python pipelines/sample.py --config configs/sample.yaml
base/text_to_video/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import torchvision
5
+
6
+ from pipelines.pipeline_videogen import VideoGenPipeline
7
+
8
+ from download import find_model
9
+ from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler
10
+ from diffusers.models import AutoencoderKL
11
+ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
12
+ from omegaconf import OmegaConf
13
+
14
+ import os, sys
15
+ sys.path.append(os.path.split(sys.path[0])[0])
16
+ from models import get_models
17
+ import imageio
18
+
19
+ config_path = "/mnt/petrelfs/zhouyan/project/lavie-release/base/configs/sample.yaml"
20
+ args = OmegaConf.load("/mnt/petrelfs/zhouyan/project/lavie-release/base/configs/sample.yaml")
21
+
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+
24
+ def model_t2v_fun(args):
25
+ sd_path = args.pretrained_path + "/stable-diffusion-v1-4"
26
+ unet = get_models(args, sd_path).to(device, dtype=torch.float16)
27
+ # state_dict = find_model(args.pretrained_path + "/lavie_base.pt")
28
+ state_dict = find_model("/mnt/petrelfs/share_data/wangyaohui/lavie/pretrained_models/lavie_base.pt")
29
+ unet.load_state_dict(state_dict)
30
+
31
+ vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae", torch_dtype=torch.float16).to(device)
32
+ tokenizer_one = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
33
+ text_encoder_one = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # huge
34
+ unet.eval()
35
+ vae.eval()
36
+ text_encoder_one.eval()
37
+ scheduler = DDIMScheduler.from_pretrained(sd_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule)
38
+ return VideoGenPipeline(vae=vae, text_encoder=text_encoder_one, tokenizer=tokenizer_one, scheduler=scheduler, unet=unet)
39
+
40
+ def setup_seed(seed):
41
+ torch.manual_seed(seed)
42
+ torch.cuda.manual_seed_all(seed)
43
+
44
+
base/text_to_video/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (3.38 kB). View file
 
base/try.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ with gr.Blocks() as demo:
4
+ prompt = gr.Textbox(label="Prompt", placeholder="enter prompt", show_label=True, elem_id="prompt-in")
5
+ demo.launch(server_name="0.0.0.0")
environment.yml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: lavie
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python=3.11.3
7
+ - pytorch=2.0.1
8
+ - pytorch-cuda=11.7
9
+ - torchvision=0.15.2
10
+ - pip:
11
+ - accelerate==0.19.0
12
+ - av==10.0.0
13
+ - decord==0.6.0
14
+ - diffusers[torch]==0.16.0
15
+ - einops==0.6.1
16
+ - ffmpeg==1.4
17
+ - imageio==2.31.1
18
+ - imageio-ffmpeg==0.4.9
19
+ - pandas==2.0.1
20
+ - timm==0.6.13
21
+ - tqdm==4.65.0
22
+ - transformers==4.28.1
23
+ - xformers==0.0.20
24
+ - omegaconf==2.3.0
25
+ - natsort==8.4.0
26
+ - rotary_embedding_torch
27
+ - gradio==4.3.0
interpolation/configs/sample.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ args:
2
+ input_folder: "../res/base/"
3
+ pretrained_path: "../pretrained_models"
4
+ output_folder: "../res/interpolation/"
5
+ seed_list:
6
+ - 3418
7
+
8
+ fps_list:
9
+ - 24
10
+
11
+ # model config:
12
+ model: TSR
13
+ num_frames: 61
14
+ image_size: [320, 512]
15
+ num_sampling_steps: 50
16
+ vae: mse
17
+ use_timecross_transformer: False
18
+ frame_interval: 1
19
+
20
+ # sample config:
21
+ seed: 0
22
+ cfg_scale: 4.0
23
+ run_time: 12
24
+ use_compile: False
25
+ enable_xformers_memory_efficient_attention: True
26
+ num_sample: 1
27
+
28
+ additional_prompt: ", 4k."
29
+ negative_prompt: "None"
30
+ do_classifier_free_guidance: True
31
+ use_ddim_sample_loop: True
32
+
33
+ researve_frame: 3
34
+ mask_type: "tsr"
35
+ use_concat: True
36
+ copy_no_mask: True
interpolation/datasets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from datasets import video_transforms
interpolation/datasets/video_transforms.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numbers
4
+ from torchvision.transforms import RandomCrop, RandomResizedCrop
5
+
6
+ def _is_tensor_video_clip(clip):
7
+ if not torch.is_tensor(clip):
8
+ raise TypeError("clip should be Tensor. Got %s" % type(clip))
9
+
10
+ if not clip.ndimension() == 4:
11
+ raise ValueError("clip should be 4D. Got %dD" % clip.dim())
12
+
13
+ return True
14
+
15
+
16
+ def to_tensor(clip):
17
+ """
18
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
19
+ permute the dimensions of clip tensor
20
+ Args:
21
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
22
+ Return:
23
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
24
+ """
25
+ _is_tensor_video_clip(clip)
26
+ if not clip.dtype == torch.uint8:
27
+ raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
28
+ # return clip.float().permute(3, 0, 1, 2) / 255.0
29
+ return clip.float() / 255.0
30
+
31
+
32
+ def resize(clip, target_size, interpolation_mode):
33
+ if len(target_size) != 2:
34
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
35
+ return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
36
+
37
+
38
+ class ToTensorVideo:
39
+ """
40
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
41
+ permute the dimensions of clip tensor
42
+ """
43
+
44
+ def __init__(self):
45
+ pass
46
+
47
+ def __call__(self, clip):
48
+ """
49
+ Args:
50
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
51
+ Return:
52
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
53
+ """
54
+ return to_tensor(clip)
55
+
56
+ def __repr__(self) -> str:
57
+ return self.__class__.__name__
58
+
59
+
60
+ class ResizeVideo:
61
+ '''
62
+ Resize to the specified size
63
+ '''
64
+ def __init__(
65
+ self,
66
+ size,
67
+ interpolation_mode="bilinear",
68
+ ):
69
+ if isinstance(size, tuple):
70
+ if len(size) != 2:
71
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
72
+ self.size = size
73
+ else:
74
+ self.size = (size, size)
75
+
76
+ self.interpolation_mode = interpolation_mode
77
+
78
+
79
+ def __call__(self, clip):
80
+ """
81
+ Args:
82
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
83
+ Returns:
84
+ torch.tensor: scale resized video clip.
85
+ size is (T, C, h, w)
86
+ """
87
+ clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
88
+ return clip_resize
89
+
90
+ def __repr__(self) -> str:
91
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
92
+
93
+
94
+ class TemporalRandomCrop(object):
95
+ """Temporally crop the given frame indices at a random location.
96
+
97
+ Args:
98
+ size (int): Desired length of frames will be seen in the model.
99
+ """
100
+
101
+ def __init__(self, size):
102
+ self.size = size
103
+
104
+ def __call__(self, total_frames):
105
+ rand_end = max(0, total_frames - self.size - 1)
106
+ begin_index = random.randint(0, rand_end)
107
+ end_index = min(begin_index + self.size, total_frames)
108
+ return begin_index, end_index
109
+
interpolation/diffusion/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from . import gaussian_diffusion as gd
7
+ from .respace import SpacedDiffusion, space_timesteps
8
+
9
+
10
+ def create_diffusion(
11
+ timestep_respacing,
12
+ noise_schedule="linear",
13
+ use_kl=False,
14
+ sigma_small=False,
15
+ predict_xstart=False,
16
+ # learn_sigma=True,
17
+ learn_sigma=False, # for unet
18
+ rescale_learned_sigmas=False,
19
+ diffusion_steps=1000
20
+ ):
21
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
22
+ if use_kl:
23
+ loss_type = gd.LossType.RESCALED_KL
24
+ elif rescale_learned_sigmas:
25
+ loss_type = gd.LossType.RESCALED_MSE
26
+ else:
27
+ loss_type = gd.LossType.MSE
28
+ if timestep_respacing is None or timestep_respacing == "":
29
+ timestep_respacing = [diffusion_steps]
30
+ return SpacedDiffusion(
31
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
32
+ betas=betas,
33
+ model_mean_type=(
34
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
35
+ ),
36
+ model_var_type=(
37
+ (
38
+ gd.ModelVarType.FIXED_LARGE
39
+ if not sigma_small
40
+ else gd.ModelVarType.FIXED_SMALL
41
+ )
42
+ if not learn_sigma
43
+ else gd.ModelVarType.LEARNED_RANGE
44
+ ),
45
+ loss_type=loss_type
46
+ # rescale_timesteps=rescale_timesteps,
47
+ )
interpolation/diffusion/diffusion_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import torch as th
7
+ import numpy as np
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = None
17
+ for obj in (mean1, logvar1, mean2, logvar2):
18
+ if isinstance(obj, th.Tensor):
19
+ tensor = obj
20
+ break
21
+ assert tensor is not None, "at least one argument must be a Tensor"
22
+
23
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
24
+ # Tensors, but it does not work for th.exp().
25
+ logvar1, logvar2 = [
26
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27
+ for x in (logvar1, logvar2)
28
+ ]
29
+
30
+ return 0.5 * (
31
+ -1.0
32
+ + logvar2
33
+ - logvar1
34
+ + th.exp(logvar1 - logvar2)
35
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36
+ )
37
+
38
+
39
+ def approx_standard_normal_cdf(x):
40
+ """
41
+ A fast approximation of the cumulative distribution function of the
42
+ standard normal.
43
+ """
44
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45
+
46
+
47
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48
+ """
49
+ Compute the log-likelihood of a continuous Gaussian distribution.
50
+ :param x: the targets
51
+ :param means: the Gaussian mean Tensor.
52
+ :param log_scales: the Gaussian log stddev Tensor.
53
+ :return: a tensor like x of log probabilities (in nats).
54
+ """
55
+ centered_x = x - means
56
+ inv_stdv = th.exp(-log_scales)
57
+ normalized_x = centered_x * inv_stdv
58
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59
+ return log_probs
60
+
61
+
62
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63
+ """
64
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
65
+ given image.
66
+ :param x: the target images. It is assumed that this was uint8 values,
67
+ rescaled to the range [-1, 1].
68
+ :param means: the Gaussian mean Tensor.
69
+ :param log_scales: the Gaussian log stddev Tensor.
70
+ :return: a tensor like x of log probabilities (in nats).
71
+ """
72
+ assert x.shape == means.shape == log_scales.shape
73
+ centered_x = x - means
74
+ inv_stdv = th.exp(-log_scales)
75
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76
+ cdf_plus = approx_standard_normal_cdf(plus_in)
77
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78
+ cdf_min = approx_standard_normal_cdf(min_in)
79
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81
+ cdf_delta = cdf_plus - cdf_min
82
+ log_probs = th.where(
83
+ x < -0.999,
84
+ log_cdf_plus,
85
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86
+ )
87
+ assert log_probs.shape == x.shape
88
+ return log_probs
interpolation/diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch as th
11
+ import enum
12
+
13
+ from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14
+
15
+
16
+ def mean_flat(tensor):
17
+ """
18
+ Take the mean over all non-batch dimensions.
19
+ """
20
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
21
+
22
+
23
+ class ModelMeanType(enum.Enum):
24
+ """
25
+ Which type of output the model predicts.
26
+ """
27
+
28
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
29
+ START_X = enum.auto() # the model predicts x_0
30
+ EPSILON = enum.auto() # the model predicts epsilon
31
+
32
+
33
+ class ModelVarType(enum.Enum):
34
+ """
35
+ What is used as the model's output variance.
36
+ The LEARNED_RANGE option has been added to allow the model to predict
37
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
38
+ """
39
+
40
+ LEARNED = enum.auto()
41
+ FIXED_SMALL = enum.auto()
42
+ FIXED_LARGE = enum.auto()
43
+ LEARNED_RANGE = enum.auto()
44
+
45
+
46
+ class LossType(enum.Enum):
47
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
48
+ RESCALED_MSE = (
49
+ enum.auto()
50
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
51
+ KL = enum.auto() # use the variational lower-bound
52
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
53
+
54
+ def is_vb(self):
55
+ return self == LossType.KL or self == LossType.RESCALED_KL
56
+
57
+
58
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
59
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
60
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
61
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
62
+ return betas
63
+
64
+
65
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
66
+ """
67
+ This is the deprecated API for creating beta schedules.
68
+ See get_named_beta_schedule() for the new library of schedules.
69
+ """
70
+ if beta_schedule == "quad":
71
+ betas = (
72
+ np.linspace(
73
+ beta_start ** 0.5,
74
+ beta_end ** 0.5,
75
+ num_diffusion_timesteps,
76
+ dtype=np.float64,
77
+ )
78
+ ** 2
79
+ )
80
+ elif beta_schedule == "linear":
81
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
82
+ elif beta_schedule == "warmup10":
83
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
84
+ elif beta_schedule == "warmup50":
85
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
86
+ elif beta_schedule == "const":
87
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
88
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
89
+ betas = 1.0 / np.linspace(
90
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
91
+ )
92
+ else:
93
+ raise NotImplementedError(beta_schedule)
94
+ assert betas.shape == (num_diffusion_timesteps,)
95
+ return betas
96
+
97
+
98
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
99
+ """
100
+ Get a pre-defined beta schedule for the given name.
101
+ The beta schedule library consists of beta schedules which remain similar
102
+ in the limit of num_diffusion_timesteps.
103
+ Beta schedules may be added, but should not be removed or changed once
104
+ they are committed to maintain backwards compatibility.
105
+ """
106
+ if schedule_name == "linear":
107
+ # Linear schedule from Ho et al, extended to work for any number of
108
+ # diffusion steps.
109
+ scale = 1000 / num_diffusion_timesteps
110
+ return get_beta_schedule(
111
+ "linear",
112
+ beta_start=scale * 0.0001,
113
+ beta_end=scale * 0.02,
114
+ num_diffusion_timesteps=num_diffusion_timesteps,
115
+ )
116
+ elif schedule_name == "squaredcos_cap_v2":
117
+ return betas_for_alpha_bar(
118
+ num_diffusion_timesteps,
119
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
120
+ )
121
+ else:
122
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
123
+
124
+
125
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
126
+ """
127
+ Create a beta schedule that discretizes the given alpha_t_bar function,
128
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
129
+ :param num_diffusion_timesteps: the number of betas to produce.
130
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
131
+ produces the cumulative product of (1-beta) up to that
132
+ part of the diffusion process.
133
+ :param max_beta: the maximum beta to use; use values lower than 1 to
134
+ prevent singularities.
135
+ """
136
+ betas = []
137
+ for i in range(num_diffusion_timesteps):
138
+ t1 = i / num_diffusion_timesteps
139
+ t2 = (i + 1) / num_diffusion_timesteps
140
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
141
+ return np.array(betas)
142
+
143
+
144
+ class GaussianDiffusion:
145
+ """
146
+ Utilities for training and sampling diffusion models.
147
+ Original ported from this codebase:
148
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
149
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
150
+ starting at T and going to 1.
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ *,
156
+ betas,
157
+ model_mean_type,
158
+ model_var_type,
159
+ loss_type
160
+ ):
161
+
162
+ self.model_mean_type = model_mean_type
163
+ self.model_var_type = model_var_type
164
+ self.loss_type = loss_type
165
+
166
+ # Use float64 for accuracy.
167
+ betas = np.array(betas, dtype=np.float64)
168
+ self.betas = betas
169
+ assert len(betas.shape) == 1, "betas must be 1-D"
170
+ assert (betas > 0).all() and (betas <= 1).all()
171
+
172
+ self.num_timesteps = int(betas.shape[0])
173
+
174
+ alphas = 1.0 - betas
175
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
176
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
177
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
178
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
179
+
180
+ # calculations for diffusion q(x_t | x_{t-1}) and others
181
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
182
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
183
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
184
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
185
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
186
+
187
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
188
+ self.posterior_variance = (
189
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
190
+ )
191
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
192
+ self.posterior_log_variance_clipped = np.log(
193
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
194
+ ) if len(self.posterior_variance) > 1 else np.array([])
195
+
196
+ self.posterior_mean_coef1 = (
197
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
198
+ )
199
+ self.posterior_mean_coef2 = (
200
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
201
+ )
202
+
203
+ def q_mean_variance(self, x_start, t):
204
+ """
205
+ Get the distribution q(x_t | x_0).
206
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
207
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
208
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
209
+ """
210
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
211
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
212
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
213
+ return mean, variance, log_variance
214
+
215
+ def q_sample(self, x_start, t, noise=None):
216
+ """
217
+ Diffuse the data for a given number of diffusion steps.
218
+ In other words, sample from q(x_t | x_0).
219
+ :param x_start: the initial data batch.
220
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
221
+ :param noise: if specified, the split-out normal noise.
222
+ :return: A noisy version of x_start.
223
+ """
224
+ if noise is None:
225
+ noise = th.randn_like(x_start)
226
+ assert noise.shape == x_start.shape
227
+ return (
228
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
229
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
230
+ )
231
+
232
+ def q_posterior_mean_variance(self, x_start, x_t, t):
233
+ """
234
+ Compute the mean and variance of the diffusion posterior:
235
+ q(x_{t-1} | x_t, x_0)
236
+ """
237
+ assert x_start.shape == x_t.shape
238
+ posterior_mean = (
239
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
240
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
241
+ )
242
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
243
+ posterior_log_variance_clipped = _extract_into_tensor(
244
+ self.posterior_log_variance_clipped, t, x_t.shape
245
+ )
246
+ assert (
247
+ posterior_mean.shape[0]
248
+ == posterior_variance.shape[0]
249
+ == posterior_log_variance_clipped.shape[0]
250
+ == x_start.shape[0]
251
+ )
252
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
253
+
254
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None,
255
+ mask=None, x_start=None, use_concat=False,
256
+ copy_no_mask=False, ):
257
+ """
258
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
259
+ the initial x, x_0.
260
+ :param model: the model, which takes a signal and a batch of timesteps
261
+ as input.
262
+ :param x: the [N x C x ...] tensor at time t.
263
+ :param t: a 1-D Tensor of timesteps.
264
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
265
+ :param denoised_fn: if not None, a function which applies to the
266
+ x_start prediction before it is used to sample. Applies before
267
+ clip_denoised.
268
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
269
+ pass to the model. This can be used for conditioning.
270
+ :return: a dict with the following keys:
271
+ - 'mean': the model mean output.
272
+ - 'variance': the model variance output.
273
+ - 'log_variance': the log of 'variance'.
274
+ - 'pred_xstart': the prediction for x_0.
275
+ """
276
+ if model_kwargs is None:
277
+ model_kwargs = {}
278
+
279
+ B, F, C = x.shape[:3]
280
+ assert t.shape == (B,)
281
+ # model_output = model(x, t, **model_kwargs)
282
+ if copy_no_mask:
283
+ if use_concat:
284
+ try:
285
+ model_output = model(th.concat([x, x_start], dim=1), t, **model_kwargs).sample
286
+ except:
287
+ # print(f'x.shape = {x.shape}, x_start.shape = {x_start.shape}')
288
+ # )
289
+ # x.shape = torch.Size([2, 4, 61, 32, 32]), x_start.shape = torch.Size([2, 4, 61, 32, 32]
290
+ # print(f'x[0,0,:,0,0] = {x[0,0,:,0,0]}, \nx_start[0,0,:,0,0] = {x_start[0,0,:,0,0]}')
291
+ model_output = model(th.concat([x, x_start], dim=1), t, **model_kwargs)
292
+ else:
293
+ try:
294
+ model_output = model(x, t, **model_kwargs).sample # for tav unet
295
+ except:
296
+ model_output = model(x, t, **model_kwargs)
297
+ else:
298
+ if use_concat:
299
+ try:
300
+ model_output = model(th.concat([x, mask, x_start], dim=1), t, **model_kwargs).sample
301
+ except:
302
+ model_output = model(th.concat([x, mask, x_start], dim=1), t, **model_kwargs)
303
+ else:
304
+ try:
305
+ model_output = model(x, t, **model_kwargs).sample # for tav unet
306
+ except:
307
+ model_output = model(x, t, **model_kwargs)
308
+ if isinstance(model_output, tuple):
309
+ model_output, extra = model_output
310
+ else:
311
+ extra = None
312
+
313
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
314
+ assert model_output.shape == (B, F, C * 2, *x.shape[3:])
315
+ model_output, model_var_values = th.split(model_output, C, dim=2)
316
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
317
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
318
+ # The model_var_values is [-1, 1] for [min_var, max_var].
319
+ frac = (model_var_values + 1) / 2
320
+ model_log_variance = frac * max_log + (1 - frac) * min_log
321
+ model_variance = th.exp(model_log_variance)
322
+ else:
323
+ model_variance, model_log_variance = {
324
+ # for fixedlarge, we set the initial (log-)variance like so
325
+ # to get a better decoder log likelihood.
326
+ ModelVarType.FIXED_LARGE: (
327
+ np.append(self.posterior_variance[1], self.betas[1:]),
328
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
329
+ ),
330
+ ModelVarType.FIXED_SMALL: (
331
+ self.posterior_variance,
332
+ self.posterior_log_variance_clipped,
333
+ ),
334
+ }[self.model_var_type]
335
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
336
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
337
+
338
+ def process_xstart(x):
339
+ if denoised_fn is not None:
340
+ x = denoised_fn(x)
341
+ if clip_denoised:
342
+ return x.clamp(-1, 1)
343
+ return x
344
+
345
+ if self.model_mean_type == ModelMeanType.START_X:
346
+ pred_xstart = process_xstart(model_output)
347
+ else:
348
+ pred_xstart = process_xstart(
349
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output, mask=mask, x_start=x_start, use_concat=use_concat)
350
+ )
351
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
352
+
353
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
354
+ return {
355
+ "mean": model_mean,
356
+ "variance": model_variance,
357
+ "log_variance": model_log_variance,
358
+ "pred_xstart": pred_xstart,
359
+ "extra": extra,
360
+ }
361
+
362
+ def _predict_xstart_from_eps(self, x_t, t, eps, mask=None, x_start=None, use_concat=False): # (x_t=x, t=t, eps=model_output)
363
+ assert x_t.shape == eps.shape
364
+ if not use_concat:
365
+ if mask is not None:
366
+ if x_start is None:
367
+ return (
368
+ (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
369
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps )* mask + x_t * (1-mask)
370
+ )
371
+ else:
372
+ # breakpoint()
373
+ if (t == 0).any():
374
+ print('t=0')
375
+ x_unknown = _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t \
376
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
377
+ return x_start * (1-mask) + x_unknown * mask
378
+ else:
379
+ x_known = self.q_sample(x_start, t-1)
380
+ x_unknown = _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t \
381
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
382
+ return (
383
+ x_known * (1-mask) + x_unknown * mask
384
+ )
385
+ else:
386
+ return (
387
+ (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
388
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps )
389
+ )
390
+ else:
391
+ return (
392
+ (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
393
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps )
394
+ )
395
+
396
+
397
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
398
+ return (
399
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
400
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
401
+
402
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
403
+ """
404
+ Compute the mean for the previous step, given a function cond_fn that
405
+ computes the gradient of a conditional log probability with respect to
406
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
407
+ condition on y.
408
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
409
+ """
410
+ gradient = cond_fn(x, t, **model_kwargs)
411
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
412
+ return new_mean
413
+
414
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
415
+ """
416
+ Compute what the p_mean_variance output would have been, should the
417
+ model's score function be conditioned by cond_fn.
418
+ See condition_mean() for details on cond_fn.
419
+ Unlike condition_mean(), this instead uses the conditioning strategy
420
+ from Song et al (2020).
421
+ """
422
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
423
+
424
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
425
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
426
+
427
+ out = p_mean_var.copy()
428
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
429
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
430
+ return out
431
+
432
+ def p_sample(
433
+ self,
434
+ model,
435
+ x,
436
+ t,
437
+ clip_denoised=True,
438
+ denoised_fn=None,
439
+ cond_fn=None,
440
+ model_kwargs=None,
441
+ mask=None,
442
+ x_start=None,
443
+ use_concat=False
444
+ ):
445
+ """
446
+ Sample x_{t-1} from the model at the given timestep.
447
+ :param model: the model to sample from.
448
+ :param x: the current tensor at x_{t-1}.
449
+ :param t: the value of t, starting at 0 for the first diffusion step.
450
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
451
+ :param denoised_fn: if not None, a function which applies to the
452
+ x_start prediction before it is used to sample.
453
+ :param cond_fn: if not None, this is a gradient function that acts
454
+ similarly to the model.
455
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
456
+ pass to the model. This can be used for conditioning.
457
+ :return: a dict containing the following keys:
458
+ - 'sample': a random sample from the model.
459
+ - 'pred_xstart': a prediction of x_0.
460
+ """
461
+ out = self.p_mean_variance(
462
+ model,
463
+ x,
464
+ t,
465
+ clip_denoised=clip_denoised,
466
+ denoised_fn=denoised_fn,
467
+ model_kwargs=model_kwargs,
468
+ mask=mask,
469
+ x_start=x_start,
470
+ use_concat=use_concat,
471
+ )
472
+ noise = th.randn_like(x)
473
+ nonzero_mask = (
474
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
475
+ ) # no noise when t == 0
476
+ if cond_fn is not None:
477
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
478
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
479
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
480
+
481
+ def p_sample_loop(
482
+ self,
483
+ model,
484
+ shape,
485
+ noise=None,
486
+ clip_denoised=True,
487
+ denoised_fn=None,
488
+ cond_fn=None,
489
+ model_kwargs=None,
490
+ device=None,
491
+ progress=False,
492
+ mask=None,
493
+ x_start=None,
494
+ use_concat=False,
495
+ ):
496
+ """
497
+ Generate samples from the model.
498
+ :param model: the model module.
499
+ :param shape: the shape of the samples, (N, C, H, W).
500
+ :param noise: if specified, the noise from the encoder to sample.
501
+ Should be of the same shape as `shape`.
502
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
503
+ :param denoised_fn: if not None, a function which applies to the
504
+ x_start prediction before it is used to sample.
505
+ :param cond_fn: if not None, this is a gradient function that acts
506
+ similarly to the model.
507
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
508
+ pass to the model. This can be used for conditioning.
509
+ :param device: if specified, the device to create the samples on.
510
+ If not specified, use a model parameter's device.
511
+ :param progress: if True, show a tqdm progress bar.
512
+ :return: a non-differentiable batch of samples.
513
+ """
514
+ final = None
515
+ for sample in self.p_sample_loop_progressive(
516
+ model,
517
+ shape,
518
+ noise=noise,
519
+ clip_denoised=clip_denoised,
520
+ denoised_fn=denoised_fn,
521
+ cond_fn=cond_fn,
522
+ model_kwargs=model_kwargs,
523
+ device=device,
524
+ progress=progress,
525
+ mask=mask,
526
+ x_start=x_start,
527
+ use_concat=use_concat
528
+ ):
529
+ final = sample
530
+ return final["sample"]
531
+
532
+ def p_sample_loop_progressive(
533
+ self,
534
+ model,
535
+ shape,
536
+ noise=None,
537
+ clip_denoised=True,
538
+ denoised_fn=None,
539
+ cond_fn=None,
540
+ model_kwargs=None,
541
+ device=None,
542
+ progress=False,
543
+ mask=None,
544
+ x_start=None,
545
+ use_concat=False
546
+ ):
547
+ """
548
+ Generate samples from the model and yield intermediate samples from
549
+ each timestep of diffusion.
550
+ Arguments are the same as p_sample_loop().
551
+ Returns a generator over dicts, where each dict is the return value of
552
+ p_sample().
553
+ """
554
+ if device is None:
555
+ device = next(model.parameters()).device
556
+ assert isinstance(shape, (tuple, list))
557
+ if noise is not None:
558
+ img = noise
559
+ else:
560
+ img = th.randn(*shape, device=device)
561
+ indices = list(range(self.num_timesteps))[::-1]
562
+
563
+ if progress:
564
+ # Lazy import so that we don't depend on tqdm.
565
+ from tqdm.auto import tqdm
566
+
567
+ indices = tqdm(indices)
568
+
569
+ for i in indices:
570
+ t = th.tensor([i] * shape[0], device=device)
571
+ with th.no_grad(): # loop
572
+ out = self.p_sample(
573
+ model,
574
+ img,
575
+ t,
576
+ clip_denoised=clip_denoised,
577
+ denoised_fn=denoised_fn,
578
+ cond_fn=cond_fn,
579
+ model_kwargs=model_kwargs,
580
+ mask=mask,
581
+ x_start=x_start,
582
+ use_concat=use_concat
583
+ )
584
+ yield out
585
+ img = out["sample"]
586
+
587
+ def ddim_sample(
588
+ self,
589
+ model,
590
+ x,
591
+ t,
592
+ clip_denoised=True,
593
+ denoised_fn=None,
594
+ cond_fn=None,
595
+ model_kwargs=None,
596
+ eta=0.0,
597
+ mask=None,
598
+ x_start=None,
599
+ use_concat=False,
600
+ copy_no_mask=False,
601
+ ):
602
+ """
603
+ Sample x_{t-1} from the model using DDIM.
604
+ Same usage as p_sample().
605
+ """
606
+ out = self.p_mean_variance(
607
+ model,
608
+ x,
609
+ t,
610
+ clip_denoised=clip_denoised,
611
+ denoised_fn=denoised_fn,
612
+ model_kwargs=model_kwargs,
613
+ mask=mask,
614
+ x_start=x_start,
615
+ use_concat=use_concat,
616
+ copy_no_mask=copy_no_mask,
617
+ )
618
+ if cond_fn is not None:
619
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
620
+
621
+ # Usually our model outputs epsilon, but we re-derive it
622
+ # in case we used x_start or x_prev prediction.
623
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
624
+
625
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
626
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
627
+ sigma = (
628
+ eta
629
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
630
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
631
+ )
632
+ # Equation 12.
633
+ noise = th.randn_like(x)
634
+ mean_pred = (
635
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
636
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
637
+ )
638
+ nonzero_mask = (
639
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
640
+ ) # no noise when t == 0
641
+ sample = mean_pred + nonzero_mask * sigma * noise
642
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
643
+
644
+ def ddim_reverse_sample(
645
+ self,
646
+ model,
647
+ x,
648
+ t,
649
+ clip_denoised=True,
650
+ denoised_fn=None,
651
+ cond_fn=None,
652
+ model_kwargs=None,
653
+ eta=0.0,
654
+ ):
655
+ """
656
+ Sample x_{t+1} from the model using DDIM reverse ODE.
657
+ """
658
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
659
+ out = self.p_mean_variance(
660
+ model,
661
+ x,
662
+ t,
663
+ clip_denoised=clip_denoised,
664
+ denoised_fn=denoised_fn,
665
+ model_kwargs=model_kwargs,
666
+ )
667
+ if cond_fn is not None:
668
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
669
+ # Usually our model outputs epsilon, but we re-derive it
670
+ # in case we used x_start or x_prev prediction.
671
+ eps = (
672
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
673
+ - out["pred_xstart"]
674
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
675
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
676
+
677
+ # Equation 12. reversed
678
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
679
+
680
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
681
+
682
+ def ddim_sample_loop(
683
+ self,
684
+ model,
685
+ shape,
686
+ noise=None,
687
+ clip_denoised=True,
688
+ denoised_fn=None,
689
+ cond_fn=None,
690
+ model_kwargs=None,
691
+ device=None,
692
+ progress=False,
693
+ eta=0.0,
694
+ mask=None,
695
+ x_start=None,
696
+ use_concat=False,
697
+ copy_no_mask=False,
698
+ ):
699
+ """
700
+ Generate samples from the model using DDIM.
701
+ Same usage as p_sample_loop().
702
+ """
703
+ final = None
704
+ for sample in self.ddim_sample_loop_progressive(
705
+ model,
706
+ shape,
707
+ noise=noise,
708
+ clip_denoised=clip_denoised,
709
+ denoised_fn=denoised_fn,
710
+ cond_fn=cond_fn,
711
+ model_kwargs=model_kwargs,
712
+ device=device,
713
+ progress=progress,
714
+ eta=eta,
715
+ mask=mask,
716
+ x_start=x_start,
717
+ use_concat=use_concat,
718
+ copy_no_mask=copy_no_mask,
719
+ ):
720
+ final = sample
721
+ return final["sample"]
722
+
723
+ def ddim_sample_loop_progressive(
724
+ self,
725
+ model,
726
+ shape,
727
+ noise=None,
728
+ clip_denoised=True,
729
+ denoised_fn=None,
730
+ cond_fn=None,
731
+ model_kwargs=None,
732
+ device=None,
733
+ progress=False,
734
+ eta=0.0,
735
+ mask=None,
736
+ x_start=None,
737
+ use_concat=False,
738
+ copy_no_mask=False,
739
+ ):
740
+ """
741
+ Use DDIM to sample from the model and yield intermediate samples from
742
+ each timestep of DDIM.
743
+ Same usage as p_sample_loop_progressive().
744
+ """
745
+ if device is None:
746
+ device = next(model.parameters()).device
747
+ assert isinstance(shape, (tuple, list))
748
+ if noise is not None:
749
+ img = noise
750
+ else:
751
+ img = th.randn(*shape, device=device)
752
+ indices = list(range(self.num_timesteps))[::-1]
753
+
754
+ if progress:
755
+ # Lazy import so that we don't depend on tqdm.
756
+ from tqdm.auto import tqdm
757
+
758
+ indices = tqdm(indices)
759
+
760
+ for i in indices:
761
+ t = th.tensor([i] * shape[0], device=device)
762
+ with th.no_grad():
763
+ out = self.ddim_sample(
764
+ model,
765
+ img,
766
+ t,
767
+ clip_denoised=clip_denoised,
768
+ denoised_fn=denoised_fn,
769
+ cond_fn=cond_fn,
770
+ model_kwargs=model_kwargs,
771
+ eta=eta,
772
+ mask=mask,
773
+ x_start=x_start,
774
+ use_concat=use_concat,
775
+ copy_no_mask=copy_no_mask,
776
+ )
777
+ yield out
778
+ img = out["sample"]
779
+
780
+ def _vb_terms_bpd(
781
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
782
+ ):
783
+ """
784
+ Get a term for the variational lower-bound.
785
+ The resulting units are bits (rather than nats, as one might expect).
786
+ This allows for comparison to other papers.
787
+ :return: a dict with the following keys:
788
+ - 'output': a shape [N] tensor of NLLs or KLs.
789
+ - 'pred_xstart': the x_0 predictions.
790
+ """
791
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
792
+ x_start=x_start, x_t=x_t, t=t
793
+ )
794
+ out = self.p_mean_variance(
795
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
796
+ )
797
+ kl = normal_kl(
798
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
799
+ )
800
+ kl = mean_flat(kl) / np.log(2.0)
801
+
802
+ decoder_nll = -discretized_gaussian_log_likelihood(
803
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
804
+ )
805
+ assert decoder_nll.shape == x_start.shape
806
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
807
+
808
+ # At the first timestep return the decoder NLL,
809
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
810
+ output = th.where((t == 0), decoder_nll, kl)
811
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
812
+
813
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, mask=None, t_head=None, copy_no_mask=False):
814
+ """
815
+ Compute training losses for a single timestep.
816
+ :param model: the model to evaluate loss on.
817
+ :param x_start: the [N x C x ...] tensor of inputs.
818
+ :param t: a batch of timestep indices.
819
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
820
+ pass to the model. This can be used for conditioning.
821
+ :param noise: if specified, the specific Gaussian noise to try to remove.
822
+ :return: a dict with the key "loss" containing a tensor of shape [N].
823
+ Some mean or variance settings may also have other keys.
824
+ """
825
+ # mask could be here
826
+ if model_kwargs is None:
827
+ model_kwargs = {}
828
+ if noise is None:
829
+ noise = th.randn_like(x_start)
830
+ x_t = self.q_sample(x_start, t, noise=noise)
831
+ x_t = th.cat([x_t[:, :4], x_start[:, 4:]], dim=1)
832
+ # mask is used for (0,0,0,1,1,1,...) which means the diffusion model can see the first 3 frames of the input video
833
+ # print(f'training_losses(): mask = {mask}') # None
834
+
835
+ if mask is not None:
836
+ x_t = x_t*mask + x_start*(1-mask)
837
+
838
+ # noise augmentation
839
+ if copy_no_mask:
840
+ if t_head is not None:
841
+ noise_aug = self.q_sample(x_start[:, 4:], t_head) # noise aug on copied_video
842
+ x_t = th.cat([x_t[:, :4], noise_aug], dim=1)
843
+ else:
844
+ if t_head is not None:
845
+ noise_aug = self.q_sample(x_start[:, 5:], t_head) # b, 4, f, h, w
846
+ noise_aug = noise_aug * (x_start[:, 4].unsqueeze(1).expand(-1, 4, -1, -1, -1) == 0) # use mask to zero out augmented noises
847
+ x_t = th.cat([x_t[:, :5], noise_aug], dim=1)
848
+ terms = {}
849
+ # for i in [0,1,2,3,4,5,6,7]:
850
+ # print(f'x_t[0,{i},:,0,0] = {x_t[0,i,:,0,0]}')
851
+
852
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
853
+ terms["loss"] = self._vb_terms_bpd(
854
+ model=model,
855
+ x_start=x_start,
856
+ x_t=x_t,
857
+ t=t,
858
+ clip_denoised=False,
859
+ model_kwargs=model_kwargs,
860
+ )["output"]
861
+ if self.loss_type == LossType.RESCALED_KL:
862
+ terms["loss"] *= self.num_timesteps
863
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
864
+ # print(f'self.loss_type = {self.loss_type}') # LossType.MSE
865
+ # model_output = model(x_t, t, **model_kwargs)
866
+ try:
867
+ model_output = model(x_t, t, **model_kwargs).sample # for tav unet
868
+ except:
869
+ model_output = model(x_t, t, **model_kwargs)
870
+
871
+ if self.model_var_type in [
872
+ ModelVarType.LEARNED,
873
+ ModelVarType.LEARNED_RANGE,
874
+ ]:
875
+ B, F, C = x_t.shape[:3]
876
+ assert model_output.shape == (B, F, C * 2, *x_t.shape[3:])
877
+ model_output, model_var_values = th.split(model_output, C, dim=2)
878
+ # Learn the variance using the variational bound, but don't let
879
+ # it affect our mean prediction.
880
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=2)
881
+ terms["vb"] = self._vb_terms_bpd(
882
+ model=lambda *args, r=frozen_out: r,
883
+ x_start=x_start,
884
+ x_t=x_t,
885
+ t=t,
886
+ clip_denoised=False,
887
+ )["output"]
888
+ if self.loss_type == LossType.RESCALED_MSE:
889
+ # Divide by 1000 for equivalence with initial implementation.
890
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
891
+ terms["vb"] *= self.num_timesteps / 1000.0
892
+
893
+ # print(f'self.model_mean_type = {self.model_mean_type}') # ModelMeanType.EPSILON
894
+ target = {
895
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
896
+ x_start=x_start, x_t=x_t, t=t
897
+ )[0],
898
+ ModelMeanType.START_X: x_start,
899
+ ModelMeanType.EPSILON: noise,
900
+ }[self.model_mean_type]
901
+ # assert model_output.shape == target.shape == x_start.shape
902
+ # if mask is not None:
903
+ # nonzero_idx = th.nonzero(1-mask)
904
+ terms["mse"] = mean_flat((target[:,:4] - model_output) ** 2)
905
+ # else:
906
+ # terms["mse"] = mean_flat((target - model_output) ** 2)
907
+ if "vb" in terms:
908
+ terms["loss"] = terms["mse"] + terms["vb"]
909
+ else:
910
+ terms["loss"] = terms["mse"]
911
+ else:
912
+ raise NotImplementedError(self.loss_type)
913
+
914
+ return terms
915
+
916
+ def _prior_bpd(self, x_start):
917
+ """
918
+ Get the prior KL term for the variational lower-bound, measured in
919
+ bits-per-dim.
920
+ This term can't be optimized, as it only depends on the encoder.
921
+ :param x_start: the [N x C x ...] tensor of inputs.
922
+ :return: a batch of [N] KL values (in bits), one per batch element.
923
+ """
924
+ batch_size = x_start.shape[0]
925
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
926
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
927
+ kl_prior = normal_kl(
928
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
929
+ )
930
+ return mean_flat(kl_prior) / np.log(2.0)
931
+
932
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
933
+ """
934
+ Compute the entire variational lower-bound, measured in bits-per-dim,
935
+ as well as other related quantities.
936
+ :param model: the model to evaluate loss on.
937
+ :param x_start: the [N x C x ...] tensor of inputs.
938
+ :param clip_denoised: if True, clip denoised samples.
939
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
940
+ pass to the model. This can be used for conditioning.
941
+ :return: a dict containing the following keys:
942
+ - total_bpd: the total variational lower-bound, per batch element.
943
+ - prior_bpd: the prior term in the lower-bound.
944
+ - vb: an [N x T] tensor of terms in the lower-bound.
945
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
946
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
947
+ """
948
+ device = x_start.device
949
+ batch_size = x_start.shape[0]
950
+
951
+ vb = []
952
+ xstart_mse = []
953
+ mse = []
954
+ for t in list(range(self.num_timesteps))[::-1]:
955
+ t_batch = th.tensor([t] * batch_size, device=device)
956
+ noise = th.randn_like(x_start)
957
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
958
+ # Calculate VLB term at the current timestep
959
+ with th.no_grad():
960
+ out = self._vb_terms_bpd(
961
+ model,
962
+ x_start=x_start,
963
+ x_t=x_t,
964
+ t=t_batch,
965
+ clip_denoised=clip_denoised,
966
+ model_kwargs=model_kwargs,
967
+ )
968
+ vb.append(out["output"])
969
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
970
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
971
+ mse.append(mean_flat((eps - noise) ** 2))
972
+
973
+ vb = th.stack(vb, dim=1)
974
+ xstart_mse = th.stack(xstart_mse, dim=1)
975
+ mse = th.stack(mse, dim=1)
976
+
977
+ prior_bpd = self._prior_bpd(x_start)
978
+ total_bpd = vb.sum(dim=1) + prior_bpd
979
+ return {
980
+ "total_bpd": total_bpd,
981
+ "prior_bpd": prior_bpd,
982
+ "vb": vb,
983
+ "xstart_mse": xstart_mse,
984
+ "mse": mse,
985
+ }
986
+
987
+
988
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
989
+ """
990
+ Extract values from a 1-D numpy array for a batch of indices.
991
+ :param arr: the 1-D numpy array.
992
+ :param timesteps: a tensor of indices into the array to extract.
993
+ :param broadcast_shape: a larger shape of K dimensions with the batch
994
+ dimension equal to the length of timesteps.
995
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
996
+ """
997
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
998
+ while len(res.shape) < len(broadcast_shape):
999
+ res = res[..., None]
1000
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
interpolation/diffusion/respace.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+ import torch
6
+ import numpy as np
7
+ import torch as th
8
+
9
+ from .gaussian_diffusion import GaussianDiffusion
10
+
11
+
12
+ def space_timesteps(num_timesteps, section_counts):
13
+ """
14
+ Create a list of timesteps to use from an original diffusion process,
15
+ given the number of timesteps we want to take from equally-sized portions
16
+ of the original process.
17
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
18
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
19
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
20
+ If the stride is a string starting with "ddim", then the fixed striding
21
+ from the DDIM paper is used, and only one section is allowed.
22
+ :param num_timesteps: the number of diffusion steps in the original
23
+ process to divide up.
24
+ :param section_counts: either a list of numbers, or a string containing
25
+ comma-separated numbers, indicating the step count
26
+ per section. As a special case, use "ddimN" where N
27
+ is a number of steps to use the striding from the
28
+ DDIM paper.
29
+ :return: a set of diffusion steps from the original process to use.
30
+ """
31
+ if isinstance(section_counts, str):
32
+ if section_counts.startswith("ddim"):
33
+ desired_count = int(section_counts[len("ddim") :])
34
+ for i in range(1, num_timesteps):
35
+ if len(range(0, num_timesteps, i)) == desired_count:
36
+ return set(range(0, num_timesteps, i))
37
+ raise ValueError(
38
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
39
+ )
40
+ section_counts = [int(x) for x in section_counts.split(",")]
41
+ size_per = num_timesteps // len(section_counts)
42
+ extra = num_timesteps % len(section_counts)
43
+ start_idx = 0
44
+ all_steps = []
45
+ for i, section_count in enumerate(section_counts):
46
+ size = size_per + (1 if i < extra else 0)
47
+ if size < section_count:
48
+ raise ValueError(
49
+ f"cannot divide section of {size} steps into {section_count}"
50
+ )
51
+ if section_count <= 1:
52
+ frac_stride = 1
53
+ else:
54
+ frac_stride = (size - 1) / (section_count - 1)
55
+ cur_idx = 0.0
56
+ taken_steps = []
57
+ for _ in range(section_count):
58
+ taken_steps.append(start_idx + round(cur_idx))
59
+ cur_idx += frac_stride
60
+ all_steps += taken_steps
61
+ start_idx += size
62
+ return set(all_steps)
63
+
64
+
65
+ class SpacedDiffusion(GaussianDiffusion):
66
+ """
67
+ A diffusion process which can skip steps in a base diffusion process.
68
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
69
+ original diffusion process to retain.
70
+ :param kwargs: the kwargs to create the base diffusion process.
71
+ """
72
+
73
+ def __init__(self, use_timesteps, **kwargs):
74
+ self.use_timesteps = set(use_timesteps)
75
+ self.timestep_map = []
76
+ self.original_num_steps = len(kwargs["betas"])
77
+
78
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79
+ last_alpha_cumprod = 1.0
80
+ new_betas = []
81
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82
+ if i in self.use_timesteps:
83
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84
+ last_alpha_cumprod = alpha_cumprod
85
+ self.timestep_map.append(i)
86
+ kwargs["betas"] = np.array(new_betas)
87
+ super().__init__(**kwargs)
88
+
89
+ def p_mean_variance(
90
+ self, model, *args, **kwargs
91
+ ): # pylint: disable=signature-differs
92
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93
+
94
+ # @torch.compile
95
+ def training_losses(
96
+ self, model, *args, **kwargs
97
+ ): # pylint: disable=signature-differs
98
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
99
+
100
+ def condition_mean(self, cond_fn, *args, **kwargs):
101
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
102
+
103
+ def condition_score(self, cond_fn, *args, **kwargs):
104
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
105
+
106
+ def _wrap_model(self, model):
107
+ if isinstance(model, _WrappedModel):
108
+ return model
109
+ return _WrappedModel(
110
+ model, self.timestep_map, self.original_num_steps
111
+ )
112
+
113
+ def _scale_timesteps(self, t):
114
+ # Scaling is done by the wrapped model.
115
+ return t
116
+
117
+
118
+ class _WrappedModel:
119
+ def __init__(self, model, timestep_map, original_num_steps):
120
+ self.model = model
121
+ self.timestep_map = timestep_map
122
+ # self.rescale_timesteps = rescale_timesteps
123
+ self.original_num_steps = original_num_steps
124
+
125
+ def __call__(self, x, ts, **kwargs):
126
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
127
+ new_ts = map_tensor[ts]
128
+ # if self.rescale_timesteps:
129
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
130
+ return self.model(x, new_ts, **kwargs)
interpolation/diffusion/timestep_sampler.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import numpy as np
9
+ import torch as th
10
+ import torch.distributed as dist
11
+
12
+
13
+ def create_named_schedule_sampler(name, diffusion):
14
+ """
15
+ Create a ScheduleSampler from a library of pre-defined samplers.
16
+ :param name: the name of the sampler.
17
+ :param diffusion: the diffusion object to sample for.
18
+ """
19
+ if name == "uniform":
20
+ return UniformSampler(diffusion)
21
+ elif name == "loss-second-moment":
22
+ return LossSecondMomentResampler(diffusion)
23
+ else:
24
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
25
+
26
+
27
+ class ScheduleSampler(ABC):
28
+ """
29
+ A distribution over timesteps in the diffusion process, intended to reduce
30
+ variance of the objective.
31
+ By default, samplers perform unbiased importance sampling, in which the
32
+ objective's mean is unchanged.
33
+ However, subclasses may override sample() to change how the resampled
34
+ terms are reweighted, allowing for actual changes in the objective.
35
+ """
36
+
37
+ @abstractmethod
38
+ def weights(self):
39
+ """
40
+ Get a numpy array of weights, one per diffusion step.
41
+ The weights needn't be normalized, but must be positive.
42
+ """
43
+
44
+ def sample(self, batch_size, device):
45
+ """
46
+ Importance-sample timesteps for a batch.
47
+ :param batch_size: the number of timesteps.
48
+ :param device: the torch device to save to.
49
+ :return: a tuple (timesteps, weights):
50
+ - timesteps: a tensor of timestep indices.
51
+ - weights: a tensor of weights to scale the resulting losses.
52
+ """
53
+ w = self.weights()
54
+ p = w / np.sum(w)
55
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56
+ indices = th.from_numpy(indices_np).long().to(device)
57
+ weights_np = 1 / (len(p) * p[indices_np])
58
+ weights = th.from_numpy(weights_np).float().to(device)
59
+ return indices, weights
60
+
61
+
62
+ class UniformSampler(ScheduleSampler):
63
+ def __init__(self, diffusion):
64
+ self.diffusion = diffusion
65
+ self._weights = np.ones([diffusion.num_timesteps])
66
+
67
+ def weights(self):
68
+ return self._weights
69
+
70
+
71
+ class LossAwareSampler(ScheduleSampler):
72
+ def update_with_local_losses(self, local_ts, local_losses):
73
+ """
74
+ Update the reweighting using losses from a model.
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+ :param local_ts: an integer Tensor of timesteps.
80
+ :param local_losses: a 1D Tensor of losses.
81
+ """
82
+ batch_sizes = [
83
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
84
+ for _ in range(dist.get_world_size())
85
+ ]
86
+ dist.all_gather(
87
+ batch_sizes,
88
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89
+ )
90
+
91
+ # Pad all_gather batches to be the maximum batch size.
92
+ batch_sizes = [x.item() for x in batch_sizes]
93
+ max_bs = max(batch_sizes)
94
+
95
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97
+ dist.all_gather(timestep_batches, local_ts)
98
+ dist.all_gather(loss_batches, local_losses)
99
+ timesteps = [
100
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101
+ ]
102
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103
+ self.update_with_all_losses(timesteps, losses)
104
+
105
+ @abstractmethod
106
+ def update_with_all_losses(self, ts, losses):
107
+ """
108
+ Update the reweighting using losses from a model.
109
+ Sub-classes should override this method to update the reweighting
110
+ using losses from the model.
111
+ This method directly updates the reweighting without synchronizing
112
+ between workers. It is called by update_with_local_losses from all
113
+ ranks with identical arguments. Thus, it should have deterministic
114
+ behavior to maintain state across workers.
115
+ :param ts: a list of int timesteps.
116
+ :param losses: a list of float losses, one per timestep.
117
+ """
118
+
119
+
120
+ class LossSecondMomentResampler(LossAwareSampler):
121
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122
+ self.diffusion = diffusion
123
+ self.history_per_term = history_per_term
124
+ self.uniform_prob = uniform_prob
125
+ self._loss_history = np.zeros(
126
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
127
+ )
128
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129
+
130
+ def weights(self):
131
+ if not self._warmed_up():
132
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134
+ weights /= np.sum(weights)
135
+ weights *= 1 - self.uniform_prob
136
+ weights += self.uniform_prob / len(weights)
137
+ return weights
138
+
139
+ def update_with_all_losses(self, ts, losses):
140
+ for t, loss in zip(ts, losses):
141
+ if self._loss_counts[t] == self.history_per_term:
142
+ # Shift out the oldest loss term.
143
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
144
+ self._loss_history[t, -1] = loss
145
+ else:
146
+ self._loss_history[t, self._loss_counts[t]] = loss
147
+ self._loss_counts[t] += 1
148
+
149
+ def _warmed_up(self):
150
+ return (self._loss_counts == self.history_per_term).all()
interpolation/download.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # All rights reserved.
2
+
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import os
8
+
9
+
10
+ pretrained_models = {''}
11
+
12
+
13
+ def find_model(model_name):
14
+ """
15
+ Finds a pre-trained model, downloading it if necessary. Alternatively, loads a model from a local path.
16
+ """
17
+ assert os.path.isfile(model_name), f'Could not find checkpoint at {model_name}'
18
+ checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
19
+ if "ema" in checkpoint: # supports checkpoints from train.py
20
+ checkpoint = checkpoint["ema"]
21
+ return checkpoint
22
+
interpolation/models/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.split(sys.path[0])[0])
4
+
5
+ from .unet import UNet3DConditionModel
6
+ from torch.optim.lr_scheduler import LambdaLR
7
+
8
+ def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+ def fn(step):
11
+ if warmup_steps > 0:
12
+ return min(step / warmup_steps, 1)
13
+ else:
14
+ return 1
15
+ return LambdaLR(optimizer, fn)
16
+
17
+
18
+ def get_lr_scheduler(optimizer, name, **kwargs):
19
+ if name == 'warmup':
20
+ return customized_lr_scheduler(optimizer, **kwargs)
21
+ elif name == 'cosine':
22
+ from torch.optim.lr_scheduler import CosineAnnealingLR
23
+ return CosineAnnealingLR(optimizer, **kwargs)
24
+ else:
25
+ raise NotImplementedError(name)
26
+
27
+ def get_models(args, ckpt_path):
28
+
29
+ if 'TSR' in args.model:
30
+ return UNet3DConditionModel.from_pretrained_2d(ckpt_path, subfolder="unet", use_concat=args.use_concat, copy_no_mask=args.copy_no_mask)
31
+ else:
32
+ raise '{} Model Not Supported!'.format(args.model)
33
+
interpolation/models/attention.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+ import os
3
+ import sys
4
+ sys.path.append(os.path.split(sys.path[0])[0])
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Optional
8
+
9
+ import math
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import nn
13
+
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.utils import BaseOutput
16
+ from diffusers.utils.import_utils import is_xformers_available
17
+ from diffusers.models.attention import FeedForward, AdaLayerNorm
18
+
19
+ from einops import rearrange, repeat
20
+
21
+ try:
22
+ from diffusers.models.modeling_utils import ModelMixin
23
+ except:
24
+ from diffusers.modeling_utils import ModelMixin # 0.11.1
25
+
26
+
27
+ @dataclass
28
+ class Transformer3DModelOutput(BaseOutput):
29
+ sample: torch.FloatTensor
30
+
31
+
32
+ if is_xformers_available():
33
+ import xformers
34
+ import xformers.ops
35
+ else:
36
+ xformers = None
37
+
38
+
39
+ class CrossAttention(nn.Module):
40
+ r"""
41
+ copy from diffuser 0.11.1
42
+ A cross attention layer.
43
+ Parameters:
44
+ query_dim (`int`): The number of channels in the query.
45
+ cross_attention_dim (`int`, *optional*):
46
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
47
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
48
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
49
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
50
+ bias (`bool`, *optional*, defaults to False):
51
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ query_dim: int,
57
+ cross_attention_dim: Optional[int] = None,
58
+ heads: int = 8,
59
+ dim_head: int = 64,
60
+ dropout: float = 0.0,
61
+ bias=False,
62
+ upcast_attention: bool = False,
63
+ upcast_softmax: bool = False,
64
+ added_kv_proj_dim: Optional[int] = None,
65
+ norm_num_groups: Optional[int] = None,
66
+ use_relative_position: bool = False,
67
+ ):
68
+ super().__init__()
69
+ inner_dim = dim_head * heads
70
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
71
+ self.upcast_attention = upcast_attention
72
+ self.upcast_softmax = upcast_softmax
73
+
74
+ self.scale = dim_head**-0.5
75
+
76
+ self.heads = heads
77
+ self.dim_head = dim_head
78
+ # for slice_size > 0 the attention score computation
79
+ # is split across the batch axis to save memory
80
+ # You can set slice_size with `set_attention_slice`
81
+ self.sliceable_head_dim = heads
82
+ self._slice_size = None
83
+ self._use_memory_efficient_attention_xformers = False
84
+ self.added_kv_proj_dim = added_kv_proj_dim
85
+
86
+ if norm_num_groups is not None:
87
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
88
+ else:
89
+ self.group_norm = None
90
+
91
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
92
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
93
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
94
+
95
+ if self.added_kv_proj_dim is not None:
96
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
97
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
98
+
99
+ self.to_out = nn.ModuleList([])
100
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
101
+ self.to_out.append(nn.Dropout(dropout))
102
+
103
+ # print(use_relative_position)
104
+ self.use_relative_position = use_relative_position
105
+ if self.use_relative_position:
106
+ # print(dim_head)
107
+ # print(heads)
108
+ # adopt https://github.com/huggingface/transformers/blob/8a817e1ecac6a420b1bdc701fcc33535a3b96ff5/src/transformers/models/bert/modeling_bert.py#L265
109
+ self.max_position_embeddings = 32
110
+ self.distance_embedding = nn.Embedding(2 * self.max_position_embeddings - 1, dim_head)
111
+
112
+ self.dropout = nn.Dropout(dropout)
113
+
114
+
115
+ def reshape_heads_to_batch_dim(self, tensor):
116
+ batch_size, seq_len, dim = tensor.shape
117
+ head_size = self.heads
118
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
119
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
120
+ return tensor
121
+
122
+ def reshape_batch_dim_to_heads(self, tensor):
123
+ batch_size, seq_len, dim = tensor.shape
124
+ head_size = self.heads
125
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
126
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
127
+ return tensor
128
+
129
+ def reshape_for_scores(self, tensor):
130
+ # split heads and dims
131
+ # tensor should be [b (h w)] f (d nd)
132
+ batch_size, seq_len, dim = tensor.shape
133
+ head_size = self.heads
134
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
135
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
136
+ return tensor
137
+
138
+ def same_batch_dim_to_heads(self, tensor):
139
+ batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d
140
+ tensor = tensor.reshape(batch_size, seq_len, dim * head_size)
141
+ return tensor
142
+
143
+ def set_attention_slice(self, slice_size):
144
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
145
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
146
+
147
+ self._slice_size = slice_size
148
+
149
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
150
+ batch_size, sequence_length, _ = hidden_states.shape
151
+
152
+ encoder_hidden_states = encoder_hidden_states
153
+
154
+ if self.group_norm is not None:
155
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
156
+
157
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
158
+ # if self.use_relative_position:
159
+ # print('before attention query shape', query.shape)
160
+ dim = query.shape[-1]
161
+ query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d
162
+ # if self.use_relative_position:
163
+ # print('before attention query shape', query.shape)
164
+
165
+ if self.added_kv_proj_dim is not None:
166
+ key = self.to_k(hidden_states)
167
+ value = self.to_v(hidden_states)
168
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
169
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
170
+
171
+ key = self.reshape_heads_to_batch_dim(key)
172
+ value = self.reshape_heads_to_batch_dim(value)
173
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
174
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
175
+
176
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
177
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
178
+ else:
179
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
180
+ key = self.to_k(encoder_hidden_states)
181
+ value = self.to_v(encoder_hidden_states)
182
+
183
+ key = self.reshape_heads_to_batch_dim(key)
184
+ value = self.reshape_heads_to_batch_dim(value)
185
+
186
+ if attention_mask is not None:
187
+ if attention_mask.shape[-1] != query.shape[1]:
188
+ target_length = query.shape[1]
189
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
190
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
191
+
192
+ # attention, what we cannot get enough of
193
+ if self._use_memory_efficient_attention_xformers:
194
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
195
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
196
+ hidden_states = hidden_states.to(query.dtype)
197
+ else:
198
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
199
+ hidden_states = self._attention(query, key, value, attention_mask)
200
+ else:
201
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
202
+
203
+ # linear proj
204
+ hidden_states = self.to_out[0](hidden_states)
205
+
206
+ # dropout
207
+ hidden_states = self.to_out[1](hidden_states)
208
+ return hidden_states
209
+
210
+
211
+ def _attention(self, query, key, value, attention_mask=None):
212
+ if self.upcast_attention:
213
+ query = query.float()
214
+ key = key.float()
215
+
216
+ if self.use_relative_position:
217
+ query = self.reshape_for_scores(self.reshape_batch_dim_to_heads(query))
218
+ key = self.reshape_for_scores(self.reshape_batch_dim_to_heads(key))
219
+ value = self.reshape_for_scores(self.reshape_batch_dim_to_heads(value))
220
+
221
+ # torch.baddbmm only accepte 3-D tensor
222
+ # https://runebook.dev/zh/docs/pytorch/generated/torch.baddbmm
223
+ attention_scores = self.scale * torch.matmul(query, key.transpose(-1, -2))
224
+
225
+ # print('attention_scores shape', attention_scores.shape)
226
+
227
+ # print(query.shape) # [b (h w)] nd f d
228
+ query_length, key_length = query.shape[2], key.shape[2]
229
+ # print('query shape', query.shape)
230
+ # print('key shape', key.shape)
231
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) # hidden_states.device
232
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=key.device).view(1, -1) # hidden_states.device
233
+ distance = position_ids_l - position_ids_r
234
+ # print('distance shape', distance.shape)
235
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
236
+ positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility
237
+ # print('positional_embedding shape', positional_embedding.shape)
238
+ relative_position_scores_query = torch.einsum("bhld, lrd -> bhlr", query, positional_embedding)
239
+ relative_position_scores_key = torch.einsum("bhrd, lrd -> bhlr", key, positional_embedding)
240
+ # print('relative_position_scores_key shape', relative_position_scores_key.shape)
241
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
242
+ # print(attention_scores.shape)
243
+
244
+ attention_scores = attention_scores / math.sqrt(self.dim_head)
245
+
246
+ # Normalize the attention scores to probabilities.
247
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
248
+
249
+ # cast back to the original dtype
250
+ attention_probs = attention_probs.to(value.dtype)
251
+
252
+ # compute attention output
253
+ hidden_states = torch.matmul(attention_probs, value)
254
+ # print(hidden_states.shape)
255
+ hidden_states = self.same_batch_dim_to_heads(hidden_states)
256
+ # print(hidden_states.shape)
257
+ # exit()
258
+
259
+ else:
260
+ attention_scores = torch.baddbmm(
261
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
262
+ query,
263
+ key.transpose(-1, -2),
264
+ beta=0,
265
+ alpha=self.scale,
266
+ )
267
+
268
+ if attention_mask is not None:
269
+ attention_scores = attention_scores + attention_mask
270
+
271
+ if self.upcast_softmax:
272
+ attention_scores = attention_scores.float()
273
+
274
+ attention_probs = attention_scores.softmax(dim=-1)
275
+ # print(attention_probs.shape)
276
+
277
+ # cast back to the original dtype
278
+ attention_probs = attention_probs.to(value.dtype)
279
+ # print(attention_probs.shape)
280
+
281
+ # compute attention output
282
+ hidden_states = torch.bmm(attention_probs, value)
283
+ # print(hidden_states.shape)
284
+
285
+ # reshape hidden_states
286
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
287
+ # print(hidden_states.shape)
288
+ # exit()
289
+ return hidden_states
290
+
291
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
292
+ batch_size_attention = query.shape[0]
293
+ hidden_states = torch.zeros(
294
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
295
+ )
296
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
297
+ for i in range(hidden_states.shape[0] // slice_size):
298
+ start_idx = i * slice_size
299
+ end_idx = (i + 1) * slice_size
300
+
301
+ query_slice = query[start_idx:end_idx]
302
+ key_slice = key[start_idx:end_idx]
303
+
304
+ if self.upcast_attention:
305
+ query_slice = query_slice.float()
306
+ key_slice = key_slice.float()
307
+
308
+ attn_slice = torch.baddbmm(
309
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
310
+ query_slice,
311
+ key_slice.transpose(-1, -2),
312
+ beta=0,
313
+ alpha=self.scale,
314
+ )
315
+
316
+ if attention_mask is not None:
317
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
318
+
319
+ if self.upcast_softmax:
320
+ attn_slice = attn_slice.float()
321
+
322
+ attn_slice = attn_slice.softmax(dim=-1)
323
+
324
+ # cast back to the original dtype
325
+ attn_slice = attn_slice.to(value.dtype)
326
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
327
+
328
+ hidden_states[start_idx:end_idx] = attn_slice
329
+
330
+ # reshape hidden_states
331
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
332
+ return hidden_states
333
+
334
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
335
+ # TODO attention_mask
336
+ query = query.contiguous()
337
+ key = key.contiguous()
338
+ value = value.contiguous()
339
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
340
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
341
+ return hidden_states
342
+
343
+
344
+ class Transformer3DModel(ModelMixin, ConfigMixin):
345
+ @register_to_config
346
+ def __init__(
347
+ self,
348
+ num_attention_heads: int = 16,
349
+ attention_head_dim: int = 88,
350
+ in_channels: Optional[int] = None,
351
+ num_layers: int = 1,
352
+ dropout: float = 0.0,
353
+ norm_num_groups: int = 32,
354
+ cross_attention_dim: Optional[int] = None,
355
+ attention_bias: bool = False,
356
+ activation_fn: str = "geglu",
357
+ num_embeds_ada_norm: Optional[int] = None,
358
+ use_linear_projection: bool = False,
359
+ only_cross_attention: bool = False,
360
+ upcast_attention: bool = False,
361
+ use_first_frame: bool = False,
362
+ use_relative_position: bool = False,
363
+ ):
364
+ super().__init__()
365
+ self.use_linear_projection = use_linear_projection
366
+ self.num_attention_heads = num_attention_heads
367
+ self.attention_head_dim = attention_head_dim
368
+ inner_dim = num_attention_heads * attention_head_dim
369
+
370
+ # Define input layers
371
+ self.in_channels = in_channels
372
+
373
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
374
+ if use_linear_projection:
375
+ self.proj_in = nn.Linear(in_channels, inner_dim)
376
+ else:
377
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
378
+
379
+ # Define transformers blocks
380
+ self.transformer_blocks = nn.ModuleList(
381
+ [
382
+ BasicTransformerBlock(
383
+ inner_dim,
384
+ num_attention_heads,
385
+ attention_head_dim,
386
+ dropout=dropout,
387
+ cross_attention_dim=cross_attention_dim,
388
+ activation_fn=activation_fn,
389
+ num_embeds_ada_norm=num_embeds_ada_norm,
390
+ attention_bias=attention_bias,
391
+ only_cross_attention=only_cross_attention,
392
+ upcast_attention=upcast_attention,
393
+ use_first_frame=use_first_frame,
394
+ use_relative_position=use_relative_position,
395
+ )
396
+ for d in range(num_layers)
397
+ ]
398
+ )
399
+
400
+ # 4. Define output layers
401
+ if use_linear_projection:
402
+ self.proj_out = nn.Linear(in_channels, inner_dim)
403
+ else:
404
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
405
+
406
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
407
+ # Input
408
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
409
+ video_length = hidden_states.shape[2]
410
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
411
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
412
+
413
+ batch, channel, height, weight = hidden_states.shape
414
+ residual = hidden_states
415
+
416
+ hidden_states = self.norm(hidden_states)
417
+ if not self.use_linear_projection:
418
+ hidden_states = self.proj_in(hidden_states)
419
+ inner_dim = hidden_states.shape[1]
420
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
421
+ else:
422
+ inner_dim = hidden_states.shape[1]
423
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
424
+ hidden_states = self.proj_in(hidden_states)
425
+
426
+ # Blocks
427
+ for block in self.transformer_blocks:
428
+ hidden_states = block(
429
+ hidden_states,
430
+ encoder_hidden_states=encoder_hidden_states,
431
+ timestep=timestep,
432
+ video_length=video_length
433
+ )
434
+
435
+ # Output
436
+ if not self.use_linear_projection:
437
+ hidden_states = (
438
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
439
+ )
440
+ hidden_states = self.proj_out(hidden_states)
441
+ else:
442
+ hidden_states = self.proj_out(hidden_states)
443
+ hidden_states = (
444
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
445
+ )
446
+
447
+ output = hidden_states + residual
448
+
449
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
450
+ if not return_dict:
451
+ return (output,)
452
+
453
+ return Transformer3DModelOutput(sample=output)
454
+
455
+
456
+ class BasicTransformerBlock(nn.Module):
457
+ def __init__(
458
+ self,
459
+ dim: int,
460
+ num_attention_heads: int,
461
+ attention_head_dim: int,
462
+ dropout=0.0,
463
+ cross_attention_dim: Optional[int] = None,
464
+ activation_fn: str = "geglu",
465
+ num_embeds_ada_norm: Optional[int] = None,
466
+ attention_bias: bool = False,
467
+ only_cross_attention: bool = False,
468
+ upcast_attention: bool = False,
469
+ use_first_frame: bool = False,
470
+ use_relative_position: bool = False,
471
+ ):
472
+ super().__init__()
473
+ self.only_cross_attention = only_cross_attention
474
+ # print(only_cross_attention)
475
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
476
+ self.use_first_frame = use_first_frame
477
+
478
+ # SC-Attn
479
+ if use_first_frame:
480
+ self.attn1 = SparseCausalAttention(
481
+ query_dim=dim,
482
+ heads=num_attention_heads,
483
+ dim_head=attention_head_dim,
484
+ dropout=dropout,
485
+ bias=attention_bias,
486
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
487
+ upcast_attention=upcast_attention,
488
+ )
489
+ # print(cross_attention_dim)
490
+ else:
491
+ self.attn1 = CrossAttention(
492
+ query_dim=dim,
493
+ heads=num_attention_heads,
494
+ dim_head=attention_head_dim,
495
+ dropout=dropout,
496
+ bias=attention_bias,
497
+ cross_attention_dim=None,
498
+ upcast_attention=upcast_attention,
499
+ )
500
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
501
+
502
+ # Cross-Attn
503
+ if cross_attention_dim is not None:
504
+ self.attn2 = CrossAttention(
505
+ query_dim=dim,
506
+ cross_attention_dim=cross_attention_dim,
507
+ heads=num_attention_heads,
508
+ dim_head=attention_head_dim,
509
+ dropout=dropout,
510
+ bias=attention_bias,
511
+ upcast_attention=upcast_attention,
512
+ )
513
+ else:
514
+ self.attn2 = None
515
+
516
+ if cross_attention_dim is not None:
517
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
518
+ else:
519
+ self.norm2 = None
520
+
521
+ # Feed-forward
522
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
523
+ self.norm3 = nn.LayerNorm(dim)
524
+
525
+ # Temp-Attn
526
+ self.attn_temp = CrossAttention(
527
+ query_dim=dim,
528
+ heads=num_attention_heads,
529
+ dim_head=attention_head_dim,
530
+ dropout=dropout,
531
+ bias=attention_bias,
532
+ upcast_attention=upcast_attention,
533
+ use_relative_position=use_relative_position
534
+ )
535
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
536
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
537
+
538
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op=None):
539
+ if not is_xformers_available():
540
+ print("Here is how to install it")
541
+ raise ModuleNotFoundError(
542
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
543
+ " xformers",
544
+ name="xformers",
545
+ )
546
+ elif not torch.cuda.is_available():
547
+ raise ValueError(
548
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
549
+ " available for GPU "
550
+ )
551
+ else:
552
+ try:
553
+ # Make sure we can run the memory efficient attention
554
+ _ = xformers.ops.memory_efficient_attention(
555
+ torch.randn((1, 2, 40), device="cuda"),
556
+ torch.randn((1, 2, 40), device="cuda"),
557
+ torch.randn((1, 2, 40), device="cuda"),
558
+ )
559
+ except Exception as e:
560
+ raise e
561
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
562
+ if self.attn2 is not None:
563
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
564
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
565
+
566
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
567
+ # SparseCausal-Attention
568
+ norm_hidden_states = (
569
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
570
+ )
571
+
572
+ if self.only_cross_attention:
573
+ hidden_states = (
574
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
575
+ )
576
+ else:
577
+ if self.use_first_frame:
578
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
579
+ else:
580
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
581
+
582
+ if self.attn2 is not None:
583
+ # Cross-Attention
584
+ norm_hidden_states = (
585
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
586
+ )
587
+ hidden_states = (
588
+ self.attn2(
589
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
590
+ )
591
+ + hidden_states
592
+ )
593
+
594
+ # Feed-forward
595
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
596
+
597
+ # Temporal-Attention
598
+ d = hidden_states.shape[1]
599
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
600
+ norm_hidden_states = (
601
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
602
+ )
603
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
604
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
605
+
606
+ return hidden_states
607
+
608
+
609
+ class SparseCausalAttention(CrossAttention):
610
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
611
+ batch_size, sequence_length, _ = hidden_states.shape
612
+
613
+ encoder_hidden_states = encoder_hidden_states
614
+
615
+ if self.group_norm is not None:
616
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
617
+
618
+ query = self.to_q(hidden_states)
619
+ dim = query.shape[-1]
620
+ query = self.reshape_heads_to_batch_dim(query)
621
+
622
+ if self.added_kv_proj_dim is not None:
623
+ raise NotImplementedError
624
+
625
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
626
+ key = self.to_k(encoder_hidden_states)
627
+ value = self.to_v(encoder_hidden_states)
628
+
629
+ former_frame_index = torch.arange(video_length) - 1
630
+ former_frame_index[0] = 0
631
+
632
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
633
+ key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
634
+ key = rearrange(key, "b f d c -> (b f) d c")
635
+
636
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
637
+ value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
638
+ value = rearrange(value, "b f d c -> (b f) d c")
639
+
640
+ key = self.reshape_heads_to_batch_dim(key)
641
+ value = self.reshape_heads_to_batch_dim(value)
642
+
643
+ if attention_mask is not None:
644
+ if attention_mask.shape[-1] != query.shape[1]:
645
+ target_length = query.shape[1]
646
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
647
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
648
+
649
+ # attention, what we cannot get enough of
650
+ if self._use_memory_efficient_attention_xformers:
651
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
652
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
653
+ hidden_states = hidden_states.to(query.dtype)
654
+ else:
655
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
656
+ hidden_states = self._attention(query, key, value, attention_mask)
657
+ else:
658
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
659
+
660
+ # linear proj
661
+ hidden_states = self.to_out[0](hidden_states)
662
+
663
+ # dropout
664
+ hidden_states = self.to_out[1](hidden_states)
665
+ return hidden_states
interpolation/models/clip.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy
2
+ import torch.nn as nn
3
+ from transformers import CLIPTokenizer, CLIPTextModel
4
+
5
+ import transformers
6
+ transformers.logging.set_verbosity_error()
7
+
8
+ """
9
+ Will encounter following warning:
10
+ - This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task
11
+ or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
12
+ - This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model
13
+ that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
14
+
15
+ https://github.com/CompVis/stable-diffusion/issues/97
16
+ according to this issue, this warning is safe.
17
+
18
+ This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion.
19
+ You can safely ignore the warning, it is not an error.
20
+
21
+ This clip usage is from U-ViT and same with Stable Diffusion.
22
+ """
23
+
24
+ class AbstractEncoder(nn.Module):
25
+ def __init__(self):
26
+ super().__init__()
27
+
28
+ def encode(self, *args, **kwargs):
29
+ raise NotImplementedError
30
+
31
+
32
+ class FrozenCLIPEmbedder(AbstractEncoder):
33
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
34
+ def __init__(self, sd_path, device="cuda", max_length=77):
35
+ super().__init__()
36
+ self.tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer", use_fast=False)
37
+ self.transformer = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder")
38
+ self.device = device
39
+ self.max_length = max_length
40
+ self.freeze()
41
+
42
+ def freeze(self):
43
+ self.transformer = self.transformer.eval()
44
+ for param in self.parameters():
45
+ param.requires_grad = False
46
+
47
+ def forward(self, text):
48
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
49
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
50
+ tokens = batch_encoding["input_ids"].to(self.device)
51
+ outputs = self.transformer(input_ids=tokens)
52
+
53
+ z = outputs.last_hidden_state
54
+ return z
55
+
56
+ def encode(self, text):
57
+ return self(text)
58
+
59
+
60
+ class TextEmbedder(nn.Module):
61
+ """
62
+ Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
63
+ """
64
+ def __init__(self, args, dropout_prob=0.1):
65
+ super().__init__()
66
+ self.text_encodder = FrozenCLIPEmbedder(args)
67
+ self.dropout_prob = dropout_prob
68
+
69
+ def token_drop(self, text_prompts, force_drop_ids=None):
70
+ """
71
+ Drops text to enable classifier-free guidance.
72
+ """
73
+ if force_drop_ids is None:
74
+ drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob
75
+ else:
76
+ # TODO
77
+ drop_ids = force_drop_ids == 1
78
+ labels = list(numpy.where(drop_ids, "None", text_prompts))
79
+ # print(labels)
80
+ return labels
81
+
82
+ def forward(self, text_prompts, train, force_drop_ids=None):
83
+ use_dropout = self.dropout_prob > 0
84
+ if (train and use_dropout) or (force_drop_ids is not None):
85
+ text_prompts = self.token_drop(text_prompts, force_drop_ids)
86
+ embeddings = self.text_encodder(text_prompts)
87
+ return embeddings
88
+
89
+
90
+ if __name__ == '__main__':
91
+
92
+ r"""
93
+ Returns:
94
+
95
+ Examples from CLIPTextModel:
96
+
97
+ ```python
98
+ >>> from transformers import AutoTokenizer, CLIPTextModel
99
+
100
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
101
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
102
+
103
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
104
+
105
+ >>> outputs = model(**inputs)
106
+ >>> last_hidden_state = outputs.last_hidden_state
107
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
108
+ ```"""
109
+
110
+ import torch
111
+
112
+ device = "cuda" if torch.cuda.is_available() else "cpu"
113
+
114
+ text_encoder = TextEmbedder(dropout_prob=0.00001).to(device)
115
+ text_encoder1 = FrozenCLIPEmbedder().to(device)
116
+
117
+ text_prompt = ["a photo of a cat", "a photo of a dog", 'a photo of a dog human']
118
+ # text_prompt = ('None', 'None', 'None')
119
+ output = text_encoder(text_prompts=text_prompt, train=True)
120
+ output1 = text_encoder1(text_prompt)
121
+ # print(output)
122
+ print(output.shape)
123
+ print(output1.shape)
124
+ print((output==output1).all())
interpolation/models/resnet.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+ import os
3
+ import sys
4
+ sys.path.append(os.path.split(sys.path[0])[0])
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange
11
+
12
+
13
+ class InflatedConv3d(nn.Conv2d):
14
+ def forward(self, x):
15
+ video_length = x.shape[2]
16
+
17
+ x = rearrange(x, "b c f h w -> (b f) c h w")
18
+ x = super().forward(x)
19
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
20
+
21
+ return x
22
+
23
+
24
+ class Upsample3D(nn.Module):
25
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
26
+ super().__init__()
27
+ self.channels = channels
28
+ self.out_channels = out_channels or channels
29
+ self.use_conv = use_conv
30
+ self.use_conv_transpose = use_conv_transpose
31
+ self.name = name
32
+
33
+ conv = None
34
+ if use_conv_transpose:
35
+ raise NotImplementedError
36
+ elif use_conv:
37
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
38
+
39
+ if name == "conv":
40
+ self.conv = conv
41
+ else:
42
+ self.Conv2d_0 = conv
43
+
44
+ def forward(self, hidden_states, output_size=None):
45
+ assert hidden_states.shape[1] == self.channels
46
+
47
+ if self.use_conv_transpose:
48
+ raise NotImplementedError
49
+
50
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
51
+ dtype = hidden_states.dtype
52
+ if dtype == torch.bfloat16:
53
+ hidden_states = hidden_states.to(torch.float32)
54
+
55
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
56
+ if hidden_states.shape[0] >= 64:
57
+ hidden_states = hidden_states.contiguous()
58
+
59
+ # if `output_size` is passed we force the interpolation output
60
+ # size and do not make use of `scale_factor=2`
61
+ if output_size is None:
62
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
63
+ else:
64
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
65
+
66
+ # If the input is bfloat16, we cast back to bfloat16
67
+ if dtype == torch.bfloat16:
68
+ hidden_states = hidden_states.to(dtype)
69
+
70
+ if self.use_conv:
71
+ if self.name == "conv":
72
+ hidden_states = self.conv(hidden_states)
73
+ else:
74
+ hidden_states = self.Conv2d_0(hidden_states)
75
+
76
+ return hidden_states
77
+
78
+
79
+ class Downsample3D(nn.Module):
80
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
81
+ super().__init__()
82
+ self.channels = channels
83
+ self.out_channels = out_channels or channels
84
+ self.use_conv = use_conv
85
+ self.padding = padding
86
+ stride = 2
87
+ self.name = name
88
+
89
+ if use_conv:
90
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
91
+ else:
92
+ raise NotImplementedError
93
+
94
+ if name == "conv":
95
+ self.Conv2d_0 = conv
96
+ self.conv = conv
97
+ elif name == "Conv2d_0":
98
+ self.conv = conv
99
+ else:
100
+ self.conv = conv
101
+
102
+ def forward(self, hidden_states):
103
+ assert hidden_states.shape[1] == self.channels
104
+ if self.use_conv and self.padding == 0:
105
+ raise NotImplementedError
106
+
107
+ assert hidden_states.shape[1] == self.channels
108
+ hidden_states = self.conv(hidden_states)
109
+
110
+ return hidden_states
111
+
112
+
113
+ class ResnetBlock3D(nn.Module):
114
+ def __init__(
115
+ self,
116
+ *,
117
+ in_channels,
118
+ out_channels=None,
119
+ conv_shortcut=False,
120
+ dropout=0.0,
121
+ temb_channels=512,
122
+ groups=32,
123
+ groups_out=None,
124
+ pre_norm=True,
125
+ eps=1e-6,
126
+ non_linearity="swish",
127
+ time_embedding_norm="default",
128
+ output_scale_factor=1.0,
129
+ use_in_shortcut=None,
130
+ ):
131
+ super().__init__()
132
+ self.pre_norm = pre_norm
133
+ self.pre_norm = True
134
+ self.in_channels = in_channels
135
+ out_channels = in_channels if out_channels is None else out_channels
136
+ self.out_channels = out_channels
137
+ self.use_conv_shortcut = conv_shortcut
138
+ self.time_embedding_norm = time_embedding_norm
139
+ self.output_scale_factor = output_scale_factor
140
+
141
+ if groups_out is None:
142
+ groups_out = groups
143
+
144
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
145
+
146
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
147
+
148
+ if temb_channels is not None:
149
+ if self.time_embedding_norm == "default":
150
+ time_emb_proj_out_channels = out_channels
151
+ elif self.time_embedding_norm == "scale_shift":
152
+ time_emb_proj_out_channels = out_channels * 2
153
+ else:
154
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
155
+
156
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
157
+ else:
158
+ self.time_emb_proj = None
159
+
160
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
161
+ self.dropout = torch.nn.Dropout(dropout)
162
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
163
+
164
+ if non_linearity == "swish":
165
+ self.nonlinearity = lambda x: F.silu(x)
166
+ elif non_linearity == "mish":
167
+ self.nonlinearity = Mish()
168
+ elif non_linearity == "silu":
169
+ self.nonlinearity = nn.SiLU()
170
+
171
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
172
+
173
+ self.conv_shortcut = None
174
+ if self.use_in_shortcut:
175
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
176
+
177
+ def forward(self, input_tensor, temb):
178
+ hidden_states = input_tensor
179
+
180
+ hidden_states = self.norm1(hidden_states)
181
+ hidden_states = self.nonlinearity(hidden_states)
182
+
183
+ hidden_states = self.conv1(hidden_states)
184
+
185
+ if temb is not None:
186
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
187
+
188
+ if temb is not None and self.time_embedding_norm == "default":
189
+ hidden_states = hidden_states + temb
190
+
191
+ hidden_states = self.norm2(hidden_states)
192
+
193
+ if temb is not None and self.time_embedding_norm == "scale_shift":
194
+ scale, shift = torch.chunk(temb, 2, dim=1)
195
+ hidden_states = hidden_states * (1 + scale) + shift
196
+
197
+ hidden_states = self.nonlinearity(hidden_states)
198
+
199
+ hidden_states = self.dropout(hidden_states)
200
+ hidden_states = self.conv2(hidden_states)
201
+
202
+ if self.conv_shortcut is not None:
203
+ input_tensor = self.conv_shortcut(input_tensor)
204
+
205
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
206
+
207
+ return output_tensor
208
+
209
+
210
+ class Mish(torch.nn.Module):
211
+ def forward(self, hidden_states):
212
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
interpolation/models/unet.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import os
7
+ import sys
8
+ sys.path.append(os.path.split(sys.path[0])[0])
9
+
10
+ import json
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.utils.checkpoint
15
+
16
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
17
+ from diffusers.utils import BaseOutput, logging
18
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
19
+
20
+ try:
21
+ from diffusers.models.modeling_utils import ModelMixin
22
+ except:
23
+ from diffusers.modeling_utils import ModelMixin # 0.11.1
24
+
25
+ try:
26
+ from .unet_blocks import (
27
+ CrossAttnDownBlock3D,
28
+ CrossAttnUpBlock3D,
29
+ DownBlock3D,
30
+ UNetMidBlock3DCrossAttn,
31
+ UpBlock3D,
32
+ get_down_block,
33
+ get_up_block,
34
+ )
35
+ from .resnet import InflatedConv3d
36
+ except:
37
+ from unet_blocks import (
38
+ CrossAttnDownBlock3D,
39
+ CrossAttnUpBlock3D,
40
+ DownBlock3D,
41
+ UNetMidBlock3DCrossAttn,
42
+ UpBlock3D,
43
+ get_down_block,
44
+ get_up_block,
45
+ )
46
+ from resnet import InflatedConv3d
47
+
48
+
49
+
50
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
+
52
+
53
+ @dataclass
54
+ class UNet3DConditionOutput(BaseOutput):
55
+ sample: torch.FloatTensor
56
+
57
+
58
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
59
+ _supports_gradient_checkpointing = True
60
+
61
+ @register_to_config
62
+ def __init__(
63
+ self,
64
+ sample_size: Optional[int] = None, # 64
65
+ in_channels: int = 4,
66
+ out_channels: int = 4,
67
+ center_input_sample: bool = False,
68
+ flip_sin_to_cos: bool = True,
69
+ freq_shift: int = 0,
70
+ down_block_types: Tuple[str] = (
71
+ "CrossAttnDownBlock3D",
72
+ "CrossAttnDownBlock3D",
73
+ "CrossAttnDownBlock3D",
74
+ "DownBlock3D",
75
+ ),
76
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
77
+ up_block_types: Tuple[str] = (
78
+ "UpBlock3D",
79
+ "CrossAttnUpBlock3D",
80
+ "CrossAttnUpBlock3D",
81
+ "CrossAttnUpBlock3D"
82
+ ),
83
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
84
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
85
+ layers_per_block: int = 2,
86
+ downsample_padding: int = 1,
87
+ mid_block_scale_factor: float = 1,
88
+ act_fn: str = "silu",
89
+ norm_num_groups: int = 32,
90
+ norm_eps: float = 1e-5,
91
+ cross_attention_dim: int = 1280,
92
+ attention_head_dim: Union[int, Tuple[int]] = 8,
93
+ dual_cross_attention: bool = False,
94
+ use_linear_projection: bool = False,
95
+ class_embed_type: Optional[str] = None,
96
+ num_class_embeds: Optional[int] = None,
97
+ upcast_attention: bool = False,
98
+ resnet_time_scale_shift: str = "default",
99
+ use_first_frame: bool = False,
100
+ use_relative_position: bool = False,
101
+ ):
102
+ super().__init__()
103
+
104
+ # print(use_first_frame)
105
+
106
+ self.sample_size = sample_size
107
+ time_embed_dim = block_out_channels[0] * 4
108
+
109
+ # input
110
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
111
+
112
+ # time
113
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
114
+ timestep_input_dim = block_out_channels[0]
115
+
116
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
117
+
118
+ # class embedding
119
+ if class_embed_type is None and num_class_embeds is not None:
120
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
121
+ elif class_embed_type == "timestep":
122
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
123
+ elif class_embed_type == "identity":
124
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
125
+ else:
126
+ self.class_embedding = None
127
+
128
+ self.down_blocks = nn.ModuleList([])
129
+ self.mid_block = None
130
+ self.up_blocks = nn.ModuleList([])
131
+
132
+ if isinstance(only_cross_attention, bool):
133
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
134
+ # print(only_cross_attention)
135
+ # exit()
136
+
137
+ if isinstance(attention_head_dim, int):
138
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
139
+ # print(attention_head_dim)
140
+ # exit()
141
+
142
+ # down
143
+ output_channel = block_out_channels[0]
144
+ for i, down_block_type in enumerate(down_block_types):
145
+ input_channel = output_channel
146
+ output_channel = block_out_channels[i]
147
+ is_final_block = i == len(block_out_channels) - 1
148
+
149
+ down_block = get_down_block(
150
+ down_block_type,
151
+ num_layers=layers_per_block,
152
+ in_channels=input_channel,
153
+ out_channels=output_channel,
154
+ temb_channels=time_embed_dim,
155
+ add_downsample=not is_final_block,
156
+ resnet_eps=norm_eps,
157
+ resnet_act_fn=act_fn,
158
+ resnet_groups=norm_num_groups,
159
+ cross_attention_dim=cross_attention_dim,
160
+ attn_num_head_channels=attention_head_dim[i],
161
+ downsample_padding=downsample_padding,
162
+ dual_cross_attention=dual_cross_attention,
163
+ use_linear_projection=use_linear_projection,
164
+ only_cross_attention=only_cross_attention[i],
165
+ upcast_attention=upcast_attention,
166
+ resnet_time_scale_shift=resnet_time_scale_shift,
167
+ use_first_frame=use_first_frame,
168
+ use_relative_position=use_relative_position,
169
+ )
170
+ self.down_blocks.append(down_block)
171
+
172
+ # mid
173
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
174
+ self.mid_block = UNetMidBlock3DCrossAttn(
175
+ in_channels=block_out_channels[-1],
176
+ temb_channels=time_embed_dim,
177
+ resnet_eps=norm_eps,
178
+ resnet_act_fn=act_fn,
179
+ output_scale_factor=mid_block_scale_factor,
180
+ resnet_time_scale_shift=resnet_time_scale_shift,
181
+ cross_attention_dim=cross_attention_dim,
182
+ attn_num_head_channels=attention_head_dim[-1],
183
+ resnet_groups=norm_num_groups,
184
+ dual_cross_attention=dual_cross_attention,
185
+ use_linear_projection=use_linear_projection,
186
+ upcast_attention=upcast_attention,
187
+ use_first_frame=use_first_frame,
188
+ use_relative_position=use_relative_position,
189
+ )
190
+ else:
191
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
192
+
193
+ # count how many layers upsample the videos
194
+ self.num_upsamplers = 0
195
+
196
+ # up
197
+ reversed_block_out_channels = list(reversed(block_out_channels))
198
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
199
+ only_cross_attention = list(reversed(only_cross_attention))
200
+ output_channel = reversed_block_out_channels[0]
201
+ for i, up_block_type in enumerate(up_block_types):
202
+ is_final_block = i == len(block_out_channels) - 1
203
+
204
+ prev_output_channel = output_channel
205
+ output_channel = reversed_block_out_channels[i]
206
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
207
+
208
+ # add upsample block for all BUT final layer
209
+ if not is_final_block:
210
+ add_upsample = True
211
+ self.num_upsamplers += 1
212
+ else:
213
+ add_upsample = False
214
+
215
+ up_block = get_up_block(
216
+ up_block_type,
217
+ num_layers=layers_per_block + 1,
218
+ in_channels=input_channel,
219
+ out_channels=output_channel,
220
+ prev_output_channel=prev_output_channel,
221
+ temb_channels=time_embed_dim,
222
+ add_upsample=add_upsample,
223
+ resnet_eps=norm_eps,
224
+ resnet_act_fn=act_fn,
225
+ resnet_groups=norm_num_groups,
226
+ cross_attention_dim=cross_attention_dim,
227
+ attn_num_head_channels=reversed_attention_head_dim[i],
228
+ dual_cross_attention=dual_cross_attention,
229
+ use_linear_projection=use_linear_projection,
230
+ only_cross_attention=only_cross_attention[i],
231
+ upcast_attention=upcast_attention,
232
+ resnet_time_scale_shift=resnet_time_scale_shift,
233
+ use_first_frame=use_first_frame,
234
+ use_relative_position=use_relative_position,
235
+ )
236
+ self.up_blocks.append(up_block)
237
+ prev_output_channel = output_channel
238
+
239
+ # out
240
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
241
+ self.conv_act = nn.SiLU()
242
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
243
+
244
+ def set_attention_slice(self, slice_size):
245
+ r"""
246
+ Enable sliced attention computation.
247
+
248
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
249
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
250
+
251
+ Args:
252
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
253
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
254
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
255
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
256
+ must be a multiple of `slice_size`.
257
+ """
258
+ sliceable_head_dims = []
259
+
260
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
261
+ if hasattr(module, "set_attention_slice"):
262
+ sliceable_head_dims.append(module.sliceable_head_dim)
263
+
264
+ for child in module.children():
265
+ fn_recursive_retrieve_slicable_dims(child)
266
+
267
+ # retrieve number of attention layers
268
+ for module in self.children():
269
+ fn_recursive_retrieve_slicable_dims(module)
270
+
271
+ num_slicable_layers = len(sliceable_head_dims)
272
+
273
+ if slice_size == "auto":
274
+ # half the attention head size is usually a good trade-off between
275
+ # speed and memory
276
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
277
+ elif slice_size == "max":
278
+ # make smallest slice possible
279
+ slice_size = num_slicable_layers * [1]
280
+
281
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
282
+
283
+ if len(slice_size) != len(sliceable_head_dims):
284
+ raise ValueError(
285
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
286
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
287
+ )
288
+
289
+ for i in range(len(slice_size)):
290
+ size = slice_size[i]
291
+ dim = sliceable_head_dims[i]
292
+ if size is not None and size > dim:
293
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
294
+
295
+ # Recursively walk through all the children.
296
+ # Any children which exposes the set_attention_slice method
297
+ # gets the message
298
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
299
+ if hasattr(module, "set_attention_slice"):
300
+ module.set_attention_slice(slice_size.pop())
301
+
302
+ for child in module.children():
303
+ fn_recursive_set_attention_slice(child, slice_size)
304
+
305
+ reversed_slice_size = list(reversed(slice_size))
306
+ for module in self.children():
307
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
308
+
309
+ def _set_gradient_checkpointing(self, module, value=False):
310
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
311
+ module.gradient_checkpointing = value
312
+
313
+ def forward(
314
+ self,
315
+ sample: torch.FloatTensor,
316
+ timestep: Union[torch.Tensor, float, int],
317
+ encoder_hidden_states: torch.Tensor = None,
318
+ class_labels: Optional[torch.Tensor] = None,
319
+ attention_mask: Optional[torch.Tensor] = None,
320
+ return_dict: bool = True,
321
+ ) -> Union[UNet3DConditionOutput, Tuple]:
322
+ r"""
323
+ Args:
324
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
325
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
326
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
327
+ return_dict (`bool`, *optional*, defaults to `True`):
328
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
329
+
330
+ Returns:
331
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
332
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
333
+ returning a tuple, the first element is the sample tensor.
334
+ """
335
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
336
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
337
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
338
+ # on the fly if necessary.
339
+ default_overall_up_factor = 2**self.num_upsamplers
340
+
341
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
342
+ forward_upsample_size = False
343
+ upsample_size = None
344
+
345
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
346
+ logger.info("Forward upsample size to force interpolation output size.")
347
+ forward_upsample_size = True
348
+
349
+ # prepare attention_mask
350
+ if attention_mask is not None:
351
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
352
+ attention_mask = attention_mask.unsqueeze(1)
353
+
354
+ # center input if necessary
355
+ if self.config.center_input_sample:
356
+ sample = 2 * sample - 1.0
357
+
358
+ # time
359
+ timesteps = timestep
360
+ if not torch.is_tensor(timesteps):
361
+ # This would be a good case for the `match` statement (Python 3.10+)
362
+ is_mps = sample.device.type == "mps"
363
+ if isinstance(timestep, float):
364
+ dtype = torch.float32 if is_mps else torch.float64
365
+ else:
366
+ dtype = torch.int32 if is_mps else torch.int64
367
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
368
+ elif len(timesteps.shape) == 0:
369
+ timesteps = timesteps[None].to(sample.device)
370
+
371
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
372
+ timesteps = timesteps.expand(sample.shape[0])
373
+
374
+ t_emb = self.time_proj(timesteps)
375
+
376
+ # timesteps does not contain any weights and will always return f32 tensors
377
+ # but time_embedding might actually be running in fp16. so we need to cast here.
378
+ # there might be better ways to encapsulate this.
379
+ t_emb = t_emb.to(dtype=self.dtype)
380
+ emb = self.time_embedding(t_emb)
381
+
382
+ if self.class_embedding is not None:
383
+ if class_labels is None:
384
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
385
+
386
+ if self.config.class_embed_type == "timestep":
387
+ class_labels = self.time_proj(class_labels)
388
+
389
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
390
+ # print(emb.shape) # torch.Size([3, 1280])
391
+ # print(class_emb.shape) # torch.Size([3, 1280])
392
+ emb = emb + class_emb
393
+
394
+ # pre-process
395
+ sample = self.conv_in(sample)
396
+
397
+ # down
398
+ down_block_res_samples = (sample,)
399
+ for downsample_block in self.down_blocks:
400
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
401
+ sample, res_samples = downsample_block(
402
+ hidden_states=sample,
403
+ temb=emb,
404
+ encoder_hidden_states=encoder_hidden_states,
405
+ attention_mask=attention_mask,
406
+ )
407
+ else:
408
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
409
+
410
+ down_block_res_samples += res_samples
411
+
412
+ # mid
413
+ sample = self.mid_block(
414
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
415
+ )
416
+
417
+ # up
418
+ for i, upsample_block in enumerate(self.up_blocks):
419
+ is_final_block = i == len(self.up_blocks) - 1
420
+
421
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
422
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
423
+
424
+ # if we have not reached the final block and need to forward the
425
+ # upsample size, we do it here
426
+ if not is_final_block and forward_upsample_size:
427
+ upsample_size = down_block_res_samples[-1].shape[2:]
428
+
429
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
430
+ sample = upsample_block(
431
+ hidden_states=sample,
432
+ temb=emb,
433
+ res_hidden_states_tuple=res_samples,
434
+ encoder_hidden_states=encoder_hidden_states,
435
+ upsample_size=upsample_size,
436
+ attention_mask=attention_mask,
437
+ )
438
+ else:
439
+ sample = upsample_block(
440
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
441
+ )
442
+ # post-process
443
+ sample = self.conv_norm_out(sample)
444
+ sample = self.conv_act(sample)
445
+ sample = self.conv_out(sample)
446
+ # print(sample.shape)
447
+
448
+ if not return_dict:
449
+ return (sample,)
450
+ sample = UNet3DConditionOutput(sample=sample)
451
+ return sample
452
+
453
+ def forward_with_cfg(self,
454
+ x,
455
+ t,
456
+ encoder_hidden_states = None,
457
+ class_labels: Optional[torch.Tensor] = None,
458
+ cfg_scale=4.0):
459
+ """
460
+ Forward, but also batches the unconditional forward pass for classifier-free guidance.
461
+ """
462
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
463
+ half = x[: len(x) // 2]
464
+ combined = torch.cat([half, half], dim=0)
465
+ model_out = self.forward(combined, t, encoder_hidden_states, class_labels).sample
466
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
467
+ # three channels by default. The standard approach to cfg applies it to all channels.
468
+ # This can be done by uncommenting the following line and commenting-out the line following that.
469
+ # eps, rest = model_out[:, :4], model_out[:, 4:]
470
+ eps, rest = model_out[:, :4], model_out[:, 4:] # b c f h w
471
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
472
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
473
+ eps = torch.cat([half_eps, half_eps], dim=0)
474
+ return torch.cat([eps, rest], dim=1)
475
+
476
+ @classmethod
477
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, use_concat=False, copy_no_mask=False):
478
+ if subfolder is not None:
479
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
480
+
481
+
482
+ config_file = os.path.join(pretrained_model_path, 'config.json')
483
+ if not os.path.isfile(config_file):
484
+ raise RuntimeError(f"{config_file} does not exist")
485
+ with open(config_file, "r") as f:
486
+ config = json.load(f)
487
+ config["_class_name"] = cls.__name__
488
+ config["down_block_types"] = [
489
+ "CrossAttnDownBlock3D",
490
+ "CrossAttnDownBlock3D",
491
+ "CrossAttnDownBlock3D",
492
+ "DownBlock3D"
493
+ ]
494
+ config["up_block_types"] = [
495
+ "UpBlock3D",
496
+ "CrossAttnUpBlock3D",
497
+ "CrossAttnUpBlock3D",
498
+ "CrossAttnUpBlock3D"
499
+ ]
500
+
501
+ config["use_first_frame"] = True
502
+
503
+ if copy_no_mask:
504
+ config["in_channels"] = 8
505
+ else:
506
+ if use_concat:
507
+ config["in_channels"] = 9
508
+
509
+
510
+ from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin
511
+
512
+
513
+ model = cls.from_config(config)
514
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
515
+ if not os.path.isfile(model_file):
516
+ raise RuntimeError(f"{model_file} does not exist")
517
+ state_dict = torch.load(model_file, map_location="cpu")
518
+
519
+
520
+ if use_concat:
521
+ new_state_dict = {}
522
+ conv_in_weight = state_dict["conv_in.weight"]
523
+
524
+ print(f'from_pretrained_2d copy_no_mask = {copy_no_mask}')
525
+ if copy_no_mask:
526
+ new_conv_in_channel = 8
527
+ new_conv_in_list = [0, 1, 2, 3, 4, 5, 6, 7]
528
+ else:
529
+ new_conv_in_channel = 9
530
+ new_conv_in_list = [0, 1, 2, 3, 4, 5, 6, 7, 8]
531
+ new_conv_weight = torch.zeros((conv_in_weight.shape[0], new_conv_in_channel, *conv_in_weight.shape[2:]), dtype=conv_in_weight.dtype)
532
+
533
+ for i, j in zip([0, 1, 2, 3], new_conv_in_list):
534
+ new_conv_weight[:, j] = conv_in_weight[:, i]
535
+ new_state_dict["conv_in.weight"] = new_conv_weight
536
+ new_state_dict["conv_in.bias"] = state_dict["conv_in.bias"]
537
+ for k, v in model.state_dict().items():
538
+ # print(k)
539
+ if '_temp.' in k:
540
+ new_state_dict.update({k: v})
541
+ elif 'conv_in' in k:
542
+ continue
543
+ else:
544
+ new_state_dict[k] = v
545
+ # # tmp
546
+ # if 'class_embedding' in k:
547
+ # state_dict.update({k: v})
548
+ # breakpoint()
549
+ model.load_state_dict(new_state_dict)
550
+ else:
551
+ for k, v in model.state_dict().items():
552
+ # print(k)
553
+ if '_temp.' in k:
554
+ state_dict.update({k: v})
555
+ model.load_state_dict(state_dict)
556
+ return model
557
+
558
+ if __name__ == '__main__':
559
+ import torch
560
+
561
+ device = "cuda" if torch.cuda.is_available() else "cpu"
562
+
563
+ pretrained_model_path = "/nvme/maxin/work/large-dit-video/pretrained/stable-diffusion-v1-4/" # 43
564
+ unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet").to(device)
565
+
566
+ noisy_latents = torch.randn((3, 4, 16, 32, 32)).to(device)
567
+ bsz = noisy_latents.shape[0]
568
+ timesteps = torch.randint(0, 1000, (bsz,)).to(device)
569
+ timesteps = timesteps.long()
570
+ encoder_hidden_states = torch.randn((bsz, 77, 768)).to(device)
571
+ class_labels = torch.randn((bsz, )).to(device)
572
+
573
+ model_pred = unet(sample=noisy_latents, timestep=timesteps,
574
+ encoder_hidden_states=encoder_hidden_states,
575
+ class_labels=class_labels).sample
576
+ print(model_pred.shape)
interpolation/models/unet_blocks.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+ import os
3
+ import sys
4
+ sys.path.append(os.path.split(sys.path[0])[0])
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ try:
10
+ from .attention import Transformer3DModel
11
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
12
+ except:
13
+ from attention import Transformer3DModel
14
+ from resnet import Downsample3D, ResnetBlock3D, Upsample3D
15
+
16
+
17
+ def get_down_block(
18
+ down_block_type,
19
+ num_layers,
20
+ in_channels,
21
+ out_channels,
22
+ temb_channels,
23
+ add_downsample,
24
+ resnet_eps,
25
+ resnet_act_fn,
26
+ attn_num_head_channels,
27
+ resnet_groups=None,
28
+ cross_attention_dim=None,
29
+ downsample_padding=None,
30
+ dual_cross_attention=False,
31
+ use_linear_projection=False,
32
+ only_cross_attention=False,
33
+ upcast_attention=False,
34
+ resnet_time_scale_shift="default",
35
+ use_first_frame=False,
36
+ use_relative_position=False,
37
+ ):
38
+ # print(down_block_type)
39
+ # print(use_first_frame)
40
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
41
+ if down_block_type == "DownBlock3D":
42
+ return DownBlock3D(
43
+ num_layers=num_layers,
44
+ in_channels=in_channels,
45
+ out_channels=out_channels,
46
+ temb_channels=temb_channels,
47
+ add_downsample=add_downsample,
48
+ resnet_eps=resnet_eps,
49
+ resnet_act_fn=resnet_act_fn,
50
+ resnet_groups=resnet_groups,
51
+ downsample_padding=downsample_padding,
52
+ resnet_time_scale_shift=resnet_time_scale_shift,
53
+ )
54
+ elif down_block_type == "CrossAttnDownBlock3D":
55
+ if cross_attention_dim is None:
56
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
57
+ return CrossAttnDownBlock3D(
58
+ num_layers=num_layers,
59
+ in_channels=in_channels,
60
+ out_channels=out_channels,
61
+ temb_channels=temb_channels,
62
+ add_downsample=add_downsample,
63
+ resnet_eps=resnet_eps,
64
+ resnet_act_fn=resnet_act_fn,
65
+ resnet_groups=resnet_groups,
66
+ downsample_padding=downsample_padding,
67
+ cross_attention_dim=cross_attention_dim,
68
+ attn_num_head_channels=attn_num_head_channels,
69
+ dual_cross_attention=dual_cross_attention,
70
+ use_linear_projection=use_linear_projection,
71
+ only_cross_attention=only_cross_attention,
72
+ upcast_attention=upcast_attention,
73
+ resnet_time_scale_shift=resnet_time_scale_shift,
74
+ use_first_frame=use_first_frame,
75
+ use_relative_position=use_relative_position,
76
+ )
77
+ raise ValueError(f"{down_block_type} does not exist.")
78
+
79
+
80
+ def get_up_block(
81
+ up_block_type,
82
+ num_layers,
83
+ in_channels,
84
+ out_channels,
85
+ prev_output_channel,
86
+ temb_channels,
87
+ add_upsample,
88
+ resnet_eps,
89
+ resnet_act_fn,
90
+ attn_num_head_channels,
91
+ resnet_groups=None,
92
+ cross_attention_dim=None,
93
+ dual_cross_attention=False,
94
+ use_linear_projection=False,
95
+ only_cross_attention=False,
96
+ upcast_attention=False,
97
+ resnet_time_scale_shift="default",
98
+ use_first_frame=False,
99
+ use_relative_position=False,
100
+ ):
101
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
102
+ if up_block_type == "UpBlock3D":
103
+ return UpBlock3D(
104
+ num_layers=num_layers,
105
+ in_channels=in_channels,
106
+ out_channels=out_channels,
107
+ prev_output_channel=prev_output_channel,
108
+ temb_channels=temb_channels,
109
+ add_upsample=add_upsample,
110
+ resnet_eps=resnet_eps,
111
+ resnet_act_fn=resnet_act_fn,
112
+ resnet_groups=resnet_groups,
113
+ resnet_time_scale_shift=resnet_time_scale_shift,
114
+ )
115
+ elif up_block_type == "CrossAttnUpBlock3D":
116
+ if cross_attention_dim is None:
117
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
118
+ return CrossAttnUpBlock3D(
119
+ num_layers=num_layers,
120
+ in_channels=in_channels,
121
+ out_channels=out_channels,
122
+ prev_output_channel=prev_output_channel,
123
+ temb_channels=temb_channels,
124
+ add_upsample=add_upsample,
125
+ resnet_eps=resnet_eps,
126
+ resnet_act_fn=resnet_act_fn,
127
+ resnet_groups=resnet_groups,
128
+ cross_attention_dim=cross_attention_dim,
129
+ attn_num_head_channels=attn_num_head_channels,
130
+ dual_cross_attention=dual_cross_attention,
131
+ use_linear_projection=use_linear_projection,
132
+ only_cross_attention=only_cross_attention,
133
+ upcast_attention=upcast_attention,
134
+ resnet_time_scale_shift=resnet_time_scale_shift,
135
+ use_first_frame=use_first_frame,
136
+ use_relative_position=use_relative_position,
137
+ )
138
+ raise ValueError(f"{up_block_type} does not exist.")
139
+
140
+
141
+ class UNetMidBlock3DCrossAttn(nn.Module):
142
+ def __init__(
143
+ self,
144
+ in_channels: int,
145
+ temb_channels: int,
146
+ dropout: float = 0.0,
147
+ num_layers: int = 1,
148
+ resnet_eps: float = 1e-6,
149
+ resnet_time_scale_shift: str = "default",
150
+ resnet_act_fn: str = "swish",
151
+ resnet_groups: int = 32,
152
+ resnet_pre_norm: bool = True,
153
+ attn_num_head_channels=1,
154
+ output_scale_factor=1.0,
155
+ cross_attention_dim=1280,
156
+ dual_cross_attention=False,
157
+ use_linear_projection=False,
158
+ upcast_attention=False,
159
+ use_first_frame=False,
160
+ use_relative_position=False,
161
+ ):
162
+ super().__init__()
163
+
164
+ self.has_cross_attention = True
165
+ self.attn_num_head_channels = attn_num_head_channels
166
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
167
+
168
+ # there is always at least one resnet
169
+ resnets = [
170
+ ResnetBlock3D(
171
+ in_channels=in_channels,
172
+ out_channels=in_channels,
173
+ temb_channels=temb_channels,
174
+ eps=resnet_eps,
175
+ groups=resnet_groups,
176
+ dropout=dropout,
177
+ time_embedding_norm=resnet_time_scale_shift,
178
+ non_linearity=resnet_act_fn,
179
+ output_scale_factor=output_scale_factor,
180
+ pre_norm=resnet_pre_norm,
181
+ )
182
+ ]
183
+ attentions = []
184
+
185
+ for _ in range(num_layers):
186
+ if dual_cross_attention:
187
+ raise NotImplementedError
188
+ attentions.append(
189
+ Transformer3DModel(
190
+ attn_num_head_channels,
191
+ in_channels // attn_num_head_channels,
192
+ in_channels=in_channels,
193
+ num_layers=1,
194
+ cross_attention_dim=cross_attention_dim,
195
+ norm_num_groups=resnet_groups,
196
+ use_linear_projection=use_linear_projection,
197
+ upcast_attention=upcast_attention,
198
+ use_first_frame=use_first_frame,
199
+ use_relative_position=use_relative_position,
200
+ )
201
+ )
202
+ resnets.append(
203
+ ResnetBlock3D(
204
+ in_channels=in_channels,
205
+ out_channels=in_channels,
206
+ temb_channels=temb_channels,
207
+ eps=resnet_eps,
208
+ groups=resnet_groups,
209
+ dropout=dropout,
210
+ time_embedding_norm=resnet_time_scale_shift,
211
+ non_linearity=resnet_act_fn,
212
+ output_scale_factor=output_scale_factor,
213
+ pre_norm=resnet_pre_norm,
214
+ )
215
+ )
216
+
217
+ self.attentions = nn.ModuleList(attentions)
218
+ self.resnets = nn.ModuleList(resnets)
219
+
220
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
221
+ hidden_states = self.resnets[0](hidden_states, temb)
222
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
223
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
224
+ hidden_states = resnet(hidden_states, temb)
225
+
226
+ return hidden_states
227
+
228
+
229
+ class CrossAttnDownBlock3D(nn.Module):
230
+ def __init__(
231
+ self,
232
+ in_channels: int,
233
+ out_channels: int,
234
+ temb_channels: int,
235
+ dropout: float = 0.0,
236
+ num_layers: int = 1,
237
+ resnet_eps: float = 1e-6,
238
+ resnet_time_scale_shift: str = "default",
239
+ resnet_act_fn: str = "swish",
240
+ resnet_groups: int = 32,
241
+ resnet_pre_norm: bool = True,
242
+ attn_num_head_channels=1,
243
+ cross_attention_dim=1280,
244
+ output_scale_factor=1.0,
245
+ downsample_padding=1,
246
+ add_downsample=True,
247
+ dual_cross_attention=False,
248
+ use_linear_projection=False,
249
+ only_cross_attention=False,
250
+ upcast_attention=False,
251
+ use_first_frame=False,
252
+ use_relative_position=False,
253
+ ):
254
+ super().__init__()
255
+ resnets = []
256
+ attentions = []
257
+
258
+ # print(use_first_frame)
259
+
260
+ self.has_cross_attention = True
261
+ self.attn_num_head_channels = attn_num_head_channels
262
+
263
+ for i in range(num_layers):
264
+ in_channels = in_channels if i == 0 else out_channels
265
+ resnets.append(
266
+ ResnetBlock3D(
267
+ in_channels=in_channels,
268
+ out_channels=out_channels,
269
+ temb_channels=temb_channels,
270
+ eps=resnet_eps,
271
+ groups=resnet_groups,
272
+ dropout=dropout,
273
+ time_embedding_norm=resnet_time_scale_shift,
274
+ non_linearity=resnet_act_fn,
275
+ output_scale_factor=output_scale_factor,
276
+ pre_norm=resnet_pre_norm,
277
+ )
278
+ )
279
+ if dual_cross_attention:
280
+ raise NotImplementedError
281
+ attentions.append(
282
+ Transformer3DModel(
283
+ attn_num_head_channels,
284
+ out_channels // attn_num_head_channels,
285
+ in_channels=out_channels,
286
+ num_layers=1,
287
+ cross_attention_dim=cross_attention_dim,
288
+ norm_num_groups=resnet_groups,
289
+ use_linear_projection=use_linear_projection,
290
+ only_cross_attention=only_cross_attention,
291
+ upcast_attention=upcast_attention,
292
+ use_first_frame=use_first_frame,
293
+ use_relative_position=use_relative_position,
294
+ )
295
+ )
296
+ self.attentions = nn.ModuleList(attentions)
297
+ self.resnets = nn.ModuleList(resnets)
298
+
299
+ if add_downsample:
300
+ self.downsamplers = nn.ModuleList(
301
+ [
302
+ Downsample3D(
303
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
304
+ )
305
+ ]
306
+ )
307
+ else:
308
+ self.downsamplers = None
309
+
310
+ self.gradient_checkpointing = False
311
+
312
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
313
+ output_states = ()
314
+
315
+ for resnet, attn in zip(self.resnets, self.attentions):
316
+ if self.training and self.gradient_checkpointing:
317
+
318
+ def create_custom_forward(module, return_dict=None):
319
+ def custom_forward(*inputs):
320
+ if return_dict is not None:
321
+ return module(*inputs, return_dict=return_dict)
322
+ else:
323
+ return module(*inputs)
324
+
325
+ return custom_forward
326
+
327
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
328
+ hidden_states = torch.utils.checkpoint.checkpoint(
329
+ create_custom_forward(attn, return_dict=False),
330
+ hidden_states,
331
+ encoder_hidden_states,
332
+ )[0]
333
+ else:
334
+ hidden_states = resnet(hidden_states, temb)
335
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
336
+
337
+ output_states += (hidden_states,)
338
+
339
+ if self.downsamplers is not None:
340
+ for downsampler in self.downsamplers:
341
+ hidden_states = downsampler(hidden_states)
342
+
343
+ output_states += (hidden_states,)
344
+
345
+ return hidden_states, output_states
346
+
347
+
348
+ class DownBlock3D(nn.Module):
349
+ def __init__(
350
+ self,
351
+ in_channels: int,
352
+ out_channels: int,
353
+ temb_channels: int,
354
+ dropout: float = 0.0,
355
+ num_layers: int = 1,
356
+ resnet_eps: float = 1e-6,
357
+ resnet_time_scale_shift: str = "default",
358
+ resnet_act_fn: str = "swish",
359
+ resnet_groups: int = 32,
360
+ resnet_pre_norm: bool = True,
361
+ output_scale_factor=1.0,
362
+ add_downsample=True,
363
+ downsample_padding=1,
364
+ ):
365
+ super().__init__()
366
+ resnets = []
367
+
368
+ for i in range(num_layers):
369
+ in_channels = in_channels if i == 0 else out_channels
370
+ resnets.append(
371
+ ResnetBlock3D(
372
+ in_channels=in_channels,
373
+ out_channels=out_channels,
374
+ temb_channels=temb_channels,
375
+ eps=resnet_eps,
376
+ groups=resnet_groups,
377
+ dropout=dropout,
378
+ time_embedding_norm=resnet_time_scale_shift,
379
+ non_linearity=resnet_act_fn,
380
+ output_scale_factor=output_scale_factor,
381
+ pre_norm=resnet_pre_norm,
382
+ )
383
+ )
384
+
385
+ self.resnets = nn.ModuleList(resnets)
386
+
387
+ if add_downsample:
388
+ self.downsamplers = nn.ModuleList(
389
+ [
390
+ Downsample3D(
391
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
392
+ )
393
+ ]
394
+ )
395
+ else:
396
+ self.downsamplers = None
397
+
398
+ self.gradient_checkpointing = False
399
+
400
+ def forward(self, hidden_states, temb=None):
401
+ output_states = ()
402
+
403
+ for resnet in self.resnets:
404
+ if self.training and self.gradient_checkpointing:
405
+
406
+ def create_custom_forward(module):
407
+ def custom_forward(*inputs):
408
+ return module(*inputs)
409
+
410
+ return custom_forward
411
+
412
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
413
+ else:
414
+ hidden_states = resnet(hidden_states, temb)
415
+
416
+ output_states += (hidden_states,)
417
+
418
+ if self.downsamplers is not None:
419
+ for downsampler in self.downsamplers:
420
+ hidden_states = downsampler(hidden_states)
421
+
422
+ output_states += (hidden_states,)
423
+
424
+ return hidden_states, output_states
425
+
426
+
427
+ class CrossAttnUpBlock3D(nn.Module):
428
+ def __init__(
429
+ self,
430
+ in_channels: int,
431
+ out_channels: int,
432
+ prev_output_channel: int,
433
+ temb_channels: int,
434
+ dropout: float = 0.0,
435
+ num_layers: int = 1,
436
+ resnet_eps: float = 1e-6,
437
+ resnet_time_scale_shift: str = "default",
438
+ resnet_act_fn: str = "swish",
439
+ resnet_groups: int = 32,
440
+ resnet_pre_norm: bool = True,
441
+ attn_num_head_channels=1,
442
+ cross_attention_dim=1280,
443
+ output_scale_factor=1.0,
444
+ add_upsample=True,
445
+ dual_cross_attention=False,
446
+ use_linear_projection=False,
447
+ only_cross_attention=False,
448
+ upcast_attention=False,
449
+ use_first_frame=False,
450
+ use_relative_position=False,
451
+ ):
452
+ super().__init__()
453
+ resnets = []
454
+ attentions = []
455
+
456
+ self.has_cross_attention = True
457
+ self.attn_num_head_channels = attn_num_head_channels
458
+
459
+ for i in range(num_layers):
460
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
461
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
462
+
463
+ resnets.append(
464
+ ResnetBlock3D(
465
+ in_channels=resnet_in_channels + res_skip_channels,
466
+ out_channels=out_channels,
467
+ temb_channels=temb_channels,
468
+ eps=resnet_eps,
469
+ groups=resnet_groups,
470
+ dropout=dropout,
471
+ time_embedding_norm=resnet_time_scale_shift,
472
+ non_linearity=resnet_act_fn,
473
+ output_scale_factor=output_scale_factor,
474
+ pre_norm=resnet_pre_norm,
475
+ )
476
+ )
477
+ if dual_cross_attention:
478
+ raise NotImplementedError
479
+ attentions.append(
480
+ Transformer3DModel(
481
+ attn_num_head_channels,
482
+ out_channels // attn_num_head_channels,
483
+ in_channels=out_channels,
484
+ num_layers=1,
485
+ cross_attention_dim=cross_attention_dim,
486
+ norm_num_groups=resnet_groups,
487
+ use_linear_projection=use_linear_projection,
488
+ only_cross_attention=only_cross_attention,
489
+ upcast_attention=upcast_attention,
490
+ use_first_frame=use_first_frame,
491
+ use_relative_position=use_relative_position,
492
+ )
493
+ )
494
+
495
+ self.attentions = nn.ModuleList(attentions)
496
+ self.resnets = nn.ModuleList(resnets)
497
+
498
+ if add_upsample:
499
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
500
+ else:
501
+ self.upsamplers = None
502
+
503
+ self.gradient_checkpointing = False
504
+
505
+ def forward(
506
+ self,
507
+ hidden_states,
508
+ res_hidden_states_tuple,
509
+ temb=None,
510
+ encoder_hidden_states=None,
511
+ upsample_size=None,
512
+ attention_mask=None,
513
+ ):
514
+ for resnet, attn in zip(self.resnets, self.attentions):
515
+ # pop res hidden states
516
+ res_hidden_states = res_hidden_states_tuple[-1]
517
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
518
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
519
+
520
+ if self.training and self.gradient_checkpointing:
521
+
522
+ def create_custom_forward(module, return_dict=None):
523
+ def custom_forward(*inputs):
524
+ if return_dict is not None:
525
+ return module(*inputs, return_dict=return_dict)
526
+ else:
527
+ return module(*inputs)
528
+
529
+ return custom_forward
530
+
531
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
532
+ hidden_states = torch.utils.checkpoint.checkpoint(
533
+ create_custom_forward(attn, return_dict=False),
534
+ hidden_states,
535
+ encoder_hidden_states,
536
+ )[0]
537
+ else:
538
+ hidden_states = resnet(hidden_states, temb)
539
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
540
+
541
+ if self.upsamplers is not None:
542
+ for upsampler in self.upsamplers:
543
+ hidden_states = upsampler(hidden_states, upsample_size)
544
+
545
+ return hidden_states
546
+
547
+
548
+ class UpBlock3D(nn.Module):
549
+ def __init__(
550
+ self,
551
+ in_channels: int,
552
+ prev_output_channel: int,
553
+ out_channels: int,
554
+ temb_channels: int,
555
+ dropout: float = 0.0,
556
+ num_layers: int = 1,
557
+ resnet_eps: float = 1e-6,
558
+ resnet_time_scale_shift: str = "default",
559
+ resnet_act_fn: str = "swish",
560
+ resnet_groups: int = 32,
561
+ resnet_pre_norm: bool = True,
562
+ output_scale_factor=1.0,
563
+ add_upsample=True,
564
+ ):
565
+ super().__init__()
566
+ resnets = []
567
+
568
+ for i in range(num_layers):
569
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
570
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
571
+
572
+ resnets.append(
573
+ ResnetBlock3D(
574
+ in_channels=resnet_in_channels + res_skip_channels,
575
+ out_channels=out_channels,
576
+ temb_channels=temb_channels,
577
+ eps=resnet_eps,
578
+ groups=resnet_groups,
579
+ dropout=dropout,
580
+ time_embedding_norm=resnet_time_scale_shift,
581
+ non_linearity=resnet_act_fn,
582
+ output_scale_factor=output_scale_factor,
583
+ pre_norm=resnet_pre_norm,
584
+ )
585
+ )
586
+
587
+ self.resnets = nn.ModuleList(resnets)
588
+
589
+ if add_upsample:
590
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
591
+ else:
592
+ self.upsamplers = None
593
+
594
+ self.gradient_checkpointing = False
595
+
596
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
597
+ for resnet in self.resnets:
598
+ # pop res hidden states
599
+ res_hidden_states = res_hidden_states_tuple[-1]
600
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
601
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
602
+
603
+ if self.training and self.gradient_checkpointing:
604
+
605
+ def create_custom_forward(module):
606
+ def custom_forward(*inputs):
607
+ return module(*inputs)
608
+
609
+ return custom_forward
610
+
611
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
612
+ else:
613
+ hidden_states = resnet(hidden_states, temb)
614
+
615
+ if self.upsamplers is not None:
616
+ for upsampler in self.upsamplers:
617
+ hidden_states = upsampler(hidden_states, upsample_size)
618
+
619
+ return hidden_states
interpolation/models/utils.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import os
12
+ import math
13
+ import torch
14
+
15
+ import numpy as np
16
+ import torch.nn as nn
17
+
18
+ from einops import repeat
19
+
20
+
21
+ #################################################################################
22
+ # Unet Utils #
23
+ #################################################################################
24
+
25
+ def checkpoint(func, inputs, params, flag):
26
+ """
27
+ Evaluate a function without caching intermediate activations, allowing for
28
+ reduced memory at the expense of extra compute in the backward pass.
29
+ :param func: the function to evaluate.
30
+ :param inputs: the argument sequence to pass to `func`.
31
+ :param params: a sequence of parameters `func` depends on but does not
32
+ explicitly take as arguments.
33
+ :param flag: if False, disable gradient checkpointing.
34
+ """
35
+ if flag:
36
+ args = tuple(inputs) + tuple(params)
37
+ return CheckpointFunction.apply(func, len(inputs), *args)
38
+ else:
39
+ return func(*inputs)
40
+
41
+
42
+ class CheckpointFunction(torch.autograd.Function):
43
+ @staticmethod
44
+ def forward(ctx, run_function, length, *args):
45
+ ctx.run_function = run_function
46
+ ctx.input_tensors = list(args[:length])
47
+ ctx.input_params = list(args[length:])
48
+
49
+ with torch.no_grad():
50
+ output_tensors = ctx.run_function(*ctx.input_tensors)
51
+ return output_tensors
52
+
53
+ @staticmethod
54
+ def backward(ctx, *output_grads):
55
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
56
+ with torch.enable_grad():
57
+ # Fixes a bug where the first op in run_function modifies the
58
+ # Tensor storage in place, which is not allowed for detach()'d
59
+ # Tensors.
60
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
61
+ output_tensors = ctx.run_function(*shallow_copies)
62
+ input_grads = torch.autograd.grad(
63
+ output_tensors,
64
+ ctx.input_tensors + ctx.input_params,
65
+ output_grads,
66
+ allow_unused=True,
67
+ )
68
+ del ctx.input_tensors
69
+ del ctx.input_params
70
+ del output_tensors
71
+ return (None, None) + input_grads
72
+
73
+
74
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
75
+ """
76
+ Create sinusoidal timestep embeddings.
77
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
78
+ These may be fractional.
79
+ :param dim: the dimension of the output.
80
+ :param max_period: controls the minimum frequency of the embeddings.
81
+ :return: an [N x dim] Tensor of positional embeddings.
82
+ """
83
+ if not repeat_only:
84
+ half = dim // 2
85
+ freqs = torch.exp(
86
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
87
+ ).to(device=timesteps.device)
88
+ args = timesteps[:, None].float() * freqs[None]
89
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
90
+ if dim % 2:
91
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
92
+ else:
93
+ embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous()
94
+ return embedding
95
+
96
+
97
+ def zero_module(module):
98
+ """
99
+ Zero out the parameters of a module and return it.
100
+ """
101
+ for p in module.parameters():
102
+ p.detach().zero_()
103
+ return module
104
+
105
+
106
+ def scale_module(module, scale):
107
+ """
108
+ Scale the parameters of a module and return it.
109
+ """
110
+ for p in module.parameters():
111
+ p.detach().mul_(scale)
112
+ return module
113
+
114
+
115
+ def mean_flat(tensor):
116
+ """
117
+ Take the mean over all non-batch dimensions.
118
+ """
119
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
120
+
121
+
122
+ def normalization(channels):
123
+ """
124
+ Make a standard normalization layer.
125
+ :param channels: number of input channels.
126
+ :return: an nn.Module for normalization.
127
+ """
128
+ return GroupNorm32(32, channels)
129
+
130
+
131
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
132
+ class SiLU(nn.Module):
133
+ def forward(self, x):
134
+ return x * torch.sigmoid(x)
135
+
136
+
137
+ class GroupNorm32(nn.GroupNorm):
138
+ def forward(self, x):
139
+ return super().forward(x.float()).type(x.dtype)
140
+
141
+ def conv_nd(dims, *args, **kwargs):
142
+ """
143
+ Create a 1D, 2D, or 3D convolution module.
144
+ """
145
+ if dims == 1:
146
+ return nn.Conv1d(*args, **kwargs)
147
+ elif dims == 2:
148
+ return nn.Conv2d(*args, **kwargs)
149
+ elif dims == 3:
150
+ return nn.Conv3d(*args, **kwargs)
151
+ raise ValueError(f"unsupported dimensions: {dims}")
152
+
153
+
154
+ def linear(*args, **kwargs):
155
+ """
156
+ Create a linear module.
157
+ """
158
+ return nn.Linear(*args, **kwargs)
159
+
160
+
161
+ def avg_pool_nd(dims, *args, **kwargs):
162
+ """
163
+ Create a 1D, 2D, or 3D average pooling module.
164
+ """
165
+ if dims == 1:
166
+ return nn.AvgPool1d(*args, **kwargs)
167
+ elif dims == 2:
168
+ return nn.AvgPool2d(*args, **kwargs)
169
+ elif dims == 3:
170
+ return nn.AvgPool3d(*args, **kwargs)
171
+ raise ValueError(f"unsupported dimensions: {dims}")
172
+
173
+
174
+ # class HybridConditioner(nn.Module):
175
+
176
+ # def __init__(self, c_concat_config, c_crossattn_config):
177
+ # super().__init__()
178
+ # self.concat_conditioner = instantiate_from_config(c_concat_config)
179
+ # self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
180
+
181
+ # def forward(self, c_concat, c_crossattn):
182
+ # c_concat = self.concat_conioner(c_concat)
183
+ # c_crossattn = self.crossattn_conditioner(c_crossattn)
184
+ # return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
185
+
186
+
187
+ def noise_like(shape, device, repeat=False):
188
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
189
+ noise = lambda: torch.randn(shape, device=device)
190
+ return repeat_noise() if repeat else noise()
191
+
192
+ def count_flops_attn(model, _x, y):
193
+ """
194
+ A counter for the `thop` package to count the operations in an
195
+ attention operation.
196
+ Meant to be used like:
197
+ macs, params = thop.profile(
198
+ model,
199
+ inputs=(inputs, timestamps),
200
+ custom_ops={QKVAttention: QKVAttention.count_flops},
201
+ )
202
+ """
203
+ b, c, *spatial = y[0].shape
204
+ num_spatial = int(np.prod(spatial))
205
+ # We perform two matmuls with the same number of ops.
206
+ # The first computes the weight matrix, the second computes
207
+ # the combination of the value vectors.
208
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
209
+ model.total_ops += torch.DoubleTensor([matmul_ops])
210
+
211
+ def count_params(model, verbose=False):
212
+ total_params = sum(p.numel() for p in model.parameters())
213
+ if verbose:
214
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
215
+ return total_params
interpolation/sample.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ we introduce a temporal interpolation network to enhance the smoothness of generated videos and synthesize richer temporal details.
3
+ This network takes a 16-frame base video as input and produces an upsampled output consisting of 61 frames.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import math
9
+ try:
10
+ import utils
11
+
12
+ from diffusion import create_diffusion
13
+ from download import find_model
14
+ except:
15
+ sys.path.append(os.path.split(sys.path[0])[0])
16
+
17
+ import utils
18
+
19
+ from diffusion import create_diffusion
20
+ from download import find_model
21
+
22
+ import torch
23
+ import argparse
24
+ import torchvision
25
+
26
+ from einops import rearrange
27
+ from models import get_models
28
+ from torchvision.utils import save_image
29
+ from diffusers.models import AutoencoderKL
30
+ from models.clip import TextEmbedder
31
+ from omegaconf import OmegaConf
32
+ from PIL import Image
33
+ import numpy as np
34
+ from torchvision import transforms
35
+ sys.path.append("..")
36
+ from datasets import video_transforms
37
+ from decord import VideoReader
38
+ from utils import mask_generation, mask_generation_before
39
+ from natsort import natsorted
40
+ from diffusers.utils.import_utils import is_xformers_available
41
+
42
+ torch.backends.cuda.matmul.allow_tf32 = True
43
+ torch.backends.cudnn.allow_tf32 = True
44
+
45
+
46
+ def get_input(args):
47
+ input_path = args.input_path
48
+ transform_video = transforms.Compose([
49
+ video_transforms.ToTensorVideo(), # TCHW
50
+ # video_transforms.CenterCropResizeVideo((args.image_h, args.image_w)),
51
+ video_transforms.ResizeVideo((args.image_h, args.image_w)),
52
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
53
+ ])
54
+ temporal_sample_func = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval)
55
+ if input_path is not None:
56
+ print(f'loading video from {input_path}')
57
+ if os.path.isdir(input_path):
58
+ file_list = os.listdir(input_path)
59
+ video_frames = []
60
+ for file in file_list:
61
+ if file.endswith('jpg') or file.endswith('png'):
62
+ image = torch.as_tensor(np.array(Image.open(file), dtype=np.uint8, copy=True)).unsqueeze(0)
63
+ video_frames.append(image)
64
+ else:
65
+ continue
66
+ n = 0
67
+ video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
68
+ video_frames = transform_video(video_frames)
69
+ return video_frames, n
70
+ elif os.path.isfile(input_path):
71
+ _, full_file_name = os.path.split(input_path)
72
+ file_name, extention = os.path.splitext(full_file_name)
73
+ if extention == '.mp4':
74
+ video_reader = VideoReader(input_path)
75
+ total_frames = len(video_reader)
76
+ start_frame_ind, end_frame_ind = temporal_sample_func(total_frames)
77
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, args.num_frames, dtype=int)
78
+ video_frames = torch.from_numpy(video_reader.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous()
79
+ video_frames = transform_video(video_frames)
80
+ n = 3
81
+ del video_reader
82
+ return video_frames, n
83
+ else:
84
+ raise TypeError(f'{extention} is not supported !!')
85
+ else:
86
+ raise ValueError('Please check your path input!!')
87
+ else:
88
+ print('given video is None, using text to video')
89
+ video_frames = torch.zeros(16,3,args.latent_h,args.latent_w,dtype=torch.uint8)
90
+ args.mask_type = 'all'
91
+ video_frames = transform_video(video_frames)
92
+ n = 0
93
+ return video_frames, n
94
+
95
+
96
+ def auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,):
97
+
98
+
99
+ b,f,c,h,w=video_input.shape
100
+ latent_h = args.image_size[0] // 8
101
+ latent_w = args.image_size[1] // 8
102
+
103
+ z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, device=device) # b,c,f,h,w
104
+
105
+ masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous()
106
+ masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215)
107
+ masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous()
108
+ mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1)
109
+
110
+
111
+ masked_video = torch.cat([masked_video] * 2) if args.do_classifier_free_guidance else masked_video
112
+ mask = torch.cat([mask] * 2) if args.do_classifier_free_guidance else mask
113
+ z = torch.cat([z] * 2) if args.do_classifier_free_guidance else z
114
+
115
+ prompt_all = [prompt] + [args.negative_prompt] if args.do_classifier_free_guidance else [prompt]
116
+ text_prompt = text_encoder(text_prompts=prompt_all, train=False)
117
+ model_kwargs = dict(encoder_hidden_states=text_prompt, class_labels=None)
118
+
119
+ if args.use_ddim_sample_loop:
120
+ samples = diffusion.ddim_sample_loop(
121
+ model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, \
122
+ progress=True, device=device, mask=mask, x_start=masked_video, use_concat=args.use_concat
123
+ )
124
+ else:
125
+ samples = diffusion.p_sample_loop(
126
+ model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, \
127
+ progress=True, device=device, mask=mask, x_start=masked_video, use_concat=args.use_concat
128
+ ) # torch.Size([2, 4, 16, 32, 32])
129
+ samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32]
130
+
131
+ video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32]
132
+ video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256]
133
+ return video_clip
134
+
135
+
136
+ def auto_inpainting_copy_no_mask(args, video_input, prompt, vae, text_encoder, diffusion, model, device,):
137
+
138
+ b,f,c,h,w=video_input.shape
139
+ latent_h = args.image_size[0] // 8
140
+ latent_w = args.image_size[1] // 8
141
+
142
+ video_input = rearrange(video_input, 'b f c h w -> (b f) c h w').contiguous()
143
+ video_input = vae.encode(video_input).latent_dist.sample().mul_(0.18215)
144
+ video_input = rearrange(video_input, '(b f) c h w -> b c f h w', b=b).contiguous()
145
+
146
+ lr_indice = torch.IntTensor([i for i in range(0,62,4)]).to(device)
147
+ copied_video = torch.index_select(video_input, 2, lr_indice)
148
+ copied_video = torch.repeat_interleave(copied_video, 4, dim=2)
149
+ copied_video = copied_video[:,:,1:-2,:,:]
150
+ copied_video = torch.cat([copied_video] * 2) if args.do_classifier_free_guidance else copied_video
151
+
152
+ torch.manual_seed(args.seed)
153
+ torch.cuda.manual_seed(args.seed)
154
+ z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, device=device) # b,c,f,h,w
155
+ z = torch.cat([z] * 2) if args.do_classifier_free_guidance else z
156
+
157
+ prompt_all = [prompt] + [args.negative_prompt] if args.do_classifier_free_guidance else [prompt]
158
+ text_prompt = text_encoder(text_prompts=prompt_all, train=False)
159
+ model_kwargs = dict(encoder_hidden_states=text_prompt, class_labels=None)
160
+
161
+ torch.manual_seed(args.seed)
162
+ torch.cuda.manual_seed(args.seed)
163
+ if args.use_ddim_sample_loop:
164
+ samples = diffusion.ddim_sample_loop(
165
+ model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, \
166
+ progress=True, device=device, mask=None, x_start=copied_video, use_concat=args.use_concat, copy_no_mask=args.copy_no_mask,
167
+ )
168
+ else:
169
+ raise ValueError(f'We only have ddim sampling implementation for now')
170
+
171
+ samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32]
172
+
173
+ video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32]
174
+ video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256]
175
+ return video_clip
176
+
177
+
178
+
179
+ def main(args):
180
+
181
+ for seed in args.seed_list:
182
+
183
+ args.seed = seed
184
+ torch.manual_seed(args.seed)
185
+ torch.cuda.manual_seed(args.seed)
186
+ # print(f'torch.seed() = {torch.seed()}')
187
+
188
+ print('sampling begins')
189
+ torch.set_grad_enabled(False)
190
+ device = "cuda" if torch.cuda.is_available() else "cpu"
191
+ # device = "cpu"
192
+
193
+ ckpt_path = args.pretrained_path + "/lavie_interpolation.pt"
194
+ sd_path = args.pretrained_path + "/stable-diffusion-v1-4"
195
+ for ckpt in [ckpt_path]:
196
+
197
+ ckpt_num = str(ckpt_path).zfill(7)
198
+
199
+ # Load model:
200
+ latent_h = args.image_size[0] // 8
201
+ latent_w = args.image_size[1] // 8
202
+ args.image_h = args.image_size[0]
203
+ args.image_w = args.image_size[1]
204
+ args.latent_h = latent_h
205
+ args.latent_w = latent_w
206
+ print(f'args.copy_no_mask = {args.copy_no_mask}')
207
+ model = get_models(args, sd_path).to(device)
208
+
209
+ if args.use_compile:
210
+ model = torch.compile(model)
211
+ if args.enable_xformers_memory_efficient_attention:
212
+ if is_xformers_available():
213
+ model.enable_xformers_memory_efficient_attention()
214
+ # model.enable_vae_slicing() # ziqi added
215
+ else:
216
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
217
+
218
+ # Auto-download a pre-trained model or load a custom checkpoint from train.py:
219
+ print(f'loading model from {ckpt_path}')
220
+
221
+ # load ckpt
222
+ state_dict = find_model(ckpt_path)
223
+
224
+ print(f'state_dict["conv_in.weight"].shape = {state_dict["conv_in.weight"].shape}') # [320, 8, 3, 3]
225
+ print('loading succeed')
226
+ # model.load_state_dict(state_dict)
227
+
228
+ torch.manual_seed(args.seed)
229
+ torch.cuda.manual_seed(args.seed)
230
+
231
+ model.eval() # important!
232
+ diffusion = create_diffusion(str(args.num_sampling_steps))
233
+ vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae").to(device)
234
+ text_encoder = TextEmbedder(sd_path).to(device)
235
+
236
+ video_list = os.listdir(args.input_folder)
237
+ args.input_path_list = [os.path.join(args.input_folder, video) for video in video_list]
238
+ for input_path in args.input_path_list:
239
+
240
+ args.input_path = input_path
241
+
242
+ print(f'=======================================')
243
+ if not args.input_path.endswith('.mp4'):
244
+ print(f'Skipping {args.input_path}')
245
+ continue
246
+
247
+ print(f'args.input_path = {args.input_path}')
248
+
249
+ torch.manual_seed(args.seed)
250
+ torch.cuda.manual_seed(args.seed)
251
+
252
+ # Labels to condition the model with (feel free to change):
253
+ video_name = args.input_path.split('/')[-1].split('.mp4')[0]
254
+ args.prompt = [video_name]
255
+ print(f'args.prompt = {args.prompt}')
256
+ prompts = args.prompt
257
+ class_name = [p + args.additional_prompt for p in prompts]
258
+
259
+ if not os.path.exists(os.path.join(args.output_folder)):
260
+ os.makedirs(os.path.join(args.output_folder))
261
+ video_input, researve_frames = get_input(args) # f,c,h,w
262
+ video_input = video_input.to(device).unsqueeze(0) # b,f,c,h,w
263
+ if args.copy_no_mask:
264
+ pass
265
+ else:
266
+ mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) # b,f,c,h,w
267
+
268
+ if args.copy_no_mask:
269
+ pass
270
+ else:
271
+ if args.mask_type == 'tsr':
272
+ masked_video = video_input * (mask == 0)
273
+ else:
274
+ masked_video = video_input * (mask == 0)
275
+
276
+ all_video = []
277
+ if researve_frames != 0:
278
+ all_video.append(video_input)
279
+ for idx, prompt in enumerate(class_name):
280
+ if idx == 0:
281
+ if args.copy_no_mask:
282
+ video_clip = auto_inpainting_copy_no_mask(args, video_input, prompt, vae, text_encoder, diffusion, model, device,)
283
+ video_clip_ = video_clip.unsqueeze(0)
284
+ all_video.append(video_clip_)
285
+ else:
286
+ video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,)
287
+ video_clip_ = video_clip.unsqueeze(0)
288
+ all_video.append(video_clip_)
289
+ else:
290
+ raise NotImplementedError
291
+ masked_video = video_input * (mask == 0)
292
+ video_clip = auto_inpainting_copy_no_mask(args, video_clip.unsqueeze(0), masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,)
293
+ video_clip_ = video_clip.unsqueeze(0)
294
+ all_video.append(video_clip_[:, 3:])
295
+ video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
296
+ for fps in args.fps_list:
297
+ save_path = args.output_folder
298
+ if not os.path.exists(os.path.join(save_path)):
299
+ os.makedirs(os.path.join(save_path))
300
+ local_save_path = os.path.join(save_path, f'{video_name}.mp4')
301
+ print(f'save in {local_save_path}')
302
+ torchvision.io.write_video(local_save_path, video_, fps=fps)
303
+
304
+
305
+
306
+ if __name__ == "__main__":
307
+ parser = argparse.ArgumentParser()
308
+ parser.add_argument("--config", type=str, required=True)
309
+ args = parser.parse_args()
310
+ main(**OmegaConf.load(args.config))
311
+
312
+
interpolation/utils.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ import logging
5
+ import subprocess
6
+ import numpy as np
7
+ import torch.distributed as dist
8
+
9
+ # from torch._six import inf
10
+ from torch import inf
11
+ from PIL import Image
12
+ from typing import Union, Iterable
13
+ from collections import OrderedDict
14
+
15
+
16
+ _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
17
+
18
+ #################################################################################
19
+ # Training Helper Functions #
20
+ #################################################################################
21
+
22
+ #################################################################################
23
+ # Training Clip Gradients #
24
+ #################################################################################
25
+
26
+ def get_grad_norm(
27
+ parameters: _tensor_or_tensors, norm_type: float = 2.0) -> torch.Tensor:
28
+ r"""
29
+ Copy from torch.nn.utils.clip_grad_norm_
30
+
31
+ Clips gradient norm of an iterable of parameters.
32
+
33
+ The norm is computed over all gradients together, as if they were
34
+ concatenated into a single vector. Gradients are modified in-place.
35
+
36
+ Args:
37
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
38
+ single Tensor that will have gradients normalized
39
+ max_norm (float or int): max norm of the gradients
40
+ norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
41
+ infinity norm.
42
+ error_if_nonfinite (bool): if True, an error is thrown if the total
43
+ norm of the gradients from :attr:`parameters` is ``nan``,
44
+ ``inf``, or ``-inf``. Default: False (will switch to True in the future)
45
+
46
+ Returns:
47
+ Total norm of the parameter gradients (viewed as a single vector).
48
+ """
49
+ if isinstance(parameters, torch.Tensor):
50
+ parameters = [parameters]
51
+ grads = [p.grad for p in parameters if p.grad is not None]
52
+ norm_type = float(norm_type)
53
+ if len(grads) == 0:
54
+ return torch.tensor(0.)
55
+ device = grads[0].device
56
+ if norm_type == inf:
57
+ norms = [g.detach().abs().max().to(device) for g in grads]
58
+ total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
59
+ else:
60
+ total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
61
+ return total_norm
62
+
63
+ def clip_grad_norm_(
64
+ parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
65
+ error_if_nonfinite: bool = False, clip_grad = True) -> torch.Tensor:
66
+ r"""
67
+ Copy from torch.nn.utils.clip_grad_norm_
68
+
69
+ Clips gradient norm of an iterable of parameters.
70
+
71
+ The norm is computed over all gradients together, as if they were
72
+ concatenated into a single vector. Gradients are modified in-place.
73
+
74
+ Args:
75
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
76
+ single Tensor that will have gradients normalized
77
+ max_norm (float or int): max norm of the gradients
78
+ norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
79
+ infinity norm.
80
+ error_if_nonfinite (bool): if True, an error is thrown if the total
81
+ norm of the gradients from :attr:`parameters` is ``nan``,
82
+ ``inf``, or ``-inf``. Default: False (will switch to True in the future)
83
+
84
+ Returns:
85
+ Total norm of the parameter gradients (viewed as a single vector).
86
+ """
87
+ if isinstance(parameters, torch.Tensor):
88
+ parameters = [parameters]
89
+ grads = [p.grad for p in parameters if p.grad is not None]
90
+ max_norm = float(max_norm)
91
+ norm_type = float(norm_type)
92
+ if len(grads) == 0:
93
+ return torch.tensor(0.)
94
+ device = grads[0].device
95
+ if norm_type == inf:
96
+ norms = [g.detach().abs().max().to(device) for g in grads]
97
+ total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
98
+ else:
99
+ total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
100
+ # print(total_norm)
101
+
102
+ if clip_grad:
103
+ if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
104
+ raise RuntimeError(
105
+ f'The total norm of order {norm_type} for gradients from '
106
+ '`parameters` is non-finite, so it cannot be clipped. To disable '
107
+ 'this error and scale the gradients by the non-finite norm anyway, '
108
+ 'set `error_if_nonfinite=False`')
109
+ clip_coef = max_norm / (total_norm + 1e-6)
110
+ # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
111
+ # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
112
+ # when the gradients do not reside in CPU memory.
113
+ clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
114
+ for g in grads:
115
+ g.detach().mul_(clip_coef_clamped.to(g.device))
116
+ # gradient_cliped = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
117
+ # print(gradient_cliped)
118
+ return total_norm
119
+
120
+ #################################################################################
121
+ # Training Logger #
122
+ #################################################################################
123
+
124
+ def create_logger(logging_dir):
125
+ """
126
+ Create a logger that writes to a log file and stdout.
127
+ """
128
+ if dist.get_rank() == 0: # real logger
129
+ logging.basicConfig(
130
+ level=logging.INFO,
131
+ # format='[\033[34m%(asctime)s\033[0m] %(message)s',
132
+ format='[%(asctime)s] %(message)s',
133
+ datefmt='%Y-%m-%d %H:%M:%S',
134
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
135
+ )
136
+ logger = logging.getLogger(__name__)
137
+
138
+ else: # dummy logger (does nothing)
139
+ logger = logging.getLogger(__name__)
140
+ logger.addHandler(logging.NullHandler())
141
+ return logger
142
+
143
+ def create_accelerate_logger(logging_dir, is_main_process=False):
144
+ """
145
+ Create a logger that writes to a log file and stdout.
146
+ """
147
+ if is_main_process: # real logger
148
+ logging.basicConfig(
149
+ level=logging.INFO,
150
+ # format='[\033[34m%(asctime)s\033[0m] %(message)s',
151
+ format='[%(asctime)s] %(message)s',
152
+ datefmt='%Y-%m-%d %H:%M:%S',
153
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
154
+ )
155
+ logger = logging.getLogger(__name__)
156
+ else: # dummy logger (does nothing)
157
+ logger = logging.getLogger(__name__)
158
+ logger.addHandler(logging.NullHandler())
159
+ return logger
160
+
161
+
162
+ def create_tensorboard(tensorboard_dir):
163
+ """
164
+ Create a tensorboard that saves losses.
165
+ """
166
+ if dist.get_rank() == 0: # real tensorboard
167
+ # tensorboard
168
+ writer = SummaryWriter(tensorboard_dir)
169
+
170
+ return writer
171
+
172
+ def write_tensorboard(writer, *args):
173
+ '''
174
+ write the loss information to a tensorboard file.
175
+ Only for pytorch DDP mode.
176
+ '''
177
+ if dist.get_rank() == 0: # real tensorboard
178
+ writer.add_scalar(args[0], args[1], args[2])
179
+
180
+ #################################################################################
181
+ # EMA Update/ DDP Training Utils #
182
+ #################################################################################
183
+
184
+ @torch.no_grad()
185
+ def update_ema(ema_model, model, decay=0.9999):
186
+ """
187
+ Step the EMA model towards the current model.
188
+ """
189
+ ema_params = OrderedDict(ema_model.named_parameters())
190
+ model_params = OrderedDict(model.named_parameters())
191
+
192
+ for name, param in model_params.items():
193
+ # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
194
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
195
+
196
+ def requires_grad(model, flag=True):
197
+ """
198
+ Set requires_grad flag for all parameters in a model.
199
+ """
200
+ for p in model.parameters():
201
+ p.requires_grad = flag
202
+
203
+ def cleanup():
204
+ """
205
+ End DDP training.
206
+ """
207
+ dist.destroy_process_group()
208
+
209
+
210
+ def setup_distributed(backend="nccl", port=None):
211
+ """Initialize distributed training environment.
212
+ support both slurm and torch.distributed.launch
213
+ see torch.distributed.init_process_group() for more details
214
+ """
215
+ num_gpus = torch.cuda.device_count()
216
+
217
+ print(f'Hahahahahaha')
218
+ if "SLURM_JOB_ID" in os.environ:
219
+ rank = int(os.environ["SLURM_PROCID"])
220
+ world_size = int(os.environ["SLURM_NTASKS"])
221
+ node_list = os.environ["SLURM_NODELIST"]
222
+ addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
223
+ # specify master port
224
+ if port is not None:
225
+ os.environ["MASTER_PORT"] = str(port)
226
+ elif "MASTER_PORT" not in os.environ:
227
+ # os.environ["MASTER_PORT"] = "29566"
228
+ os.environ["MASTER_PORT"] = str(29566 + num_gpus)
229
+ if "MASTER_ADDR" not in os.environ:
230
+ os.environ["MASTER_ADDR"] = addr
231
+ os.environ["WORLD_SIZE"] = str(world_size)
232
+ os.environ["LOCAL_RANK"] = str(rank % num_gpus)
233
+ os.environ["RANK"] = str(rank)
234
+ else:
235
+ rank = int(os.environ["RANK"])
236
+ world_size = int(os.environ["WORLD_SIZE"])
237
+
238
+ # torch.cuda.set_device(rank % num_gpus)
239
+
240
+ print(f'before dist.init_process_group')
241
+
242
+ dist.init_process_group(
243
+ backend=backend,
244
+ world_size=world_size,
245
+ rank=rank,
246
+ )
247
+ print(f'after dist.init_process_group')
248
+
249
+ #################################################################################
250
+ # Testing Utils #
251
+ #################################################################################
252
+
253
+ def save_video_grid(video, nrow=None):
254
+ b, t, h, w, c = video.shape
255
+
256
+ if nrow is None:
257
+ nrow = math.ceil(math.sqrt(b))
258
+ ncol = math.ceil(b / nrow)
259
+ padding = 1
260
+ video_grid = torch.zeros((t, (padding + h) * nrow + padding,
261
+ (padding + w) * ncol + padding, c), dtype=torch.uint8)
262
+
263
+ print(video_grid.shape)
264
+ for i in range(b):
265
+ r = i // ncol
266
+ c = i % ncol
267
+ start_r = (padding + h) * r
268
+ start_c = (padding + w) * c
269
+ video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i]
270
+
271
+ return video_grid
272
+
273
+
274
+ #################################################################################
275
+ # MMCV Utils #
276
+ #################################################################################
277
+
278
+
279
+ def collect_env():
280
+ # Copyright (c) OpenMMLab. All rights reserved.
281
+ from mmcv.utils import collect_env as collect_base_env
282
+ from mmcv.utils import get_git_hash
283
+ """Collect the information of the running environments."""
284
+
285
+ env_info = collect_base_env()
286
+ env_info['MMClassification'] = get_git_hash()[:7]
287
+
288
+ for name, val in env_info.items():
289
+ print(f'{name}: {val}')
290
+
291
+ print(torch.cuda.get_arch_list())
292
+ print(torch.version.cuda)
293
+
294
+ #################################################################################
295
+ # Long video generation Utils #
296
+ #################################################################################
297
+
298
+ def mask_generation(mask_type, shape, dtype, device):
299
+ b, c, f, h, w = shape
300
+ if mask_type.startswith('random'):
301
+ num = float(mask_type.split('random')[-1])
302
+ mask_f = torch.ones(1, 1, f, 1, 1, dtype=dtype, device=device)
303
+ indices = torch.randperm(f, device=device)[:int(f*num)]
304
+ mask_f[0, 0, indices, :, :] = 0
305
+ mask = mask_f.expand(b, c, -1, h, w)
306
+ elif mask_type.startswith('first'):
307
+ num = int(mask_type.split('first')[-1])
308
+ mask_f = torch.cat([torch.zeros(1, 1, num, 1, 1, dtype=dtype, device=device),
309
+ torch.ones(1, 1, f-num, 1, 1, dtype=dtype, device=device)], dim=2)
310
+ mask = mask_f.expand(b, c, -1, h, w)
311
+ else:
312
+ raise ValueError(f"Invalid mask type: {mask_type}")
313
+ return mask
314
+
315
+
316
+
317
+ def mask_generation_before(mask_type, shape, dtype, device):
318
+ b, f, c, h, w = shape
319
+ if mask_type.startswith('random'):
320
+ num = float(mask_type.split('random')[-1])
321
+ mask_f = torch.ones(1, f, 1, 1, 1, dtype=dtype, device=device)
322
+ indices = torch.randperm(f, device=device)[:int(f*num)]
323
+ mask_f[0, indices, :, :, :] = 0
324
+ mask = mask_f.expand(b, -1, c, h, w)
325
+ elif mask_type.startswith('first'):
326
+ num = int(mask_type.split('first')[-1])
327
+ mask_f = torch.cat([torch.zeros(1, num, 1, 1, 1, dtype=dtype, device=device),
328
+ torch.ones(1, f-num, 1, 1, 1, dtype=dtype, device=device)], dim=1)
329
+ mask = mask_f.expand(b, -1, c, h, w)
330
+ elif mask_type.startswith('uniform'):
331
+ p = float(mask_type.split('uniform')[-1])
332
+ mask_f = torch.ones(1, f, 1, 1, 1, dtype=dtype, device=device)
333
+ mask_f[0, torch.rand(f, device=device) < p, :, :, :] = 0
334
+ print(f'mask_f: = {mask_f}')
335
+ mask = mask_f.expand(b, -1, c, h, w)
336
+ print(f'mask.shape: = {mask.shape}, mask: = {mask}')
337
+ elif mask_type.startswith('all'):
338
+ mask = torch.ones(b,f,c,h,w,dtype=dtype,device=device)
339
+ elif mask_type.startswith('onelast'):
340
+ num = int(mask_type.split('onelast')[-1])
341
+ mask_one = torch.zeros(1,1,1,1,1, dtype=dtype, device=device)
342
+ mask_mid = torch.ones(1,f-2*num,1,1,1,dtype=dtype, device=device)
343
+ mask_last = torch.zeros_like(mask_one)
344
+ mask = torch.cat([mask_one]*num + [mask_mid] + [mask_last]*num, dim=1)
345
+ # breakpoint()
346
+ mask = mask.expand(b, -1, c, h, w)
347
+ elif mask_type.startswith('interpolate'):
348
+ mask_f = []
349
+ for i in range(4):
350
+ mask_zero = torch.zeros(1,1,1,1,1, dtype=dtype, device=device)
351
+ mask_f.append(mask_zero)
352
+ mask_one = torch.ones(1,3,1,1,1, dtype=dtype, device=device)
353
+ mask_f.append(mask_one)
354
+ mask = torch.cat(mask_f, dim=1)
355
+ print(f'mask={mask}')
356
+ elif mask_type.startswith('tsr'):
357
+ mask_f = []
358
+ mask_zero = torch.zeros(1,1,1,1,1, dtype=dtype, device=device)
359
+ mask_one = torch.ones(1,3,1,1,1, dtype=dtype, device=device)
360
+ for i in range(15):
361
+ mask_f.append(mask_zero) # not masked
362
+ mask_f.append(mask_one) # masked
363
+ mask_f.append(mask_zero) # not masked
364
+ mask = torch.cat(mask_f, dim=1)
365
+ # print(f'before mask.shape = {mask.shape}, mask = {mask}') # [1, 61, 1, 1, 1]
366
+ mask = mask.expand(b, -1, c, h, w)
367
+ # print(f'after mask.shape = {mask.shape}, mask = {mask}') # [4, 61, 3, 256, 256]
368
+ else:
369
+ raise ValueError(f"Invalid mask type: {mask_type}")
370
+
371
+ return mask