Artvik commited on
Commit
e1042be
1 Parent(s): 95400a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -23
app.py CHANGED
@@ -5,35 +5,66 @@ from torch import autocast
5
  from diffusers import StableDiffusionPipeline
6
  from datasets import load_dataset
7
  from PIL import Image
 
 
8
  import re
 
 
 
 
9
 
10
  model_id = "CompVis/stable-diffusion-v1-4"
11
  device = "cuda"
12
 
13
  #If you are running this code locally, you need to either do a 'huggingface-cli login` or paste your User Access Token from here https://huggingface.co/settings/tokens into the use_auth_token field below.
14
- pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token="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", revision="fp16", torch_dtype=torch.float16)
15
  pipe = pipe.to(device)
16
  torch.backends.cudnn.benchmark = True
17
 
 
 
 
18
 
19
- def infer(prompt, samples, steps, scale, seed):
 
 
 
 
 
 
 
 
 
20
 
21
- generator = torch.Generator(device=device).manual_seed(seed)
22
-
23
- images_list = pipe(
24
- [prompt] * samples,
25
- num_inference_steps=steps,
26
- guidance_scale=scale,
27
- generator=generator,
28
- )
29
  images = []
30
- safe_image = Image.open(r"unsafe.png")
31
- for i, image in enumerate(images_list["sample"]):
32
- if(images_list["nsfw_content_detected"][i]):
33
- images.append(safe_image)
34
- else:
35
- images.append(image)
36
- return images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  css = """
39
  .gradio-container {
@@ -118,6 +149,38 @@ css = """
118
  font-weight: bold;
119
  font-size: 115%;
120
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  """
122
 
123
  block = gr.Blocks(css=css)
@@ -160,6 +223,7 @@ examples = [
160
  ],
161
  ]
162
 
 
163
  with block:
164
  gr.HTML(
165
  """
@@ -225,12 +289,13 @@ with block:
225
  )
226
  with gr.Group():
227
  with gr.Box():
228
- with gr.Row().style(mobile_collapse=False, equal_height=True):
229
  text = gr.Textbox(
230
  label="Enter your prompt",
231
  show_label=False,
232
  max_lines=1,
233
  placeholder="Enter your prompt",
 
234
  ).style(
235
  border=(True, False, True, True),
236
  rounded=(True, False, False, True),
@@ -239,15 +304,22 @@ with block:
239
  btn = gr.Button("Generate image").style(
240
  margin=False,
241
  rounded=(False, True, True, False),
 
242
  )
243
 
244
  gallery = gr.Gallery(
245
  label="Generated images", show_label=False, elem_id="gallery"
246
  ).style(grid=[2], height="auto")
247
 
248
- advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
 
 
 
 
 
249
 
250
  with gr.Row(elem_id="advanced-options"):
 
251
  samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
252
  steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
253
  scale = gr.Slider(
@@ -261,12 +333,14 @@ with block:
261
  randomize=True,
262
  )
263
 
264
- ex = gr.Examples(examples=examples, fn=infer, inputs=[text, samples, steps, scale, seed], outputs=gallery, cache_examples=True)
265
  ex.dataset.headers = [""]
266
 
267
 
268
- text.submit(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
269
- btn.click(infer, inputs=[text, samples, steps, scale, seed], outputs=gallery)
 
 
270
  advanced_button.click(
271
  None,
272
  [],
@@ -277,6 +351,12 @@ with block:
277
  options.style.display = ["none", ""].includes(options.style.display) ? "flex" : "none";
278
  }""",
279
  )
 
 
 
 
 
 
280
  gr.HTML(
281
  """
282
  <div class="footer">
@@ -292,4 +372,4 @@ Despite how impressive being able to turn text into image is, beware to the fact
292
  """
293
  )
294
 
295
- block.queue(max_size=25).launch()
 
5
  from diffusers import StableDiffusionPipeline
6
  from datasets import load_dataset
7
  from PIL import Image
8
+ from io import BytesIO
9
+ import base64
10
  import re
11
+ import os
12
+ import requests
13
+
14
+ from share_btn import community_icon_html, loading_icon_html, share_js
15
 
16
  model_id = "CompVis/stable-diffusion-v1-4"
17
  device = "cuda"
18
 
19
  #If you are running this code locally, you need to either do a 'huggingface-cli login` or paste your User Access Token from here https://huggingface.co/settings/tokens into the use_auth_token field below.
20
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, revision="fp16", torch_dtype=torch.float16)
21
  pipe = pipe.to(device)
22
  torch.backends.cudnn.benchmark = True
23
 
24
+ #When running locally, you won`t have access to this, so you can remove this part
25
+ word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
26
+ word_list = word_list_dataset["train"]['text']
27
 
28
+ is_gpu_busy = False
29
+ def infer(prompt):
30
+ global is_gpu_busy
31
+ samples = 4
32
+ steps = 50
33
+ scale = 7.5
34
+ #When running locally you can also remove this filter
35
+ for filter in word_list:
36
+ if re.search(rf"\b{filter}\b", prompt):
37
+ raise gr.Error("Unsafe content found. Please try again with different prompts.")
38
 
39
+ #generator = torch.Generator(device=device).manual_seed(seed)
40
+ print("Is GPU busy? ", is_gpu_busy)
 
 
 
 
 
 
41
  images = []
42
+ if(not is_gpu_busy):
43
+ is_gpu_busy = True
44
+ images_list = pipe(
45
+ [prompt] * samples,
46
+ num_inference_steps=steps,
47
+ guidance_scale=scale,
48
+ #generator=generator,
49
+ )
50
+ is_gpu_busy = False
51
+ safe_image = Image.open(r"unsafe.png")
52
+ for i, image in enumerate(images_list["sample"]):
53
+ if(images_list["nsfw_content_detected"][i]):
54
+ images.append(safe_image)
55
+ else:
56
+ images.append(image)
57
+ else:
58
+ url = os.getenv('JAX_BACKEND_URL')
59
+ payload = {'prompt': prompt}
60
+ images_request = requests.post(url, json = payload)
61
+ for image in images_request.json()["images"]:
62
+ image_decoded = Image.open(BytesIO(base64.b64decode(image)))
63
+ images.append(image_decoded)
64
+
65
+
66
+ return images, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
67
+
68
 
69
  css = """
70
  .gradio-container {
 
149
  font-weight: bold;
150
  font-size: 115%;
151
  }
152
+ #container-advanced-btns{
153
+ display: flex;
154
+ flex-wrap: wrap;
155
+ justify-content: space-between;
156
+ align-items: center;
157
+ }
158
+ .animate-spin {
159
+ animation: spin 1s linear infinite;
160
+ }
161
+ @keyframes spin {
162
+ from {
163
+ transform: rotate(0deg);
164
+ }
165
+ to {
166
+ transform: rotate(360deg);
167
+ }
168
+ }
169
+ #share-btn-container {
170
+ display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
171
+ }
172
+ #share-btn {
173
+ all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;
174
+ }
175
+ #share-btn * {
176
+ all: unset;
177
+ }
178
+ .gr-form{
179
+ flex: 1 1 50%; border-top-right-radius: 0; border-bottom-right-radius: 0;
180
+ }
181
+ #prompt-container{
182
+ gap: 0;
183
+ }
184
  """
185
 
186
  block = gr.Blocks(css=css)
 
223
  ],
224
  ]
225
 
226
+
227
  with block:
228
  gr.HTML(
229
  """
 
