israelgonzalezb commited on
Commit
db58866
1 Parent(s): c34bc8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -24
app.py CHANGED
@@ -21,7 +21,7 @@ word_list = word_list_dataset["train"]['text']
21
 
22
  is_gpu_busy = False
23
  def infer(prompt):
24
- #global is_gpu_busy
25
  samples = 4
26
  steps = 50
27
  scale = 7.5
@@ -30,31 +30,31 @@ def infer(prompt):
30
  if re.search(rf"\b{filter}\b", prompt):
31
  raise gr.Error("Unsafe content found. Please try again with different prompts.")
32
 
33
- #generator = torch.Generator(device=device).manual_seed(seed)
34
- #print("Is GPU busy? ", is_gpu_busy)
35
  images = []
36
- #if(not is_gpu_busy):
37
- # is_gpu_busy = True
38
- # images_list = pipe(
39
- # [prompt] * samples,
40
- # num_inference_steps=steps,
41
- # guidance_scale=scale,
42
  #generator=generator,
43
- # )
44
- # is_gpu_busy = False
45
- # safe_image = Image.open(r"unsafe.png")
46
- # for i, image in enumerate(images_list["sample"]):
47
- # if(images_list["nsfw_content_detected"][i]):
48
- # images.append(safe_image)
49
- # else:
50
- # images.append(image)
51
- #else:
52
- url = os.getenv('JAX_BACKEND_URL')
53
- payload = {'prompt': prompt}
54
- images_request = requests.post(url, json = payload)
55
- for image in images_request.json()["images"]:
56
- image_b64 = (f"data:image/jpeg;base64,{image}")
57
- images.append(image_b64)
58
 
59
  return images
60
 
 
21
 
22
  is_gpu_busy = False
23
  def infer(prompt):
24
+ global is_gpu_busy
25
  samples = 4
26
  steps = 50
27
  scale = 7.5
 
30
  if re.search(rf"\b{filter}\b", prompt):
31
  raise gr.Error("Unsafe content found. Please try again with different prompts.")
32
 
33
+ generator = torch.Generator(device=device).manual_seed(seed)
34
+ print("Is GPU busy? ", is_gpu_busy)
35
  images = []
36
+ if(not is_gpu_busy):
37
+ is_gpu_busy = True
38
+ images_list = pipe(
39
+ [prompt] * samples,
40
+ num_inference_steps=steps,
41
+ guidance_scale=scale,
42
  #generator=generator,
43
+ )
44
+ is_gpu_busy = False
45
+ safe_image = Image.open(r"unsafe.png")
46
+ for i, image in enumerate(images_list["sample"]):
47
+ if(images_list["nsfw_content_detected"][i]):
48
+ images.append(safe_image)
49
+ else:
50
+ images.append(image)
51
+ else:
52
+ url = os.getenv('JAX_BACKEND_URL')
53
+ payload = {'prompt': prompt}
54
+ images_request = requests.post(url, json = payload)
55
+ for image in images_request.json()["images"]:
56
+ image_b64 = (f"data:image/jpeg;base64,{image}")
57
+ images.append(image_b64)
58
 
59
  return images
60