multimodalart HF staff commited on
Commit
d56d267
1 Parent(s): 5e1ee6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +218 -62
app.py CHANGED
@@ -8,6 +8,7 @@ import cv2
8
  import numpy as np
9
  import torch
10
  from einops import rearrange, repeat
 
11
  from omegaconf import OmegaConf
12
  from PIL import Image
13
  from torchvision.transforms import ToTensor
@@ -16,37 +17,198 @@ from scripts.util.detection.nsfw_and_watermark_dectection import \
16
  DeepFloydDataFiltering
17
  from sgm.inference.helpers import embed_watermark
18
  from sgm.util import default, instantiate_from_config
19
- from huggingface_hub import hf_hub_download
20
 
21
- import gradio as gr
22
- import uuid
23
-
24
- from simple_video_sample import sample
25
 
26
- num_frames = 25
27
- num_steps = 30
28
- model_config = "scripts/sampling/configs/svd_xt.yaml"
29
  device = "cuda"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt", filename="svd_xt.safetensors", local_dir="checkpoints", token=os.getenv("HF_TOKEN"))
 
32
 
33
- def run_sampling(
34
- input_path: str,
35
- num_frames: Optional[int] = 25,
36
- num_steps: Optional[int] = 30,
 
 
 
 
 
 
 
 
 
 
 
 
37
  version: str = "svd_xt",
38
  fps_id: int = 6,
39
  motion_bucket_id: int = 127,
40
  cond_aug: float = 0.02,
41
  seed: int = 23,
42
  decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
 
 
43
  ):
44
- output_folder = str(uuid.uuid4())
45
- print(output_folder)
46
- print(version)
47
- print(input_path)
48
- sample(input_path, version=version, output_folder=output_folder, decoding_t=decoding_t)
49
- return f"{output_folder}/000000.mp4"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def get_unique_embedder_keys_from_conditioner(conditioner):
52
  return list(set([x.input_key for x in conditioner.embedders]))
@@ -92,58 +254,52 @@ def get_batch(keys, value_dict, N, T, device):
92
  batch_uc[key] = torch.clone(batch[key])
93
  return batch, batch_uc
94
 
 
 
95
  def resize_image(image_path, output_size=(1024, 576)):
96
- with Image.open(image_path) as image:
97
- # Calculate aspect ratios
98
- target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size
99
- image_aspect = image.width / image.height # Aspect ratio of the original image
100
-
101
- # Resize then crop if the original image is larger
102
- if image_aspect > target_aspect:
103
- # Resize the image to match the target height, maintaining aspect ratio
104
- new_height = output_size[1]
105
- new_width = int(new_height * image_aspect)
106
- resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
107
- # Calculate coordinates for cropping
108
- left = (new_width - output_size[0]) / 2
109
- top = 0
110
- right = (new_width + output_size[0]) / 2
111
- bottom = output_size[1]
112
- else:
113
- # Resize the image to match the target width, maintaining aspect ratio
114
- new_width = output_size[0]
115
- new_height = int(new_width / image_aspect)
116
- resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
117
- # Calculate coordinates for cropping
118
- left = 0
119
- top = (new_height - output_size[1]) / 2
120
- right = output_size[0]
121
- bottom = (new_height + output_size[1]) / 2
122
-
123
- # Crop the image
124
- cropped_image = resized_image.crop((left, top, right, bottom))
125
 
126
- return cropped_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- css = '''
129
- .gradio-container{max-width:850px !important}
130
- '''
 
131
 
132
- with gr.Blocks(css=css) as demo:
133
  gr.Markdown('''# Stable Video Diffusion - Image2Video - XT
134
- Generate 25 frames of video from a single image with SDV-XT. [Join the waitlist](https://stability.ai/contact) for the text-to-video web experience
135
  ''')
136
  with gr.Column():
137
  image = gr.Image(label="Upload your image (it will be center cropped to 1024x576)", type="filepath")
138
  generate_btn = gr.Button("Generate")
139
- #with gr.Accordion("Advanced options", open=False):
140
- # cond_aug = gr.Slider(label="Conditioning augmentation", value=0.02, minimum=0.0)
141
- # seed = gr.Slider(label="Seed", value=42, minimum=0, maximum=int(1e9), step=1)
142
- #decoding_t = gr.Slider(label="Decode frames at a time", value=6, minimum=1, maximum=14, interactive=False)
143
- # saving_fps = gr.Slider(label="Saving FPS", value=6, minimum=6, maximum=48, step=6)
144
  with gr.Column():
145
  video = gr.Video()
146
- image.upload(fn=resize_image, inputs=image, outputs=image)
147
- generate_btn.click(fn=run_sampling, inputs=[image], outputs=video, api_name="video")
148
 
