Spaces:
Running
Running
Upload 4 files
Browse files- app.py +38 -0
- color_palet.py +89 -0
- predict.py +79 -0
- requirements.txt +7 -0
app.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from predict import predict_masks, get_mask_for_label
|
3 |
+
import glob
|
4 |
+
|
5 |
+
demo = gr.Blocks()
|
6 |
+
|
7 |
+
with demo:
|
8 |
+
|
9 |
+
gr.Markdown("# **<p align='center'>FurnishAI</p>**")
|
10 |
+
|
11 |
+
with gr.Box():
|
12 |
+
|
13 |
+
with gr.Row():
|
14 |
+
with gr.Column():
|
15 |
+
gr.Markdown("**Inputs**")
|
16 |
+
input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
|
17 |
+
labels_dropdown = gr.Dropdown(label="Labels", show_label=True)
|
18 |
+
|
19 |
+
with gr.Column():
|
20 |
+
gr.Markdown("**Outputs**")
|
21 |
+
output_heading = gr.Textbox(label="Output Type", show_label=True)
|
22 |
+
output_mask = gr.Image(label="Predicted Masks", show_label=True)
|
23 |
+
selected_mask = gr.Image(label="Selected Mask", show_label=True)
|
24 |
+
|
25 |
+
gr.Markdown("**Predict**")
|
26 |
+
|
27 |
+
with gr.Box():
|
28 |
+
with gr.Row():
|
29 |
+
submit_button = gr.Button("Submit")
|
30 |
+
generate_mask_button = gr.Button("Generate Mask")
|
31 |
+
|
32 |
+
gr.Markdown("**Examples:**")
|
33 |
+
submit_button.click(predict_masks, inputs=[input_image], outputs=[output_mask, output_heading, labels_dropdown])
|
34 |
+
generate_mask_button.click(get_mask_for_label, inputs=[predict_masks(input_image), labels_dropdown], outputs=[selected_mask])
|
35 |
+
|
36 |
+
gr.Markdown('\n Demo created by: <a href=\"https://www.linkedin.com/in/theo-belen-halimi/">Théo Belen-Halimi</a>')
|
37 |
+
|
38 |
+
demo.launch(debug=True)
|
color_palet.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# color palattes for COCO, cityscapes and ADE datasets
|
2 |
+
|
3 |
+
def coco_panoptic_palette():
|
4 |
+
return [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230),
|
5 |
+
(106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70),
|
6 |
+
(0, 0, 192), (250, 170, 30), (100, 170, 30), (220, 220, 0),
|
7 |
+
(175, 116, 175), (250, 0, 30), (165, 42, 42), (255, 77, 255),
|
8 |
+
(0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
|
9 |
+
(110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118),
|
10 |
+
(255, 179, 240), (0, 125, 92), (209, 0, 151), (188, 208, 182),
|
11 |
+
(0, 220, 176), (255, 99, 164), (92, 0, 73), (133, 129, 255),
|
12 |
+
(78, 180, 255), (0, 228, 0), (174, 255, 243), (45, 89, 255),
|
13 |
+
(134, 134, 103), (145, 148, 174), (255, 208, 186),
|
14 |
+
(197, 226, 255), (171, 134, 1), (109, 63, 54), (207, 138, 255),
|
15 |
+
(151, 0, 95), (9, 80, 61), (84, 105, 51), (74, 65, 105),
|
16 |
+
(166, 196, 102), (208, 195, 210), (255, 109, 65), (0, 143, 149),
|
17 |
+
(179, 0, 194), (209, 99, 106), (5, 121, 0), (227, 255, 205),
|
18 |
+
(147, 186, 208), (153, 69, 1), (3, 95, 161), (163, 255, 0),
|
19 |
+
(119, 0, 170), (0, 182, 199), (0, 165, 120), (183, 130, 88),
|
20 |
+
(95, 32, 0), (130, 114, 135), (110, 129, 133), (166, 74, 118),
|
21 |
+
(219, 142, 185), (79, 210, 114), (178, 90, 62), (65, 70, 15),
|
22 |
+
(127, 167, 115), (59, 105, 106), (142, 108, 45), (196, 172, 0),
|
23 |
+
(95, 54, 80), (128, 76, 255), (201, 57, 1), (246, 0, 122),
|
24 |
+
(191, 162, 208), (255, 255, 128), (147, 211, 203),
|
25 |
+
(150, 100, 100), (168, 171, 172), (146, 112, 198),
|
26 |
+
(210, 170, 100), (92, 136, 89), (218, 88, 184), (241, 129, 0),
|
27 |
+
(217, 17, 255), (124, 74, 181), (70, 70, 70), (255, 228, 255),
|
28 |
+
(154, 208, 0), (193, 0, 92), (76, 91, 113), (255, 180, 195),
|
29 |
+
(106, 154, 176),
|
30 |
+
(230, 150, 140), (60, 143, 255), (128, 64, 128), (92, 82, 55),
|
31 |
+
(254, 212, 124), (73, 77, 174), (255, 160, 98), (255, 255, 255),
|
32 |
+
(104, 84, 109), (169, 164, 131), (225, 199, 255), (137, 54, 74),
|
33 |
+
(135, 158, 223), (7, 246, 231), (107, 255, 200), (58, 41, 149),
|
34 |
+
(183, 121, 142), (255, 73, 97), (107, 142, 35), (190, 153, 153),
|
35 |
+
(146, 139, 141),
|
36 |
+
(70, 130, 180), (134, 199, 156), (209, 226, 140), (96, 36, 108),
|
37 |
+
(96, 96, 96), (64, 170, 64), (152, 251, 152), (208, 229, 228),
|
38 |
+
(206, 186, 171), (152, 161, 64), (116, 112, 0), (0, 114, 143),
|
39 |
+
(102, 102, 156), (250, 141, 255)]
|
40 |
+
|
41 |
+
def cityscapes_palette():
|
42 |
+
return [[128, 64, 128],[244, 35, 232],[70, 70, 70],[102, 102, 156],[190, 153, 153],
|
43 |
+
[153, 153, 153],[250, 170, 30],[220, 220, 0],[107, 142, 35],[152, 251, 152],
|
44 |
+
[70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
|
45 |
+
[0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]]
|
46 |
+
|
47 |
+
def ade_palette():
|
48 |
+
"""Color palette that maps each class to RGB values.
|
49 |
+
|
50 |
+
This one is actually taken from ADE20k.
|
51 |
+
"""
|
52 |
+
return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
|
53 |
+
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
|
54 |
+
[230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
|
55 |
+
[150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
|
56 |
+
[143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
|
57 |
+
[0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
|
58 |
+
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
|
59 |
+
[255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
|
60 |
+
[255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
|
61 |
+
[224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
|
62 |
+
[255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
|
63 |
+
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
|
64 |
+
[140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
|
65 |
+
[255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
|
66 |
+
[255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
|
67 |
+
[11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
|
68 |
+
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
|
69 |
+
[255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
|
70 |
+
[0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
|
71 |
+
[173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
|
72 |
+
[255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
|
73 |
+
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
|
74 |
+
[255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
|
75 |
+
[0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
|
76 |
+
[0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
|
77 |
+
[143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
|
78 |
+
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
|
79 |
+
[255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
|
80 |
+
[92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
|
81 |
+
[163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
|
82 |
+
[255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
|
83 |
+
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
|
84 |
+
[10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
|
85 |
+
[255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
|
86 |
+
[41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
|
87 |
+
[71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
|
88 |
+
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
|
89 |
+
[102, 255, 0], [92, 0, 255]]
|
predict.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from collections import defaultdict
|
6 |
+
import os
|
7 |
+
# Mentioning detectron2 as a dependency directly in requirements.txt tries to install detectron2 before torch and results in an error even if torch is listed as a dependency before detectron2.
|
8 |
+
# Hence, installing detectron2 this way when using Gradio HF spaces.
|
9 |
+
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
|
10 |
+
|
11 |
+
from detectron2.data import MetadataCatalog
|
12 |
+
from detectron2.utils.visualizer import ColorMode, Visualizer
|
13 |
+
from color_palette import ade_palette
|
14 |
+
from transformers import Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation
|
15 |
+
|
16 |
+
def load_model_and_processor(model_ckpt: str):
|
17 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device))
|
19 |
+
model.eval()
|
20 |
+
image_preprocessor = Mask2FormerImageProcessor.from_pretrained(model_ckpt)
|
21 |
+
return model, image_preprocessor
|
22 |
+
|
23 |
+
def load_default_ckpt():
|
24 |
+
default_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic"
|
25 |
+
return default_ckpt
|
26 |
+
|
27 |
+
def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
|
28 |
+
metadata = MetadataCatalog.get("coco_2017_val_panoptic")
|
29 |
+
for res in seg_info:
|
30 |
+
res['category_id'] = res.pop('label_id')
|
31 |
+
pred_class = res['category_id']
|
32 |
+
isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
|
33 |
+
res['isthing'] = bool(isthing)
|
34 |
+
|
35 |
+
visualizer = Visualizer(np.array(image)[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
|
36 |
+
out = visualizer.draw_panoptic_seg_predictions(
|
37 |
+
predicted_segmentation_map.cpu(), seg_info, alpha=0.5
|
38 |
+
)
|
39 |
+
output_img = Image.fromarray(out.get_image())
|
40 |
+
labels = [res['category_id'] for res in seg_info]
|
41 |
+
return output_img, labels
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
def predict_masks(input_img_path: str):
|
46 |
+
|
47 |
+
#load model and image processor
|
48 |
+
default_ckpt = load_default_ckpt()
|
49 |
+
model, image_processor = load_model_and_processor()
|
50 |
+
|
51 |
+
## pass input image through image processor
|
52 |
+
image = Image.open(input_img_path)
|
53 |
+
inputs = image_processor(images=image, return_tensors="pt")
|
54 |
+
|
55 |
+
## pass inputs to model for prediction
|
56 |
+
with torch.no_grad():
|
57 |
+
outputs = model(**inputs)
|
58 |
+
result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
59 |
+
predicted_segmentation_map = result["segmentation"]
|
60 |
+
seg_info = result['segments_info']
|
61 |
+
output_result, labels = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image)
|
62 |
+
output_heading = "Panoptic Segmentation Output"
|
63 |
+
|
64 |
+
return output_result, output_heading, labels
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
def get_mask_for_label(results, label):
|
71 |
+
import numpy as np
|
72 |
+
from PIL import Image
|
73 |
+
|
74 |
+
mask = (results['segmentation'].numpy() == label)
|
75 |
+
visual_mask = (mask * 255).astype(np.uint8)
|
76 |
+
visual_mask = Image.fromarray(visual_mask)
|
77 |
+
|
78 |
+
return visual_mask
|
79 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
gradio
|
3 |
+
opencv-python
|
4 |
+
git+https://github.com/huggingface/transformers.git
|
5 |
+
pillow
|
6 |
+
scipy
|
7 |
+
torchvision
|