TheoBH commited on
Commit
7762f58
1 Parent(s): 0eb741c

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +38 -0
  2. color_palet.py +89 -0
  3. predict.py +79 -0
  4. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from predict import predict_masks, get_mask_for_label
3
+ import glob
4
+
5
+ demo = gr.Blocks()
6
+
7
+ with demo:
8
+
9
+ gr.Markdown("# **<p align='center'>FurnishAI</p>**")
10
+
11
+ with gr.Box():
12
+
13
+ with gr.Row():
14
+ with gr.Column():
15
+ gr.Markdown("**Inputs**")
16
+ input_image = gr.Image(type='filepath',label="Input Image", show_label=True)
17
+ labels_dropdown = gr.Dropdown(label="Labels", show_label=True)
18
+
19
+ with gr.Column():
20
+ gr.Markdown("**Outputs**")
21
+ output_heading = gr.Textbox(label="Output Type", show_label=True)
22
+ output_mask = gr.Image(label="Predicted Masks", show_label=True)
23
+ selected_mask = gr.Image(label="Selected Mask", show_label=True)
24
+
25
+ gr.Markdown("**Predict**")
26
+
27
+ with gr.Box():
28
+ with gr.Row():
29
+ submit_button = gr.Button("Submit")
30
+ generate_mask_button = gr.Button("Generate Mask")
31
+
32
+ gr.Markdown("**Examples:**")
33
+ submit_button.click(predict_masks, inputs=[input_image], outputs=[output_mask, output_heading, labels_dropdown])
34
+ generate_mask_button.click(get_mask_for_label, inputs=[predict_masks(input_image), labels_dropdown], outputs=[selected_mask])
35
+
36
+ gr.Markdown('\n Demo created by: <a href=\"https://www.linkedin.com/in/theo-belen-halimi/">Théo Belen-Halimi</a>')
37
+
38
+ demo.launch(debug=True)
color_palet.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # color palattes for COCO, cityscapes and ADE datasets
2
+
3
+ def coco_panoptic_palette():
4
+ return [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230),
5
+ (106, 0, 228), (0, 60, 100), (0, 80, 100), (0, 0, 70),
6
+ (0, 0, 192), (250, 170, 30), (100, 170, 30), (220, 220, 0),
7
+ (175, 116, 175), (250, 0, 30), (165, 42, 42), (255, 77, 255),
8
+ (0, 226, 252), (182, 182, 255), (0, 82, 0), (120, 166, 157),
9
+ (110, 76, 0), (174, 57, 255), (199, 100, 0), (72, 0, 118),
10
+ (255, 179, 240), (0, 125, 92), (209, 0, 151), (188, 208, 182),
11
+ (0, 220, 176), (255, 99, 164), (92, 0, 73), (133, 129, 255),
12
+ (78, 180, 255), (0, 228, 0), (174, 255, 243), (45, 89, 255),
13
+ (134, 134, 103), (145, 148, 174), (255, 208, 186),
14
+ (197, 226, 255), (171, 134, 1), (109, 63, 54), (207, 138, 255),
15
+ (151, 0, 95), (9, 80, 61), (84, 105, 51), (74, 65, 105),
16
+ (166, 196, 102), (208, 195, 210), (255, 109, 65), (0, 143, 149),
17
+ (179, 0, 194), (209, 99, 106), (5, 121, 0), (227, 255, 205),
18
+ (147, 186, 208), (153, 69, 1), (3, 95, 161), (163, 255, 0),
19
+ (119, 0, 170), (0, 182, 199), (0, 165, 120), (183, 130, 88),
20
+ (95, 32, 0), (130, 114, 135), (110, 129, 133), (166, 74, 118),
21
+ (219, 142, 185), (79, 210, 114), (178, 90, 62), (65, 70, 15),
22
+ (127, 167, 115), (59, 105, 106), (142, 108, 45), (196, 172, 0),
23
+ (95, 54, 80), (128, 76, 255), (201, 57, 1), (246, 0, 122),
24
+ (191, 162, 208), (255, 255, 128), (147, 211, 203),
25
+ (150, 100, 100), (168, 171, 172), (146, 112, 198),
26
+ (210, 170, 100), (92, 136, 89), (218, 88, 184), (241, 129, 0),
27
+ (217, 17, 255), (124, 74, 181), (70, 70, 70), (255, 228, 255),
28
+ (154, 208, 0), (193, 0, 92), (76, 91, 113), (255, 180, 195),
29
+ (106, 154, 176),
30
+ (230, 150, 140), (60, 143, 255), (128, 64, 128), (92, 82, 55),
31
+ (254, 212, 124), (73, 77, 174), (255, 160, 98), (255, 255, 255),
32
+ (104, 84, 109), (169, 164, 131), (225, 199, 255), (137, 54, 74),
33
+ (135, 158, 223), (7, 246, 231), (107, 255, 200), (58, 41, 149),
34
+ (183, 121, 142), (255, 73, 97), (107, 142, 35), (190, 153, 153),
35
+ (146, 139, 141),
36
+ (70, 130, 180), (134, 199, 156), (209, 226, 140), (96, 36, 108),
37
+ (96, 96, 96), (64, 170, 64), (152, 251, 152), (208, 229, 228),
38
+ (206, 186, 171), (152, 161, 64), (116, 112, 0), (0, 114, 143),
39
+ (102, 102, 156), (250, 141, 255)]
40
+
41
+ def cityscapes_palette():
42
+ return [[128, 64, 128],[244, 35, 232],[70, 70, 70],[102, 102, 156],[190, 153, 153],
43
+ [153, 153, 153],[250, 170, 30],[220, 220, 0],[107, 142, 35],[152, 251, 152],
44
+ [70, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], [0, 0, 70],
45
+ [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32]]
46
+
47
+ def ade_palette():
48
+ """Color palette that maps each class to RGB values.
49
+
50
+ This one is actually taken from ADE20k.
51
+ """
52
+ return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
53
+ [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
54
+ [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
55
+ [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
56
+ [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
57
+ [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
58
+ [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
59
+ [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
60
+ [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
61
+ [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
62
+ [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
63
+ [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
64
+ [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
65
+ [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
66
+ [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
67
+ [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
68
+ [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
69
+ [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
70
+ [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
71
+ [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
72
+ [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
73
+ [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
74
+ [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
75
+ [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
76
+ [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
77
+ [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
78
+ [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
79
+ [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
80
+ [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
81
+ [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
82
+ [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
83
+ [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
84
+ [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
85
+ [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
86
+ [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
87
+ [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
88
+ [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
89
+ [102, 255, 0], [92, 0, 255]]
predict.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numpy as np
4
+ from PIL import Image
5
+ from collections import defaultdict
6
+ import os
7
+ # Mentioning detectron2 as a dependency directly in requirements.txt tries to install detectron2 before torch and results in an error even if torch is listed as a dependency before detectron2.
8
+ # Hence, installing detectron2 this way when using Gradio HF spaces.
9
+ os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
10
+
11
+ from detectron2.data import MetadataCatalog
12
+ from detectron2.utils.visualizer import ColorMode, Visualizer
13
+ from color_palette import ade_palette
14
+ from transformers import Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation
15
+
16
+ def load_model_and_processor(model_ckpt: str):
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ model = Mask2FormerForUniversalSegmentation.from_pretrained(model_ckpt).to(torch.device(device))
19
+ model.eval()
20
+ image_preprocessor = Mask2FormerImageProcessor.from_pretrained(model_ckpt)
21
+ return model, image_preprocessor
22
+
23
+ def load_default_ckpt():
24
+ default_ckpt = "facebook/mask2former-swin-tiny-coco-panoptic"
25
+ return default_ckpt
26
+
27
+ def draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image):
28
+ metadata = MetadataCatalog.get("coco_2017_val_panoptic")
29
+ for res in seg_info:
30
+ res['category_id'] = res.pop('label_id')
31
+ pred_class = res['category_id']
32
+ isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values()
33
+ res['isthing'] = bool(isthing)
34
+
35
+ visualizer = Visualizer(np.array(image)[:, :, ::-1], metadata=metadata, instance_mode=ColorMode.IMAGE)
36
+ out = visualizer.draw_panoptic_seg_predictions(
37
+ predicted_segmentation_map.cpu(), seg_info, alpha=0.5
38
+ )
39
+ output_img = Image.fromarray(out.get_image())
40
+ labels = [res['category_id'] for res in seg_info]
41
+ return output_img, labels
42
+
43
+
44
+
45
+ def predict_masks(input_img_path: str):
46
+
47
+ #load model and image processor
48
+ default_ckpt = load_default_ckpt()
49
+ model, image_processor = load_model_and_processor()
50
+
51
+ ## pass input image through image processor
52
+ image = Image.open(input_img_path)
53
+ inputs = image_processor(images=image, return_tensors="pt")
54
+
55
+ ## pass inputs to model for prediction
56
+ with torch.no_grad():
57
+ outputs = model(**inputs)
58
+ result = image_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
59
+ predicted_segmentation_map = result["segmentation"]
60
+ seg_info = result['segments_info']
61
+ output_result, labels = draw_panoptic_segmentation(predicted_segmentation_map, seg_info, image)
62
+ output_heading = "Panoptic Segmentation Output"
63
+
64
+ return output_result, output_heading, labels
65
+
66
+
67
+
68
+
69
+
70
+ def get_mask_for_label(results, label):
71
+ import numpy as np
72
+ from PIL import Image
73
+
74
+ mask = (results['segmentation'].numpy() == label)
75
+ visual_mask = (mask * 255).astype(np.uint8)
76
+ visual_mask = Image.fromarray(visual_mask)
77
+
78
+ return visual_mask
79
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ opencv-python
4
+ git+https://github.com/huggingface/transformers.git
5
+ pillow
6
+ scipy
7
+ torchvision