149
- demo.launch()
 
 
8
  import numpy as np
9
  import torch
10
  from einops import rearrange, repeat
11
+ from fire import Fire
12
  from omegaconf import OmegaConf
13
  from PIL import Image
14
  from torchvision.transforms import ToTensor
 
17
  DeepFloydDataFiltering
18
  from sgm.inference.helpers import embed_watermark
19
  from sgm.util import default, instantiate_from_config
 
20
 
21
+ hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt", filename="svd_xt.safetensors", local_dir="checkpoints")
 
 
 
22
 
23
+ version = "svd_xt"
 
 
24
  device = "cuda"
25
+ def load_model(
26
+ config: str,
27
+ device: str,
28
+ num_frames: int,
29
+ num_steps: int,
30
+ ):
31
+ config = OmegaConf.load(config)
32
+ if device == "cuda":
33
+ config.model.params.conditioner_config.params.emb_models[
34
+ 0
35
+ ].params.open_clip_embedding_config.params.init_device = device
36
+
37
+ config.model.params.sampler_config.params.num_steps = num_steps
38
+ config.model.params.sampler_config.params.guider_config.params.num_frames = (
39
+ num_frames
40
+ )
41
+ if device == "cuda":
42
+ with torch.device(device):
43
+ model = instantiate_from_config(config.model).to(device).eval()
44
+ else:
45
+ model = instantiate_from_config(config.model).to(device).eval()
46
 
47
+ filter = DeepFloydDataFiltering(verbose=False, device=device)
48
+ return model, filter
49
 
50
+ if version == "svd_xt":
51
+ num_frames = 25
52
+ num_steps = 30
53
+ model_config = "scripts/sampling/configs/svd_xt.yaml"
54
+ else:
55
+ raise ValueError(f"Version {version} does not exist.")
56
+
57
+ model, filter = load_model(
58
+ model_config,
59
+ device,
60
+ num_frames,
61
+ num_steps,
62
+ )
63
+
64
+ def sample(
65
+ input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
66
  version: str = "svd_xt",
67
  fps_id: int = 6,
68
  motion_bucket_id: int = 127,
69
  cond_aug: float = 0.02,
70
  seed: int = 23,
71
  decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
72
+ device: str = "cuda",
73
+ output_folder: str = "outputs",
74
  ):
