Commit
•
61902e5
1
Parent(s):
7da1ebd
Update app.py
Browse files
app.py
CHANGED
@@ -40,8 +40,6 @@ def check_nsfw_images(images: list[Image.Image]) -> list[bool]:
|
|
40 |
has_nsfw_concepts = safety_checker(images=[images], clip_input=safety_checker_input.pixel_values.to(device))
|
41 |
return has_nsfw_concepts
|
42 |
|
43 |
-
# Function
|
44 |
-
@spaces.GPU(enable_queue=True)
|
45 |
def generate_image(prompt, base, motion, step, progress=gr.Progress()):
|
46 |
global step_loaded
|
47 |
global base_loaded
|
@@ -71,36 +69,41 @@ def generate_image(prompt, base, motion, step, progress=gr.Progress()):
|
|
71 |
|
72 |
output = pipe(
|
73 |
prompt=prompt,
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
76 |
guidance_scale=1.0,
|
77 |
num_inference_steps=step,
|
78 |
callback=progress_callback,
|
79 |
callback_steps=1
|
80 |
)
|
81 |
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
86 |
|
87 |
name = str(uuid.uuid4()).replace("-", "")
|
88 |
path = f"/tmp/{name}.mp4"
|
|
|
|
|
|
|
89 |
export_to_video(output.frames[0], path, fps=10)
|
|
|
90 |
return path
|
91 |
|
92 |
|
93 |
# Gradio Interface
|
94 |
-
with gr.Blocks(
|
95 |
-
|
96 |
-
"<h1><center>AnimateDiff-Lightning ⚡</center></h1>" +
|
97 |
-
"<p><center>Lightning-fast text-to-video generation</center></p>" +
|
98 |
-
"<p><center><a href='https://huggingface.co/ByteDance/AnimateDiff-Lightning'>https://huggingface.co/ByteDance/AnimateDiff-Lightning</a></center></p>"
|
99 |
-
)
|
100 |
with gr.Group():
|
101 |
with gr.Row():
|
102 |
prompt = gr.Textbox(
|
103 |
-
label='Prompt
|
104 |
)
|
105 |
with gr.Row():
|
106 |
select_base = gr.Dropdown(
|
@@ -135,7 +138,7 @@ with gr.Blocks(css="style.css") as demo:
|
|
135 |
('2-Step', 2),
|
136 |
('4-Step', 4),
|
137 |
('8-Step', 8)],
|
138 |
-
value=
|
139 |
interactive=True
|
140 |
)
|
141 |
submit = gr.Button(
|
@@ -145,8 +148,8 @@ with gr.Blocks(css="style.css") as demo:
|
|
145 |
video = gr.Video(
|
146 |
label='AnimateDiff-Lightning',
|
147 |
autoplay=True,
|
148 |
-
height=
|
149 |
-
width=
|
150 |
elem_id="video_output"
|
151 |
)
|
152 |
|
|
|
40 |
has_nsfw_concepts = safety_checker(images=[images], clip_input=safety_checker_input.pixel_values.to(device))
|
41 |
return has_nsfw_concepts
|
42 |
|
|
|
|
|
43 |
def generate_image(prompt, base, motion, step, progress=gr.Progress()):
|
44 |
global step_loaded
|
45 |
global base_loaded
|
|
|
69 |
|
70 |
output = pipe(
|
71 |
prompt=prompt,
|
72 |
+
|
73 |
+
# this corresponds roughly to 16:9
|
74 |
+
# which is the aspect ratio video used by AiTube
|
75 |
+
width=910, # 1024,
|
76 |
+
height=512, # 576,
|
77 |
+
|
78 |
guidance_scale=1.0,
|
79 |
num_inference_steps=step,
|
80 |
callback=progress_callback,
|
81 |
callback_steps=1
|
82 |
)
|
83 |
|
84 |
+
# AiTube aims for real time, but we are loosing FPS if we perform this step
|
85 |
+
#has_nsfw_concepts = check_nsfw_images([output.frames[0][0]])
|
86 |
+
#if has_nsfw_concepts[0]:
|
87 |
+
# gr.Warning("NSFW content detected.")
|
88 |
+
# return None
|
89 |
|
90 |
name = str(uuid.uuid4()).replace("-", "")
|
91 |
path = f"/tmp/{name}.mp4"
|
92 |
+
|
93 |
+
# I think we are looking time here too, converting to mp4 is too slow, we should return
|
94 |
+
# the frames unencoded to the frontend renderer
|
95 |
export_to_video(output.frames[0], path, fps=10)
|
96 |
+
|
97 |
return path
|
98 |
|
99 |
|
100 |
# Gradio Interface
|
101 |
+
with gr.Blocks() as demo:
|
102 |
+
|
|
|
|
|
|
|
|
|
103 |
with gr.Group():
|
104 |
with gr.Row():
|
105 |
prompt = gr.Textbox(
|
106 |
+
label='Prompt'
|
107 |
)
|
108 |
with gr.Row():
|
109 |
select_base = gr.Dropdown(
|
|
|
138 |
('2-Step', 2),
|
139 |
('4-Step', 4),
|
140 |
('8-Step', 8)],
|
141 |
+
value=2,
|
142 |
interactive=True
|
143 |
)
|
144 |
submit = gr.Button(
|
|
|
148 |
video = gr.Video(
|
149 |
label='AnimateDiff-Lightning',
|
150 |
autoplay=True,
|
151 |
+
height=512,
|
152 |
+
width=910,
|
153 |
elem_id="video_output"
|
154 |
)
|
155 |
|