Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import os
|
4 |
-
import spaces
|
5 |
import uuid
|
6 |
|
7 |
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
|
@@ -10,6 +9,8 @@ from huggingface_hub import hf_hub_download
|
|
10 |
from safetensors.torch import load_file
|
11 |
from PIL import Image
|
12 |
|
|
|
|
|
13 |
# Constants
|
14 |
bases = {
|
15 |
"ToonYou": "frankjoshua/toonyou_beta6",
|
@@ -28,23 +29,16 @@ dtype = torch.float16
|
|
28 |
pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
|
29 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device)
|
36 |
-
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
|
37 |
|
38 |
-
def check_nsfw_images(images: list[Image.Image]) -> list[bool]:
|
39 |
-
safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
|
40 |
-
has_nsfw_concepts = safety_checker(images=[images], clip_input=safety_checker_input.pixel_values.to(device))
|
41 |
-
return has_nsfw_concepts
|
42 |
|
43 |
-
def generate_image(prompt, base, motion, step, progress=gr.Progress()):
|
44 |
global step_loaded
|
45 |
global base_loaded
|
46 |
global motion_loaded
|
47 |
-
print(prompt, base, step)
|
48 |
|
49 |
if step_loaded != step:
|
50 |
repo = "ByteDance/AnimateDiff-Lightning"
|
@@ -81,25 +75,40 @@ def generate_image(prompt, base, motion, step, progress=gr.Progress()):
|
|
81 |
callback_steps=1
|
82 |
)
|
83 |
|
84 |
-
# AiTube aims for real time, but we are loosing FPS if we perform this step
|
85 |
-
#has_nsfw_concepts = check_nsfw_images([output.frames[0][0]])
|
86 |
-
#if has_nsfw_concepts[0]:
|
87 |
-
# gr.Warning("NSFW content detected.")
|
88 |
-
# return None
|
89 |
-
|
90 |
name = str(uuid.uuid4()).replace("-", "")
|
91 |
path = f"/tmp/{name}.mp4"
|
92 |
|
93 |
# I think we are looking time here too, converting to mp4 is too slow, we should return
|
94 |
# the frames unencoded to the frontend renderer
|
95 |
export_to_video(output.frames[0], path, fps=10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
-
|
|
|
|
|
|
|
|
|
98 |
|
99 |
|
100 |
# Gradio Interface
|
101 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
|
|
|
|
103 |
with gr.Group():
|
104 |
with gr.Row():
|
105 |
prompt = gr.Textbox(
|
@@ -138,30 +147,17 @@ with gr.Blocks() as demo:
|
|
138 |
('2-Step', 2),
|
139 |
('4-Step', 4),
|
140 |
('8-Step', 8)],
|
141 |
-
value=
|
142 |
interactive=True
|
143 |
)
|
144 |
-
submit = gr.Button(
|
145 |
-
|
146 |
-
|
147 |
-
)
|
148 |
-
video = gr.Video(
|
149 |
-
label='AnimateDiff-Lightning',
|
150 |
-
autoplay=True,
|
151 |
-
height=512,
|
152 |
-
width=912,
|
153 |
-
elem_id="video_output"
|
154 |
-
)
|
155 |
|
156 |
-
prompt.submit(
|
157 |
-
fn=generate_image,
|
158 |
-
inputs=[prompt, select_base, select_motion, select_step],
|
159 |
-
outputs=video,
|
160 |
-
)
|
161 |
submit.click(
|
162 |
fn=generate_image,
|
163 |
-
inputs=[prompt, select_base, select_motion, select_step],
|
164 |
-
outputs=
|
165 |
)
|
166 |
|
167 |
-
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import os
|
|
|
4 |
import uuid
|
5 |
|
6 |
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
|
|
|
9 |
from safetensors.torch import load_file
|
10 |
from PIL import Image
|
11 |
|
12 |
+
SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
|
13 |
+
|
14 |
# Constants
|
15 |
bases = {
|
16 |
"ToonYou": "frankjoshua/toonyou_beta6",
|
|
|
29 |
pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
|
30 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
|
31 |
|
32 |
+
def generate_image(secret_token, prompt, base, motion, step):
|
33 |
+
if secret_token != SECRET_TOKEN:
|
34 |
+
raise gr.Error(
|
35 |
+
f'Invalid secret token. Please fork the original space if you want to use it for yourself.')
|
|
|
|
|
36 |
|
|
|
|
|
|
|
|
|
37 |
|
|
|
38 |
global step_loaded
|
39 |
global base_loaded
|
40 |
global motion_loaded
|
41 |
+
# print(prompt, base, step)
|
42 |
|
43 |
if step_loaded != step:
|
44 |
repo = "ByteDance/AnimateDiff-Lightning"
|
|
|
75 |
callback_steps=1
|
76 |
)
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
name = str(uuid.uuid4()).replace("-", "")
|
79 |
path = f"/tmp/{name}.mp4"
|
80 |
|
81 |
# I think we are looking time here too, converting to mp4 is too slow, we should return
|
82 |
# the frames unencoded to the frontend renderer
|
83 |
export_to_video(output.frames[0], path, fps=10)
|
84 |
+
|
85 |
+
# Read the content of the video file and encode it to base64
|
86 |
+
with open(path, "rb") as video_file:
|
87 |
+
video_base64 = base64.b64encode(video_file.read()).decode('utf-8')
|
88 |
+
|
89 |
+
# Prepend the appropriate data URI header with MIME type
|
90 |
+
video_data_uri = 'data:video/mp4;base64,' + video_base64
|
91 |
|
92 |
+
# clean-up (otherwise there is always a risk of "ghosting", eg. someone seeing the previous generated video",
|
93 |
+
# of one of the steps go wrong)
|
94 |
+
os.remove(path)
|
95 |
+
|
96 |
+
return video_data_uri
|
97 |
|
98 |
|
99 |
# Gradio Interface
|
100 |
with gr.Blocks() as demo:
|
101 |
+
gr.HTML("""
|
102 |
+
<div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
|
103 |
+
<div style="text-align: center; color: black;">
|
104 |
+
<p style="color: black;">This space is a REST API to programmatically generate MP4 videos for AiTube, the next generation video platform.</p>
|
105 |
+
<p style="color: black;">Interested in using it? Look no further than the <a href="https://huggingface.co/spaces/ByteDance/AnimateDiff-Lightning" target="_blank">original space</a>!</p>
|
106 |
+
</div>
|
107 |
+
</div>""")
|
108 |
+
|
109 |
|
110 |
+
secret_token = gr.Text(label='Secret Token', max_lines=1)
|
111 |
+
|
112 |
with gr.Group():
|
113 |
with gr.Row():
|
114 |
prompt = gr.Textbox(
|
|
|
147 |
('2-Step', 2),
|
148 |
('4-Step', 4),
|
149 |
('8-Step', 8)],
|
150 |
+
value=4,
|
151 |
interactive=True
|
152 |
)
|
153 |
+
submit = gr.Button()
|
154 |
+
|
155 |
+
output_video_base64 = gr.Text()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
|
|
|
|
|
|
|
|
|
|
157 |
submit.click(
|
158 |
fn=generate_image,
|
159 |
+
inputs=[secret_token, prompt, select_base, select_motion, select_step],
|
160 |
+
outputs=output_video_base64,
|
161 |
)
|
162 |
|
163 |
+
app.queue(max_size=12).launch(show_api=True)
|