nroggendorff commited on
Commit
495bf76
1 Parent(s): 3ae9f6d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import torch
4
+ import os
5
+ from glob import glob
6
+ from diffusers import StableVideoDiffusionPipeline
7
+ from diffusers.utils import export_to_video
8
+ from PIL import Image
9
+
10
+ output_folder = "outputs"
11
+
12
+ pipe = StableVideoDiffusionPipeline.from_pretrained(
13
+ "stabilityai/stable-video-diffusion-img2vid-xt", variant="fp16"
14
+ ).to("cuda")
15
+
16
+ @spaces.GPU(duration=480)
17
+ def sample(
18
+ image: Image,
19
+ width: int = 1024,
20
+ height: int = 576,
21
+ motion_bucket_id: int = 127,
22
+ fps_id: int = 30,
23
+ ):
24
+ width = int(width)
25
+ height = int(height)
26
+ img = image.resize((width, height))
27
+
28
+ os.makedirs(output_folder, exist_ok=True)
29
+ base_count = len(glob(os.path.join(output_folder, "*.mp4")))
30
+ video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
31
+
32
+ frames = pipe(
33
+ img,
34
+ decode_chunk_size=3,
35
+ generator=None,
36
+ motion_bucket_id=motion_bucket_id,
37
+ noise_aug_strength=0.1,
38
+ num_frames=25,
39
+ ).frames[0]
40
+
41
+ export_to_video(frames, video_path, fps=fps_id)
42
+ return video_path
43
+
44
+ with gr.Blocks() as demo:
45
+ with gr.Row():
46
+ image = gr.Image(label="Upload your image", type="pil")
47
+ video = gr.Video()
48
+
49
+ with gr.Column():
50
+ generate_btn = gr.Button("Generate")
51
+
52
+ with gr.Accordion("Advanced options", open=False):
53
+ width = gr.Number(label="Width", value=1024, minimum=1)
54
+ height = gr.Number(label="Height", value=576, minimum=1)
55
+ motion_bucket_id = gr.Slider(
56
+ label="Motion bucket id",
57
+ info="Controls how much motion to add/remove from the image",
58
+ value=60,
59
+ minimum=1,
60
+ maximum=255,
61
+ )
62
+ fps_id = gr.Slider(
63
+ label="Frames per second",
64
+ info="Video length will be 25 frames.",
65
+ value=30,
66
+ minimum=5,
67
+ maximum=60,
68
+ )
69
+
70
+ image.upload(fn=lambda img: img, inputs=image, outputs=image, queue=False)
71
+ generate_btn.click(
72
+ fn=sample,
73
+ inputs=[image, width, height, motion_bucket_id, fps_id],
74
+ outputs=[video],
75
+ api_name="video",
76
+ )
77
+
78
+ if __name__ == "__main__":
79
+ demo.queue(max_size=20, api_open=False)
80
+ demo.launch(show_api=False)