fffiloni commited on
Commit
a9ed66b
1 Parent(s): d8793bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  #import torch
3
  #from torch import autocast // only for GPU
4
- from datasets import load_dataset
5
  from PIL import Image
6
 
7
  import os
@@ -19,24 +19,26 @@ device="cpu"
19
  pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=YOUR_TOKEN)
20
  pipe.to(device)
21
 
22
- #When running locally, you won`t have access to this, so you can remove this part
23
- word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=YOUR_TOKEN)
24
- word_list = word_list_dataset["train"]['text']
25
 
26
  def infer(prompt):
27
- for filter in word_list:
28
- if re.search(rf"\b{filter}\b", prompt):
29
- raise gr.Error("Unsafe content found. Please try again with different prompts.")
30
 
31
- #image = pipe(prompt, init_image=init_image)["sample"][0]
32
- image = pipe(prompt)["sample"][0]
33
 
34
- return image
 
 
 
 
 
 
 
 
 
35
 
36
  print("Great sylvain ! Everything is working fine !")
37
 
38
  title="Stable Diffusion CPU"
39
  description="Stable Diffusion example using CPU and HF token. Warning: Slow process... ~5/10 min inference time"
40
 
41
- gr.Interface(fn=infer, inputs="text", outputs="image",title=title,description=description).launch(enable_queue=True)
42
 
 
1
  import gradio as gr
2
  #import torch
3
  #from torch import autocast // only for GPU
4
+
5
  from PIL import Image
6
 
7
  import os
 
19
  pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=YOUR_TOKEN)
20
  pipe.to(device)
21
 
22
+ gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[2], height="auto")
 
 
23
 
24
  def infer(prompt):
 
 
 
25
 
 
 
26
 
27
+ #image = pipe(prompt, init_image=init_image)["sample"][0]
28
+ images_list = pipe([prompt] * 1)
29
+ images = []
30
+ safe_image = Image.open(r"unsafe.png")
31
+ for i, image in enumerate(images_list["sample"]):
32
+ if(images_list["nsfw_content_detected"][i]):
33
+ images.append(safe_image)
34
+ else:
35
+ images.append(image)
36
+ return images
37
 
38
  print("Great sylvain ! Everything is working fine !")
39
 
40
  title="Stable Diffusion CPU"
41
  description="Stable Diffusion example using CPU and HF token. Warning: Slow process... ~5/10 min inference time"
42
 
43
+ gr.Interface(fn=infer, inputs="text", outputs=gallery,title=title,description=description).launch(enable_queue=True)
44