cocktailpeanut commited on
Commit
ad9639a
1 Parent(s): c2a3eed
Files changed (1) hide show
  1. app2.py +391 -0
app2.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import random
5
+
6
+ import gradio as gr
7
+ from glob import glob
8
+ from omegaconf import OmegaConf
9
+ from datetime import datetime
10
+ from safetensors import safe_open
11
+
12
+ from diffusers import AutoencoderKL
13
+ from diffusers.utils.import_utils import is_xformers_available
14
+ from transformers import CLIPTextModel, CLIPTokenizer
15
+
16
+ from animatelcm.scheduler.lcm_scheduler import LCMScheduler
17
+ from animatelcm.models.unet import UNet3DConditionModel
18
+ from animatelcm.pipelines.pipeline_animation import AnimationPipeline
19
+ from animatelcm.utils.util import save_videos_grid
20
+ from animatelcm.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
21
+ from animatelcm.utils.convert_lora_safetensor_to_diffusers import convert_lora
22
+ from animatelcm.utils.lcm_utils import convert_lcm_lora
23
+ import copy
24
+
25
+ sample_idx = 0
26
+ scheduler_dict = {
27
+ "LCM": LCMScheduler,
28
+ }
29
+
30
+ css = """
31
+ .toolbutton {
32
+ margin-buttom: 0em 0em 0em 0em;
33
+ max-width: 2.5em;
34
+ min-width: 2.5em !important;
35
+ height: 2.5em;
36
+ }
37
+ """
38
+
39
+ if torch.backends.mps.is_available():
40
+ device = "mps"
41
+ elif torch.cuda.is_available():
42
+ device = "cuda"
43
+ else:
44
+ device = "cpu"
45
+
46
+ class AnimateController:
47
+ def __init__(self):
48
+
49
+ # config dirs
50
+ self.basedir = os.getcwd()
51
+ self.stable_diffusion_dir = os.path.join(
52
+ self.basedir, "models", "StableDiffusion")
53
+ self.motion_module_dir = os.path.join(
54
+ self.basedir, "models", "Motion_Module")
55
+ self.personalized_model_dir = os.path.join(
56
+ self.basedir, "models", "DreamBooth_LoRA")
57
+ self.savedir = os.path.join(
58
+ self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
59
+ self.savedir_sample = os.path.join(self.savedir, "sample")
60
+ self.lcm_lora_path = "models/LCM_LoRA/sd15_t2v_beta_lora.safetensors"
61
+ os.makedirs(self.savedir, exist_ok=True)
62
+
63
+ self.stable_diffusion_list = []
64
+ self.motion_module_list = []
65
+ self.personalized_model_list = []
66
+
67
+ self.refresh_stable_diffusion()
68
+ self.refresh_motion_module()
69
+ self.refresh_personalized_model()
70
+
71
+ # config models
72
+ self.tokenizer = None
73
+ self.text_encoder = None
74
+ self.vae = None
75
+ self.unet = None
76
+ self.pipeline = None
77
+ self.lora_model_state_dict = {}
78
+
79
+ self.inference_config = OmegaConf.load("configs/inference.yaml")
80
+
81
+ def refresh_stable_diffusion(self):
82
+ self.stable_diffusion_list = glob(
83
+ os.path.join(self.stable_diffusion_dir, "*/"))
84
+
85
+ def refresh_motion_module(self):
86
+ motion_module_list = glob(os.path.join(
87
+ self.motion_module_dir, "*.ckpt"))
88
+ self.motion_module_list = [
89
+ os.path.basename(p) for p in motion_module_list]
90
+
91
+ def refresh_personalized_model(self):
92
+ personalized_model_list = glob(os.path.join(
93
+ self.personalized_model_dir, "*.safetensors"))
94
+ self.personalized_model_list = [
95
+ os.path.basename(p) for p in personalized_model_list]
96
+
97
+ def update_stable_diffusion(self, stable_diffusion_dropdown):
98
+ stable_diffusion_dropdown = os.path.join(self.stable_diffusion_dir,stable_diffusion_dropdown)
99
+ self.tokenizer = CLIPTokenizer.from_pretrained(
100
+ stable_diffusion_dropdown, subfolder="tokenizer")
101
+ self.text_encoder = CLIPTextModel.from_pretrained(
102
+ stable_diffusion_dropdown, subfolder="text_encoder").to(device)
103
+ self.vae = AutoencoderKL.from_pretrained(
104
+ stable_diffusion_dropdown, subfolder="vae").to(device)
105
+ self.unet = UNet3DConditionModel.from_pretrained_2d(
106
+ stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).to(device)
107
+ return gr.Dropdown.update()
108
+
109
+ def update_motion_module(self, motion_module_dropdown):
110
+ if self.unet is None:
111
+ gr.Info(f"Please select a pretrained model path.")
112
+ return gr.Dropdown.update(value=None)
113
+ else:
114
+ motion_module_dropdown = os.path.join(
115
+ self.motion_module_dir, motion_module_dropdown)
116
+ motion_module_state_dict = torch.load(
117
+ motion_module_dropdown, map_location="cpu")
118
+ missing, unexpected = self.unet.load_state_dict(
119
+ motion_module_state_dict, strict=False)
120
+ del motion_module_state_dict
121
+ assert len(unexpected) == 0
122
+ return gr.Dropdown.update()
123
+
124
+ def update_base_model(self, base_model_dropdown):
125
+ if self.unet is None:
126
+ gr.Info(f"Please select a pretrained model path.")
127
+ return gr.Dropdown.update(value=None)
128
+ else:
129
+ base_model_dropdown = os.path.join(
130
+ self.personalized_model_dir, base_model_dropdown)
131
+ base_model_state_dict = {}
132
+ with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
133
+ for key in f.keys():
134
+ base_model_state_dict[key] = f.get_tensor(key)
135
+
136
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(
137
+ base_model_state_dict, self.vae.config)
138
+ self.vae.load_state_dict(converted_vae_checkpoint)
139
+
140
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(
141
+ base_model_state_dict, self.unet.config)
142
+ self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
143
+ del converted_unet_checkpoint
144
+ del converted_vae_checkpoint
145
+ del base_model_state_dict
146
+
147
+ # self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
148
+ return gr.Dropdown.update()
149
+
150
+ def update_lora_model(self, lora_model_dropdown):
151
+ lora_model_dropdown = os.path.join(
152
+ self.personalized_model_dir, lora_model_dropdown)
153
+ self.lora_model_state_dict = {}
154
+ if lora_model_dropdown == "none":
155
+ pass
156
+ else:
157
+ with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f:
158
+ for key in f.keys():
159
+ self.lora_model_state_dict[key] = f.get_tensor(key)
160
+ return gr.Dropdown.update()
161
+ @torch.no_grad()
162
+ def animate(
163
+ self,
164
+ lora_alpha_slider,
165
+ spatial_lora_slider,
166
+ prompt_textbox,
167
+ negative_prompt_textbox,
168
+ sampler_dropdown,
169
+ sample_step_slider,
170
+ width_slider,
171
+ length_slider,
172
+ height_slider,
173
+ cfg_scale_slider,
174
+ seed_textbox
175
+ ):
176
+
177
+ if is_xformers_available():
178
+ self.unet.enable_xformers_memory_efficient_attention()
179
+
180
+ pipeline = AnimationPipeline(
181
+ vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
182
+ scheduler=scheduler_dict[sampler_dropdown](
183
+ **OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
184
+ ).to(device)
185
+
186
+ original_state_dict = {k: v.cpu().clone() for k, v in pipeline.unet.state_dict().items() if "motion_modules." not in k}
187
+ pipeline.unet = convert_lcm_lora(pipeline.unet, self.lcm_lora_path, spatial_lora_slider)
188
+
189
+ pipeline.to(device)
190
+
191
+ if seed_textbox != -1 and seed_textbox != "":
192
+ torch.manual_seed(int(seed_textbox))
193
+ else:
194
+ torch.seed()
195
+ seed = torch.initial_seed()
196
+
197
+ with torch.autocast(device:
198
+ sample = pipeline(
199
+ prompt_textbox,
200
+ negative_prompt=negative_prompt_textbox,
201
+ num_inference_steps=sample_step_slider,
202
+ guidance_scale=cfg_scale_slider,
203
+ width=width_slider,
204
+ height=height_slider,
205
+ video_length=length_slider,
206
+ ).videos
207
+
208
+ pipeline.unet.load_state_dict(original_state_dict,strict=False)
209
+ del original_state_dict
210
+
211
+ save_sample_path = os.path.join(
212
+ self.savedir_sample, f"{sample_idx}.mp4")
213
+ save_videos_grid(sample, save_sample_path)
214
+
215
+ sample_config = {
216
+ "prompt": prompt_textbox,
217
+ "n_prompt": negative_prompt_textbox,
218
+ "sampler": sampler_dropdown,
219
+ "num_inference_steps": sample_step_slider,
220
+ "guidance_scale": cfg_scale_slider,
221
+ "width": width_slider,
222
+ "height": height_slider,
223
+ "video_length": length_slider,
224
+ "seed": seed
225
+ }
226
+ json_str = json.dumps(sample_config, indent=4)
227
+ with open(os.path.join(self.savedir, "logs.json"), "a") as f:
228
+ f.write(json_str)
229
+ f.write("\n\n")
230
+ return gr.Video.update(value=save_sample_path)
231
+
232
+
233
+ controller = AnimateController()
234
+
235
+ controller.update_stable_diffusion("stable-diffusion-v1-5")
236
+ controller.update_motion_module("sd15_t2v_beta_motion.ckpt")
237
+ controller.update_base_model("realistic2.safetensors")
238
+
239
+
240
+ def ui():
241
+ with gr.Blocks(css=css) as demo:
242
+ gr.Markdown(
243
+ """
244
+ # [AnimateLCM: Accelerating the Animation of Personalized Diffusion Models and Adapters with Decoupled Consistency Learning](https://arxiv.org/abs/2402.00769)
245
+ Fu-Yun Wang, Zhaoyang Huang (*Corresponding Author), Xiaoyu Shi, Weikang Bian, Guanglu Song, Yu Liu, Hongsheng Li (*Corresponding Author)<br>
246
+ [arXiv Report](https://arxiv.org/abs/2402.00769) | [Project Page](https://animatelcm.github.io/) | [Github](https://github.com/G-U-N/AnimateLCM) | [Civitai](https://civitai.com/models/290375/animatelcm-fast-video-generation) | [Replicate](https://replicate.com/camenduru/animate-lcm)
247
+ """
248
+
249
+ '''
250
+ Important Notes:
251
+ 1. The generation speed is around few seconds. There is delay in the space.
252
+ 2. Increase the sampling step and cfg if you want more fancy videos.
253
+ '''
254
+ )
255
+ with gr.Column(variant="panel"):
256
+ with gr.Row():
257
+
258
+ base_model_dropdown = gr.Dropdown(
259
+ label="Select base Dreambooth model (required)",
260
+ choices=controller.personalized_model_list,
261
+ interactive=True,
262
+ value="realistic2.safetensors"
263
+ )
264
+ base_model_dropdown.change(fn=controller.update_base_model, inputs=[
265
+ base_model_dropdown], outputs=[base_model_dropdown])
266
+
267
+ lora_model_dropdown = gr.Dropdown(
268
+ label="Select LoRA model (optional)",
269
+ choices=["none",],
270
+ value="none",
271
+ interactive=True,
272
+ )
273
+ lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[
274
+ lora_model_dropdown], outputs=[lora_model_dropdown])
275
+
276
+ lora_alpha_slider = gr.Slider(
277
+ label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True)
278
+ spatial_lora_slider = gr.Slider(
279
+ label="LCM LoRA alpha", value=0.8, minimum=0.0, maximum=1.0, interactive=True)
280
+
281
+ personalized_refresh_button = gr.Button(
282
+ value="\U0001F503", elem_classes="toolbutton")
283
+
284
+ def update_personalized_model():
285
+ controller.refresh_personalized_model()
286
+ return [
287
+ gr.Dropdown.update(
288
+ choices=controller.personalized_model_list),
289
+ gr.Dropdown.update(
290
+ choices=["none"] + controller.personalized_model_list)
291
+ ]
292
+ personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[
293
+ base_model_dropdown, lora_model_dropdown])
294
+
295
+ with gr.Column(variant="panel"):
296
+ gr.Markdown(
297
+ """
298
+ ### 2. Configs for AnimateLCM.
299
+ """
300
+ )
301
+
302
+ prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="a boy holding a rabbit")
303
+ negative_prompt_textbox = gr.Textbox(
304
+ label="Negative prompt", lines=2, value="bad quality")
305
+
306
+ with gr.Row().style(equal_height=False):
307
+ with gr.Column():
308
+ with gr.Row():
309
+ sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(
310
+ scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
311
+ sample_step_slider = gr.Slider(
312
+ label="Sampling steps", value=6, minimum=1, maximum=25, step=1)
313
+
314
+ width_slider = gr.Slider(
315
+ label="Width", value=512, minimum=256, maximum=1024, step=64)
316
+ height_slider = gr.Slider(
317
+ label="Height", value=512, minimum=256, maximum=1024, step=64)
318
+ length_slider = gr.Slider(
319
+ label="Animation length", value=16, minimum=12, maximum=20, step=1)
320
+ cfg_scale_slider = gr.Slider(
321
+ label="CFG Scale", value=1.5, minimum=1, maximum=2)
322
+
323
+ with gr.Row():
324
+ seed_textbox = gr.Textbox(label="Seed", value=-1)
325
+ seed_button = gr.Button(
326
+ value="\U0001F3B2", elem_classes="toolbutton")
327
+ seed_button.click(fn=lambda: gr.Textbox.update(
328
+ value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
329
+
330
+ generate_button = gr.Button(
331
+ value="Generate", variant='primary')
332
+
333
+ result_video = gr.Video(
334
+ label="Generated Animation", interactive=False)
335
+
336
+ generate_button.click(
337
+ fn=controller.animate,
338
+ inputs=[
339
+ lora_alpha_slider,
340
+ spatial_lora_slider,
341
+ prompt_textbox,
342
+ negative_prompt_textbox,
343
+ sampler_dropdown,
344
+ sample_step_slider,
345
+ width_slider,
346
+ length_slider,
347
+ height_slider,
348
+ cfg_scale_slider,
349
+ seed_textbox,
350
+ ],
351
+ outputs=[result_video]
352
+ )
353
+
354
+ examples = [
355
+ [0.8, 0.8, "a boy is holding a rabbit", "bad quality", "LCM", 8, 512, 16, 512, 1.5, 123],
356
+ [0.8, 0.8, "1girl smiling", "bad quality", "LCM", 4, 512, 16, 512, 1.5, 1233],
357
+ [0.8, 0.8, "1girl,face,white background,", "bad quality", "LCM", 6, 512, 16, 512, 1.5, 1234],
358
+ [0.8, 0.8, "clouds in the sky, best quality", "bad quality", "LCM", 4, 512, 16, 512, 1.5, 1234],
359
+
360
+
361
+ ]
362
+ gr.Examples(
363
+ examples = examples,
364
+ inputs=[
365
+ lora_alpha_slider,
366
+ spatial_lora_slider,
367
+ prompt_textbox,
368
+ negative_prompt_textbox,
369
+ sampler_dropdown,
370
+ sample_step_slider,
371
+ width_slider,
372
+ length_slider,
373
+ height_slider,
374
+ cfg_scale_slider,
375
+ seed_textbox,
376
+ ],
377
+ outputs=[result_video],
378
+ fn=controller.animate,
379
+ cache_examples=True,
380
+ )
381
+
382
+ return demo
383
+
384
+
385
+
386
+ if __name__ == "__main__":
387
+ demo = ui()
388
+ # gr.close_all()
389
+ demo.queue(api_open=False)
390
+ demo.launch()
391
+