Spaces:
Runtime error
Runtime error
Alexander McKinney
commited on
Commit
•
92ba1f6
1
Parent(s):
3d237d0
updates for CPU only diffusion
Browse files
app.py
CHANGED
@@ -15,6 +15,7 @@ from transformers.models.detr.feature_extraction_detr import rgb_to_id
|
|
15 |
from diffusers import StableDiffusionInpaintPipeline
|
16 |
|
17 |
auth_token = os.environ.get("READ_TOKEN")
|
|
|
18 |
|
19 |
torch.inference_mode()
|
20 |
torch.no_grad()
|
@@ -32,7 +33,7 @@ def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpaint
|
|
32 |
return StableDiffusionInpaintPipeline.from_pretrained(
|
33 |
model_name,
|
34 |
revision='fp16',
|
35 |
-
torch_dtype=torch.float16,
|
36 |
use_auth_token=auth_token
|
37 |
)
|
38 |
|
@@ -60,7 +61,7 @@ def clean_mask(mask, max_kernel: int = 23, min_kernel: int = 5):
|
|
60 |
feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
|
61 |
pipe = load_diffusion_pipeline()
|
62 |
|
63 |
-
device = get_device()
|
64 |
pipe = pipe.to(device)
|
65 |
|
66 |
# Callback function that runs segmentation and updates CheckboxGroup
|
@@ -161,7 +162,9 @@ demo = gr.Blocks(css=open('app.css').read())
|
|
161 |
|
162 |
with demo:
|
163 |
gr.HTML(open('app_header.html').read())
|
164 |
-
|
|
|
|
|
165 |
|
166 |
# Input image control
|
167 |
input_image = gr.Image(value="example.png", type='pil', label="Input Image")
|
|
|
15 |
from diffusers import StableDiffusionInpaintPipeline
|
16 |
|
17 |
auth_token = os.environ.get("READ_TOKEN")
|
18 |
+
try_cuda = True
|
19 |
|
20 |
torch.inference_mode()
|
21 |
torch.no_grad()
|
|
|
33 |
return StableDiffusionInpaintPipeline.from_pretrained(
|
34 |
model_name,
|
35 |
revision='fp16',
|
36 |
+
torch_dtype=torch.float16 if try_cuda and torch.cuda.is_available() else torch.float32,
|
37 |
use_auth_token=auth_token
|
38 |
)
|
39 |
|
|
|
61 |
feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
|
62 |
pipe = load_diffusion_pipeline()
|
63 |
|
64 |
+
device = get_device(try_cuda=try_cuda)
|
65 |
pipe = pipe.to(device)
|
66 |
|
67 |
# Callback function that runs segmentation and updates CheckboxGroup
|
|
|
162 |
|
163 |
with demo:
|
164 |
gr.HTML(open('app_header.html').read())
|
165 |
+
|
166 |
+
if not try_cuda or not torch.cuda.is_available():
|
167 |
+
gr.HTML('<div class="alert alert-warning" role="alert" style="color:red"><b>Warning: GPU not available! Diffusion will be slow.</b></div>')
|
168 |
|
169 |
# Input image control
|
170 |
input_image = gr.Image(value="example.png", type='pil', label="Input Image")
|