hlydecker commited on
Commit
496c7b8
1 Parent(s): 33c9935

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -18
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import gradio as gr
2
  from datasets import load_dataset
3
  from PIL import Image
 
 
4
 
5
  import re
6
  import os
@@ -8,28 +10,17 @@ import requests
8
 
9
  from share_btn import community_icon_html, loading_icon_html, share_js
10
 
11
- word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
12
- word_list = word_list_dataset["train"]['text']
13
 
14
  is_gpu_busy = False
15
  def infer(person1, person2, scale):
16
  global is_gpu_busy
17
  #FIXME: did somebody say f"prompt" injections? :P
18
  prompt = f"Cross between {person1} and {person2}"
19
- for filter in word_list:
20
- if re.search(rf"\b{filter}\b", prompt):
21
- raise gr.Error("Unsafe content found. Please try again with different prompts.")
22
 
23
- images = []
24
- person_1 = ''
25
- url = os.getenv('JAX_BACKEND_URL')
26
- payload = {'prompt': prompt, 'guidance_scale': scale}
27
- images_request = requests.post(url, json = payload)
28
- for image in images_request.json()["images"]:
29
- image_b64 = (f"data:image/jpeg;base64,{image}")
30
- images.append(image_b64)
31
 
32
- return images
33
 
34
 
35
  css = """
@@ -300,11 +291,11 @@ with block:
300
  # randomize=True,
301
  # )
302
 
303
- ex = gr.Examples(examples=examples, fn=infer, inputs=[person1, person2, guidance_scale], outputs=[gallery, community_icon, loading_icon, share_button], cache_examples=False)
304
  ex.dataset.headers = [""]
305
- negative.submit(infer, inputs=[person1, person2, guidance_scale], outputs=[gallery], postprocess=False)
306
- text.submit(infer, inputs=[person1, person2, guidance_scale], outputs=[gallery], postprocess=False)
307
- btn.click(infer, inputs=[person1, person2, guidance_scale], outputs=[gallery], postprocess=False)
308
 
309
  #advanced_button.click(
310
  # None,
 
1
  import gradio as gr
2
  from datasets import load_dataset
3
  from PIL import Image
4
+ import torch
5
+ from diffusers import StableDiffusionPipeline
6
 
7
  import re
8
  import os
 
10
 
11
  from share_btn import community_icon_html, loading_icon_html, share_js
12
 
13
+ pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
 
14
 
15
  is_gpu_busy = False
16
  def infer(person1, person2, scale):
17
  global is_gpu_busy
18
  #FIXME: did somebody say f"prompt" injections? :P
19
  prompt = f"Cross between {person1} and {person2}"
 
 
 
20
 
21
+ image = pipe(prompt).images[0]
 
 
 
 
 
 
 
22
 
23
+ return image
24
 
25
 
26
  css = """
 
291
  # randomize=True,
292
  # )
293
 
294
+ ex = gr.Examples(examples=examples, fn=infer, inputs=[person1, person2], outputs=[gallery, community_icon, loading_icon, share_button], cache_examples=False)
295
  ex.dataset.headers = [""]
296
+ negative.submit(infer, inputs=[person1, person2], outputs=[gallery], postprocess=False)
297
+ text.submit(infer, inputs=[person1, person2], outputs=[gallery], postprocess=False)
298
+ btn.click(infer, inputs=[person1, person2], outputs=[gallery], postprocess=False)
299
 
300
  #advanced_button.click(
301
  # None,