jbilcke-hf HF staff commited on
Commit
0a535f7
·
verified ·
1 Parent(s): ed63142

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -40
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import torch
3
  import os
4
- import spaces
5
  import uuid
6
 
7
  from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
@@ -10,6 +9,8 @@ from huggingface_hub import hf_hub_download
10
  from safetensors.torch import load_file
11
  from PIL import Image
12
 
 
 
13
  # Constants
14
  bases = {
15
  "ToonYou": "frankjoshua/toonyou_beta6",
@@ -28,23 +29,16 @@ dtype = torch.float16
28
  pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
29
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
30
 
31
- # Safety checkers
32
- from safety_checker import StableDiffusionSafetyChecker
33
- from transformers import CLIPFeatureExtractor
34
-
35
- safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device)
36
- feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
37
 
38
- def check_nsfw_images(images: list[Image.Image]) -> list[bool]:
39
- safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
40
- has_nsfw_concepts = safety_checker(images=[images], clip_input=safety_checker_input.pixel_values.to(device))
41
- return has_nsfw_concepts
42
 
43
- def generate_image(prompt, base, motion, step, progress=gr.Progress()):
44
  global step_loaded
45
  global base_loaded
46
  global motion_loaded
47
- print(prompt, base, step)
48
 
49
  if step_loaded != step:
50
  repo = "ByteDance/AnimateDiff-Lightning"
@@ -81,25 +75,40 @@ def generate_image(prompt, base, motion, step, progress=gr.Progress()):
81
  callback_steps=1
82
  )
83
 
84
- # AiTube aims for real time, but we are loosing FPS if we perform this step
85
- #has_nsfw_concepts = check_nsfw_images([output.frames[0][0]])
86
- #if has_nsfw_concepts[0]:
87
- # gr.Warning("NSFW content detected.")
88
- # return None
89
-
90
  name = str(uuid.uuid4()).replace("-", "")
91
  path = f"/tmp/{name}.mp4"
92
 
93
  # I think we are looking time here too, converting to mp4 is too slow, we should return
94
  # the frames unencoded to the frontend renderer
95
  export_to_video(output.frames[0], path, fps=10)
 
 
 
 
 
 
 
96
 
97
- return path
 
 
 
 
98
 
99
 
100
  # Gradio Interface
101
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
102
 
 
 
103
  with gr.Group():
104
  with gr.Row():
105
  prompt = gr.Textbox(
@@ -138,30 +147,17 @@ with gr.Blocks() as demo:
138
  ('2-Step', 2),
139
  ('4-Step', 4),
140
  ('8-Step', 8)],
141
- value=2,
142
  interactive=True
143
  )
144
- submit = gr.Button(
145
- scale=1,
146
- variant='primary'
147
- )
148
- video = gr.Video(
149
- label='AnimateDiff-Lightning',
150
- autoplay=True,
151
- height=512,
152
- width=912,
153
- elem_id="video_output"
154
- )
155
 
156
- prompt.submit(
157
- fn=generate_image,
158
- inputs=[prompt, select_base, select_motion, select_step],
159
- outputs=video,
160
- )
161
  submit.click(
162
  fn=generate_image,
163
- inputs=[prompt, select_base, select_motion, select_step],
164
- outputs=video,
165
  )
166
 
167
- demo.queue().launch()
 
1
  import gradio as gr
2
  import torch
3
  import os
 
4
  import uuid
5
 
6
  from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
 
9
  from safetensors.torch import load_file
10
  from PIL import Image
11
 
12
+ SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')
13
+
14
  # Constants
15
  bases = {
16
  "ToonYou": "frankjoshua/toonyou_beta6",
 
29
  pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
30
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
31
 
32
+ def generate_image(secret_token, prompt, base, motion, step):
33
+ if secret_token != SECRET_TOKEN:
34
+ raise gr.Error(
35
+ f'Invalid secret token. Please fork the original space if you want to use it for yourself.')
 
 
36
 
 
 
 
 
37
 
 
38
  global step_loaded
39
  global base_loaded
40
  global motion_loaded
41
+ # print(prompt, base, step)
42
 
43
  if step_loaded != step:
44
  repo = "ByteDance/AnimateDiff-Lightning"
 
75
  callback_steps=1
76
  )
77
 
 
 
 
 
 
 
78
  name = str(uuid.uuid4()).replace("-", "")
79
  path = f"/tmp/{name}.mp4"
80
 
81
  # I think we are looking time here too, converting to mp4 is too slow, we should return
82
  # the frames unencoded to the frontend renderer
83
  export_to_video(output.frames[0], path, fps=10)
84
+
85
+ # Read the content of the video file and encode it to base64
86
+ with open(path, "rb") as video_file:
87
+ video_base64 = base64.b64encode(video_file.read()).decode('utf-8')
88
+
89
+ # Prepend the appropriate data URI header with MIME type
90
+ video_data_uri = 'data:video/mp4;base64,' + video_base64
91
 
92
+ # clean-up (otherwise there is always a risk of "ghosting", eg. someone seeing the previous generated video",
93
+ # of one of the steps go wrong)
94
+ os.remove(path)
95
+
96
+ return video_data_uri
97
 
98
 
99
  # Gradio Interface
100
  with gr.Blocks() as demo:
101
+ gr.HTML("""
102
+ <div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
103
+ <div style="text-align: center; color: black;">
104
+ <p style="color: black;">This space is a REST API to programmatically generate MP4 videos for AiTube, the next generation video platform.</p>
105
+ <p style="color: black;">Interested in using it? Look no further than the <a href="https://huggingface.co/spaces/ByteDance/AnimateDiff-Lightning" target="_blank">original space</a>!</p>
106
+ </div>
107
+ </div>""")
108
+
109
 
110
+ secret_token = gr.Text(label='Secret Token', max_lines=1)
111
+
112
  with gr.Group():
113
  with gr.Row():
114
  prompt = gr.Textbox(
 
147
  ('2-Step', 2),
148
  ('4-Step', 4),
149
  ('8-Step', 8)],
150
+ value=4,
151
  interactive=True
152
  )
153
+ submit = gr.Button()
154
+
155
+ output_video_base64 = gr.Text()
 
 
 
 
 
 
 
 
156
 
 
 
 
 
 
157
  submit.click(
158
  fn=generate_image,
159
+ inputs=[secret_token, prompt, select_base, select_motion, select_step],
160
+ outputs=output_video_base64,
161
  )
162
 
163
+ app.queue(max_size=12).launch(show_api=True)