sfmig commited on
Commit
6c333c9
1 Parent(s): 912be9c

added a different color palette

Browse files
Files changed (2) hide show
  1. .gitignore +2 -0
  2. app.py +19 -162
.gitignore CHANGED
@@ -1,3 +1,5 @@
1
  scrap*
2
  .DS_Store
3
  requirements_conda.txt
 
 
 
1
  scrap*
2
  .DS_Store
3
  requirements_conda.txt
4
+ app_0.py
5
+ test.py
app.py CHANGED
@@ -12,6 +12,13 @@ https://arxiv.org/abs/2005.12872
12
  Additions
13
  - add shown labels as strings
14
  - show only animal masks (ask an nlp model?)
 
 
 
 
 
 
 
15
  """
16
 
17
  from transformers import DetrFeatureExtractor, DetrForSegmentation
@@ -21,163 +28,8 @@ import numpy as np
21
  import torch
22
  import torchvision
23
 
24
- # Returns a list with a color per ADE class (150 classes)
25
- # from https://huggingface.co/spaces/chansung/segformer-tf-transformers/blob/main/app.py
26
- def ade_palette():
27
- """ADE20K palette that maps each class to RGB values."""
28
- return [
29
- [120, 120, 120],
30
- [180, 120, 120],
31
- [6, 230, 230],
32
- [80, 50, 50],
33
- [4, 200, 3],
34
- [120, 120, 80],
35
- [140, 140, 140],
36
- [204, 5, 255],
37
- [230, 230, 230],
38
- [4, 250, 7],
39
- [224, 5, 255],
40
- [235, 255, 7],
41
- [150, 5, 61],
42
- [120, 120, 70],
43
- [8, 255, 51],
44
- [255, 6, 82],
45
- [143, 255, 140],
46
- [204, 255, 4],
47
- [255, 51, 7],
48
- [204, 70, 3],
49
- [0, 102, 200],
50
- [61, 230, 250],
51
- [255, 6, 51],
52
- [11, 102, 255],
53
- [255, 7, 71],
54
- [255, 9, 224],
55
- [9, 7, 230],
56
- [220, 220, 220],
57
- [255, 9, 92],
58
- [112, 9, 255],
59
- [8, 255, 214],
60
- [7, 255, 224],
61
- [255, 184, 6],
62
- [10, 255, 71],
63
- [255, 41, 10],
64
- [7, 255, 255],
65
- [224, 255, 8],
66
- [102, 8, 255],
67
- [255, 61, 6],
68
- [255, 194, 7],
69
- [255, 122, 8],
70
- [0, 255, 20],
71
- [255, 8, 41],
72
- [255, 5, 153],
73
- [6, 51, 255],
74
- [235, 12, 255],
75
- [160, 150, 20],
76
- [0, 163, 255],
77
- [140, 140, 140],
78
- [250, 10, 15],
79
- [20, 255, 0],
80
- [31, 255, 0],
81
- [255, 31, 0],
82
- [255, 224, 0],
83
- [153, 255, 0],
84
- [0, 0, 255],
85
- [255, 71, 0],
86
- [0, 235, 255],
87
- [0, 173, 255],
88
- [31, 0, 255],
89
- [11, 200, 200],
90
- [255, 82, 0],
91
- [0, 255, 245],
92
- [0, 61, 255],
93
- [0, 255, 112],
94
- [0, 255, 133],
95
- [255, 0, 0],
96
- [255, 163, 0],
97
- [255, 102, 0],
98
- [194, 255, 0],
99
- [0, 143, 255],
100
- [51, 255, 0],
101
- [0, 82, 255],
102
- [0, 255, 41],
103
- [0, 255, 173],
104
- [10, 0, 255],
105
- [173, 255, 0],
106
- [0, 255, 153],
107
- [255, 92, 0],
108
- [255, 0, 255],
109
- [255, 0, 245],
110
- [255, 0, 102],
111
- [255, 173, 0],
112
- [255, 0, 20],
113
- [255, 184, 184],
114
- [0, 31, 255],
115
- [0, 255, 61],
116
- [0, 71, 255],
117
- [255, 0, 204],
118
- [0, 255, 194],
119
- [0, 255, 82],
120
- [0, 10, 255],
121
- [0, 112, 255],
122
- [51, 0, 255],
123
- [0, 194, 255],
124
- [0, 122, 255],
125
- [0, 255, 163],
126
- [255, 153, 0],
127
- [0, 255, 10],
128
- [255, 112, 0],
129
- [143, 255, 0],
130
- [82, 0, 255],
131
- [163, 255, 0],
132
- [255, 235, 0],
133
- [8, 184, 170],
134
- [133, 0, 255],
135
- [0, 255, 92],
136
- [184, 0, 255],
137
- [255, 0, 31],
138
- [0, 184, 255],
139
- [0, 214, 255],
140
- [255, 0, 112],
141
- [92, 255, 0],
142
- [0, 224, 255],
143
- [112, 224, 255],
144
- [70, 184, 160],
145
- [163, 0, 255],
146
- [153, 0, 255],
147
- [71, 255, 0],
148
- [255, 0, 163],
149
- [255, 204, 0],
150
- [255, 0, 143],
151
- [0, 255, 235],
152
- [133, 255, 0],
153
- [255, 0, 235],
154
- [245, 0, 255],
155
- [255, 0, 122],
156
- [255, 245, 0],
157
- [10, 190, 212],
158
- [214, 255, 0],
159
- [0, 204, 255],
160
- [20, 0, 255],
161
- [255, 255, 0],
162
- [0, 153, 255],
163
- [0, 41, 255],
164
- [0, 255, 204],
165
- [41, 0, 255],
166
- [41, 255, 0],
167
- [173, 0, 255],
168
- [0, 245, 255],
169
- [71, 0, 255],
170
- [122, 0, 255],
171
- [0, 255, 184],
172
- [0, 92, 255],
173
- [184, 255, 0],
174
- [0, 133, 255],
175
- [255, 214, 0],
176
- [25, 194, 194],
177
- [102, 255, 0],
178
- [92, 0, 255],
179
- ]
180
-
181
 
182
  def predict_animal_mask(im,
183
  gr_slider_confidence):
@@ -187,9 +39,9 @@ def predict_animal_mask(im,
187
  # encoding is a dict with pixel_values and pixel_mask
188
  encoding = feature_extractor(images=image, return_tensors="pt") #pt=Pytorch, tf=TensorFlow
189
  outputs = model(**encoding) # odict with keys: ['logits', 'pred_boxes', 'pred_masks', 'last_hidden_state', 'encoder_last_hidden_state']
190
- logits = outputs.logits # torch.Size([1, 100, 251]); why 251?
191
  bboxes = outputs.pred_boxes
192
- masks = outputs.pred_masks # torch.Size([1, 100, 200, 200]); for every pixel, score in each of the 100 classes? there is a mask per class
193
 
194
  # keep only the masks with high confidence?--------------------------------
195
  # compute the prob per mask (i.e., class), excluding the "no-object" class (the last one)
@@ -200,8 +52,13 @@ def predict_animal_mask(im,
200
  # postprocess the mask (numpy arrays)
201
  label_per_pixel = torch.argmax(masks[keep].squeeze(),dim=0).detach().numpy() # from the masks per class, select the highest per pixel
202
  color_mask = np.zeros(image.size+(3,))
203
- for lbl, color in enumerate(ade_palette()):
204
- color_mask[label_per_pixel==lbl,:] = color
 
 
 
 
 
205
 
206
  # Show image + mask
207
  pred_img = np.array(image.convert('RGB'))*0.5 + color_mask*0.5
@@ -227,7 +84,7 @@ gr.Interface(predict_animal_mask,
227
  inputs = [gr_image_input,gr_slider_confidence],
228
  outputs = gr_image_output,
229
  title = 'Image segmentation with varying confidence',
230
- description = "An image segmentation webapp using DETR (End-to-End Object Detection) model with ResNet-50 backbone").launch()
231
 
232
 
233
  ####################################
 
12
  Additions
13
  - add shown labels as strings
14
  - show only animal masks (ask an nlp model?)
15
+
16
+ For next time
17
+ - for diff 'confidence' the high conf masks should change....
18
+ - colors are not great and should be constant per class? add text?
19
+ - Im getting core dumped (segmentation fault) when loading hugging face model.. :()
20
+ https://github.com/huggingface/transformers/issues/16939
21
+ - cap slider to 95?
22
  """
