Spaces:
Runtime error
Runtime error
import gradio as gr | |
import spaces | |
import torch | |
import os | |
from glob import glob | |
from diffusers import StableVideoDiffusionPipeline | |
from diffusers.utils import export_to_video | |
from PIL import Image | |
output_folder = "outputs" | |
pipe = StableVideoDiffusionPipeline.from_pretrained( | |
"stabilityai/stable-video-diffusion-img2vid-xt", variant="fp16" | |
).to("cuda") | |
def sample( | |
image: Image, | |
width: int = 1024, | |
height: int = 576, | |
motion_bucket_id: int = 127, | |
fps_id: int = 30, | |
): | |
width = int(width) | |
height = int(height) | |
img = image.resize((width, height)) | |
os.makedirs(output_folder, exist_ok=True) | |
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) | |
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") | |
frames = pipe( | |
img, | |
decode_chunk_size=3, | |
generator=None, | |
motion_bucket_id=motion_bucket_id, | |
noise_aug_strength=0.1, | |
num_frames=25, | |
).frames[0] | |
export_to_video(frames, video_path, fps=fps_id) | |
return video_path | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
image = gr.Image(label="Upload your image", type="pil") | |
video = gr.Video() | |
with gr.Column(): | |
generate_btn = gr.Button("Generate") | |
with gr.Accordion("Advanced options", open=False): | |
width = gr.Number(label="Width", value=1024, minimum=1) | |
height = gr.Number(label="Height", value=576, minimum=1) | |
motion_bucket_id = gr.Slider( | |
label="Motion bucket id", | |
info="Controls how much motion to add/remove from the image", | |
value=60, | |
minimum=1, | |
maximum=255, | |
) | |
fps_id = gr.Slider( | |
label="Frames per second", | |
info="Video length will be 25 frames.", | |
value=30, | |
minimum=5, | |
maximum=60, | |
) | |
image.upload(fn=lambda img: img, inputs=image, outputs=image, queue=False) | |
generate_btn.click( | |
fn=sample, | |
inputs=[image, width, height, motion_bucket_id, fps_id], | |
outputs=[video], | |
api_name="video", | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20, api_open=False) | |
demo.launch(show_api=False) |