wren93 commited on
Commit
ef16dc7
·
1 Parent(s): 09d869f
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ samples/
2
+ wandb/
3
+ outputs/
4
+ __pycache__/
5
+ scripts/animate_inter.py
6
+ scripts/gradio_app.py
7
+ *.ipynb
8
+ *.safetensors
9
+ *.ckpt
10
+ .ossutil_checkpoint/
11
+ ossutil_output/
12
+ debugs/
13
+ .vscode
14
+ .env
15
+ models
16
+ !*/models
17
+ .ipynb_checkpoints
18
+ checkpoints
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 TIGER Lab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: ConsistI2V
3
- emoji: 💻
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: gradio
 
1
  ---
2
  title: ConsistI2V
3
+ emoji: 🎥
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import json
4
+ import torch
5
+ import random
6
+ import requests
7
+ from PIL import Image
8
+ import numpy as np
9
+
10
+ import gradio as gr
11
+ from datetime import datetime
12
+
13
+ import torchvision.transforms as T
14
+
15
+ from diffusers import DDIMScheduler
16
+ from diffusers.utils.import_utils import is_xformers_available
17
+ from consisti2v.pipelines.pipeline_conditional_animation import ConditionalAnimationPipeline
18
+ from consisti2v.utils.util import save_videos_grid
19
+ from omegaconf import OmegaConf
20
+
21
+
22
+ sample_idx = 0
23
+ scheduler_dict = {
24
+ "DDIM": DDIMScheduler,
25
+ }
26
+
27
+ css = """
28
+ .toolbutton {
29
+ margin-buttom: 0em 0em 0em 0em;
30
+ max-width: 2.5em;
31
+ min-width: 2.5em !important;
32
+ height: 2.5em;
33
+ }
34
+ """
35
+
36
+ class AnimateController:
37
+ def __init__(self):
38
+
39
+ # config dirs
40
+ self.basedir = os.getcwd()
41
+ self.savedir = os.path.join(self.basedir, "samples/Gradio", datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
42
+ self.savedir_sample = os.path.join(self.savedir, "sample")
43
+ os.makedirs(self.savedir, exist_ok=True)
44
+
45
+ self.image_resolution = (256, 256)
46
+ # config models
47
+ self.pipeline = ConditionalAnimationPipeline.from_pretrained("TIGER-Lab/ConsistI2V", torch_dtype=torch.float16,)
48
+ self.pipeline.to("cuda")
49
+
50
+ def update_textbox_and_save_image(self, input_image, height_slider, width_slider, center_crop):
51
+ pil_image = Image.fromarray(input_image.astype(np.uint8)).convert("RGB")
52
+ img_path = os.path.join(self.savedir, "input_image.png")
53
+ pil_image.save(img_path)
54
+ self.image_resolution = pil_image.size
55
+ pil_image = pil_image.resize((width_slider, height_slider))
56
+ if center_crop:
57
+ width, height = width_slider, height_slider
58
+ aspect_ratio = width / height
59
+ if aspect_ratio > 16 / 10:
60
+ pil_image = pil_image.crop((int((width - height * 16 / 10) / 2), 0, int((width + height * 16 / 10) / 2), height))
61
+ elif aspect_ratio < 16 / 10:
62
+ pil_image = pil_image.crop((0, int((height - width * 10 / 16) / 2), width, int((height + width * 10 / 16) / 2)))
63
+ return gr.Textbox.update(value=img_path), gr.Image.update(value=np.array(pil_image))
64
+
65
+ @spaces.GPU
66
+ def animate(
67
+ self,
68
+ prompt_textbox,
69
+ negative_prompt_textbox,
70
+ input_image_path,
71
+ sampler_dropdown,
72
+ sample_step_slider,
73
+ width_slider,
74
+ height_slider,
75
+ txt_cfg_scale_slider,
76
+ img_cfg_scale_slider,
77
+ center_crop,
78
+ frame_stride,
79
+ use_frameinit,
80
+ frame_init_noise_level,
81
+ seed_textbox
82
+ ):
83
+ if self.pipeline is None:
84
+ raise gr.Error(f"Please select a pretrained pipeline path.")
85
+ if input_image_path == "":
86
+ raise gr.Error(f"Please upload an input image.")
87
+ if (not center_crop) and (width_slider % 8 != 0 or height_slider % 8 != 0):
88
+ raise gr.Error(f"`height` and `width` have to be divisible by 8 but are {height_slider} and {width_slider}.")
89
+ if center_crop and (width_slider % 8 != 0 or height_slider % 8 != 0):
90
+ raise gr.Error(f"`height` and `width` (after cropping) have to be divisible by 8 but are {height_slider} and {width_slider}.")
91
+
92
+ if is_xformers_available() and int(torch.__version__.split(".")[0]) < 2: self.pipeline.unet.enable_xformers_memory_efficient_attention()
93
+
94
+ if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
95
+ else: torch.seed()
96
+ seed = torch.initial_seed()
97
+
98
+ if input_image_path.startswith("http://") or input_image_path.startswith("https://"):
99
+ first_frame = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
100
+ else:
101
+ first_frame = Image.open(input_image_path).convert('RGB')
102
+
103
+ original_width, original_height = first_frame.size
104
+
105
+ if not center_crop:
106
+ img_transform = T.Compose([
107
+ T.ToTensor(),
108
+ T.Resize((height_slider, width_slider), antialias=None),
109
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
110
+ ])
111
+ else:
112
+ aspect_ratio = original_width / original_height
113
+ crop_aspect_ratio = width_slider / height_slider
114
+ if aspect_ratio > crop_aspect_ratio:
115
+ center_crop_width = int(crop_aspect_ratio * original_height)
116
+ center_crop_height = original_height
117
+ elif aspect_ratio < crop_aspect_ratio:
118
+ center_crop_width = original_width
119
+ center_crop_height = int(original_width / crop_aspect_ratio)
120
+ else:
121
+ center_crop_width = original_width
122
+ center_crop_height = original_height
123
+ img_transform = T.Compose([
124
+ T.ToTensor(),
125
+ T.CenterCrop((center_crop_height, center_crop_width)),
126
+ T.Resize((height_slider, width_slider), antialias=None),
127
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
128
+ ])
129
+
130
+ first_frame = img_transform(first_frame).unsqueeze(0)
131
+ first_frame = first_frame.to("cuda")
132
+
133
+ if use_frameinit:
134
+ self.pipeline.init_filter(
135
+ width = width_slider,
136
+ height = height_slider,
137
+ video_length = 16,
138
+ filter_params = OmegaConf.create({'method': 'gaussian', 'd_s': 0.25, 'd_t': 0.25,})
139
+ )
140
+
141
+
142
+ sample = self.pipeline(
143
+ prompt_textbox,
144
+ negative_prompt = negative_prompt_textbox,
145
+ first_frames = first_frame,
146
+ num_inference_steps = sample_step_slider,
147
+ guidance_scale_txt = txt_cfg_scale_slider,
148
+ guidance_scale_img = img_cfg_scale_slider,
149
+ width = width_slider,
150
+ height = height_slider,
151
+ video_length = 16,
152
+ noise_sampling_method = "pyoco_mixed",
153
+ noise_alpha = 1.0,
154
+ frame_stride = frame_stride,
155
+ use_frameinit = use_frameinit,
156
+ frameinit_noise_level = frame_init_noise_level,
157
+ camera_motion = None,
158
+ ).videos
159
+
160
+ global sample_idx
161
+ sample_idx += 1
162
+ save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4")
163
+ save_videos_grid(sample, save_sample_path, format="mp4")
164
+
165
+ sample_config = {
166
+ "prompt": prompt_textbox,
167
+ "n_prompt": negative_prompt_textbox,
168
+ "first_frame_path": input_image_path,
169
+ "sampler": sampler_dropdown,
170
+ "num_inference_steps": sample_step_slider,
171
+ "guidance_scale_text": txt_cfg_scale_slider,
172
+ "guidance_scale_image": img_cfg_scale_slider,
173
+ "width": width_slider,
174
+ "height": height_slider,
175
+ "video_length": 8,
176
+ "seed": seed
177
+ }
178
+ json_str = json.dumps(sample_config, indent=4)
179
+ with open(os.path.join(self.savedir, "logs.json"), "a") as f:
180
+ f.write(json_str)
181
+ f.write("\n\n")
182
+
183
+ return gr.Video.update(value=save_sample_path)
184
+
185
+
186
+ controller = AnimateController()
187
+
188
+
189
+ def ui():
190
+ with gr.Blocks(css=css) as demo:
191
+ gr.Markdown(
192
+ """
193
+ # ConsistI2V Text+Image to Video Generation
194
+ Input image will be used as the first frame of the video. Text prompts will be used to control the output video content.
195
+ """
196
+ )
197
+
198
+ with gr.Column(variant="panel"):
199
+ gr.Markdown(
200
+ """
201
+ - Input image can be specified using the "Input Image Path/URL" text box (this can be either a local image path or an image URL) or uploaded by clicking or dragging the image to the "Input Image" box. The uploaded image will be temporarily stored in the "samples/Gradio" folder under the project root folder.
202
+ - Input image can be resized and/or center cropped to a given resolution by adjusting the "Width" and "Height" sliders. It is recommended to use the same resolution as the training resolution (256x256).
203
+ - After setting the input image path or changed the width/height of the input image, press the "Preview" button to visualize the resized input image.
204
+ """
205
+ )
206
+
207
+ with gr.Row():
208
+ prompt_textbox = gr.Textbox(label="Prompt", lines=2)
209
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2)
210
+
211
+ with gr.Row().style(equal_height=False):
212
+ with gr.Column():
213
+ with gr.Row():
214
+ sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
215
+ sample_step_slider = gr.Slider(label="Sampling steps", value=50, minimum=10, maximum=250, step=1)
216
+
217
+ with gr.Row():
218
+ center_crop = gr.Checkbox(label="Center Crop the Image", value=True)
219
+ width_slider = gr.Slider(label="Width", value=256, minimum=0, maximum=512, step=64)
220
+ height_slider = gr.Slider(label="Height", value=256, minimum=0, maximum=512, step=64)
221
+ with gr.Row():
222
+ txt_cfg_scale_slider = gr.Slider(label="Text CFG Scale", value=7.5, minimum=1.0, maximum=20.0, step=0.5)
223
+ img_cfg_scale_slider = gr.Slider(label="Image CFG Scale", value=1.0, minimum=1.0, maximum=20.0, step=0.5)
224
+ frame_stride = gr.Slider(label="Frame Stride", value=3, minimum=1, maximum=5, step=1)
225
+
226
+ with gr.Row():
227
+ use_frameinit = gr.Checkbox(label="Enable FrameInit", value=True)
228
+ frameinit_noise_level = gr.Slider(label="FrameInit Noise Level", value=850, minimum=1, maximum=999, step=1)
229
+
230
+
231
+ seed_textbox = gr.Textbox(label="Seed", value=-1)
232
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
233
+ seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
234
+
235
+
236
+
237
+ generate_button = gr.Button(value="Generate", variant='primary')
238
+
239
+ with gr.Column():
240
+ with gr.Row():
241
+ input_image_path = gr.Textbox(label="Input Image Path/URL", lines=1, scale=10, info="Press Enter or the Preview button to confirm the input image.")
242
+ preview_button = gr.Button(value="Preview")
243
+
244
+ with gr.Row():
245
+ input_image = gr.Image(label="Input Image", interactive=True)
246
+ input_image.upload(fn=controller.update_textbox_and_save_image, inputs=[input_image, height_slider, width_slider, center_crop], outputs=[input_image_path, input_image])
247
+ result_video = gr.Video(label="Generated Animation", interactive=False, autoplay=True)
248
+
249
+ def update_and_resize_image(input_image_path, height_slider, width_slider, center_crop):
250
+ if input_image_path.startswith("http://") or input_image_path.startswith("https://"):
251
+ pil_image = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
252
+ else:
253
+ pil_image = Image.open(input_image_path).convert('RGB')
254
+ controller.image_resolution = pil_image.size
255
+ original_width, original_height = pil_image.size
256
+
257
+ if center_crop:
258
+ crop_aspect_ratio = width_slider / height_slider
259
+ aspect_ratio = original_width / original_height
260
+ if aspect_ratio > crop_aspect_ratio:
261
+ new_width = int(crop_aspect_ratio * original_height)
262
+ left = (original_width - new_width) / 2
263
+ top = 0
264
+ right = left + new_width
265
+ bottom = original_height
266
+ pil_image = pil_image.crop((left, top, right, bottom))
267
+ elif aspect_ratio < crop_aspect_ratio:
268
+ new_height = int(original_width / crop_aspect_ratio)
269
+ top = (original_height - new_height) / 2
270
+ left = 0
271
+ right = original_width
272
+ bottom = top + new_height
273
+ pil_image = pil_image.crop((left, top, right, bottom))
274
+
275
+ pil_image = pil_image.resize((width_slider, height_slider))
276
+ return gr.Image.update(value=np.array(pil_image))
277
+
278
+ preview_button.click(fn=update_and_resize_image, inputs=[input_image_path, height_slider, width_slider, center_crop], outputs=[input_image])
279
+ input_image_path.submit(fn=update_and_resize_image, inputs=[input_image_path, height_slider, width_slider, center_crop], outputs=[input_image])
280
+
281
+ generate_button.click(
282
+ fn=controller.animate,
283
+ inputs=[
284
+ prompt_textbox,
285
+ negative_prompt_textbox,
286
+ input_image_path,
287
+ sampler_dropdown,
288
+ sample_step_slider,
289
+ width_slider,
290
+ height_slider,
291
+ txt_cfg_scale_slider,
292
+ img_cfg_scale_slider,
293
+ center_crop,
294
+ frame_stride,
295
+ use_frameinit,
296
+ frameinit_noise_level,
297
+ seed_textbox,
298
+ ],
299
+ outputs=[result_video]
300
+ )
301
+
302
+ return demo
303
+
304
+
305
+ if __name__ == "__main__":
306
+ demo = ui()
307
+ demo.launch(share=True)
configs/inference/inference.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: "samples/inference"
2
+ output_name: "i2v"
3
+
4
+ pretrained_model_path: "TIGER-Lab/ConsistI2V"
5
+ unet_path: null
6
+ unet_ckpt_prefix: "module."
7
+ pipeline_pretrained_path: null
8
+
9
+ sampling_kwargs:
10
+ height: 256
11
+ width: 256
12
+ n_frames: 16
13
+ steps: 50
14
+ ddim_eta: 0.0
15
+ guidance_scale_txt: 7.5
16
+ guidance_scale_img: 1.0
17
+ guidance_rescale: 0.0
18
+ num_videos_per_prompt: 1
19
+ frame_stride: 3
20
+
21
+ unet_additional_kwargs:
22
+ variant: null
23
+ n_temp_heads: 8
24
+ augment_temporal_attention: true
25
+ temp_pos_embedding: "rotary" # "rotary" or "sinusoidal"
26
+ first_frame_condition_mode: "concat"
27
+ use_frame_stride_condition: true
28
+ noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive"
29
+ noise_alpha: 1.0
30
+
31
+ noise_scheduler_kwargs:
32
+ beta_start: 0.00085
33
+ beta_end: 0.012
34
+ beta_schedule: "linear"
35
+ steps_offset: 1
36
+ clip_sample: false
37
+ rescale_betas_zero_snr: false # true if using zero terminal snr
38
+ timestep_spacing: "leading" # "trailing" if using zero terminal snr
39
+ prediction_type: "epsilon" # "v_prediction" if using zero terminal snr
40
+
41
+ frameinit_kwargs:
42
+ enable: true
43
+ camera_motion: null
44
+ noise_level: 850
45
+ filter_params:
46
+ method: 'gaussian'
47
+ d_s: 0.25
48
+ d_t: 0.25
configs/inference/inference_autoregress.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: "samples/inference"
2
+ output_name: "long_video"
3
+
4
+ pretrained_model_path: "TIGER-Lab/ConsistI2V"
5
+ unet_path: null
6
+ unet_ckpt_prefix: "module."
7
+ pipeline_pretrained_path: null
8
+
9
+ sampling_kwargs:
10
+ height: 256
11
+ width: 256
12
+ n_frames: 16
13
+ steps: 50
14
+ ddim_eta: 0.0
15
+ guidance_scale_txt: 7.5
16
+ guidance_scale_img: 1.0
17
+ guidance_rescale: 0.0
18
+ num_videos_per_prompt: 1
19
+ frame_stride: 3
20
+ autoregress_steps: 3
21
+
22
+ unet_additional_kwargs:
23
+ variant: null
24
+ n_temp_heads: 8
25
+ augment_temporal_attention: true
26
+ temp_pos_embedding: "rotary" # "rotary" or "sinusoidal"
27
+ first_frame_condition_mode: "concat"
28
+ use_frame_stride_condition: true
29
+ noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive"
30
+ noise_alpha: 1.0
31
+
32
+ noise_scheduler_kwargs:
33
+ beta_start: 0.00085
34
+ beta_end: 0.012
35
+ beta_schedule: "linear"
36
+ steps_offset: 1
37
+ clip_sample: false
38
+ rescale_betas_zero_snr: false # true if using zero terminal snr
39
+ timestep_spacing: "leading" # "trailing" if using zero terminal snr
40
+ prediction_type: "epsilon" # "v_prediction" if using zero terminal snr
41
+
42
+
43
+ frameinit_kwargs:
44
+ enable: true
45
+ noise_level: 850
46
+ filter_params:
47
+ method: 'gaussian'
48
+ d_s: 0.25
49
+ d_t: 0.25
configs/prompts/default.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seeds: random
2
+
3
+ prompts:
4
+ - "timelapse at the snow land with aurora in the sky."
5
+ - "fireworks."
6
+ - "clown fish swimming through the coral reef."
7
+ - "melting ice cream dripping down the cone."
8
+
9
+ n_prompts:
10
+ - ""
11
+
12
+ path_to_first_frames:
13
+ - "assets/example/example_01.png"
14
+ - "assets/example/example_02.png"
15
+ - "assets/example/example_03.png"
16
+ - "assets/example/example_04.png"
configs/training/training.yaml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: "checkpoints"
2
+ pretrained_model_path: "stabilityai/stable-diffusion-2-1-base"
3
+
4
+ noise_scheduler_kwargs:
5
+ num_train_timesteps: 1000
6
+ beta_start: 0.00085
7
+ beta_end: 0.012
8
+ beta_schedule: "linear"
9
+ steps_offset: 1
10
+ clip_sample: false
11
+ rescale_betas_zero_snr: false # true if using zero terminal snr
12
+ timestep_spacing: "leading" # "trailing" if using zero terminal snr
13
+ prediction_type: "epsilon" # "v_prediction" if using zero terminal snr
14
+
15
+ train_data:
16
+ dataset: "joint"
17
+ pexels_config:
18
+ enable: false
19
+ json_path: null
20
+ caption_json_path: null
21
+ video_folder: null
22
+ webvid_config:
23
+ enable: true
24
+ json_path: "/path/to/webvid/annotation"
25
+ video_folder: "/path/to/webvid/data"
26
+ sample_size: 256
27
+ sample_duration: null
28
+ sample_fps: null
29
+ sample_stride: [1, 5]
30
+ sample_n_frames: 16
31
+
32
+ validation_data:
33
+ prompts:
34
+ - "timelapse at the snow land with aurora in the sky."
35
+ - "fireworks."
36
+ - "clown fish swimming through the coral reef."
37
+ - "melting ice cream dripping down the cone."
38
+
39
+ path_to_first_frames:
40
+ - "assets/example/example_01.jpg"
41
+ - "assets/example/example_02.jpg"
42
+ - "assets/example/example_03.jpg"
43
+ - "assets/example/example_04.jpg"
44
+
45
+ num_inference_steps: 50
46
+ ddim_eta: 0.0
47
+ guidance_scale_txt: 7.5
48
+ guidance_scale_img: 1.0
49
+ guidance_rescale: 0.0
50
+ frame_stride: 3
51
+
52
+ trainable_modules:
53
+ - "all"
54
+ # - "conv3ds."
55
+ # - "tempo_attns."
56
+
57
+ resume_from_checkpoint: null
58
+
59
+ unet_additional_kwargs:
60
+ variant: null
61
+ n_temp_heads: 8
62
+ augment_temporal_attention: true
63
+ temp_pos_embedding: "rotary" # "rotary" or "sinusoidal"
64
+ first_frame_condition_mode: "concat"
65
+ use_frame_stride_condition: true
66
+ noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive"
67
+ noise_alpha: 1.0
68
+
69
+ cfg_random_null_text_ratio: 0.1
70
+ cfg_random_null_img_ratio: 0.1
71
+
72
+ use_ema: false
73
+ ema_decay: 0.9999
74
+
75
+ learning_rate: 5.e-5
76
+ train_batch_size: 3
77
+ gradient_accumulation_steps: 1
78
+ max_grad_norm: 0.5
79
+
80
+ max_train_epoch: -1
81
+ max_train_steps: 200000
82
+ checkpointing_epochs: -1
83
+ checkpointing_steps: 2000
84
+ validation_steps: 1000
85
+
86
+ seed: 42
87
+ mixed_precision: "bf16"
88
+ num_workers: 32
89
+ enable_xformers_memory_efficient_attention: true
90
+
91
+ is_image: false
92
+ is_debug: false
consisti2v/data/dataset.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, csv, math, random
2
+ import json
3
+ import numpy as np
4
+ from einops import rearrange
5
+ from decord import VideoReader
6
+
7
+ import torch
8
+ import torchvision.transforms as transforms
9
+ from torch.utils.data.dataset import Dataset
10
+
11
+ from diffusers.utils import logging
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+ class WebVid10M(Dataset):
16
+ def __init__(
17
+ self,
18
+ json_path, video_folder=None,
19
+ sample_size=256, sample_stride=4, sample_n_frames=16,
20
+ is_image=False,
21
+ **kwargs,
22
+ ):
23
+ logger.info(f"loading annotations from {json_path} ...")
24
+ with open(json_path, 'rb') as json_file:
25
+ json_list = list(json_file)
26
+ self.dataset = [json.loads(json_str) for json_str in json_list]
27
+ self.length = len(self.dataset)
28
+ logger.info(f"data scale: {self.length}")
29
+
30
+ self.video_folder = video_folder
31
+ self.sample_stride = sample_stride if isinstance(sample_stride, int) else tuple(sample_stride)
32
+ self.sample_n_frames = sample_n_frames
33
+ self.is_image = is_image
34
+
35
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
36
+ self.pixel_transforms = transforms.Compose([
37
+ transforms.RandomHorizontalFlip(),
38
+ transforms.Resize(sample_size[0], antialias=None),
39
+ transforms.CenterCrop(sample_size),
40
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
41
+ ])
42
+
43
+ def get_batch(self, idx):
44
+ video_dict = self.dataset[idx]
45
+ video_relative_path, name = video_dict['file'], video_dict['text']
46
+
47
+ if self.video_folder is not None:
48
+ if video_relative_path[0] == '/':
49
+ video_dir = os.path.join(self.video_folder, os.path.basename(video_relative_path))
50
+ else:
51
+ video_dir = os.path.join(self.video_folder, video_relative_path)
52
+ else:
53
+ video_dir = video_relative_path
54
+ video_reader = VideoReader(video_dir)
55
+ video_length = len(video_reader)
56
+
57
+ if not self.is_image:
58
+ if isinstance(self.sample_stride, int):
59
+ stride = self.sample_stride
60
+ elif isinstance(self.sample_stride, tuple):
61
+ stride = random.randint(self.sample_stride[0], self.sample_stride[1])
62
+ clip_length = min(video_length, (self.sample_n_frames - 1) * stride + 1)
63
+ start_idx = random.randint(0, video_length - clip_length)
64
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
65
+ else:
66
+ frame_difference = random.randint(2, self.sample_n_frames)
67
+ clip_length = min(video_length, (frame_difference - 1) * self.sample_stride + 1)
68
+ start_idx = random.randint(0, video_length - clip_length)
69
+ batch_index = [start_idx, start_idx + clip_length - 1]
70
+
71
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
72
+ pixel_values = pixel_values / 255.
73
+ del video_reader
74
+
75
+ return pixel_values, name
76
+
77
+ def __len__(self):
78
+ return self.length
79
+
80
+ def __getitem__(self, idx):
81
+ while True:
82
+ try:
83
+ pixel_values, name = self.get_batch(idx)
84
+ break
85
+
86
+ except Exception as e:
87
+ idx = random.randint(0, self.length-1)
88
+
89
+ pixel_values = self.pixel_transforms(pixel_values)
90
+ sample = dict(pixel_values=pixel_values, text=name)
91
+ return sample
92
+
93
+
94
+ class Pexels(Dataset):
95
+ def __init__(
96
+ self,
97
+ json_path, caption_json_path, video_folder=None,
98
+ sample_size=256, sample_duration=1, sample_fps=8,
99
+ is_image=False,
100
+ **kwargs,
101
+ ):
102
+ logger.info(f"loading captions from {caption_json_path} ...")
103
+ with open(caption_json_path, 'rb') as caption_json_file:
104
+ caption_json_list = list(caption_json_file)
105
+ self.caption_dict = {json.loads(json_str)['id']: json.loads(json_str)['text'] for json_str in caption_json_list}
106
+
107
+ logger.info(f"loading annotations from {json_path} ...")
108
+ with open(json_path, 'rb') as json_file:
109
+ json_list = list(json_file)
110
+ dataset = [json.loads(json_str) for json_str in json_list]
111
+
112
+ self.dataset = []
113
+ for data in dataset:
114
+ data['text'] = self.caption_dict[data['id']]
115
+ if data['height'] / data['width'] < 0.625:
116
+ self.dataset.append(data)
117
+ self.length = len(self.dataset)
118
+ logger.info(f"data scale: {self.length}")
119
+
120
+ self.video_folder = video_folder
121
+ self.sample_duration = sample_duration
122
+ self.sample_fps = sample_fps
123
+ self.sample_n_frames = sample_duration * sample_fps
124
+ self.is_image = is_image
125
+
126
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
127
+ self.pixel_transforms = transforms.Compose([
128
+ transforms.RandomHorizontalFlip(),
129
+ transforms.Resize(sample_size[0], antialias=None),
130
+ transforms.CenterCrop(sample_size),
131
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
132
+ ])
133
+
134
+ def get_batch(self, idx):
135
+ video_dict = self.dataset[idx]
136
+ video_relative_path, name = video_dict['file'], video_dict['text']
137
+ fps = video_dict['fps']
138
+
139
+ if self.video_folder is not None:
140
+ if video_relative_path[0] == '/':
141
+ video_dir = os.path.join(self.video_folder, os.path.basename(video_relative_path))
142
+ else:
143
+ video_dir = os.path.join(self.video_folder, video_relative_path)
144
+ else:
145
+ video_dir = video_relative_path
146
+ video_reader = VideoReader(video_dir)
147
+ video_length = len(video_reader)
148
+
149
+ if not self.is_image:
150
+ clip_length = min(video_length, math.ceil(fps * self.sample_duration))
151
+ start_idx = random.randint(0, video_length - clip_length)
152
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
153
+ else:
154
+ frame_difference = random.randint(2, self.sample_n_frames)
155
+ sample_stride = math.ceil((fps * self.sample_duration) / (self.sample_n_frames - 1) - 1)
156
+ clip_length = min(video_length, (frame_difference - 1) * sample_stride + 1)
157
+ start_idx = random.randint(0, video_length - clip_length)
158
+ batch_index = [start_idx, start_idx + clip_length - 1]
159
+
160
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
161
+ pixel_values = pixel_values / 255.
162
+ del video_reader
163
+
164
+ return pixel_values, name
165
+
166
+ def __len__(self):
167
+ return self.length
168
+
169
+ def __getitem__(self, idx):
170
+ while True:
171
+ try:
172
+ pixel_values, name = self.get_batch(idx)
173
+ break
174
+
175
+ except Exception as e:
176
+ idx = random.randint(0, self.length-1)
177
+
178
+ pixel_values = self.pixel_transforms(pixel_values)
179
+ sample = dict(pixel_values=pixel_values, text=name)
180
+ return sample
181
+
182
+
183
+ class JointDataset(Dataset):
184
+ def __init__(
185
+ self,
186
+ webvid_config, pexels_config,
187
+ sample_size=256,
188
+ sample_duration=None, sample_fps=None, sample_stride=None, sample_n_frames=None,
189
+ is_image=False,
190
+ **kwargs,
191
+ ):
192
+ assert (sample_duration is None and sample_fps is None) or (sample_duration is not None and sample_fps is not None), "sample_duration and sample_fps should be both None or not None"
193
+ if sample_duration is not None and sample_fps is not None:
194
+ assert sample_stride is None, "when sample_duration and sample_fps are not None, sample_stride should be None"
195
+ if sample_stride is not None:
196
+ assert sample_fps is None and sample_duration is None, "when sample_stride is not None, sample_duration and sample_fps should be both None"
197
+
198
+ self.dataset = []
199
+
200
+ if pexels_config.enable:
201
+ logger.info(f"loading pexels dataset")
202
+ logger.info(f"loading captions from {pexels_config.caption_json_path} ...")
203
+ with open(pexels_config.caption_json_path, 'rb') as caption_json_file:
204
+ caption_json_list = list(caption_json_file)
205
+ self.caption_dict = {json.loads(json_str)['id']: json.loads(json_str)['text'] for json_str in caption_json_list}
206
+
207
+ logger.info(f"loading annotations from {pexels_config.json_path} ...")
208
+ with open(pexels_config.json_path, 'rb') as json_file:
209
+ json_list = list(json_file)
210
+ dataset = [json.loads(json_str) for json_str in json_list]
211
+
212
+ for data in dataset:
213
+ data['text'] = self.caption_dict[data['id']]
214
+ data['dataset'] = 'pexels'
215
+ if data['height'] / data['width'] < 0.625:
216
+ self.dataset.append(data)
217
+
218
+ if webvid_config.enable:
219
+ logger.info(f"loading webvid dataset")
220
+ logger.info(f"loading annotations from {webvid_config.json_path} ...")
221
+ with open(webvid_config.json_path, 'rb') as json_file:
222
+ json_list = list(json_file)
223
+ dataset = [json.loads(json_str) for json_str in json_list]
224
+ for data in dataset:
225
+ data['dataset'] = 'webvid'
226
+ self.dataset.extend(dataset)
227
+
228
+ self.length = len(self.dataset)
229
+ logger.info(f"data scale: {self.length}")
230
+
231
+ self.pexels_folder = pexels_config.video_folder
232
+ self.webvid_folder = webvid_config.video_folder
233
+ self.sample_duration = sample_duration
234
+ self.sample_fps = sample_fps
235
+ self.sample_n_frames = sample_duration * sample_fps if sample_n_frames is None else sample_n_frames
236
+ self.sample_stride = sample_stride if (sample_stride is None) or (sample_stride is not None and isinstance(sample_stride, int)) else tuple(sample_stride)
237
+ self.is_image = is_image
238
+
239
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
240
+ self.pixel_transforms = transforms.Compose([
241
+ transforms.RandomHorizontalFlip(),
242
+ transforms.Resize(sample_size[0], antialias=None),
243
+ transforms.CenterCrop(sample_size),
244
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
245
+ ])
246
+
247
+ def get_batch(self, idx):
248
+ video_dict = self.dataset[idx]
249
+ video_relative_path, name = video_dict['file'], video_dict['text']
250
+
251
+ if video_dict['dataset'] == 'pexels':
252
+ video_folder = self.pexels_folder
253
+ elif video_dict['dataset'] == 'webvid':
254
+ video_folder = self.webvid_folder
255
+ else:
256
+ raise NotImplementedError
257
+
258
+ if video_folder is not None:
259
+ if video_relative_path[0] == '/':
260
+ video_dir = os.path.join(video_folder, os.path.basename(video_relative_path))
261
+ else:
262
+ video_dir = os.path.join(video_folder, video_relative_path)
263
+ else:
264
+ video_dir = video_relative_path
265
+ video_reader = VideoReader(video_dir)
266
+ video_length = len(video_reader)
267
+
268
+ stride = None
269
+ if not self.is_image:
270
+ if self.sample_duration is not None:
271
+ fps = video_dict['fps']
272
+ clip_length = min(video_length, math.ceil(fps * self.sample_duration))
273
+ elif self.sample_stride is not None:
274
+ if isinstance(self.sample_stride, int):
275
+ stride = self.sample_stride
276
+ elif isinstance(self.sample_stride, tuple):
277
+ stride = random.randint(self.sample_stride[0], self.sample_stride[1])
278
+ clip_length = min(video_length, (self.sample_n_frames - 1) * stride + 1)
279
+
280
+ start_idx = random.randint(0, video_length - clip_length)
281
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
282
+
283
+ else:
284
+ frame_difference = random.randint(2, self.sample_n_frames)
285
+ if self.sample_duration is not None:
286
+ fps = video_dict['fps']
287
+ sample_stride = math.ceil((fps * self.sample_duration) / (self.sample_n_frames - 1) - 1)
288
+ elif self.sample_stride is not None:
289
+ sample_stride = self.sample_stride
290
+
291
+ clip_length = min(video_length, (frame_difference - 1) * sample_stride + 1)
292
+ start_idx = random.randint(0, video_length - clip_length)
293
+ batch_index = [start_idx, start_idx + clip_length - 1]
294
+
295
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
296
+ pixel_values = pixel_values / 255.
297
+ del video_reader
298
+
299
+ return pixel_values, name, stride
300
+
301
+ def __len__(self):
302
+ return self.length
303
+
304
+ def __getitem__(self, idx):
305
+ while True:
306
+ try:
307
+ pixel_values, name, stride = self.get_batch(idx)
308
+ break
309
+
310
+ except Exception as e:
311
+ idx = random.randint(0, self.length-1)
312
+
313
+ pixel_values = self.pixel_transforms(pixel_values)
314
+ sample = dict(pixel_values=pixel_values, text=name, stride=stride)
315
+ return sample
consisti2v/models/rotary_embedding.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi, log
2
+
3
+ import torch
4
+ from torch.nn import Module, ModuleList
5
+ from torch.cuda.amp import autocast
6
+ from torch import nn, einsum, broadcast_tensors, Tensor
7
+
8
+ from einops import rearrange, repeat
9
+
10
+ from beartype import beartype
11
+ from beartype.typing import Literal, Union, Optional
12
+
13
+ # helper functions
14
+
15
+ def exists(val):
16
+ return val is not None
17
+
18
+ def default(val, d):
19
+ return val if exists(val) else d
20
+
21
+ # broadcat, as tortoise-tts was using it
22
+
23
+ def broadcat(tensors, dim = -1):
24
+ broadcasted_tensors = broadcast_tensors(*tensors)
25
+ return torch.cat(broadcasted_tensors, dim = dim)
26
+
27
+ # rotary embedding helper functions
28
+
29
+ def rotate_half(x):
30
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
31
+ x1, x2 = x.unbind(dim = -1)
32
+ x = torch.stack((-x2, x1), dim = -1)
33
+ return rearrange(x, '... d r -> ... (d r)')
34
+
35
+ @autocast(enabled = False)
36
+ def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2):
37
+ if t.ndim == 3:
38
+ seq_len = t.shape[seq_dim]
39
+ freqs = freqs[-seq_len:].to(t)
40
+
41
+ rot_dim = freqs.shape[-1]
42
+ end_index = start_index + rot_dim
43
+
44
+ assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
45
+
46
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
47
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
48
+ return torch.cat((t_left, t, t_right), dim = -1)
49
+
50
+ # learned rotation helpers
51
+
52
+ def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None):
53
+ if exists(freq_ranges):
54
+ rotations = einsum('..., f -> ... f', rotations, freq_ranges)
55
+ rotations = rearrange(rotations, '... r f -> ... (r f)')
56
+
57
+ rotations = repeat(rotations, '... n -> ... (n r)', r = 2)
58
+ return apply_rotary_emb(rotations, t, start_index = start_index)
59
+
60
+ # classes
61
+
62
+ class RotaryEmbedding(Module):
63
+ @beartype
64
+ def __init__(
65
+ self,
66
+ dim,
67
+ custom_freqs: Optional[Tensor] = None,
68
+ freqs_for: Union[
69
+ Literal['lang'],
70
+ Literal['pixel'],
71
+ Literal['constant']
72
+ ] = 'lang',
73
+ theta = 10000,
74
+ max_freq = 10,
75
+ num_freqs = 1,
76
+ learned_freq = False,
77
+ use_xpos = False,
78
+ xpos_scale_base = 512,
79
+ interpolate_factor = 1.,
80
+ theta_rescale_factor = 1.,
81
+ seq_before_head_dim = False
82
+ ):
83
+ super().__init__()
84
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
85
+ # has some connection to NTK literature
86
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
87
+
88
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
89
+
90
+ self.freqs_for = freqs_for
91
+
92
+ if exists(custom_freqs):
93
+ freqs = custom_freqs
94
+ elif freqs_for == 'lang':
95
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
96
+ elif freqs_for == 'pixel':
97
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
98
+ elif freqs_for == 'constant':
99
+ freqs = torch.ones(num_freqs).float()
100
+
101
+ self.tmp_store('cached_freqs', None)
102
+ self.tmp_store('cached_scales', None)
103
+
104
+ self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
105
+
106
+ self.learned_freq = learned_freq
107
+
108
+ # dummy for device
109
+
110
+ self.tmp_store('dummy', torch.tensor(0))
111
+
112
+ # default sequence dimension
113
+
114
+ self.seq_before_head_dim = seq_before_head_dim
115
+ self.default_seq_dim = -3 if seq_before_head_dim else -2
116
+
117
+ # interpolation factors
118
+
119
+ assert interpolate_factor >= 1.
120
+ self.interpolate_factor = interpolate_factor
121
+
122
+ # xpos
123
+
124
+ self.use_xpos = use_xpos
125
+ if not use_xpos:
126
+ self.tmp_store('scale', None)
127
+ return
128
+
129
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
130
+ self.scale_base = xpos_scale_base
131
+ self.tmp_store('scale', scale)
132
+
133
+ @property
134
+ def device(self):
135
+ return self.dummy.device
136
+
137
+ def tmp_store(self, key, value):
138
+ self.register_buffer(key, value, persistent = False)
139
+
140
+ def get_seq_pos(self, seq_len, device, dtype, offset = 0):
141
+ return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor
142
+
143
+ def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, freq_seq_len = None, seq_pos = None):
144
+ seq_dim = default(seq_dim, self.default_seq_dim)
145
+
146
+ assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
147
+
148
+ device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
149
+
150
+ if exists(freq_seq_len):
151
+ assert freq_seq_len >= seq_len
152
+ seq_len = freq_seq_len
153
+
154
+ if seq_pos is None:
155
+ seq_pos = self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset)
156
+ else:
157
+ assert seq_pos.shape[0] == seq_len
158
+
159
+ freqs = self.forward(seq_pos, seq_len = seq_len, offset = offset)
160
+
161
+ if seq_dim == -3:
162
+ freqs = rearrange(freqs, 'n d -> n 1 d')
163
+
164
+ return apply_rotary_emb(freqs, t, seq_dim = seq_dim)
165
+
166
+ def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0):
167
+ seq_dim = default(seq_dim, self.default_seq_dim)
168
+
169
+ q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
170
+ assert q_len <= k_len
171
+ rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, freq_seq_len = k_len)
172
+ rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim)
173
+
174
+ rotated_q = rotated_q.type(q.dtype)
175
+ rotated_k = rotated_k.type(k.dtype)
176
+
177
+ return rotated_q, rotated_k
178
+
179
+ def rotate_queries_and_keys(self, q, k, seq_dim = None):
180
+ seq_dim = default(seq_dim, self.default_seq_dim)
181
+
182
+ assert self.use_xpos
183
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
184
+
185
+ seq = self.get_seq_pos(seq_len, dtype = dtype, device = device)
186
+
187
+ freqs = self.forward(seq, seq_len = seq_len)
188
+ scale = self.get_scale(seq, seq_len = seq_len).to(dtype)
189
+
190
+ if seq_dim == -3:
191
+ freqs = rearrange(freqs, 'n d -> n 1 d')
192
+ scale = rearrange(scale, 'n d -> n 1 d')
193
+
194
+ rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim)
195
+ rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim)
196
+
197
+ rotated_q = rotated_q.type(q.dtype)
198
+ rotated_k = rotated_k.type(k.dtype)
199
+
200
+ return rotated_q, rotated_k
201
+
202
+ @beartype
203
+ def get_scale(
204
+ self,
205
+ t: Tensor,
206
+ seq_len: Optional[int] = None,
207
+ offset = 0
208
+ ):
209
+ assert self.use_xpos
210
+
211
+ should_cache = exists(seq_len)
212
+
213
+ if (
214
+ should_cache and \
215
+ exists(self.cached_scales) and \
216
+ (seq_len + offset) <= self.cached_scales.shape[0]
217
+ ):
218
+ return self.cached_scales[offset:(offset + seq_len)]
219
+
220
+ scale = 1.
221
+ if self.use_xpos:
222
+ power = (t - len(t) // 2) / self.scale_base
223
+ scale = self.scale ** rearrange(power, 'n -> n 1')
224
+ scale = torch.cat((scale, scale), dim = -1)
225
+
226
+ if should_cache:
227
+ self.tmp_store('cached_scales', scale)
228
+
229
+ return scale
230
+
231
+ def get_axial_freqs(self, *dims):
232
+ Colon = slice(None)
233
+ all_freqs = []
234
+
235
+ for ind, dim in enumerate(dims):
236
+ if self.freqs_for == 'pixel':
237
+ pos = torch.linspace(-1, 1, steps = dim, device = self.device)
238
+ else:
239
+ pos = torch.arange(dim, device = self.device)
240
+
241
+ freqs = self.forward(pos, seq_len = dim)
242
+
243
+ all_axis = [None] * len(dims)
244
+ all_axis[ind] = Colon
245
+
246
+ new_axis_slice = (Ellipsis, *all_axis, Colon)
247
+ all_freqs.append(freqs[new_axis_slice])
248
+
249
+ all_freqs = broadcast_tensors(*all_freqs)
250
+ return torch.cat(all_freqs, dim = -1)
251
+
252
+ @autocast(enabled = False)
253
+ def forward(
254
+ self,
255
+ t: Tensor,
256
+ seq_len = None,
257
+ offset = 0
258
+ ):
259
+ # should_cache = (
260
+ # not self.learned_freq and \
261
+ # exists(seq_len) and \
262
+ # self.freqs_for != 'pixel'
263
+ # )
264
+
265
+ # if (
266
+ # should_cache and \
267
+ # exists(self.cached_freqs) and \
268
+ # (offset + seq_len) <= self.cached_freqs.shape[0]
269
+ # ):
270
+ # return self.cached_freqs[offset:(offset + seq_len)].detach()
271
+
272
+ freqs = self.freqs
273
+
274
+ freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
275
+ freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
276
+
277
+ # if should_cache:
278
+ # self.tmp_store('cached_freqs', freqs.detach())
279
+
280
+ return freqs
consisti2v/models/videoldm_attention.py ADDED
@@ -0,0 +1,809 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ from typing import Callable, Optional, Union
3
+ import math
4
+
5
+ from einops import rearrange, repeat
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ from diffusers.utils import deprecate, logging
12
+ from diffusers.utils.import_utils import is_xformers_available
13
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
14
+ from diffusers.models.lora import LoRACompatibleLinear, LoRALinearLayer
15
+ from diffusers.models.attention_processor import (
16
+ Attention,
17
+ AttnAddedKVProcessor,
18
+ AttnAddedKVProcessor2_0,
19
+ AttnProcessor,
20
+ AttnProcessor2_0,
21
+ SpatialNorm,
22
+ LORA_ATTENTION_PROCESSORS,
23
+ CustomDiffusionAttnProcessor,
24
+ CustomDiffusionXFormersAttnProcessor,
25
+ SlicedAttnAddedKVProcessor,
26
+ XFormersAttnAddedKVProcessor,
27
+ LoRAAttnAddedKVProcessor,
28
+ XFormersAttnProcessor,
29
+ LoRAXFormersAttnProcessor,
30
+ LoRAAttnProcessor,
31
+ LoRAAttnProcessor2_0,
32
+ SlicedAttnProcessor,
33
+ AttentionProcessor
34
+ )
35
+
36
+ from .rotary_embedding import RotaryEmbedding
37
+
38
+
39
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
40
+
41
+
42
+ if is_xformers_available():
43
+ import xformers
44
+ import xformers.ops
45
+ else:
46
+ xformers = None
47
+
48
+ @maybe_allow_in_graph
49
+ class ConditionalAttention(nn.Module):
50
+ r"""
51
+ A cross attention layer.
52
+
53
+ Parameters:
54
+ query_dim (`int`): The number of channels in the query.
55
+ cross_attention_dim (`int`, *optional*):
56
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
57
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
58
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
59
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
60
+ bias (`bool`, *optional*, defaults to False):
61
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ query_dim: int,
67
+ cross_attention_dim: Optional[int] = None,
68
+ heads: int = 8,
69
+ dim_head: int = 64,
70
+ dropout: float = 0.0,
71
+ bias=False,
72
+ upcast_attention: bool = False,
73
+ upcast_softmax: bool = False,
74
+ cross_attention_norm: Optional[str] = None,
75
+ cross_attention_norm_num_groups: int = 32,
76
+ added_kv_proj_dim: Optional[int] = None,
77
+ norm_num_groups: Optional[int] = None,
78
+ spatial_norm_dim: Optional[int] = None,
79
+ out_bias: bool = True,
80
+ scale_qk: bool = True,
81
+ only_cross_attention: bool = False,
82
+ eps: float = 1e-5,
83
+ rescale_output_factor: float = 1.0,
84
+ residual_connection: bool = False,
85
+ _from_deprecated_attn_block=False,
86
+ processor: Optional["AttnProcessor"] = None,
87
+ ):
88
+ super().__init__()
89
+ self.inner_dim = dim_head * heads
90
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
91
+ self.upcast_attention = upcast_attention
92
+ self.upcast_softmax = upcast_softmax
93
+ self.rescale_output_factor = rescale_output_factor
94
+ self.residual_connection = residual_connection
95
+ self.dropout = dropout
96
+
97
+ # we make use of this private variable to know whether this class is loaded
98
+ # with an deprecated state dict so that we can convert it on the fly
99
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
100
+
101
+ self.scale_qk = scale_qk
102
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
103
+
104
+ self.heads = heads
105
+ # for slice_size > 0 the attention score computation
106
+ # is split across the batch axis to save memory
107
+ # You can set slice_size with `set_attention_slice`
108
+ self.sliceable_head_dim = heads
109
+
110
+ self.added_kv_proj_dim = added_kv_proj_dim
111
+ self.only_cross_attention = only_cross_attention
112
+
113
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
114
+ raise ValueError(
115
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
116
+ )
117
+
118
+ if norm_num_groups is not None:
119
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
120
+ else:
121
+ self.group_norm = None
122
+
123
+ if spatial_norm_dim is not None:
124
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
125
+ else:
126
+ self.spatial_norm = None
127
+
128
+ if cross_attention_norm is None:
129
+ self.norm_cross = None
130
+ elif cross_attention_norm == "layer_norm":
131
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
132
+ elif cross_attention_norm == "group_norm":
133
+ if self.added_kv_proj_dim is not None:
134
+ # The given `encoder_hidden_states` are initially of shape
135
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
136
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
137
+ # before the projection, so we need to use `added_kv_proj_dim` as
138
+ # the number of channels for the group norm.
139
+ norm_cross_num_channels = added_kv_proj_dim
140
+ else:
141
+ norm_cross_num_channels = self.cross_attention_dim
142
+
143
+ self.norm_cross = nn.GroupNorm(
144
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
145
+ )
146
+ else:
147
+ raise ValueError(
148
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
149
+ )
150
+
151
+ self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias)
152
+
153
+ if not self.only_cross_attention:
154
+ # only relevant for the `AddedKVProcessor` classes
155
+ self.to_k = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias)
156
+ self.to_v = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias)
157
+ else:
158
+ self.to_k = None
159
+ self.to_v = None
160
+
161
+ if self.added_kv_proj_dim is not None:
162
+ self.add_k_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim)
163
+ self.add_v_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim)
164
+
165
+ self.to_out = nn.ModuleList([])
166
+ self.to_out.append(LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias))
167
+ self.to_out.append(nn.Dropout(dropout))
168
+
169
+ # set attention processor
170
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
171
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
172
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
173
+ if processor is None:
174
+ processor = (
175
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
176
+ )
177
+ self.set_processor(processor)
178
+
179
+ def set_use_memory_efficient_attention_xformers(
180
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
181
+ ):
182
+ is_lora = hasattr(self, "processor") and isinstance(
183
+ self.processor,
184
+ LORA_ATTENTION_PROCESSORS,
185
+ )
186
+ is_custom_diffusion = hasattr(self, "processor") and isinstance(
187
+ self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
188
+ )
189
+ is_added_kv_processor = hasattr(self, "processor") and isinstance(
190
+ self.processor,
191
+ (
192
+ AttnAddedKVProcessor,
193
+ AttnAddedKVProcessor2_0,
194
+ SlicedAttnAddedKVProcessor,
195
+ XFormersAttnAddedKVProcessor,
196
+ LoRAAttnAddedKVProcessor,
197
+ ),
198
+ )
199
+
200
+ if use_memory_efficient_attention_xformers:
201
+ if is_added_kv_processor and (is_lora or is_custom_diffusion):
202
+ raise NotImplementedError(
203
+ f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
204
+ )
205
+ if not is_xformers_available():
206
+ raise ModuleNotFoundError(
207
+ (
208
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
209
+ " xformers"
210
+ ),
211
+ name="xformers",
212
+ )
213
+ elif not torch.cuda.is_available():
214
+ raise ValueError(
215
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
216
+ " only available for GPU "
217
+ )
218
+ else:
219
+ try:
220
+ # Make sure we can run the memory efficient attention
221
+ _ = xformers.ops.memory_efficient_attention(
222
+ torch.randn((1, 2, 40), device="cuda"),
223
+ torch.randn((1, 2, 40), device="cuda"),
224
+ torch.randn((1, 2, 40), device="cuda"),
225
+ )
226
+ except Exception as e:
227
+ raise e
228
+
229
+ if is_lora:
230
+ # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
231
+ # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
232
+ processor = LoRAXFormersAttnProcessor(
233
+ hidden_size=self.processor.hidden_size,
234
+ cross_attention_dim=self.processor.cross_attention_dim,
235
+ rank=self.processor.rank,
236
+ attention_op=attention_op,
237
+ )
238
+ processor.load_state_dict(self.processor.state_dict())
239
+ processor.to(self.processor.to_q_lora.up.weight.device)
240
+ elif is_custom_diffusion:
241
+ processor = CustomDiffusionXFormersAttnProcessor(
242
+ train_kv=self.processor.train_kv,
243
+ train_q_out=self.processor.train_q_out,
244
+ hidden_size=self.processor.hidden_size,
245
+ cross_attention_dim=self.processor.cross_attention_dim,
246
+ attention_op=attention_op,
247
+ )
248
+ processor.load_state_dict(self.processor.state_dict())
249
+ if hasattr(self.processor, "to_k_custom_diffusion"):
250
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
251
+ elif is_added_kv_processor:
252
+ # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
253
+ # which uses this type of cross attention ONLY because the attention mask of format
254
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
255
+ # throw warning
256
+ logger.info(
257
+ "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
258
+ )
259
+ processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
260
+ else:
261
+ processor = XFormersAttnProcessor(attention_op=attention_op)
262
+ else:
263
+ if is_lora:
264
+ attn_processor_class = (
265
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
266
+ )
267
+ processor = attn_processor_class(
268
+ hidden_size=self.processor.hidden_size,
269
+ cross_attention_dim=self.processor.cross_attention_dim,
270
+ rank=self.processor.rank,
271
+ )
272
+ processor.load_state_dict(self.processor.state_dict())
273
+ processor.to(self.processor.to_q_lora.up.weight.device)
274
+ elif is_custom_diffusion:
275
+ processor = CustomDiffusionAttnProcessor(
276
+ train_kv=self.processor.train_kv,
277
+ train_q_out=self.processor.train_q_out,
278
+ hidden_size=self.processor.hidden_size,
279
+ cross_attention_dim=self.processor.cross_attention_dim,
280
+ )
281
+ processor.load_state_dict(self.processor.state_dict())
282
+ if hasattr(self.processor, "to_k_custom_diffusion"):
283
+ processor.to(self.processor.to_k_custom_diffusion.weight.device)
284
+ else:
285
+ # set attention processor
286
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
287
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
288
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
289
+ processor = (
290
+ AttnProcessor2_0()
291
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
292
+ else AttnProcessor()
293
+ )
294
+
295
+ self.set_processor(processor)
296
+
297
+ def set_attention_slice(self, slice_size):
298
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
299
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
300
+
301
+ if slice_size is not None and self.added_kv_proj_dim is not None:
302
+ processor = SlicedAttnAddedKVProcessor(slice_size)
303
+ elif slice_size is not None:
304
+ processor = SlicedAttnProcessor(slice_size)
305
+ elif self.added_kv_proj_dim is not None:
306
+ processor = AttnAddedKVProcessor()
307
+ else:
308
+ # set attention processor
309
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
310
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
311
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
312
+ processor = (
313
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
314
+ )
315
+
316
+ self.set_processor(processor)
317
+
318
+ def set_processor(self, processor: "AttnProcessor"):
319
+ if (
320
+ hasattr(self, "processor")
321
+ and not isinstance(processor, LORA_ATTENTION_PROCESSORS)
322
+ and self.to_q.lora_layer is not None
323
+ ):
324
+ deprecate(
325
+ "set_processor to offload LoRA",
326
+ "0.26.0",
327
+ "In detail, removing LoRA layers via calling `set_processor` or `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
328
+ )
329
+ # (Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
330
+ # We need to remove all LoRA layers
331
+ for module in self.modules():
332
+ if hasattr(module, "set_lora_layer"):
333
+ module.set_lora_layer(None)
334
+
335
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
336
+ # pop `processor` from `self._modules`
337
+ if (
338
+ hasattr(self, "processor")
339
+ and isinstance(self.processor, torch.nn.Module)
340
+ and not isinstance(processor, torch.nn.Module)
341
+ ):
342
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
343
+ self._modules.pop("processor")
344
+
345
+ self.processor = processor
346
+
347
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
348
+ if not return_deprecated_lora:
349
+ return self.processor
350
+
351
+ # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
352
+ # serialization format for LoRA Attention Processors. It should be deleted once the integration
353
+ # with PEFT is completed.
354
+ is_lora_activated = {
355
+ name: module.lora_layer is not None
356
+ for name, module in self.named_modules()
357
+ if hasattr(module, "lora_layer")
358
+ }
359
+
360
+ # 1. if no layer has a LoRA activated we can return the processor as usual
361
+ if not any(is_lora_activated.values()):
362
+ return self.processor
363
+
364
+ # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
365
+ is_lora_activated.pop("add_k_proj", None)
366
+ is_lora_activated.pop("add_v_proj", None)
367
+ # 2. else it is not posssible that only some layers have LoRA activated
368
+ if not all(is_lora_activated.values()):
369
+ raise ValueError(
370
+ f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
371
+ )
372
+
373
+ # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
374
+ non_lora_processor_cls_name = self.processor.__class__.__name__
375
+ lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
376
+
377
+ hidden_size = self.inner_dim
378
+
379
+ # now create a LoRA attention processor from the LoRA layers
380
+ if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
381
+ kwargs = {
382
+ "cross_attention_dim": self.cross_attention_dim,
383
+ "rank": self.to_q.lora_layer.rank,
384
+ "network_alpha": self.to_q.lora_layer.network_alpha,
385
+ "q_rank": self.to_q.lora_layer.rank,
386
+ "q_hidden_size": self.to_q.lora_layer.out_features,
387
+ "k_rank": self.to_k.lora_layer.rank,
388
+ "k_hidden_size": self.to_k.lora_layer.out_features,
389
+ "v_rank": self.to_v.lora_layer.rank,
390
+ "v_hidden_size": self.to_v.lora_layer.out_features,
391
+ "out_rank": self.to_out[0].lora_layer.rank,
392
+ "out_hidden_size": self.to_out[0].lora_layer.out_features,
393
+ }
394
+
395
+ if hasattr(self.processor, "attention_op"):
396
+ kwargs["attention_op"] = self.prcoessor.attention_op
397
+
398
+ lora_processor = lora_processor_cls(hidden_size, **kwargs)
399
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
400
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
401
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
402
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
403
+ elif lora_processor_cls == LoRAAttnAddedKVProcessor:
404
+ lora_processor = lora_processor_cls(
405
+ hidden_size,
406
+ cross_attention_dim=self.add_k_proj.weight.shape[0],
407
+ rank=self.to_q.lora_layer.rank,
408
+ network_alpha=self.to_q.lora_layer.network_alpha,
409
+ )
410
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
411
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
412
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
413
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
414
+
415
+ # only save if used
416
+ if self.add_k_proj.lora_layer is not None:
417
+ lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
418
+ lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
419
+ else:
420
+ lora_processor.add_k_proj_lora = None
421
+ lora_processor.add_v_proj_lora = None
422
+ else:
423
+ raise ValueError(f"{lora_processor_cls} does not exist.")
424
+
425
+ return lora_processor
426
+
427
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
428
+ # The `Attention` class can call different attention processors / attention functions
429
+ # here we simply pass along all tensors to the selected processor class
430
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
431
+ return self.processor(
432
+ self,
433
+ hidden_states,
434
+ encoder_hidden_states=encoder_hidden_states,
435
+ attention_mask=attention_mask,
436
+ **cross_attention_kwargs,
437
+ )
438
+
439
+ def batch_to_head_dim(self, tensor):
440
+ head_size = self.heads
441
+ batch_size, seq_len, dim = tensor.shape
442
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
443
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
444
+ return tensor
445
+
446
+ def head_to_batch_dim(self, tensor, out_dim=3):
447
+ head_size = self.heads
448
+ batch_size, seq_len, dim = tensor.shape
449
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
450
+ tensor = tensor.permute(0, 2, 1, 3)
451
+
452
+ if out_dim == 3:
453
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
454
+
455
+ return tensor
456
+
457
+ def get_attention_scores(self, query, key, attention_mask=None):
458
+ dtype = query.dtype
459
+ if self.upcast_attention:
460
+ query = query.float()
461
+ key = key.float()
462
+
463
+ if attention_mask is None:
464
+ baddbmm_input = torch.empty(
465
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
466
+ )
467
+ beta = 0
468
+ else:
469
+ baddbmm_input = attention_mask
470
+ beta = 1
471
+
472
+ attention_scores = torch.baddbmm(
473
+ baddbmm_input,
474
+ query,
475
+ key.transpose(-1, -2),
476
+ beta=beta,
477
+ alpha=self.scale,
478
+ )
479
+ del baddbmm_input
480
+
481
+ if self.upcast_softmax:
482
+ attention_scores = attention_scores.float()
483
+
484
+ attention_probs = attention_scores.softmax(dim=-1)
485
+ del attention_scores
486
+
487
+ attention_probs = attention_probs.to(dtype)
488
+
489
+ return attention_probs
490
+
491
+ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
492
+ if batch_size is None:
493
+ deprecate(
494
+ "batch_size=None",
495
+ "0.22.0",
496
+ (
497
+ "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
498
+ " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
499
+ " `prepare_attention_mask` when preparing the attention_mask."
500
+ ),
501
+ )
502
+ batch_size = 1
503
+
504
+ head_size = self.heads
505
+ if attention_mask is None:
506
+ return attention_mask
507
+
508
+ current_length: int = attention_mask.shape[-1]
509
+ if current_length != target_length:
510
+ if attention_mask.device.type == "mps":
511
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
512
+ # Instead, we can manually construct the padding tensor.
513
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
514
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
515
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
516
+ else:
517
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
518
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
519
+ # remaining_length: int = target_length - current_length
520
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
521
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
522
+
523
+ if out_dim == 3:
524
+ if attention_mask.shape[0] < batch_size * head_size:
525
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
526
+ elif out_dim == 4:
527
+ attention_mask = attention_mask.unsqueeze(1)
528
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
529
+
530
+ return attention_mask
531
+
532
+ def norm_encoder_hidden_states(self, encoder_hidden_states):
533
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
534
+
535
+ if isinstance(self.norm_cross, nn.LayerNorm):
536
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
537
+ elif isinstance(self.norm_cross, nn.GroupNorm):
538
+ # Group norm norms along the channels dimension and expects
539
+ # input to be in the shape of (N, C, *). In this case, we want
540
+ # to norm along the hidden dimension, so we need to move
541
+ # (batch_size, sequence_length, hidden_size) ->
542
+ # (batch_size, hidden_size, sequence_length)
543
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
544
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
545
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
546
+ else:
547
+ assert False
548
+
549
+ return encoder_hidden_states
550
+
551
+
552
+ class TemporalConditionalAttention(Attention):
553
+ def __init__(self, n_frames=8, rotary_emb=False, *args, **kwargs):
554
+ super().__init__(processor=RotaryEmbAttnProcessor2_0() if rotary_emb else None, *args, **kwargs)
555
+
556
+ if not rotary_emb:
557
+ self.pos_enc = PositionalEncoding(self.inner_dim)
558
+ else:
559
+ rotary_bias = RelativePositionBias(heads=kwargs['heads'], max_distance=32)
560
+ self.rotary_bias = rotary_bias
561
+ self.rotary_emb = RotaryEmbedding(self.inner_dim // 2)
562
+
563
+ self.use_rotary_emb = rotary_emb
564
+ self.n_frames = n_frames
565
+
566
+ def forward(
567
+ self,
568
+ hidden_states,
569
+ encoder_hidden_states=None,
570
+ attention_mask=None,
571
+ adjacent_slices=None,
572
+ **cross_attention_kwargs):
573
+
574
+ key_pos_idx = None
575
+
576
+ bt, hw, c = hidden_states.shape
577
+ hidden_states = rearrange(hidden_states, '(b t) hw c -> b hw t c', t=self.n_frames)
578
+ if not self.use_rotary_emb:
579
+ pos_embed = self.pos_enc(self.n_frames)
580
+ hidden_states = hidden_states + pos_embed
581
+ hidden_states = rearrange(hidden_states, 'b hw t c -> (b hw) t c')
582
+
583
+ if encoder_hidden_states is not None:
584
+ assert adjacent_slices is None
585
+ encoder_hidden_states = encoder_hidden_states[::self.n_frames]
586
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b hw) n c', hw=hw)
587
+
588
+ if adjacent_slices is not None:
589
+ assert encoder_hidden_states is None
590
+ adjacent_slices = rearrange(adjacent_slices, 'b c h w n -> b (h w) n c')
591
+ if not self.use_rotary_emb:
592
+ first_frame_pos_embed = pos_embed[0:1, :]
593
+ adjacent_slices = adjacent_slices + first_frame_pos_embed
594
+ else:
595
+ pos_idx = torch.arange(self.n_frames, device=hidden_states.device, dtype=hidden_states.dtype)
596
+ first_frame_pos_pad = torch.zeros(adjacent_slices.shape[2], device=hidden_states.device, dtype=hidden_states.dtype)
597
+ key_pos_idx = torch.cat([pos_idx, first_frame_pos_pad], dim=0)
598
+ adjacent_slices = rearrange(adjacent_slices, 'b hw n c -> (b hw) n c')
599
+ encoder_hidden_states = torch.cat([hidden_states, adjacent_slices], dim=1)
600
+
601
+ if not self.use_rotary_emb:
602
+ out = self.processor(
603
+ self,
604
+ hidden_states,
605
+ encoder_hidden_states=encoder_hidden_states,
606
+ attention_mask=attention_mask,
607
+ **cross_attention_kwargs,
608
+ )
609
+ else:
610
+ out = self.processor(
611
+ self,
612
+ hidden_states,
613
+ encoder_hidden_states=encoder_hidden_states,
614
+ attention_mask=attention_mask,
615
+ key_pos_idx=key_pos_idx,
616
+ **cross_attention_kwargs,
617
+ )
618
+
619
+ out = rearrange(out, '(b hw) t c -> (b t) hw c', hw=hw)
620
+
621
+ return out
622
+
623
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers, attention_op=None):
624
+ if use_memory_efficient_attention_xformers:
625
+ try:
626
+ # Make sure we can run the memory efficient attention
627
+ _ = xformers.ops.memory_efficient_attention(
628
+ torch.randn((1, 2, 40), device="cuda"),
629
+ torch.randn((1, 2, 40), device="cuda"),
630
+ torch.randn((1, 2, 40), device="cuda"),
631
+ )
632
+ except Exception as e:
633
+ raise e
634
+ processor = XFormersAttnProcessor(attention_op=attention_op)
635
+ else:
636
+ processor = (
637
+ AttnProcessor2_0()
638
+ if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
639
+ else AttnProcessor()
640
+ )
641
+ self.set_processor(processor)
642
+
643
+
644
+ class PositionalEncoding(nn.Module):
645
+ def __init__(self, dim, max_pos=512):
646
+ super().__init__()
647
+
648
+ pos = torch.arange(max_pos)
649
+
650
+ freq = torch.arange(dim//2) / dim
651
+ freq = (freq * torch.tensor(10000).log()).exp()
652
+
653
+ x = rearrange(pos, 'L -> L 1') / freq
654
+ x = rearrange(x, 'L d -> L d 1')
655
+
656
+ pe = torch.cat((x.sin(), x.cos()), dim=-1)
657
+ self.pe = rearrange(pe, 'L d sc -> L (d sc)')
658
+
659
+ self.dummy = nn.Parameter(torch.rand(1))
660
+
661
+ def forward(self, length):
662
+ enc = self.pe[:length]
663
+ enc = enc.to(self.dummy.device, self.dummy.dtype)
664
+ return enc
665
+
666
+
667
+ # code taken from https://github.com/Vchitect/LaVie/blob/main/base/models/temporal_attention.py
668
+ class RelativePositionBias(nn.Module):
669
+ def __init__(
670
+ self,
671
+ heads=8,
672
+ num_buckets=32,
673
+ max_distance=128,
674
+ ):
675
+ super().__init__()
676
+ self.num_buckets = num_buckets
677
+ self.max_distance = max_distance
678
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
679
+
680
+ @staticmethod
681
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
682
+ ret = 0
683
+ n = -relative_position
684
+
685
+ num_buckets //= 2
686
+ ret += (n < 0).long() * num_buckets
687
+ n = torch.abs(n)
688
+
689
+ max_exact = num_buckets // 2
690
+ is_small = n < max_exact
691
+
692
+ val_if_large = max_exact + (
693
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
694
+ ).long()
695
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
696
+
697
+ ret += torch.where(is_small, n, val_if_large)
698
+ return ret
699
+
700
+ def forward(self, qlen, klen, device, dtype):
701
+ q_pos = torch.arange(qlen, dtype = torch.long, device = device)
702
+ k_pos = torch.arange(klen, dtype = torch.long, device = device)
703
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
704
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
705
+ values = self.relative_attention_bias(rp_bucket)
706
+ values = values.to(device, dtype)
707
+ return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
708
+
709
+
710
+ class RotaryEmbAttnProcessor2_0:
711
+ r"""
712
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
713
+ Add rotary embedding support
714
+ """
715
+
716
+ def __init__(self):
717
+
718
+ if not hasattr(F, "scaled_dot_product_attention"):
719
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
720
+
721
+ def __call__(
722
+ self,
723
+ attn: Attention,
724
+ hidden_states,
725
+ encoder_hidden_states=None,
726
+ attention_mask=None,
727
+ temb=None,
728
+ scale: float = 1.0,
729
+ key_pos_idx: Optional[torch.Tensor] = None,
730
+ ):
731
+ assert attention_mask is None
732
+ residual = hidden_states
733
+
734
+ if attn.spatial_norm is not None:
735
+ hidden_states = attn.spatial_norm(hidden_states, temb)
736
+
737
+ input_ndim = hidden_states.ndim
738
+
739
+ if input_ndim == 4:
740
+ batch_size, channel, height, width = hidden_states.shape
741
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
742
+
743
+ batch_size, sequence_length, _ = (
744
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
745
+ )
746
+
747
+ # if attention_mask is not None:
748
+ # attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
749
+ # # scaled_dot_product_attention expects attention_mask shape to be
750
+ # # (batch, heads, source_length, target_length)
751
+ # attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
752
+
753
+ if attn.group_norm is not None:
754
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
755
+
756
+ query = attn.to_q(hidden_states, scale=scale)
757
+
758
+ if encoder_hidden_states is None:
759
+ encoder_hidden_states = hidden_states
760
+ elif attn.norm_cross:
761
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
762
+
763
+ qlen = hidden_states.shape[1]
764
+ klen = encoder_hidden_states.shape[1]
765
+ # currently only add bias for self attention. Relative distance doesn't make sense for cross attention.
766
+ # if qlen == klen:
767
+ # time_rel_pos_bias = attn.rotary_bias(qlen, klen, device=hidden_states.device, dtype=hidden_states.dtype)
768
+ # attention_mask = repeat(time_rel_pos_bias, "h d1 d2 -> b h d1 d2", b=batch_size)
769
+
770
+ key = attn.to_k(encoder_hidden_states, scale=scale)
771
+ value = attn.to_v(encoder_hidden_states, scale=scale)
772
+
773
+ query = attn.rotary_emb.rotate_queries_or_keys(query)
774
+ if qlen == klen:
775
+ key = attn.rotary_emb.rotate_queries_or_keys(key)
776
+ elif key_pos_idx is not None:
777
+ key = attn.rotary_emb.rotate_queries_or_keys(key, seq_pos=key_pos_idx)
778
+
779
+ inner_dim = key.shape[-1]
780
+ head_dim = inner_dim // attn.heads
781
+
782
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
783
+
784
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
785
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
786
+
787
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
788
+ # TODO: add support for attn.scale when we move to Torch 2.1
789
+ hidden_states = F.scaled_dot_product_attention(
790
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
791
+ )
792
+
793
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
794
+ hidden_states = hidden_states.to(query.dtype)
795
+
796
+ # linear proj
797
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
798
+ # dropout
799
+ hidden_states = attn.to_out[1](hidden_states)
800
+
801
+ if input_ndim == 4:
802
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
803
+
804
+ if attn.residual_connection:
805
+ hidden_states = hidden_states + residual
806
+
807
+ hidden_states = hidden_states / attn.rescale_output_factor
808
+
809
+ return hidden_states
consisti2v/models/videoldm_transformer_blocks.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/v0.21.0/src/diffusers/models/transformer_2d.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ from einops import rearrange, repeat
10
+
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
13
+ from diffusers.utils import BaseOutput, deprecate
14
+ from diffusers.models.attention import AdaLayerNorm, AdaLayerNormZero, FeedForward, GatedSelfAttentionDense
15
+ from diffusers.models.embeddings import PatchEmbed
16
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
17
+ from diffusers.models.modeling_utils import ModelMixin
18
+ from diffusers.models.transformer_2d import Transformer2DModelOutput
19
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
20
+ from diffusers.models.attention_processor import Attention
21
+ from diffusers.models.lora import LoRACompatibleLinear
22
+
23
+ from .videoldm_attention import ConditionalAttention, TemporalConditionalAttention
24
+
25
+
26
+ class Transformer2DConditionModel(ModelMixin, ConfigMixin):
27
+ @register_to_config
28
+ def __init__(
29
+ self,
30
+ num_attention_heads: int = 16,
31
+ attention_head_dim: int = 88,
32
+ in_channels: Optional[int] = None,
33
+ out_channels: Optional[int] = None,
34
+ num_layers: int = 1,
35
+ dropout: float = 0.0,
36
+ norm_num_groups: int = 32,
37
+ cross_attention_dim: Optional[int] = None,
38
+ attention_bias: bool = False,
39
+ sample_size: Optional[int] = None,
40
+ num_vector_embeds: Optional[int] = None,
41
+ patch_size: Optional[int] = None,
42
+ activation_fn: str = "geglu",
43
+ num_embeds_ada_norm: Optional[int] = None,
44
+ use_linear_projection: bool = False,
45
+ only_cross_attention: bool = False,
46
+ double_self_attention: bool = False,
47
+ upcast_attention: bool = False,
48
+ norm_type: str = "layer_norm",
49
+ norm_elementwise_affine: bool = True,
50
+ attention_type: str = "default",
51
+ # additional
52
+ n_frames: int = 8,
53
+ is_temporal: bool = False,
54
+ augment_temporal_attention: bool = False,
55
+ rotary_emb=False,
56
+ ):
57
+ super().__init__()
58
+ self.use_linear_projection = use_linear_projection
59
+ self.num_attention_heads = num_attention_heads
60
+ self.attention_head_dim = attention_head_dim
61
+ inner_dim = num_attention_heads * attention_head_dim
62
+
63
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
64
+ # Define whether input is continuous or discrete depending on configuration
65
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
66
+ self.is_input_vectorized = num_vector_embeds is not None
67
+ self.is_input_patches = in_channels is not None and patch_size is not None
68
+
69
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
70
+ deprecation_message = (
71
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
72
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
73
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
74
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
75
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
76
+ )
77
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
78
+ norm_type = "ada_norm"
79
+
80
+ if self.is_input_continuous and self.is_input_vectorized:
81
+ raise ValueError(
82
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
83
+ " sure that either `in_channels` or `num_vector_embeds` is None."
84
+ )
85
+ elif self.is_input_vectorized and self.is_input_patches:
86
+ raise ValueError(
87
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
88
+ " sure that either `num_vector_embeds` or `num_patches` is None."
89
+ )
90
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
91
+ raise ValueError(
92
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
93
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
94
+ )
95
+
96
+ # 2. Define input layers
97
+ if self.is_input_continuous:
98
+ self.in_channels = in_channels
99
+
100
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
101
+ if use_linear_projection:
102
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
103
+ else:
104
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
105
+ elif self.is_input_vectorized:
106
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
107
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
108
+
109
+ self.height = sample_size
110
+ self.width = sample_size
111
+ self.num_vector_embeds = num_vector_embeds
112
+ self.num_latent_pixels = self.height * self.width
113
+
114
+ self.latent_image_embedding = ImagePositionalEmbeddings(
115
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
116
+ )
117
+ elif self.is_input_patches:
118
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
119
+
120
+ self.height = sample_size
121
+ self.width = sample_size
122
+
123
+ self.patch_size = patch_size
124
+ self.pos_embed = PatchEmbed(
125
+ height=sample_size,
126
+ width=sample_size,
127
+ patch_size=patch_size,
128
+ in_channels=in_channels,
129
+ embed_dim=inner_dim,
130
+ )
131
+
132
+ # 3. Define transformers blocks
133
+ self.transformer_blocks = nn.ModuleList(
134
+ [
135
+ BasicConditionalTransformerBlock(
136
+ inner_dim,
137
+ num_attention_heads,
138
+ attention_head_dim,
139
+ dropout=dropout,
140
+ cross_attention_dim=cross_attention_dim,
141
+ activation_fn=activation_fn,
142
+ num_embeds_ada_norm=num_embeds_ada_norm,
143
+ attention_bias=attention_bias,
144
+ only_cross_attention=only_cross_attention,
145
+ double_self_attention=double_self_attention,
146
+ upcast_attention=upcast_attention,
147
+ norm_type=norm_type,
148
+ norm_elementwise_affine=norm_elementwise_affine,
149
+ attention_type=attention_type,
150
+ # additional
151
+ n_frames=n_frames,
152
+ is_temporal=is_temporal,
153
+ augment_temporal_attention=augment_temporal_attention,
154
+ rotary_emb=rotary_emb,
155
+ )
156
+ for d in range(num_layers)
157
+ ]
158
+ )
159
+
160
+ # 4. Define output layers
161
+ self.out_channels = in_channels if out_channels is None else out_channels
162
+ if self.is_input_continuous:
163
+ # TODO: should use out_channels for continuous projections
164
+ if use_linear_projection:
165
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
166
+ else:
167
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
168
+ elif self.is_input_vectorized:
169
+ self.norm_out = nn.LayerNorm(inner_dim)
170
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
171
+ elif self.is_input_patches:
172
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
173
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
174
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
175
+
176
+ self.alpha = None
177
+ if is_temporal:
178
+ self.alpha = nn.Parameter(torch.ones(1))
179
+
180
+ self.gradient_checkpointing = False
181
+
182
+ def forward(
183
+ self,
184
+ hidden_states: torch.Tensor,
185
+ encoder_hidden_states: Optional[torch.Tensor] = None,
186
+ timestep: Optional[torch.LongTensor] = None,
187
+ class_labels: Optional[torch.LongTensor] = None,
188
+ cross_attention_kwargs: Dict[str, Any] = None,
189
+ attention_mask: Optional[torch.Tensor] = None,
190
+ encoder_attention_mask: Optional[torch.Tensor] = None,
191
+ return_dict: bool = True,
192
+ condition_on_first_frame: bool = False,
193
+ ):
194
+ input_states = hidden_states
195
+ input_height, input_width = hidden_states.shape[-2:]
196
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
197
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
198
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
199
+ # expects mask of shape:
200
+ # [batch, key_tokens]
201
+ # adds singleton query_tokens dimension:
202
+ # [batch, 1, key_tokens]
203
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
204
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
205
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
206
+ if attention_mask is not None and attention_mask.ndim == 2:
207
+ # assume that mask is expressed as:
208
+ # (1 = keep, 0 = discard)
209
+ # convert mask into a bias that can be added to attention scores:
210
+ # (keep = +0, discard = -10000.0)
211
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
212
+ attention_mask = attention_mask.unsqueeze(1)
213
+
214
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
215
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
216
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
217
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
218
+
219
+ # Retrieve lora scale.
220
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
221
+
222
+ # 1. Input
223
+ if self.is_input_continuous:
224
+ batch, _, height, width = hidden_states.shape
225
+ residual = hidden_states
226
+
227
+ hidden_states = self.norm(hidden_states)
228
+ if not self.use_linear_projection:
229
+ hidden_states = self.proj_in(hidden_states, lora_scale)
230
+ inner_dim = hidden_states.shape[1]
231
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
232
+ else:
233
+ inner_dim = hidden_states.shape[1]
234
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
235
+ hidden_states = self.proj_in(hidden_states, scale=lora_scale)
236
+
237
+ elif self.is_input_vectorized:
238
+ hidden_states = self.latent_image_embedding(hidden_states)
239
+ elif self.is_input_patches:
240
+ hidden_states = self.pos_embed(hidden_states)
241
+
242
+ # 2. Blocks
243
+ for block in self.transformer_blocks:
244
+ if self.training and self.gradient_checkpointing:
245
+ hidden_states = torch.utils.checkpoint.checkpoint(
246
+ block,
247
+ hidden_states,
248
+ attention_mask,
249
+ encoder_hidden_states,
250
+ encoder_attention_mask,
251
+ timestep,
252
+ cross_attention_kwargs,
253
+ class_labels,
254
+ use_reentrant=False,
255
+ )
256
+ else:
257
+ hidden_states = block(
258
+ hidden_states,
259
+ attention_mask=attention_mask,
260
+ encoder_hidden_states=encoder_hidden_states,
261
+ encoder_attention_mask=encoder_attention_mask,
262
+ timestep=timestep,
263
+ cross_attention_kwargs=cross_attention_kwargs,
264
+ class_labels=class_labels,
265
+ # additional
266
+ condition_on_first_frame=condition_on_first_frame,
267
+ input_height=input_height,
268
+ input_width=input_width,
269
+ )
270
+
271
+ # 3. Output
272
+ if self.is_input_continuous:
273
+ if not self.use_linear_projection:
274
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
275
+ hidden_states = self.proj_out(hidden_states, scale=lora_scale)
276
+ else:
277
+ hidden_states = self.proj_out(hidden_states, scale=lora_scale)
278
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
279
+
280
+ output = hidden_states + residual
281
+ elif self.is_input_vectorized:
282
+ hidden_states = self.norm_out(hidden_states)
283
+ logits = self.out(hidden_states)
284
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
285
+ logits = logits.permute(0, 2, 1)
286
+
287
+ # log(p(x_0))
288
+ output = F.log_softmax(logits.double(), dim=1).float()
289
+ elif self.is_input_patches:
290
+ # TODO: cleanup!
291
+ conditioning = self.transformer_blocks[0].norm1.emb(
292
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
293
+ )
294
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
295
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
296
+ hidden_states = self.proj_out_2(hidden_states)
297
+
298
+ # unpatchify
299
+ height = width = int(hidden_states.shape[1] ** 0.5)
300
+ hidden_states = hidden_states.reshape(
301
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
302
+ )
303
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
304
+ output = hidden_states.reshape(
305
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
306
+ )
307
+
308
+ if self.alpha is not None:
309
+ with torch.no_grad():
310
+ self.alpha.clamp_(0, 1)
311
+
312
+ output = self.alpha * input_states + (1 - self.alpha) * output
313
+
314
+ if not return_dict:
315
+ return (output,)
316
+
317
+ return Transformer2DModelOutput(sample=output)
318
+
319
+
320
+ @maybe_allow_in_graph
321
+ class BasicConditionalTransformerBlock(nn.Module):
322
+ """ transformer block with first frame conditioning """
323
+ def __init__(
324
+ self,
325
+ dim: int,
326
+ num_attention_heads: int,
327
+ attention_head_dim: int,
328
+ dropout=0.0,
329
+ cross_attention_dim: Optional[int] = None,
330
+ activation_fn: str = "geglu",
331
+ num_embeds_ada_norm: Optional[int] = None,
332
+ attention_bias: bool = False,
333
+ only_cross_attention: bool = False,
334
+ double_self_attention: bool = False,
335
+ upcast_attention: bool = False,
336
+ norm_elementwise_affine: bool = True,
337
+ norm_type: str = "layer_norm",
338
+ final_dropout: bool = False,
339
+ attention_type: str = "default",
340
+ # additional
341
+ n_frames: int = 8,
342
+ is_temporal: bool = False,
343
+ augment_temporal_attention: bool = False,
344
+ rotary_emb=False,
345
+ ):
346
+ super().__init__()
347
+ self.n_frames = n_frames
348
+ self.only_cross_attention = only_cross_attention
349
+ self.augment_temporal_attention = augment_temporal_attention
350
+ self.is_temporal = is_temporal
351
+
352
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
353
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
354
+
355
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
356
+ raise ValueError(
357
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
358
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
359
+ )
360
+
361
+ # Define 3 blocks. Each block has its own normalization layer.
362
+ # 1. Self-Attn
363
+ if self.use_ada_layer_norm:
364
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
365
+ elif self.use_ada_layer_norm_zero:
366
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
367
+ else:
368
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
369
+
370
+ if not is_temporal:
371
+ self.attn1 = ConditionalAttention(
372
+ query_dim=dim,
373
+ heads=num_attention_heads,
374
+ dim_head=attention_head_dim,
375
+ dropout=dropout,
376
+ bias=attention_bias,
377
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
378
+ upcast_attention=upcast_attention,
379
+ )
380
+ else:
381
+ self.attn1 = TemporalConditionalAttention(
382
+ query_dim=dim,
383
+ heads=num_attention_heads,
384
+ dim_head=attention_head_dim,
385
+ dropout=dropout,
386
+ bias=attention_bias,
387
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
388
+ upcast_attention=upcast_attention,
389
+ # additional
390
+ n_frames=n_frames,
391
+ rotary_emb=rotary_emb,
392
+ )
393
+
394
+ # 2. Cross-Attn
395
+ if cross_attention_dim is not None or double_self_attention:
396
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
397
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
398
+ # the second cross attention block.
399
+ self.norm2 = (
400
+ AdaLayerNorm(dim, num_embeds_ada_norm)
401
+ if self.use_ada_layer_norm
402
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
403
+ )
404
+ if not is_temporal:
405
+ self.attn2 = ConditionalAttention(
406
+ query_dim=dim,
407
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
408
+ heads=num_attention_heads,
409
+ dim_head=attention_head_dim,
410
+ dropout=dropout,
411
+ bias=attention_bias,
412
+ upcast_attention=upcast_attention,
413
+ ) # is self-attn if encoder_hidden_states is none
414
+ else:
415
+ self.attn2 = TemporalConditionalAttention(
416
+ query_dim=dim,
417
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
418
+ heads=num_attention_heads,
419
+ dim_head=attention_head_dim,
420
+ dropout=dropout,
421
+ bias=attention_bias,
422
+ upcast_attention=upcast_attention,
423
+ # additional
424
+ n_frames=n_frames,
425
+ rotary_emb=rotary_emb,
426
+ )
427
+ else:
428
+ self.norm2 = None
429
+ self.attn2 = None
430
+
431
+ # 3. Feed-forward
432
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
433
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
434
+
435
+ # 4. Fuser
436
+ if attention_type == "gated" or attention_type == "gated-text-image":
437
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
438
+
439
+ # let chunk size default to None
440
+ self._chunk_size = None
441
+ self._chunk_dim = 0
442
+
443
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
444
+ # Sets chunk feed-forward
445
+ self._chunk_size = chunk_size
446
+ self._chunk_dim = dim
447
+
448
+ def forward(
449
+ self,
450
+ hidden_states: torch.FloatTensor,
451
+ attention_mask: Optional[torch.FloatTensor] = None,
452
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
453
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
454
+ timestep: Optional[torch.LongTensor] = None,
455
+ cross_attention_kwargs: Dict[str, Any] = None,
456
+ class_labels: Optional[torch.LongTensor] = None,
457
+ condition_on_first_frame: bool = False,
458
+ input_height: Optional[int] = None,
459
+ input_width: Optional[int] = None,
460
+ ):
461
+ # Notice that normalization is always applied before the real computation in the following blocks.
462
+ # 0. Self-Attention
463
+ if self.use_ada_layer_norm:
464
+ norm_hidden_states = self.norm1(hidden_states, timestep)
465
+ elif self.use_ada_layer_norm_zero:
466
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
467
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
468
+ )
469
+ else:
470
+ norm_hidden_states = self.norm1(hidden_states)
471
+
472
+ # 1. Retrieve lora scale.
473
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
474
+
475
+ # 2. Prepare GLIGEN inputs
476
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
477
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
478
+
479
+ if condition_on_first_frame:
480
+ first_frame_hidden_states = rearrange(norm_hidden_states, '(b f) d h -> b f d h', f=self.n_frames)[:, 0, :, :]
481
+ first_frame_hidden_states = repeat(first_frame_hidden_states, 'b d h -> b f d h', f=self.n_frames)
482
+ first_frame_hidden_states = rearrange(first_frame_hidden_states, 'b f d h -> (b f) d h')
483
+ first_frame_concat_hidden_states = torch.cat((norm_hidden_states, first_frame_hidden_states), dim=1)
484
+ attn_output = self.attn1(
485
+ norm_hidden_states,
486
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else first_frame_concat_hidden_states,
487
+ attention_mask=attention_mask,
488
+ **cross_attention_kwargs,
489
+ )
490
+ elif self.is_temporal and self.augment_temporal_attention:
491
+ first_frame_hidden_states = rearrange(norm_hidden_states, '(b f) d h -> b f d h', f=self.n_frames)[:, 0, :, :]
492
+ first_frame_hidden_states = rearrange(first_frame_hidden_states, 'b (h w) c -> b h w c', h=input_height, w=input_width)
493
+ first_frame_hidden_states = first_frame_hidden_states.permute(0, 3, 1, 2)
494
+ padded_first_frame = torch.nn.functional.pad(first_frame_hidden_states, (1, 1, 1, 1), "replicate")
495
+ first_frame_windows = padded_first_frame.unfold(2, 3, 1).unfold(3, 3, 1)
496
+ mask = torch.tensor([[1, 1, 1], [1, 0, 1], [1, 1, 1]], dtype=torch.bool)
497
+ adjacent_slices = first_frame_windows[:, :, :, :, mask]
498
+ attn_output = self.attn1(
499
+ norm_hidden_states,
500
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
501
+ attention_mask=attention_mask,
502
+ adjacent_slices=adjacent_slices,
503
+ **cross_attention_kwargs,
504
+ )
505
+ else:
506
+ attn_output = self.attn1(
507
+ norm_hidden_states,
508
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
509
+ attention_mask=attention_mask,
510
+ **cross_attention_kwargs,
511
+ )
512
+ if self.use_ada_layer_norm_zero:
513
+ attn_output = gate_msa.unsqueeze(1) * attn_output
514
+ hidden_states = attn_output + hidden_states
515
+
516
+ # 2.5 GLIGEN Control
517
+ if gligen_kwargs is not None:
518
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
519
+ # 2.5 ends
520
+
521
+ # 3. Cross-Attention
522
+ if self.attn2 is not None:
523
+ norm_hidden_states = (
524
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
525
+ )
526
+
527
+ attn_output = self.attn2(
528
+ norm_hidden_states,
529
+ encoder_hidden_states=encoder_hidden_states,
530
+ attention_mask=encoder_attention_mask,
531
+ **cross_attention_kwargs,
532
+ )
533
+ hidden_states = attn_output + hidden_states
534
+
535
+ # 4. Feed-forward
536
+ norm_hidden_states = self.norm3(hidden_states)
537
+
538
+ if self.use_ada_layer_norm_zero:
539
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
540
+
541
+ if self._chunk_size is not None:
542
+ # "feed_forward_chunk_size" can be used to save memory
543
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
544
+ raise ValueError(
545
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
546
+ )
547
+
548
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
549
+ ff_output = torch.cat(
550
+ [
551
+ self.ff(hid_slice, scale=lora_scale)
552
+ for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
553
+ ],
554
+ dim=self._chunk_dim,
555
+ )
556
+ else:
557
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
558
+
559
+ if self.use_ada_layer_norm_zero:
560
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
561
+
562
+ hidden_states = ff_output + hidden_states
563
+
564
+ return hidden_states
consisti2v/models/videoldm_unet.py ADDED
@@ -0,0 +1,1371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from typing import Optional, Tuple, Union, Dict, List, Any
4
+ from einops import rearrange, repeat
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from diffusers.loaders import UNet2DConditionLoadersMixin
9
+ from diffusers.models import ModelMixin
10
+ from diffusers.models.unet_2d_condition import UNet2DConditionOutput
11
+ from diffusers.models.unet_2d_blocks import UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn
12
+ from diffusers.models.embeddings import (
13
+ GaussianFourierProjection,
14
+ ImageHintTimeEmbedding,
15
+ ImageProjection,
16
+ ImageTimeEmbedding,
17
+ PositionNet,
18
+ TextImageProjection,
19
+ TextImageTimeEmbedding,
20
+ TextTimeEmbedding,
21
+ TimestepEmbedding,
22
+ Timesteps,
23
+ )
24
+ from diffusers.models.attention_processor import (
25
+ ADDED_KV_ATTENTION_PROCESSORS,
26
+ CROSS_ATTENTION_PROCESSORS,
27
+ AttentionProcessor,
28
+ AttnAddedKVProcessor,
29
+ AttnProcessor,
30
+ )
31
+ from diffusers.models.activations import get_activation
32
+ from diffusers.configuration_utils import register_to_config, ConfigMixin
33
+ from diffusers.models.modeling_utils import load_state_dict, load_model_dict_into_meta
34
+ from diffusers.utils import (
35
+ CONFIG_NAME,
36
+ DIFFUSERS_CACHE,
37
+ FLAX_WEIGHTS_NAME,
38
+ HF_HUB_OFFLINE,
39
+ SAFETENSORS_WEIGHTS_NAME,
40
+ WEIGHTS_NAME,
41
+ _add_variant,
42
+ _get_model_file,
43
+ deprecate,
44
+ is_accelerate_available,
45
+ is_torch_version,
46
+ logging,
47
+ )
48
+ from diffusers import __version__
49
+
50
+ if is_torch_version(">=", "1.9.0"):
51
+ _LOW_CPU_MEM_USAGE_DEFAULT = True
52
+ else:
53
+ _LOW_CPU_MEM_USAGE_DEFAULT = False
54
+
55
+
56
+ if is_accelerate_available():
57
+ import accelerate
58
+ from accelerate.utils import set_module_tensor_to_device
59
+ from accelerate.utils.versions import is_torch_version
60
+
61
+
62
+
63
+ from .videoldm_unet_blocks import get_down_block, get_up_block, VideoLDMUNetMidBlock2DCrossAttn
64
+
65
+ logger = logging.get_logger(__name__)
66
+
67
+
68
+ class VideoLDMUNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
69
+ _supports_gradient_checkpointing = True
70
+ @register_to_config
71
+ def __init__(
72
+ self,
73
+ sample_size: Optional[int] = None,
74
+ in_channels: int = 4,
75
+ out_channels: int = 4,
76
+ center_input_sample: bool = False,
77
+ flip_sin_to_cos: bool = True,
78
+ freq_shift: int = 0,
79
+ down_block_types: Tuple[str] = (
80
+ "CrossAttnDownBlock2D", # -> VideoLDMDownBlock
81
+ "CrossAttnDownBlock2D", # -> VideoLDMDownBlock
82
+ "CrossAttnDownBlock2D", # -> VideoLDMDownBlock
83
+ "DownBlock2D",
84
+ ),
85
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
86
+ up_block_types: Tuple[str] = (
87
+ "UpBlock2D",
88
+ "CrossAttnUpBlock2D", # -> VideoLDMUpBlock
89
+ "CrossAttnUpBlock2D", # -> VideoLDMUpBlock
90
+ "CrossAttnUpBlock2D", # -> VideoLDMUpBlock
91
+ ),
92
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
93
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
94
+ layers_per_block: Union[int, Tuple[int]] = 2,
95
+ downsample_padding: int = 1,
96
+ mid_block_scale_factor: float = 1,
97
+ dropout: float = 0.0,
98
+ act_fn: str = "silu",
99
+ norm_num_groups: Optional[int] = 32,
100
+ norm_eps: float = 1e-5,
101
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
102
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
103
+ encoder_hid_dim: Optional[int] = None,
104
+ encoder_hid_dim_type: Optional[str] = None,
105
+ attention_head_dim: Union[int, Tuple[int]] = 8,
106
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
107
+ dual_cross_attention: bool = False,
108
+ use_linear_projection: bool = False,
109
+ class_embed_type: Optional[str] = None,
110
+ addition_embed_type: Optional[str] = None,
111
+ addition_time_embed_dim: Optional[int] = None,
112
+ num_class_embeds: Optional[int] = None,
113
+ upcast_attention: bool = False,
114
+ resnet_time_scale_shift: str = "default",
115
+ resnet_skip_time_act: bool = False,
116
+ resnet_out_scale_factor: int = 1.0,
117
+ time_embedding_type: str = "positional",
118
+ time_embedding_dim: Optional[int] = None,
119
+ time_embedding_act_fn: Optional[str] = None,
120
+ timestep_post_act: Optional[str] = None,
121
+ time_cond_proj_dim: Optional[int] = None,
122
+ conv_in_kernel: int = 3,
123
+ conv_out_kernel: int = 3,
124
+ projection_class_embeddings_input_dim: Optional[int] = None,
125
+ attention_type: str = "default",
126
+ class_embeddings_concat: bool = False,
127
+ mid_block_only_cross_attention: Optional[bool] = None,
128
+ cross_attention_norm: Optional[str] = None,
129
+ addition_embed_type_num_heads=64,
130
+ # additional
131
+ use_temporal: bool = True,
132
+ n_frames: int = 8,
133
+ n_temp_heads: int = 8,
134
+ first_frame_condition_mode: str = "none",
135
+ augment_temporal_attention: bool = False,
136
+ temp_pos_embedding: str = "sinusoidal",
137
+ use_frame_stride_condition: bool = False,
138
+ ):
139
+ super().__init__()
140
+
141
+ rotary_emb = False
142
+ if temp_pos_embedding == "rotary":
143
+ # from rotary_embedding_torch import RotaryEmbedding
144
+ # rotary_emb = RotaryEmbedding(32)
145
+ # self.rotary_emb = rotary_emb
146
+ rotary_emb = True
147
+ self.rotary_emb = rotary_emb
148
+
149
+ self.use_temporal = use_temporal
150
+ self.augment_temporal_attention = augment_temporal_attention
151
+
152
+ assert first_frame_condition_mode in ["none", "concat", "conv2d", "input_only"], f"first_frame_condition_mode: {first_frame_condition_mode} must be one of ['none', 'concat', 'conv2d', 'input_only']"
153
+ self.first_frame_condition_mode = first_frame_condition_mode
154
+ latent_channels = in_channels
155
+
156
+ self.sample_size = sample_size
157
+
158
+ if num_attention_heads is not None:
159
+ raise ValueError(
160
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
161
+ )
162
+
163
+ num_attention_heads = num_attention_heads or attention_head_dim
164
+
165
+ # Check inputs
166
+ if len(down_block_types) != len(up_block_types):
167
+ raise ValueError(
168
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
169
+ )
170
+
171
+ if len(block_out_channels) != len(down_block_types):
172
+ raise ValueError(
173
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
174
+ )
175
+
176
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
177
+ raise ValueError(
178
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
179
+ )
180
+
181
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
182
+ raise ValueError(
183
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
184
+ )
185
+
186
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
187
+ raise ValueError(
188
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
189
+ )
190
+
191
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
192
+ raise ValueError(
193
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
194
+ )
195
+
196
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
197
+ raise ValueError(
198
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
199
+ )
200
+
201
+ # input
202
+ conv_in_padding = (conv_in_kernel - 1) // 2
203
+ self.conv_in = nn.Conv2d(
204
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
205
+ )
206
+
207
+ # time
208
+ if time_embedding_type == "fourier":
209
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
210
+ if time_embed_dim % 2 != 0:
211
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
212
+ self.time_proj = GaussianFourierProjection(
213
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
214
+ )
215
+ timestep_input_dim = time_embed_dim
216
+ elif time_embedding_type == "positional":
217
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
218
+
219
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
220
+ timestep_input_dim = block_out_channels[0]
221
+ else:
222
+ raise ValueError(
223
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
224
+ )
225
+
226
+ self.time_embedding = TimestepEmbedding(
227
+ timestep_input_dim,
228
+ time_embed_dim,
229
+ act_fn=act_fn,
230
+ post_act_fn=timestep_post_act,
231
+ cond_proj_dim=time_cond_proj_dim,
232
+ )
233
+
234
+ self.use_frame_stride_condition = use_frame_stride_condition
235
+ if self.use_frame_stride_condition:
236
+ self.frame_stride_embedding = TimestepEmbedding(
237
+ timestep_input_dim,
238
+ time_embed_dim,
239
+ act_fn=act_fn,
240
+ post_act_fn=timestep_post_act,
241
+ cond_proj_dim=time_cond_proj_dim,
242
+ )
243
+ # zero init
244
+ nn.init.zeros_(self.frame_stride_embedding.linear_2.weight)
245
+ nn.init.zeros_(self.frame_stride_embedding.linear_2.bias)
246
+
247
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
248
+ encoder_hid_dim_type = "text_proj"
249
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
250
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
251
+
252
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
253
+ raise ValueError(
254
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
255
+ )
256
+
257
+ if encoder_hid_dim_type == "text_proj":
258
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
259
+ elif encoder_hid_dim_type == "text_image_proj":
260
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
261
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
262
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
263
+ self.encoder_hid_proj = TextImageProjection(
264
+ text_embed_dim=encoder_hid_dim,
265
+ image_embed_dim=cross_attention_dim,
266
+ cross_attention_dim=cross_attention_dim,
267
+ )
268
+ elif encoder_hid_dim_type == "image_proj":
269
+ # Kandinsky 2.2
270
+ self.encoder_hid_proj = ImageProjection(
271
+ image_embed_dim=encoder_hid_dim,
272
+ cross_attention_dim=cross_attention_dim,
273
+ )
274
+ elif encoder_hid_dim_type is not None:
275
+ raise ValueError(
276
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
277
+ )
278
+ else:
279
+ self.encoder_hid_proj = None
280
+
281
+ # class embedding
282
+ if class_embed_type is None and num_class_embeds is not None:
283
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
284
+ elif class_embed_type == "timestep":
285
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
286
+ elif class_embed_type == "identity":
287
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
288
+ elif class_embed_type == "projection":
289
+ if projection_class_embeddings_input_dim is None:
290
+ raise ValueError(
291
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
292
+ )
293
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
294
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
295
+ # 2. it projects from an arbitrary input dimension.
296
+ #
297
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
298
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
299
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
300
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
301
+ elif class_embed_type == "simple_projection":
302
+ if projection_class_embeddings_input_dim is None:
303
+ raise ValueError(
304
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
305
+ )
306
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
307
+ else:
308
+ self.class_embedding = None
309
+
310
+ if addition_embed_type == "text":
311
+ if encoder_hid_dim is not None:
312
+ text_time_embedding_from_dim = encoder_hid_dim
313
+ else:
314
+ text_time_embedding_from_dim = cross_attention_dim
315
+
316
+ self.add_embedding = TextTimeEmbedding(
317
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
318
+ )
319
+ elif addition_embed_type == "text_image":
320
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
321
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
322
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
323
+ self.add_embedding = TextImageTimeEmbedding(
324
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
325
+ )
326
+ elif addition_embed_type == "text_time":
327
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
328
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
329
+ elif addition_embed_type == "image":
330
+ # Kandinsky 2.2
331
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
332
+ elif addition_embed_type == "image_hint":
333
+ # Kandinsky 2.2 ControlNet
334
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
335
+ elif addition_embed_type is not None:
336
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
337
+
338
+ if time_embedding_act_fn is None:
339
+ self.time_embed_act = None
340
+ else:
341
+ self.time_embed_act = get_activation(time_embedding_act_fn)
342
+
343
+ self.down_blocks = nn.ModuleList([])
344
+ self.up_blocks = nn.ModuleList([])
345
+
346
+ if isinstance(only_cross_attention, bool):
347
+ if mid_block_only_cross_attention is None:
348
+ mid_block_only_cross_attention = only_cross_attention
349
+
350
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
351
+
352
+ if mid_block_only_cross_attention is None:
353
+ mid_block_only_cross_attention = False
354
+
355
+ if isinstance(num_attention_heads, int):
356
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
357
+
358
+ if isinstance(attention_head_dim, int):
359
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
360
+
361
+ if isinstance(cross_attention_dim, int):
362
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
363
+
364
+ if isinstance(layers_per_block, int):
365
+ layers_per_block = [layers_per_block] * len(down_block_types)
366
+
367
+ if isinstance(transformer_layers_per_block, int):
368
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
369
+
370
+ if class_embeddings_concat:
371
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
372
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
373
+ # regular time embeddings
374
+ blocks_time_embed_dim = time_embed_dim * 2
375
+ else:
376
+ blocks_time_embed_dim = time_embed_dim
377
+ # down
378
+ output_channel = block_out_channels[0]
379
+ for i, down_block_type in enumerate(down_block_types):
380
+ input_channel = output_channel
381
+ output_channel = block_out_channels[i]
382
+ is_final_block = i == len(block_out_channels) - 1
383
+
384
+ down_block = get_down_block(
385
+ down_block_type,
386
+ num_layers=layers_per_block[i],
387
+ transformer_layers_per_block=transformer_layers_per_block[i],
388
+ in_channels=input_channel,
389
+ out_channels=output_channel,
390
+ temb_channels=blocks_time_embed_dim,
391
+ add_downsample=not is_final_block,
392
+ resnet_eps=norm_eps,
393
+ resnet_act_fn=act_fn,
394
+ resnet_groups=norm_num_groups,
395
+ cross_attention_dim=cross_attention_dim[i],
396
+ num_attention_heads=num_attention_heads[i],
397
+ downsample_padding=downsample_padding,
398
+ dual_cross_attention=dual_cross_attention,
399
+ use_linear_projection=use_linear_projection,
400
+ only_cross_attention=only_cross_attention[i],
401
+ upcast_attention=upcast_attention,
402
+ resnet_time_scale_shift=resnet_time_scale_shift,
403
+ attention_type=attention_type,
404
+ resnet_skip_time_act=resnet_skip_time_act,
405
+ resnet_out_scale_factor=resnet_out_scale_factor,
406
+ cross_attention_norm=cross_attention_norm,
407
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
408
+ dropout=dropout,
409
+ # additional
410
+ use_temporal=use_temporal,
411
+ augment_temporal_attention=augment_temporal_attention,
412
+ n_frames=n_frames,
413
+ n_temp_heads=n_temp_heads,
414
+ first_frame_condition_mode=first_frame_condition_mode,
415
+ latent_channels=latent_channels,
416
+ rotary_emb=rotary_emb,
417
+ )
418
+ self.down_blocks.append(down_block)
419
+
420
+ # mid
421
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
422
+ self.mid_block = VideoLDMUNetMidBlock2DCrossAttn(
423
+ transformer_layers_per_block=transformer_layers_per_block[-1],
424
+ in_channels=block_out_channels[-1],
425
+ temb_channels=blocks_time_embed_dim,
426
+ dropout=dropout,
427
+ resnet_eps=norm_eps,
428
+ resnet_act_fn=act_fn,
429
+ output_scale_factor=mid_block_scale_factor,
430
+ resnet_time_scale_shift=resnet_time_scale_shift,
431
+ cross_attention_dim=cross_attention_dim[-1],
432
+ num_attention_heads=num_attention_heads[-1],
433
+ resnet_groups=norm_num_groups,
434
+ dual_cross_attention=dual_cross_attention,
435
+ use_linear_projection=use_linear_projection,
436
+ upcast_attention=upcast_attention,
437
+ attention_type=attention_type,
438
+ # additional
439
+ use_temporal=use_temporal,
440
+ n_frames=n_frames,
441
+ first_frame_condition_mode=first_frame_condition_mode,
442
+ latent_channels=latent_channels,
443
+ )
444
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
445
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
446
+ in_channels=block_out_channels[-1],
447
+ temb_channels=blocks_time_embed_dim,
448
+ dropout=dropout,
449
+ resnet_eps=norm_eps,
450
+ resnet_act_fn=act_fn,
451
+ output_scale_factor=mid_block_scale_factor,
452
+ cross_attention_dim=cross_attention_dim[-1],
453
+ attention_head_dim=attention_head_dim[-1],
454
+ resnet_groups=norm_num_groups,
455
+ resnet_time_scale_shift=resnet_time_scale_shift,
456
+ skip_time_act=resnet_skip_time_act,
457
+ only_cross_attention=mid_block_only_cross_attention,
458
+ cross_attention_norm=cross_attention_norm,
459
+ )
460
+ elif mid_block_type is None:
461
+ self.mid_block = None
462
+ else:
463
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
464
+
465
+ # count how many layers upsample the images
466
+ self.num_upsamplers = 0
467
+
468
+ # up
469
+ reversed_block_out_channels = list(reversed(block_out_channels))
470
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
471
+ reversed_layers_per_block = list(reversed(layers_per_block))
472
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
473
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
474
+ only_cross_attention = list(reversed(only_cross_attention))
475
+
476
+ output_channel = reversed_block_out_channels[0]
477
+ for i, up_block_type in enumerate(up_block_types):
478
+ is_final_block = i == len(block_out_channels) - 1
479
+
480
+ prev_output_channel = output_channel
481
+ output_channel = reversed_block_out_channels[i]
482
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
483
+
484
+ # add upsample block for all BUT final layer
485
+ if not is_final_block:
486
+ add_upsample = True
487
+ self.num_upsamplers += 1
488
+ else:
489
+ add_upsample = False
490
+
491
+ up_block = get_up_block(
492
+ up_block_type,
493
+ num_layers=reversed_layers_per_block[i] + 1,
494
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
495
+ in_channels=input_channel,
496
+ out_channels=output_channel,
497
+ prev_output_channel=prev_output_channel,
498
+ temb_channels=blocks_time_embed_dim,
499
+ add_upsample=add_upsample,
500
+ resnet_eps=norm_eps,
501
+ resnet_act_fn=act_fn,
502
+ resnet_groups=norm_num_groups,
503
+ cross_attention_dim=reversed_cross_attention_dim[i],
504
+ num_attention_heads=reversed_num_attention_heads[i],
505
+ dual_cross_attention=dual_cross_attention,
506
+ use_linear_projection=use_linear_projection,
507
+ only_cross_attention=only_cross_attention[i],
508
+ upcast_attention=upcast_attention,
509
+ resnet_time_scale_shift=resnet_time_scale_shift,
510
+ attention_type=attention_type,
511
+ resnet_skip_time_act=resnet_skip_time_act,
512
+ resnet_out_scale_factor=resnet_out_scale_factor,
513
+ cross_attention_norm=cross_attention_norm,
514
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
515
+ dropout=dropout,
516
+ # additional
517
+ use_temporal=use_temporal,
518
+ augment_temporal_attention=augment_temporal_attention,
519
+ n_frames=n_frames,
520
+ n_temp_heads=n_temp_heads,
521
+ first_frame_condition_mode=first_frame_condition_mode,
522
+ latent_channels=latent_channels,
523
+ rotary_emb=rotary_emb,
524
+ )
525
+ self.up_blocks.append(up_block)
526
+ prev_output_channel = output_channel
527
+
528
+ # out
529
+ if norm_num_groups is not None:
530
+ self.conv_norm_out = nn.GroupNorm(
531
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
532
+ )
533
+
534
+ self.conv_act = get_activation(act_fn)
535
+
536
+ else:
537
+ self.conv_norm_out = None
538
+ self.conv_act = None
539
+
540
+ conv_out_padding = (conv_out_kernel - 1) // 2
541
+ self.conv_out = nn.Conv2d(
542
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
543
+ )
544
+
545
+ @property
546
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
547
+ r"""
548
+ Returns:
549
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
550
+ indexed by its weight name.
551
+ """
552
+ # set recursively
553
+ processors = {}
554
+
555
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
556
+ if hasattr(module, "get_processor"):
557
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
558
+
559
+ for sub_name, child in module.named_children():
560
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
561
+
562
+ return processors
563
+
564
+ for name, module in self.named_children():
565
+ fn_recursive_add_processors(name, module, processors)
566
+
567
+ return processors
568
+
569
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
570
+ r"""
571
+ Sets the attention processor to use to compute attention.
572
+
573
+ Parameters:
574
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
575
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
576
+ for **all** `Attention` layers.
577
+
578
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
579
+ processor. This is strongly recommended when setting trainable attention processors.
580
+
581
+ """
582
+ count = len(self.attn_processors.keys())
583
+
584
+ if isinstance(processor, dict) and len(processor) != count:
585
+ raise ValueError(
586
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
587
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
588
+ )
589
+
590
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
591
+ if hasattr(module, "set_processor"):
592
+ if not isinstance(processor, dict):
593
+ module.set_processor(processor)
594
+ else:
595
+ module.set_processor(processor.pop(f"{name}.processor"))
596
+
597
+ for sub_name, child in module.named_children():
598
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
599
+
600
+ for name, module in self.named_children():
601
+ fn_recursive_attn_processor(name, module, processor)
602
+
603
+ def set_default_attn_processor(self):
604
+ """
605
+ Disables custom attention processors and sets the default attention implementation.
606
+ """
607
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
608
+ processor = AttnAddedKVProcessor()
609
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
610
+ processor = AttnProcessor()
611
+ else:
612
+ raise ValueError(
613
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
614
+ )
615
+
616
+ self.set_attn_processor(processor)
617
+
618
+ def set_attention_slice(self, slice_size):
619
+ r"""
620
+ Enable sliced attention computation.
621
+
622
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
623
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
624
+
625
+ Args:
626
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
627
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
628
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
629
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
630
+ must be a multiple of `slice_size`.
631
+ """
632
+ sliceable_head_dims = []
633
+
634
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
635
+ if hasattr(module, "set_attention_slice"):
636
+ sliceable_head_dims.append(module.sliceable_head_dim)
637
+
638
+ for child in module.children():
639
+ fn_recursive_retrieve_sliceable_dims(child)
640
+
641
+ # retrieve number of attention layers
642
+ for module in self.children():
643
+ fn_recursive_retrieve_sliceable_dims(module)
644
+
645
+ num_sliceable_layers = len(sliceable_head_dims)
646
+
647
+ if slice_size == "auto":
648
+ # half the attention head size is usually a good trade-off between
649
+ # speed and memory
650
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
651
+ elif slice_size == "max":
652
+ # make smallest slice possible
653
+ slice_size = num_sliceable_layers * [1]
654
+
655
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
656
+
657
+ if len(slice_size) != len(sliceable_head_dims):
658
+ raise ValueError(
659
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
660
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
661
+ )
662
+
663
+ for i in range(len(slice_size)):
664
+ size = slice_size[i]
665
+ dim = sliceable_head_dims[i]
666
+ if size is not None and size > dim:
667
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
668
+
669
+ # Recursively walk through all the children.
670
+ # Any children which exposes the set_attention_slice method
671
+ # gets the message
672
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
673
+ if hasattr(module, "set_attention_slice"):
674
+ module.set_attention_slice(slice_size.pop())
675
+
676
+ for child in module.children():
677
+ fn_recursive_set_attention_slice(child, slice_size)
678
+
679
+ reversed_slice_size = list(reversed(slice_size))
680
+ for module in self.children():
681
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
682
+
683
+ def _set_gradient_checkpointing(self, module, value=False):
684
+ if hasattr(module, "gradient_checkpointing"):
685
+ module.gradient_checkpointing = value
686
+
687
+ def forward(
688
+ self,
689
+ sample: torch.FloatTensor,
690
+ timestep: Union[torch.Tensor, float, int],
691
+ encoder_hidden_states: torch.Tensor,
692
+ class_labels: Optional[torch.Tensor] = None,
693
+ timestep_cond: Optional[torch.Tensor] = None,
694
+ attention_mask: Optional[torch.Tensor] = None,
695
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
696
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
697
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
698
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
699
+ encoder_attention_mask: Optional[torch.Tensor] = None,
700
+ return_dict: bool = True,
701
+ # additional
702
+ first_frame_latents: Optional[torch.Tensor] = None,
703
+ frame_stride: Optional[Union[torch.Tensor, float, int]] = None,
704
+ ) -> Union[UNet2DConditionOutput, Tuple]:
705
+ # reshape video data
706
+ assert sample.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={sample.dim()}."
707
+ video_length = sample.shape[2]
708
+
709
+ if first_frame_latents is not None:
710
+ assert self.config.first_frame_condition_mode != "none", "first_frame_latents is not None, but first_frame_condition_mode is 'none'."
711
+
712
+ if self.config.first_frame_condition_mode != "none":
713
+ sample = torch.cat([first_frame_latents, sample], dim=2)
714
+ video_length += 1
715
+
716
+ # copy conditioning embeddings for cross attention
717
+ if encoder_hidden_states is not None:
718
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
719
+
720
+ sample = rearrange(sample, "b c f h w -> (b f) c h w")
721
+
722
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
723
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
724
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
725
+ # on the fly if necessary.
726
+ default_overall_up_factor = 2**self.num_upsamplers
727
+
728
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
729
+ forward_upsample_size = False
730
+ upsample_size = None
731
+
732
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
733
+ logger.info("Forward upsample size to force interpolation output size.")
734
+ forward_upsample_size = True
735
+
736
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
737
+ # expects mask of shape:
738
+ # [batch, key_tokens]
739
+ # adds singleton query_tokens dimension:
740
+ # [batch, 1, key_tokens]
741
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
742
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
743
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
744
+ if attention_mask is not None:
745
+ # assume that mask is expressed as:
746
+ # (1 = keep, 0 = discard)
747
+ # convert mask into a bias that can be added to attention scores:
748
+ # (keep = +0, discard = -10000.0)
749
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
750
+ attention_mask = attention_mask.unsqueeze(1)
751
+
752
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
753
+ if encoder_attention_mask is not None:
754
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
755
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
756
+
757
+ # 0. center input if necessary
758
+ if self.config.center_input_sample:
759
+ sample = 2 * sample - 1.0
760
+
761
+ # 1. time
762
+ timesteps = timestep
763
+ if not torch.is_tensor(timesteps):
764
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
765
+ # This would be a good case for the `match` statement (Python 3.10+)
766
+ is_mps = sample.device.type == "mps"
767
+ if isinstance(timestep, float):
768
+ dtype = torch.float32 if is_mps else torch.float64
769
+ else:
770
+ dtype = torch.int32 if is_mps else torch.int64
771
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
772
+ elif len(timesteps.shape) == 0:
773
+ timesteps = timesteps[None].to(sample.device)
774
+
775
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
776
+ timesteps = timesteps.expand(sample.shape[0])
777
+
778
+ t_emb = self.time_proj(timesteps)
779
+
780
+ # `Timesteps` does not contain any weights and will always return f32 tensors
781
+ # but time_embedding might actually be running in fp16. so we need to cast here.
782
+ # there might be better ways to encapsulate this.
783
+ t_emb = t_emb.to(dtype=sample.dtype)
784
+
785
+ emb = self.time_embedding(t_emb, timestep_cond)
786
+
787
+ if self.use_frame_stride_condition:
788
+ if not torch.is_tensor(frame_stride):
789
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
790
+ # This would be a good case for the `match` statement (Python 3.10+)
791
+ is_mps = sample.device.type == "mps"
792
+ if isinstance(timestep, float):
793
+ dtype = torch.float32 if is_mps else torch.float64
794
+ else:
795
+ dtype = torch.int32 if is_mps else torch.int64
796
+ frame_stride = torch.tensor([frame_stride], dtype=dtype, device=sample.device)
797
+ elif len(frame_stride.shape) == 0:
798
+ frame_stride = frame_stride[None].to(sample.device)
799
+
800
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
801
+ frame_stride = frame_stride.expand(sample.shape[0])
802
+
803
+ fs_emb = self.time_proj(frame_stride)
804
+
805
+ # `Timesteps` does not contain any weights and will always return f32 tensors
806
+ # but time_embedding might actually be running in fp16. so we need to cast here.
807
+ # there might be better ways to encapsulate this.
808
+ fs_emb = fs_emb.to(dtype=sample.dtype)
809
+
810
+ fs_emb = self.frame_stride_embedding(fs_emb, timestep_cond)
811
+ emb = emb + fs_emb
812
+
813
+ aug_emb = None
814
+
815
+ if self.class_embedding is not None:
816
+ if class_labels is None:
817
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
818
+
819
+ if self.config.class_embed_type == "timestep":
820
+ class_labels = self.time_proj(class_labels)
821
+
822
+ # `Timesteps` does not contain any weights and will always return f32 tensors
823
+ # there might be better ways to encapsulate this.
824
+ class_labels = class_labels.to(dtype=sample.dtype)
825
+
826
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
827
+
828
+ if self.config.class_embeddings_concat:
829
+ emb = torch.cat([emb, class_emb], dim=-1)
830
+ else:
831
+ emb = emb + class_emb
832
+
833
+ if self.config.addition_embed_type == "text":
834
+ aug_emb = self.add_embedding(encoder_hidden_states)
835
+ elif self.config.addition_embed_type == "text_image":
836
+ # Kandinsky 2.1 - style
837
+ if "image_embeds" not in added_cond_kwargs:
838
+ raise ValueError(
839
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
840
+ )
841
+
842
+ image_embs = added_cond_kwargs.get("image_embeds")
843
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
844
+ aug_emb = self.add_embedding(text_embs, image_embs)
845
+ elif self.config.addition_embed_type == "text_time":
846
+ # SDXL - style
847
+ if "text_embeds" not in added_cond_kwargs:
848
+ raise ValueError(
849
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
850
+ )
851
+ text_embeds = added_cond_kwargs.get("text_embeds")
852
+ if "time_ids" not in added_cond_kwargs:
853
+ raise ValueError(
854
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
855
+ )
856
+ time_ids = added_cond_kwargs.get("time_ids")
857
+ time_embeds = self.add_time_proj(time_ids.flatten())
858
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
859
+
860
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
861
+ add_embeds = add_embeds.to(emb.dtype)
862
+ aug_emb = self.add_embedding(add_embeds)
863
+ elif self.config.addition_embed_type == "image":
864
+ # Kandinsky 2.2 - style
865
+ if "image_embeds" not in added_cond_kwargs:
866
+ raise ValueError(
867
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
868
+ )
869
+ image_embs = added_cond_kwargs.get("image_embeds")
870
+ aug_emb = self.add_embedding(image_embs)
871
+ elif self.config.addition_embed_type == "image_hint":
872
+ # Kandinsky 2.2 - style
873
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
874
+ raise ValueError(
875
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
876
+ )
877
+ image_embs = added_cond_kwargs.get("image_embeds")
878
+ hint = added_cond_kwargs.get("hint")
879
+ aug_emb, hint = self.add_embedding(image_embs, hint)
880
+ sample = torch.cat([sample, hint], dim=1)
881
+
882
+ emb = emb + aug_emb if aug_emb is not None else emb
883
+
884
+ if self.time_embed_act is not None:
885
+ emb = self.time_embed_act(emb)
886
+
887
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
888
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
889
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
890
+ # Kadinsky 2.1 - style
891
+ if "image_embeds" not in added_cond_kwargs:
892
+ raise ValueError(
893
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
894
+ )
895
+
896
+ image_embeds = added_cond_kwargs.get("image_embeds")
897
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
898
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
899
+ # Kandinsky 2.2 - style
900
+ if "image_embeds" not in added_cond_kwargs:
901
+ raise ValueError(
902
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
903
+ )
904
+ image_embeds = added_cond_kwargs.get("image_embeds")
905
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
906
+ # 2. pre-process
907
+ sample = self.conv_in(sample)
908
+
909
+ # 2.5 GLIGEN position net
910
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
911
+ cross_attention_kwargs = cross_attention_kwargs.copy()
912
+ gligen_args = cross_attention_kwargs.pop("gligen")
913
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
914
+
915
+ # 3. down
916
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
917
+
918
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
919
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
920
+
921
+ down_block_res_samples = (sample,)
922
+ for downsample_block in self.down_blocks:
923
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
924
+ # For t2i-adapter CrossAttnDownBlock2D
925
+ additional_residuals = {}
926
+ if is_adapter and len(down_block_additional_residuals) > 0:
927
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
928
+
929
+ sample, res_samples = downsample_block(
930
+ hidden_states=sample,
931
+ temb=emb,
932
+ encoder_hidden_states=encoder_hidden_states,
933
+ attention_mask=attention_mask,
934
+ cross_attention_kwargs=cross_attention_kwargs,
935
+ encoder_attention_mask=encoder_attention_mask,
936
+ first_frame_latents=first_frame_latents,
937
+ **additional_residuals,
938
+ )
939
+ else:
940
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale, first_frame_latents=first_frame_latents,)
941
+
942
+ if is_adapter and len(down_block_additional_residuals) > 0:
943
+ sample += down_block_additional_residuals.pop(0)
944
+
945
+ down_block_res_samples += res_samples
946
+
947
+ if is_controlnet:
948
+ new_down_block_res_samples = ()
949
+
950
+ for down_block_res_sample, down_block_additional_residual in zip(
951
+ down_block_res_samples, down_block_additional_residuals
952
+ ):
953
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
954
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
955
+
956
+ down_block_res_samples = new_down_block_res_samples
957
+
958
+ # 4. mid
959
+ if self.mid_block is not None:
960
+ sample = self.mid_block(
961
+ sample,
962
+ emb,
963
+ encoder_hidden_states=encoder_hidden_states,
964
+ attention_mask=attention_mask,
965
+ cross_attention_kwargs=cross_attention_kwargs,
966
+ encoder_attention_mask=encoder_attention_mask,
967
+ # additional
968
+ first_frame_latents=first_frame_latents,
969
+ )
970
+ # To support T2I-Adapter-XL
971
+ if (
972
+ is_adapter
973
+ and len(down_block_additional_residuals) > 0
974
+ and sample.shape == down_block_additional_residuals[0].shape
975
+ ):
976
+ sample += down_block_additional_residuals.pop(0)
977
+
978
+ if is_controlnet:
979
+ sample = sample + mid_block_additional_residual
980
+
981
+ # 5. up
982
+ for i, upsample_block in enumerate(self.up_blocks):
983
+ is_final_block = i == len(self.up_blocks) - 1
984
+
985
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
986
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
987
+
988
+ # if we have not reached the final block and need to forward the
989
+ # upsample size, we do it here
990
+ if not is_final_block and forward_upsample_size:
991
+ upsample_size = down_block_res_samples[-1].shape[2:]
992
+
993
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
994
+ sample = upsample_block(
995
+ hidden_states=sample,
996
+ temb=emb,
997
+ res_hidden_states_tuple=res_samples,
998
+ encoder_hidden_states=encoder_hidden_states,
999
+ cross_attention_kwargs=cross_attention_kwargs,
1000
+ upsample_size=upsample_size,
1001
+ attention_mask=attention_mask,
1002
+ encoder_attention_mask=encoder_attention_mask,
1003
+ first_frame_latents=first_frame_latents,
1004
+ )
1005
+ else:
1006
+ sample = upsample_block(
1007
+ hidden_states=sample,
1008
+ temb=emb,
1009
+ res_hidden_states_tuple=res_samples,
1010
+ upsample_size=upsample_size,
1011
+ scale=lora_scale,
1012
+ first_frame_latents=first_frame_latents,
1013
+ )
1014
+
1015
+ # 6. post-process
1016
+ if self.conv_norm_out:
1017
+ sample = self.conv_norm_out(sample)
1018
+ sample = self.conv_act(sample)
1019
+ sample = self.conv_out(sample)
1020
+
1021
+ sample = rearrange(sample, "(b f) c h w -> b c f h w", f=video_length)
1022
+ if self.config.first_frame_condition_mode != "none":
1023
+ sample = sample[:, :, 1:, :, :]
1024
+
1025
+ if not return_dict:
1026
+ return (sample,)
1027
+
1028
+ return UNet2DConditionOutput(sample=sample)
1029
+
1030
+ @classmethod
1031
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
1032
+
1033
+ kwargs.pop("low_cpu_mem_usage", False)
1034
+ kwargs.pop("device_map", None)
1035
+
1036
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1037
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1038
+ force_download = kwargs.pop("force_download", False)
1039
+ from_flax = kwargs.pop("from_flax", False)
1040
+ resume_download = kwargs.pop("resume_download", False)
1041
+ proxies = kwargs.pop("proxies", None)
1042
+ output_loading_info = kwargs.pop("output_loading_info", False)
1043
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1044
+ use_auth_token = kwargs.pop("use_auth_token", None)
1045
+ revision = kwargs.pop("revision", None)
1046
+ torch_dtype = kwargs.pop("torch_dtype", None)
1047
+ subfolder = kwargs.pop("subfolder", None)
1048
+ device_map = None
1049
+ max_memory = kwargs.pop("max_memory", None)
1050
+ offload_folder = kwargs.pop("offload_folder", None)
1051
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
1052
+ low_cpu_mem_usage = False
1053
+ variant = kwargs.pop("variant", None)
1054
+ use_safetensors = kwargs.pop("use_safetensors", None)
1055
+
1056
+ allow_pickle = False
1057
+ if use_safetensors is None:
1058
+ use_safetensors = True
1059
+ allow_pickle = True
1060
+
1061
+ if low_cpu_mem_usage and not is_accelerate_available():
1062
+ low_cpu_mem_usage = False
1063
+ logger.warning(
1064
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
1065
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
1066
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
1067
+ " install accelerate\n```\n."
1068
+ )
1069
+
1070
+ if device_map is not None and not is_accelerate_available():
1071
+ raise NotImplementedError(
1072
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
1073
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
1074
+ )
1075
+
1076
+ # Check if we can handle device_map and dispatching the weights
1077
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
1078
+ raise NotImplementedError(
1079
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1080
+ " `device_map=None`."
1081
+ )
1082
+
1083
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
1084
+ raise NotImplementedError(
1085
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
1086
+ " `low_cpu_mem_usage=False`."
1087
+ )
1088
+
1089
+ if low_cpu_mem_usage is False and device_map is not None:
1090
+ raise ValueError(
1091
+ f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
1092
+ " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
1093
+ )
1094
+
1095
+ # Load config if we don't provide a configuration
1096
+ config_path = pretrained_model_name_or_path
1097
+
1098
+ user_agent = {
1099
+ "diffusers": __version__,
1100
+ "file_type": "model",
1101
+ "framework": "pytorch",
1102
+ }
1103
+
1104
+ # load config
1105
+ config, unused_kwargs, commit_hash = cls.load_config(
1106
+ config_path,
1107
+ cache_dir=cache_dir,
1108
+ return_unused_kwargs=True,
1109
+ return_commit_hash=True,
1110
+ force_download=force_download,
1111
+ resume_download=resume_download,
1112
+ proxies=proxies,
1113
+ local_files_only=local_files_only,
1114
+ use_auth_token=use_auth_token,
1115
+ revision=revision,
1116
+ subfolder=subfolder,
1117
+ device_map=device_map,
1118
+ max_memory=max_memory,
1119
+ offload_folder=offload_folder,
1120
+ offload_state_dict=offload_state_dict,
1121
+ user_agent=user_agent,
1122
+ **kwargs,
1123
+ )
1124
+
1125
+ # load model
1126
+ model_file = None
1127
+ if from_flax:
1128
+ model_file = _get_model_file(
1129
+ pretrained_model_name_or_path,
1130
+ weights_name=FLAX_WEIGHTS_NAME,
1131
+ cache_dir=cache_dir,
1132
+ force_download=force_download,
1133
+ resume_download=resume_download,
1134
+ proxies=proxies,
1135
+ local_files_only=local_files_only,
1136
+ use_auth_token=use_auth_token,
1137
+ revision=revision,
1138
+ subfolder=subfolder,
1139
+ user_agent=user_agent,
1140
+ commit_hash=commit_hash,
1141
+ )
1142
+ model = cls.from_config(config, **unused_kwargs)
1143
+
1144
+ # Convert the weights
1145
+ from diffusers.models.modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
1146
+
1147
+ model = load_flax_checkpoint_in_pytorch_model(model, model_file)
1148
+ else:
1149
+ if use_safetensors:
1150
+ try:
1151
+ model_file = _get_model_file(
1152
+ pretrained_model_name_or_path,
1153
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1154
+ cache_dir=cache_dir,
1155
+ force_download=force_download,
1156
+ resume_download=resume_download,
1157
+ proxies=proxies,
1158
+ local_files_only=local_files_only,
1159
+ use_auth_token=use_auth_token,
1160
+ revision=revision,
1161
+ subfolder=subfolder,
1162
+ user_agent=user_agent,
1163
+ commit_hash=commit_hash,
1164
+ )
1165
+ except IOError as e:
1166
+ if not allow_pickle:
1167
+ raise e
1168
+ pass
1169
+ if model_file is None:
1170
+ model_file = _get_model_file(
1171
+ pretrained_model_name_or_path,
1172
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
1173
+ cache_dir=cache_dir,
1174
+ force_download=force_download,
1175
+ resume_download=resume_download,
1176
+ proxies=proxies,
1177
+ local_files_only=local_files_only,
1178
+ use_auth_token=use_auth_token,
1179
+ revision=revision,
1180
+ subfolder=subfolder,
1181
+ user_agent=user_agent,
1182
+ commit_hash=commit_hash,
1183
+ )
1184
+
1185
+ if low_cpu_mem_usage:
1186
+ # Instantiate model with empty weights
1187
+ with accelerate.init_empty_weights():
1188
+ model = cls.from_config(config, **unused_kwargs)
1189
+
1190
+ # if device_map is None, load the state dict and move the params from meta device to the cpu
1191
+ if device_map is None:
1192
+ param_device = "cpu"
1193
+ state_dict = load_state_dict(model_file, variant=variant)
1194
+ model._convert_deprecated_attention_blocks(state_dict)
1195
+ # move the params from meta device to cpu
1196
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
1197
+ if len(missing_keys) > 0:
1198
+ raise ValueError(
1199
+ f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
1200
+ f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
1201
+ " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
1202
+ " those weights or else make sure your checkpoint file is correct."
1203
+ )
1204
+
1205
+ unexpected_keys = load_model_dict_into_meta(
1206
+ model,
1207
+ state_dict,
1208
+ device=param_device,
1209
+ dtype=torch_dtype,
1210
+ model_name_or_path=pretrained_model_name_or_path,
1211
+ )
1212
+
1213
+ if cls._keys_to_ignore_on_load_unexpected is not None:
1214
+ for pat in cls._keys_to_ignore_on_load_unexpected:
1215
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1216
+
1217
+ if len(unexpected_keys) > 0:
1218
+ logger.warn(
1219
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1220
+ )
1221
+
1222
+ else: # else let accelerate handle loading and dispatching.
1223
+ # Load weights and dispatch according to the device_map
1224
+ # by default the device_map is None and the weights are loaded on the CPU
1225
+ try:
1226
+ accelerate.load_checkpoint_and_dispatch(
1227
+ model,
1228
+ model_file,
1229
+ device_map,
1230
+ max_memory=max_memory,
1231
+ offload_folder=offload_folder,
1232
+ offload_state_dict=offload_state_dict,
1233
+ dtype=torch_dtype,
1234
+ )
1235
+ except AttributeError as e:
1236
+ # When using accelerate loading, we do not have the ability to load the state
1237
+ # dict and rename the weight names manually. Additionally, accelerate skips
1238
+ # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
1239
+ # (which look like they should be private variables?), so we can't use the standard hooks
1240
+ # to rename parameters on load. We need to mimic the original weight names so the correct
1241
+ # attributes are available. After we have loaded the weights, we convert the deprecated
1242
+ # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
1243
+ # the weights so we don't have to do this again.
1244
+
1245
+ if "'Attention' object has no attribute" in str(e):
1246
+ logger.warn(
1247
+ f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
1248
+ " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
1249
+ " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
1250
+ " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
1251
+ " please also re-upload it or open a PR on the original repository."
1252
+ )
1253
+ model._temp_convert_self_to_deprecated_attention_blocks()
1254
+ accelerate.load_checkpoint_and_dispatch(
1255
+ model,
1256
+ model_file,
1257
+ device_map,
1258
+ max_memory=max_memory,
1259
+ offload_folder=offload_folder,
1260
+ offload_state_dict=offload_state_dict,
1261
+ dtype=torch_dtype,
1262
+ )
1263
+ model._undo_temp_convert_self_to_deprecated_attention_blocks()
1264
+ else:
1265
+ raise e
1266
+
1267
+ loading_info = {
1268
+ "missing_keys": [],
1269
+ "unexpected_keys": [],
1270
+ "mismatched_keys": [],
1271
+ "error_msgs": [],
1272
+ }
1273
+ else:
1274
+ model = cls.from_config(config, **unused_kwargs)
1275
+
1276
+ state_dict = load_state_dict(model_file, variant=variant)
1277
+ model._convert_deprecated_attention_blocks(state_dict)
1278
+
1279
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
1280
+ model,
1281
+ state_dict,
1282
+ model_file,
1283
+ pretrained_model_name_or_path,
1284
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
1285
+ )
1286
+
1287
+ loading_info = {
1288
+ "missing_keys": missing_keys,
1289
+ "unexpected_keys": unexpected_keys,
1290
+ "mismatched_keys": mismatched_keys,
1291
+ "error_msgs": error_msgs,
1292
+ }
1293
+
1294
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
1295
+ raise ValueError(
1296
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1297
+ )
1298
+ elif torch_dtype is not None:
1299
+ model = model.to(torch_dtype)
1300
+
1301
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1302
+
1303
+ m, u = loading_info["missing_keys"], loading_info["unexpected_keys"]
1304
+ logger.info(f"### missing keys: {len(m)}; unexpected keys: {len(u)};")
1305
+ # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
1306
+
1307
+ spatial_params = [p.numel() if "conv3ds" not in n and "tempo_attns" not in n else 0 for n, p in model.named_parameters()]
1308
+ tconv_params = [p.numel() if "conv3ds." in n else 0 for n, p in model.named_parameters()]
1309
+ tattn_params = [p.numel() if "tempo_attns." in n else 0 for n, p in model.named_parameters()]
1310
+ tffconv_params = [p.numel() if "first_frame_conv." in n else 0 for n, p in model.named_parameters()]
1311
+ logger.info(f"### First Frame Convolution Layer Parameters: {sum(tffconv_params) / 1e6} M")
1312
+ logger.info(f"### Spatial UNet Parameters: {sum(spatial_params) / 1e6} M")
1313
+ logger.info(f"### Temporal Convolution Module Parameters: {sum(tconv_params) / 1e6} M")
1314
+ logger.info(f"### Temporal Attention Module Parameters: {sum(tattn_params) / 1e6} M")
1315
+
1316
+ # Set model in evaluation mode to deactivate DropOut modules by default
1317
+ model.eval()
1318
+ if output_loading_info:
1319
+ return model, loading_info
1320
+
1321
+ return model
1322
+
1323
+ if __name__ == "__main__":
1324
+ # test
1325
+ from diffusers import AutoencoderKL, DDIMScheduler
1326
+ from transformers import CLIPTextModel, CLIPTokenizer
1327
+ from consisti2v.pipelines.pipeline_animation import AnimationPipeline
1328
+ from consisti2v.pipelines.pipeline_conditional_animation import ConditionalAnimationPipeline
1329
+ from consisti2v.utils.util import save_videos_grid
1330
+
1331
+ pretrained_model_path = "models/StableDiffusion/stable-diffusion-v1-5"
1332
+ prompt = "apply eye makeup"
1333
+ first_frame_path = "/ML-A100/home/weiming/datasets/UCF/frames/v_ApplyEyeMakeup_g01_c01_frame_90.jpg"
1334
+
1335
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer", use_safetensors=True)
1336
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
1337
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae", use_safetensors=True)
1338
+ unet = VideoLDMUNet3DConditionModel.from_pretrained(
1339
+ pretrained_model_path,
1340
+ subfolder="unet",
1341
+ use_safetensors=True
1342
+ )
1343
+
1344
+ noise_scheduler_kwargs = {
1345
+ "num_train_timesteps": 1000,
1346
+ "beta_start": 0.00085,
1347
+ "beta_end": 0.012,
1348
+ "beta_schedule": "linear",
1349
+ "steps_offset": 1,
1350
+ "clip_sample": False,
1351
+ }
1352
+ noise_scheduler = DDIMScheduler(**noise_scheduler_kwargs)
1353
+ # latent = torch.randn(1, 4, 8, 64, 64).to("cuda")
1354
+ # text_embedding = torch.randn(1, 77, 768).to("cuda")
1355
+ # timestep = torch.randint(0, 1000, (1,)).to("cuda").squeeze(0)
1356
+ # output = unet(latent, timestep, text_embedding)
1357
+
1358
+ pipeline = ConditionalAnimationPipeline(
1359
+ unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler,
1360
+ ).to("cuda")
1361
+ sample = pipeline(
1362
+ prompt,
1363
+ num_inference_steps = 25,
1364
+ guidance_scale = 8.,
1365
+ video_length = 8,
1366
+ height = 256,
1367
+ width = 256,
1368
+ first_frame_paths = first_frame_path,
1369
+ ).videos
1370
+ print(sample.shape)
1371
+ save_videos_grid(sample, f"samples/videoldm.gif")
consisti2v/models/videoldm_unet_blocks.py ADDED
@@ -0,0 +1,1159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict, Tuple, Any
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange, repeat
7
+ from einops.layers.torch import Rearrange
8
+ from diffusers.utils import logging
9
+ from diffusers.models.unet_2d_blocks import (
10
+ DownBlock2D,
11
+ UpBlock2D
12
+ )
13
+ from diffusers.models.resnet import (
14
+ ResnetBlock2D,
15
+ Downsample2D,
16
+ Upsample2D,
17
+ )
18
+ from diffusers.models.transformer_2d import Transformer2DModelOutput
19
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
20
+ from diffusers.models.activations import get_activation
21
+ from diffusers.utils import logging, is_torch_version
22
+ from diffusers.utils.import_utils import is_xformers_available
23
+ from .videoldm_transformer_blocks import Transformer2DConditionModel
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ if is_xformers_available():
28
+ import xformers
29
+ import xformers.ops
30
+ else:
31
+ xformers = None
32
+
33
+
34
+ def get_down_block(
35
+ down_block_type,
36
+ num_layers,
37
+ in_channels,
38
+ out_channels,
39
+ temb_channels,
40
+ add_downsample,
41
+ resnet_eps,
42
+ resnet_act_fn,
43
+ transformer_layers_per_block=1,
44
+ num_attention_heads=None,
45
+ resnet_groups=None,
46
+ cross_attention_dim=None,
47
+ downsample_padding=None,
48
+ dual_cross_attention=False,
49
+ use_linear_projection=False,
50
+ only_cross_attention=False,
51
+ upcast_attention=False,
52
+ resnet_time_scale_shift="default",
53
+ attention_type="default",
54
+ resnet_skip_time_act=False,
55
+ resnet_out_scale_factor=1.0,
56
+ cross_attention_norm=None,
57
+ attention_head_dim=None,
58
+ downsample_type=None,
59
+ dropout=0.0,
60
+ # additional
61
+ use_temporal=True,
62
+ augment_temporal_attention=False,
63
+ n_frames=8,
64
+ n_temp_heads=8,
65
+ first_frame_condition_mode="none",
66
+ latent_channels=4,
67
+ rotary_emb=False,
68
+ ):
69
+ # If attn head dim is not defined, we default it to the number of heads
70
+ if attention_head_dim is None:
71
+ logger.warn(
72
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
73
+ )
74
+ attention_head_dim = num_attention_heads
75
+
76
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
77
+ if down_block_type == "DownBlock2D":
78
+ return VideoLDMDownBlock(
79
+ num_layers=num_layers,
80
+ in_channels=in_channels,
81
+ out_channels=out_channels,
82
+ temb_channels=temb_channels,
83
+ dropout=dropout,
84
+ add_downsample=add_downsample,
85
+ resnet_eps=resnet_eps,
86
+ resnet_act_fn=resnet_act_fn,
87
+ resnet_groups=resnet_groups,
88
+ downsample_padding=downsample_padding,
89
+ resnet_time_scale_shift=resnet_time_scale_shift,
90
+ # additional
91
+ use_temporal=use_temporal,
92
+ n_frames=n_frames,
93
+ first_frame_condition_mode=first_frame_condition_mode,
94
+ latent_channels=latent_channels
95
+ )
96
+ elif down_block_type == "CrossAttnDownBlock2D":
97
+ return VideoLDMCrossAttnDownBlock(
98
+ num_layers=num_layers,
99
+ transformer_layers_per_block=transformer_layers_per_block,
100
+ in_channels=in_channels,
101
+ out_channels=out_channels,
102
+ temb_channels=temb_channels,
103
+ dropout=dropout,
104
+ add_downsample=add_downsample,
105
+ resnet_eps=resnet_eps,
106
+ resnet_act_fn=resnet_act_fn,
107
+ resnet_groups=resnet_groups,
108
+ downsample_padding=downsample_padding,
109
+ cross_attention_dim=cross_attention_dim,
110
+ num_attention_heads=num_attention_heads,
111
+ dual_cross_attention=dual_cross_attention,
112
+ use_linear_projection=use_linear_projection,
113
+ only_cross_attention=only_cross_attention,
114
+ upcast_attention=upcast_attention,
115
+ resnet_time_scale_shift=resnet_time_scale_shift,
116
+ attention_type=attention_type,
117
+ # additional
118
+ use_temporal=use_temporal,
119
+ augment_temporal_attention=augment_temporal_attention,
120
+ n_frames=n_frames,
121
+ n_temp_heads=n_temp_heads,
122
+ first_frame_condition_mode=first_frame_condition_mode,
123
+ latent_channels=latent_channels,
124
+ rotary_emb=rotary_emb,
125
+ )
126
+
127
+ raise ValueError(f'{down_block_type} does not exist.')
128
+
129
+
130
+ def get_up_block(
131
+ up_block_type,
132
+ num_layers,
133
+ in_channels,
134
+ out_channels,
135
+ prev_output_channel,
136
+ temb_channels,
137
+ add_upsample,
138
+ resnet_eps,
139
+ resnet_act_fn,
140
+ transformer_layers_per_block=1,
141
+ num_attention_heads=None,
142
+ resnet_groups=None,
143
+ cross_attention_dim=None,
144
+ dual_cross_attention=False,
145
+ use_linear_projection=False,
146
+ only_cross_attention=False,
147
+ upcast_attention=False,
148
+ resnet_time_scale_shift="default",
149
+ attention_type="default",
150
+ resnet_skip_time_act=False,
151
+ resnet_out_scale_factor=1.0,
152
+ cross_attention_norm=None,
153
+ attention_head_dim=None,
154
+ upsample_type=None,
155
+ dropout=0.0,
156
+ # additional
157
+ use_temporal=True,
158
+ augment_temporal_attention=False,
159
+ n_frames=8,
160
+ n_temp_heads=8,
161
+ first_frame_condition_mode="none",
162
+ latent_channels=4,
163
+ rotary_emb=None,
164
+ ):
165
+ if attention_head_dim is None:
166
+ logger.warn(
167
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
168
+ )
169
+ attention_head_dim = num_attention_heads
170
+
171
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
172
+ if up_block_type == "UpBlock2D":
173
+ return VideoLDMUpBlock(
174
+ num_layers=num_layers,
175
+ in_channels=in_channels,
176
+ out_channels=out_channels,
177
+ prev_output_channel=prev_output_channel,
178
+ temb_channels=temb_channels,
179
+ dropout=dropout,
180
+ add_upsample=add_upsample,
181
+ resnet_eps=resnet_eps,
182
+ resnet_act_fn=resnet_act_fn,
183
+ resnet_groups=resnet_groups,
184
+ resnet_time_scale_shift=resnet_time_scale_shift,
185
+ # additional
186
+ use_temporal=use_temporal,
187
+ n_frames=n_frames,
188
+ first_frame_condition_mode=first_frame_condition_mode,
189
+ latent_channels=latent_channels
190
+ )
191
+ elif up_block_type == 'CrossAttnUpBlock2D':
192
+ return VideoLDMCrossAttnUpBlock(
193
+ num_layers=num_layers,
194
+ transformer_layers_per_block=transformer_layers_per_block,
195
+ in_channels=in_channels,
196
+ out_channels=out_channels,
197
+ prev_output_channel=prev_output_channel,
198
+ temb_channels=temb_channels,
199
+ dropout=dropout,
200
+ add_upsample=add_upsample,
201
+ resnet_eps=resnet_eps,
202
+ resnet_act_fn=resnet_act_fn,
203
+ resnet_groups=resnet_groups,
204
+ cross_attention_dim=cross_attention_dim,
205
+ num_attention_heads=num_attention_heads,
206
+ dual_cross_attention=dual_cross_attention,
207
+ use_linear_projection=use_linear_projection,
208
+ only_cross_attention=only_cross_attention,
209
+ upcast_attention=upcast_attention,
210
+ resnet_time_scale_shift=resnet_time_scale_shift,
211
+ attention_type=attention_type,
212
+ # additional
213
+ use_temporal=use_temporal,
214
+ augment_temporal_attention=augment_temporal_attention,
215
+ n_frames=n_frames,
216
+ n_temp_heads=n_temp_heads,
217
+ first_frame_condition_mode=first_frame_condition_mode,
218
+ latent_channels=latent_channels,
219
+ rotary_emb=rotary_emb,
220
+ )
221
+
222
+ raise ValueError(f'{up_block_type} does not exist.')
223
+
224
+
225
+ class TemporalResnetBlock(nn.Module):
226
+ def __init__(
227
+ self,
228
+ *,
229
+ in_channels,
230
+ out_channels=None,
231
+ dropout=0.0,
232
+ temb_channels=512,
233
+ groups=32,
234
+ groups_out=None,
235
+ pre_norm=True,
236
+ eps=1e-6,
237
+ non_linearity="swish",
238
+ time_embedding_norm="default",
239
+ output_scale_factor=1.0,
240
+ # additional
241
+ n_frames=8,
242
+ ):
243
+ super().__init__()
244
+ self.pre_norm = pre_norm
245
+ self.pre_norm = True
246
+ self.in_channels = in_channels
247
+ out_channels = in_channels if out_channels is None else out_channels
248
+ self.out_channels = out_channels
249
+ self.time_embedding_norm = time_embedding_norm
250
+ self.output_scale_factor = output_scale_factor
251
+
252
+ if groups_out is None:
253
+ groups_out = groups
254
+
255
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
256
+
257
+ self.conv1 = Conv3DLayer(in_channels, out_channels, n_frames=n_frames)
258
+
259
+ if temb_channels is not None:
260
+ if self.time_embedding_norm == "default":
261
+ time_emb_proj_out_channels = out_channels
262
+ elif self.time_embedding_norm == "scale_shift":
263
+ time_emb_proj_out_channels = out_channels * 2
264
+ else:
265
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
266
+
267
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
268
+ else:
269
+ self.time_emb_proj = None
270
+
271
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
272
+
273
+ self.dropout = torch.nn.Dropout(dropout)
274
+ self.conv2 = Conv3DLayer(out_channels, out_channels, n_frames=n_frames)
275
+
276
+ self.nonlinearity = get_activation(non_linearity)
277
+
278
+ self.alpha = nn.Parameter(torch.ones(1))
279
+
280
+ def forward(self, input_tensor, temb=None):
281
+ hidden_states = input_tensor
282
+
283
+ hidden_states = self.norm1(hidden_states)
284
+ hidden_states = self.nonlinearity(hidden_states)
285
+
286
+ hidden_states = self.conv1(hidden_states)
287
+
288
+ if temb is not None:
289
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
290
+
291
+ if temb is not None and self.time_embedding_norm == "default":
292
+ hidden_states = hidden_states + temb
293
+
294
+ hidden_states = self.norm2(hidden_states)
295
+
296
+ if temb is not None and self.time_embedding_norm == "scale_shift":
297
+ scale, shift = torch.chunk(temb, 2, dim=1)
298
+ hidden_states = hidden_states * (1 + scale) + shift
299
+
300
+ hidden_states = self.nonlinearity(hidden_states)
301
+
302
+ hidden_states = self.dropout(hidden_states)
303
+ hidden_states = self.conv2(hidden_states)
304
+
305
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
306
+
307
+ # weighted sum between spatial and temporal features
308
+ with torch.no_grad():
309
+ self.alpha.clamp_(0, 1)
310
+
311
+ output_tensor = self.alpha * input_tensor + (1 - self.alpha) * output_tensor
312
+
313
+ return output_tensor
314
+
315
+
316
+ class Conv3DLayer(nn.Conv3d):
317
+ def __init__(self, in_dim, out_dim, n_frames):
318
+ k, p = (3, 1, 1), (1, 0, 0)
319
+ super().__init__(in_channels=in_dim, out_channels=out_dim, kernel_size=k, stride=1, padding=p)
320
+
321
+ self.to_3d = Rearrange('(b t) c h w -> b c t h w', t=n_frames)
322
+ self.to_2d = Rearrange('b c t h w -> (b t) c h w')
323
+
324
+ def forward(self, x):
325
+ h = self.to_3d(x)
326
+ h = super().forward(h)
327
+ out = self.to_2d(h)
328
+ return out
329
+
330
+
331
+ class IdentityLayer(nn.Identity):
332
+ def __init__(self, return_trans2d_output, *args, **kwargs):
333
+ super().__init__()
334
+ self.return_trans2d_output = return_trans2d_output
335
+
336
+ def forward(self, x, *args, **kwargs):
337
+ if self.return_trans2d_output:
338
+ return Transformer2DModelOutput(sample=x)
339
+ else:
340
+ return x
341
+
342
+
343
+ class VideoLDMCrossAttnDownBlock(nn.Module):
344
+ def __init__(
345
+ self,
346
+ in_channels: int,
347
+ out_channels: int,
348
+ temb_channels: int,
349
+ dropout: float = 0.0,
350
+ num_layers: int = 1,
351
+ transformer_layers_per_block: int = 1,
352
+ resnet_eps: float = 1e-6,
353
+ resnet_time_scale_shift: str = "default",
354
+ resnet_act_fn: str = "swish",
355
+ resnet_groups: int = 32,
356
+ resnet_pre_norm: bool = True,
357
+ num_attention_heads=1,
358
+ cross_attention_dim=1280,
359
+ output_scale_factor=1.0,
360
+ downsample_padding=1,
361
+ add_downsample=True,
362
+ dual_cross_attention=False,
363
+ use_linear_projection=False,
364
+ only_cross_attention=False,
365
+ upcast_attention=False,
366
+ attention_type="default",
367
+ # additional
368
+ use_temporal=True,
369
+ augment_temporal_attention=False,
370
+ n_frames=8,
371
+ n_temp_heads=8,
372
+ first_frame_condition_mode="none",
373
+ latent_channels=4,
374
+ rotary_emb=False,
375
+ ):
376
+ super().__init__()
377
+
378
+ self.use_temporal = use_temporal
379
+
380
+ self.n_frames = n_frames
381
+ self.first_frame_condition_mode = first_frame_condition_mode
382
+ if self.first_frame_condition_mode == "conv2d":
383
+ self.first_frame_conv = nn.Conv2d(latent_channels, in_channels, kernel_size=1)
384
+
385
+ resnets = []
386
+ attentions = []
387
+
388
+ self.n_frames = n_frames
389
+ self.n_temp_heads = n_temp_heads
390
+
391
+ self.has_cross_attention = True
392
+ self.num_attention_heads = num_attention_heads
393
+
394
+ for i in range(num_layers):
395
+ in_channels = in_channels if i == 0 else out_channels
396
+ resnets.append(
397
+ ResnetBlock2D(
398
+ in_channels=in_channels,
399
+ out_channels=out_channels,
400
+ temb_channels=temb_channels,
401
+ eps=resnet_eps,
402
+ groups=resnet_groups,
403
+ dropout=dropout,
404
+ time_embedding_norm=resnet_time_scale_shift,
405
+ non_linearity=resnet_act_fn,
406
+ output_scale_factor=output_scale_factor,
407
+ pre_norm=resnet_pre_norm,
408
+ )
409
+ )
410
+ if not dual_cross_attention:
411
+ attentions.append(
412
+ Transformer2DConditionModel(
413
+ num_attention_heads,
414
+ out_channels // num_attention_heads,
415
+ in_channels=out_channels,
416
+ num_layers=transformer_layers_per_block,
417
+ cross_attention_dim=cross_attention_dim,
418
+ norm_num_groups=resnet_groups,
419
+ use_linear_projection=use_linear_projection,
420
+ only_cross_attention=only_cross_attention,
421
+ upcast_attention=upcast_attention,
422
+ attention_type=attention_type,
423
+ # additional
424
+ n_frames=n_frames,
425
+ )
426
+ )
427
+ else:
428
+ attentions.append(
429
+ DualTransformer2DModel(
430
+ num_attention_heads,
431
+ out_channels // num_attention_heads,
432
+ in_channels=out_channels,
433
+ num_layers=1,
434
+ cross_attention_dim=cross_attention_dim,
435
+ norm_num_groups=resnet_groups,
436
+ )
437
+ )
438
+ self.attentions = nn.ModuleList(attentions)
439
+ self.resnets = nn.ModuleList(resnets)
440
+
441
+ if add_downsample:
442
+ self.downsamplers = nn.ModuleList(
443
+ [
444
+ Downsample2D(
445
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
446
+ )
447
+ ]
448
+ )
449
+ else:
450
+ self.downsamplers = None
451
+
452
+ self.gradient_checkpointing = False
453
+
454
+ # >>> Temporal Layers >>>
455
+ conv3ds = []
456
+ tempo_attns = []
457
+
458
+ for i in range(num_layers):
459
+ if self.use_temporal:
460
+ conv3ds.append(
461
+ TemporalResnetBlock(
462
+ in_channels=out_channels,
463
+ out_channels=out_channels,
464
+ n_frames=n_frames,
465
+ )
466
+ )
467
+
468
+ tempo_attns.append(
469
+ Transformer2DConditionModel(
470
+ n_temp_heads,
471
+ out_channels // n_temp_heads,
472
+ in_channels=out_channels,
473
+ num_layers=transformer_layers_per_block,
474
+ cross_attention_dim=cross_attention_dim,
475
+ norm_num_groups=resnet_groups,
476
+ use_linear_projection=use_linear_projection,
477
+ only_cross_attention=only_cross_attention,
478
+ upcast_attention=upcast_attention,
479
+ attention_type=attention_type,
480
+ # additional
481
+ n_frames=n_frames,
482
+ is_temporal=True,
483
+ augment_temporal_attention=augment_temporal_attention,
484
+ rotary_emb=rotary_emb
485
+ )
486
+ )
487
+ else:
488
+ conv3ds.append(IdentityLayer(return_trans2d_output=False))
489
+ tempo_attns.append(IdentityLayer(return_trans2d_output=True))
490
+
491
+ self.conv3ds = nn.ModuleList(conv3ds)
492
+ self.tempo_attns = nn.ModuleList(tempo_attns)
493
+ # <<< Temporal Layers <<<
494
+
495
+ def forward(
496
+ self,
497
+ hidden_states: torch.FloatTensor,
498
+ temb: Optional[torch.FloatTensor] = None,
499
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
500
+ attention_mask: Optional[torch.FloatTensor] = None,
501
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
502
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
503
+ # additional
504
+ first_frame_latents=None,
505
+ ):
506
+ condition_on_first_frame = (self.first_frame_condition_mode != "none" and self.first_frame_condition_mode != "input_only")
507
+ # input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w
508
+ if self.first_frame_condition_mode == "conv2d":
509
+ hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames)
510
+ hidden_height = hidden_states.shape[3]
511
+ first_frame_height = first_frame_latents.shape[3]
512
+ downsample_ratio = hidden_height / first_frame_height
513
+ first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest")
514
+ first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2)
515
+ hidden_states[:, :, 0:1, :, :] = first_frame_latents
516
+ hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames)
517
+
518
+ output_states = ()
519
+
520
+ for resnet, conv3d, attn, tempo_attn in zip(self.resnets, self.conv3ds, self.attentions, self.tempo_attns):
521
+
522
+ hidden_states = resnet(hidden_states, temb)
523
+ hidden_states = conv3d(hidden_states)
524
+ hidden_states = attn(
525
+ hidden_states,
526
+ encoder_hidden_states=encoder_hidden_states,
527
+ cross_attention_kwargs=cross_attention_kwargs,
528
+ condition_on_first_frame=condition_on_first_frame,
529
+ ).sample
530
+ hidden_states = tempo_attn(
531
+ hidden_states,
532
+ encoder_hidden_states=encoder_hidden_states,
533
+ cross_attention_kwargs=cross_attention_kwargs,
534
+ condition_on_first_frame=False,
535
+ ).sample
536
+
537
+ output_states += (hidden_states,)
538
+
539
+ if self.downsamplers is not None:
540
+ for downsampler in self.downsamplers:
541
+ hidden_states = downsampler(hidden_states)
542
+
543
+ output_states += (hidden_states,)
544
+
545
+ return hidden_states, output_states
546
+
547
+
548
+ class VideoLDMCrossAttnUpBlock(nn.Module):
549
+ def __init__(
550
+ self,
551
+ in_channels: int,
552
+ out_channels: int,
553
+ prev_output_channel: int,
554
+ temb_channels: int,
555
+ dropout: float = 0.0,
556
+ num_layers: int = 1,
557
+ transformer_layers_per_block: int = 1,
558
+ resnet_eps: float = 1e-6,
559
+ resnet_time_scale_shift: str = "default",
560
+ resnet_act_fn: str = "swish",
561
+ resnet_groups: int = 32,
562
+ resnet_pre_norm: bool = True,
563
+ num_attention_heads=1,
564
+ cross_attention_dim=1280,
565
+ output_scale_factor=1.0,
566
+ add_upsample=True,
567
+ dual_cross_attention=False,
568
+ use_linear_projection=False,
569
+ only_cross_attention=False,
570
+ upcast_attention=False,
571
+ attention_type="default",
572
+ # additional
573
+ use_temporal=True,
574
+ augment_temporal_attention=False,
575
+ n_frames=8,
576
+ n_temp_heads=8,
577
+ first_frame_condition_mode="none",
578
+ latent_channels=4,
579
+ rotary_emb=False,
580
+ ):
581
+ super().__init__()
582
+
583
+ self.use_temporal = use_temporal
584
+
585
+ self.n_frames = n_frames
586
+ self.first_frame_condition_mode = first_frame_condition_mode
587
+ if self.first_frame_condition_mode == "conv2d":
588
+ self.first_frame_conv = nn.Conv2d(latent_channels, prev_output_channel, kernel_size=1)
589
+
590
+ resnets = []
591
+ attentions = []
592
+
593
+ self.n_frames = n_frames
594
+ self.n_temp_heads = n_temp_heads
595
+
596
+ self.has_cross_attention = True
597
+ self.num_attention_heads = num_attention_heads
598
+
599
+ for i in range(num_layers):
600
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
601
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
602
+
603
+ resnets.append(
604
+ ResnetBlock2D(
605
+ in_channels=resnet_in_channels + res_skip_channels,
606
+ out_channels=out_channels,
607
+ temb_channels=temb_channels,
608
+ eps=resnet_eps,
609
+ groups=resnet_groups,
610
+ dropout=dropout,
611
+ time_embedding_norm=resnet_time_scale_shift,
612
+ non_linearity=resnet_act_fn,
613
+ output_scale_factor=output_scale_factor,
614
+ pre_norm=resnet_pre_norm,
615
+ )
616
+ )
617
+ if not dual_cross_attention:
618
+ attentions.append(
619
+ Transformer2DConditionModel(
620
+ num_attention_heads,
621
+ out_channels // num_attention_heads,
622
+ in_channels=out_channels,
623
+ num_layers=transformer_layers_per_block,
624
+ cross_attention_dim=cross_attention_dim,
625
+ norm_num_groups=resnet_groups,
626
+ use_linear_projection=use_linear_projection,
627
+ only_cross_attention=only_cross_attention,
628
+ upcast_attention=upcast_attention,
629
+ attention_type=attention_type,
630
+ # additional
631
+ n_frames=n_frames,
632
+ )
633
+ )
634
+ else:
635
+ attentions.append(
636
+ DualTransformer2DModel(
637
+ num_attention_heads,
638
+ out_channels // num_attention_heads,
639
+ in_channels=out_channels,
640
+ num_layers=1,
641
+ cross_attention_dim=cross_attention_dim,
642
+ norm_num_groups=resnet_groups,
643
+ )
644
+ )
645
+ self.attentions = nn.ModuleList(attentions)
646
+ self.resnets = nn.ModuleList(resnets)
647
+
648
+ if add_upsample:
649
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
650
+ else:
651
+ self.upsamplers = None
652
+
653
+ self.gradient_checkpointing = False
654
+
655
+ # >>> Temporal Layers >>>
656
+ conv3ds = []
657
+ tempo_attns = []
658
+
659
+ for i in range(num_layers):
660
+ if self.use_temporal:
661
+ conv3ds.append(
662
+ TemporalResnetBlock(
663
+ in_channels=out_channels,
664
+ out_channels=out_channels,
665
+ n_frames=n_frames,
666
+ )
667
+ )
668
+
669
+ tempo_attns.append(
670
+ Transformer2DConditionModel(
671
+ n_temp_heads,
672
+ out_channels // n_temp_heads,
673
+ in_channels=out_channels,
674
+ num_layers=transformer_layers_per_block,
675
+ cross_attention_dim=cross_attention_dim,
676
+ norm_num_groups=resnet_groups,
677
+ use_linear_projection=use_linear_projection,
678
+ only_cross_attention=only_cross_attention,
679
+ upcast_attention=upcast_attention,
680
+ attention_type=attention_type,
681
+ # additional
682
+ n_frames=n_frames,
683
+ augment_temporal_attention=augment_temporal_attention,
684
+ is_temporal=True,
685
+ rotary_emb=rotary_emb,
686
+ )
687
+ )
688
+ else:
689
+ conv3ds.append(IdentityLayer(return_trans2d_output=False))
690
+ tempo_attns.append(IdentityLayer(return_trans2d_output=True))
691
+
692
+ self.conv3ds = nn.ModuleList(conv3ds)
693
+ self.tempo_attns = nn.ModuleList(tempo_attns)
694
+ # <<< Temporal Layers <<<
695
+
696
+ def forward(
697
+ self,
698
+ hidden_states: torch.FloatTensor,
699
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
700
+ temb: Optional[torch.FloatTensor] = None,
701
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
702
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
703
+ upsample_size: Optional[int] = None,
704
+ attention_mask: Optional[torch.FloatTensor] = None,
705
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
706
+ # additional
707
+ first_frame_latents=None,
708
+ ):
709
+ condition_on_first_frame = (self.first_frame_condition_mode != "none" and self.first_frame_condition_mode != "input_only")
710
+ # input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w
711
+ if self.first_frame_condition_mode == "conv2d":
712
+ hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames)
713
+ hidden_height = hidden_states.shape[3]
714
+ first_frame_height = first_frame_latents.shape[3]
715
+ downsample_ratio = hidden_height / first_frame_height
716
+ first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest")
717
+ first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2)
718
+ hidden_states[:, :, 0:1, :, :] = first_frame_latents
719
+ hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames)
720
+
721
+ for resnet, conv3d, attn, tempo_attn in zip(self.resnets, self.conv3ds, self.attentions, self.tempo_attns):
722
+
723
+ res_hidden_states = res_hidden_states_tuple[-1]
724
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
725
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
726
+
727
+ hidden_states = resnet(hidden_states, temb)
728
+ hidden_states = conv3d(hidden_states)
729
+ hidden_states = attn(
730
+ hidden_states,
731
+ encoder_hidden_states=encoder_hidden_states,
732
+ cross_attention_kwargs=cross_attention_kwargs,
733
+ condition_on_first_frame=condition_on_first_frame,
734
+ ).sample
735
+ hidden_states = tempo_attn(
736
+ hidden_states,
737
+ encoder_hidden_states=encoder_hidden_states,
738
+ cross_attention_kwargs=cross_attention_kwargs,
739
+ condition_on_first_frame=False,
740
+ ).sample
741
+
742
+ if self.upsamplers is not None:
743
+ for upsampler in self.upsamplers:
744
+ hidden_states = upsampler(hidden_states, upsample_size)
745
+ return hidden_states
746
+
747
+
748
+ class VideoLDMUNetMidBlock2DCrossAttn(nn.Module):
749
+ def __init__(
750
+ self,
751
+ in_channels: int,
752
+ temb_channels: int,
753
+ dropout: float = 0.0,
754
+ num_layers: int = 1,
755
+ transformer_layers_per_block: int = 1,
756
+ resnet_eps: float = 1e-6,
757
+ resnet_time_scale_shift: str = "default",
758
+ resnet_act_fn: str = "swish",
759
+ resnet_groups: int = 32,
760
+ resnet_pre_norm: bool = True,
761
+ num_attention_heads=1,
762
+ output_scale_factor=1.0,
763
+ cross_attention_dim=1280,
764
+ dual_cross_attention=False,
765
+ use_linear_projection=False,
766
+ upcast_attention=False,
767
+ attention_type="default",
768
+ # additional
769
+ use_temporal=True,
770
+ n_frames: int = 8,
771
+ first_frame_condition_mode="none",
772
+ latent_channels=4,
773
+ ):
774
+ super().__init__()
775
+
776
+ self.use_temporal = use_temporal
777
+
778
+ self.n_frames = n_frames
779
+ self.first_frame_condition_mode = first_frame_condition_mode
780
+ if self.first_frame_condition_mode == "conv2d":
781
+ self.first_frame_conv = nn.Conv2d(latent_channels, in_channels, kernel_size=1)
782
+
783
+ self.has_cross_attention = True
784
+ self.num_attention_heads = num_attention_heads
785
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
786
+
787
+ # there is always at least one resnet
788
+ resnets = [
789
+ ResnetBlock2D(
790
+ in_channels=in_channels,
791
+ out_channels=in_channels,
792
+ temb_channels=temb_channels,
793
+ eps=resnet_eps,
794
+ groups=resnet_groups,
795
+ dropout=dropout,
796
+ time_embedding_norm=resnet_time_scale_shift,
797
+ non_linearity=resnet_act_fn,
798
+ output_scale_factor=output_scale_factor,
799
+ pre_norm=resnet_pre_norm,
800
+ )
801
+ ]
802
+ if self.use_temporal:
803
+ conv3ds = [
804
+ TemporalResnetBlock(
805
+ in_channels=in_channels,
806
+ out_channels=in_channels,
807
+ n_frames=n_frames,
808
+ )
809
+ ]
810
+ else:
811
+ conv3ds = [IdentityLayer(return_trans2d_output=False)]
812
+
813
+ attentions = []
814
+
815
+ for _ in range(num_layers):
816
+ if not dual_cross_attention:
817
+ attentions.append(
818
+ Transformer2DConditionModel(
819
+ num_attention_heads,
820
+ in_channels // num_attention_heads,
821
+ in_channels=in_channels,
822
+ num_layers=transformer_layers_per_block,
823
+ cross_attention_dim=cross_attention_dim,
824
+ norm_num_groups=resnet_groups,
825
+ use_linear_projection=use_linear_projection,
826
+ upcast_attention=upcast_attention,
827
+ attention_type=attention_type,
828
+ # additional
829
+ n_frames=n_frames,
830
+ )
831
+ )
832
+ else:
833
+ attentions.append(
834
+ DualTransformer2DModel(
835
+ num_attention_heads,
836
+ in_channels // num_attention_heads,
837
+ in_channels=in_channels,
838
+ num_layers=1,
839
+ cross_attention_dim=cross_attention_dim,
840
+ norm_num_groups=resnet_groups,
841
+ )
842
+ )
843
+ resnets.append(
844
+ ResnetBlock2D(
845
+ in_channels=in_channels,
846
+ out_channels=in_channels,
847
+ temb_channels=temb_channels,
848
+ eps=resnet_eps,
849
+ groups=resnet_groups,
850
+ dropout=dropout,
851
+ time_embedding_norm=resnet_time_scale_shift,
852
+ non_linearity=resnet_act_fn,
853
+ output_scale_factor=output_scale_factor,
854
+ pre_norm=resnet_pre_norm,
855
+ )
856
+ )
857
+ if self.use_temporal:
858
+ conv3ds.append(
859
+ TemporalResnetBlock(
860
+ in_channels=in_channels,
861
+ out_channels=in_channels,
862
+ n_frames=n_frames,
863
+ )
864
+ )
865
+ else:
866
+ conv3ds.append(IdentityLayer(return_trans2d_output=False))
867
+
868
+ self.attentions = nn.ModuleList(attentions)
869
+ self.resnets = nn.ModuleList(resnets)
870
+ self.conv3ds = nn.ModuleList(conv3ds)
871
+
872
+ self.gradient_checkpointing = False
873
+
874
+ def forward(
875
+ self,
876
+ hidden_states: torch.FloatTensor,
877
+ temb: Optional[torch.FloatTensor] = None,
878
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
879
+ attention_mask: Optional[torch.FloatTensor] = None,
880
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
881
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
882
+ # additional
883
+ first_frame_latents=None,
884
+ ) -> torch.FloatTensor:
885
+ condition_on_first_frame = (self.first_frame_condition_mode != "none" and self.first_frame_condition_mode != "input_only")
886
+ # input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w
887
+ if self.first_frame_condition_mode == "conv2d":
888
+ hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames)
889
+ hidden_height = hidden_states.shape[3]
890
+ first_frame_height = first_frame_latents.shape[3]
891
+ downsample_ratio = hidden_height / first_frame_height
892
+ first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest")
893
+ first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2)
894
+ hidden_states[:, :, 0:1, :, :] = first_frame_latents
895
+ hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames)
896
+
897
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
898
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
899
+ hidden_states = self.conv3ds[0](hidden_states)
900
+ for attn, resnet, conv3d in zip(self.attentions, self.resnets[1:], self.conv3ds[1:]):
901
+ if self.training and self.gradient_checkpointing:
902
+
903
+ def create_custom_forward(module, return_dict=None):
904
+ def custom_forward(*inputs):
905
+ if return_dict is not None:
906
+ return module(*inputs, return_dict=return_dict)
907
+ else:
908
+ return module(*inputs)
909
+
910
+ return custom_forward
911
+
912
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
913
+ hidden_states = attn(
914
+ hidden_states,
915
+ encoder_hidden_states=encoder_hidden_states,
916
+ cross_attention_kwargs=cross_attention_kwargs,
917
+ attention_mask=attention_mask,
918
+ encoder_attention_mask=encoder_attention_mask,
919
+ return_dict=False,
920
+ # additional
921
+ condition_on_first_frame=condition_on_first_frame,
922
+ )[0]
923
+ hidden_states = torch.utils.checkpoint.checkpoint(
924
+ create_custom_forward(resnet),
925
+ hidden_states,
926
+ temb,
927
+ **ckpt_kwargs,
928
+ )
929
+ hidden_states = conv3d(hidden_states)
930
+ else:
931
+ hidden_states = attn(
932
+ hidden_states,
933
+ encoder_hidden_states=encoder_hidden_states,
934
+ cross_attention_kwargs=cross_attention_kwargs,
935
+ attention_mask=attention_mask,
936
+ encoder_attention_mask=encoder_attention_mask,
937
+ return_dict=False,
938
+ # additional
939
+ condition_on_first_frame=condition_on_first_frame,
940
+ )[0]
941
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
942
+ hidden_states = conv3d(hidden_states)
943
+
944
+ return hidden_states
945
+
946
+
947
+ class VideoLDMDownBlock(DownBlock2D):
948
+ def __init__(
949
+ self,
950
+ in_channels: int,
951
+ out_channels: int,
952
+ temb_channels: int,
953
+ dropout: float = 0.0,
954
+ num_layers: int = 1,
955
+ resnet_eps: float = 1e-6,
956
+ resnet_time_scale_shift: str = "default",
957
+ resnet_act_fn: str = "swish",
958
+ resnet_groups: int = 32,
959
+ resnet_pre_norm: bool = True,
960
+ output_scale_factor=1.0,
961
+ add_downsample=True,
962
+ downsample_padding=1,
963
+ # additional
964
+ use_temporal=True,
965
+ n_frames: int = 8,
966
+ first_frame_condition_mode="none",
967
+ latent_channels=4,
968
+ ):
969
+ super().__init__(
970
+ in_channels,
971
+ out_channels,
972
+ temb_channels,
973
+ dropout,
974
+ num_layers,
975
+ resnet_eps,
976
+ resnet_time_scale_shift,
977
+ resnet_act_fn,
978
+ resnet_groups,
979
+ resnet_pre_norm,
980
+ output_scale_factor,
981
+ add_downsample,
982
+ downsample_padding,)
983
+
984
+ self.use_temporal = use_temporal
985
+
986
+ self.n_frames = n_frames
987
+ self.first_frame_condition_mode = first_frame_condition_mode
988
+ if self.first_frame_condition_mode == "conv2d":
989
+ self.first_frame_conv = nn.Conv2d(latent_channels, in_channels, kernel_size=1)
990
+
991
+ # >>> Temporal Layers >>>
992
+ conv3ds = []
993
+ for i in range(num_layers):
994
+ if self.use_temporal:
995
+ conv3ds.append(
996
+ TemporalResnetBlock(
997
+ in_channels=out_channels,
998
+ out_channels=out_channels,
999
+ n_frames=n_frames,
1000
+ )
1001
+ )
1002
+ else:
1003
+ conv3ds.append(IdentityLayer(return_trans2d_output=False))
1004
+ self.conv3ds = nn.ModuleList(conv3ds)
1005
+ # <<< Temporal Layers <<<
1006
+
1007
+ def forward(self, hidden_states, temb=None, scale: float = 1, first_frame_latents=None):
1008
+ # input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w
1009
+ if self.first_frame_condition_mode == "conv2d":
1010
+ hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames)
1011
+ hidden_height = hidden_states.shape[3]
1012
+ first_frame_height = first_frame_latents.shape[3]
1013
+ downsample_ratio = hidden_height / first_frame_height
1014
+ first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest")
1015
+ first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2)
1016
+ hidden_states[:, :, 0:1, :, :] = first_frame_latents
1017
+ hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames)
1018
+
1019
+ output_states = ()
1020
+
1021
+ for resnet, conv3d in zip(self.resnets, self.conv3ds):
1022
+ if self.training and self.gradient_checkpointing:
1023
+
1024
+ def create_custom_forward(module):
1025
+ def custom_forward(*inputs):
1026
+ return module(*inputs)
1027
+
1028
+ return custom_forward
1029
+
1030
+ if is_torch_version(">=", "1.11.0"):
1031
+ hidden_states = torch.utils.checkpoint.checkpoint(
1032
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
1033
+ )
1034
+ else:
1035
+ hidden_states = torch.utils.checkpoint.checkpoint(
1036
+ create_custom_forward(resnet), hidden_states, temb
1037
+ )
1038
+ else:
1039
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1040
+
1041
+ hidden_states = conv3d(hidden_states)
1042
+
1043
+ output_states = output_states + (hidden_states,)
1044
+
1045
+ if self.downsamplers is not None:
1046
+ for downsampler in self.downsamplers:
1047
+ hidden_states = downsampler(hidden_states, scale=scale)
1048
+
1049
+ output_states = output_states + (hidden_states,)
1050
+
1051
+ return hidden_states, output_states
1052
+
1053
+
1054
+ class VideoLDMUpBlock(UpBlock2D):
1055
+ def __init__(
1056
+ self,
1057
+ in_channels: int,
1058
+ prev_output_channel: int,
1059
+ out_channels: int,
1060
+ temb_channels: int,
1061
+ dropout: float = 0.0,
1062
+ num_layers: int = 1,
1063
+ resnet_eps: float = 1e-6,
1064
+ resnet_time_scale_shift: str = "default",
1065
+ resnet_act_fn: str = "swish",
1066
+ resnet_groups: int = 32,
1067
+ resnet_pre_norm: bool = True,
1068
+ output_scale_factor=1.0,
1069
+ add_upsample=True,
1070
+ # additional
1071
+ use_temporal=True,
1072
+ n_frames: int = 8,
1073
+ first_frame_condition_mode="none",
1074
+ latent_channels=4,
1075
+ ):
1076
+ super().__init__(
1077
+ in_channels,
1078
+ prev_output_channel,
1079
+ out_channels,
1080
+ temb_channels,
1081
+ dropout,
1082
+ num_layers,
1083
+ resnet_eps,
1084
+ resnet_time_scale_shift,
1085
+ resnet_act_fn,
1086
+ resnet_groups,
1087
+ resnet_pre_norm,
1088
+ output_scale_factor,
1089
+ add_upsample,
1090
+ )
1091
+
1092
+ self.use_temporal = use_temporal
1093
+
1094
+ self.n_frames = n_frames
1095
+ self.first_frame_condition_mode = first_frame_condition_mode
1096
+ if self.first_frame_condition_mode == "conv2d":
1097
+ self.first_frame_conv = nn.Conv2d(latent_channels, prev_output_channel, kernel_size=1)
1098
+
1099
+ # >>> Temporal Layers >>>
1100
+ conv3ds = []
1101
+ for i in range(num_layers):
1102
+ if self.use_temporal:
1103
+ conv3ds.append(
1104
+ TemporalResnetBlock(
1105
+ in_channels=out_channels,
1106
+ out_channels=out_channels,
1107
+ n_frames=n_frames,
1108
+ )
1109
+ )
1110
+ else:
1111
+ conv3ds.append(IdentityLayer(return_trans2d_output=False))
1112
+
1113
+ self.conv3ds = nn.ModuleList(conv3ds)
1114
+ # <<< Temporal Layers <<<
1115
+
1116
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1, first_frame_latents=None):
1117
+ # input shape: hidden_states = (b f) c h w, first_frame_latents = b c 1 h w
1118
+ if self.first_frame_condition_mode == "conv2d":
1119
+ hidden_states = rearrange(hidden_states, '(b t) c h w -> b c t h w', t=self.n_frames)
1120
+ hidden_height = hidden_states.shape[3]
1121
+ first_frame_height = first_frame_latents.shape[3]
1122
+ downsample_ratio = hidden_height / first_frame_height
1123
+ first_frame_latents = F.interpolate(first_frame_latents.squeeze(2), scale_factor=downsample_ratio, mode="nearest")
1124
+ first_frame_latents = self.first_frame_conv(first_frame_latents).unsqueeze(2)
1125
+ hidden_states[:, :, 0:1, :, :] = first_frame_latents
1126
+ hidden_states = rearrange(hidden_states, 'b c t h w -> (b t) c h w', t=self.n_frames)
1127
+
1128
+ for resnet, conv3d in zip(self.resnets, self.conv3ds):
1129
+ # pop res hidden states
1130
+ res_hidden_states = res_hidden_states_tuple[-1]
1131
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1132
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1133
+
1134
+ if self.training and self.gradient_checkpointing:
1135
+
1136
+ def create_custom_forward(module):
1137
+ def custom_forward(*inputs):
1138
+ return module(*inputs)
1139
+
1140
+ return custom_forward
1141
+
1142
+ if is_torch_version(">=", "1.11.0"):
1143
+ hidden_states = torch.utils.checkpoint.checkpoint(
1144
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
1145
+ )
1146
+ else:
1147
+ hidden_states = torch.utils.checkpoint.checkpoint(
1148
+ create_custom_forward(resnet), hidden_states, temb
1149
+ )
1150
+ else:
1151
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1152
+
1153
+ hidden_states = conv3d(hidden_states)
1154
+
1155
+ if self.upsamplers is not None:
1156
+ for upsampler in self.upsamplers:
1157
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
1158
+
1159
+ return hidden_states
consisti2v/pipelines/pipeline_autoregress_animation.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2
+
3
+ import inspect
4
+ from typing import Callable, List, Optional, Union
5
+ from dataclasses import dataclass
6
+
7
+ import math
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+ from torchvision import transforms as T
13
+ from PIL import Image
14
+
15
+ from diffusers.utils import is_accelerate_available
16
+ from packaging import version
17
+ from transformers import CLIPTextModel, CLIPTokenizer
18
+
19
+ from diffusers.configuration_utils import FrozenDict
20
+ from diffusers.models import AutoencoderKL
21
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
22
+ from diffusers.schedulers import (
23
+ DDIMScheduler,
24
+ DPMSolverMultistepScheduler,
25
+ EulerAncestralDiscreteScheduler,
26
+ EulerDiscreteScheduler,
27
+ LMSDiscreteScheduler,
28
+ PNDMScheduler,
29
+ )
30
+ from diffusers.utils import deprecate, logging, BaseOutput
31
+
32
+ from einops import rearrange, repeat
33
+
34
+ from ..models.unet import UNet3DConditionModel
35
+ from ..utils.frameinit_utils import freq_mix_3d, get_freq_filter
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+ # copied from https://github.com/huggingface/diffusers/blob/v0.23.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L59C1-L70C21
41
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
42
+ """
43
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
44
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
45
+ """
46
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
47
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
48
+ # rescale the results from guidance (fixes overexposure)
49
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
50
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
51
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
52
+ return noise_cfg
53
+
54
+
55
+ @dataclass
56
+ class AnimationPipelineOutput(BaseOutput):
57
+ videos: Union[torch.Tensor, np.ndarray]
58
+
59
+
60
+ class AutoregressiveAnimationPipeline(DiffusionPipeline):
61
+ _optional_components = []
62
+
63
+ def __init__(
64
+ self,
65
+ vae: AutoencoderKL,
66
+ text_encoder: CLIPTextModel,
67
+ tokenizer: CLIPTokenizer,
68
+ unet: UNet3DConditionModel,
69
+ scheduler: Union[
70
+ DDIMScheduler,
71
+ PNDMScheduler,
72
+ LMSDiscreteScheduler,
73
+ EulerDiscreteScheduler,
74
+ EulerAncestralDiscreteScheduler,
75
+ DPMSolverMultistepScheduler,
76
+ ],
77
+ ):
78
+ super().__init__()
79
+
80
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
81
+ deprecation_message = (
82
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
83
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
84
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
85
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
86
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
87
+ " file"
88
+ )
89
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
90
+ new_config = dict(scheduler.config)
91
+ new_config["steps_offset"] = 1
92
+ scheduler._internal_dict = FrozenDict(new_config)
93
+
94
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
95
+ deprecation_message = (
96
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
97
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
98
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
99
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
100
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
101
+ )
102
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
103
+ new_config = dict(scheduler.config)
104
+ new_config["clip_sample"] = False
105
+ scheduler._internal_dict = FrozenDict(new_config)
106
+
107
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
108
+ version.parse(unet.config._diffusers_version).base_version
109
+ ) < version.parse("0.9.0.dev0")
110
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
111
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
112
+ deprecation_message = (
113
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
114
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
115
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
116
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
117
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
118
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
119
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
120
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
121
+ " the `unet/config.json` file"
122
+ )
123
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
124
+ new_config = dict(unet.config)
125
+ new_config["sample_size"] = 64
126
+ unet._internal_dict = FrozenDict(new_config)
127
+
128
+ self.register_modules(
129
+ vae=vae,
130
+ text_encoder=text_encoder,
131
+ tokenizer=tokenizer,
132
+ unet=unet,
133
+ scheduler=scheduler,
134
+ )
135
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
136
+
137
+ self.freq_filter = None
138
+
139
+ @torch.no_grad()
140
+ def init_filter(self, video_length, height, width, filter_params):
141
+ # initialize frequency filter for noise reinitialization
142
+ batch_size = 1
143
+ num_channels_latents = self.unet.config.in_channels
144
+ filter_shape = [
145
+ batch_size,
146
+ num_channels_latents,
147
+ video_length,
148
+ height // self.vae_scale_factor,
149
+ width // self.vae_scale_factor
150
+ ]
151
+ # self.freq_filter = get_freq_filter(filter_shape, device=self._execution_device, params=filter_params)
152
+ self.freq_filter = get_freq_filter(
153
+ filter_shape,
154
+ device=self._execution_device,
155
+ filter_type=filter_params.method,
156
+ n=filter_params.n if filter_params.method=="butterworth" else None,
157
+ d_s=filter_params.d_s,
158
+ d_t=filter_params.d_t
159
+ )
160
+
161
+ def enable_vae_slicing(self):
162
+ self.vae.enable_slicing()
163
+
164
+ def disable_vae_slicing(self):
165
+ self.vae.disable_slicing()
166
+
167
+ def enable_sequential_cpu_offload(self, gpu_id=0):
168
+ if is_accelerate_available():
169
+ from accelerate import cpu_offload
170
+ else:
171
+ raise ImportError("Please install accelerate via `pip install accelerate`")
172
+
173
+ device = torch.device(f"cuda:{gpu_id}")
174
+
175
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
176
+ if cpu_offloaded_model is not None:
177
+ cpu_offload(cpu_offloaded_model, device)
178
+
179
+
180
+ @property
181
+ def _execution_device(self):
182
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
183
+ return self.device
184
+ for module in self.unet.modules():
185
+ if (
186
+ hasattr(module, "_hf_hook")
187
+ and hasattr(module._hf_hook, "execution_device")
188
+ and module._hf_hook.execution_device is not None
189
+ ):
190
+ return torch.device(module._hf_hook.execution_device)
191
+ return self.device
192
+
193
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
194
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
195
+
196
+ text_inputs = self.tokenizer(
197
+ prompt,
198
+ padding="max_length",
199
+ max_length=self.tokenizer.model_max_length,
200
+ truncation=True,
201
+ return_tensors="pt",
202
+ )
203
+ text_input_ids = text_inputs.input_ids
204
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
205
+
206
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
207
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
208
+ logger.warning(
209
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
210
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
211
+ )
212
+
213
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
214
+ attention_mask = text_inputs.attention_mask.to(device)
215
+ else:
216
+ attention_mask = None
217
+
218
+ text_embeddings = self.text_encoder(
219
+ text_input_ids.to(device),
220
+ attention_mask=attention_mask,
221
+ )
222
+ text_embeddings = text_embeddings[0]
223
+
224
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
225
+ bs_embed, seq_len, _ = text_embeddings.shape
226
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
227
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
228
+
229
+ # get unconditional embeddings for classifier free guidance
230
+ if do_classifier_free_guidance is not None:
231
+ uncond_tokens: List[str]
232
+ if negative_prompt is None:
233
+ uncond_tokens = [""] * batch_size
234
+ elif type(prompt) is not type(negative_prompt):
235
+ raise TypeError(
236
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
237
+ f" {type(prompt)}."
238
+ )
239
+ elif isinstance(negative_prompt, str):
240
+ uncond_tokens = [negative_prompt]
241
+ elif batch_size != len(negative_prompt):
242
+ raise ValueError(
243
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
244
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
245
+ " the batch size of `prompt`."
246
+ )
247
+ else:
248
+ uncond_tokens = negative_prompt
249
+
250
+ max_length = text_input_ids.shape[-1]
251
+ uncond_input = self.tokenizer(
252
+ uncond_tokens,
253
+ padding="max_length",
254
+ max_length=max_length,
255
+ truncation=True,
256
+ return_tensors="pt",
257
+ )
258
+
259
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
260
+ attention_mask = uncond_input.attention_mask.to(device)
261
+ else:
262
+ attention_mask = None
263
+
264
+ uncond_embeddings = self.text_encoder(
265
+ uncond_input.input_ids.to(device),
266
+ attention_mask=attention_mask,
267
+ )
268
+ uncond_embeddings = uncond_embeddings[0]
269
+
270
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
271
+ seq_len = uncond_embeddings.shape[1]
272
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
273
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
274
+
275
+ # For classifier free guidance, we need to do two forward passes.
276
+ # Here we concatenate the unconditional and text embeddings into a single batch
277
+ # to avoid doing two forward passes
278
+ if do_classifier_free_guidance == "text":
279
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
280
+ elif do_classifier_free_guidance == "both":
281
+ text_embeddings = torch.cat([uncond_embeddings, uncond_embeddings, text_embeddings])
282
+
283
+ return text_embeddings
284
+
285
+ def decode_latents(self, latents, first_frames=None):
286
+ video_length = latents.shape[2]
287
+ latents = 1 / self.vae.config.scaling_factor * latents
288
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
289
+ # video = self.vae.decode(latents).sample
290
+ video = []
291
+ for frame_idx in tqdm(range(latents.shape[0]), **self._progress_bar_config):
292
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
293
+ video = torch.cat(video)
294
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
295
+
296
+ if first_frames is not None:
297
+ first_frames = first_frames.unsqueeze(2)
298
+ video = torch.cat([first_frames, video], dim=2)
299
+
300
+ video = (video / 2 + 0.5).clamp(0, 1)
301
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
302
+ video = video.cpu().float().numpy()
303
+ return video
304
+
305
+ def prepare_extra_step_kwargs(self, generator, eta):
306
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
307
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
308
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
309
+ # and should be between [0, 1]
310
+
311
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
312
+ extra_step_kwargs = {}
313
+ if accepts_eta:
314
+ extra_step_kwargs["eta"] = eta
315
+
316
+ # check if the scheduler accepts generator
317
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
318
+ if accepts_generator:
319
+ extra_step_kwargs["generator"] = generator
320
+ return extra_step_kwargs
321
+
322
+ def check_inputs(self, prompt, height, width, callback_steps, first_frame_paths=None):
323
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
324
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
325
+
326
+ if first_frame_paths is not None and (not isinstance(prompt, str) and not isinstance(first_frame_paths, list)):
327
+ raise ValueError(f"`first_frame_paths` has to be of type `str` or `list` but is {type(first_frame_paths)}")
328
+
329
+ if height % 8 != 0 or width % 8 != 0:
330
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
331
+
332
+ if (callback_steps is None) or (
333
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
334
+ ):
335
+ raise ValueError(
336
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
337
+ f" {type(callback_steps)}."
338
+ )
339
+
340
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None, noise_sampling_method="vanilla", noise_alpha=1.0):
341
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
342
+ if isinstance(generator, list) and len(generator) != batch_size:
343
+ raise ValueError(
344
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
345
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
346
+ )
347
+ if latents is None:
348
+ rand_device = "cpu" if device.type == "mps" else device
349
+
350
+ if isinstance(generator, list):
351
+ # shape = shape
352
+ shape = (1,) + shape[1:]
353
+ if noise_sampling_method == "vanilla":
354
+ latents = [
355
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
356
+ for i in range(batch_size)
357
+ ]
358
+ elif noise_sampling_method == "pyoco_mixed":
359
+ base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
360
+ latents = []
361
+ noise_alpha_squared = noise_alpha ** 2
362
+ for i in range(batch_size):
363
+ base_latent = torch.randn(base_shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared))
364
+ ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
365
+ latents.append(base_latent + ind_latent)
366
+ elif noise_sampling_method == "pyoco_progressive":
367
+ latents = []
368
+ noise_alpha_squared = noise_alpha ** 2
369
+ for i in range(batch_size):
370
+ latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
371
+ ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
372
+ for j in range(1, video_length):
373
+ latent[:, :, j, :, :] = latent[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latent[:, :, j, :, :]
374
+ latents.append(latent)
375
+ latents = torch.cat(latents, dim=0).to(device)
376
+ else:
377
+ if noise_sampling_method == "vanilla":
378
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
379
+ elif noise_sampling_method == "pyoco_mixed":
380
+ noise_alpha_squared = noise_alpha ** 2
381
+ base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
382
+ base_latents = torch.randn(base_shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared))
383
+ ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
384
+ latents = base_latents + ind_latents
385
+ elif noise_sampling_method == "pyoco_progressive":
386
+ noise_alpha_squared = noise_alpha ** 2
387
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype)
388
+ ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
389
+ for j in range(1, video_length):
390
+ latents[:, :, j, :, :] = latents[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latents[:, :, j, :, :]
391
+ else:
392
+ if latents.shape != shape:
393
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
394
+ latents = latents.to(device)
395
+
396
+ # scale the initial noise by the standard deviation required by the scheduler
397
+ latents = latents * self.scheduler.init_noise_sigma
398
+ return latents
399
+
400
+ @torch.no_grad()
401
+ def __call__(
402
+ self,
403
+ prompt: Union[str, List[str]],
404
+ video_length: Optional[int],
405
+ height: Optional[int] = None,
406
+ width: Optional[int] = None,
407
+ num_inference_steps: int = 50,
408
+ guidance_scale_txt: float = 7.5,
409
+ guidance_scale_img: float = 2.0,
410
+ negative_prompt: Optional[Union[str, List[str]]] = None,
411
+ num_videos_per_prompt: Optional[int] = 1,
412
+ eta: float = 0.0,
413
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
414
+ latents: Optional[torch.FloatTensor] = None,
415
+ output_type: Optional[str] = "tensor",
416
+ return_dict: bool = True,
417
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
418
+ callback_steps: Optional[int] = 1,
419
+ # additional
420
+ first_frame_paths: Optional[Union[str, List[str]]] = None,
421
+ first_frames: Optional[torch.FloatTensor] = None,
422
+ noise_sampling_method: str = "vanilla",
423
+ noise_alpha: float = 1.0,
424
+ guidance_rescale: float = 0.0,
425
+ frame_stride: Optional[int] = None,
426
+ autoregress_steps: int = 3,
427
+ use_frameinit: bool = False,
428
+ frameinit_noise_level: int = 999,
429
+ **kwargs,
430
+ ):
431
+ if first_frame_paths is not None and first_frames is not None:
432
+ raise ValueError("Only one of `first_frame_paths` and `first_frames` can be passed.")
433
+ # Default height and width to unet
434
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
435
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
436
+
437
+ # Check inputs. Raise error if not correct
438
+ self.check_inputs(prompt, height, width, callback_steps, first_frame_paths)
439
+
440
+ # Define call parameters
441
+ # batch_size = 1 if isinstance(prompt, str) else len(prompt)
442
+ batch_size = 1
443
+ if latents is not None:
444
+ batch_size = latents.shape[0]
445
+ if isinstance(prompt, list):
446
+ batch_size = len(prompt)
447
+ first_frame_input = first_frame_paths if first_frame_paths is not None else first_frames
448
+ if first_frame_input is not None:
449
+ assert len(prompt) == len(first_frame_input), "prompt and first_frame_paths should have the same length"
450
+
451
+ device = self._execution_device
452
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
453
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
454
+ # corresponds to doing no classifier free guidance.
455
+ do_classifier_free_guidance = None
456
+ # two guidance mode: text and text+image
457
+ if guidance_scale_txt > 1.0:
458
+ do_classifier_free_guidance = "text"
459
+ if guidance_scale_img > 1.0:
460
+ do_classifier_free_guidance = "both"
461
+
462
+ # Encode input prompt
463
+ prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
464
+ if negative_prompt is not None:
465
+ negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
466
+ text_embeddings = self._encode_prompt(
467
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
468
+ )
469
+
470
+ # Encode input first frame
471
+ first_frame_latents = None
472
+ if first_frame_paths is not None:
473
+ first_frame_paths = first_frame_paths if isinstance(first_frame_paths, list) else [first_frame_paths] * batch_size
474
+ img_transform = T.Compose([
475
+ T.ToTensor(),
476
+ T.Resize(height, antialias=None),
477
+ T.CenterCrop((height, width)),
478
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
479
+ ])
480
+ first_frames = []
481
+ for first_frame_path in first_frame_paths:
482
+ first_frame = Image.open(first_frame_path).convert('RGB')
483
+ first_frame = img_transform(first_frame).unsqueeze(0)
484
+ first_frames.append(first_frame)
485
+ first_frames = torch.cat(first_frames, dim=0)
486
+ if first_frames is not None:
487
+ first_frames = first_frames.to(device, dtype=self.vae.dtype)
488
+ first_frame_latents = self.vae.encode(first_frames).latent_dist
489
+ first_frame_latents = first_frame_latents.sample()
490
+ first_frame_latents = first_frame_latents * self.vae.config.scaling_factor # b, c, h, w
491
+ first_frame_latents = repeat(first_frame_latents, "b c h w -> (b n) c h w", n=num_videos_per_prompt)
492
+ first_frames = repeat(first_frames, "b c h w -> (b n) c h w", n=num_videos_per_prompt)
493
+
494
+ full_video_latent = torch.zeros(batch_size * num_videos_per_prompt, self.unet.config.in_channels, video_length * autoregress_steps - autoregress_steps + 1, height // self.vae_scale_factor, width // self.vae_scale_factor, device=device, dtype=self.vae.dtype)
495
+
496
+ start_idx = 0
497
+ for ar_step in range(autoregress_steps):
498
+ # Prepare timesteps
499
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
500
+ timesteps = self.scheduler.timesteps
501
+
502
+ # Prepare latent variables
503
+ num_channels_latents = self.unet.config.in_channels
504
+ latents = self.prepare_latents(
505
+ batch_size * num_videos_per_prompt,
506
+ num_channels_latents,
507
+ video_length,
508
+ height,
509
+ width,
510
+ text_embeddings.dtype,
511
+ device,
512
+ generator,
513
+ latents,
514
+ noise_sampling_method,
515
+ noise_alpha,
516
+ )
517
+ latents_dtype = latents.dtype
518
+
519
+ if use_frameinit:
520
+ current_diffuse_timestep = frameinit_noise_level # diffuse to noise level
521
+ diffuse_timesteps = torch.full((batch_size,),int(current_diffuse_timestep))
522
+ diffuse_timesteps = diffuse_timesteps.long()
523
+ first_frames_static_vid = repeat(first_frame_latents, "b c h w -> b c t h w", t=video_length)
524
+ z_T = self.scheduler.add_noise(
525
+ original_samples=first_frames_static_vid.to(device),
526
+ noise=latents.to(device),
527
+ timesteps=diffuse_timesteps.to(device)
528
+ )
529
+ latents = freq_mix_3d(z_T.to(dtype=torch.float32), latents, LPF=self.freq_filter)
530
+ latents = latents.to(dtype=latents_dtype)
531
+
532
+ if first_frame_latents is not None:
533
+ first_frame_noisy_latent = latents[:, :, 0, :, :]
534
+ latents = latents[:, :, 1:, :, :]
535
+
536
+ # Prepare extra step kwargs.
537
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
538
+
539
+ # Denoising loop
540
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
541
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
542
+ for i, t in enumerate(timesteps):
543
+ # expand the latents if we are doing classifier free guidance
544
+ if do_classifier_free_guidance is None:
545
+ latent_model_input = latents
546
+ elif do_classifier_free_guidance == "text":
547
+ latent_model_input = torch.cat([latents] * 2)
548
+ elif do_classifier_free_guidance == "both":
549
+ latent_model_input = torch.cat([latents] * 3)
550
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
551
+ if first_frame_latents is not None:
552
+ if do_classifier_free_guidance is None:
553
+ first_frame_latents_input = first_frame_latents
554
+ elif do_classifier_free_guidance == "text":
555
+ first_frame_latents_input = torch.cat([first_frame_latents] * 2)
556
+ elif do_classifier_free_guidance == "both":
557
+ first_frame_latents_input = torch.cat([first_frame_noisy_latent, first_frame_latents, first_frame_latents])
558
+
559
+ first_frame_latents_input = first_frame_latents_input.unsqueeze(2)
560
+
561
+ # predict the noise residual
562
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, first_frame_latents=first_frame_latents_input, frame_stride=frame_stride).sample.to(dtype=latents_dtype)
563
+ else:
564
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
565
+ # noise_pred = []
566
+ # import pdb
567
+ # pdb.set_trace()
568
+ # for batch_idx in range(latent_model_input.shape[0]):
569
+ # noise_pred_single = self.unet(latent_model_input[batch_idx:batch_idx+1], t, encoder_hidden_states=text_embeddings[batch_idx:batch_idx+1]).sample.to(dtype=latents_dtype)
570
+ # noise_pred.append(noise_pred_single)
571
+ # noise_pred = torch.cat(noise_pred)
572
+
573
+ # perform guidance
574
+ if do_classifier_free_guidance:
575
+ if do_classifier_free_guidance == "text":
576
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
577
+ noise_pred = noise_pred_uncond + guidance_scale_txt * (noise_pred_text - noise_pred_uncond)
578
+ elif do_classifier_free_guidance == "both":
579
+ noise_pred_uncond, noise_pred_img, noise_pred_both = noise_pred.chunk(3)
580
+ noise_pred = noise_pred_uncond + guidance_scale_img * (noise_pred_img - noise_pred_uncond) + guidance_scale_txt * (noise_pred_both - noise_pred_img)
581
+
582
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
583
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
584
+ # currently only support text guidance
585
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
586
+
587
+ # compute the previous noisy sample x_t -> x_t-1
588
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
589
+
590
+ # call the callback, if provided
591
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
592
+ progress_bar.update()
593
+ if callback is not None and i % callback_steps == 0:
594
+ callback(i, t, latents)
595
+
596
+ # Post-processing
597
+
598
+ latents = torch.cat([first_frame_latents.unsqueeze(2), latents], dim=2)
599
+ first_frame_latents = latents[:, :, -1, :, :]
600
+ full_video_latent[:, :, start_idx:start_idx + video_length, :, :] = latents
601
+
602
+ latents = None
603
+ start_idx += (video_length - 1)
604
+
605
+ # video = self.decode_latents(latents, first_frames)
606
+ video = self.decode_latents(full_video_latent)
607
+
608
+ # Convert to tensor
609
+ if output_type == "tensor":
610
+ video = torch.from_numpy(video)
611
+
612
+ if not return_dict:
613
+ return video
614
+
615
+ return AnimationPipelineOutput(videos=video)
consisti2v/pipelines/pipeline_conditional_animation.py ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2
+
3
+ import inspect
4
+ from typing import Callable, List, Optional, Union
5
+ from dataclasses import dataclass
6
+
7
+ import math
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+ from torchvision import transforms as T
13
+ from torchvision.transforms import functional as F
14
+ from PIL import Image
15
+
16
+ from diffusers.utils import is_accelerate_available
17
+ from packaging import version
18
+ from transformers import CLIPTextModel, CLIPTokenizer
19
+
20
+ from diffusers.configuration_utils import FrozenDict
21
+ from diffusers.models import AutoencoderKL
22
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
23
+ from diffusers.schedulers import (
24
+ DDIMScheduler,
25
+ DPMSolverMultistepScheduler,
26
+ EulerAncestralDiscreteScheduler,
27
+ EulerDiscreteScheduler,
28
+ LMSDiscreteScheduler,
29
+ PNDMScheduler,
30
+ )
31
+ from diffusers.utils import deprecate, logging, BaseOutput
32
+
33
+ from einops import rearrange, repeat
34
+
35
+ from ..models.videoldm_unet import VideoLDMUNet3DConditionModel
36
+
37
+ from ..utils.frameinit_utils import get_freq_filter, freq_mix_3d
38
+
39
+
40
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
+
42
+ # copied from https://github.com/huggingface/diffusers/blob/v0.23.0/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L59C1-L70C21
43
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
44
+ """
45
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
46
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
47
+ """
48
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
49
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
50
+ # rescale the results from guidance (fixes overexposure)
51
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
52
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
53
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
54
+ return noise_cfg
55
+
56
+ def pan_right(image, num_frames=16, crop_width=256):
57
+ frames = []
58
+ height, width = image.shape[-2:]
59
+
60
+ for i in range(num_frames):
61
+ # Calculate the start position of the crop
62
+ start_x = int((width - crop_width) * (i / num_frames))
63
+ crop = F.crop(image, 0, start_x, height, crop_width)
64
+ frames.append(crop.unsqueeze(0))
65
+
66
+ return torch.cat(frames, dim=0)
67
+
68
+
69
+ def pan_left(image, num_frames=16, crop_width=256):
70
+ frames = []
71
+ height, width = image.shape[-2:]
72
+
73
+ for i in range(num_frames):
74
+ # Start position moves from right to left
75
+ start_x = int((width - crop_width) * (1 - (i / num_frames)))
76
+ crop = F.crop(image, 0, start_x, height, crop_width)
77
+ frames.append(crop.unsqueeze(0))
78
+
79
+ return torch.cat(frames, dim=0)
80
+
81
+
82
+ def zoom_in(image, num_frames=16, crop_width=256, ratio=1.5):
83
+ frames = []
84
+ height, width = image.shape[-2:]
85
+ max_crop_size = min(width, height)
86
+
87
+ for i in range(num_frames):
88
+ # Calculate the size of the crop
89
+ crop_size = max_crop_size - int((max_crop_size - max_crop_size // ratio) * (i / num_frames))
90
+ start_x = (width - crop_size) // 2
91
+ start_y = (height - crop_size) // 2
92
+ crop = F.crop(image, start_y, start_x, crop_size, crop_size)
93
+ resized_crop = F.resize(crop, (crop_width, crop_width), antialias=None) # Resize back to original size
94
+ frames.append(resized_crop.unsqueeze(0))
95
+
96
+ return torch.cat(frames, dim=0)
97
+
98
+
99
+ def zoom_out(image, num_frames=16, crop_width=256, ratio=1.5):
100
+ frames = []
101
+ height, width = image.shape[-2:]
102
+ min_crop_size = min(width, height) // ratio # Starting from a quarter of the size
103
+
104
+ for i in range(num_frames):
105
+ # Calculate the size of the crop
106
+ crop_size = min_crop_size + int((min(width, height) - min_crop_size) * (i / num_frames))
107
+ start_x = (width - crop_size) // 2
108
+ start_y = (height - crop_size) // 2
109
+ crop = F.crop(image, start_y, start_x, crop_size, crop_size)
110
+ resized_crop = F.resize(crop, (crop_width, crop_width), antialias=None) # Resize back to original size
111
+ frames.append(resized_crop.unsqueeze(0))
112
+
113
+ return torch.cat(frames, dim=0)
114
+
115
+
116
+ @dataclass
117
+ class AnimationPipelineOutput(BaseOutput):
118
+ videos: Union[torch.Tensor, np.ndarray]
119
+
120
+
121
+ class ConditionalAnimationPipeline(DiffusionPipeline):
122
+ _optional_components = []
123
+
124
+ def __init__(
125
+ self,
126
+ vae: AutoencoderKL,
127
+ text_encoder: CLIPTextModel,
128
+ tokenizer: CLIPTokenizer,
129
+ unet: VideoLDMUNet3DConditionModel,
130
+ scheduler: Union[
131
+ DDIMScheduler,
132
+ PNDMScheduler,
133
+ LMSDiscreteScheduler,
134
+ EulerDiscreteScheduler,
135
+ EulerAncestralDiscreteScheduler,
136
+ DPMSolverMultistepScheduler,
137
+ ],
138
+ ):
139
+ super().__init__()
140
+
141
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
142
+ deprecation_message = (
143
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
144
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
145
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
146
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
147
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
148
+ " file"
149
+ )
150
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
151
+ new_config = dict(scheduler.config)
152
+ new_config["steps_offset"] = 1
153
+ scheduler._internal_dict = FrozenDict(new_config)
154
+
155
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
156
+ deprecation_message = (
157
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
158
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
159
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
160
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
161
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
162
+ )
163
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
164
+ new_config = dict(scheduler.config)
165
+ new_config["clip_sample"] = False
166
+ scheduler._internal_dict = FrozenDict(new_config)
167
+
168
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
169
+ version.parse(unet.config._diffusers_version).base_version
170
+ ) < version.parse("0.9.0.dev0")
171
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
172
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
173
+ deprecation_message = (
174
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
175
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
176
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
177
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
178
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
179
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
180
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
181
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
182
+ " the `unet/config.json` file"
183
+ )
184
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
185
+ new_config = dict(unet.config)
186
+ new_config["sample_size"] = 64
187
+ unet._internal_dict = FrozenDict(new_config)
188
+
189
+ self.register_modules(
190
+ vae=vae,
191
+ text_encoder=text_encoder,
192
+ tokenizer=tokenizer,
193
+ unet=unet,
194
+ scheduler=scheduler,
195
+ )
196
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
197
+
198
+ self.freq_filter = None
199
+
200
+ @torch.no_grad()
201
+ def init_filter(self, video_length, height, width, filter_params):
202
+ # initialize frequency filter for noise reinitialization
203
+ batch_size = 1
204
+ num_channels_latents = self.unet.config.in_channels
205
+ filter_shape = [
206
+ batch_size,
207
+ num_channels_latents,
208
+ video_length,
209
+ height // self.vae_scale_factor,
210
+ width // self.vae_scale_factor
211
+ ]
212
+ # self.freq_filter = get_freq_filter(filter_shape, device=self._execution_device, params=filter_params)
213
+ self.freq_filter = get_freq_filter(
214
+ filter_shape,
215
+ device=self._execution_device,
216
+ filter_type=filter_params.method,
217
+ n=filter_params.n if filter_params.method=="butterworth" else None,
218
+ d_s=filter_params.d_s,
219
+ d_t=filter_params.d_t
220
+ )
221
+
222
+ def enable_vae_slicing(self):
223
+ self.vae.enable_slicing()
224
+
225
+ def disable_vae_slicing(self):
226
+ self.vae.disable_slicing()
227
+
228
+ def enable_sequential_cpu_offload(self, gpu_id=0):
229
+ if is_accelerate_available():
230
+ from accelerate import cpu_offload
231
+ else:
232
+ raise ImportError("Please install accelerate via `pip install accelerate`")
233
+
234
+ device = torch.device(f"cuda:{gpu_id}")
235
+
236
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
237
+ if cpu_offloaded_model is not None:
238
+ cpu_offload(cpu_offloaded_model, device)
239
+
240
+
241
+ @property
242
+ def _execution_device(self):
243
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
244
+ return self.device
245
+ for module in self.unet.modules():
246
+ if (
247
+ hasattr(module, "_hf_hook")
248
+ and hasattr(module._hf_hook, "execution_device")
249
+ and module._hf_hook.execution_device is not None
250
+ ):
251
+ return torch.device(module._hf_hook.execution_device)
252
+ return self.device
253
+
254
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
255
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
256
+
257
+ text_inputs = self.tokenizer(
258
+ prompt,
259
+ padding="max_length",
260
+ max_length=self.tokenizer.model_max_length,
261
+ truncation=True,
262
+ return_tensors="pt",
263
+ )
264
+ text_input_ids = text_inputs.input_ids
265
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
266
+
267
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
268
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
269
+ logger.warning(
270
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
271
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
272
+ )
273
+
274
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
275
+ attention_mask = text_inputs.attention_mask.to(device)
276
+ else:
277
+ attention_mask = None
278
+
279
+ text_embeddings = self.text_encoder(
280
+ text_input_ids.to(device),
281
+ attention_mask=attention_mask,
282
+ )
283
+ text_embeddings = text_embeddings[0]
284
+
285
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
286
+ bs_embed, seq_len, _ = text_embeddings.shape
287
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
288
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
289
+
290
+ # get unconditional embeddings for classifier free guidance
291
+ if do_classifier_free_guidance is not None:
292
+ uncond_tokens: List[str]
293
+ if negative_prompt is None:
294
+ uncond_tokens = [""] * batch_size
295
+ elif type(prompt) is not type(negative_prompt):
296
+ raise TypeError(
297
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
298
+ f" {type(prompt)}."
299
+ )
300
+ elif isinstance(negative_prompt, str):
301
+ uncond_tokens = [negative_prompt]
302
+ elif batch_size != len(negative_prompt):
303
+ raise ValueError(
304
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
305
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
306
+ " the batch size of `prompt`."
307
+ )
308
+ else:
309
+ uncond_tokens = negative_prompt
310
+
311
+ max_length = text_input_ids.shape[-1]
312
+ uncond_input = self.tokenizer(
313
+ uncond_tokens,
314
+ padding="max_length",
315
+ max_length=max_length,
316
+ truncation=True,
317
+ return_tensors="pt",
318
+ )
319
+
320
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
321
+ attention_mask = uncond_input.attention_mask.to(device)
322
+ else:
323
+ attention_mask = None
324
+
325
+ uncond_embeddings = self.text_encoder(
326
+ uncond_input.input_ids.to(device),
327
+ attention_mask=attention_mask,
328
+ )
329
+ uncond_embeddings = uncond_embeddings[0]
330
+
331
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
332
+ seq_len = uncond_embeddings.shape[1]
333
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
334
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
335
+
336
+ # For classifier free guidance, we need to do two forward passes.
337
+ # Here we concatenate the unconditional and text embeddings into a single batch
338
+ # to avoid doing two forward passes
339
+ if do_classifier_free_guidance == "text":
340
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
341
+ elif do_classifier_free_guidance == "both":
342
+ text_embeddings = torch.cat([uncond_embeddings, uncond_embeddings, text_embeddings])
343
+
344
+ return text_embeddings
345
+
346
+ def decode_latents(self, latents, first_frames=None):
347
+ video_length = latents.shape[2]
348
+ latents = 1 / self.vae.config.scaling_factor * latents
349
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
350
+ # video = self.vae.decode(latents).sample
351
+ video = []
352
+ for frame_idx in tqdm(range(latents.shape[0]), **self._progress_bar_config):
353
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
354
+ video = torch.cat(video)
355
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
356
+
357
+ if first_frames is not None:
358
+ first_frames = first_frames.unsqueeze(2)
359
+ video = torch.cat([first_frames, video], dim=2)
360
+
361
+ video = (video / 2 + 0.5).clamp(0, 1)
362
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
363
+ video = video.cpu().float().numpy()
364
+ return video
365
+
366
+ def prepare_extra_step_kwargs(self, generator, eta):
367
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
368
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
369
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
370
+ # and should be between [0, 1]
371
+
372
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
373
+ extra_step_kwargs = {}
374
+ if accepts_eta:
375
+ extra_step_kwargs["eta"] = eta
376
+
377
+ # check if the scheduler accepts generator
378
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
379
+ if accepts_generator:
380
+ extra_step_kwargs["generator"] = generator
381
+ return extra_step_kwargs
382
+
383
+ def check_inputs(self, prompt, height, width, callback_steps, first_frame_paths=None):
384
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
385
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
386
+
387
+ if first_frame_paths is not None and (not isinstance(prompt, str) and not isinstance(first_frame_paths, list)):
388
+ raise ValueError(f"`first_frame_paths` has to be of type `str` or `list` but is {type(first_frame_paths)}")
389
+
390
+ if height % 8 != 0 or width % 8 != 0:
391
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
392
+
393
+ if (callback_steps is None) or (
394
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
395
+ ):
396
+ raise ValueError(
397
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
398
+ f" {type(callback_steps)}."
399
+ )
400
+
401
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None, noise_sampling_method="vanilla", noise_alpha=1.0):
402
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
403
+ if isinstance(generator, list) and len(generator) != batch_size:
404
+ raise ValueError(
405
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
406
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
407
+ )
408
+ if latents is None:
409
+ rand_device = "cpu" if device.type == "mps" else device
410
+
411
+ if isinstance(generator, list):
412
+ # shape = shape
413
+ shape = (1,) + shape[1:]
414
+ if noise_sampling_method == "vanilla":
415
+ latents = [
416
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
417
+ for i in range(batch_size)
418
+ ]
419
+ elif noise_sampling_method == "pyoco_mixed":
420
+ base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
421
+ latents = []
422
+ noise_alpha_squared = noise_alpha ** 2
423
+ for i in range(batch_size):
424
+ base_latent = torch.randn(base_shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared))
425
+ ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
426
+ latents.append(base_latent + ind_latent)
427
+ elif noise_sampling_method == "pyoco_progressive":
428
+ latents = []
429
+ noise_alpha_squared = noise_alpha ** 2
430
+ for i in range(batch_size):
431
+ latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
432
+ ind_latent = torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
433
+ for j in range(1, video_length):
434
+ latent[:, :, j, :, :] = latent[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latent[:, :, j, :, :]
435
+ latents.append(latent)
436
+ latents = torch.cat(latents, dim=0).to(device)
437
+ else:
438
+ if noise_sampling_method == "vanilla":
439
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
440
+ elif noise_sampling_method == "pyoco_mixed":
441
+ noise_alpha_squared = noise_alpha ** 2
442
+ base_shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
443
+ base_latents = torch.randn(base_shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared))
444
+ ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
445
+ latents = base_latents + ind_latents
446
+ elif noise_sampling_method == "pyoco_progressive":
447
+ noise_alpha_squared = noise_alpha ** 2
448
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype)
449
+ ind_latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype) * math.sqrt(1 / (1 + noise_alpha_squared))
450
+ for j in range(1, video_length):
451
+ latents[:, :, j, :, :] = latents[:, :, j - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_latents[:, :, j, :, :]
452
+ else:
453
+ if latents.shape != shape:
454
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
455
+ latents = latents.to(device)
456
+
457
+ # scale the initial noise by the standard deviation required by the scheduler
458
+ latents = latents * self.scheduler.init_noise_sigma
459
+ return latents
460
+
461
+ @torch.no_grad()
462
+ def __call__(
463
+ self,
464
+ prompt: Union[str, List[str]],
465
+ video_length: Optional[int],
466
+ height: Optional[int] = None,
467
+ width: Optional[int] = None,
468
+ num_inference_steps: int = 50,
469
+ guidance_scale_txt: float = 7.5,
470
+ guidance_scale_img: float = 2.0,
471
+ negative_prompt: Optional[Union[str, List[str]]] = None,
472
+ num_videos_per_prompt: Optional[int] = 1,
473
+ eta: float = 0.0,
474
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
475
+ latents: Optional[torch.FloatTensor] = None,
476
+ output_type: Optional[str] = "tensor",
477
+ return_dict: bool = True,
478
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
479
+ callback_steps: Optional[int] = 1,
480
+ # additional
481
+ first_frame_paths: Optional[Union[str, List[str]]] = None,
482
+ first_frames: Optional[torch.FloatTensor] = None,
483
+ noise_sampling_method: str = "vanilla",
484
+ noise_alpha: float = 1.0,
485
+ guidance_rescale: float = 0.0,
486
+ frame_stride: Optional[int] = None,
487
+ use_frameinit: bool = False,
488
+ frameinit_noise_level: int = 999,
489
+ camera_motion: str = None,
490
+ **kwargs,
491
+ ):
492
+ if first_frame_paths is not None and first_frames is not None:
493
+ raise ValueError("Only one of `first_frame_paths` and `first_frames` can be passed.")
494
+ # Default height and width to unet
495
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
496
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
497
+
498
+ # Check inputs. Raise error if not correct
499
+ self.check_inputs(prompt, height, width, callback_steps, first_frame_paths)
500
+
501
+ # Define call parameters
502
+ # batch_size = 1 if isinstance(prompt, str) else len(prompt)
503
+ batch_size = 1
504
+ if latents is not None:
505
+ batch_size = latents.shape[0]
506
+ if isinstance(prompt, list):
507
+ batch_size = len(prompt)
508
+ first_frame_input = first_frame_paths if first_frame_paths is not None else first_frames
509
+ if first_frame_input is not None:
510
+ assert len(prompt) == len(first_frame_input), "prompt and first_frame_paths should have the same length"
511
+
512
+ device = self._execution_device
513
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
514
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
515
+ # corresponds to doing no classifier free guidance.
516
+ do_classifier_free_guidance = None
517
+ # two guidance mode: text and text+image
518
+ if guidance_scale_txt > 1.0:
519
+ do_classifier_free_guidance = "text"
520
+ if guidance_scale_img > 1.0:
521
+ do_classifier_free_guidance = "both"
522
+
523
+ # Encode input prompt
524
+ prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
525
+ if negative_prompt is not None:
526
+ negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
527
+ text_embeddings = self._encode_prompt(
528
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
529
+ )
530
+
531
+ # Encode input first frame
532
+ first_frame_latents = None
533
+ if first_frame_paths is not None:
534
+ first_frame_paths = first_frame_paths if isinstance(first_frame_paths, list) else [first_frame_paths] * batch_size
535
+ if camera_motion is None:
536
+ img_transform = T.Compose([
537
+ T.ToTensor(),
538
+ T.Resize(height, antialias=None),
539
+ T.CenterCrop((height, width)),
540
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
541
+ ])
542
+ elif camera_motion == "pan_left" or camera_motion == "pan_right":
543
+ img_transform = T.Compose([
544
+ T.ToTensor(),
545
+ T.Resize(height, antialias=None),
546
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
547
+ ])
548
+ elif camera_motion == "zoom_out" or camera_motion == "zoom_in":
549
+ img_transform = T.Compose([
550
+ T.ToTensor(),
551
+ T.Resize(height * 2, antialias=None),
552
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
553
+ ])
554
+
555
+ first_frames = []
556
+ for first_frame_path in first_frame_paths:
557
+ first_frame = Image.open(first_frame_path).convert('RGB')
558
+ first_frame = img_transform(first_frame)
559
+ if camera_motion is not None:
560
+ if camera_motion == "pan_left":
561
+ first_frame = pan_left(first_frame, num_frames=video_length, crop_width=width)
562
+ elif camera_motion == "pan_right":
563
+ first_frame = pan_right(first_frame, num_frames=video_length, crop_width=width)
564
+ elif camera_motion == "zoom_in":
565
+ first_frame = zoom_in(first_frame, num_frames=video_length, crop_width=width)
566
+ elif camera_motion == "zoom_out":
567
+ first_frame = zoom_out(first_frame, num_frames=video_length, crop_width=width)
568
+ else:
569
+ raise NotImplementedError(f"camera_motion: {camera_motion} is not implemented.")
570
+ first_frames.append(first_frame.unsqueeze(0))
571
+ first_frames = torch.cat(first_frames, dim=0)
572
+ if first_frames is not None:
573
+ first_frames = first_frames.to(device, dtype=self.vae.dtype)
574
+ if camera_motion is not None:
575
+ first_frames = rearrange(first_frames, "b f c h w -> (b f) c h w")
576
+ first_frame_latents = self.vae.encode(first_frames).latent_dist
577
+ first_frame_latents = first_frame_latents.sample()
578
+ first_frame_latents = first_frame_latents * self.vae.config.scaling_factor # b, c, h, w
579
+ first_frame_static_vid = rearrange(first_frame_latents, "(b f) c h w -> b c f h w", f=video_length if camera_motion is not None else 1)
580
+ first_frame_latents = first_frame_static_vid[:, :, 0, :, :]
581
+ first_frame_latents = repeat(first_frame_latents, "b c h w -> (b n) c h w", n=num_videos_per_prompt)
582
+ first_frames = repeat(first_frames, "b c h w -> (b n) c h w", n=num_videos_per_prompt)
583
+
584
+ if use_frameinit and camera_motion is None:
585
+ first_frame_static_vid = repeat(first_frame_static_vid, "b c 1 h w -> b c t h w", t=video_length)
586
+
587
+ # self._progress_bar_config = {}
588
+ # vid = self.decode_latents(first_frame_static_vid)
589
+ # vid = torch.from_numpy(vid)
590
+ # from ..utils.util import save_videos_grid
591
+ # save_videos_grid(vid, "samples/debug/camera_motion/first_frame_static_vid.mp4", fps=8)
592
+
593
+ # Prepare timesteps
594
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
595
+ timesteps = self.scheduler.timesteps
596
+
597
+ # Prepare latent variables
598
+ num_channels_latents = self.unet.config.in_channels
599
+ latents = self.prepare_latents(
600
+ batch_size * num_videos_per_prompt,
601
+ num_channels_latents,
602
+ video_length,
603
+ height,
604
+ width,
605
+ text_embeddings.dtype,
606
+ device,
607
+ generator,
608
+ latents,
609
+ noise_sampling_method,
610
+ noise_alpha,
611
+ )
612
+ latents_dtype = latents.dtype
613
+
614
+ if use_frameinit:
615
+ current_diffuse_timestep = frameinit_noise_level # diffuse to t noise level
616
+ diffuse_timesteps = torch.full((batch_size,),int(current_diffuse_timestep))
617
+ diffuse_timesteps = diffuse_timesteps.long()
618
+ z_T = self.scheduler.add_noise(
619
+ original_samples=first_frame_static_vid.to(device),
620
+ noise=latents.to(device),
621
+ timesteps=diffuse_timesteps.to(device)
622
+ )
623
+ latents = freq_mix_3d(z_T.to(dtype=torch.float32), latents.to(dtype=torch.float32), LPF=self.freq_filter)
624
+ latents = latents.to(dtype=latents_dtype)
625
+
626
+ if first_frame_latents is not None:
627
+ first_frame_noisy_latent = latents[:, :, 0, :, :]
628
+ latents = latents[:, :, 1:, :, :]
629
+
630
+ # Prepare extra step kwargs.
631
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
632
+
633
+ # Denoising loop
634
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
635
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
636
+ for i, t in enumerate(timesteps):
637
+ # expand the latents if we are doing classifier free guidance
638
+ if do_classifier_free_guidance is None:
639
+ latent_model_input = latents
640
+ elif do_classifier_free_guidance == "text":
641
+ latent_model_input = torch.cat([latents] * 2)
642
+ elif do_classifier_free_guidance == "both":
643
+ latent_model_input = torch.cat([latents] * 3)
644
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
645
+ if first_frame_latents is not None:
646
+ if do_classifier_free_guidance is None:
647
+ first_frame_latents_input = first_frame_latents
648
+ elif do_classifier_free_guidance == "text":
649
+ first_frame_latents_input = torch.cat([first_frame_latents] * 2)
650
+ elif do_classifier_free_guidance == "both":
651
+ first_frame_latents_input = torch.cat([first_frame_noisy_latent, first_frame_latents, first_frame_latents])
652
+
653
+ first_frame_latents_input = first_frame_latents_input.unsqueeze(2)
654
+
655
+ # predict the noise residual
656
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, first_frame_latents=first_frame_latents_input, frame_stride=frame_stride).sample.to(dtype=latents_dtype)
657
+ else:
658
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
659
+
660
+ # perform guidance
661
+ if do_classifier_free_guidance:
662
+ if do_classifier_free_guidance == "text":
663
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
664
+ noise_pred = noise_pred_uncond + guidance_scale_txt * (noise_pred_text - noise_pred_uncond)
665
+ elif do_classifier_free_guidance == "both":
666
+ noise_pred_uncond, noise_pred_img, noise_pred_both = noise_pred.chunk(3)
667
+ noise_pred = noise_pred_uncond + guidance_scale_img * (noise_pred_img - noise_pred_uncond) + guidance_scale_txt * (noise_pred_both - noise_pred_img)
668
+
669
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
670
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
671
+ # currently only support text guidance
672
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
673
+
674
+ # compute the previous noisy sample x_t -> x_t-1
675
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
676
+
677
+ # call the callback, if provided
678
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
679
+ progress_bar.update()
680
+ if callback is not None and i % callback_steps == 0:
681
+ callback(i, t, latents)
682
+
683
+ # Post-processing
684
+ latents = torch.cat([first_frame_latents.unsqueeze(2), latents], dim=2)
685
+ # video = self.decode_latents(latents, first_frames)
686
+ video = self.decode_latents(latents)
687
+
688
+ # Convert to tensor
689
+ if output_type == "tensor":
690
+ video = torch.from_numpy(video)
691
+
692
+ if not return_dict:
693
+ return video
694
+
695
+ return AnimationPipelineOutput(videos=video)
consisti2v/utils/frameinit_utils.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/TianxingWu/FreeInit/blob/master/freeinit_utils.py
2
+ import torch
3
+ import torch.fft as fft
4
+ import math
5
+
6
+
7
+ def freq_mix_3d(x, noise, LPF):
8
+ """
9
+ Noise reinitialization.
10
+
11
+ Args:
12
+ x: diffused latent
13
+ noise: randomly sampled noise
14
+ LPF: low pass filter
15
+ """
16
+ # FFT
17
+ x_freq = fft.fftn(x, dim=(-3, -2, -1))
18
+ x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
19
+ noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
20
+ noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))
21
+
22
+ # frequency mix
23
+ HPF = 1 - LPF
24
+ x_freq_low = x_freq * LPF
25
+ noise_freq_high = noise_freq * HPF
26
+ x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain
27
+
28
+ # IFFT
29
+ x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
30
+ x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real
31
+
32
+ return x_mixed
33
+
34
+
35
+ def get_freq_filter(shape, device, filter_type, n, d_s, d_t):
36
+ """
37
+ Form the frequency filter for noise reinitialization.
38
+
39
+ Args:
40
+ shape: shape of latent (B, C, T, H, W)
41
+ filter_type: type of the freq filter
42
+ n: (only for butterworth) order of the filter, larger n ~ ideal, smaller n ~ gaussian
43
+ d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
44
+ d_t: normalized stop frequency for temporal dimension (0.0-1.0)
45
+ """
46
+ if filter_type == "gaussian":
47
+ return gaussian_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)
48
+ elif filter_type == "ideal":
49
+ return ideal_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)
50
+ elif filter_type == "box":
51
+ return box_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device)
52
+ elif filter_type == "butterworth":
53
+ return butterworth_low_pass_filter(shape=shape, n=n, d_s=d_s, d_t=d_t).to(device)
54
+ else:
55
+ raise NotImplementedError
56
+
57
+
58
+ def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25):
59
+ """
60
+ Compute the gaussian low pass filter mask.
61
+
62
+ Args:
63
+ shape: shape of the filter (volume)
64
+ d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
65
+ d_t: normalized stop frequency for temporal dimension (0.0-1.0)
66
+ """
67
+ T, H, W = shape[-3], shape[-2], shape[-1]
68
+ mask = torch.zeros(shape)
69
+ if d_s==0 or d_t==0:
70
+ return mask
71
+ for t in range(T):
72
+ for h in range(H):
73
+ for w in range(W):
74
+ d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
75
+ mask[..., t,h,w] = math.exp(-1/(2*d_s**2) * d_square)
76
+ return mask
77
+
78
+
79
+ def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25):
80
+ """
81
+ Compute the butterworth low pass filter mask.
82
+
83
+ Args:
84
+ shape: shape of the filter (volume)
85
+ n: order of the filter, larger n ~ ideal, smaller n ~ gaussian
86
+ d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
87
+ d_t: normalized stop frequency for temporal dimension (0.0-1.0)
88
+ """
89
+ T, H, W = shape[-3], shape[-2], shape[-1]
90
+ mask = torch.zeros(shape)
91
+ if d_s==0 or d_t==0:
92
+ return mask
93
+ for t in range(T):
94
+ for h in range(H):
95
+ for w in range(W):
96
+ d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
97
+ mask[..., t,h,w] = 1 / (1 + (d_square / d_s**2)**n)
98
+ return mask
99
+
100
+
101
+ def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25):
102
+ """
103
+ Compute the ideal low pass filter mask.
104
+
105
+ Args:
106
+ shape: shape of the filter (volume)
107
+ d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
108
+ d_t: normalized stop frequency for temporal dimension (0.0-1.0)
109
+ """
110
+ T, H, W = shape[-3], shape[-2], shape[-1]
111
+ mask = torch.zeros(shape)
112
+ if d_s==0 or d_t==0:
113
+ return mask
114
+ for t in range(T):
115
+ for h in range(H):
116
+ for w in range(W):
117
+ d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
118
+ mask[..., t,h,w] = 1 if d_square <= d_s*2 else 0
119
+ return mask
120
+
121
+
122
+ def box_low_pass_filter(shape, d_s=0.25, d_t=0.25):
123
+ """
124
+ Compute the ideal low pass filter mask (approximated version).
125
+
126
+ Args:
127
+ shape: shape of the filter (volume)
128
+ d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
129
+ d_t: normalized stop frequency for temporal dimension (0.0-1.0)
130
+ """
131
+ T, H, W = shape[-3], shape[-2], shape[-1]
132
+ mask = torch.zeros(shape)
133
+ if d_s==0 or d_t==0:
134
+ return mask
135
+
136
+ threshold_s = round(int(H // 2) * d_s)
137
+ threshold_t = round(T // 2 * d_t)
138
+
139
+ cframe, crow, ccol = T // 2, H // 2, W //2
140
+ mask[..., cframe - threshold_t:cframe + threshold_t, crow - threshold_s:crow + threshold_s, ccol - threshold_s:ccol + threshold_s] = 1.0
141
+
142
+ return mask
consisti2v/utils/util.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ from typing import Union
5
+
6
+ import torch
7
+ import torchvision
8
+ import torch.distributed as dist
9
+ import wandb
10
+
11
+ from tqdm import tqdm
12
+ from einops import rearrange
13
+
14
+ from torchmetrics.image.fid import _compute_fid
15
+
16
+
17
+ def zero_rank_print(s):
18
+ if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
19
+
20
+
21
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, wandb=False, global_step=0, format="gif"):
22
+ videos = rearrange(videos, "b c t h w -> t b c h w")
23
+ outputs = []
24
+ for x in videos:
25
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
26
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
27
+ if rescale:
28
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
29
+ x = (x * 255).numpy().astype(np.uint8)
30
+ outputs.append(x)
31
+
32
+ if wandb:
33
+ wandb_video = wandb.Video(outputs, fps=fps)
34
+ wandb.log({"val_videos": wandb_video}, step=global_step)
35
+
36
+ os.makedirs(os.path.dirname(path), exist_ok=True)
37
+ if format == "gif":
38
+ imageio.mimsave(path, outputs, fps=fps)
39
+ elif format == "mp4":
40
+ torchvision.io.write_video(path, np.array(outputs), fps=fps, video_codec='h264', options={'crf': '10'})
41
+
42
+ # DDIM Inversion
43
+ @torch.no_grad()
44
+ def init_prompt(prompt, pipeline):
45
+ uncond_input = pipeline.tokenizer(
46
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
47
+ return_tensors="pt"
48
+ )
49
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
50
+ text_input = pipeline.tokenizer(
51
+ [prompt],
52
+ padding="max_length",
53
+ max_length=pipeline.tokenizer.model_max_length,
54
+ truncation=True,
55
+ return_tensors="pt",
56
+ )
57
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
58
+ context = torch.cat([uncond_embeddings, text_embeddings])
59
+
60
+ return context
61
+
62
+
63
+ def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
64
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
65
+ timestep, next_timestep = min(
66
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
67
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
68
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
69
+ beta_prod_t = 1 - alpha_prod_t
70
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
71
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
72
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
73
+ return next_sample
74
+
75
+
76
+ def get_noise_pred_single(latents, t, context, first_frame_latents, frame_stride, unet):
77
+ noise_pred = unet(latents, t, encoder_hidden_states=context, first_frame_latents=first_frame_latents, frame_stride=frame_stride).sample
78
+ return noise_pred
79
+
80
+
81
+ @torch.no_grad()
82
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt, first_frame_latents, frame_stride):
83
+ context = init_prompt(prompt, pipeline)
84
+ uncond_embeddings, cond_embeddings = context.chunk(2)
85
+ all_latent = [latent]
86
+ latent = latent.clone().detach()
87
+ for i in tqdm(range(num_inv_steps)):
88
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
89
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, first_frame_latents, frame_stride, pipeline.unet)
90
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
91
+ all_latent.append(latent)
92
+ return all_latent
93
+
94
+
95
+ @torch.no_grad()
96
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt="", first_frame_latents=None, frame_stride=3):
97
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt, first_frame_latents, frame_stride)
98
+ return ddim_latents
99
+
100
+
101
+ def compute_fid(real_features, fake_features, num_features, device):
102
+ orig_dtype = real_features.dtype
103
+
104
+ mx_num_feats = (num_features, num_features)
105
+ real_features_sum = torch.zeros(num_features).double().to(device)
106
+ real_features_cov_sum = torch.zeros(mx_num_feats).double().to(device)
107
+ real_features_num_samples = torch.tensor(0).long().to(device)
108
+
109
+ fake_features_sum = torch.zeros(num_features).double().to(device)
110
+ fake_features_cov_sum = torch.zeros(mx_num_feats).double().to(device)
111
+ fake_features_num_samples = torch.tensor(0).long().to(device)
112
+
113
+ real_features = real_features.double()
114
+ fake_features = fake_features.double()
115
+
116
+ real_features_sum += real_features.sum(dim=0)
117
+ real_features_cov_sum += real_features.t().mm(real_features)
118
+ real_features_num_samples += real_features.shape[0]
119
+
120
+ fake_features_sum += fake_features.sum(dim=0)
121
+ fake_features_cov_sum += fake_features.t().mm(fake_features)
122
+ fake_features_num_samples += fake_features.shape[0]
123
+
124
+ """Calculate FID score based on accumulated extracted features from the two distributions."""
125
+ if real_features_num_samples < 2 or fake_features_num_samples < 2:
126
+ raise RuntimeError("More than one sample is required for both the real and fake distributed to compute FID")
127
+ mean_real = (real_features_sum / real_features_num_samples).unsqueeze(0)
128
+ mean_fake = (fake_features_sum / fake_features_num_samples).unsqueeze(0)
129
+
130
+ cov_real_num = real_features_cov_sum - real_features_num_samples * mean_real.t().mm(mean_real)
131
+ cov_real = cov_real_num / (real_features_num_samples - 1)
132
+ cov_fake_num = fake_features_cov_sum - fake_features_num_samples * mean_fake.t().mm(mean_fake)
133
+ cov_fake = cov_fake_num / (fake_features_num_samples - 1)
134
+ return _compute_fid(mean_real.squeeze(0), cov_real, mean_fake.squeeze(0), cov_fake).to(orig_dtype)
135
+
136
+
137
+ def compute_inception_score(gen_probs, num_splits=10):
138
+ num_gen = gen_probs.shape[0]
139
+ gen_probs = gen_probs.detach().cpu().numpy()
140
+ scores = []
141
+ np.random.RandomState(42).shuffle(gen_probs)
142
+ for i in range(num_splits):
143
+ part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
144
+ kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
145
+ kl = np.mean(np.sum(kl, axis=1))
146
+ scores.append(np.exp(kl))
147
+ return float(np.mean(scores)), float(np.std(scores))
148
+ # idx = torch.randperm(features.shape[0])
149
+ # features = features[idx]
150
+ # # calculate probs and logits
151
+ # prob = features.softmax(dim=1)
152
+ # log_prob = features.log_softmax(dim=1)
153
+
154
+ # # split into groups
155
+ # prob = prob.chunk(splits, dim=0)
156
+ # log_prob = log_prob.chunk(splits, dim=0)
157
+
158
+ # # calculate score per split
159
+ # mean_prob = [p.mean(dim=0, keepdim=True) for p in prob]
160
+ # kl_ = [p * (log_p - m_p.log()) for p, log_p, m_p in zip(prob, log_prob, mean_prob)]
161
+ # kl_ = [k.sum(dim=1).mean().exp() for k in kl_]
162
+ # kl = torch.stack(kl_)
163
+
164
+ # return mean and std
165
+ # return kl.mean(), kl.std()
environment.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: consisti2v
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python=3.10
7
+ - pytorch=2.1.0
8
+ - torchvision=0.16.0
9
+ - torchaudio=2.1.0
10
+ - pytorch-cuda=11.8
11
+ - pip
12
+ - pip:
13
+ - diffusers==0.21.2
14
+ - transformers==4.25.1
15
+ - accelerate==0.23.0
16
+ - imageio==2.27.0
17
+ - decord==0.6.0
18
+ - einops
19
+ - omegaconf
20
+ - safetensors
21
+ - gradio==3.42.0
22
+ - wandb
23
+ - moviepy
24
+ - scikit-learn
25
+ - av
26
+ - rotary_embedding_torch
27
+ - torchmetrics
28
+ - torch-fidelity
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.1.0
2
+ torchvision==0.16.0
3
+ torchaudio==2.1.0
4
+ diffusers==0.21.2
5
+ transformers==4.25.1
6
+ accelerate==0.23.0
7
+ imageio==2.27.0
8
+ decord==0.6.0
9
+ spaces
10
+ einops
11
+ omegaconf
12
+ safetensors
13
+ moviepy
14
+ scikit-learn
15
+ av
16
+ rotary_embedding_torch
17
+ torchmetrics
18
+ torch-fidelity
scripts/animate.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import random
4
+ import os
5
+ import logging
6
+ from omegaconf import OmegaConf
7
+
8
+ import torch
9
+
10
+ import diffusers
11
+ from diffusers import AutoencoderKL, DDIMScheduler
12
+
13
+ from transformers import CLIPTextModel, CLIPTokenizer
14
+
15
+ from consisti2v.models.videoldm_unet import VideoLDMUNet3DConditionModel
16
+ from consisti2v.pipelines.pipeline_conditional_animation import ConditionalAnimationPipeline
17
+ from consisti2v.utils.util import save_videos_grid
18
+ from diffusers.utils.import_utils import is_xformers_available
19
+
20
+ def main(args, config):
21
+ logging.basicConfig(
22
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
23
+ datefmt="%m/%d/%Y %H:%M:%S",
24
+ level=logging.INFO,
25
+ )
26
+ diffusers.utils.logging.set_verbosity_info()
27
+
28
+ time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
29
+ savedir = f"{config.output_dir}/{config.output_name}-{time_str}"
30
+ os.makedirs(savedir)
31
+
32
+ samples = []
33
+ sample_idx = 0
34
+
35
+ ### >>> create validation pipeline >>> ###
36
+ if config.pipeline_pretrained_path is None:
37
+ noise_scheduler = DDIMScheduler(**OmegaConf.to_container(config.noise_scheduler_kwargs))
38
+ tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer", use_safetensors=True)
39
+ text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder")
40
+ vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae", use_safetensors=True)
41
+ unet = VideoLDMUNet3DConditionModel.from_pretrained(
42
+ config.pretrained_model_path,
43
+ subfolder="unet",
44
+ variant=config.unet_additional_kwargs['variant'],
45
+ temp_pos_embedding=config.unet_additional_kwargs['temp_pos_embedding'],
46
+ augment_temporal_attention=config.unet_additional_kwargs['augment_temporal_attention'],
47
+ use_temporal=True,
48
+ n_frames=config.sampling_kwargs['n_frames'],
49
+ n_temp_heads=config.unet_additional_kwargs['n_temp_heads'],
50
+ first_frame_condition_mode=config.unet_additional_kwargs['first_frame_condition_mode'],
51
+ use_frame_stride_condition=config.unet_additional_kwargs['use_frame_stride_condition'],
52
+ use_safetensors=True
53
+ )
54
+
55
+ # 1. unet ckpt
56
+ if config.unet_path is not None:
57
+ if os.path.isdir(config.unet_path):
58
+ unet_dict = VideoLDMUNet3DConditionModel.from_pretrained(config.unet_path)
59
+ m, u = unet.load_state_dict(unet_dict.state_dict(), strict=False)
60
+ assert len(u) == 0
61
+ del unet_dict
62
+ else:
63
+ checkpoint_dict = torch.load(config.unet_path, map_location="cpu")
64
+ state_dict = checkpoint_dict["state_dict"] if "state_dict" in checkpoint_dict else checkpoint_dict
65
+ if config.unet_ckpt_prefix is not None:
66
+ state_dict = {k.replace(config.unet_ckpt_prefix, ''): v for k, v in state_dict.items()}
67
+ m, u = unet.load_state_dict(state_dict, strict=False)
68
+ assert len(u) == 0
69
+
70
+ if is_xformers_available() and int(torch.__version__.split(".")[0]) < 2:
71
+ unet.enable_xformers_memory_efficient_attention()
72
+
73
+ pipeline = ConditionalAnimationPipeline(
74
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=noise_scheduler)
75
+
76
+ else:
77
+ pipeline = ConditionalAnimationPipeline.from_pretrained(config.pipeline_pretrained_path)
78
+
79
+ pipeline.to("cuda")
80
+
81
+ # (frameinit) initialize frequency filter for noise reinitialization -------------
82
+ if config.frameinit_kwargs.enable:
83
+ pipeline.init_filter(
84
+ width = config.sampling_kwargs.width,
85
+ height = config.sampling_kwargs.height,
86
+ video_length = config.sampling_kwargs.n_frames,
87
+ filter_params = config.frameinit_kwargs.filter_params,
88
+ )
89
+ # -------------------------------------------------------------------------------
90
+ ### <<< create validation pipeline <<< ###
91
+
92
+ if args.prompt is not None:
93
+ prompts = [args.prompt]
94
+ n_prompts = [args.n_prompt]
95
+ first_frame_paths = [args.path_to_first_frame]
96
+ random_seeds = [int(args.seed)] if args.seed != "random" else "random"
97
+ else:
98
+ prompt_config = OmegaConf.load(args.prompt_config)
99
+ prompts = prompt_config.prompts
100
+ n_prompts = list(prompt_config.n_prompts) * len(prompts) if len(prompt_config.n_prompts) == 1 else prompt_config.n_prompts
101
+ first_frame_paths = prompt_config.path_to_first_frames
102
+ random_seeds = prompt_config.seeds
103
+
104
+ if random_seeds == "random":
105
+ random_seeds = [random.randint(0, 1e5) for _ in range(len(prompts))]
106
+ else:
107
+ random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
108
+ random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
109
+
110
+ config.prompt_kwargs = OmegaConf.create({"random_seeds": [], "prompts": prompts, "n_prompts": n_prompts, "first_frame_paths": first_frame_paths})
111
+ for prompt_idx, (prompt, n_prompt, first_frame_path, random_seed) in enumerate(zip(prompts, n_prompts, first_frame_paths, random_seeds)):
112
+ # manually set random seed for reproduction
113
+ if random_seed != -1: torch.manual_seed(random_seed)
114
+ else: torch.seed()
115
+ config.prompt_kwargs.random_seeds.append(torch.initial_seed())
116
+
117
+ print(f"current seed: {torch.initial_seed()}")
118
+ print(f"sampling {prompt} ...")
119
+ sample = pipeline(
120
+ prompt,
121
+ negative_prompt = n_prompt,
122
+ first_frame_paths = first_frame_path,
123
+ num_inference_steps = config.sampling_kwargs.steps,
124
+ guidance_scale_txt = config.sampling_kwargs.guidance_scale_txt,
125
+ guidance_scale_img = config.sampling_kwargs.guidance_scale_img,
126
+ width = config.sampling_kwargs.width,
127
+ height = config.sampling_kwargs.height,
128
+ video_length = config.sampling_kwargs.n_frames,
129
+ noise_sampling_method = config.unet_additional_kwargs['noise_sampling_method'],
130
+ noise_alpha = float(config.unet_additional_kwargs['noise_alpha']),
131
+ eta = config.sampling_kwargs.ddim_eta,
132
+ frame_stride = config.sampling_kwargs.frame_stride,
133
+ guidance_rescale = config.sampling_kwargs.guidance_rescale,
134
+ num_videos_per_prompt = config.sampling_kwargs.num_videos_per_prompt,
135
+ use_frameinit = config.frameinit_kwargs.enable,
136
+ frameinit_noise_level = config.frameinit_kwargs.noise_level,
137
+ camera_motion = config.frameinit_kwargs.camera_motion,
138
+ ).videos
139
+ samples.append(sample)
140
+
141
+ prompt = "-".join((prompt.replace("/", "").split(" ")[:10])).replace(":", "")
142
+ if sample.shape[0] > 1:
143
+ for cnt, samp in enumerate(sample):
144
+ save_videos_grid(samp.unsqueeze(0), f"{savedir}/sample/{sample_idx}-{cnt + 1}-{prompt}.{args.format}", format=args.format)
145
+ else:
146
+ save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.{args.format}", format=args.format)
147
+ print(f"save to {savedir}/sample/{prompt}.{args.format}")
148
+
149
+ sample_idx += 1
150
+
151
+ samples = torch.concat(samples)
152
+ save_videos_grid(samples, f"{savedir}/sample.{args.format}", n_rows=4, format=args.format)
153
+
154
+ OmegaConf.save(config, f"{savedir}/config.yaml")
155
+
156
+ if args.save_model:
157
+ pipeline.save_pretrained(f"{savedir}/model")
158
+
159
+
160
+ if __name__ == "__main__":
161
+ parser = argparse.ArgumentParser()
162
+ parser.add_argument("--inference_config", type=str, default="configs/inference/inference.yaml")
163
+ parser.add_argument("--prompt", "-p", type=str, default=None)
164
+ parser.add_argument("--n_prompt", "-n", type=str, default="")
165
+ parser.add_argument("--seed", type=str, default="random")
166
+ parser.add_argument("--path_to_first_frame", "-f", type=str, default=None)
167
+ parser.add_argument("--prompt_config", type=str, default="configs/prompts/default.yaml")
168
+ parser.add_argument("--format", type=str, default="mp4", choices=["gif", "mp4"])
169
+ parser.add_argument("--save_model", action="store_true")
170
+ parser.add_argument("optional_args", nargs='*', default=[])
171
+ args = parser.parse_args()
172
+
173
+ config = OmegaConf.load(args.inference_config)
174
+
175
+ if args.optional_args:
176
+ modified_config = OmegaConf.from_dotlist(args.optional_args)
177
+ config = OmegaConf.merge(config, modified_config)
178
+
179
+ main(args, config)
scripts/animate_autoregress.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import random
4
+ import os
5
+ import logging
6
+ from omegaconf import OmegaConf
7
+
8
+ import torch
9
+
10
+ import diffusers
11
+ from diffusers import AutoencoderKL, DDIMScheduler
12
+
13
+ from transformers import CLIPTextModel, CLIPTokenizer
14
+
15
+ from consisti2v.models.videoldm_unet import VideoLDMUNet3DConditionModel
16
+ from consisti2v.pipelines.pipeline_autoregress_animation import AutoregressiveAnimationPipeline
17
+ from consisti2v.utils.util import save_videos_grid
18
+ from diffusers.utils.import_utils import is_xformers_available
19
+
20
+ def main(args, config):
21
+ logging.basicConfig(
22
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
23
+ datefmt="%m/%d/%Y %H:%M:%S",
24
+ level=logging.INFO,
25
+ )
26
+ diffusers.utils.logging.set_verbosity_info()
27
+
28
+ time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
29
+ savedir = f"{config.output_dir}/{config.output_name}-{time_str}"
30
+ os.makedirs(savedir)
31
+
32
+ samples = []
33
+ sample_idx = 0
34
+
35
+ ### >>> create validation pipeline >>> ###
36
+ if config.pipeline_pretrained_path is None:
37
+ noise_scheduler = DDIMScheduler(**OmegaConf.to_container(config.noise_scheduler_kwargs))
38
+ tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer", use_safetensors=True)
39
+ text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder")
40
+ vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae", use_safetensors=True)
41
+ unet = VideoLDMUNet3DConditionModel.from_pretrained(
42
+ config.pretrained_model_path,
43
+ subfolder="unet",
44
+ variant=config.unet_additional_kwargs['variant'],
45
+ temp_pos_embedding=config.unet_additional_kwargs['temp_pos_embedding'],
46
+ augment_temporal_attention=config.unet_additional_kwargs['augment_temporal_attention'],
47
+ use_temporal=True,
48
+ n_frames=config.sampling_kwargs['n_frames'],
49
+ n_temp_heads=config.unet_additional_kwargs['n_temp_heads'],
50
+ first_frame_condition_mode=config.unet_additional_kwargs['first_frame_condition_mode'],
51
+ use_frame_stride_condition=config.unet_additional_kwargs['use_frame_stride_condition'],
52
+ use_safetensors=True
53
+ )
54
+
55
+ params_unet = [p.numel() for n, p in unet.named_parameters()]
56
+ params_vae = [p.numel() for n, p in vae.named_parameters()]
57
+ params_text_encoder = [p.numel() for n, p in text_encoder.named_parameters()]
58
+ params = params_unet + params_vae + params_text_encoder
59
+ print(f"### UNet Parameters: {sum(params) / 1e6} M")
60
+
61
+ # 1. unet ckpt
62
+ if config.unet_path is not None:
63
+ if os.path.isdir(config.unet_path):
64
+ unet_dict = VideoLDMUNet3DConditionModel.from_pretrained(config.unet_path)
65
+ m, u = unet.load_state_dict(unet_dict.state_dict(), strict=False)
66
+ assert len(u) == 0
67
+ del unet_dict
68
+ else:
69
+ checkpoint_dict = torch.load(config.unet_path, map_location="cpu")
70
+ state_dict = checkpoint_dict["state_dict"] if "state_dict" in checkpoint_dict else checkpoint_dict
71
+ if config.unet_ckpt_prefix is not None:
72
+ state_dict = {k.replace(config.unet_ckpt_prefix, ''): v for k, v in state_dict.items()}
73
+ m, u = unet.load_state_dict(state_dict, strict=False)
74
+ assert len(u) == 0
75
+
76
+ if is_xformers_available() and int(torch.__version__.split(".")[0]) < 2:
77
+ unet.enable_xformers_memory_efficient_attention()
78
+
79
+ pipeline = AutoregressiveAnimationPipeline(
80
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=noise_scheduler)
81
+
82
+ else:
83
+ pipeline = AutoregressiveAnimationPipeline.from_pretrained(config.pipeline_pretrained_path)
84
+
85
+ pipeline.to("cuda")
86
+
87
+ # (frameinit) initialize frequency filter for noise reinitialization -------------
88
+ if config.frameinit_kwargs.enable:
89
+ pipeline.init_filter(
90
+ width = config.sampling_kwargs.width,
91
+ height = config.sampling_kwargs.height,
92
+ video_length = config.sampling_kwargs.n_frames,
93
+ filter_params = config.frameinit_kwargs.filter_params,
94
+ )
95
+ # -------------------------------------------------------------------------------
96
+ ### <<< create validation pipeline <<< ###
97
+
98
+ if args.prompt is not None:
99
+ prompts = [args.prompt]
100
+ n_prompts = [args.n_prompt]
101
+ first_frame_paths = [args.path_to_first_frame]
102
+ random_seeds = [int(args.seed)] if args.seed != "random" else "random"
103
+ else:
104
+ prompt_config = OmegaConf.load(args.prompt_config)
105
+ prompts = prompt_config.prompts
106
+ n_prompts = list(prompt_config.n_prompts) * len(prompts) if len(prompt_config.n_prompts) == 1 else prompt_config.n_prompts
107
+ first_frame_paths = prompt_config.path_to_first_frames
108
+ random_seeds = prompt_config.seeds
109
+
110
+ if random_seeds == "random":
111
+ random_seeds = [random.randint(0, 1e5) for _ in range(len(prompts))]
112
+ else:
113
+ random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
114
+ random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
115
+
116
+ config.prompt_kwargs = OmegaConf.create({"random_seeds": [], "prompts": prompts, "n_prompts": n_prompts, "first_frame_paths": first_frame_paths})
117
+ for prompt_idx, (prompt, n_prompt, first_frame_path, random_seed) in enumerate(zip(prompts, n_prompts, first_frame_paths, random_seeds)):
118
+ # manually set random seed for reproduction
119
+ if random_seed != -1: torch.manual_seed(random_seed)
120
+ else: torch.seed()
121
+ config.prompt_kwargs.random_seeds.append(torch.initial_seed())
122
+
123
+ print(f"current seed: {torch.initial_seed()}")
124
+ print(f"sampling {prompt} ...")
125
+ sample = pipeline(
126
+ prompt,
127
+ negative_prompt = n_prompt,
128
+ first_frame_paths = first_frame_path,
129
+ num_inference_steps = config.sampling_kwargs.steps,
130
+ guidance_scale_txt = config.sampling_kwargs.guidance_scale_txt,
131
+ guidance_scale_img = config.sampling_kwargs.guidance_scale_img,
132
+ width = config.sampling_kwargs.width,
133
+ height = config.sampling_kwargs.height,
134
+ video_length = config.sampling_kwargs.n_frames,
135
+ noise_sampling_method = config.unet_additional_kwargs['noise_sampling_method'],
136
+ noise_alpha = float(config.unet_additional_kwargs['noise_alpha']),
137
+ eta = config.sampling_kwargs.ddim_eta,
138
+ frame_stride = config.sampling_kwargs.frame_stride,
139
+ guidance_rescale = config.sampling_kwargs.guidance_rescale,
140
+ num_videos_per_prompt = config.sampling_kwargs.num_videos_per_prompt,
141
+ autoregress_steps = config.sampling_kwargs.autoregress_steps,
142
+ use_frameinit = config.frameinit_kwargs.enable,
143
+ frameinit_noise_level = config.frameinit_kwargs.noise_level,
144
+ ).videos
145
+ samples.append(sample)
146
+
147
+ prompt = "-".join((prompt.replace("/", "").split(" ")[:10])).replace(":", "")
148
+ if sample.shape[0] > 1:
149
+ for cnt, samp in enumerate(sample):
150
+ save_videos_grid(samp.unsqueeze(0), f"{savedir}/sample/{sample_idx}-{cnt + 1}-{prompt}.{args.format}", format=args.format)
151
+ else:
152
+ save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.{args.format}", format=args.format)
153
+ print(f"save to {savedir}/sample/{prompt}.{args.format}")
154
+
155
+ sample_idx += 1
156
+
157
+ samples = torch.concat(samples)
158
+ save_videos_grid(samples, f"{savedir}/sample.{args.format}", n_rows=4, format=args.format)
159
+
160
+ OmegaConf.save(config, f"{savedir}/config.yaml")
161
+
162
+ if args.save_model:
163
+ pipeline.save_pretrained(f"{savedir}/model")
164
+
165
+
166
+ if __name__ == "__main__":
167
+ parser = argparse.ArgumentParser()
168
+ parser.add_argument("--inference_config", type=str, default="configs/inference/inference_autoregress.yaml")
169
+ parser.add_argument("--prompt", "-p", type=str, default=None)
170
+ parser.add_argument("--n_prompt", "-n", type=str, default="")
171
+ parser.add_argument("--seed", type=str, default="random")
172
+ parser.add_argument("--path_to_first_frame", "-f", type=str, default=None)
173
+ parser.add_argument("--prompt_config", type=str, default="configs/prompts/default.yaml")
174
+ parser.add_argument("--format", type=str, default="gif", choices=["gif", "mp4"])
175
+ parser.add_argument("--save_model", action="store_true")
176
+ parser.add_argument("optional_args", nargs='*', default=[])
177
+ args = parser.parse_args()
178
+
179
+ config = OmegaConf.load(args.inference_config)
180
+
181
+ if args.optional_args:
182
+ modified_config = OmegaConf.from_dotlist(args.optional_args)
183
+ config = OmegaConf.merge(config, modified_config)
184
+
185
+ main(args, config)
train.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import wandb
4
+ import random
5
+ import time
6
+ import logging
7
+ import inspect
8
+ import argparse
9
+ import datetime
10
+ import numpy as np
11
+
12
+ from pathlib import Path
13
+ from tqdm.auto import tqdm
14
+ from einops import rearrange, repeat
15
+ from omegaconf import OmegaConf
16
+ from typing import Dict, Optional, Tuple
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+
21
+ import diffusers
22
+ from diffusers import AutoencoderKL, DDIMScheduler
23
+ from diffusers.optimization import get_scheduler
24
+ from diffusers.utils import check_min_version
25
+ from diffusers.utils.import_utils import is_xformers_available
26
+ from diffusers.training_utils import EMAModel
27
+
28
+ import transformers
29
+ from transformers import CLIPTextModel, CLIPTokenizer
30
+
31
+ from accelerate import Accelerator, DistributedDataParallelKwargs, InitProcessGroupKwargs
32
+ from accelerate.logging import get_logger
33
+ from accelerate.utils import set_seed
34
+
35
+ from consisti2v.data.dataset import WebVid10M, Pexels, JointDataset
36
+ from consisti2v.models.videoldm_unet import VideoLDMUNet3DConditionModel
37
+ from consisti2v.pipelines.pipeline_conditional_animation import ConditionalAnimationPipeline
38
+ from consisti2v.utils.util import save_videos_grid
39
+
40
+ logger = get_logger(__name__, log_level="INFO")
41
+
42
+ def main(
43
+ name: str,
44
+ use_wandb: bool,
45
+
46
+ is_image: bool,
47
+
48
+ output_dir: str,
49
+ pretrained_model_path: str,
50
+
51
+ train_data: Dict,
52
+ validation_data: Dict,
53
+
54
+ cfg_random_null_text_ratio: float = 0.1,
55
+ cfg_random_null_img_ratio: float = 0.0,
56
+
57
+ resume_from_checkpoint: Optional[str] = None,
58
+ unet_additional_kwargs: Dict = {},
59
+ use_ema: bool = False,
60
+ ema_decay: float = 0.9999,
61
+ noise_scheduler_kwargs = None,
62
+
63
+ max_train_epoch: int = -1,
64
+ max_train_steps: int = 100,
65
+ validation_steps: int = 100,
66
+
67
+ learning_rate: float = 3e-5,
68
+ scale_lr: bool = False,
69
+ lr_warmup_steps: int = 0,
70
+ lr_scheduler: str = "constant",
71
+
72
+ trainable_modules: Tuple[str] = (None, ),
73
+ num_workers: int = 32,
74
+ train_batch_size: int = 1,
75
+ adam_beta1: float = 0.9,
76
+ adam_beta2: float = 0.999,
77
+ adam_weight_decay: float = 1e-2,
78
+ adam_epsilon: float = 1e-08,
79
+ max_grad_norm: float = 1.0,
80
+ gradient_accumulation_steps: int = 1,
81
+ gradient_checkpointing: bool = False,
82
+ checkpointing_epochs: int = 5,
83
+ checkpointing_steps: int = -1,
84
+
85
+ mixed_precision: Optional[str] = "fp16",
86
+ enable_xformers_memory_efficient_attention: bool = True,
87
+
88
+ seed: Optional[int] = 42,
89
+ is_debug: bool = False,
90
+ ):
91
+ check_min_version("0.10.0.dev0")
92
+ *_, config = inspect.getargvalues(inspect.currentframe())
93
+ config = {k: v for k, v in config.items() if k != 'config' and k != '_'}
94
+
95
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True if not is_image else False)
96
+ init_kwargs = InitProcessGroupKwargs(timeout=datetime.timedelta(seconds=3600))
97
+
98
+ accelerator = Accelerator(
99
+ gradient_accumulation_steps=gradient_accumulation_steps,
100
+ mixed_precision=mixed_precision,
101
+ kwargs_handlers=[ddp_kwargs, init_kwargs],
102
+ )
103
+
104
+ if seed is not None:
105
+ set_seed(seed)
106
+
107
+ # Logging folder
108
+ folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S")
109
+ output_dir = os.path.join(output_dir, folder_name)
110
+ if is_debug and os.path.exists(output_dir):
111
+ os.system(f"rm -rf {output_dir}")
112
+
113
+ # Make one log on every process with the configuration for debugging.
114
+ logging.basicConfig(
115
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
116
+ datefmt="%m/%d/%Y %H:%M:%S",
117
+ level=logging.INFO,
118
+ )
119
+ logger.info(accelerator.state, main_process_only=False)
120
+
121
+ if accelerator.is_local_main_process:
122
+ transformers.utils.logging.set_verbosity_warning()
123
+ diffusers.utils.logging.set_verbosity_info()
124
+ else:
125
+ transformers.utils.logging.set_verbosity_error()
126
+ diffusers.utils.logging.set_verbosity_error()
127
+
128
+ if accelerator.is_main_process and (not is_debug) and use_wandb:
129
+ project_name = "text_image_to_video" if not is_image else "image_finetune"
130
+ wandb.init(project=project_name, name=folder_name, config=config)
131
+ accelerator.wait_for_everyone()
132
+
133
+ # Handle the output folder creation
134
+ if accelerator.is_main_process:
135
+ os.makedirs(output_dir, exist_ok=True)
136
+ os.makedirs(f"{output_dir}/samples", exist_ok=True)
137
+ os.makedirs(f"{output_dir}/sanity_check", exist_ok=True)
138
+ os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
139
+ OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
140
+
141
+ # TODO: change all datasets to fps+duration in the future
142
+ if train_data.dataset == "pexels":
143
+ train_data.sample_n_frames = train_data.sample_duration * train_data.sample_fps
144
+ elif train_data.dataset == "joint":
145
+ if train_data.sample_duration is not None:
146
+ train_data.sample_n_frames = train_data.sample_duration * train_data.sample_fps
147
+ # Load scheduler, tokenizer and models.
148
+ noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
149
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
150
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
151
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
152
+ unet = VideoLDMUNet3DConditionModel.from_pretrained(
153
+ pretrained_model_path,
154
+ subfolder="unet",
155
+ variant=unet_additional_kwargs['variant'],
156
+ use_temporal=True if not is_image else False,
157
+ temp_pos_embedding=unet_additional_kwargs['temp_pos_embedding'],
158
+ augment_temporal_attention=unet_additional_kwargs['augment_temporal_attention'],
159
+ n_frames=train_data.sample_n_frames if not is_image else 2,
160
+ n_temp_heads=unet_additional_kwargs['n_temp_heads'],
161
+ first_frame_condition_mode=unet_additional_kwargs['first_frame_condition_mode'],
162
+ use_frame_stride_condition=unet_additional_kwargs['use_frame_stride_condition'],
163
+ use_safetensors=True
164
+ )
165
+
166
+ # Freeze vae and text_encoder
167
+ vae.requires_grad_(False)
168
+ text_encoder.requires_grad_(False)
169
+ unet.train()
170
+
171
+ if use_ema:
172
+ ema_unet = VideoLDMUNet3DConditionModel.from_pretrained(
173
+ pretrained_model_path,
174
+ subfolder="unet",
175
+ variant=unet_additional_kwargs['variant'],
176
+ use_temporal=True if not is_image else False,
177
+ temp_pos_embedding=unet_additional_kwargs['temp_pos_embedding'],
178
+ augment_temporal_attention=unet_additional_kwargs['augment_temporal_attention'],
179
+ n_frames=train_data.sample_n_frames if not is_image else 2,
180
+ n_temp_heads=unet_additional_kwargs['n_temp_heads'],
181
+ first_frame_condition_mode=unet_additional_kwargs['first_frame_condition_mode'],
182
+ use_frame_stride_condition=unet_additional_kwargs['use_frame_stride_condition'],
183
+ use_safetensors=True
184
+ )
185
+ ema_unet = EMAModel(ema_unet.parameters(), decay=ema_decay, model_cls=VideoLDMUNet3DConditionModel, model_config=ema_unet.config)
186
+
187
+ # Set unet trainable parameters
188
+ train_all_parameters = False
189
+ for trainable_module_name in trainable_modules:
190
+ if trainable_module_name == 'all':
191
+ unet.requires_grad_(True)
192
+ train_all_parameters = True
193
+ break
194
+
195
+ if not train_all_parameters:
196
+ unet.requires_grad_(False)
197
+ for name, param in unet.named_parameters():
198
+ for trainable_module_name in trainable_modules:
199
+ if trainable_module_name in name:
200
+ param.requires_grad = True
201
+ break
202
+
203
+ # Enable xformers
204
+ if enable_xformers_memory_efficient_attention and int(torch.__version__.split(".")[0]) < 2:
205
+ if is_xformers_available():
206
+ unet.enable_xformers_memory_efficient_attention()
207
+ else:
208
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
209
+
210
+ def save_model_hook(models, weights, output_dir):
211
+ if accelerator.is_main_process:
212
+ if use_ema:
213
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
214
+
215
+ for i, model in enumerate(models):
216
+ model.save_pretrained(os.path.join(output_dir, "unet"))
217
+
218
+ # make sure to pop weight so that corresponding model is not saved again
219
+ weights.pop()
220
+
221
+ def load_model_hook(models, input_dir):
222
+ if use_ema:
223
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), VideoLDMUNet3DConditionModel)
224
+ ema_unet.load_state_dict(load_model.state_dict())
225
+ ema_unet.to(accelerator.device)
226
+ del load_model
227
+
228
+ for i in range(len(models)):
229
+ # pop models so that they are not loaded again
230
+ model = models.pop()
231
+
232
+ # load diffusers style into model
233
+ load_model = VideoLDMUNet3DConditionModel.from_pretrained(input_dir, subfolder="unet")
234
+ model.register_to_config(**load_model.config)
235
+
236
+ model.load_state_dict(load_model.state_dict())
237
+ del load_model
238
+
239
+ accelerator.register_save_state_pre_hook(save_model_hook)
240
+ accelerator.register_load_state_pre_hook(load_model_hook)
241
+
242
+ # Enable gradient checkpointing
243
+ if gradient_checkpointing:
244
+ unet.enable_gradient_checkpointing()
245
+
246
+ if scale_lr:
247
+ learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes)
248
+
249
+ trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters()))
250
+ optimizer = torch.optim.AdamW(
251
+ trainable_params,
252
+ lr=learning_rate,
253
+ betas=(adam_beta1, adam_beta2),
254
+ weight_decay=adam_weight_decay,
255
+ eps=adam_epsilon,
256
+ )
257
+
258
+ logger.info(f"trainable params number: {len(trainable_params)}")
259
+ logger.info(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
260
+
261
+ # Get the training dataset
262
+ if train_data['dataset'] == "webvid":
263
+ train_dataset = WebVid10M(**train_data, is_image=is_image)
264
+ elif train_data['dataset'] == "pexels":
265
+ train_dataset = Pexels(**train_data, is_image=is_image)
266
+ elif train_data['dataset'] == "joint":
267
+ train_dataset = JointDataset(**train_data, is_image=is_image)
268
+ else:
269
+ raise ValueError(f"Unknown dataset {train_data['dataset']}")
270
+
271
+ # DataLoaders creation:
272
+ train_dataloader = torch.utils.data.DataLoader(
273
+ train_dataset,
274
+ shuffle=True,
275
+ batch_size=train_batch_size,
276
+ num_workers=num_workers,
277
+ pin_memory=True,
278
+ )
279
+
280
+ # Get the training iteration
281
+ if max_train_steps == -1:
282
+ assert max_train_epoch != -1
283
+ max_train_steps = max_train_epoch * len(train_dataloader)
284
+
285
+ if checkpointing_steps == -1:
286
+ assert checkpointing_epochs != -1
287
+ checkpointing_steps = checkpointing_epochs * len(train_dataloader)
288
+
289
+ # Scheduler
290
+ lr_scheduler = get_scheduler(
291
+ lr_scheduler,
292
+ optimizer=optimizer,
293
+ num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
294
+ num_training_steps=max_train_steps * gradient_accumulation_steps,
295
+ )
296
+
297
+ # Validation pipeline
298
+ validation_pipeline = ConditionalAnimationPipeline(
299
+ unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler,
300
+ )
301
+ validation_pipeline.enable_vae_slicing()
302
+
303
+ # Prepare everything with our `accelerator`.
304
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
305
+ unet, optimizer, train_dataloader, lr_scheduler
306
+ )
307
+
308
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
309
+ # as these models are only used for inference, keeping weights in full precision is not required.
310
+ weight_dtype = torch.float32
311
+ if accelerator.mixed_precision == "fp16":
312
+ weight_dtype = torch.float16
313
+ elif accelerator.mixed_precision == "bf16":
314
+ weight_dtype = torch.bfloat16
315
+
316
+ if use_ema:
317
+ ema_unet.to(accelerator.device)
318
+
319
+ # Move text_encode and vae to gpu and cast to weight_dtype
320
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
321
+ vae.to(accelerator.device, dtype=weight_dtype)
322
+
323
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
324
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
325
+ # Afterwards we recalculate our number of training epochs
326
+ num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
327
+
328
+ # Train!
329
+ total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps
330
+
331
+ logger.info("***** Running training *****")
332
+ logger.info(f" Num examples = {len(train_dataset)}")
333
+ logger.info(f" Num Epochs = {num_train_epochs}")
334
+ logger.info(f" Instantaneous batch size per device = {train_batch_size}")
335
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
336
+ logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
337
+ logger.info(f" Total optimization steps = {max_train_steps}")
338
+
339
+ global_step = 0
340
+ first_epoch = 0
341
+
342
+ # Load pretrained unet weights
343
+ if resume_from_checkpoint is not None:
344
+ logger.info(f"Resuming from checkpoint: {resume_from_checkpoint}")
345
+ accelerator.load_state(resume_from_checkpoint)
346
+ global_step = int(resume_from_checkpoint.split("-")[-1])
347
+
348
+ initial_global_step = global_step
349
+ first_epoch = global_step // num_update_steps_per_epoch
350
+ logger.info(f"global_step: {global_step}")
351
+ logger.info(f"first_epoch: {first_epoch}")
352
+ else:
353
+ initial_global_step = 0
354
+
355
+ # Only show the progress bar once on each machine.
356
+ progress_bar = tqdm(range(0, max_train_steps), initial=initial_global_step, desc="Steps", disable=not accelerator.is_main_process)
357
+
358
+ for epoch in range(first_epoch, num_train_epochs):
359
+ train_loss = 0.0
360
+ train_grad_norm = 0.0
361
+ data_loading_time = 0.0
362
+ prepare_everything_time = 0.0
363
+ network_forward_time = 0.0
364
+ network_backward_time = 0.0
365
+
366
+ t0 = time.time()
367
+ for step, batch in enumerate(train_dataloader):
368
+ t1 = time.time()
369
+ if cfg_random_null_text_ratio > 0.0:
370
+ batch['text'] = [name if random.random() > cfg_random_null_text_ratio else "" for name in batch['text']]
371
+
372
+ # Data batch sanity check
373
+ if accelerator.is_main_process and epoch == first_epoch and step == 0:
374
+ pixel_values, texts = batch['pixel_values'].cpu(), batch['text']
375
+ pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
376
+ for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
377
+ pixel_value = pixel_value[None, ...]
378
+ save_videos_grid(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'no_text-{idx}'}.gif", rescale=True)
379
+
380
+ ### >>>> Training >>>> ###
381
+ with accelerator.accumulate(unet):
382
+ # Convert videos to latent space
383
+ pixel_values = batch["pixel_values"].to(weight_dtype)
384
+ video_length = pixel_values.shape[1]
385
+ pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
386
+ latents = vae.encode(pixel_values).latent_dist
387
+ latents = latents.sample()
388
+ latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
389
+
390
+ latents = latents * vae.config.scaling_factor
391
+
392
+ if unet_additional_kwargs["first_frame_condition_mode"] != "none":
393
+ # Get first frame latents
394
+ first_frame_latents = latents[:, :, 0:1, :, :]
395
+
396
+ # Sample noise that we'll add to the latents
397
+ if unet_additional_kwargs['noise_sampling_method'] == 'vanilla':
398
+ noise = torch.randn_like(latents)
399
+ elif unet_additional_kwargs['noise_sampling_method'] == 'pyoco_mixed':
400
+ noise_alpha_squared = float(unet_additional_kwargs['noise_alpha']) ** 2
401
+ shared_noise = torch.randn_like(latents[:, :, 0:1, :, :]) * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared))
402
+ ind_noise = torch.randn_like(latents) * math.sqrt(1 / (1 + noise_alpha_squared))
403
+ noise = shared_noise + ind_noise
404
+ elif unet_additional_kwargs['noise_sampling_method'] == 'pyoco_progressive':
405
+ noise_alpha_squared = float(unet_additional_kwargs['noise_alpha']) ** 2
406
+ noise = torch.randn_like(latents)
407
+ ind_noise = torch.randn_like(latents) * math.sqrt(1 / (1 + noise_alpha_squared))
408
+ for i in range(1, noise.shape[2]):
409
+ noise[:, :, i, :, :] = noise[:, :, i - 1, :, :] * math.sqrt((noise_alpha_squared) / (1 + noise_alpha_squared)) + ind_noise[:, :, i, :, :]
410
+ else:
411
+ raise ValueError(f"Unknown noise sampling method {unet_additional_kwargs['noise_sampling_method']}")
412
+
413
+ bsz = latents.shape[0]
414
+
415
+ # Sample a random timestep for each video
416
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
417
+ timesteps = timesteps.long()
418
+
419
+ # Add noise to the latents according to the noise magnitude at each timestep
420
+ # (this is the forward diffusion process)
421
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
422
+
423
+ if cfg_random_null_img_ratio > 0.0:
424
+ for i in range(first_frame_latents.shape[0]):
425
+ if random.random() <= cfg_random_null_img_ratio:
426
+ first_frame_latents[i, :, :, :, :] = noisy_latents[i, :, 0:1, :, :]
427
+
428
+ # Remove the first noisy latent from the latents if we're conditioning on the first frame
429
+ if unet_additional_kwargs["first_frame_condition_mode"] != "none":
430
+ noisy_latents = noisy_latents[:, :, 1:, :, :]
431
+
432
+ # Get the text embedding for conditioning
433
+ prompt_ids = tokenizer(
434
+ batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
435
+ ).input_ids.to(latents.device)
436
+ encoder_hidden_states = text_encoder(prompt_ids)[0]
437
+
438
+ # Get the target for loss depending on the prediction type
439
+ if noise_scheduler.config.prediction_type == "epsilon":
440
+ target = noise
441
+ elif noise_scheduler.config.prediction_type == "v_prediction":
442
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
443
+ else:
444
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
445
+
446
+ timesteps = repeat(timesteps, "b -> b f", f=video_length)
447
+ timesteps = rearrange(timesteps, "b f -> (b f)")
448
+
449
+ frame_stride = None
450
+ if unet_additional_kwargs["use_frame_stride_condition"]:
451
+ frame_stride = batch['stride'].to(latents.device)
452
+ frame_stride = frame_stride.long()
453
+ frame_stride = repeat(frame_stride, "b -> b f", f=video_length)
454
+ frame_stride = rearrange(frame_stride, "b f -> (b f)")
455
+
456
+ t2 = time.time()
457
+
458
+ # Predict the noise residual and compute loss
459
+ if unet_additional_kwargs["first_frame_condition_mode"] != "none":
460
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, first_frame_latents=first_frame_latents, frame_stride=frame_stride).sample
461
+ loss = F.mse_loss(model_pred.float(), target.float()[:, :, 1:, :, :], reduction="mean")
462
+ else:
463
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
464
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
465
+
466
+ t3 = time.time()
467
+
468
+ avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
469
+ train_loss += avg_loss.item() / gradient_accumulation_steps
470
+
471
+ # Backpropagate
472
+ accelerator.backward(loss)
473
+ if accelerator.sync_gradients:
474
+ grad_norm = accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm)
475
+ avg_grad_norm = accelerator.gather(grad_norm.repeat(train_batch_size)).mean()
476
+ train_grad_norm += avg_grad_norm.item() / gradient_accumulation_steps
477
+
478
+ optimizer.step()
479
+ lr_scheduler.step()
480
+ optimizer.zero_grad()
481
+
482
+ t4 = time.time()
483
+
484
+ data_loading_time += (t1 - t0) / gradient_accumulation_steps
485
+ prepare_everything_time += (t2 - t1) / gradient_accumulation_steps
486
+ network_forward_time += (t3 - t2) / gradient_accumulation_steps
487
+ network_backward_time += (t4 - t3) / gradient_accumulation_steps
488
+
489
+ t0 = time.time()
490
+
491
+ ### <<<< Training <<<< ###
492
+
493
+ # Checks if the accelerator has performed an optimization step behind the scenes
494
+ if accelerator.sync_gradients:
495
+ if use_ema:
496
+ ema_unet.step(unet.parameters())
497
+ progress_bar.update(1)
498
+ global_step += 1
499
+
500
+ # Wandb logging
501
+ if accelerator.is_main_process and (not is_debug) and use_wandb:
502
+ wandb.log({"metrics/train_loss": train_loss}, step=global_step)
503
+ wandb.log({"metrics/train_grad_norm": train_grad_norm}, step=global_step)
504
+
505
+ wandb.log({"profiling/train_data_loading_time": data_loading_time}, step=global_step)
506
+ wandb.log({"profiling/train_prepare_everything_time": prepare_everything_time}, step=global_step)
507
+ wandb.log({"profiling/train_network_forward_time": network_forward_time}, step=global_step)
508
+ wandb.log({"profiling/train_network_backward_time": network_backward_time}, step=global_step)
509
+ # accelerator.log({"train_loss": train_loss}, step=global_step)
510
+ train_loss = 0.0
511
+ train_grad_norm = 0.0
512
+ data_loading_time = 0.0
513
+ prepare_everything_time = 0.0
514
+ network_forward_time = 0.0
515
+ network_backward_time = 0.0
516
+
517
+ # Save checkpoint
518
+ if global_step % checkpointing_steps == 0:
519
+ if accelerator.is_main_process:
520
+ save_path = os.path.join(output_dir, f"checkpoints/checkpoint-{global_step}")
521
+ accelerator.save_state(save_path)
522
+ logger.info(f"Saved state to {save_path} (global_step: {global_step})")
523
+
524
+ # Periodically validation
525
+ if accelerator.is_main_process and global_step % validation_steps == 0:
526
+ if use_ema:
527
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
528
+ ema_unet.store(unet.parameters())
529
+ ema_unet.copy_to(unet.parameters())
530
+
531
+ samples = []
532
+ wandb_samples = []
533
+
534
+ generator = torch.Generator(device=latents.device)
535
+ generator.manual_seed(seed)
536
+
537
+ height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size
538
+ width = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size
539
+
540
+ prompts = validation_data.prompts
541
+
542
+ first_frame_paths = [None] * len(prompts)
543
+ if unet_additional_kwargs["first_frame_condition_mode"] != "none":
544
+ first_frame_paths = validation_data.path_to_first_frames
545
+
546
+ for idx, (prompt, first_frame_path) in enumerate(zip(prompts, first_frame_paths)):
547
+ sample = validation_pipeline(
548
+ prompt,
549
+ generator = generator,
550
+ video_length = train_data.sample_n_frames if not is_image else 2,
551
+ height = height,
552
+ width = width,
553
+ first_frame_paths = first_frame_path,
554
+ noise_sampling_method = unet_additional_kwargs['noise_sampling_method'],
555
+ noise_alpha = float(unet_additional_kwargs['noise_alpha']),
556
+ **validation_data,
557
+ ).videos
558
+ save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif")
559
+ samples.append(sample)
560
+
561
+ numpy_sample = (sample.squeeze(0).permute(1, 0, 2, 3) * 255).cpu().numpy().astype(np.uint8)
562
+ wandb_video = wandb.Video(numpy_sample, fps=8, caption=prompt)
563
+ wandb_samples.append(wandb_video)
564
+
565
+ if (not is_debug) and use_wandb:
566
+ val_title = 'val_videos'
567
+ wandb.log({val_title: wandb_samples}, step=global_step)
568
+
569
+ samples = torch.concat(samples)
570
+ save_path = f"{output_dir}/samples/sample-{global_step}.gif"
571
+ save_videos_grid(samples, save_path)
572
+
573
+ logger.info(f"Saved samples to {save_path}")
574
+
575
+ if use_ema:
576
+ # Switch back to the original UNet parameters.
577
+ ema_unet.restore(unet.parameters())
578
+
579
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
580
+ progress_bar.set_postfix(**logs)
581
+
582
+ if accelerator.is_main_process and (not is_debug) and use_wandb:
583
+ wandb.log({"metrics/train_lr": lr_scheduler.get_last_lr()[0]}, step=global_step)
584
+
585
+ if global_step >= max_train_steps:
586
+ break
587
+
588
+ # Create the pipeline using the trained modules and save it.
589
+ accelerator.wait_for_everyone()
590
+ if accelerator.is_main_process:
591
+ unet = accelerator.unwrap_model(unet)
592
+ pipeline = ConditionalAnimationPipeline(
593
+ text_encoder=text_encoder,
594
+ vae=vae,
595
+ unet=unet,
596
+ tokenizer=tokenizer,
597
+ scheduler=noise_scheduler,
598
+ )
599
+ pipeline.save_pretrained(f"{output_dir}/final_checkpoint")
600
+
601
+
602
+ if __name__ == "__main__":
603
+ parser = argparse.ArgumentParser()
604
+ parser.add_argument("--config", type=str, required=True)
605
+ parser.add_argument("--name", "-n", type=str, default="")
606
+ parser.add_argument("--wandb", action="store_true")
607
+ parser.add_argument("optional_args", nargs='*', default=[])
608
+ args = parser.parse_args()
609
+
610
+ name = args.name + "_" + Path(args.config).stem
611
+ config = OmegaConf.load(args.config)
612
+
613
+ if args.optional_args:
614
+ modified_config = OmegaConf.from_dotlist(args.optional_args)
615
+ config = OmegaConf.merge(config, modified_config)
616
+
617
+ main(name=name, use_wandb=args.wandb, **config)