289
  )
290
  with gr.Group():
291
  with gr.Box():
292
+ with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
293
  text = gr.Textbox(
294
  label="Enter your prompt",
295
  show_label=False,
296
  max_lines=1,
297
  placeholder="Enter your prompt",
298
+ elem_id="prompt-text-input",
299
  ).style(
300
  border=(True, False, True, True),
301
  rounded=(True, False, False, True),
 
304
  btn = gr.Button("Generate image").style(
305
  margin=False,
306
  rounded=(False, True, True, False),
307
+ full_width=False,
308
  )
309
 
310
  gallery = gr.Gallery(
311
  label="Generated images", show_label=False, elem_id="gallery"
312
  ).style(grid=[2], height="auto")
313
 
314
+ with gr.Group(elem_id="container-advanced-btns"):
315
+ advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
316
+ with gr.Group(elem_id="share-btn-container"):
317
+ community_icon = gr.HTML(community_icon_html, visible=False)
318
+ loading_icon = gr.HTML(loading_icon_html, visible=False)
319
+ share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
320
 
321
  with gr.Row(elem_id="advanced-options"):
322
+ gr.Markdown("Advanced settings are temporarily unavailable")
323
  samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
324
  steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
325
  scale = gr.Slider(
 
333
  randomize=True,
334
  )
335
 
336
+ ex = gr.Examples(examples=examples, fn=infer, inputs=text, outputs=[gallery, community_icon, loading_icon, share_button], cache_examples=True)
337
  ex.dataset.headers = [""]
338
 
339
 
340
+ text.submit(infer, inputs=text, outputs=[gallery, community_icon, loading_icon, share_button])
341
+
342
+ btn.click(infer, inputs=text, outputs=[gallery, community_icon, loading_icon, share_button])
343
+
344
  advanced_button.click(
345
  None,
346
  [],
 
351
  options.style.display = ["none", ""].includes(options.style.display) ? "flex" : "none";
352
  }""",
353
  )
354
+ share_button.click(
355
+ None,
356
+ [],
357
+ [],
358
+ _js=share_js,
359
+ )
360
  gr.HTML(
361
  """
362
  <div class="footer">
 
372
  """
373
  )
374
 
375
+ block.queue(max_size=25, concurrency_count=2).launch()