apolinario commited on
Commit
53066e3
1 Parent(s): 38d05ac

Swap to hybrid backend

Browse files
Files changed (1) hide show
  1. app.py +41 -19
app.py CHANGED
@@ -5,7 +5,11 @@ from torch import autocast
5
  from diffusers import StableDiffusionPipeline
6
  from datasets import load_dataset
7
  from PIL import Image
 
 
8
  import re
 
 
9
 
10
  from share_btn import community_icon_html, loading_icon_html, share_js
11
 
@@ -21,27 +25,44 @@ torch.backends.cudnn.benchmark = True
21
  word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
22
  word_list = word_list_dataset["train"]['text']
23
 
24
- def infer(prompt, samples, steps, scale, seed):
 
 
 
 
 
25
  #When running locally you can also remove this filter
26
  for filter in word_list:
27
  if re.search(rf"\b{filter}\b", prompt):
28
  raise gr.Error("Unsafe content found. Please try again with different prompts.")
29
 
30
- generator = torch.Generator(device=device).manual_seed(seed)
31
-
32
- images_list = pipe(
33
- [prompt] * samples,
34
- num_inference_steps=steps,
35
- guidance_scale=scale,
36
- generator=generator,
37
- )
38
  images = []
39
- safe_image = Image.open(r"unsafe.png")
40
- for i, image in enumerate(images_list["sample"]):
41
- if(images_list["nsfw_content_detected"][i]):
42
- images.append(safe_image)
43
- else:
44
- images.append(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  return images, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
46
 
47
 
@@ -298,6 +319,7 @@ with block:
298
  share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
299
 
300
  with gr.Row(elem_id="advanced-options"):
 
301
  samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
302
  steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
303
  scale = gr.Slider(
@@ -311,13 +333,13 @@ with block:
311
  randomize=True,
312
  )
313
 
314
- ex = gr.Examples(examples=examples, fn=infer, inputs=[text, samples, steps, scale, seed], outputs=[gallery, community_icon, loading_icon, share_button], cache_examples=True)
315
  ex.dataset.headers = [""]
316
 
317
 
318
- text.submit(infer, inputs=[text, samples, steps, scale, seed], outputs=[gallery, community_icon, loading_icon, share_button])
319
 
320
- btn.click(infer, inputs=[text, samples, steps, scale, seed], outputs=[gallery, community_icon, loading_icon, share_button])
321
 
322
  advanced_button.click(
323
  None,
@@ -350,4 +372,4 @@ Despite how impressive being able to turn text into image is, beware to the fact
350
  """
351
  )
352
 
353
- block.queue(max_size=25).launch()
 
5
  from diffusers import StableDiffusionPipeline
6
  from datasets import load_dataset
7
  from PIL import Image
8
+ from io import BytesIO
9
+ import base64
10
  import re
11
+ import os
12
+ import requests
13
 
14
  from share_btn import community_icon_html, loading_icon_html, share_js
15
 
 
25
  word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
26
  word_list = word_list_dataset["train"]['text']
27
 
28
+ is_gpu_busy = False
29
+ def infer(prompt):
30
+ global is_gpu_busy
31
+ samples = 4
32
+ steps = 50
33
+ scale = 7.5
34
  #When running locally you can also remove this filter
35
  for filter in word_list:
36
  if re.search(rf"\b{filter}\b", prompt):
37
  raise gr.Error("Unsafe content found. Please try again with different prompts.")
38
 
39
+ #generator = torch.Generator(device=device).manual_seed(seed)
40
+ print("Is GPU busy? ", is_gpu_busy)
 
 
 
 
 
 
41
  images = []
42
+ if(not is_gpu_busy):
43
+ is_gpu_busy = True
44
+ images_list = pipe(
45
+ [prompt] * samples,
46
+ num_inference_steps=steps,
47
+ guidance_scale=scale,
48
+ #generator=generator,
49
+ )
50
+ is_gpu_busy = False
51
+ safe_image = Image.open(r"unsafe.png")
52
+ for i, image in enumerate(images_list["sample"]):
53
+ if(images_list["nsfw_content_detected"][i]):
54
+ images.append(safe_image)
55
+ else:
56
+ images.append(image)
57
+ else:
58
+ url = os.getenv('JAX_BACKEND_URL')
59
+ payload = {'prompt': prompt}
60
+ images_request = requests.post(url, json = payload)
61
+ for image in images_request.json()["images"]:
62
+ image_decoded = Image.open(BytesIO(base64.b64decode(image)))
63
+ images.append(image_decoded)
64
+
65
+
66
  return images, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
67
 
68
 
 
319
  share_button = gr.Button("Share to community", elem_id="share-btn", visible=False)
320
 
321
  with gr.Row(elem_id="advanced-options"):
322
+ gr.Markdown("Advanced settings are temporarily unavailable")
323
  samples = gr.Slider(label="Images", minimum=1, maximum=4, value=4, step=1)
324
  steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=45, step=1)
325
  scale = gr.Slider(
 
333
  randomize=True,
334
  )
335
 
336
+ ex = gr.Examples(examples=examples, fn=infer, inputs=text, outputs=[gallery, community_icon, loading_icon, share_button], cache_examples=True)
337
  ex.dataset.headers = [""]
338
 
339
 
340
+ text.submit(infer, inputs=text, outputs=[gallery, community_icon, loading_icon, share_button])
341
 
342
+ btn.click(infer, inputs=text, outputs=[gallery, community_icon, loading_icon, share_button])
343
 
344
  advanced_button.click(
345
  None,
 
372
  """
373
  )
374
 
375
+ block.queue(max_size=25, concurrency_count=2).launch()