23
 
24
  from transformers import DetrFeatureExtractor, DetrForSegmentation
 
28
  import torch
29
  import torchvision
30
 
31
+ import itertools
32
+ import seaborn as sns
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def predict_animal_mask(im,
35
  gr_slider_confidence):
 
39
  # encoding is a dict with pixel_values and pixel_mask
40
  encoding = feature_extractor(images=image, return_tensors="pt") #pt=Pytorch, tf=TensorFlow
41
  outputs = model(**encoding) # odict with keys: ['logits', 'pred_boxes', 'pred_masks', 'last_hidden_state', 'encoder_last_hidden_state']
42
+ logits = outputs.logits # torch.Size([1, 100, 251]); class logits? but why 251?
43
  bboxes = outputs.pred_boxes
44
+ masks = outputs.pred_masks # torch.Size([1, 100, 200, 200]); mask logits? for every pixel, score in each of the 100 classes? there is a mask per class
45
 
46
  # keep only the masks with high confidence?--------------------------------
47
  # compute the prob per mask (i.e., class), excluding the "no-object" class (the last one)
 
52
  # postprocess the mask (numpy arrays)
53
  label_per_pixel = torch.argmax(masks[keep].squeeze(),dim=0).detach().numpy() # from the masks per class, select the highest per pixel
54
  color_mask = np.zeros(image.size+(3,))
55
+ palette = itertools.cycle(sns.color_palette())
56
+ for lbl in np.unique(label_per_pixel): #enumerate(palette()):
57
+ color_mask[label_per_pixel==lbl,:] = np.asarray(next(palette))*255 #color
58
+
59
+ # color_mask = np.zeros(image.size+(3,))
60
+ # for lbl, color in enumerate(ade_palette()):
61
+ # color_mask[label_per_pixel==lbl,:] = color
62
 
63
  # Show image + mask
64
  pred_img = np.array(image.convert('RGB'))*0.5 + color_mask*0.5
 
84
  inputs = [gr_image_input,gr_slider_confidence],
85
  outputs = gr_image_output,
86
  title = 'Image segmentation with varying confidence',
87
+ description = "A panoptic (semantic+instance) segmentation webapp using DETR (End-to-End Object Detection) model with ResNet-50 backbone").launch()
88
 
89
 
90
  ####################################