multimodalart HF staff commited on
Commit
3b06696
1 Parent(s): 262138f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +283 -0
app.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from glob import glob
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ from einops import rearrange, repeat
12
+ from omegaconf import OmegaConf
13
+ from PIL import Image
14
+ from torchvision.transforms import ToTensor
15
+
16
+ from scripts.util.detection.nsfw_and_watermark_dectection import \
17
+ DeepFloydDataFiltering
18
+ from sgm.inference.helpers import embed_watermark
19
+ from sgm.util import default, instantiate_from_config
20
+ from huggingface_hub import hf_hub_download
21
+
22
+ num_frames = 25
23
+ num_steps = 30
24
+ model_config = "scripts/sampling/configs/svd_xt.yaml"
25
+ device = "cuda"
26
+
27
+ hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt", filename="svd_xt.safetensors", local_dir="checkpoints", token=os.getenv("HF_TOKEN"))
28
+
29
+ def load_model(
30
+ config: str,
31
+ device: str,
32
+ num_frames: int,
33
+ num_steps: int,
34
+ ):
35
+ config = OmegaConf.load(config)
36
+ if device == "cuda":
37
+ config.model.params.conditioner_config.params.emb_models[
38
+ 0
39
+ ].params.open_clip_embedding_config.params.init_device = device
40
+
41
+ config.model.params.sampler_config.params.num_steps = num_steps
42
+ config.model.params.sampler_config.params.guider_config.params.num_frames = (
43
+ num_frames
44
+ )
45
+ if device == "cuda":
46
+ with torch.device(device):
47
+ model = instantiate_from_config(config.model).to(device).eval()
48
+ else:
49
+ model = instantiate_from_config(config.model).to(device).eval()
50
+
51
+ filter = DeepFloydDataFiltering(verbose=False, device=device)
52
+ return model, filter
53
+
54
+ model, filter = load_model(
55
+ model_config,
56
+ device,
57
+ num_frames,
58
+ num_steps,
59
+ )
60
+
61
+ def sample(
62
+ image: Image.Image,
63
+ num_frames: Optional[int] = 25,
64
+ num_steps: Optional[int] = 30,
65
+ version: str = "svd_xt",
66
+ fps_id: int = 6,
67
+ motion_bucket_id: int = 127,
68
+ cond_aug: float = 0.02,
69
+ seed: int = 23,
70
+ decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
71
+ ):
72
+ output_folder = str(uuid.uuid4())
73
+ torch.manual_seed(seed)
74
+
75
+ all_img_paths = [image]
76
+ for input_img_path in all_img_paths:
77
+ if image.mode == "RGBA":
78
+ image = image.convert("RGB")
79
+ w, h = image.size
80
+
81
+ if h % 64 != 0 or w % 64 != 0:
82
+ width, height = map(lambda x: x - x % 64, (w, h))
83
+ image = image.resize((width, height))
84
+ print(
85
+ f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
86
+ )
87
+
88
+ image = ToTensor()(image)
89
+ image = image * 2.0 - 1.0
90
+
91
+ image = image.unsqueeze(0).to(device)
92
+ H, W = image.shape[2:]
93
+ assert image.shape[1] == 3
94
+ F = 8
95
+ C = 4
96
+ shape = (num_frames, C, H // F, W // F)
97
+ if (H, W) != (576, 1024):
98
+ print(
99
+ "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
100
+ )
101
+ if motion_bucket_id > 255:
102
+ print(
103
+ "WARNING: High motion bucket! This may lead to suboptimal performance."
104
+ )
105
+
106
+ if fps_id < 5:
107
+ print("WARNING: Small fps value! This may lead to suboptimal performance.")
108
+
109
+ if fps_id > 30:
110
+ print("WARNING: Large fps value! This may lead to suboptimal performance.")
111
+
112
+ value_dict = {}
113
+ value_dict["motion_bucket_id"] = motion_bucket_id
114
+ value_dict["fps_id"] = fps_id
115
+ value_dict["cond_aug"] = cond_aug
116
+ value_dict["cond_frames_without_noise"] = image
117
+ value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
118
+ value_dict["cond_aug"] = cond_aug
119
+
120
+ with torch.no_grad():
121
+ with torch.autocast(device):
122
+ batch, batch_uc = get_batch(
123
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
124
+ value_dict,
125
+ [1, num_frames],
126
+ T=num_frames,
127
+ device=device,
128
+ )
129
+ c, uc = model.conditioner.get_unconditional_conditioning(
130
+ batch,
131
+ batch_uc=batch_uc,
132
+ force_uc_zero_embeddings=[
133
+ "cond_frames",
134
+ "cond_frames_without_noise",
135
+ ],
136
+ )
137
+
138
+ for k in ["crossattn", "concat"]:
139
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
140
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
141
+ c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
142
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
143
+
144
+ randn = torch.randn(shape, device=device)
145
+
146
+ additional_model_inputs = {}
147
+ additional_model_inputs["image_only_indicator"] = torch.zeros(
148
+ 2, num_frames
149
+ ).to(device)
150
+ additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
151
+
152
+ def denoiser(input, sigma, c):
153
+ return model.denoiser(
154
+ model.model, input, sigma, c, **additional_model_inputs
155
+ )
156
+
157
+ samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
158
+ model.en_and_decode_n_samples_a_time = decoding_t
159
+ samples_x = model.decode_first_stage(samples_z)
160
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
161
+
162
+ os.makedirs(output_folder, exist_ok=True)
163
+ base_count = len(glob(os.path.join(output_folder, "*.mp4")))
164
+ video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
165
+ writer = cv2.VideoWriter(
166
+ video_path,
167
+ cv2.VideoWriter_fourcc(*'avc1'),
168
+ fps_id + 1,
169
+ (samples.shape[-1], samples.shape[-2]),
170
+ )
171
+
172
+ samples = embed_watermark(samples)
173
+ samples = filter(samples)
174
+ vid = (
175
+ (rearrange(samples, "t c h w -> t h w c") * 255)
176
+ .cpu()
177
+ .numpy()
178
+ .astype(np.uint8)
179
+ )
180
+ for frame in vid:
181
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
182
+ writer.write(frame)
183
+ writer.release()
184
+ return video_path
185
+
186
+ def get_unique_embedder_keys_from_conditioner(conditioner):
187
+ return list(set([x.input_key for x in conditioner.embedders]))
188
+
189
+
190
+ def get_batch(keys, value_dict, N, T, device):
191
+ batch = {}
192
+ batch_uc = {}
193
+
194
+ for key in keys:
195
+ if key == "fps_id":
196
+ batch[key] = (
197
+ torch.tensor([value_dict["fps_id"]])
198
+ .to(device)
199
+ .repeat(int(math.prod(N)))
200
+ )
201
+ elif key == "motion_bucket_id":
202
+ batch[key] = (
203
+ torch.tensor([value_dict["motion_bucket_id"]])
204
+ .to(device)
205
+ .repeat(int(math.prod(N)))
206
+ )
207
+ elif key == "cond_aug":
208
+ batch[key] = repeat(
209
+ torch.tensor([value_dict["cond_aug"]]).to(device),
210
+ "1 -> b",
211
+ b=math.prod(N),
212
+ )
213
+ elif key == "cond_frames":
214
+ batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
215
+ elif key == "cond_frames_without_noise":
216
+ batch[key] = repeat(
217
+ value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
218
+ )
219
+ else:
220
+ batch[key] = value_dict[key]
221
+
222
+ if T is not None:
223
+ batch["num_video_frames"] = T
224
+
225
+ for key in batch.keys():
226
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
227
+ batch_uc[key] = torch.clone(batch[key])
228
+ return batch, batch_uc
229
+
230
+
231
+ import gradio as gr
232
+ import uuid
233
+ def resize_image(image, output_size=(1024, 576)):
234
+
235
+ # Calculate aspect ratios
236
+ target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size
237
+ image_aspect = image.width / image.height # Aspect ratio of the original image
238
+
239
+ # Resize then crop if the original image is larger
240
+ if image_aspect > target_aspect:
241
+ # Resize the image to match the target height, maintaining aspect ratio
242
+ new_height = output_size[1]
243
+ new_width = int(new_height * image_aspect)
244
+ resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
245
+ # Calculate coordinates for cropping
246
+ left = (new_width - output_size[0]) / 2
247
+ top = 0
248
+ right = (new_width + output_size[0]) / 2
249
+ bottom = output_size[1]
250
+ else:
251
+ # Resize the image to match the target width, maintaining aspect ratio
252
+ new_width = output_size[0]
253
+ new_height = int(new_width / image_aspect)
254
+ resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
255
+ # Calculate coordinates for cropping
256
+ left = 0
257
+ top = (new_height - output_size[1]) / 2
258
+ right = output_size[0]
259
+ bottom = (new_height + output_size[1]) / 2
260
+
261
+ # Crop the image
262
+ cropped_image = resized_image.crop((left, top, right, bottom))
263
+
264
+ return cropped_image
265
+
266
+ with gr.Blocks() as demo:
267
+ gr.Markdown('''# Stable Video Diffusion - Image2Video - XT
268
+ Generate 25 frames of video from a single image using SDV-XT.
269
+ ''')
270
+ with gr.Column():
271
+ image = gr.Image(label="Upload your image (it will be center cropped to 1024x576)", type="pil")
272
+ generate_btn = gr.Button("Generate")
273
+ with gr.Accordion("Advanced options", open=False):
274
+ cond_aug = gr.Slider(label="Conditioning augmentation", value=0.02, minimum=0.0)
275
+ seed = gr.Slider(label="Seed", value=42, minimum=0, maximum=int(1e9), step=1)
276
+ #decoding_t = gr.Slider(label="Decode frames at a time", value=6, minimum=1, maximum=14, interactive=False)
277
+ saving_fps = gr.Slider(label="Saving FPS", value=6, minimum=6, maximum=48, step=6)
278
+ with gr.Column():
279
+ video = gr.Video()
280
+ image.upload(fn=resize_image, inputs=image, outputs=image)
281
+ generate_btn.click(fn=sample, inputs=[image], outputs=video, api_name="video")
282
+
283
+ demo.launch()