multimodalart HF staff commited on
Commit
895d93f
1 Parent(s): 63a6e70

Pure external backend (#804)

Browse files

- Pure external backend (b7a47493cada27eb0e5727eff02d877dad5684c8)

Files changed (1) hide show
  1. app.py +26 -26
app.py CHANGED
@@ -17,9 +17,9 @@ model_id = "CompVis/stable-diffusion-v1-4"
17
  device = "cuda"
18
 
19
  #If you are running this code locally, you need to either do a 'huggingface-cli login` or paste your User Access Token from here https://huggingface.co/settings/tokens into the use_auth_token field below.
20
- pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, revision="fp16", torch_dtype=torch.float16)
21
- pipe = pipe.to(device)
22
- torch.backends.cudnn.benchmark = True
23
 
24
  #When running locally, you won`t have access to this, so you can remove this part
25
  word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
@@ -37,30 +37,30 @@ def infer(prompt):
37
  raise gr.Error("Unsafe content found. Please try again with different prompts.")
38
 
39
  #generator = torch.Generator(device=device).manual_seed(seed)
40
- print("Is GPU busy? ", is_gpu_busy)
41
  images = []
42
- if(not is_gpu_busy):
43
- is_gpu_busy = True
44
- images_list = pipe(
45
- [prompt] * samples,
46
- num_inference_steps=steps,
47
- guidance_scale=scale,
48
  #generator=generator,
49
- )
50
- is_gpu_busy = False
51
- safe_image = Image.open(r"unsafe.png")
52
- for i, image in enumerate(images_list["sample"]):
53
- if(images_list["nsfw_content_detected"][i]):
54
- images.append(safe_image)
55
- else:
56
- images.append(image)
57
- else:
58
- url = os.getenv('JAX_BACKEND_URL')
59
- payload = {'prompt': prompt}
60
- images_request = requests.post(url, json = payload)
61
- for image in images_request.json()["images"]:
62
- image_decoded = Image.open(BytesIO(base64.b64decode(image)))
63
- images.append(image_decoded)
64
 
65
 
66
  return images, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
@@ -372,4 +372,4 @@ Despite how impressive being able to turn text into image is, beware to the fact
372
  """
373
  )
374
 
375
- block.queue(max_size=25, concurrency_count=2).launch()
 
17
  device = "cuda"
18
 
19
  #If you are running this code locally, you need to either do a 'huggingface-cli login` or paste your User Access Token from here https://huggingface.co/settings/tokens into the use_auth_token field below.
20
+ #pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True, revision="fp16", torch_dtype=torch.float16)
21
+ #pipe = pipe.to(device)
22
+ #torch.backends.cudnn.benchmark = True
23
 
24
  #When running locally, you won`t have access to this, so you can remove this part
25
  word_list_dataset = load_dataset("stabilityai/word-list", data_files="list.txt", use_auth_token=True)
 
37
  raise gr.Error("Unsafe content found. Please try again with different prompts.")
38
 
39
  #generator = torch.Generator(device=device).manual_seed(seed)
40
+ #print("Is GPU busy? ", is_gpu_busy)
41
  images = []
42
+ #if(not is_gpu_busy):
43
+ # is_gpu_busy = True
44
+ # images_list = pipe(
45
+ # [prompt] * samples,
46
+ # num_inference_steps=steps,
47
+ # guidance_scale=scale,
48
  #generator=generator,
49
+ # )
50
+ # is_gpu_busy = False
51
+ # safe_image = Image.open(r"unsafe.png")
52
+ # for i, image in enumerate(images_list["sample"]):
53
+ # if(images_list["nsfw_content_detected"][i]):
54
+ # images.append(safe_image)
55
+ # else:
56
+ # images.append(image)
57
+ #else:
58
+ url = os.getenv('JAX_BACKEND_URL')
59
+ payload = {'prompt': prompt}
60
+ images_request = requests.post(url, json = payload)
61
+ for image in images_request.json()["images"]:
62
+ image_decoded = Image.open(BytesIO(base64.b64decode(image)))
63
+ images.append(image_decoded)
64
 
65
 
66
  return images, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
 
372
  """
373
  )
374
 
375
+ block.queue(max_size=50, concurrency_count=40).launch()