Garrett Goon commited on
Commit
4811b12
1 Parent(s): f5c02a5
Files changed (2) hide show
  1. app.py +16 -14
  2. nsfw.png +0 -0
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import pathlib
2
  import os
 
3
 
4
  import gradio as gr
5
  import torch
@@ -8,6 +9,7 @@ from diffusers import StableDiffusionPipeline
8
  import utils
9
 
10
  use_auth_token = os.environ["HF_AUTH_TOKEN"]
 
11
 
12
  # Instantiate the pipeline.
13
  device, revision, torch_dtype = (
@@ -51,19 +53,25 @@ def replace_concept_tokens(text: str):
51
  text = text.replace(concept_token, dummy_tokens)
52
  return text
53
 
54
-
55
  def inference(
56
  prompt: str, guidance_scale: int, num_inference_steps: int, seed: int
57
  ):
58
  prompt = replace_concept_tokens(prompt)
59
  generator = torch.Generator(device=device).manual_seed(seed)
60
- img_list = pipeline(
61
  prompt=[prompt] * 2,
62
  num_inference_steps=num_inference_steps,
63
  guidance_scale=guidance_scale,
64
  generator=generator,
65
- ).images
66
- return img_list
 
 
 
 
 
 
67
 
68
  DEFAULT_PROMPT = (
69
  "A watercolor painting on textured paper of a <det-logo> using soft strokes,"
@@ -77,11 +85,11 @@ with gr.Blocks() as demo:
77
  interactive=True,
78
  )
79
  guidance_scale = gr.Slider(
80
- minimum=1.0, maximum=10.0, value=3.0, label="Guidance Scale", interactive=True
81
  )
82
  num_inference_steps = gr.Slider(
83
  minimum=25,
84
- maximum=60,
85
  value=40,
86
  label="Num Inference Steps",
87
  interactive=True,
@@ -90,18 +98,12 @@ with gr.Blocks() as demo:
90
  seed = gr.Slider(
91
  minimum=2147483147,
92
  maximum=2147483647,
93
- value=2147483397,
94
  label="Seed",
95
  interactive=True,
 
96
  )
97
- # output = gr.Textbox(
98
- # label="output", placeholder=use_auth_token[:5], interactive=False
99
- # )
100
- # gr.Button("test").click(
101
- # lambda s: replace_concept_tokens(s), inputs=[prompt], outputs=output
102
- # )
103
 
104
- generate_btn = gr.Button(label="Generate")
105
  gallery = gr.Gallery(
106
  label="Generated Images",
107
  value=[],
 
1
  import pathlib
2
  import os
3
+ from PIL import Image
4
 
5
  import gradio as gr
6
  import torch
 
9
  import utils
10
 
11
  use_auth_token = os.environ["HF_AUTH_TOKEN"]
12
+ NSFW_IMAGE = Image.open("nsfw.png")
13
 
14
  # Instantiate the pipeline.
15
  device, revision, torch_dtype = (
 
53
  text = text.replace(concept_token, dummy_tokens)
54
  return text
55
 
56
+ all_generated_images = []
57
  def inference(
58
  prompt: str, guidance_scale: int, num_inference_steps: int, seed: int
59
  ):
60
  prompt = replace_concept_tokens(prompt)
61
  generator = torch.Generator(device=device).manual_seed(seed)
62
+ output = pipeline(
63
  prompt=[prompt] * 2,
64
  num_inference_steps=num_inference_steps,
65
  guidance_scale=guidance_scale,
66
  generator=generator,
67
+ )
68
+ img_list, nsfw_list = output.images, output.nsfw_content_detected
69
+ for img, nsfw in zip(img_list, nsfw_list):
70
+ if nsfw:
71
+ all_generated_images.append(NSFW_IMAGE)
72
+ else:
73
+ all_generated_images.append(img)
74
+ return all_generated_images
75
 
76
  DEFAULT_PROMPT = (
77
  "A watercolor painting on textured paper of a <det-logo> using soft strokes,"
 
85
  interactive=True,
86
  )
87
  guidance_scale = gr.Slider(
88
+ minimum=1.0, maximum=20.0, value=3.0, label="Guidance Scale", interactive=True
89
  )
90
  num_inference_steps = gr.Slider(
91
  minimum=25,
92
+ maximum=75,
93
  value=40,
94
  label="Num Inference Steps",
95
  interactive=True,
 
98
  seed = gr.Slider(
99
  minimum=2147483147,
100
  maximum=2147483647,
 
101
  label="Seed",
102
  interactive=True,
103
+ randomize=True
104
  )
 
 
 
 
 
 
105
 
106
+ generate_btn = gr.Button(value="Generate")
107
  gallery = gr.Gallery(
108
  label="Generated Images",
109
  value=[],
nsfw.png ADDED