heheyas commited on
Commit
d85898f
1 Parent(s): cfb7702

add app.py

Browse files
Files changed (1) hide show
  1. app.py +290 -0
app.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO
2
+ import numpy as np
3
+ import argparse
4
+ import torch
5
+ from torchvision.utils import make_grid
6
+ import tempfile
7
+ import gradio as gr
8
+ from omegaconf import OmegaConf
9
+ from einops import rearrange
10
+ from scripts.pub.V3D_512 import (
11
+ sample_one,
12
+ get_batch,
13
+ get_unique_embedder_keys_from_conditioner,
14
+ load_model,
15
+ )
16
+ from sgm.util import default, instantiate_from_config
17
+ from safetensors.torch import load_file as load_safetensors
18
+ from PIL import Image
19
+ from kiui.op import recenter
20
+ from torchvision.transforms import ToTensor
21
+ from einops import rearrange, repeat
22
+ import rembg
23
+ import os
24
+ from glob import glob
25
+ from mediapy import write_video
26
+ from pathlib import Path
27
+
28
+
29
+ def do_sample(
30
+ image,
31
+ model,
32
+ clip_model,
33
+ ae_model,
34
+ device,
35
+ num_frames,
36
+ num_steps,
37
+ decoding_t,
38
+ border_ratio,
39
+ ignore_alpha,
40
+ rembg_session,
41
+ output_folder,
42
+ ):
43
+ # if image.mode == "RGBA":
44
+ # image = image.convert("RGB")
45
+ image = Image.fromarray(image)
46
+ w, h = image.size
47
+
48
+ if border_ratio > 0:
49
+ if image.mode != "RGBA" or ignore_alpha:
50
+ image = image.convert("RGB")
51
+ image = np.asarray(image)
52
+ carved_image = rembg.remove(image, session=rembg_session) # [H, W, 4]
53
+ else:
54
+ image = np.asarray(image)
55
+ carved_image = image
56
+ mask = carved_image[..., -1] > 0
57
+ image = recenter(carved_image, mask, border_ratio=border_ratio)
58
+ image = image.astype(np.float32) / 255.0
59
+ if image.shape[-1] == 4:
60
+ image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
61
+ image = Image.fromarray((image * 255).astype(np.uint8))
62
+ else:
63
+ print("Ignore border ratio")
64
+ image = image.resize((512, 512))
65
+
66
+ image = ToTensor()(image)
67
+ image = image * 2.0 - 1.0
68
+
69
+ image = image.unsqueeze(0).to(device)
70
+ H, W = image.shape[2:]
71
+ assert image.shape[1] == 3
72
+ F = 8
73
+ C = 4
74
+ shape = (num_frames, C, H // F, W // F)
75
+
76
+ value_dict = {}
77
+ value_dict["motion_bucket_id"] = 0
78
+ value_dict["fps_id"] = 0
79
+ value_dict["cond_aug"] = 0.05
80
+ value_dict["cond_frames_without_noise"] = clip_model(image)
81
+ value_dict["cond_frames"] = ae_model.encode(image)
82
+ value_dict["cond_frames"] += 0.05 * torch.randn_like(value_dict["cond_frames"])
83
+ value_dict["cond_aug"] = 0.05
84
+
85
+ with torch.no_grad():
86
+ with torch.autocast(device):
87
+ batch, batch_uc = get_batch(
88
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
89
+ value_dict,
90
+ [1, num_frames],
91
+ T=num_frames,
92
+ device=device,
93
+ )
94
+ c, uc = model.conditioner.get_unconditional_conditioning(
95
+ batch,
96
+ batch_uc=batch_uc,
97
+ force_uc_zero_embeddings=[
98
+ "cond_frames",
99
+ "cond_frames_without_noise",
100
+ ],
101
+ )
102
+
103
+ for k in ["crossattn", "concat"]:
104
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
105
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
106
+ c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
107
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
108
+
109
+ randn = torch.randn(shape, device=device)
110
+ randn = randn.to(device)
111
+
112
+ additional_model_inputs = {}
113
+ additional_model_inputs["image_only_indicator"] = torch.zeros(
114
+ 2, num_frames
115
+ ).to(device)
116
+ additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
117
+
118
+ def denoiser(input, sigma, c):
119
+ return model.denoiser(
120
+ model.model, input, sigma, c, **additional_model_inputs
121
+ )
122
+
123
+ samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
124
+ model.en_and_decode_n_samples_a_time = decoding_t
125
+ samples_x = model.decode_first_stage(samples_z)
126
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
127
+
128
+ os.makedirs(output_folder, exist_ok=True)
129
+ base_count = len(glob(os.path.join(output_folder, "*.mp4")))
130
+ video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
131
+
132
+ frames = (
133
+ (rearrange(samples, "t c h w -> t h w c") * 255)
134
+ .cpu()
135
+ .numpy()
136
+ .astype(np.uint8)
137
+ )
138
+ write_video(video_path, frames, fps=6)
139
+
140
+ return video_path
141
+
142
+
143
+ def change_model_params(model, min_cfg, max_cfg):
144
+ model.sampler.guider.max_scale = max_cfg
145
+ model.sampler.guider.min_scale = min_cfg
146
+
147
+
148
+ def launch(device="cuda", port=4321, share=False):
149
+ model_config = "scripts/pub/configs/V3D_512.yaml"
150
+ num_frames = OmegaConf.load(
151
+ model_config
152
+ ).model.params.sampler_config.params.guider_config.params.num_frames
153
+ print("Detected num_frames:", num_frames)
154
+ # num_steps = default(num_steps, 25)
155
+ num_steps = 25
156
+ output_folder = "outputs/V3D_512"
157
+
158
+ sd = load_safetensors("./ckpts/svd_xt.safetensors")
159
+ clip_model_config = OmegaConf.load("configs/embedder/clip_image.yaml")
160
+ clip_model = instantiate_from_config(clip_model_config).eval()
161
+ clip_sd = dict()
162
+ for k, v in sd.items():
163
+ if "conditioner.embedders.0" in k:
164
+ clip_sd[k.replace("conditioner.embedders.0.", "")] = v
165
+ clip_model.load_state_dict(clip_sd)
166
+ clip_model = clip_model.to(device)
167
+
168
+ ae_model_config = OmegaConf.load("configs/ae/video.yaml")
169
+ ae_model = instantiate_from_config(ae_model_config).eval()
170
+ encoder_sd = dict()
171
+ for k, v in sd.items():
172
+ if "first_stage_model" in k:
173
+ encoder_sd[k.replace("first_stage_model.", "")] = v
174
+ ae_model.load_state_dict(encoder_sd)
175
+ ae_model = ae_model.to(device)
176
+ rembg_session = rembg.new_session()
177
+
178
+ model, _ = load_model(
179
+ model_config, device, num_frames, num_steps, min_cfg=3.5, max_cfg=3.5
180
+ )
181
+
182
+ with gr.Blocks(title="V3D", theme=gr.themes.Monochrome()) as demo:
183
+ with gr.Row(equal_height=True):
184
+ with gr.Column():
185
+ input_image = gr.Image(value=None, label="Input Image")
186
+
187
+ border_ratio_slider = gr.Slider(
188
+ value=0.3,
189
+ label="Border Ratio",
190
+ minimum=0.05,
191
+ maximum=0.5,
192
+ step=0.05,
193
+ )
194
+ decoding_t_slider = gr.Slider(
195
+ value=1,
196
+ label="Number of Decoding frames",
197
+ minimum=1,
198
+ maximum=num_frames,
199
+ step=1,
200
+ )
201
+ min_guidance_slider = gr.Slider(
202
+ value=3.5,
203
+ label="Min CFG Value",
204
+ minimum=0.05,
205
+ maximum=0.5,
206
+ step=0.05,
207
+ )
208
+ max_guidance_slider = gr.Slider(
209
+ value=3.5,
210
+ label="Max CFG Value",
211
+ minimum=0.05,
212
+ maximum=0.5,
213
+ step=0.05,
214
+ )
215
+ run_button = gr.Button(value="Run V3D")
216
+
217
+ with gr.Column():
218
+ output_video = gr.Video(value=None, label="Output Orbit Video")
219
+
220
+ @run_button.click(
221
+ inputs=[
222
+ input_image,
223
+ border_ratio_slider,
224
+ min_guidance_slider,
225
+ max_guidance_slider,
226
+ decoding_t_slider,
227
+ ],
228
+ outputs=[output_video],
229
+ )
230
+ def _(image, border_ratio, min_guidance, max_guidance, decoding_t):
231
+ change_model_params(model, min_guidance, max_guidance)
232
+ return do_sample(
233
+ image,
234
+ model,
235
+ clip_model,
236
+ ae_model,
237
+ device,
238
+ num_frames,
239
+ num_steps,
240
+ int(decoding_t),
241
+ border_ratio,
242
+ False,
243
+ rembg_session,
244
+ output_folder,
245
+ )
246
+
247
+ # do_sample(
248
+ # np.asarray(Image.open("assets/baby_yoda.png")),
249
+ # model,
250
+ # clip_model,
251
+ # ae_model,
252
+ # device,
253
+ # num_frames,
254
+ # num_steps,
255
+ # 1,
256
+ # 0.3,
257
+ # False,
258
+ # rembg_session,
259
+ # output_folder,
260
+ # )
261
+ demo.launch(
262
+ inbrowser=True, inline=False, server_port=port, share=share, show_error=True
263
+ )
264
+
265
+
266
+ if __name__ == "__main__":
267
+ parser = argparse.ArgumentParser()
268
+ parser.add_argument("--port", type=int, default=4321)
269
+ parser.add_argument("--device", type=str, default="cuda")
270
+ parser.add_argument("--share", action="store_true")
271
+
272
+ opt = parser.parse_args()
273
+
274
+ def download_if_need(path, url):
275
+ if Path(path).exists():
276
+ return
277
+ import wget
278
+
279
+ path.parent.mkdir(parents=True, exist_ok=True)
280
+ wget.download(url, out=str(path))
281
+
282
+ download_if_need(
283
+ "ckpts/svd_xt.safetensors",
284
+ "https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/resolve/main/svd_xt.safetensors -O ckpts/svd_xt.safetensors",
285
+ )
286
+ download_if_need(
287
+ "ckpts/V3D_512.ckpt", "https://huggingface.co/heheyas/V3D/resolve/main/V3D.ckpt"
288
+ )
289
+
290
+ launch(opt.device, opt.port, opt.share)