Spaces:
Paused
Paused
| import gradio as gr | |
| import torch | |
| import os | |
| from glob import glob | |
| from diffusers import StableVideoDiffusionPipeline | |
| from diffusers.utils import export_to_video | |
| from PIL import Image | |
| output_folder = "outputs" | |
| pipe = StableVideoDiffusionPipeline.from_pretrained( | |
| "stabilityai/stable-video-diffusion-img2vid-xt", variant="fp16" | |
| ).to("cuda") | |
| def sample( | |
| image: Image.Image, | |
| width: int = 1024, | |
| height: int = 576, | |
| motion_bucket_id: int = 127, | |
| fps_id: int = 30, | |
| ): | |
| width = int(width) | |
| height = int(height) | |
| img = image.resize((width, height)) | |
| os.makedirs(output_folder, exist_ok=True) | |
| base_count = len(glob(os.path.join(output_folder, "*.mp4"))) | |
| video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") | |
| frames = pipe( | |
| img, | |
| decode_chunk_size=3, | |
| generator=None, | |
| motion_bucket_id=motion_bucket_id, | |
| noise_aug_strength=0.1, | |
| num_frames=25, | |
| ).frames[0] | |
| export_to_video(frames, video_path, fps=fps_id) | |
| return video_path | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| image = gr.Image(label="Upload your image", type="pil") | |
| video = gr.Video() | |
| with gr.Column(): | |
| generate_btn = gr.Button("Generate") | |
| with gr.Accordion("Advanced options", open=False): | |
| width = gr.Number(label="Width", value=1024, minimum=1) | |
| height = gr.Number(label="Height", value=576, minimum=1) | |
| motion_bucket_id = gr.Slider( | |
| label="Motion bucket id", | |
| info="Controls how much motion to add/remove from the image", | |
| value=60, | |
| minimum=1, | |
| maximum=255, | |
| ) | |
| fps_id = gr.Slider( | |
| label="Frames per second", | |
| info="Video length will be 25 frames.", | |
| value=30, | |
| minimum=5, | |
| step=2, | |
| maximum=60, | |
| ) | |
| generate_btn.click( | |
| fn=sample, | |
| inputs=[image, width, height, motion_bucket_id, fps_id], | |
| outputs=[video], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20, api_open=False) | |
| demo.launch(show_api=False) |