Alexander McKinney commited on
Commit
92ba1f6
1 Parent(s): 3d237d0

updates for CPU only diffusion

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -15,6 +15,7 @@ from transformers.models.detr.feature_extraction_detr import rgb_to_id
15
  from diffusers import StableDiffusionInpaintPipeline
16
 
17
  auth_token = os.environ.get("READ_TOKEN")
 
18
 
19
  torch.inference_mode()
20
  torch.no_grad()
@@ -32,7 +33,7 @@ def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpaint
32
  return StableDiffusionInpaintPipeline.from_pretrained(
33
  model_name,
34
  revision='fp16',
35
- torch_dtype=torch.float16,
36
  use_auth_token=auth_token
37
  )
38
 
@@ -60,7 +61,7 @@ def clean_mask(mask, max_kernel: int = 23, min_kernel: int = 5):
60
  feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
61
  pipe = load_diffusion_pipeline()
62
 
63
- device = get_device()
64
  pipe = pipe.to(device)
65
 
66
  # Callback function that runs segmentation and updates CheckboxGroup
@@ -161,7 +162,9 @@ demo = gr.Blocks(css=open('app.css').read())
161
 
162
  with demo:
163
  gr.HTML(open('app_header.html').read())
164
- # gr.Markdown("# Stable Diffusion Inpainting Segmentation")
 
 
165
 
166
  # Input image control
167
  input_image = gr.Image(value="example.png", type='pil', label="Input Image")
 
15
  from diffusers import StableDiffusionInpaintPipeline
16
 
17
  auth_token = os.environ.get("READ_TOKEN")
18
+ try_cuda = True
19
 
20
  torch.inference_mode()
21
  torch.no_grad()
 
33
  return StableDiffusionInpaintPipeline.from_pretrained(
34
  model_name,
35
  revision='fp16',
36
+ torch_dtype=torch.float16 if try_cuda and torch.cuda.is_available() else torch.float32,
37
  use_auth_token=auth_token
38
  )
39
 
 
61
  feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
62
  pipe = load_diffusion_pipeline()
63
 
64
+ device = get_device(try_cuda=try_cuda)
65
  pipe = pipe.to(device)
66
 
67
  # Callback function that runs segmentation and updates CheckboxGroup
 
162
 
163
  with demo:
164
  gr.HTML(open('app_header.html').read())
165
+
166
+ if not try_cuda or not torch.cuda.is_available():
167
+ gr.HTML('<div class="alert alert-warning" role="alert" style="color:red"><b>Warning: GPU not available! Diffusion will be slow.</b></div>')
168
 
169
  # Input image control
170
  input_image = gr.Image(value="example.png", type='pil', label="Input Image")