wren93 commited on
Commit
ef16dc7
·
1 Parent(s): 09d869f
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ samples/
2
+ wandb/
3
+ outputs/
4
+ __pycache__/
5
+ scripts/animate_inter.py
6
+ scripts/gradio_app.py
7
+ *.ipynb
8
+ *.safetensors
9
+ *.ckpt
10
+ .ossutil_checkpoint/
11
+ ossutil_output/
12
+ debugs/
13
+ .vscode
14
+ .env
15
+ models
16
+ !*/models
17
+ .ipynb_checkpoints
18
+ checkpoints
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 TIGER Lab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: ConsistI2V
3
- emoji: 💻
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: gradio
 
1
  ---
2
  title: ConsistI2V
3
+ emoji: 🎥
4
  colorFrom: purple
5
  colorTo: green
6
  sdk: gradio
app.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import json
4
+ import torch
5
+ import random
6
+ import requests
7
+ from PIL import Image
8
+ import numpy as np
9
+
10
+ import gradio as gr
11
+ from datetime import datetime
12
+
13
+ import torchvision.transforms as T
14
+
15
+ from diffusers import DDIMScheduler
16
+ from diffusers.utils.import_utils import is_xformers_available
17
+ from consisti2v.pipelines.pipeline_conditional_animation import ConditionalAnimationPipeline
18
+ from consisti2v.utils.util import save_videos_grid
19
+ from omegaconf import OmegaConf
20
+
21
+
22
+ sample_idx = 0
23
+ scheduler_dict = {
24
+ "DDIM": DDIMScheduler,
25
+ }
26
+
27
+ css = """
28
+ .toolbutton {
29
+ margin-buttom: 0em 0em 0em 0em;
30
+ max-width: 2.5em;
31
+ min-width: 2.5em !important;
32
+ height: 2.5em;
33
+ }
34
+ """
35
+
36
+ class AnimateController:
37
+ def __init__(self):
38
+
39
+ # config dirs
40
+ self.basedir = os.getcwd()
41
+ self.savedir = os.path.join(self.basedir, "samples/Gradio", datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
42
+ self.savedir_sample = os.path.join(self.savedir, "sample")
43
+ os.makedirs(self.savedir, exist_ok=True)
44
+
45
+ self.image_resolution = (256, 256)
46
+ # config models
47
+ self.pipeline = ConditionalAnimationPipeline.from_pretrained("TIGER-Lab/ConsistI2V", torch_dtype=torch.float16,)
48
+ self.pipeline.to("cuda")
49
+
50
+ def update_textbox_and_save_image(self, input_image, height_slider, width_slider, center_crop):
51
+ pil_image = Image.fromarray(input_image.astype(np.uint8)).convert("RGB")
52
+ img_path = os.path.join(self.savedir, "input_image.png")
53
+ pil_image.save(img_path)
54
+ self.image_resolution = pil_image.size
55
+ pil_image = pil_image.resize((width_slider, height_slider))
56
+ if center_crop:
57
+ width, height = width_slider, height_slider
58
+ aspect_ratio = width / height
59
+ if aspect_ratio > 16 / 10:
60
+ pil_image = pil_image.crop((int((width - height * 16 / 10) / 2), 0, int((width + height * 16 / 10) / 2), height))
61
+ elif aspect_ratio < 16 / 10:
62
+ pil_image = pil_image.crop((0, int((height - width * 10 / 16) / 2), width, int((height + width * 10 / 16) / 2)))
63
+ return gr.Textbox.update(value=img_path), gr.Image.update(value=np.array(pil_image))
64
+
65
+ @spaces.GPU
66
+ def animate(
67
+ self,
68
+ prompt_textbox,
69
+ negative_prompt_textbox,
70
+ input_image_path,
71
+ sampler_dropdown,
72
+ sample_step_slider,
73
+ width_slider,
74
+ height_slider,
75
+ txt_cfg_scale_slider,
76
+ img_cfg_scale_slider,
77
+ center_crop,
78
+ frame_stride,
79
+ use_frameinit,
80
+ frame_init_noise_level,
81
+ seed_textbox
82
+ ):
83
+ if self.pipeline is None:
84
+ raise gr.Error(f"Please select a pretrained pipeline path.")
85
+ if input_image_path == "":
86
+ raise gr.Error(f"Please upload an input image.")
87
+ if (not center_crop) and (width_slider % 8 != 0 or height_slider % 8 != 0):
88
+ raise gr.Error(f"`height` and `width` have to be divisible by 8 but are {height_slider} and {width_slider}.")
89
+ if center_crop and (width_slider % 8 != 0 or height_slider % 8 != 0):
90
+ raise gr.Error(f"`height` and `width` (after cropping) have to be divisible by 8 but are {height_slider} and {width_slider}.")
91
+
92
+ if is_xformers_available() and int(torch.__version__.split(".")[0]) < 2: self.pipeline.unet.enable_xformers_memory_efficient_attention()
93
+
94
+ if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
95
+ else: torch.seed()
96
+ seed = torch.initial_seed()
97
+
98
+ if input_image_path.startswith("http://") or input_image_path.startswith("https://"):
99
+ first_frame = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
100
+ else:
101
+ first_frame = Image.open(input_image_path).convert('RGB')
102
+
103
+ original_width, original_height = first_frame.size
104
+
105
+ if not center_crop:
106
+ img_transform = T.Compose([
107
+ T.ToTensor(),
108
+ T.Resize((height_slider, width_slider), antialias=None),
109
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
110
+ ])
111
+ else:
112
+ aspect_ratio = original_width / original_height
113
+ crop_aspect_ratio = width_slider / height_slider
114
+ if aspect_ratio > crop_aspect_ratio:
115
+ center_crop_width = int(crop_aspect_ratio * original_height)
116
+ center_crop_height = original_height
117
+ elif aspect_ratio < crop_aspect_ratio:
118
+ center_crop_width = original_width
119
+ center_crop_height = int(original_width / crop_aspect_ratio)
120
+ else:
121
+ center_crop_width = original_width
122
+ center_crop_height = original_height
123
+ img_transform = T.Compose([
124
+ T.ToTensor(),
125
+ T.CenterCrop((center_crop_height, center_crop_width)),
126
+ T.Resize((height_slider, width_slider), antialias=None),
127
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
128
+ ])
129
+
130
+ first_frame = img_transform(first_frame).unsqueeze(0)
131
+ first_frame = first_frame.to("cuda")
132
+
133
+ if use_frameinit:
134
+ self.pipeline.init_filter(
135
+ width = width_slider,
136
+ height = height_slider,
137
+ video_length = 16,
138
+ filter_params = OmegaConf.create({'method': 'gaussian', 'd_s': 0.25, 'd_t': 0.25,})
139
+ )
140
+
141
+
142
+ sample = self.pipeline(
143
+ prompt_textbox,
144
+ negative_prompt = negative_prompt_textbox,
145
+ first_frames = first_frame,
146
+ num_inference_steps = sample_step_slider,
147
+ guidance_scale_txt = txt_cfg_scale_slider,
148
+ guidance_scale_img = img_cfg_scale_slider,
149
+ width = width_slider,
150
+ height = height_slider,
151
+ video_length = 16,
152
+ noise_sampling_method = "pyoco_mixed",
153
+ noise_alpha = 1.0,
154
+ frame_stride = frame_stride,
155
+ use_frameinit = use_frameinit,
156
+ frameinit_noise_level = frame_init_noise_level,
157
+ camera_motion = None,
158
+ ).videos
159
+
160
+ global sample_idx
161
+ sample_idx += 1
162
+ save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4")
163
+ save_videos_grid(sample, save_sample_path, format="mp4")
164
+
165
+ sample_config = {
166
+ "prompt": prompt_textbox,
167
+ "n_prompt": negative_prompt_textbox,
168
+ "first_frame_path": input_image_path,
169
+ "sampler": sampler_dropdown,
170
+ "num_inference_steps": sample_step_slider,
171
+ "guidance_scale_text": txt_cfg_scale_slider,
172
+ "guidance_scale_image": img_cfg_scale_slider,
173
+ "width": width_slider,
174
+ "height": height_slider,
175
+ "video_length": 8,
176
+ "seed": seed
177
+ }
178
+ json_str = json.dumps(sample_config, indent=4)
179
+ with open(os.path.join(self.savedir, "logs.json"), "a") as f:
180
+ f.write(json_str)
181
+ f.write("\n\n")
182
+
183
+ return gr.Video.update(value=save_sample_path)
184
+
185
+
186
+ controller = AnimateController()
187
+
188
+
189
+ def ui():
190
+ with gr.Blocks(css=css) as demo:
191
+ gr.Markdown(
192
+ """
193
+ # ConsistI2V Text+Image to Video Generation
194
+ Input image will be used as the first frame of the video. Text prompts will be used to control the output video content.
195
+ """
196
+ )
197
+
198
+ with gr.Column(variant="panel"):
199
+ gr.Markdown(
200
+ """
201
+ - Input image can be specified using the "Input Image Path/URL" text box (this can be either a local image path or an image URL) or uploaded by clicking or dragging the image to the "Input Image" box. The uploaded image will be temporarily stored in the "samples/Gradio" folder under the project root folder.
202
+ - Input image can be resized and/or center cropped to a given resolution by adjusting the "Width" and "Height" sliders. It is recommended to use the same resolution as the training resolution (256x256).
203
+ - After setting the input image path or changed the width/height of the input image, press the "Preview" button to visualize the resized input image.
204
+ """
205
+ )
206
+
207
+ with gr.Row():
208
+ prompt_textbox = gr.Textbox(label="Prompt", lines=2)
209
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2)
210
+
211
+ with gr.Row().style(equal_height=False):
212
+ with gr.Column():
213
+ with gr.Row():
214
+ sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
215
+ sample_step_slider = gr.Slider(label="Sampling steps", value=50, minimum=10, maximum=250, step=1)
216
+
217
+ with gr.Row():
218
+ center_crop = gr.Checkbox(label="Center Crop the Image", value=True)
219
+ width_slider = gr.Slider(label="Width", value=256, minimum=0, maximum=512, step=64)
220
+ height_slider = gr.Slider(label="Height", value=256, minimum=0, maximum=512, step=64)
221
+ with gr.Row():
222
+ txt_cfg_scale_slider = gr.Slider(label="Text CFG Scale", value=7.5, minimum=1.0, maximum=20.0, step=0.5)
223
+ img_cfg_scale_slider = gr.Slider(label="Image CFG Scale", value=1.0, minimum=1.0, maximum=20.0, step=0.5)
224
+ frame_stride = gr.Slider(label="Frame Stride", value=3, minimum=1, maximum=5, step=1)
225
+
226
+ with gr.Row():
227
+ use_frameinit = gr.Checkbox(label="Enable FrameInit", value=True)
228
+ frameinit_noise_level = gr.Slider(label="FrameInit Noise Level", value=850, minimum=1, maximum=999, step=1)
229
+
230
+
231
+ seed_textbox = gr.Textbox(label="Seed", value=-1)
232
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
233
+ seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
234
+
235
+
236
+
237
+ generate_button = gr.Button(value="Generate", variant='primary')
238
+
239
+ with gr.Column():
240
+ with gr.Row():
241
+ input_image_path = gr.Textbox(label="Input Image Path/URL", lines=1, scale=10, info="Press Enter or the Preview button to confirm the input image.")
242
+ preview_button = gr.Button(value="Preview")
243
+
244
+ with gr.Row():
245
+ input_image = gr.Image(label="Input Image", interactive=True)
246
+ input_image.upload(fn=controller.update_textbox_and_save_image, inputs=[input_image, height_slider, width_slider, center_crop], outputs=[input_image_path, input_image])
247
+ result_video = gr.Video(label="Generated Animation", interactive=False, autoplay=True)
248
+
249
+ def update_and_resize_image(input_image_path, height_slider, width_slider, center_crop):
250
+ if input_image_path.startswith("http://") or input_image_path.startswith("https://"):
251
+ pil_image = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
252
+ else:
253
+ pil_image = Image.open(input_image_path).convert('RGB')
254
+ controller.image_resolution = pil_image.size
255
+ original_width, original_height = pil_image.size
256
+
257
+ if center_crop:
258
+ crop_aspect_ratio = width_slider / height_slider
259
+ aspect_ratio = original_width / original_height
260
+ if aspect_ratio > crop_aspect_ratio:
261
+ new_width = int(crop_aspect_ratio * original_height)
262
+ left = (original_width - new_width) / 2
263
+ top = 0
264
+ right = left + new_width
265
+ bottom = original_height
266
+ pil_image = pil_image.crop((left, top, right, bottom))
267
+ elif aspect_ratio < crop_aspect_ratio:
268
+ new_height = int(original_width / crop_aspect_ratio)
269
+ top = (original_height - new_height) / 2
270
+ left = 0
271
+ right = original_width
272
+ bottom = top + new_height
273
+ pil_image = pil_image.crop((left, top, right, bottom))
274
+
275
+ pil_image = pil_image.resize((width_slider, height_slider))
276
+ return gr.Image.update(value=np.array(pil_image))
277
+
278
+ preview_button.click(fn=update_and_resize_image, inputs=[input_image_path, height_slider, width_slider, center_crop], outputs=[input_image])
279
+ input_image_path.submit(fn=update_and_resize_image, inputs=[input_image_path, height_slider, width_slider, center_crop], outputs=[input_image])
280
+
281
+ generate_button.click(
282
+ fn=controller.animate,
283
+ inputs=[
284
+ prompt_textbox,
285
+ negative_prompt_textbox,
286
+ input_image_path,
287
+ sampler_dropdown,
288
+ sample_step_slider,
289
+ width_slider,
290
+ height_slider,
291
+ txt_cfg_scale_slider,
292
+ img_cfg_scale_slider,
293
+ center_crop,
294
+ frame_stride,
295
+ use_frameinit,
296
+ frameinit_noise_level,
297
+ seed_textbox,
298
+ ],
299
+ outputs=[result_video]
300
+ )
301
+
302
+ return demo
303
+
304
+
305
+ if __name__ == "__main__":
306
+ demo = ui()
307
+ demo.launch(share=True)
configs/inference/inference.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: "samples/inference"
2
+ output_name: "i2v"
3
+
4
+ pretrained_model_path: "TIGER-Lab/ConsistI2V"
5
+ unet_path: null
6
+ unet_ckpt_prefix: "module."
7
+ pipeline_pretrained_path: null
8
+
9
+ sampling_kwargs:
10
+ height: 256
11
+ width: 256
12
+ n_frames: 16
13
+ steps: 50
14
+ ddim_eta: 0.0
15
+ guidance_scale_txt: 7.5
16
+ guidance_scale_img: 1.0
17
+ guidance_rescale: 0.0
18
+ num_videos_per_prompt: 1
19
+ frame_stride: 3
20
+
21
+ unet_additional_kwargs:
22
+ variant: null
23
+ n_temp_heads: 8
24
+ augment_temporal_attention: true
25
+ temp_pos_embedding: "rotary" # "rotary" or "sinusoidal"
26
+ first_frame_condition_mode: "concat"
27
+ use_frame_stride_condition: true
28
+ noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive"
29
+ noise_alpha: 1.0
30
+
31
+ noise_scheduler_kwargs:
32
+ beta_start: 0.00085
33
+ beta_end: 0.012
34
+ beta_schedule: "linear"
35
+ steps_offset: 1
36
+ clip_sample: false
37
+ rescale_betas_zero_snr: false # true if using zero terminal snr
38
+ timestep_spacing: "leading" # "trailing" if using zero terminal snr
39
+ prediction_type: "epsilon" # "v_prediction" if using zero terminal snr
40
+
41
+ frameinit_kwargs:
42
+ enable: true
43
+ camera_motion: null
44
+ noise_level: 850
45
+ filter_params:
46
+ method: 'gaussian'
47
+ d_s: 0.25
48
+ d_t: 0.25
configs/inference/inference_autoregress.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: "samples/inference"
2
+ output_name: "long_video"
3
+
4
+ pretrained_model_path: "TIGER-Lab/ConsistI2V"
5
+ unet_path: null
6
+ unet_ckpt_prefix: "module."
7
+ pipeline_pretrained_path: null
8
+
9
+ sampling_kwargs:
10
+ height: 256
11
+ width: 256
12
+ n_frames: 16
13
+ steps: 50
14
+ ddim_eta: 0.0
15
+ guidance_scale_txt: 7.5
16
+ guidance_scale_img: 1.0
17
+ guidance_rescale: 0.0
18
+ num_videos_per_prompt: 1
19
+ frame_stride: 3
20
+ autoregress_steps: 3
21
+
22
+ unet_additional_kwargs:
23
+ variant: null
24
+ n_temp_heads: 8
25
+ augment_temporal_attention: true
26
+ temp_pos_embedding: "rotary" # "rotary" or "sinusoidal"
27
+ first_frame_condition_mode: "concat"
28
+ use_frame_stride_condition: true
29
+ noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive"
30
+ noise_alpha: 1.0
31
+
32
+ noise_scheduler_kwargs:
33
+ beta_start: 0.00085
34
+ beta_end: 0.012
35
+ beta_schedule: "linear"
36
+ steps_offset: 1
37
+ clip_sample: false
38
+ rescale_betas_zero_snr: false # true if using zero terminal snr
39
+ timestep_spacing: "leading" # "trailing" if using zero terminal snr
40
+ prediction_type: "epsilon" # "v_prediction" if using zero terminal snr
41
+
42
+
43
+ frameinit_kwargs:
44
+ enable: true
45
+ noise_level: 850
46
+ filter_params:
47
+ method: 'gaussian'
48
+ d_s: 0.25
49
+ d_t: 0.25
configs/prompts/default.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seeds: random
2
+
3
+ prompts:
4
+ - "timelapse at the snow land with aurora in the sky."
5
+ - "fireworks."
6
+ - "clown fish swimming through the coral reef."
7
+ - "melting ice cream dripping down the cone."
8
+
9
+ n_prompts:
10
+ - ""
11
+
12
+ path_to_first_frames:
13
+ - "assets/example/example_01.png"
14
+ - "assets/example/example_02.png"
15
+ - "assets/example/example_03.png"
16
+ - "assets/example/example_04.png"
configs/training/training.yaml ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: "checkpoints"
2
+ pretrained_model_path: "stabilityai/stable-diffusion-2-1-base"
3
+
4
+ noise_scheduler_kwargs:
5
+ num_train_timesteps: 1000
6
+ beta_start: 0.00085
7
+ beta_end: 0.012
8
+ beta_schedule: "linear"
9
+ steps_offset: 1
10
+ clip_sample: false
11
+ rescale_betas_zero_snr: false # true if using zero terminal snr
12
+ timestep_spacing: "leading" # "trailing" if using zero terminal snr
13
+ prediction_type: "epsilon" # "v_prediction" if using zero terminal snr
14
+
15
+ train_data:
16
+ dataset: "joint"
17
+ pexels_config:
18
+ enable: false
19
+ json_path: null
20
+ caption_json_path: null
21
+ video_folder: null
22
+ webvid_config:
23
+ enable: true
24
+ json_path: "/path/to/webvid/annotation"
25
+ video_folder: "/path/to/webvid/data"
26
+ sample_size: 256
27
+ sample_duration: null
28
+ sample_fps: null
29
+ sample_stride: [1, 5]
30
+ sample_n_frames: 16
31
+
32
+ validation_data:
33
+ prompts:
34
+ - "timelapse at the snow land with aurora in the sky."
35
+ - "fireworks."
36
+ - "clown fish swimming through the coral reef."
37
+ - "melting ice cream dripping down the cone."
38
+
39
+ path_to_first_frames:
40
+ - "assets/example/example_01.jpg"
41
+ - "assets/example/example_02.jpg"
42
+ - "assets/example/example_03.jpg"
43
+ - "assets/example/example_04.jpg"
44
+
45
+ num_inference_steps: 50
46
+ ddim_eta: 0.0
47
+ guidance_scale_txt: 7.5
48
+ guidance_scale_img: 1.0
49
+ guidance_rescale: 0.0
50
+ frame_stride: 3
51
+
52
+ trainable_modules:
53
+ - "all"
54
+ # - "conv3ds."
55
+ # - "tempo_attns."
56
+
57
+ resume_from_checkpoint: null
58
+
59
+ unet_additional_kwargs:
60
+ variant: null
61
+ n_temp_heads: 8
62
+ augment_temporal_attention: true
63
+ temp_pos_embedding: "rotary" # "rotary" or "sinusoidal"
64
+ first_frame_condition_mode: "concat"
65
+ use_frame_stride_condition: true
66
+ noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive"
67
+ noise_alpha: 1.0
68
+
69
+ cfg_random_null_text_ratio: 0.1
70
+ cfg_random_null_img_ratio: 0.1
71
+
72
+ use_ema: false
73
+ ema_decay: 0.9999
74
+
75
+ learning_rate: 5.e-5
76
+ train_batch_size: 3
77
+ gradient_accumulation_steps: 1
78
+ max_grad_norm: 0.5
79
+
80
+ max_train_epoch: -1
81
+ max_train_steps: 200000
82
+ checkpointing_epochs: -1
83
+ checkpointing_steps: 2000
84
+ validation_steps: 1000
85
+
86
+ seed: 42
87
+ mixed_precision: "bf16"
88
+ num_workers: 32
89
+ enable_xformers_memory_efficient_attention: true
90
+
91
+ is_image: false
92
+ is_debug: false
consisti2v/data/dataset.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, csv, math, random
2
+ import json
3
+ import numpy as np
4
+ from einops import rearrange
5
+ from decord import VideoReader
6
+
7
+ import torch
8
+ import torchvision.transforms as transforms
9
+ from torch.utils.data.dataset import Dataset
10
+
11
+ from diffusers.utils import logging
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+ class WebVid10M(Dataset):
16
+ def __init__(
17
+ self,
18
+ json_path, video_folder=None,
19
+ sample_size=256, sample_stride=4, sample_n_frames=16,
20
+ is_image=False,
21
+ **kwargs,
22
+ ):
23
+ logger.info(f"loading annotations from {json_path} ...")
24
+ with open(json_path, 'rb') as json_file:
25
+ json_list = list(json_file)
26
+ self.dataset = [json.loads(json_str) for json_str in json_list]
27
+ self.length = len(self.dataset)
28
+ logger.info(f"data scale: {self.length}")
29
+
30
+ self.video_folder = video_folder
31
+ self.sample_stride = sample_stride if isinstance(sample_stride, int) else tuple(sample_stride)
32
+ self.sample_n_frames = sample_n_frames
33
+ self.is_image = is_image
34
+
35
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
36
+ self.pixel_transforms = transforms.Compose([
37
+ transforms.RandomHorizontalFlip(),
38
+ transforms.Resize(sample_size[0], antialias=None),
39
+ transforms.CenterCrop(sample_size),
40
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
41
+ ])
42
+
43
+ def get_batch(self, idx):
44
+ video_dict = self.dataset[idx]
45
+ video_relative_path, name = video_dict['file'], video_dict['text']
46
+
47
+ if self.video_folder is not None:
48
+ if video_relative_path[0] == '/':
49
+ video_dir = os.path.join(self.video_folder, os.path.basename(video_relative_path))
50
+ else:
51
+ video_dir = os.path.join(self.video_folder, video_relative_path)
52
+ else:
53
+ video_dir = video_relative_path
54
+ video_reader = VideoReader(video_dir)
55
+ video_length = len(video_reader)
56
+
57
+ if not self.is_image:
58
+ if isinstance(self.sample_stride, int):
59
+ stride = self.sample_stride
60
+ elif isinstance(self.sample_stride, tuple):
61
+ stride = random.randint(self.sample_stride[0], self.sample_stride[1])
62
+ clip_length = min(video_length, (self.sample_n_frames - 1) * stride + 1)
63
+ start_idx = random.randint(0, video_length - clip_length)
64
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
65
+ else:
66
+ frame_difference = random.randint(2, self.sample_n_frames)
67
+ clip_length = min(video_length, (frame_difference - 1) * self.sample_stride + 1)
68
+ start_idx = random.randint(0, video_length - clip_length)
69
+ batch_index = [start_idx, start_idx + clip_length - 1]
70
+
71
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
72
+ pixel_values = pixel_values / 255.
73
+ del video_reader
74
+
75
+ return pixel_values, name
76
+
77
+ def __len__(self):
78
+ return self.length
79
+
80
+ def __getitem__(self, idx):
81
+ while True:
82
+ try:
83
+ pixel_values, name = self.get_batch(idx)
84
+ break
85
+
86
+ except Exception as e:
87
+ idx = random.randint(0, self.length-1)
88
+
89
+ pixel_values = self.pixel_transforms(pixel_values)
90
+ sample = dict(pixel_values=pixel_values, text=name)
91
+ return sample
92
+
93
+
94
+ class Pexels(Dataset):
95
+ def __init__(
96
+ self,
97
+ json_path, caption_json_path, video_folder=None,
98
+ sample_size=256, sample_duration=1, sample_fps=8,
99
+ is_image=False,
100
+ **kwargs,
101
+ ):
102
+ logger.info(f"loading captions from {caption_json_path} ...")
103
+ with open(caption_json_path, 'rb') as caption_json_file:
104
+ caption_json_list = list(caption_json_file)
105
+ self.caption_dict = {json.loads(json_str)['id']: json.loads(json_str)['text'] for json_str in caption_json_list}
106
+
107
+ logger.info(f"loading annotations from {json_path} ...")
108
+ with open(json_path, 'rb') as json_file:
109
+ json_list = list(json_file)
110
+ dataset = [json.loads(json_str) for json_str in json_list]
111
+
112
+ self.dataset = []
113
+ for data in dataset:
114
+ data['text'] = self.caption_dict[data['id']]
115
+ if data['height'] / data['width'] < 0.625:
116
+ self.dataset.append(data)
117
+ self.length = len(self.dataset)
118
+ logger.info(f"data scale: {self.length}")
119
+
120
+ self.video_folder = video_folder
121
+ self.sample_duration = sample_duration
122
+ self.sample_fps = sample_fps
123
+ self.sample_n_frames = sample_duration * sample_fps
124
+ self.is_image = is_image
125
+
126
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
127
+ self.pixel_transforms = transforms.Compose([
128
+ transforms.RandomHorizontalFlip(),
129
+ transforms.Resize(sample_size[0], antialias=None),
130
+ transforms.CenterCrop(sample_size),
131
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
132
+ ])
133
+
134
+ def get_batch(self, idx):
135
+ video_dict = self.dataset[idx]
136
+ video_relative_path, name = video_dict['file'], video_dict['text']
137
+ fps = video_dict['fps']
138
+
139
+ if self.video_folder is not None:
140
+ if video_relative_path[0] == '/':
141
+ video_dir = os.path.join(self.video_folder, os.path.basename(video_relative_path))
142
+ else:
143
+ video_dir = os.path.join(self.video_folder, video_relative_path)
144
+ else:
145
+ video_dir = video_relative_path
146
+ video_reader = VideoReader(video_dir)
147
+ video_length = len(video_reader)
148
+
149
+ if not self.is_image:
150
+ clip_length = min(video_length, math.ceil(fps * self.sample_duration))
151
+ start_idx = random.randint(0, video_length - clip_length)
152
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
153
+ else:
154
+ frame_difference = random.randint(2, self.sample_n_frames)
155
+ sample_stride = math.ceil((fps * self.sample_duration) / (self.sample_n_frames - 1) - 1)
156
+ clip_length = min(video_length, (frame_difference - 1) * sample_stride + 1)
157
+ start_idx = random.randint(0, video_length - clip_length)
158
+ batch_index = [start_idx, start_idx + clip_length - 1]
159
+
160
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
161
+ pixel_values = pixel_values / 255.
162
+ del video_reader
163
+
164
+ return pixel_values, name
165
+
166
+ def __len__(self):
167
+ return self.length
168
+
169
+ def __getitem__(self, idx):
170
+ while True:
171
+ try:
172
+ pixel_values, name = self.get_batch(idx)
173
+ break
174
+
175
+ except Exception as e:
176
+ idx = random.randint(0, self.length-1)
177
+
178
+ pixel_values = self.pixel_transforms(pixel_values)
179
+ sample = dict(pixel_values=pixel_values, text=name)
180
+ return sample
181
+
182
+
183
+ class JointDataset(Dataset):
184
+ def __init__(
185
+ self,
186
+ webvid_config, pexels_config,
187
+ sample_size=256,
188
+ sample_duration=None, sample_fps=None, sample_stride=None, sample_n_frames=None,
189
+ is_image=False,
190
+ **kwargs,
191
+ ):
192
+ assert (sample_duration is None and sample_fps is None) or (sample_duration is not None and sample_fps is not None), "sample_duration and sample_fps should be both None or not None"
193
+ if sample_duration is not None and sample_fps is not None:
194
+ assert sample_stride is None, "when sample_duration and sample_fps are not None, sample_stride should be None"
195
+ if sample_stride is not None:
196
+ assert sample_fps is None and sample_duration is None, "when sample_stride is not None, sample_duration and sample_fps should be both None"
197
+
198
+ self.dataset = []
199
+
200
+ if pexels_config.enable:
201
+ logger.info(f"loading pexels dataset")
202
+ logger.info(f"loading captions from {pexels_config.caption_json_path} ...")
203
+ with open(pexels_config.caption_json_path, 'rb') as caption_json_file:
204
+ caption_json_list = list(caption_json_file)
205
+ self.caption_dict = {json.loads(json_str)['id']: json.loads(json_str)['text'] for json_str in caption_json_list}
206
+
207
+ logger.info(f"loading annotations from {pexels_config.json_path} ...")
208
+ with open(pexels_config.json_path, 'rb') as json_file:
209
+ json_list = list(json_file)
210
+ dataset = [json.loads(json_str) for json_str in json_list]
211
+
212
+ for data in dataset:
213
+ data['text'] = self.caption_dict[data['id']]
214
+ data['dataset'] = 'pexels'
215
+ if data['height'] / data['width'] < 0.625:
216
+ self.dataset.append(data)
217
+
218
+ if webvid_config.enable:
219
+ logger.info(f"loading webvid dataset")
220
+ logger.info(f"loading annotations from {webvid_config.json_path} ...")
221
+ with open(webvid_config.json_path, 'rb') as json_file:
222
+ json_list = list(json_file)
223
+ dataset = [json.loads(json_str) for json_str in json_list]
224
+ for data in dataset:
225
+ data['dataset'] = 'webvid'
226
+ self.dataset.extend(dataset)
227
+
228
+ self.length = len(self.dataset)
229
+ logger.info(f"data scale: {self.length}")
230
+
231
+ self.pexels_folder = pexels_config.video_folder
232
+ self.webvid_folder = webvid_config.video_folder
233
+ self.sample_duration = sample_duration
234
+ self.sample_fps = sample_fps
235
+ self.sample_n_frames = sample_duration * sample_fps if sample_n_frames is None else sample_n_frames
236
+ self.sample_stride = sample_stride if (sample_stride is None) or (sample_stride is not None and isinstance(sample_stride, int)) else tuple(sample_stride)
237
+ self.is_image = is_image
238
+
239
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
240
+ self.pixel_transforms = transforms.Compose([
241
+ transforms.RandomHorizontalFlip(),
242
+ transforms.Resize(sample_size[0], antialias=None),
243
+ transforms.CenterCrop(sample_size),
244
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
245
+ ])
246
+
247
+ def get_batch(self, idx):
248
+ video_dict = self.dataset[idx]
249
+ video_relative_path, name = video_dict['file'], video_dict['text']
250
+
251
+ if video_dict['dataset'] == 'pexels':
252
+ video_folder = self.pexels_folder
253
+ elif video_dict['dataset'] == 'webvid':
254
+ video_folder = self.webvid_folder
255
+ else:
256
+ raise NotImplementedError
257
+
258
+ if video_folder is not None:
259
+ if video_relative_path[0] == '/':
260
+ video_dir = os.path.join(video_folder, os.path.basename(video_relative_path))
261
+ else:
262
+ video_dir = os.path.join(video_folder, video_relative_path)
263
+ else:
264
+ video_dir = video_relative_path
265
+ video_reader = VideoReader(video_dir)
266
+ video_length = len(video_reader)
267
+
268
+ stride = None
269
+ if not self.is_image:
270
+ if self.sample_duration is not None:
271
+ fps = video_dict['fps']
272
+ clip_length = min(video_length, math.ceil(fps * self.sample_duration))
273
+ elif self.sample_stride is not None:
274
+ if isinstance(self.sample_stride, int):
275
+ stride = self.sample_stride
276
+ elif isinstance(self.sample_stride, tuple):
277
+ stride = random.randint(self.sample_stride[0], self.sample_stride[1])
278
+ clip_length = min(video_length, (self.sample_n_frames - 1) * stride + 1)
279
+
280
+ start_idx = random.randint(0, video_length - clip_length)
281
+ batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
282
+
283
+ else:
284
+ frame_difference = random.randint(2, self.sample_n_frames)
285
+ if self.sample_duration is not None:
286
+ fps = video_dict['fps']
287
+ sample_stride = math.ceil((fps * self.sample_duration) / (self.sample_n_frames - 1) - 1)
288
+ elif self.sample_stride is not None:
289
+ sample_stride = self.sample_stride
290
+
291
+ clip_length = min(video_length, (frame_difference - 1) * sample_stride + 1)
292
+ start_idx = random.randint(0, video_length - clip_length)
293
+ batch_index = [start_idx, start_idx + clip_length - 1]
294
+
295
+ pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
296
+ pixel_values = pixel_values / 255.
297
+ del video_reader
298
+
299
+ return pixel_values, name, stride
300
+
301
+ def __len__(self):
302
+ return self.length
303
+
304
+ def __getitem__(self, idx):
305
+ while True:
306
+ try:
307
+ pixel_values, name, stride = self.get_batch(idx)
308
+ break
309
+
310
+ except Exception as e:
311
+ idx = random.randint(0, self.length-1)
312
+
313
+ pixel_values = self.pixel_transforms(pixel_values)
314
+ sample = dict(pixel_values=pixel_values, text=name, stride=stride)
315
+ return sample
consisti2v/models/rotary_embedding.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi, log
2
+
3
+ import torch
4
+ from torch.nn import Module, ModuleList
5
+ from torch.cuda.amp import autocast
6
+ from torch import nn, einsum, broadcast_tensors, Tensor
7
+
8
+ from einops import rearrange, repeat
9
+
10
+ from beartype import beartype
11
+ from beartype.typing import Literal, Union, Optional
12
+
13
+ # helper functions
14
+
15
+ def exists(val):
16
+ return val is not None
17
+
18
+ def default(val, d):
19
+ return val if exists(val) else d
20
+
21
+ # broadcat, as tortoise-tts was using it
22
+
23
+ def broadcat(tensors, dim = -1):
24
+ broadcasted_tensors = broadcast_tensors(*tensors)
25
+ return torch.cat(broadcasted_tensors, dim = dim)
26
+
27
+ # rotary embedding helper functions
28
+
29
+ def rotate_half(x):
30
+ x = rearrange(x, '... (d r) -> ... d r', r = 2)
31
+ x1, x2 = x.unbind(dim = -1)
32
+ x = torch.stack((-x2, x1), dim = -1)
33
+ return rearrange(x, '... d r -> ... (d r)')
34
+
35
+ @autocast(enabled = False)
36
+ def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2):
37
+ if t.ndim == 3:
38
+ seq_len = t.shape[seq_dim]
39
+ freqs = freqs[-seq_len:].to(t)
40
+
41
+ rot_dim = freqs.shape[-1]
42
+ end_index = start_index + rot_dim
43
+
44
+ assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
45
+
46
+ t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
47
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
48
+ return torch.cat((t_left, t, t_right), dim = -1)
49
+
50
+ # learned rotation helpers
51
+
52
+ def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None):
53
+ if exists(freq_ranges):
54
+ rotations = einsum('..., f -> ... f', rotations, freq_ranges)
55
+ rotations = rearrange(rotations, '... r f -> ... (r f)')
56
+
57
+ rotations = repeat(rotations, '... n -> ... (n r)', r = 2)
58
+ return apply_rotary_emb(rotations, t, start_index = start_index)
59
+
60
+ # classes
61
+
62
+ class RotaryEmbedding(Module):
63
+ @beartype
64
+ def __init__(
65
+ self,
66
+ dim,
67
+ custom_freqs: Optional[Tensor] = None,
68
+ freqs_for: Union[
69
+ Literal['lang'],
70
+ Literal['pixel'],
71
+ Literal['constant']
72
+ ] = 'lang',
73
+ theta = 10000,
74
+ max_freq = 10,
75
+ num_freqs = 1,
76
+ learned_freq = False,
77
+ use_xpos = False,
78
+ xpos_scale_base = 512,
79
+ interpolate_factor = 1.,
80
+ theta_rescale_factor = 1.,
81
+ seq_before_head_dim = False
82
+ ):
83
+ super().__init__()
84
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
85
+ # has some connection to NTK literature
86
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
87
+
88
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
89
+
90
+ self.freqs_for = freqs_for
91
+
92
+ if exists(custom_freqs):
93
+ freqs = custom_freqs
94
+ elif freqs_for == 'lang':
95
+ freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
96
+ elif freqs_for == 'pixel':
97
+ freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
98
+ elif freqs_for == 'constant':
99
+ freqs = torch.ones(num_freqs).float()
100
+
101
+ self.tmp_store('cached_freqs', None)
102
+ self.tmp_store('cached_scales', None)
103
+
104
+ self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
105
+
106
+ self.learned_freq = learned_freq
107
+
108
+ # dummy for device
109
+
110
+ self.tmp_store('dummy', torch.tensor(0))
111
+
112
+ # default sequence dimension
113
+
114
+ self.seq_before_head_dim = seq_before_head_dim
115
+ self.default_seq_dim = -3 if seq_before_head_dim else -2
116
+
117
+ # interpolation factors
118
+
119
+ assert interpolate_factor >= 1.
120
+ self.interpolate_factor = interpolate_factor
121
+
122
+ # xpos
123
+
124
+ self.use_xpos = use_xpos
125
+ if not use_xpos:
126
+ self.tmp_store('scale', None)
127
+ return
128
+
129
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
130
+ self.scale_base = xpos_scale_base
131
+ self.tmp_store('scale', scale)
132
+
133
+ @property
134
+ def device(self):
135
+ return self.dummy.device
136
+
137
+ def tmp_store(self, key, value):
138
+ self.register_buffer(key, value, persistent = False)
139
+
140
+ def get_seq_pos(self, seq_len, device, dtype, offset = 0):
141
+ return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor
142
+
143
+ def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, freq_seq_len = None, seq_pos = None):
144
+ seq_dim = default(seq_dim, self.default_seq_dim)
145
+
146
+ assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
147
+
148
+ device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
149
+
150
+ if exists(freq_seq_len):
151
+ assert freq_seq_len >= seq_len
152
+ seq_len = freq_seq_len
153
+
154
+ if seq_pos is None:
155
+ seq_pos = self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset)
156
+ else:
157
+ assert seq_pos.shape[0] == seq_len
158
+
159
+ freqs = self.forward(seq_pos, seq_len = seq_len, offset = offset)
160
+
161
+ if seq_dim == -3:
162
+ freqs = rearrange(freqs, 'n d -> n 1 d')
163
+
164
+ return apply_rotary_emb(freqs, t, seq_dim = seq_dim)
165
+
166
+ def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0):
167
+ seq_dim = default(seq_dim, self.default_seq_dim)
168
+
169
+ q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
170
+ assert q_len <= k_len
171
+ rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, freq_seq_len = k_len)
172
+ rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim)
173
+
174
+ rotated_q = rotated_q.type(q.dtype)
175
+ rotated_k = rotated_k.type(k.dtype)
176
+
177
+ return rotated_q, rotated_k
178
+
179
+ def rotate_queries_and_keys(self, q, k, seq_dim = None):
180
+ seq_dim = default(seq_dim, self.default_seq_dim)
181
+
182
+ assert self.use_xpos
183
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
184
+
185
+ seq = self.get_seq_pos(seq_len, dtype = dtype, device = device)
186
+
187
+ freqs = self.forward(seq, seq_len = seq_len)
188
+ scale = self.get_scale(seq, seq_len = seq_len).to(dtype)
189
+
190
+ if seq_dim == -3:
191
+ freqs = rearrange(freqs, 'n d -> n 1 d')
192
+ scale = rearrange(scale, 'n d -> n 1 d')
193
+
194
+ rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim)
195
+ rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim)
196
+
197
+ rotated_q = rotated_q.type(q.dtype)
198
+ rotated_k = rotated_k.type(k.dtype)
199
+
200
+ return rotated_q, rotated_k
201
+
202
+ @beartype
203
+ def get_scale(
204
+ self,
205
+ t: Tensor,
206
+ seq_len: Optional[int] = None,
207
+ offset = 0
208
+ ):
209
+ assert self.use_xpos
210
+
211
+ should_cache = exists(seq_len)
212
+
213
+ if (
214
+ should_cache and \
215
+ exists(self.cached_scales) and \
216
+ (seq_len + offset) <= self.cached_scales.shape[0]
217
+ ):
218
+ return self.cached_scales[offset:(offset + seq_len)]
219
+
220
+ scale = 1.
221
+ if self.use_xpos:
222
+ power = (t - len(t) // 2) / self.scale_base
223
+ scale = self.scale ** rearrange(power, 'n -> n 1')
224
+ scale = torch.cat((scale, scale), dim = -1)
225
+
226
+ if should_cache:
227
+ self.tmp_store('cached_scales', scale)
228
+
229
+ return scale
230
+
231
+ def get_axial_freqs(self, *dims):
232
+ Colon = slice(None)
233
+ all_freqs = []
234
+
235
+ for ind, dim in enumerate(dims):
236
+ if self.freqs_for == 'pixel':
237
+ pos = torch.linspace(-1, 1, steps = dim, device = self.device)
238
+ else:
239
+ pos = torch.arange(dim, device = self.device)
240
+
241
+ freqs = self.forward(pos, seq_len = dim)
242
+
243
+ all_axis = [None] * len(dims)
244
+ all_axis[ind] = Colon
245
+
246
+ new_axis_slice = (Ellipsis, *all_axis, Colon)
247
+ all_freqs.append(freqs[new_axis_slice])
248
+
249
+ all_freqs = broadcast_tensors(*all_freqs)
250
+ return torch.cat(all_freqs, dim = -1)
251
+
252
+ @autocast(enabled = False)
253
+ def forward(
254
+ self,
255
+ t: Tensor,
256
+ seq_len = None,
257
+ offset = 0
258
+ ):
259
+ # should_cache = (
260
+ # not self.learned_freq and \
261
+ # exists(seq_len) and \
262
+ # self.freqs_for != 'pixel'
263
+ # )
264
+
265
+ # if (
266
+ # should_cache and \
267
+ # exists(self.cached_freqs) and \
268
+ # (offset + seq_len) <= self.cached_freqs.shape[0]
269
+ # ):
270
+ # return self.cached_freqs[offset:(offset + seq_len)].detach()
271
+
272
+ freqs = self.freqs
273
+
274
+ freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
275
+ freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
276
+
277
+ # if should_cache:
278
+ # self.tmp_store('cached_freqs', freqs.detach())
279
+
280
+ return freqs
consisti2v/models/videoldm_attention.py ADDED
@@ -0,0 +1,809 @@