UdacityNoob commited on
Commit
c8b9d47
β€’
1 Parent(s): 16d79ac

Check if gpu is available

Browse files
Files changed (1) hide show
  1. app.py +22 -14
app.py CHANGED
@@ -7,27 +7,35 @@ from diffusers import StableDiffusionPipeline
7
  # get hf user access token as an environment variable
8
  TOKEN_KEY = os.getenv('AUTH_TOKEN')
9
 
 
 
 
 
10
  # setup pipeline
11
  pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=TOKEN_KEY)
12
- pipe = pipe.to('cuda')
13
 
14
  # define gradio function
15
  def generate(prompt:str, seed:int, guidance:float):
16
- generator = torch.Generator("cuda").manual_seed(int(seed))
17
- with autocast("cuda"):
18
  image = pipe(prompt=prompt, generator=generator, guidance_scale=guidance, steps=50).images[0]
19
  return image
20
 
21
- # create the gradio UI
22
- demo = gr.Interface(
23
- fn=generate,
24
- inputs=[gr.Textbox(placeholder="castle on a mountain"), gr.Number(value=123456), gr.Slider(0,10)],
25
- outputs="image",
26
- allow_flagging="never",
27
- )
 
 
28
 
29
- # allow queueing or incoming requests, max=3
30
- demo.queue(concurrency_count=3)
31
 
32
- # launch demo
33
- demo.launch()
 
 
 
7
  # get hf user access token as an environment variable
8
  TOKEN_KEY = os.getenv('AUTH_TOKEN')
9
 
10
+ # choose GPU else fallback to CPU
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ device_name = torch.cuda.get_device_name(0)
13
+
14
  # setup pipeline
15
  pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16, use_auth_token=TOKEN_KEY)
16
+ pipe = pipe.to(device)
17
 
18
  # define gradio function
19
  def generate(prompt:str, seed:int, guidance:float):
20
+ generator = torch.Generator(device).manual_seed(int(seed))
21
+ with autocast(device):
22
  image = pipe(prompt=prompt, generator=generator, guidance_scale=guidance, steps=50).images[0]
23
  return image
24
 
25
+ if device == "cuda":
26
+ print(device_name + " available.")
27
+ # create the gradio UI
28
+ demo = gr.Interface(
29
+ fn=generate,
30
+ inputs=[gr.Textbox(placeholder="castle on a mountain"), gr.Number(value=123456), gr.Slider(0,10)],
31
+ outputs="image",
32
+ allow_flagging="never",
33
+ )
34
 
35
+ # allow queueing or incoming requests, max=3
36
+ demo.queue(concurrency_count=3)
37
 
38
+ # launch demo
39
+ demo.launch()
40
+ else:
41
+ print("GPU unavailable.")