Spaces:
Paused
Paused
add files
Browse files- .gitignore +18 -0
- LICENSE +21 -0
- README.md +1 -1
- app.py +307 -0
- configs/inference/inference.yaml +48 -0
- configs/inference/inference_autoregress.yaml +49 -0
- configs/prompts/default.yaml +16 -0
- configs/training/training.yaml +92 -0
- consisti2v/data/dataset.py +315 -0
- consisti2v/models/rotary_embedding.py +280 -0
- consisti2v/models/videoldm_attention.py +809 -0
- consisti2v/models/videoldm_transformer_blocks.py +564 -0
- consisti2v/models/videoldm_unet.py +1371 -0
- consisti2v/models/videoldm_unet_blocks.py +1159 -0
- consisti2v/pipelines/pipeline_autoregress_animation.py +615 -0
- consisti2v/pipelines/pipeline_conditional_animation.py +695 -0
- consisti2v/utils/frameinit_utils.py +142 -0
- consisti2v/utils/util.py +165 -0
- environment.yaml +28 -0
- requirements.txt +18 -0
- scripts/animate.py +179 -0
- scripts/animate_autoregress.py +185 -0
- train.py +617 -0
.gitignore
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
samples/
|
2 |
+
wandb/
|
3 |
+
outputs/
|
4 |
+
__pycache__/
|
5 |
+
scripts/animate_inter.py
|
6 |
+
scripts/gradio_app.py
|
7 |
+
*.ipynb
|
8 |
+
*.safetensors
|
9 |
+
*.ckpt
|
10 |
+
.ossutil_checkpoint/
|
11 |
+
ossutil_output/
|
12 |
+
debugs/
|
13 |
+
.vscode
|
14 |
+
.env
|
15 |
+
models
|
16 |
+
!*/models
|
17 |
+
.ipynb_checkpoints
|
18 |
+
checkpoints
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 TIGER Lab
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
title: ConsistI2V
|
3 |
-
emoji:
|
4 |
colorFrom: purple
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
|
|
1 |
---
|
2 |
title: ConsistI2V
|
3 |
+
emoji: 🎥
|
4 |
colorFrom: purple
|
5 |
colorTo: green
|
6 |
sdk: gradio
|
app.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
from datetime import datetime
|
12 |
+
|
13 |
+
import torchvision.transforms as T
|
14 |
+
|
15 |
+
from diffusers import DDIMScheduler
|
16 |
+
from diffusers.utils.import_utils import is_xformers_available
|
17 |
+
from consisti2v.pipelines.pipeline_conditional_animation import ConditionalAnimationPipeline
|
18 |
+
from consisti2v.utils.util import save_videos_grid
|
19 |
+
from omegaconf import OmegaConf
|
20 |
+
|
21 |
+
|
22 |
+
sample_idx = 0
|
23 |
+
scheduler_dict = {
|
24 |
+
"DDIM": DDIMScheduler,
|
25 |
+
}
|
26 |
+
|
27 |
+
css = """
|
28 |
+
.toolbutton {
|
29 |
+
margin-buttom: 0em 0em 0em 0em;
|
30 |
+
max-width: 2.5em;
|
31 |
+
min-width: 2.5em !important;
|
32 |
+
height: 2.5em;
|
33 |
+
}
|
34 |
+
"""
|
35 |
+
|
36 |
+
class AnimateController:
|
37 |
+
def __init__(self):
|
38 |
+
|
39 |
+
# config dirs
|
40 |
+
self.basedir = os.getcwd()
|
41 |
+
self.savedir = os.path.join(self.basedir, "samples/Gradio", datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
|
42 |
+
self.savedir_sample = os.path.join(self.savedir, "sample")
|
43 |
+
os.makedirs(self.savedir, exist_ok=True)
|
44 |
+
|
45 |
+
self.image_resolution = (256, 256)
|
46 |
+
# config models
|
47 |
+
self.pipeline = ConditionalAnimationPipeline.from_pretrained("TIGER-Lab/ConsistI2V", torch_dtype=torch.float16,)
|
48 |
+
self.pipeline.to("cuda")
|
49 |
+
|
50 |
+
def update_textbox_and_save_image(self, input_image, height_slider, width_slider, center_crop):
|
51 |
+
pil_image = Image.fromarray(input_image.astype(np.uint8)).convert("RGB")
|
52 |
+
img_path = os.path.join(self.savedir, "input_image.png")
|
53 |
+
pil_image.save(img_path)
|
54 |
+
self.image_resolution = pil_image.size
|
55 |
+
pil_image = pil_image.resize((width_slider, height_slider))
|
56 |
+
if center_crop:
|
57 |
+
width, height = width_slider, height_slider
|
58 |
+
aspect_ratio = width / height
|
59 |
+
if aspect_ratio > 16 / 10:
|
60 |
+
pil_image = pil_image.crop((int((width - height * 16 / 10) / 2), 0, int((width + height * 16 / 10) / 2), height))
|
61 |
+
elif aspect_ratio < 16 / 10:
|
62 |
+
pil_image = pil_image.crop((0, int((height - width * 10 / 16) / 2), width, int((height + width * 10 / 16) / 2)))
|
63 |
+
return gr.Textbox.update(value=img_path), gr.Image.update(value=np.array(pil_image))
|
64 |
+
|
65 |
+
@spaces.GPU
|
66 |
+
def animate(
|
67 |
+
self,
|
68 |
+
prompt_textbox,
|
69 |
+
negative_prompt_textbox,
|
70 |
+
input_image_path,
|
71 |
+
sampler_dropdown,
|
72 |
+
sample_step_slider,
|
73 |
+
width_slider,
|
74 |
+
height_slider,
|
75 |
+
txt_cfg_scale_slider,
|
76 |
+
img_cfg_scale_slider,
|
77 |
+
center_crop,
|
78 |
+
frame_stride,
|
79 |
+
use_frameinit,
|
80 |
+
frame_init_noise_level,
|
81 |
+
seed_textbox
|
82 |
+
):
|
83 |
+
if self.pipeline is None:
|
84 |
+
raise gr.Error(f"Please select a pretrained pipeline path.")
|
85 |
+
if input_image_path == "":
|
86 |
+
raise gr.Error(f"Please upload an input image.")
|
87 |
+
if (not center_crop) and (width_slider % 8 != 0 or height_slider % 8 != 0):
|
88 |
+
raise gr.Error(f"`height` and `width` have to be divisible by 8 but are {height_slider} and {width_slider}.")
|
89 |
+
if center_crop and (width_slider % 8 != 0 or height_slider % 8 != 0):
|
90 |
+
raise gr.Error(f"`height` and `width` (after cropping) have to be divisible by 8 but are {height_slider} and {width_slider}.")
|
91 |
+
|
92 |
+
if is_xformers_available() and int(torch.__version__.split(".")[0]) < 2: self.pipeline.unet.enable_xformers_memory_efficient_attention()
|
93 |
+
|
94 |
+
if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
|
95 |
+
else: torch.seed()
|
96 |
+
seed = torch.initial_seed()
|
97 |
+
|
98 |
+
if input_image_path.startswith("http://") or input_image_path.startswith("https://"):
|
99 |
+
first_frame = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
|
100 |
+
else:
|
101 |
+
first_frame = Image.open(input_image_path).convert('RGB')
|
102 |
+
|
103 |
+
original_width, original_height = first_frame.size
|
104 |
+
|
105 |
+
if not center_crop:
|
106 |
+
img_transform = T.Compose([
|
107 |
+
T.ToTensor(),
|
108 |
+
T.Resize((height_slider, width_slider), antialias=None),
|
109 |
+
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
110 |
+
])
|
111 |
+
else:
|
112 |
+
aspect_ratio = original_width / original_height
|
113 |
+
crop_aspect_ratio = width_slider / height_slider
|
114 |
+
if aspect_ratio > crop_aspect_ratio:
|
115 |
+
center_crop_width = int(crop_aspect_ratio * original_height)
|
116 |
+
center_crop_height = original_height
|
117 |
+
elif aspect_ratio < crop_aspect_ratio:
|
118 |
+
center_crop_width = original_width
|
119 |
+
center_crop_height = int(original_width / crop_aspect_ratio)
|
120 |
+
else:
|
121 |
+
center_crop_width = original_width
|
122 |
+
center_crop_height = original_height
|
123 |
+
img_transform = T.Compose([
|
124 |
+
T.ToTensor(),
|
125 |
+
T.CenterCrop((center_crop_height, center_crop_width)),
|
126 |
+
T.Resize((height_slider, width_slider), antialias=None),
|
127 |
+
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
128 |
+
])
|
129 |
+
|
130 |
+
first_frame = img_transform(first_frame).unsqueeze(0)
|
131 |
+
first_frame = first_frame.to("cuda")
|
132 |
+
|
133 |
+
if use_frameinit:
|
134 |
+
self.pipeline.init_filter(
|
135 |
+
width = width_slider,
|
136 |
+
height = height_slider,
|
137 |
+
video_length = 16,
|
138 |
+
filter_params = OmegaConf.create({'method': 'gaussian', 'd_s': 0.25, 'd_t': 0.25,})
|
139 |
+
)
|
140 |
+
|
141 |
+
|
142 |
+
sample = self.pipeline(
|
143 |
+
prompt_textbox,
|
144 |
+
negative_prompt = negative_prompt_textbox,
|
145 |
+
first_frames = first_frame,
|
146 |
+
num_inference_steps = sample_step_slider,
|
147 |
+
guidance_scale_txt = txt_cfg_scale_slider,
|
148 |
+
guidance_scale_img = img_cfg_scale_slider,
|
149 |
+
width = width_slider,
|
150 |
+
height = height_slider,
|
151 |
+
video_length = 16,
|
152 |
+
noise_sampling_method = "pyoco_mixed",
|
153 |
+
noise_alpha = 1.0,
|
154 |
+
frame_stride = frame_stride,
|
155 |
+
use_frameinit = use_frameinit,
|
156 |
+
frameinit_noise_level = frame_init_noise_level,
|
157 |
+
camera_motion = None,
|
158 |
+
).videos
|
159 |
+
|
160 |
+
global sample_idx
|
161 |
+
sample_idx += 1
|
162 |
+
save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4")
|
163 |
+
save_videos_grid(sample, save_sample_path, format="mp4")
|
164 |
+
|
165 |
+
sample_config = {
|
166 |
+
"prompt": prompt_textbox,
|
167 |
+
"n_prompt": negative_prompt_textbox,
|
168 |
+
"first_frame_path": input_image_path,
|
169 |
+
"sampler": sampler_dropdown,
|
170 |
+
"num_inference_steps": sample_step_slider,
|
171 |
+
"guidance_scale_text": txt_cfg_scale_slider,
|
172 |
+
"guidance_scale_image": img_cfg_scale_slider,
|
173 |
+
"width": width_slider,
|
174 |
+
"height": height_slider,
|
175 |
+
"video_length": 8,
|
176 |
+
"seed": seed
|
177 |
+
}
|
178 |
+
json_str = json.dumps(sample_config, indent=4)
|
179 |
+
with open(os.path.join(self.savedir, "logs.json"), "a") as f:
|
180 |
+
f.write(json_str)
|
181 |
+
f.write("\n\n")
|
182 |
+
|
183 |
+
return gr.Video.update(value=save_sample_path)
|
184 |
+
|
185 |
+
|
186 |
+
controller = AnimateController()
|
187 |
+
|
188 |
+
|
189 |
+
def ui():
|
190 |
+
with gr.Blocks(css=css) as demo:
|
191 |
+
gr.Markdown(
|
192 |
+
"""
|
193 |
+
# ConsistI2V Text+Image to Video Generation
|
194 |
+
Input image will be used as the first frame of the video. Text prompts will be used to control the output video content.
|
195 |
+
"""
|
196 |
+
)
|
197 |
+
|
198 |
+
with gr.Column(variant="panel"):
|
199 |
+
gr.Markdown(
|
200 |
+
"""
|
201 |
+
- Input image can be specified using the "Input Image Path/URL" text box (this can be either a local image path or an image URL) or uploaded by clicking or dragging the image to the "Input Image" box. The uploaded image will be temporarily stored in the "samples/Gradio" folder under the project root folder.
|
202 |
+
- Input image can be resized and/or center cropped to a given resolution by adjusting the "Width" and "Height" sliders. It is recommended to use the same resolution as the training resolution (256x256).
|
203 |
+
- After setting the input image path or changed the width/height of the input image, press the "Preview" button to visualize the resized input image.
|
204 |
+
"""
|
205 |
+
)
|
206 |
+
|
207 |
+
with gr.Row():
|
208 |
+
prompt_textbox = gr.Textbox(label="Prompt", lines=2)
|
209 |
+
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2)
|
210 |
+
|
211 |
+
with gr.Row().style(equal_height=False):
|
212 |
+
with gr.Column():
|
213 |
+
with gr.Row():
|
214 |
+
sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
|
215 |
+
sample_step_slider = gr.Slider(label="Sampling steps", value=50, minimum=10, maximum=250, step=1)
|
216 |
+
|
217 |
+
with gr.Row():
|
218 |
+
center_crop = gr.Checkbox(label="Center Crop the Image", value=True)
|
219 |
+
width_slider = gr.Slider(label="Width", value=256, minimum=0, maximum=512, step=64)
|
220 |
+
height_slider = gr.Slider(label="Height", value=256, minimum=0, maximum=512, step=64)
|
221 |
+
with gr.Row():
|
222 |
+
txt_cfg_scale_slider = gr.Slider(label="Text CFG Scale", value=7.5, minimum=1.0, maximum=20.0, step=0.5)
|
223 |
+
img_cfg_scale_slider = gr.Slider(label="Image CFG Scale", value=1.0, minimum=1.0, maximum=20.0, step=0.5)
|
224 |
+
frame_stride = gr.Slider(label="Frame Stride", value=3, minimum=1, maximum=5, step=1)
|
225 |
+
|
226 |
+
with gr.Row():
|
227 |
+
use_frameinit = gr.Checkbox(label="Enable FrameInit", value=True)
|
228 |
+
frameinit_noise_level = gr.Slider(label="FrameInit Noise Level", value=850, minimum=1, maximum=999, step=1)
|
229 |
+
|
230 |
+
|
231 |
+
seed_textbox = gr.Textbox(label="Seed", value=-1)
|
232 |
+
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
|
233 |
+
seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
|
234 |
+
|
235 |
+
|
236 |
+
|
237 |
+
generate_button = gr.Button(value="Generate", variant='primary')
|
238 |
+
|
239 |
+
with gr.Column():
|
240 |
+
with gr.Row():
|
241 |
+
input_image_path = gr.Textbox(label="Input Image Path/URL", lines=1, scale=10, info="Press Enter or the Preview button to confirm the input image.")
|
242 |
+
preview_button = gr.Button(value="Preview")
|
243 |
+
|
244 |
+
with gr.Row():
|
245 |
+
input_image = gr.Image(label="Input Image", interactive=True)
|
246 |
+
input_image.upload(fn=controller.update_textbox_and_save_image, inputs=[input_image, height_slider, width_slider, center_crop], outputs=[input_image_path, input_image])
|
247 |
+
result_video = gr.Video(label="Generated Animation", interactive=False, autoplay=True)
|
248 |
+
|
249 |
+
def update_and_resize_image(input_image_path, height_slider, width_slider, center_crop):
|
250 |
+
if input_image_path.startswith("http://") or input_image_path.startswith("https://"):
|
251 |
+
pil_image = Image.open(requests.get(input_image_path, stream=True).raw).convert('RGB')
|
252 |
+
else:
|
253 |
+
pil_image = Image.open(input_image_path).convert('RGB')
|
254 |
+
controller.image_resolution = pil_image.size
|
255 |
+
original_width, original_height = pil_image.size
|
256 |
+
|
257 |
+
if center_crop:
|
258 |
+
crop_aspect_ratio = width_slider / height_slider
|
259 |
+
aspect_ratio = original_width / original_height
|
260 |
+
if aspect_ratio > crop_aspect_ratio:
|
261 |
+
new_width = int(crop_aspect_ratio * original_height)
|
262 |
+
left = (original_width - new_width) / 2
|
263 |
+
top = 0
|
264 |
+
right = left + new_width
|
265 |
+
bottom = original_height
|
266 |
+
pil_image = pil_image.crop((left, top, right, bottom))
|
267 |
+
elif aspect_ratio < crop_aspect_ratio:
|
268 |
+
new_height = int(original_width / crop_aspect_ratio)
|
269 |
+
top = (original_height - new_height) / 2
|
270 |
+
left = 0
|
271 |
+
right = original_width
|
272 |
+
bottom = top + new_height
|
273 |
+
pil_image = pil_image.crop((left, top, right, bottom))
|
274 |
+
|
275 |
+
pil_image = pil_image.resize((width_slider, height_slider))
|
276 |
+
return gr.Image.update(value=np.array(pil_image))
|
277 |
+
|
278 |
+
preview_button.click(fn=update_and_resize_image, inputs=[input_image_path, height_slider, width_slider, center_crop], outputs=[input_image])
|
279 |
+
input_image_path.submit(fn=update_and_resize_image, inputs=[input_image_path, height_slider, width_slider, center_crop], outputs=[input_image])
|
280 |
+
|
281 |
+
generate_button.click(
|
282 |
+
fn=controller.animate,
|
283 |
+
inputs=[
|
284 |
+
prompt_textbox,
|
285 |
+
negative_prompt_textbox,
|
286 |
+
input_image_path,
|
287 |
+
sampler_dropdown,
|
288 |
+
sample_step_slider,
|
289 |
+
width_slider,
|
290 |
+
height_slider,
|
291 |
+
txt_cfg_scale_slider,
|
292 |
+
img_cfg_scale_slider,
|
293 |
+
center_crop,
|
294 |
+
frame_stride,
|
295 |
+
use_frameinit,
|
296 |
+
frameinit_noise_level,
|
297 |
+
seed_textbox,
|
298 |
+
],
|
299 |
+
outputs=[result_video]
|
300 |
+
)
|
301 |
+
|
302 |
+
return demo
|
303 |
+
|
304 |
+
|
305 |
+
if __name__ == "__main__":
|
306 |
+
demo = ui()
|
307 |
+
demo.launch(share=True)
|
configs/inference/inference.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
output_dir: "samples/inference"
|
2 |
+
output_name: "i2v"
|
3 |
+
|
4 |
+
pretrained_model_path: "TIGER-Lab/ConsistI2V"
|
5 |
+
unet_path: null
|
6 |
+
unet_ckpt_prefix: "module."
|
7 |
+
pipeline_pretrained_path: null
|
8 |
+
|
9 |
+
sampling_kwargs:
|
10 |
+
height: 256
|
11 |
+
width: 256
|
12 |
+
n_frames: 16
|
13 |
+
steps: 50
|
14 |
+
ddim_eta: 0.0
|
15 |
+
guidance_scale_txt: 7.5
|
16 |
+
guidance_scale_img: 1.0
|
17 |
+
guidance_rescale: 0.0
|
18 |
+
num_videos_per_prompt: 1
|
19 |
+
frame_stride: 3
|
20 |
+
|
21 |
+
unet_additional_kwargs:
|
22 |
+
variant: null
|
23 |
+
n_temp_heads: 8
|
24 |
+
augment_temporal_attention: true
|
25 |
+
temp_pos_embedding: "rotary" # "rotary" or "sinusoidal"
|
26 |
+
first_frame_condition_mode: "concat"
|
27 |
+
use_frame_stride_condition: true
|
28 |
+
noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive"
|
29 |
+
noise_alpha: 1.0
|
30 |
+
|
31 |
+
noise_scheduler_kwargs:
|
32 |
+
beta_start: 0.00085
|
33 |
+
beta_end: 0.012
|
34 |
+
beta_schedule: "linear"
|
35 |
+
steps_offset: 1
|
36 |
+
clip_sample: false
|
37 |
+
rescale_betas_zero_snr: false # true if using zero terminal snr
|
38 |
+
timestep_spacing: "leading" # "trailing" if using zero terminal snr
|
39 |
+
prediction_type: "epsilon" # "v_prediction" if using zero terminal snr
|
40 |
+
|
41 |
+
frameinit_kwargs:
|
42 |
+
enable: true
|
43 |
+
camera_motion: null
|
44 |
+
noise_level: 850
|
45 |
+
filter_params:
|
46 |
+
method: 'gaussian'
|
47 |
+
d_s: 0.25
|
48 |
+
d_t: 0.25
|
configs/inference/inference_autoregress.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
output_dir: "samples/inference"
|
2 |
+
output_name: "long_video"
|
3 |
+
|
4 |
+
pretrained_model_path: "TIGER-Lab/ConsistI2V"
|
5 |
+
unet_path: null
|
6 |
+
unet_ckpt_prefix: "module."
|
7 |
+
pipeline_pretrained_path: null
|
8 |
+
|
9 |
+
sampling_kwargs:
|
10 |
+
height: 256
|
11 |
+
width: 256
|
12 |
+
n_frames: 16
|
13 |
+
steps: 50
|
14 |
+
ddim_eta: 0.0
|
15 |
+
guidance_scale_txt: 7.5
|
16 |
+
guidance_scale_img: 1.0
|
17 |
+
guidance_rescale: 0.0
|
18 |
+
num_videos_per_prompt: 1
|
19 |
+
frame_stride: 3
|
20 |
+
autoregress_steps: 3
|
21 |
+
|
22 |
+
unet_additional_kwargs:
|
23 |
+
variant: null
|
24 |
+
n_temp_heads: 8
|
25 |
+
augment_temporal_attention: true
|
26 |
+
temp_pos_embedding: "rotary" # "rotary" or "sinusoidal"
|
27 |
+
first_frame_condition_mode: "concat"
|
28 |
+
use_frame_stride_condition: true
|
29 |
+
noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive"
|
30 |
+
noise_alpha: 1.0
|
31 |
+
|
32 |
+
noise_scheduler_kwargs:
|
33 |
+
beta_start: 0.00085
|
34 |
+
beta_end: 0.012
|
35 |
+
beta_schedule: "linear"
|
36 |
+
steps_offset: 1
|
37 |
+
clip_sample: false
|
38 |
+
rescale_betas_zero_snr: false # true if using zero terminal snr
|
39 |
+
timestep_spacing: "leading" # "trailing" if using zero terminal snr
|
40 |
+
prediction_type: "epsilon" # "v_prediction" if using zero terminal snr
|
41 |
+
|
42 |
+
|
43 |
+
frameinit_kwargs:
|
44 |
+
enable: true
|
45 |
+
noise_level: 850
|
46 |
+
filter_params:
|
47 |
+
method: 'gaussian'
|
48 |
+
d_s: 0.25
|
49 |
+
d_t: 0.25
|
configs/prompts/default.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
seeds: random
|
2 |
+
|
3 |
+
prompts:
|
4 |
+
- "timelapse at the snow land with aurora in the sky."
|
5 |
+
- "fireworks."
|
6 |
+
- "clown fish swimming through the coral reef."
|
7 |
+
- "melting ice cream dripping down the cone."
|
8 |
+
|
9 |
+
n_prompts:
|
10 |
+
- ""
|
11 |
+
|
12 |
+
path_to_first_frames:
|
13 |
+
- "assets/example/example_01.png"
|
14 |
+
- "assets/example/example_02.png"
|
15 |
+
- "assets/example/example_03.png"
|
16 |
+
- "assets/example/example_04.png"
|
configs/training/training.yaml
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
output_dir: "checkpoints"
|
2 |
+
pretrained_model_path: "stabilityai/stable-diffusion-2-1-base"
|
3 |
+
|
4 |
+
noise_scheduler_kwargs:
|
5 |
+
num_train_timesteps: 1000
|
6 |
+
beta_start: 0.00085
|
7 |
+
beta_end: 0.012
|
8 |
+
beta_schedule: "linear"
|
9 |
+
steps_offset: 1
|
10 |
+
clip_sample: false
|
11 |
+
rescale_betas_zero_snr: false # true if using zero terminal snr
|
12 |
+
timestep_spacing: "leading" # "trailing" if using zero terminal snr
|
13 |
+
prediction_type: "epsilon" # "v_prediction" if using zero terminal snr
|
14 |
+
|
15 |
+
train_data:
|
16 |
+
dataset: "joint"
|
17 |
+
pexels_config:
|
18 |
+
enable: false
|
19 |
+
json_path: null
|
20 |
+
caption_json_path: null
|
21 |
+
video_folder: null
|
22 |
+
webvid_config:
|
23 |
+
enable: true
|
24 |
+
json_path: "/path/to/webvid/annotation"
|
25 |
+
video_folder: "/path/to/webvid/data"
|
26 |
+
sample_size: 256
|
27 |
+
sample_duration: null
|
28 |
+
sample_fps: null
|
29 |
+
sample_stride: [1, 5]
|
30 |
+
sample_n_frames: 16
|
31 |
+
|
32 |
+
validation_data:
|
33 |
+
prompts:
|
34 |
+
- "timelapse at the snow land with aurora in the sky."
|
35 |
+
- "fireworks."
|
36 |
+
- "clown fish swimming through the coral reef."
|
37 |
+
- "melting ice cream dripping down the cone."
|
38 |
+
|
39 |
+
path_to_first_frames:
|
40 |
+
- "assets/example/example_01.jpg"
|
41 |
+
- "assets/example/example_02.jpg"
|
42 |
+
- "assets/example/example_03.jpg"
|
43 |
+
- "assets/example/example_04.jpg"
|
44 |
+
|
45 |
+
num_inference_steps: 50
|
46 |
+
ddim_eta: 0.0
|
47 |
+
guidance_scale_txt: 7.5
|
48 |
+
guidance_scale_img: 1.0
|
49 |
+
guidance_rescale: 0.0
|
50 |
+
frame_stride: 3
|
51 |
+
|
52 |
+
trainable_modules:
|
53 |
+
- "all"
|
54 |
+
# - "conv3ds."
|
55 |
+
# - "tempo_attns."
|
56 |
+
|
57 |
+
resume_from_checkpoint: null
|
58 |
+
|
59 |
+
unet_additional_kwargs:
|
60 |
+
variant: null
|
61 |
+
n_temp_heads: 8
|
62 |
+
augment_temporal_attention: true
|
63 |
+
temp_pos_embedding: "rotary" # "rotary" or "sinusoidal"
|
64 |
+
first_frame_condition_mode: "concat"
|
65 |
+
use_frame_stride_condition: true
|
66 |
+
noise_sampling_method: "pyoco_mixed" # "vanilla" or "pyoco_mixed" or "pyoco_progressive"
|
67 |
+
noise_alpha: 1.0
|
68 |
+
|
69 |
+
cfg_random_null_text_ratio: 0.1
|
70 |
+
cfg_random_null_img_ratio: 0.1
|
71 |
+
|
72 |
+
use_ema: false
|
73 |
+
ema_decay: 0.9999
|
74 |
+
|
75 |
+
learning_rate: 5.e-5
|
76 |
+
train_batch_size: 3
|
77 |
+
gradient_accumulation_steps: 1
|
78 |
+
max_grad_norm: 0.5
|
79 |
+
|
80 |
+
max_train_epoch: -1
|
81 |
+
max_train_steps: 200000
|
82 |
+
checkpointing_epochs: -1
|
83 |
+
checkpointing_steps: 2000
|
84 |
+
validation_steps: 1000
|
85 |
+
|
86 |
+
seed: 42
|
87 |
+
mixed_precision: "bf16"
|
88 |
+
num_workers: 32
|
89 |
+
enable_xformers_memory_efficient_attention: true
|
90 |
+
|
91 |
+
is_image: false
|
92 |
+
is_debug: false
|
consisti2v/data/dataset.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, io, csv, math, random
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
from einops import rearrange
|
5 |
+
from decord import VideoReader
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from torch.utils.data.dataset import Dataset
|
10 |
+
|
11 |
+
from diffusers.utils import logging
|
12 |
+
|
13 |
+
logger = logging.get_logger(__name__)
|
14 |
+
|
15 |
+
class WebVid10M(Dataset):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
json_path, video_folder=None,
|
19 |
+
sample_size=256, sample_stride=4, sample_n_frames=16,
|
20 |
+
is_image=False,
|
21 |
+
**kwargs,
|
22 |
+
):
|
23 |
+
logger.info(f"loading annotations from {json_path} ...")
|
24 |
+
with open(json_path, 'rb') as json_file:
|
25 |
+
json_list = list(json_file)
|
26 |
+
self.dataset = [json.loads(json_str) for json_str in json_list]
|
27 |
+
self.length = len(self.dataset)
|
28 |
+
logger.info(f"data scale: {self.length}")
|
29 |
+
|
30 |
+
self.video_folder = video_folder
|
31 |
+
self.sample_stride = sample_stride if isinstance(sample_stride, int) else tuple(sample_stride)
|
32 |
+
self.sample_n_frames = sample_n_frames
|
33 |
+
self.is_image = is_image
|
34 |
+
|
35 |
+
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
36 |
+
self.pixel_transforms = transforms.Compose([
|
37 |
+
transforms.RandomHorizontalFlip(),
|
38 |
+
transforms.Resize(sample_size[0], antialias=None),
|
39 |
+
transforms.CenterCrop(sample_size),
|
40 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
41 |
+
])
|
42 |
+
|
43 |
+
def get_batch(self, idx):
|
44 |
+
video_dict = self.dataset[idx]
|
45 |
+
video_relative_path, name = video_dict['file'], video_dict['text']
|
46 |
+
|
47 |
+
if self.video_folder is not None:
|
48 |
+
if video_relative_path[0] == '/':
|
49 |
+
video_dir = os.path.join(self.video_folder, os.path.basename(video_relative_path))
|
50 |
+
else:
|
51 |
+
video_dir = os.path.join(self.video_folder, video_relative_path)
|
52 |
+
else:
|
53 |
+
video_dir = video_relative_path
|
54 |
+
video_reader = VideoReader(video_dir)
|
55 |
+
video_length = len(video_reader)
|
56 |
+
|
57 |
+
if not self.is_image:
|
58 |
+
if isinstance(self.sample_stride, int):
|
59 |
+
stride = self.sample_stride
|
60 |
+
elif isinstance(self.sample_stride, tuple):
|
61 |
+
stride = random.randint(self.sample_stride[0], self.sample_stride[1])
|
62 |
+
clip_length = min(video_length, (self.sample_n_frames - 1) * stride + 1)
|
63 |
+
start_idx = random.randint(0, video_length - clip_length)
|
64 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
|
65 |
+
else:
|
66 |
+
frame_difference = random.randint(2, self.sample_n_frames)
|
67 |
+
clip_length = min(video_length, (frame_difference - 1) * self.sample_stride + 1)
|
68 |
+
start_idx = random.randint(0, video_length - clip_length)
|
69 |
+
batch_index = [start_idx, start_idx + clip_length - 1]
|
70 |
+
|
71 |
+
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
|
72 |
+
pixel_values = pixel_values / 255.
|
73 |
+
del video_reader
|
74 |
+
|
75 |
+
return pixel_values, name
|
76 |
+
|
77 |
+
def __len__(self):
|
78 |
+
return self.length
|
79 |
+
|
80 |
+
def __getitem__(self, idx):
|
81 |
+
while True:
|
82 |
+
try:
|
83 |
+
pixel_values, name = self.get_batch(idx)
|
84 |
+
break
|
85 |
+
|
86 |
+
except Exception as e:
|
87 |
+
idx = random.randint(0, self.length-1)
|
88 |
+
|
89 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
90 |
+
sample = dict(pixel_values=pixel_values, text=name)
|
91 |
+
return sample
|
92 |
+
|
93 |
+
|
94 |
+
class Pexels(Dataset):
|
95 |
+
def __init__(
|
96 |
+
self,
|
97 |
+
json_path, caption_json_path, video_folder=None,
|
98 |
+
sample_size=256, sample_duration=1, sample_fps=8,
|
99 |
+
is_image=False,
|
100 |
+
**kwargs,
|
101 |
+
):
|
102 |
+
logger.info(f"loading captions from {caption_json_path} ...")
|
103 |
+
with open(caption_json_path, 'rb') as caption_json_file:
|
104 |
+
caption_json_list = list(caption_json_file)
|
105 |
+
self.caption_dict = {json.loads(json_str)['id']: json.loads(json_str)['text'] for json_str in caption_json_list}
|
106 |
+
|
107 |
+
logger.info(f"loading annotations from {json_path} ...")
|
108 |
+
with open(json_path, 'rb') as json_file:
|
109 |
+
json_list = list(json_file)
|
110 |
+
dataset = [json.loads(json_str) for json_str in json_list]
|
111 |
+
|
112 |
+
self.dataset = []
|
113 |
+
for data in dataset:
|
114 |
+
data['text'] = self.caption_dict[data['id']]
|
115 |
+
if data['height'] / data['width'] < 0.625:
|
116 |
+
self.dataset.append(data)
|
117 |
+
self.length = len(self.dataset)
|
118 |
+
logger.info(f"data scale: {self.length}")
|
119 |
+
|
120 |
+
self.video_folder = video_folder
|
121 |
+
self.sample_duration = sample_duration
|
122 |
+
self.sample_fps = sample_fps
|
123 |
+
self.sample_n_frames = sample_duration * sample_fps
|
124 |
+
self.is_image = is_image
|
125 |
+
|
126 |
+
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
127 |
+
self.pixel_transforms = transforms.Compose([
|
128 |
+
transforms.RandomHorizontalFlip(),
|
129 |
+
transforms.Resize(sample_size[0], antialias=None),
|
130 |
+
transforms.CenterCrop(sample_size),
|
131 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
132 |
+
])
|
133 |
+
|
134 |
+
def get_batch(self, idx):
|
135 |
+
video_dict = self.dataset[idx]
|
136 |
+
video_relative_path, name = video_dict['file'], video_dict['text']
|
137 |
+
fps = video_dict['fps']
|
138 |
+
|
139 |
+
if self.video_folder is not None:
|
140 |
+
if video_relative_path[0] == '/':
|
141 |
+
video_dir = os.path.join(self.video_folder, os.path.basename(video_relative_path))
|
142 |
+
else:
|
143 |
+
video_dir = os.path.join(self.video_folder, video_relative_path)
|
144 |
+
else:
|
145 |
+
video_dir = video_relative_path
|
146 |
+
video_reader = VideoReader(video_dir)
|
147 |
+
video_length = len(video_reader)
|
148 |
+
|
149 |
+
if not self.is_image:
|
150 |
+
clip_length = min(video_length, math.ceil(fps * self.sample_duration))
|
151 |
+
start_idx = random.randint(0, video_length - clip_length)
|
152 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
|
153 |
+
else:
|
154 |
+
frame_difference = random.randint(2, self.sample_n_frames)
|
155 |
+
sample_stride = math.ceil((fps * self.sample_duration) / (self.sample_n_frames - 1) - 1)
|
156 |
+
clip_length = min(video_length, (frame_difference - 1) * sample_stride + 1)
|
157 |
+
start_idx = random.randint(0, video_length - clip_length)
|
158 |
+
batch_index = [start_idx, start_idx + clip_length - 1]
|
159 |
+
|
160 |
+
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
|
161 |
+
pixel_values = pixel_values / 255.
|
162 |
+
del video_reader
|
163 |
+
|
164 |
+
return pixel_values, name
|
165 |
+
|
166 |
+
def __len__(self):
|
167 |
+
return self.length
|
168 |
+
|
169 |
+
def __getitem__(self, idx):
|
170 |
+
while True:
|
171 |
+
try:
|
172 |
+
pixel_values, name = self.get_batch(idx)
|
173 |
+
break
|
174 |
+
|
175 |
+
except Exception as e:
|
176 |
+
idx = random.randint(0, self.length-1)
|
177 |
+
|
178 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
179 |
+
sample = dict(pixel_values=pixel_values, text=name)
|
180 |
+
return sample
|
181 |
+
|
182 |
+
|
183 |
+
class JointDataset(Dataset):
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
webvid_config, pexels_config,
|
187 |
+
sample_size=256,
|
188 |
+
sample_duration=None, sample_fps=None, sample_stride=None, sample_n_frames=None,
|
189 |
+
is_image=False,
|
190 |
+
**kwargs,
|
191 |
+
):
|
192 |
+
assert (sample_duration is None and sample_fps is None) or (sample_duration is not None and sample_fps is not None), "sample_duration and sample_fps should be both None or not None"
|
193 |
+
if sample_duration is not None and sample_fps is not None:
|
194 |
+
assert sample_stride is None, "when sample_duration and sample_fps are not None, sample_stride should be None"
|
195 |
+
if sample_stride is not None:
|
196 |
+
assert sample_fps is None and sample_duration is None, "when sample_stride is not None, sample_duration and sample_fps should be both None"
|
197 |
+
|
198 |
+
self.dataset = []
|
199 |
+
|
200 |
+
if pexels_config.enable:
|
201 |
+
logger.info(f"loading pexels dataset")
|
202 |
+
logger.info(f"loading captions from {pexels_config.caption_json_path} ...")
|
203 |
+
with open(pexels_config.caption_json_path, 'rb') as caption_json_file:
|
204 |
+
caption_json_list = list(caption_json_file)
|
205 |
+
self.caption_dict = {json.loads(json_str)['id']: json.loads(json_str)['text'] for json_str in caption_json_list}
|
206 |
+
|
207 |
+
logger.info(f"loading annotations from {pexels_config.json_path} ...")
|
208 |
+
with open(pexels_config.json_path, 'rb') as json_file:
|
209 |
+
json_list = list(json_file)
|
210 |
+
dataset = [json.loads(json_str) for json_str in json_list]
|
211 |
+
|
212 |
+
for data in dataset:
|
213 |
+
data['text'] = self.caption_dict[data['id']]
|
214 |
+
data['dataset'] = 'pexels'
|
215 |
+
if data['height'] / data['width'] < 0.625:
|
216 |
+
self.dataset.append(data)
|
217 |
+
|
218 |
+
if webvid_config.enable:
|
219 |
+
logger.info(f"loading webvid dataset")
|
220 |
+
logger.info(f"loading annotations from {webvid_config.json_path} ...")
|
221 |
+
with open(webvid_config.json_path, 'rb') as json_file:
|
222 |
+
json_list = list(json_file)
|
223 |
+
dataset = [json.loads(json_str) for json_str in json_list]
|
224 |
+
for data in dataset:
|
225 |
+
data['dataset'] = 'webvid'
|
226 |
+
self.dataset.extend(dataset)
|
227 |
+
|
228 |
+
self.length = len(self.dataset)
|
229 |
+
logger.info(f"data scale: {self.length}")
|
230 |
+
|
231 |
+
self.pexels_folder = pexels_config.video_folder
|
232 |
+
self.webvid_folder = webvid_config.video_folder
|
233 |
+
self.sample_duration = sample_duration
|
234 |
+
self.sample_fps = sample_fps
|
235 |
+
self.sample_n_frames = sample_duration * sample_fps if sample_n_frames is None else sample_n_frames
|
236 |
+
self.sample_stride = sample_stride if (sample_stride is None) or (sample_stride is not None and isinstance(sample_stride, int)) else tuple(sample_stride)
|
237 |
+
self.is_image = is_image
|
238 |
+
|
239 |
+
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
240 |
+
self.pixel_transforms = transforms.Compose([
|
241 |
+
transforms.RandomHorizontalFlip(),
|
242 |
+
transforms.Resize(sample_size[0], antialias=None),
|
243 |
+
transforms.CenterCrop(sample_size),
|
244 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
245 |
+
])
|
246 |
+
|
247 |
+
def get_batch(self, idx):
|
248 |
+
video_dict = self.dataset[idx]
|
249 |
+
video_relative_path, name = video_dict['file'], video_dict['text']
|
250 |
+
|
251 |
+
if video_dict['dataset'] == 'pexels':
|
252 |
+
video_folder = self.pexels_folder
|
253 |
+
elif video_dict['dataset'] == 'webvid':
|
254 |
+
video_folder = self.webvid_folder
|
255 |
+
else:
|
256 |
+
raise NotImplementedError
|
257 |
+
|
258 |
+
if video_folder is not None:
|
259 |
+
if video_relative_path[0] == '/':
|
260 |
+
video_dir = os.path.join(video_folder, os.path.basename(video_relative_path))
|
261 |
+
else:
|
262 |
+
video_dir = os.path.join(video_folder, video_relative_path)
|
263 |
+
else:
|
264 |
+
video_dir = video_relative_path
|
265 |
+
video_reader = VideoReader(video_dir)
|
266 |
+
video_length = len(video_reader)
|
267 |
+
|
268 |
+
stride = None
|
269 |
+
if not self.is_image:
|
270 |
+
if self.sample_duration is not None:
|
271 |
+
fps = video_dict['fps']
|
272 |
+
clip_length = min(video_length, math.ceil(fps * self.sample_duration))
|
273 |
+
elif self.sample_stride is not None:
|
274 |
+
if isinstance(self.sample_stride, int):
|
275 |
+
stride = self.sample_stride
|
276 |
+
elif isinstance(self.sample_stride, tuple):
|
277 |
+
stride = random.randint(self.sample_stride[0], self.sample_stride[1])
|
278 |
+
clip_length = min(video_length, (self.sample_n_frames - 1) * stride + 1)
|
279 |
+
|
280 |
+
start_idx = random.randint(0, video_length - clip_length)
|
281 |
+
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
|
282 |
+
|
283 |
+
else:
|
284 |
+
frame_difference = random.randint(2, self.sample_n_frames)
|
285 |
+
if self.sample_duration is not None:
|
286 |
+
fps = video_dict['fps']
|
287 |
+
sample_stride = math.ceil((fps * self.sample_duration) / (self.sample_n_frames - 1) - 1)
|
288 |
+
elif self.sample_stride is not None:
|
289 |
+
sample_stride = self.sample_stride
|
290 |
+
|
291 |
+
clip_length = min(video_length, (frame_difference - 1) * sample_stride + 1)
|
292 |
+
start_idx = random.randint(0, video_length - clip_length)
|
293 |
+
batch_index = [start_idx, start_idx + clip_length - 1]
|
294 |
+
|
295 |
+
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
|
296 |
+
pixel_values = pixel_values / 255.
|
297 |
+
del video_reader
|
298 |
+
|
299 |
+
return pixel_values, name, stride
|
300 |
+
|
301 |
+
def __len__(self):
|
302 |
+
return self.length
|
303 |
+
|
304 |
+
def __getitem__(self, idx):
|
305 |
+
while True:
|
306 |
+
try:
|
307 |
+
pixel_values, name, stride = self.get_batch(idx)
|
308 |
+
break
|
309 |
+
|
310 |
+
except Exception as e:
|
311 |
+
idx = random.randint(0, self.length-1)
|
312 |
+
|
313 |
+
pixel_values = self.pixel_transforms(pixel_values)
|
314 |
+
sample = dict(pixel_values=pixel_values, text=name, stride=stride)
|
315 |
+
return sample
|
consisti2v/models/rotary_embedding.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import pi, log
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.nn import Module, ModuleList
|
5 |
+
from torch.cuda.amp import autocast
|
6 |
+
from torch import nn, einsum, broadcast_tensors, Tensor
|
7 |
+
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
|
10 |
+
from beartype import beartype
|
11 |
+
from beartype.typing import Literal, Union, Optional
|
12 |
+
|
13 |
+
# helper functions
|
14 |
+
|
15 |
+
def exists(val):
|
16 |
+
return val is not None
|
17 |
+
|
18 |
+
def default(val, d):
|
19 |
+
return val if exists(val) else d
|
20 |
+
|
21 |
+
# broadcat, as tortoise-tts was using it
|
22 |
+
|
23 |
+
def broadcat(tensors, dim = -1):
|
24 |
+
broadcasted_tensors = broadcast_tensors(*tensors)
|
25 |
+
return torch.cat(broadcasted_tensors, dim = dim)
|
26 |
+
|
27 |
+
# rotary embedding helper functions
|
28 |
+
|
29 |
+
def rotate_half(x):
|
30 |
+
x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
31 |
+
x1, x2 = x.unbind(dim = -1)
|
32 |
+
x = torch.stack((-x2, x1), dim = -1)
|
33 |
+
return rearrange(x, '... d r -> ... (d r)')
|
34 |
+
|
35 |
+
@autocast(enabled = False)
|
36 |
+
def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2):
|
37 |
+
if t.ndim == 3:
|
38 |
+
seq_len = t.shape[seq_dim]
|
39 |
+
freqs = freqs[-seq_len:].to(t)
|
40 |
+
|
41 |
+
rot_dim = freqs.shape[-1]
|
42 |
+
end_index = start_index + rot_dim
|
43 |
+
|
44 |
+
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
|
45 |
+
|
46 |
+
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
|
47 |
+
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
|
48 |
+
return torch.cat((t_left, t, t_right), dim = -1)
|
49 |
+
|
50 |
+
# learned rotation helpers
|
51 |
+
|
52 |
+
def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None):
|
53 |
+
if exists(freq_ranges):
|
54 |
+
rotations = einsum('..., f -> ... f', rotations, freq_ranges)
|
55 |
+
rotations = rearrange(rotations, '... r f -> ... (r f)')
|
56 |
+
|
57 |
+
rotations = repeat(rotations, '... n -> ... (n r)', r = 2)
|
58 |
+
return apply_rotary_emb(rotations, t, start_index = start_index)
|
59 |
+
|
60 |
+
# classes
|
61 |
+
|
62 |
+
class RotaryEmbedding(Module):
|
63 |
+
@beartype
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
dim,
|
67 |
+
custom_freqs: Optional[Tensor] = None,
|
68 |
+
freqs_for: Union[
|
69 |
+
Literal['lang'],
|
70 |
+
Literal['pixel'],
|
71 |
+
Literal['constant']
|
72 |
+
] = 'lang',
|
73 |
+
theta = 10000,
|
74 |
+
max_freq = 10,
|
75 |
+
num_freqs = 1,
|
76 |
+
learned_freq = False,
|
77 |
+
use_xpos = False,
|
78 |
+
xpos_scale_base = 512,
|
79 |
+
interpolate_factor = 1.,
|
80 |
+
theta_rescale_factor = 1.,
|
81 |
+
seq_before_head_dim = False
|
82 |
+
):
|
83 |
+
super().__init__()
|
84 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
85 |
+
# has some connection to NTK literature
|
86 |
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
87 |
+
|
88 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
89 |
+
|
90 |
+
self.freqs_for = freqs_for
|
91 |
+
|
92 |
+
if exists(custom_freqs):
|
93 |
+
freqs = custom_freqs
|
94 |
+
elif freqs_for == 'lang':
|
95 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
96 |
+
elif freqs_for == 'pixel':
|
97 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
98 |
+
elif freqs_for == 'constant':
|
99 |
+
freqs = torch.ones(num_freqs).float()
|
100 |
+
|
101 |
+
self.tmp_store('cached_freqs', None)
|
102 |
+
self.tmp_store('cached_scales', None)
|
103 |
+
|
104 |
+
self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
|
105 |
+
|
106 |
+
self.learned_freq = learned_freq
|
107 |
+
|
108 |
+
# dummy for device
|
109 |
+
|
110 |
+
self.tmp_store('dummy', torch.tensor(0))
|
111 |
+
|
112 |
+
# default sequence dimension
|
113 |
+
|
114 |
+
self.seq_before_head_dim = seq_before_head_dim
|
115 |
+
self.default_seq_dim = -3 if seq_before_head_dim else -2
|
116 |
+
|
117 |
+
# interpolation factors
|
118 |
+
|
119 |
+
assert interpolate_factor >= 1.
|
120 |
+
self.interpolate_factor = interpolate_factor
|
121 |
+
|
122 |
+
# xpos
|
123 |
+
|
124 |
+
self.use_xpos = use_xpos
|
125 |
+
if not use_xpos:
|
126 |
+
self.tmp_store('scale', None)
|
127 |
+
return
|
128 |
+
|
129 |
+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
130 |
+
self.scale_base = xpos_scale_base
|
131 |
+
self.tmp_store('scale', scale)
|
132 |
+
|
133 |
+
@property
|
134 |
+
def device(self):
|
135 |
+
return self.dummy.device
|
136 |
+
|
137 |
+
def tmp_store(self, key, value):
|
138 |
+
self.register_buffer(key, value, persistent = False)
|
139 |
+
|
140 |
+
def get_seq_pos(self, seq_len, device, dtype, offset = 0):
|
141 |
+
return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor
|
142 |
+
|
143 |
+
def rotate_queries_or_keys(self, t, seq_dim = None, offset = 0, freq_seq_len = None, seq_pos = None):
|
144 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
145 |
+
|
146 |
+
assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
|
147 |
+
|
148 |
+
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
|
149 |
+
|
150 |
+
if exists(freq_seq_len):
|
151 |
+
assert freq_seq_len >= seq_len
|
152 |
+
seq_len = freq_seq_len
|
153 |
+
|
154 |
+
if seq_pos is None:
|
155 |
+
seq_pos = self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset)
|
156 |
+
else:
|
157 |
+
assert seq_pos.shape[0] == seq_len
|
158 |
+
|
159 |
+
freqs = self.forward(seq_pos, seq_len = seq_len, offset = offset)
|
160 |
+
|
161 |
+
if seq_dim == -3:
|
162 |
+
freqs = rearrange(freqs, 'n d -> n 1 d')
|
163 |
+
|
164 |
+
return apply_rotary_emb(freqs, t, seq_dim = seq_dim)
|
165 |
+
|
166 |
+
def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0):
|
167 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
168 |
+
|
169 |
+
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
|
170 |
+
assert q_len <= k_len
|
171 |
+
rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, freq_seq_len = k_len)
|
172 |
+
rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim)
|
173 |
+
|
174 |
+
rotated_q = rotated_q.type(q.dtype)
|
175 |
+
rotated_k = rotated_k.type(k.dtype)
|
176 |
+
|
177 |
+
return rotated_q, rotated_k
|
178 |
+
|
179 |
+
def rotate_queries_and_keys(self, q, k, seq_dim = None):
|
180 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
181 |
+
|
182 |
+
assert self.use_xpos
|
183 |
+
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
|
184 |
+
|
185 |
+
seq = self.get_seq_pos(seq_len, dtype = dtype, device = device)
|
186 |
+
|
187 |
+
freqs = self.forward(seq, seq_len = seq_len)
|
188 |
+
scale = self.get_scale(seq, seq_len = seq_len).to(dtype)
|
189 |
+
|
190 |
+
if seq_dim == -3:
|
191 |
+
freqs = rearrange(freqs, 'n d -> n 1 d')
|
192 |
+
scale = rearrange(scale, 'n d -> n 1 d')
|
193 |
+
|
194 |
+
rotated_q = apply_rotary_emb(freqs, q, scale = scale, seq_dim = seq_dim)
|
195 |
+
rotated_k = apply_rotary_emb(freqs, k, scale = scale ** -1, seq_dim = seq_dim)
|
196 |
+
|
197 |
+
rotated_q = rotated_q.type(q.dtype)
|
198 |
+
rotated_k = rotated_k.type(k.dtype)
|
199 |
+
|
200 |
+
return rotated_q, rotated_k
|
201 |
+
|
202 |
+
@beartype
|
203 |
+
def get_scale(
|
204 |
+
self,
|
205 |
+
t: Tensor,
|
206 |
+
seq_len: Optional[int] = None,
|
207 |
+
offset = 0
|
208 |
+
):
|
209 |
+
assert self.use_xpos
|
210 |
+
|
211 |
+
should_cache = exists(seq_len)
|
212 |
+
|
213 |
+
if (
|
214 |
+
should_cache and \
|
215 |
+
exists(self.cached_scales) and \
|
216 |
+
(seq_len + offset) <= self.cached_scales.shape[0]
|
217 |
+
):
|
218 |
+
return self.cached_scales[offset:(offset + seq_len)]
|
219 |
+
|
220 |
+
scale = 1.
|
221 |
+
if self.use_xpos:
|
222 |
+
power = (t - len(t) // 2) / self.scale_base
|
223 |
+
scale = self.scale ** rearrange(power, 'n -> n 1')
|
224 |
+
scale = torch.cat((scale, scale), dim = -1)
|
225 |
+
|
226 |
+
if should_cache:
|
227 |
+
self.tmp_store('cached_scales', scale)
|
228 |
+
|
229 |
+
return scale
|
230 |
+
|
231 |
+
def get_axial_freqs(self, *dims):
|
232 |
+
Colon = slice(None)
|
233 |
+
all_freqs = []
|
234 |
+
|
235 |
+
for ind, dim in enumerate(dims):
|
236 |
+
if self.freqs_for == 'pixel':
|
237 |
+
pos = torch.linspace(-1, 1, steps = dim, device = self.device)
|
238 |
+
else:
|
239 |
+
pos = torch.arange(dim, device = self.device)
|
240 |
+
|
241 |
+
freqs = self.forward(pos, seq_len = dim)
|
242 |
+
|
243 |
+
all_axis = [None] * len(dims)
|
244 |
+
all_axis[ind] = Colon
|
245 |
+
|
246 |
+
new_axis_slice = (Ellipsis, *all_axis, Colon)
|
247 |
+
all_freqs.append(freqs[new_axis_slice])
|
248 |
+
|
249 |
+
all_freqs = broadcast_tensors(*all_freqs)
|
250 |
+
return torch.cat(all_freqs, dim = -1)
|
251 |
+
|
252 |
+
@autocast(enabled = False)
|
253 |
+
def forward(
|
254 |
+
self,
|
255 |
+
t: Tensor,
|
256 |
+
seq_len = None,
|
257 |
+
offset = 0
|
258 |
+
):
|
259 |
+
# should_cache = (
|
260 |
+
# not self.learned_freq and \
|
261 |
+
# exists(seq_len) and \
|
262 |
+
# self.freqs_for != 'pixel'
|
263 |
+
# )
|
264 |
+
|
265 |
+
# if (
|
266 |
+
# should_cache and \
|
267 |
+
# exists(self.cached_freqs) and \
|
268 |
+
# (offset + seq_len) <= self.cached_freqs.shape[0]
|
269 |
+
# ):
|
270 |
+
# return self.cached_freqs[offset:(offset + seq_len)].detach()
|
271 |
+
|
272 |
+
freqs = self.freqs
|
273 |
+
|
274 |
+
freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
|
275 |
+
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
|
276 |
+
|
277 |
+
# if should_cache:
|
278 |
+
# self.tmp_store('cached_freqs', freqs.detach())
|
279 |
+
|
280 |
+
return freqs
|
consisti2v/models/videoldm_attention.py
ADDED
@@ -0,0 +1,809 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|