75
+ """
76
+ Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
77
+ image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
78
+ """
79
+ torch.manual_seed(seed)
80
+
81
+ path = Path(input_path)
82
+ all_img_paths = []
83
+ if path.is_file():
84
+ if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
85
+ all_img_paths = [input_path]
86
+ else:
87
+ raise ValueError("Path is not valid image file.")
88
+ elif path.is_dir():
89
+ all_img_paths = sorted(
90
+ [
91
+ f
92
+ for f in path.iterdir()
93
+ if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
94
+ ]
95
+ )
96
+ if len(all_img_paths) == 0:
97
+ raise ValueError("Folder does not contain any images.")
98
+ else:
99
+ raise ValueError
100
+
101
+ for input_img_path in all_img_paths:
102
+ with Image.open(input_img_path) as image:
103
+ if image.mode == "RGBA":
104
+ image = image.convert("RGB")
105
+ w, h = image.size
106
+
107
+ if h % 64 != 0 or w % 64 != 0:
108
+ width, height = map(lambda x: x - x % 64, (w, h))
109
+ image = image.resize((width, height))
110
+ print(
111
+ f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
112
+ )
113
+
114
+ image = ToTensor()(image)
115
+ image = image * 2.0 - 1.0
116
+
117
+ image = image.unsqueeze(0).to(device)
118
+ H, W = image.shape[2:]
119
+ assert image.shape[1] == 3
120
+ F = 8
121
+ C = 4
122
+ shape = (num_frames, C, H // F, W // F)
123
+ if (H, W) != (576, 1024):
124
+ print(
125
+ "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
126
+ )
127
+ if motion_bucket_id > 255:
128
+ print(
129
+ "WARNING: High motion bucket! This may lead to suboptimal performance."
130
+ )
131
+
132
+ if fps_id < 5:
133
+ print("WARNING: Small fps value! This may lead to suboptimal performance.")
134
+
135
+ if fps_id > 30:
136
+ print("WARNING: Large fps value! This may lead to suboptimal performance.")
137
+
138
+ value_dict = {}
139
+ value_dict["motion_bucket_id"] = motion_bucket_id
140
+ value_dict["fps_id"] = fps_id
141
+ value_dict["cond_aug"] = cond_aug
142
+ value_dict["cond_frames_without_noise"] = image
143
+ value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
144
+ value_dict["cond_aug"] = cond_aug
145
+
146
+ with torch.no_grad():
147
+ with torch.autocast(device):
148
+ batch, batch_uc = get_batch(
149
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
150
+ value_dict,
151
+ [1, num_frames],
152
+ T=num_frames,
153
+ device=device,
154
+ )
155
+ c, uc = model.conditioner.get_unconditional_conditioning(
156
+ batch,
157
+ batch_uc=batch_uc,
158
+ force_uc_zero_embeddings=[
159
+ "cond_frames",
160
+ "cond_frames_without_noise",
161
+ ],
162
+ )
163
+
164
+ for k in ["crossattn", "concat"]:
165
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
166
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
167
+ c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
168
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
169
+
170
+ randn = torch.randn(shape, device=device)
171
+
172
+ additional_model_inputs = {}
173
+ additional_model_inputs["image_only_indicator"] = torch.zeros(
174
+ 2, num_frames
175
+ ).to(device)
176
+ additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
177
+
178
+ def denoiser(input, sigma, c):
179
+ return model.denoiser(
180
+ model.model, input, sigma, c, **additional_model_inputs
181
+ )
182
+
183
+ samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
184
+ model.en_and_decode_n_samples_a_time = decoding_t
185
+ samples_x = model.decode_first_stage(samples_z)
186
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
187
+
188
+ os.makedirs(output_folder, exist_ok=True)
189
+ base_count = len(glob(os.path.join(output_folder, "*.mp4")))
190
+ video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
191
+ writer = cv2.VideoWriter(
192
+ video_path,
193
+ cv2.VideoWriter_fourcc(*"mp4v"),
194
+ fps_id + 1,
195
+ (samples.shape[-1], samples.shape[-2]),
196
+ )
197
+
198
+ samples = embed_watermark(samples)
199
+ samples = filter(samples)
200
+ vid = (
201
+ (rearrange(samples, "t c h w -> t h w c") * 255)
202
+ .cpu()
203
+ .numpy()
204
+ .astype(np.uint8)
205
+ )
206
+ for frame in vid:
207
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
208
+ writer.write(frame)
209
+ writer.release()
210
+
211
+ return video_path
212
 
213
  def get_unique_embedder_keys_from_conditioner(conditioner):
214
  return list(set([x.input_key for x in conditioner.embedders]))
 
254
  batch_uc[key] = torch.clone(batch[key])
255
  return batch, batch_uc
256
 
257
+ import gradio as gr
258
+ import uuid
259
  def resize_image(image_path, output_size=(1024, 576)):
260
+ image = Image.open(image_path)
261
+ # Calculate aspect ratios
262
+ target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size
263
+ image_aspect = image.width / image.height # Aspect ratio of the original image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
265
+ # Resize then crop if the original image is larger
266
+ if image_aspect > target_aspect:
267
+ # Resize the image to match the target height, maintaining aspect ratio
268
+ new_height = output_size[1]
269
+ new_width = int(new_height * image_aspect)
270
+ resized_image = image.resize((new_width, new_height), Image.LANCZOS)
271
+ # Calculate coordinates for cropping
272
+ left = (new_width - output_size[0]) / 2
273
+ top = 0
274
+ right = (new_width + output_size[0]) / 2
275
+ bottom = output_size[1]
276
+ else:
277
+ # Resize the image to match the target width, maintaining aspect ratio
278
+ new_width = output_size[0]
279
+ new_height = int(new_width / image_aspect)
280
+ resized_image = image.resize((new_width, new_height), Image.LANCZOS)
281
+ # Calculate coordinates for cropping
282
+ left = 0
283
+ top = (new_height - output_size[1]) / 2
284
+ right = output_size[0]
285
+ bottom = (new_height + output_size[1]) / 2
286
 
287
+ # Crop the image
288
+ cropped_image = resized_image.crop((left, top, right, bottom))
289
+
290
+ return cropped_image
291
 
292
+ with gr.Blocks() as demo:
293
  gr.Markdown('''# Stable Video Diffusion - Image2Video - XT
294
+ Generate 25 frames of video from a single image using SDV-XT.
295
  ''')
296
  with gr.Column():
297
  image = gr.Image(label="Upload your image (it will be center cropped to 1024x576)", type="filepath")
298
  generate_btn = gr.Button("Generate")
 
 
 
 
 
299
  with gr.Column():
300
  video = gr.Video()
301
+ image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
302
+ generate_btn.click(fn=sample, inputs=image, outputs=video, api_name="video")
303
 
304
+ if __name__ == "__main__":
305
+ demo.launch(share=True)