seg preferred device vs inpaint, and compile everything
Browse files
app.py
CHANGED
@@ -12,14 +12,15 @@ from datetime import datetime
|
|
12 |
# ideally:
|
13 |
# preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
|
14 |
# but segformer does not work on mps lolololol
|
15 |
-
|
16 |
-
|
17 |
inpaint_preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
|
|
|
18 |
torch.backends.cuda.matmul.allow_tf32 = True
|
19 |
|
20 |
print(f"backends: {torch._dynamo.list_backends()}")
|
21 |
|
22 |
-
preferred_backend = "aot_eager" if inpaint_preferred_device == "mps" else
|
23 |
|
24 |
seg_model_img_size = 768
|
25 |
seg_model_size = 0
|
@@ -27,17 +28,18 @@ seg_model_size = 0
|
|
27 |
seg_feature_extractor = SegformerFeatureExtractor.from_pretrained(f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}")
|
28 |
seg_model = SegformerForSemanticSegmentation.from_pretrained(
|
29 |
f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}"
|
30 |
-
).to(
|
31 |
|
32 |
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
33 |
"SimianLuo/LCM_Dreamshaper_v7",
|
34 |
-
torch_dtype=
|
35 |
safety_checker=None,
|
36 |
).to(inpaint_preferred_device)
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
41 |
|
42 |
seg_working_size = (seg_model_img_size, seg_model_img_size)
|
43 |
|
@@ -54,7 +56,7 @@ ban_cars_mask = np.array(ban_cars_mask, dtype=np.uint8)
|
|
54 |
|
55 |
|
56 |
def get_seg_mask(img):
|
57 |
-
inputs = seg_feature_extractor(images=img, return_tensors="pt").to(
|
58 |
outputs = seg_model(**inputs)
|
59 |
logits = outputs.logits[0]
|
60 |
mask = Image.fromarray((ban_cars_mask[ torch.argmax(logits, dim=0).cpu().numpy() ]) * 255)
|
|
|
12 |
# ideally:
|
13 |
# preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
|
14 |
# but segformer does not work on mps lolololol
|
15 |
+
seg_preferred_device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
+
seg_preferred_dtype = torch.float16 if seg_preferred_device == 'cuda' else torch.float32
|
17 |
inpaint_preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
|
18 |
+
inpaint_preferred_dtype = torch.float32 if inpaint_preferred_device == 'cpu' else torch.float16
|
19 |
torch.backends.cuda.matmul.allow_tf32 = True
|
20 |
|
21 |
print(f"backends: {torch._dynamo.list_backends()}")
|
22 |
|
23 |
+
preferred_backend = "aot_eager" if inpaint_preferred_device == "mps" else "inductor"
|
24 |
|
25 |
seg_model_img_size = 768
|
26 |
seg_model_size = 0
|
|
|
28 |
seg_feature_extractor = SegformerFeatureExtractor.from_pretrained(f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}")
|
29 |
seg_model = SegformerForSemanticSegmentation.from_pretrained(
|
30 |
f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}"
|
31 |
+
).to(seg_preferred_device).to(seg_preferred_dtype)
|
32 |
|
33 |
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
34 |
"SimianLuo/LCM_Dreamshaper_v7",
|
35 |
+
torch_dtype=inpaint_preferred_dtype,
|
36 |
safety_checker=None,
|
37 |
).to(inpaint_preferred_device)
|
38 |
|
39 |
+
inpainting_pipeline.text_encoder = torch.compile(inpainting_pipeline.text_encoder, backend=preferred_backend)
|
40 |
+
inpainting_pipeline.unet = torch.compile(inpainting_pipeline.unet, backend=preferred_backend)
|
41 |
+
inpainting_pipeline.vae = torch.compile(inpainting_pipeline.vae, backend=preferred_backend)
|
42 |
+
seg_model = torch.compile(seg_model, backend=preferred_backend)
|
43 |
|
44 |
seg_working_size = (seg_model_img_size, seg_model_img_size)
|
45 |
|
|
|
56 |
|
57 |
|
58 |
def get_seg_mask(img):
|
59 |
+
inputs = seg_feature_extractor(images=img, return_tensors="pt").to(seg_preferred_device).to(seg_preferred_dtype)
|
60 |
outputs = seg_model(**inputs)
|
61 |
logits = outputs.logits[0]
|
62 |
mask = Image.fromarray((ban_cars_mask[ torch.argmax(logits, dim=0).cpu().numpy() ]) * 255)
|