fastai unet to segformer
Browse files
app.py
CHANGED
@@ -1,79 +1,92 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
import torch
|
4 |
-
from
|
5 |
-
from
|
6 |
-
from
|
|
|
7 |
from tqdm import tqdm
|
8 |
-
from diffusers import AutoPipelineForInpainting, LCMScheduler, DDIMScheduler
|
9 |
-
from diffusers import StableDiffusionInpaintPipeline, ControlNetModel
|
10 |
import numpy as np
|
11 |
-
from PIL import Image
|
12 |
from datetime import datetime
|
13 |
|
14 |
-
|
|
|
|
|
|
|
15 |
preferred_dtype = torch.float16 if preferred_device == 'cuda' else torch.float32
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
segmodel = load_learner("camvid-512.pkl")
|
20 |
|
21 |
-
|
22 |
-
|
|
|
|
|
23 |
|
24 |
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
25 |
"runwayml/stable-diffusion-inpainting",
|
26 |
variant="fp16",
|
27 |
torch_dtype=preferred_dtype,
|
|
|
28 |
).to(preferred_device)
|
29 |
|
30 |
-
|
|
|
|
|
|
|
|
|
31 |
|
32 |
default_inpainting_prompt = "award-winning photo of a leafy pedestrian mall full of people, with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing"
|
33 |
|
34 |
-
seg_vocabulary =
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
0, 0, 1, 0, 1,
|
44 |
-
1, 1, 0, 0,
|
45 |
-
1, 0, 1, 1, 1,
|
46 |
-
1, 0, 1, 1,
|
47 |
-
1, 0, 0, 0, 1,
|
48 |
-
0, 1, 0], dtype=np.uint8)
|
49 |
|
50 |
def get_seg_mask(img):
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
53 |
|
54 |
|
55 |
def app(img, prompt):
|
56 |
start_time = datetime.now().timestamp()
|
57 |
old_size = Image.fromarray(img).size
|
58 |
img = np.array(Image.fromarray(img).resize(working_size))
|
59 |
-
mask =
|
|
|
60 |
mask_time = datetime.now().timestamp()
|
61 |
-
print(prompt.__class__, img.__class__, mask.__class__, img.shape, mask.shape)
|
62 |
overlay_img = inpainting_pipeline(
|
63 |
prompt=prompt,
|
64 |
-
image=img,
|
65 |
-
mask_image=mask,
|
66 |
strength=0.95,
|
67 |
-
num_inference_steps=
|
68 |
).images[0]
|
|
|
69 |
end_time = datetime.now().timestamp()
|
70 |
draw = ImageDraw.Draw(overlay_img)
|
71 |
# replace spaces with newlines after many words to line break prompt
|
72 |
prompt = " ".join([prompt.split(" ")[i] if (i+1) % 5 else prompt.split(" ")[i] + "\n" for i in range(len(prompt.split(" ")))])
|
73 |
|
74 |
-
draw.text((
|
|
|
75 |
return overlay_img
|
76 |
|
|
|
|
|
|
|
|
|
|
|
77 |
#ideally:
|
78 |
#iface = gr.Interface(app, gr.Image(sources=["webcam"], streaming=True), "image", live=True)
|
79 |
iface = gr.Interface(app, [gr.Image(), gr.Textbox(value=default_inpainting_prompt)], "image")
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
import torch
|
4 |
+
from PIL import Image, ImageFilter, ImageEnhance, ImageDraw
|
5 |
+
from diffusers import LCMScheduler, StableDiffusionInpaintPipeline
|
6 |
+
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
|
7 |
+
|
8 |
from tqdm import tqdm
|
|
|
|
|
9 |
import numpy as np
|
|
|
10 |
from datetime import datetime
|
11 |
|
12 |
+
# ideally:
|
13 |
+
# preferred_device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
|
14 |
+
# but segformer does not work on mps lolololol
|
15 |
+
preferred_device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
preferred_dtype = torch.float16 if preferred_device == 'cuda' else torch.float32
|
17 |
|
18 |
+
seg_model_img_size = 768
|
19 |
+
seg_model_size = 0
|
|
|
20 |
|
21 |
+
seg_feature_extractor = SegformerFeatureExtractor.from_pretrained(f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}")
|
22 |
+
seg_model = SegformerForSemanticSegmentation.from_pretrained(
|
23 |
+
f"nvidia/segformer-b{seg_model_size}-finetuned-cityscapes-{seg_model_img_size}-{seg_model_img_size}"
|
24 |
+
).to(preferred_device).to(preferred_dtype)
|
25 |
|
26 |
inpainting_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
27 |
"runwayml/stable-diffusion-inpainting",
|
28 |
variant="fp16",
|
29 |
torch_dtype=preferred_dtype,
|
30 |
+
safety_checker=None,
|
31 |
).to(preferred_device)
|
32 |
|
33 |
+
inpainting_pipeline.scheduler = LCMScheduler.from_config(inpainting_pipeline.scheduler.config)
|
34 |
+
inpainting_pipeline.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
|
35 |
+
inpainting_pipeline.fuse_lora()
|
36 |
+
|
37 |
+
working_size = (seg_model_img_size, seg_model_img_size)
|
38 |
|
39 |
default_inpainting_prompt = "award-winning photo of a leafy pedestrian mall full of people, with multiracial genderqueer joggers and bicyclists and wheelchair users talking and laughing"
|
40 |
|
41 |
+
seg_vocabulary = seg_model.config.label2id
|
42 |
+
print(f"vocab: {seg_vocabulary}")
|
43 |
+
|
44 |
+
ban_cars_mask = [0] * len(seg_vocabulary)
|
45 |
+
banned_classes = ["car", "road", "sidewalk", "traffic light", "traffic sign"]
|
46 |
+
for c in banned_classes:
|
47 |
+
ban_cars_mask[seg_vocabulary[c]] = 1
|
48 |
+
ban_cars_mask = np.array(ban_cars_mask, dtype=np.uint8)
|
49 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
def get_seg_mask(img):
|
52 |
+
inputs = seg_feature_extractor(images=img, return_tensors="pt").to(preferred_device)
|
53 |
+
outputs = seg_model(**inputs)
|
54 |
+
logits = outputs.logits[0]
|
55 |
+
mask = Image.fromarray((ban_cars_mask[ torch.argmax(logits, dim=0).cpu().numpy() ]) * 255)
|
56 |
+
blurred_widened_mask = ImageEnhance.Contrast(mask.filter(ImageFilter.GaussianBlur(5))).enhance(9000)
|
57 |
+
return blurred_widened_mask
|
58 |
|
59 |
|
60 |
def app(img, prompt):
|
61 |
start_time = datetime.now().timestamp()
|
62 |
old_size = Image.fromarray(img).size
|
63 |
img = np.array(Image.fromarray(img).resize(working_size))
|
64 |
+
mask = get_seg_mask(img)
|
65 |
+
mask.save("mask.jpg")
|
66 |
mask_time = datetime.now().timestamp()
|
67 |
+
#print(prompt.__class__, img.__class__, mask.__class__, img.shape, mask.shape, mask.dtype, img.dtype)
|
68 |
overlay_img = inpainting_pipeline(
|
69 |
prompt=prompt,
|
70 |
+
image=Image.fromarray(img),
|
71 |
+
mask_image=(mask),
|
72 |
strength=0.95,
|
73 |
+
num_inference_steps=4,
|
74 |
).images[0]
|
75 |
+
#overlay_img.save("overlay_raw.jpg")
|
76 |
end_time = datetime.now().timestamp()
|
77 |
draw = ImageDraw.Draw(overlay_img)
|
78 |
# replace spaces with newlines after many words to line break prompt
|
79 |
prompt = " ".join([prompt.split(" ")[i] if (i+1) % 5 else prompt.split(" ")[i] + "\n" for i in range(len(prompt.split(" ")))])
|
80 |
|
81 |
+
draw.text((10, 50), f"Old size: {old_size}\nTotal duration: {int(1000 * (end_time - start_time))}ms\nSegmentation {int(1000 * (mask_time - start_time))}ms / inpainting {int(1000 * (end_time - mask_time))} \n<{prompt}>", fill=(0, 255, 0))
|
82 |
+
#overlay_img.save("overlay_with_text.jpg")
|
83 |
return overlay_img
|
84 |
|
85 |
+
### kick the tires before we start
|
86 |
+
|
87 |
+
for i in tqdm(range(2)):
|
88 |
+
app(np.array(Image.fromarray(np.zeros((1024,1024,3), dtype=np.uint8))), default_inpainting_prompt).save("zeros_inpainting_oneshot.jpg")
|
89 |
+
|
90 |
#ideally:
|
91 |
#iface = gr.Interface(app, gr.Image(sources=["webcam"], streaming=True), "image", live=True)
|
92 |
iface = gr.Interface(app, [gr.Image(), gr.Textbox(value=default_inpainting_prompt)], "image")
|