mps
Browse files
app.py
CHANGED
@@ -11,14 +11,14 @@ import numpy as np
|
|
11 |
from PIL import Image
|
12 |
from datetime import datetime
|
13 |
|
14 |
-
preferred_device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
-
preferred_dtype = torch.
|
16 |
|
17 |
def label_func(fn): return path/"labels"/f"{fn.stem}_P{fn.suffix}"
|
18 |
|
19 |
segmodel = load_learner("camvid-512.pkl")
|
20 |
|
21 |
-
if
|
22 |
segmodel = segmodel.to_fp16()
|
23 |
|
24 |
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
@@ -29,7 +29,7 @@ inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
|
29 |
|
30 |
working_size = (512, 512)
|
31 |
|
32 |
-
default_inpainting_prompt = "
|
33 |
|
34 |
seg_vocabulary = ['Animal', 'Archway', 'Bicyclist', 'Bridge', 'Building', 'Car',
|
35 |
'CartLuggagePram', 'Child', 'Column_Pole', 'Fence', 'LaneMkgsDriv',
|
@@ -64,7 +64,7 @@ def app(img, prompt):
|
|
64 |
image=img,
|
65 |
mask_image=mask,
|
66 |
strength=0.95,
|
67 |
-
num_inference_steps=
|
68 |
).images[0]
|
69 |
end_time = datetime.now().timestamp()
|
70 |
draw = ImageDraw.Draw(overlay_img)
|
|
|
11 |
from PIL import Image
|
12 |
from datetime import datetime
|
13 |
|
14 |
+
preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
|
15 |
+
preferred_dtype = torch.float16 if preferred_device == 'cuda' else torch.float32
|
16 |
|
17 |
def label_func(fn): return path/"labels"/f"{fn.stem}_P{fn.suffix}"
|
18 |
|
19 |
segmodel = load_learner("camvid-512.pkl")
|
20 |
|
21 |
+
if preferred_dtype == torch.float16:
|
22 |
segmodel = segmodel.to_fp16()
|
23 |
|
24 |
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
|
|
29 |
|
30 |
working_size = (512, 512)
|
31 |
|
32 |
+
default_inpainting_prompt = "award-winning photo of a leafy pedestrian mall full of people, with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing"
|
33 |
|
34 |
seg_vocabulary = ['Animal', 'Archway', 'Bicyclist', 'Bridge', 'Building', 'Car',
|
35 |
'CartLuggagePram', 'Child', 'Column_Pole', 'Fence', 'LaneMkgsDriv',
|
|
|
64 |
image=img,
|
65 |
mask_image=mask,
|
66 |
strength=0.95,
|
67 |
+
num_inference_steps=20,
|
68 |
).images[0]
|
69 |
end_time = datetime.now().timestamp()
|
70 |
draw = ImageDraw.Draw(overlay_img)
|