PeterL1n's picture
Update
90ee73b
raw
history blame
No virus
2.83 kB
import gradio as gr
import torch
import os
import spaces
import uuid
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
from diffusers.utils import export_to_video
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from PIL import Image
# Constants
base = "frankjoshua/toonyou_beta6"
repo = "ByteDance/AnimateDiff-Lightning"
checkpoints = {
"1-Step" : ["animatediff_lightning_1step_diffusers.safetensors", 1],
"2-Step" : ["animatediff_lightning_2step_diffusers.safetensors", 2],
"4-Step" : ["animatediff_lightning_4step_diffusers.safetensors", 4],
"8-Step" : ["animatediff_lightning_8step_diffusers.safetensors", 8],
}
loaded = None
# Ensure model and scheduler are initialized in GPU-enabled function
if torch.cuda.is_available():
device = "cuda"
dtype = torch.float16
adapter = MotionAdapter().to(device, dtype)
pipe = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
else:
raise NotImplementedError("No GPU detected!")
# Function
@spaces.GPU(enable_queue=True)
def generate_image(prompt, ckpt):
global loaded
print(prompt, ckpt)
checkpoint = checkpoints[ckpt][0]
num_inference_steps = checkpoints[ckpt][1]
if loaded != num_inference_steps:
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device=device), strict=False)
loaded = num_inference_steps
output = pipe(prompt=prompt, guidance_scale=1.0, num_inference_steps=num_inference_steps)
name = str(uuid.uuid4()).replace("-", "")
path = f"/tmp/{name}.mp4"
export_to_video(output.frames[0], path, fps=10)
return path
# Gradio Interface
with gr.Blocks(css="style.css") as demo:
gr.HTML("<h1><center>AnimateDiff-Lightning ⚡</center></h1>")
gr.HTML("<p><center>Lightning-fast text-to-video generation</center></p><p><center><a href='https://huggingface.co/ByteDance/AnimateDiff-Lightning'>https://huggingface.co/ByteDance/AnimateDiff-Lightning</a></center></p>")
with gr.Group():
with gr.Row():
prompt = gr.Textbox(label='Enter your prompt (English)', scale=8)
ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True)
submit = gr.Button(scale=1, variant='primary')
video = gr.Video(label='AnimateDiff-Lightning Generated Image')
prompt.submit(
fn=generate_image,
inputs=[prompt, ckpt],
outputs=video,
)
submit.click(
fn=generate_image,
inputs=[prompt, ckpt],
outputs=video,
)
demo.queue().launch()