multimodalart HF staff commited on
Commit
4be9315
1 Parent(s): b49ce0a

Create simple_video_sample.py

Browse files
Files changed (1) hide show
  1. simple_video_sample.py +277 -0
simple_video_sample.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from glob import glob
4
+ from pathlib import Path
5
+ from typing import Optional
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ from einops import rearrange, repeat
11
+ from fire import Fire
12
+ from omegaconf import OmegaConf
13
+ from PIL import Image
14
+ from torchvision.transforms import ToTensor
15
+
16
+ from scripts.util.detection.nsfw_and_watermark_dectection import \
17
+ DeepFloydDataFiltering
18
+ from sgm.inference.helpers import embed_watermark
19
+ from sgm.util import default, instantiate_from_config
20
+
21
+ def sample(
22
+ input_path: str = "assets/doggo.png", # Can either be image file or folder with image files
23
+ num_frames: Optional[int] = None,
24
+ num_steps: Optional[int] = None,
25
+ version: str = "svd",
26
+ fps_id: int = 6,
27
+ motion_bucket_id: int = 127,
28
+ cond_aug: float = 0.02,
29
+ seed: int = 23,
30
+ decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
31
+ device: str = "cuda",
32
+ output_folder: Optional[str] = None,
33
+ ):
34
+ """
35
+ Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
36
+ image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
37
+ """
38
+
39
+ if version == "svd":
40
+ num_frames = default(num_frames, 14)
41
+ num_steps = default(num_steps, 25)
42
+ output_folder = default(output_folder, "outputs/simple_video_sample/svd/")
43
+ model_config = "scripts/sampling/configs/svd.yaml"
44
+ elif version == "svd_xt":
45
+ num_frames = default(num_frames, 25)
46
+ num_steps = default(num_steps, 30)
47
+ output_folder = default(output_folder, "outputs/simple_video_sample/svd_xt/")
48
+ model_config = "scripts/sampling/configs/svd_xt.yaml"
49
+ elif version == "svd_image_decoder":
50
+ num_frames = default(num_frames, 14)
51
+ num_steps = default(num_steps, 25)
52
+ output_folder = default(
53
+ output_folder, "outputs/simple_video_sample/svd_image_decoder/"
54
+ )
55
+ model_config = "scripts/sampling/configs/svd_image_decoder.yaml"
56
+ elif version == "svd_xt_image_decoder":
57
+ num_frames = default(num_frames, 25)
58
+ num_steps = default(num_steps, 30)
59
+ output_folder = default(
60
+ output_folder, "outputs/simple_video_sample/svd_xt_image_decoder/"
61
+ )
62
+ model_config = "scripts/sampling/configs/svd_xt_image_decoder.yaml"
63
+ else:
64
+ raise ValueError(f"Version {version} does not exist.")
65
+
66
+ model, filter = load_model(
67
+ model_config,
68
+ device,
69
+ num_frames,
70
+ num_steps,
71
+ )
72
+ torch.manual_seed(seed)
73
+
74
+ path = Path(input_path)
75
+ all_img_paths = []
76
+ if path.is_file():
77
+ if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
78
+ all_img_paths = [input_path]
79
+ else:
80
+ raise ValueError("Path is not valid image file.")
81
+ elif path.is_dir():
82
+ all_img_paths = sorted(
83
+ [
84
+ f
85
+ for f in path.iterdir()
86
+ if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
87
+ ]
88
+ )
89
+ if len(all_img_paths) == 0:
90
+ raise ValueError("Folder does not contain any images.")
91
+ else:
92
+ raise ValueError
93
+
94
+ for input_img_path in all_img_paths:
95
+ with Image.open(input_img_path) as image:
96
+ if image.mode == "RGBA":
97
+ image = image.convert("RGB")
98
+ w, h = image.size
99
+
100
+ if h % 64 != 0 or w % 64 != 0:
101
+ width, height = map(lambda x: x - x % 64, (w, h))
102
+ image = image.resize((width, height))
103
+ print(
104
+ f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
105
+ )
106
+
107
+ image = ToTensor()(image)
108
+ image = image * 2.0 - 1.0
109
+
110
+ image = image.unsqueeze(0).to(device)
111
+ H, W = image.shape[2:]
112
+ assert image.shape[1] == 3
113
+ F = 8
114
+ C = 4
115
+ shape = (num_frames, C, H // F, W // F)
116
+ if (H, W) != (576, 1024):
117
+ print(
118
+ "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
119
+ )
120
+ if motion_bucket_id > 255:
121
+ print(
122
+ "WARNING: High motion bucket! This may lead to suboptimal performance."
123
+ )
124
+
125
+ if fps_id < 5:
126
+ print("WARNING: Small fps value! This may lead to suboptimal performance.")
127
+
128
+ if fps_id > 30:
129
+ print("WARNING: Large fps value! This may lead to suboptimal performance.")
130
+
131
+ value_dict = {}
132
+ value_dict["motion_bucket_id"] = motion_bucket_id
133
+ value_dict["fps_id"] = fps_id
134
+ value_dict["cond_aug"] = cond_aug
135
+ value_dict["cond_frames_without_noise"] = image
136
+ value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
137
+ value_dict["cond_aug"] = cond_aug
138
+
139
+ with torch.no_grad():
140
+ with torch.autocast(device):
141
+ batch, batch_uc = get_batch(
142
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
143
+ value_dict,
144
+ [1, num_frames],
145
+ T=num_frames,
146
+ device=device,
147
+ )
148
+ c, uc = model.conditioner.get_unconditional_conditioning(
149
+ batch,
150
+ batch_uc=batch_uc,
151
+ force_uc_zero_embeddings=[
152
+ "cond_frames",
153
+ "cond_frames_without_noise",
154
+ ],
155
+ )
156
+
157
+ for k in ["crossattn", "concat"]:
158
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
159
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
160
+ c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
161
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
162
+
163
+ randn = torch.randn(shape, device=device)
164
+
165
+ additional_model_inputs = {}
166
+ additional_model_inputs["image_only_indicator"] = torch.zeros(
167
+ 2, num_frames
168
+ ).to(device)
169
+ additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
170
+
171
+ def denoiser(input, sigma, c):
172
+ return model.denoiser(
173
+ model.model, input, sigma, c, **additional_model_inputs
174
+ )
175
+
176
+ samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
177
+ model.en_and_decode_n_samples_a_time = decoding_t
178
+ samples_x = model.decode_first_stage(samples_z)
179
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
180
+
181
+ os.makedirs(output_folder, exist_ok=True)
182
+ base_count = len(glob(os.path.join(output_folder, "*.mp4")))
183
+ video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
184
+ writer = cv2.VideoWriter(
185
+ video_path,
186
+ cv2.VideoWriter_fourcc(*"MP4V"),
187
+ fps_id + 1,
188
+ (samples.shape[-1], samples.shape[-2]),
189
+ )
190
+
191
+ samples = embed_watermark(samples)
192
+ samples = filter(samples)
193
+ vid = (
194
+ (rearrange(samples, "t c h w -> t h w c") * 255)
195
+ .cpu()
196
+ .numpy()
197
+ .astype(np.uint8)
198
+ )
199
+ for frame in vid:
200
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
201
+ writer.write(frame)
202
+ writer.release()
203
+
204
+
205
+ def get_unique_embedder_keys_from_conditioner(conditioner):
206
+ return list(set([x.input_key for x in conditioner.embedders]))
207
+
208
+
209
+ def get_batch(keys, value_dict, N, T, device):
210
+ batch = {}
211
+ batch_uc = {}
212
+
213
+ for key in keys:
214
+ if key == "fps_id":
215
+ batch[key] = (
216
+ torch.tensor([value_dict["fps_id"]])
217
+ .to(device)
218
+ .repeat(int(math.prod(N)))
219
+ )
220
+ elif key == "motion_bucket_id":
221
+ batch[key] = (
222
+ torch.tensor([value_dict["motion_bucket_id"]])
223
+ .to(device)
224
+ .repeat(int(math.prod(N)))
225
+ )
226
+ elif key == "cond_aug":
227
+ batch[key] = repeat(
228
+ torch.tensor([value_dict["cond_aug"]]).to(device),
229
+ "1 -> b",
230
+ b=math.prod(N),
231
+ )
232
+ elif key == "cond_frames":
233
+ batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
234
+ elif key == "cond_frames_without_noise":
235
+ batch[key] = repeat(
236
+ value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
237
+ )
238
+ else:
239
+ batch[key] = value_dict[key]
240
+
241
+ if T is not None:
242
+ batch["num_video_frames"] = T
243
+
244
+ for key in batch.keys():
245
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
246
+ batch_uc[key] = torch.clone(batch[key])
247
+ return batch, batch_uc
248
+
249
+
250
+ def load_model(
251
+ config: str,
252
+ device: str,
253
+ num_frames: int,
254
+ num_steps: int,
255
+ ):
256
+ config = OmegaConf.load(config)
257
+ if device == "cuda":
258
+ config.model.params.conditioner_config.params.emb_models[
259
+ 0
260
+ ].params.open_clip_embedding_config.params.init_device = device
261
+
262
+ config.model.params.sampler_config.params.num_steps = num_steps
263
+ config.model.params.sampler_config.params.guider_config.params.num_frames = (
264
+ num_frames
265
+ )
266
+ if device == "cuda":
267
+ with torch.device(device):
268
+ model = instantiate_from_config(config.model).to(device).eval()
269
+ else:
270
+ model = instantiate_from_config(config.model).to(device).eval()
271
+
272
+ filter = DeepFloydDataFiltering(verbose=False, device=device)
273
+ return model, filter
274
+
275
+
276
+ if __name__ == "__main__":
277
+ Fire(sample)