jepz commited on
Commit
28e5d96
·
verified ·
1 Parent(s): ad36db4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -10,14 +10,23 @@ move_cache()
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
 
 
 
 
 
 
 
 
 
13
  if torch.cuda.is_available():
14
  torch.cuda.max_memory_allocated(device=device)
15
- pipe = StableDiffusionImg2ImgPipeline.from_pretrained("Envvi/Inkpunk-Diffusion", torch_dtype=torch.float16)
16
  #pipe = DiffusionPipeline.from_pretrained("Envvi/Inkpunk-Diffusion", torch_dtype=torch.float16, variant="fp16")
17
  pipe.enable_xformers_memory_efficient_attention()
18
  pipe = pipe.to(device)
19
  else:
20
- pipe = StableDiffusionImg2ImgPipeline.from_pretrained("Envvi/Inkpunk-Diffusion", torch_dtype=torch.float16)
21
  pipe = pipe.to(device)
22
 
23
  MAX_SEED = np.iinfo(np.int32).max
@@ -81,7 +90,7 @@ with gr.Blocks(css=css) as demo:
81
 
82
 
83
  with gr.Row():
84
- base_img = gr.Interface(fn=generate_image, inputs=gr.inputs.Image(type="pil"), outputs=gr.outputs.Image(type="pil"))
85
 
86
  with gr.Row():
87
 
 
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ # Check if a GPU is available and set the appropriate torch_dtype and device
14
+ if torch.cuda.is_available():
15
+ torch_dtype = torch.float16
16
+ device = "cuda"
17
+ else:
18
+ torch_dtype = torch.float32
19
+ device = "cpu"
20
+
21
+
22
  if torch.cuda.is_available():
23
  torch.cuda.max_memory_allocated(device=device)
24
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained("Envvi/Inkpunk-Diffusion", torch_dtype=torch_dtype)
25
  #pipe = DiffusionPipeline.from_pretrained("Envvi/Inkpunk-Diffusion", torch_dtype=torch.float16, variant="fp16")
26
  pipe.enable_xformers_memory_efficient_attention()
27
  pipe = pipe.to(device)
28
  else:
29
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained("Envvi/Inkpunk-Diffusion", torch_dtype=torch_dtype)
30
  pipe = pipe.to(device)
31
 
32
  MAX_SEED = np.iinfo(np.int32).max
 
90
 
91
 
92
  with gr.Row():
93
+ base_img = gr.Interface(fn=generate_image, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"))
94
 
95
  with gr.Row():
96