Spaces:
Runtime error
Runtime error
cocktailpeanut
commited on
Commit
•
ad9639a
1
Parent(s):
c2a3eed
update
Browse files
app2.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
from glob import glob
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
from datetime import datetime
|
10 |
+
from safetensors import safe_open
|
11 |
+
|
12 |
+
from diffusers import AutoencoderKL
|
13 |
+
from diffusers.utils.import_utils import is_xformers_available
|
14 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
15 |
+
|
16 |
+
from animatelcm.scheduler.lcm_scheduler import LCMScheduler
|
17 |
+
from animatelcm.models.unet import UNet3DConditionModel
|
18 |
+
from animatelcm.pipelines.pipeline_animation import AnimationPipeline
|
19 |
+
from animatelcm.utils.util import save_videos_grid
|
20 |
+
from animatelcm.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
|
21 |
+
from animatelcm.utils.convert_lora_safetensor_to_diffusers import convert_lora
|
22 |
+
from animatelcm.utils.lcm_utils import convert_lcm_lora
|
23 |
+
import copy
|
24 |
+
|
25 |
+
sample_idx = 0
|
26 |
+
scheduler_dict = {
|
27 |
+
"LCM": LCMScheduler,
|
28 |
+
}
|
29 |
+
|
30 |
+
css = """
|
31 |
+
.toolbutton {
|
32 |
+
margin-buttom: 0em 0em 0em 0em;
|
33 |
+
max-width: 2.5em;
|
34 |
+
min-width: 2.5em !important;
|
35 |
+
height: 2.5em;
|
36 |
+
}
|
37 |
+
"""
|
38 |
+
|
39 |
+
if torch.backends.mps.is_available():
|
40 |
+
device = "mps"
|
41 |
+
elif torch.cuda.is_available():
|
42 |
+
device = "cuda"
|
43 |
+
else:
|
44 |
+
device = "cpu"
|
45 |
+
|
46 |
+
class AnimateController:
|
47 |
+
def __init__(self):
|
48 |
+
|
49 |
+
# config dirs
|
50 |
+
self.basedir = os.getcwd()
|
51 |
+
self.stable_diffusion_dir = os.path.join(
|
52 |
+
self.basedir, "models", "StableDiffusion")
|
53 |
+
self.motion_module_dir = os.path.join(
|
54 |
+
self.basedir, "models", "Motion_Module")
|
55 |
+
self.personalized_model_dir = os.path.join(
|
56 |
+
self.basedir, "models", "DreamBooth_LoRA")
|
57 |
+
self.savedir = os.path.join(
|
58 |
+
self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
|
59 |
+
self.savedir_sample = os.path.join(self.savedir, "sample")
|
60 |
+
self.lcm_lora_path = "models/LCM_LoRA/sd15_t2v_beta_lora.safetensors"
|
61 |
+
os.makedirs(self.savedir, exist_ok=True)
|
62 |
+
|
63 |
+
self.stable_diffusion_list = []
|
64 |
+
self.motion_module_list = []
|
65 |
+
self.personalized_model_list = []
|
66 |
+
|
67 |
+
self.refresh_stable_diffusion()
|
68 |
+
self.refresh_motion_module()
|
69 |
+
self.refresh_personalized_model()
|
70 |
+
|
71 |
+
# config models
|
72 |
+
self.tokenizer = None
|
73 |
+
self.text_encoder = None
|
74 |
+
self.vae = None
|
75 |
+
self.unet = None
|
76 |
+
self.pipeline = None
|
77 |
+
self.lora_model_state_dict = {}
|
78 |
+
|
79 |
+
self.inference_config = OmegaConf.load("configs/inference.yaml")
|
80 |
+
|
81 |
+
def refresh_stable_diffusion(self):
|
82 |
+
self.stable_diffusion_list = glob(
|
83 |
+
os.path.join(self.stable_diffusion_dir, "*/"))
|
84 |
+
|
85 |
+
def refresh_motion_module(self):
|
86 |
+
motion_module_list = glob(os.path.join(
|
87 |
+
self.motion_module_dir, "*.ckpt"))
|
88 |
+
self.motion_module_list = [
|
89 |
+
os.path.basename(p) for p in motion_module_list]
|
90 |
+
|
91 |
+
def refresh_personalized_model(self):
|
92 |
+
personalized_model_list = glob(os.path.join(
|
93 |
+
self.personalized_model_dir, "*.safetensors"))
|
94 |
+
self.personalized_model_list = [
|
95 |
+
os.path.basename(p) for p in personalized_model_list]
|
96 |
+
|
97 |
+
def update_stable_diffusion(self, stable_diffusion_dropdown):
|
98 |
+
stable_diffusion_dropdown = os.path.join(self.stable_diffusion_dir,stable_diffusion_dropdown)
|
99 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(
|
100 |
+
stable_diffusion_dropdown, subfolder="tokenizer")
|
101 |
+
self.text_encoder = CLIPTextModel.from_pretrained(
|
102 |
+
stable_diffusion_dropdown, subfolder="text_encoder").to(device)
|
103 |
+
self.vae = AutoencoderKL.from_pretrained(
|
104 |
+
stable_diffusion_dropdown, subfolder="vae").to(device)
|
105 |
+
self.unet = UNet3DConditionModel.from_pretrained_2d(
|
106 |
+
stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).to(device)
|
107 |
+
return gr.Dropdown.update()
|
108 |
+
|
109 |
+
def update_motion_module(self, motion_module_dropdown):
|
110 |
+
if self.unet is None:
|
111 |
+
gr.Info(f"Please select a pretrained model path.")
|
112 |
+
return gr.Dropdown.update(value=None)
|
113 |
+
else:
|
114 |
+
motion_module_dropdown = os.path.join(
|
115 |
+
self.motion_module_dir, motion_module_dropdown)
|
116 |
+
motion_module_state_dict = torch.load(
|
117 |
+
motion_module_dropdown, map_location="cpu")
|
118 |
+
missing, unexpected = self.unet.load_state_dict(
|
119 |
+
motion_module_state_dict, strict=False)
|
120 |
+
del motion_module_state_dict
|
121 |
+
assert len(unexpected) == 0
|
122 |
+
return gr.Dropdown.update()
|
123 |
+
|
124 |
+
def update_base_model(self, base_model_dropdown):
|
125 |
+
if self.unet is None:
|
126 |
+
gr.Info(f"Please select a pretrained model path.")
|
127 |
+
return gr.Dropdown.update(value=None)
|
128 |
+
else:
|
129 |
+
base_model_dropdown = os.path.join(
|
130 |
+
self.personalized_model_dir, base_model_dropdown)
|
131 |
+
base_model_state_dict = {}
|
132 |
+
with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
|
133 |
+
for key in f.keys():
|
134 |
+
base_model_state_dict[key] = f.get_tensor(key)
|
135 |
+
|
136 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
137 |
+
base_model_state_dict, self.vae.config)
|
138 |
+
self.vae.load_state_dict(converted_vae_checkpoint)
|
139 |
+
|
140 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
141 |
+
base_model_state_dict, self.unet.config)
|
142 |
+
self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
143 |
+
del converted_unet_checkpoint
|
144 |
+
del converted_vae_checkpoint
|
145 |
+
del base_model_state_dict
|
146 |
+
|
147 |
+
# self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
|
148 |
+
return gr.Dropdown.update()
|
149 |
+
|
150 |
+
def update_lora_model(self, lora_model_dropdown):
|
151 |
+
lora_model_dropdown = os.path.join(
|
152 |
+
self.personalized_model_dir, lora_model_dropdown)
|
153 |
+
self.lora_model_state_dict = {}
|
154 |
+
if lora_model_dropdown == "none":
|
155 |
+
pass
|
156 |
+
else:
|
157 |
+
with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f:
|
158 |
+
for key in f.keys():
|
159 |
+
self.lora_model_state_dict[key] = f.get_tensor(key)
|
160 |
+
return gr.Dropdown.update()
|
161 |
+
@torch.no_grad()
|
162 |
+
def animate(
|
163 |
+
self,
|
164 |
+
lora_alpha_slider,
|
165 |
+
spatial_lora_slider,
|
166 |
+
prompt_textbox,
|
167 |
+
negative_prompt_textbox,
|
168 |
+
sampler_dropdown,
|
169 |
+
sample_step_slider,
|
170 |
+
width_slider,
|
171 |
+
length_slider,
|
172 |
+
height_slider,
|
173 |
+
cfg_scale_slider,
|
174 |
+
seed_textbox
|
175 |
+
):
|
176 |
+
|
177 |
+
if is_xformers_available():
|
178 |
+
self.unet.enable_xformers_memory_efficient_attention()
|
179 |
+
|
180 |
+
pipeline = AnimationPipeline(
|
181 |
+
vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
|
182 |
+
scheduler=scheduler_dict[sampler_dropdown](
|
183 |
+
**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
|
184 |
+
).to(device)
|
185 |
+
|
186 |
+
original_state_dict = {k: v.cpu().clone() for k, v in pipeline.unet.state_dict().items() if "motion_modules." not in k}
|
187 |
+
pipeline.unet = convert_lcm_lora(pipeline.unet, self.lcm_lora_path, spatial_lora_slider)
|
188 |
+
|
189 |
+
pipeline.to(device)
|
190 |
+
|
191 |
+
if seed_textbox != -1 and seed_textbox != "":
|
192 |
+
torch.manual_seed(int(seed_textbox))
|
193 |
+
else:
|
194 |
+
torch.seed()
|
195 |
+
seed = torch.initial_seed()
|
196 |
+
|
197 |
+
with torch.autocast(device:
|
198 |
+
sample = pipeline(
|
199 |
+
prompt_textbox,
|
200 |
+
negative_prompt=negative_prompt_textbox,
|
201 |
+
num_inference_steps=sample_step_slider,
|
202 |
+
guidance_scale=cfg_scale_slider,
|
203 |
+
width=width_slider,
|
204 |
+
height=height_slider,
|
205 |
+
video_length=length_slider,
|
206 |
+
).videos
|
207 |
+
|
208 |
+
pipeline.unet.load_state_dict(original_state_dict,strict=False)
|
209 |
+
del original_state_dict
|
210 |
+
|
211 |
+
save_sample_path = os.path.join(
|
212 |
+
self.savedir_sample, f"{sample_idx}.mp4")
|
213 |
+
save_videos_grid(sample, save_sample_path)
|
214 |
+
|
215 |
+
sample_config = {
|
216 |
+
"prompt": prompt_textbox,
|
217 |
+
"n_prompt": negative_prompt_textbox,
|
218 |
+
"sampler": sampler_dropdown,
|
219 |
+
"num_inference_steps": sample_step_slider,
|
220 |
+
"guidance_scale": cfg_scale_slider,
|
221 |
+
"width": width_slider,
|
222 |
+
"height": height_slider,
|
223 |
+
"video_length": length_slider,
|
224 |
+
"seed": seed
|
225 |
+
}
|
226 |
+
json_str = json.dumps(sample_config, indent=4)
|
227 |
+
with open(os.path.join(self.savedir, "logs.json"), "a") as f:
|
228 |
+
f.write(json_str)
|
229 |
+
f.write("\n\n")
|
230 |
+
return gr.Video.update(value=save_sample_path)
|
231 |
+
|
232 |
+
|
233 |
+
controller = AnimateController()
|
234 |
+
|
235 |
+
controller.update_stable_diffusion("stable-diffusion-v1-5")
|
236 |
+
controller.update_motion_module("sd15_t2v_beta_motion.ckpt")
|
237 |
+
controller.update_base_model("realistic2.safetensors")
|
238 |
+
|
239 |
+
|
240 |
+
def ui():
|
241 |
+
with gr.Blocks(css=css) as demo:
|
242 |
+
gr.Markdown(
|
243 |
+
"""
|
244 |
+
# [AnimateLCM: Accelerating the Animation of Personalized Diffusion Models and Adapters with Decoupled Consistency Learning](https://arxiv.org/abs/2402.00769)
|
245 |
+
Fu-Yun Wang, Zhaoyang Huang (*Corresponding Author), Xiaoyu Shi, Weikang Bian, Guanglu Song, Yu Liu, Hongsheng Li (*Corresponding Author)<br>
|
246 |
+
[arXiv Report](https://arxiv.org/abs/2402.00769) | [Project Page](https://animatelcm.github.io/) | [Github](https://github.com/G-U-N/AnimateLCM) | [Civitai](https://civitai.com/models/290375/animatelcm-fast-video-generation) | [Replicate](https://replicate.com/camenduru/animate-lcm)
|
247 |
+
"""
|
248 |
+
|
249 |
+
'''
|
250 |
+
Important Notes:
|
251 |
+
1. The generation speed is around few seconds. There is delay in the space.
|
252 |
+
2. Increase the sampling step and cfg if you want more fancy videos.
|
253 |
+
'''
|
254 |
+
)
|
255 |
+
with gr.Column(variant="panel"):
|
256 |
+
with gr.Row():
|
257 |
+
|
258 |
+
base_model_dropdown = gr.Dropdown(
|
259 |
+
label="Select base Dreambooth model (required)",
|
260 |
+
choices=controller.personalized_model_list,
|
261 |
+
interactive=True,
|
262 |
+
value="realistic2.safetensors"
|
263 |
+
)
|
264 |
+
base_model_dropdown.change(fn=controller.update_base_model, inputs=[
|
265 |
+
base_model_dropdown], outputs=[base_model_dropdown])
|
266 |
+
|
267 |
+
lora_model_dropdown = gr.Dropdown(
|
268 |
+
label="Select LoRA model (optional)",
|
269 |
+
choices=["none",],
|
270 |
+
value="none",
|
271 |
+
interactive=True,
|
272 |
+
)
|
273 |
+
lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[
|
274 |
+
lora_model_dropdown], outputs=[lora_model_dropdown])
|
275 |
+
|
276 |
+
lora_alpha_slider = gr.Slider(
|
277 |
+
label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True)
|
278 |
+
spatial_lora_slider = gr.Slider(
|
279 |
+
label="LCM LoRA alpha", value=0.8, minimum=0.0, maximum=1.0, interactive=True)
|
280 |
+
|
281 |
+
personalized_refresh_button = gr.Button(
|
282 |
+
value="\U0001F503", elem_classes="toolbutton")
|
283 |
+
|
284 |
+
def update_personalized_model():
|
285 |
+
controller.refresh_personalized_model()
|
286 |
+
return [
|
287 |
+
gr.Dropdown.update(
|
288 |
+
choices=controller.personalized_model_list),
|
289 |
+
gr.Dropdown.update(
|
290 |
+
choices=["none"] + controller.personalized_model_list)
|
291 |
+
]
|
292 |
+
personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[
|
293 |
+
base_model_dropdown, lora_model_dropdown])
|
294 |
+
|
295 |
+
with gr.Column(variant="panel"):
|
296 |
+
gr.Markdown(
|
297 |
+
"""
|
298 |
+
### 2. Configs for AnimateLCM.
|
299 |
+
"""
|
300 |
+
)
|
301 |
+
|
302 |
+
prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="a boy holding a rabbit")
|
303 |
+
negative_prompt_textbox = gr.Textbox(
|
304 |
+
label="Negative prompt", lines=2, value="bad quality")
|
305 |
+
|
306 |
+
with gr.Row().style(equal_height=False):
|
307 |
+
with gr.Column():
|
308 |
+
with gr.Row():
|
309 |
+
sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(
|
310 |
+
scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
|
311 |
+
sample_step_slider = gr.Slider(
|
312 |
+
label="Sampling steps", value=6, minimum=1, maximum=25, step=1)
|
313 |
+
|
314 |
+
width_slider = gr.Slider(
|
315 |
+
label="Width", value=512, minimum=256, maximum=1024, step=64)
|
316 |
+
height_slider = gr.Slider(
|
317 |
+
label="Height", value=512, minimum=256, maximum=1024, step=64)
|
318 |
+
length_slider = gr.Slider(
|
319 |
+
label="Animation length", value=16, minimum=12, maximum=20, step=1)
|
320 |
+
cfg_scale_slider = gr.Slider(
|
321 |
+
label="CFG Scale", value=1.5, minimum=1, maximum=2)
|
322 |
+
|
323 |
+
with gr.Row():
|
324 |
+
seed_textbox = gr.Textbox(label="Seed", value=-1)
|
325 |
+
seed_button = gr.Button(
|
326 |
+
value="\U0001F3B2", elem_classes="toolbutton")
|
327 |
+
seed_button.click(fn=lambda: gr.Textbox.update(
|
328 |
+
value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
|
329 |
+
|
330 |
+
generate_button = gr.Button(
|
331 |
+
value="Generate", variant='primary')
|
332 |
+
|
333 |
+
result_video = gr.Video(
|
334 |
+
label="Generated Animation", interactive=False)
|
335 |
+
|
336 |
+
generate_button.click(
|
337 |
+
fn=controller.animate,
|
338 |
+
inputs=[
|
339 |
+
lora_alpha_slider,
|
340 |
+
spatial_lora_slider,
|
341 |
+
prompt_textbox,
|
342 |
+
negative_prompt_textbox,
|
343 |
+
sampler_dropdown,
|
344 |
+
sample_step_slider,
|
345 |
+
width_slider,
|
346 |
+
length_slider,
|
347 |
+
height_slider,
|
348 |
+
cfg_scale_slider,
|
349 |
+
seed_textbox,
|
350 |
+
],
|
351 |
+
outputs=[result_video]
|
352 |
+
)
|
353 |
+
|
354 |
+
examples = [
|
355 |
+
[0.8, 0.8, "a boy is holding a rabbit", "bad quality", "LCM", 8, 512, 16, 512, 1.5, 123],
|
356 |
+
[0.8, 0.8, "1girl smiling", "bad quality", "LCM", 4, 512, 16, 512, 1.5, 1233],
|
357 |
+
[0.8, 0.8, "1girl,face,white background,", "bad quality", "LCM", 6, 512, 16, 512, 1.5, 1234],
|
358 |
+
[0.8, 0.8, "clouds in the sky, best quality", "bad quality", "LCM", 4, 512, 16, 512, 1.5, 1234],
|
359 |
+
|
360 |
+
|
361 |
+
]
|
362 |
+
gr.Examples(
|
363 |
+
examples = examples,
|
364 |
+
inputs=[
|
365 |
+
lora_alpha_slider,
|
366 |
+
spatial_lora_slider,
|
367 |
+
prompt_textbox,
|
368 |
+
negative_prompt_textbox,
|
369 |
+
sampler_dropdown,
|
370 |
+
sample_step_slider,
|
371 |
+
width_slider,
|
372 |
+
length_slider,
|
373 |
+
height_slider,
|
374 |
+
cfg_scale_slider,
|
375 |
+
seed_textbox,
|
376 |
+
],
|
377 |
+
outputs=[result_video],
|
378 |
+
fn=controller.animate,
|
379 |
+
cache_examples=True,
|
380 |
+
)
|
381 |
+
|
382 |
+
return demo
|
383 |
+
|
384 |
+
|
385 |
+
|
386 |
+
if __name__ == "__main__":
|
387 |
+
demo = ui()
|
388 |
+
# gr.close_all()
|
389 |
+
demo.queue(api_open=False)
|
390 |
+
demo.launch()
|
391 |
+
|