canadianjosieharrison commited on
Commit
b93a8f5
·
verified ·
1 Parent(s): e6b5eea

Update semantic_seg_model.py

Browse files
Files changed (1) hide show
  1. semantic_seg_model.py +316 -316
semantic_seg_model.py CHANGED
@@ -1,317 +1,317 @@
1
- import torch
2
- from transformers import pipeline, AutoImageProcessor, SegformerForSemanticSegmentation
3
- from typing import List
4
- from PIL import Image, ImageDraw, ImageFont, ImageChops, ImageMorph
5
- import numpy as np
6
- import datasets
7
-
8
- def find_center_of_non_black_pixels(image):
9
- # Get image dimensions
10
- width, height = image.size
11
-
12
- # Iterate over the pixels to find the center of the non-black pixels
13
- total_x = 0
14
- total_y = 0
15
- num_non_black_pixels = 0
16
- top, left, bottom, right = height, width, 0, 0
17
- for y in range(height):
18
- for x in range(width):
19
- pixel = image.getpixel((x, y))
20
- if pixel != (255, 255, 255): # Non-black pixel
21
- total_x += x
22
- total_y += y
23
- num_non_black_pixels += 1
24
- top = min(top, y)
25
- left = min(left, x)
26
- bottom = max(bottom, y)
27
- right = max(right, x)
28
-
29
- bbox_width = right - left
30
- bbox_height = bottom - top
31
- bbox_size = max(bbox_height, bbox_width)
32
- # Calculate the center of the non-black pixels
33
- if num_non_black_pixels == 0:
34
- return None # No non-black pixels found
35
- center_x = total_x // num_non_black_pixels
36
- center_y = total_y // num_non_black_pixels
37
- return (center_x, center_y), bbox_size
38
-
39
- def create_centered_image(image, center, bbox_size):
40
- # Get image dimensions
41
- width, height = image.size
42
-
43
- # Calculate the offset to center the non-black pixels in the new image
44
- offset_x = bbox_size // 2 - center[0]
45
- offset_y = bbox_size // 2 - center[1]
46
-
47
- # Create a new image with the same size as the original image
48
- new_image = Image.new("RGB", (bbox_size, bbox_size), color=(255, 255, 255))
49
-
50
- # Paste the non-black pixels onto the new image
51
- new_image.paste(image, (offset_x, offset_y))
52
-
53
- return new_image
54
-
55
- def ade_palette():
56
- """ADE20K palette that maps each class to RGB values."""
57
- return [
58
- [180, 120, 20],
59
- [180, 120, 120],
60
- [6, 230, 230],
61
- [80, 50, 50],
62
- [4, 200, 3],
63
- [120, 120, 80],
64
- [140, 140, 140],
65
- [204, 5, 255],
66
- [230, 230, 230],
67
- [4, 250, 7],
68
- [224, 5, 255],
69
- [235, 255, 7],
70
- [150, 5, 61],
71
- [120, 120, 70],
72
- [8, 255, 51],
73
- [255, 6, 82],
74
- [143, 255, 140],
75
- [204, 255, 4],
76
- [255, 51, 7],
77
- [204, 70, 3],
78
- [0, 102, 200],
79
- [61, 230, 250],
80
- [255, 6, 51],
81
- [11, 102, 255],
82
- [255, 7, 71],
83
- [255, 9, 224],
84
- [9, 7, 230],
85
- [220, 220, 220],
86
- [255, 9, 92],
87
- [112, 9, 255],
88
- [8, 255, 214],
89
- [7, 255, 224],
90
- [255, 184, 6],
91
- [10, 255, 71],
92
- [255, 41, 10],
93
- [7, 255, 255],
94
- [224, 255, 8],
95
- [102, 8, 255],
96
- [255, 61, 6],
97
- [255, 194, 7],
98
- [255, 122, 8],
99
- [0, 255, 20],
100
- [255, 8, 41],
101
- [255, 5, 153],
102
- [6, 51, 255],
103
- [235, 12, 255],
104
- [160, 150, 20],
105
- [0, 163, 255],
106
- [140, 140, 140],
107
- [250, 10, 15],
108
- [20, 255, 0],
109
- [31, 255, 0],
110
- [255, 31, 0],
111
- [255, 224, 0],
112
- [153, 255, 0],
113
- [0, 0, 255],
114
- [255, 71, 0],
115
- [0, 235, 255],
116
- [0, 173, 255],
117
- [31, 0, 255],
118
- [11, 200, 200],
119
- [255, 82, 0],
120
- [0, 255, 245],
121
- [0, 61, 255],
122
- [0, 255, 112],
123
- [0, 255, 133],
124
- [255, 0, 0],
125
- [255, 163, 0],
126
- [255, 102, 0],
127
- [194, 255, 0],
128
- [0, 143, 255],
129
- [51, 255, 0],
130
- [0, 82, 255],
131
- [0, 255, 41],
132
- [0, 255, 173],
133
- [10, 0, 255],
134
- [173, 255, 0],
135
- [0, 255, 153],
136
- [255, 92, 0],
137
- [255, 0, 255],
138
- [255, 0, 245],
139
- [255, 0, 102],
140
- [255, 173, 0],
141
- [255, 0, 20],
142
- [255, 184, 184],
143
- [0, 31, 255],
144
- [0, 255, 61],
145
- [0, 71, 255],
146
- [255, 0, 204],
147
- [0, 255, 194],
148
- [0, 255, 82],
149
- [0, 10, 255],
150
- [0, 112, 255],
151
- [51, 0, 255],
152
- [0, 194, 255],
153
- [0, 122, 255],
154
- [0, 255, 163],
155
- [255, 153, 0],
156
- [0, 255, 10],
157
- [255, 112, 0],
158
- [143, 255, 0],
159
- [82, 0, 255],
160
- [163, 255, 0],
161
- [255, 235, 0],
162
- [8, 184, 170],
163
- [133, 0, 255],
164
- [0, 255, 92],
165
- [184, 0, 255],
166
- [255, 0, 31],
167
- [0, 184, 255],
168
- [0, 214, 255],
169
- [255, 0, 112],
170
- [92, 255, 0],
171
- [0, 224, 255],
172
- [112, 224, 255],
173
- [70, 184, 160],
174
- [163, 0, 255],
175
- [153, 0, 255],
176
- [71, 255, 0],
177
- [255, 0, 163],
178
- [255, 204, 0],
179
- [255, 0, 143],
180
- [0, 255, 235],
181
- [133, 255, 0],
182
- [255, 0, 235],
183
- [245, 0, 255],
184
- [255, 0, 122],
185
- [255, 245, 0],
186
- [10, 190, 212],
187
- [214, 255, 0],
188
- [0, 204, 255],
189
- [20, 0, 255],
190
- [255, 255, 0],
191
- [0, 153, 255],
192
- [0, 41, 255],
193
- [0, 255, 204],
194
- [41, 0, 255],
195
- [41, 255, 0],
196
- [173, 0, 255],
197
- [0, 245, 255],
198
- [71, 0, 255],
199
- [122, 0, 255],
200
- [0, 255, 184],
201
- [0, 92, 255],
202
- [184, 255, 0],
203
- [0, 133, 255],
204
- [255, 214, 0],
205
- [25, 194, 194],
206
- [102, 255, 0],
207
- [92, 0, 255],
208
- ]
209
-
210
- def label_to_color_image(label, colormap):
211
- if label.ndim != 2:
212
- raise ValueError("Expect 2-D input label")
213
-
214
- if np.max(label) >= len(colormap):
215
- raise ValueError("label value too large.")
216
-
217
- return colormap[label]
218
-
219
- labels_list = []
220
-
221
- with open(r'labels.txt', 'r') as fp:
222
- for line in fp:
223
- labels_list.append(line[:-1])
224
-
225
- colormap = np.asarray(ade_palette())
226
- LABEL_NAMES = np.asarray(labels_list)
227
- LABEL_TO_INDEX = {label: i for i, label in enumerate(labels_list)}
228
- FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
229
- FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP, colormap)
230
- FONT = ImageFont.truetype("Arial.ttf", 1000)
231
-
232
- def lift_black_value(image, lift_amount):
233
- """
234
- Increase the black values of an image by a specified amount.
235
-
236
- Parameters:
237
- image (PIL.Image): The image to adjust.
238
- lift_amount (int): The amount to increase the brightness of the darker pixels.
239
-
240
- Returns:
241
- PIL.Image: The adjusted image with lifted black values.
242
- """
243
- # Ensure that we don't go out of the 0-255 range for any pixel value
244
- def adjust_value(value):
245
- return min(255, max(0, value + lift_amount))
246
-
247
- # Apply the point function to each channel
248
- return image.point(adjust_value)
249
-
250
- torch.set_grad_enabled(False)
251
-
252
- DEVICE = 'cuda' if torch.cuda.is_available() else "cpu"
253
- # MIN_AREA_THRESHOLD = 0.01
254
-
255
- pipe = pipeline("image-segmentation", model="nvidia/segformer-b5-finetuned-ade-640-640")
256
-
257
- def segmentation_inference(
258
- image_rgb_pil: Image.Image,
259
- savepath: str
260
- ):
261
- outputs = pipe(image_rgb_pil, points_per_batch=32)
262
-
263
- for i, prediction in enumerate(outputs):
264
- label = prediction['label']
265
- if (label == "floor") | (label == "wall") | (label == "ceiling"):
266
- mask = prediction['mask']
267
-
268
- ## Save mask
269
- label_savepath = savepath + label + str(i) + '.png'
270
- fill_image = Image.new("RGB", image_rgb_pil.size, color=(255,255,255))
271
- cutout_image = Image.composite(image_rgb_pil, fill_image, mask)
272
-
273
- # Crop mask
274
- center, bbox_size = find_center_of_non_black_pixels(cutout_image)
275
- if center is not None:
276
- centered_image = create_centered_image(cutout_image, center, bbox_size)
277
- centered_image.save(label_savepath)
278
-
279
- ## Inspect masks
280
- # inverted_mask = ImageChops.invert(mask)
281
- # mask_adjusted = lift_black_value(inverted_mask, 100)
282
- # color_index = LABEL_TO_INDEX[label]
283
- # color = tuple(FULL_COLOR_MAP[color_index][0])
284
- # fill_image = Image.new("RGB", image_rgb_pil.size, color=color)
285
- # image_rgb_pil = Image.composite(image_rgb_pil, fill_image, mask_adjusted)
286
-
287
- # Display the final image
288
- # image_rgb_pil.show()
289
-
290
- def online_segmentation_inference(
291
- image_rgb_pil: Image.Image
292
- ):
293
- outputs = pipe(image_rgb_pil, points_per_batch=32)
294
-
295
- # Create an image dictionary
296
- image_dict = {"image": [], "label":[]}
297
-
298
- for i, prediction in enumerate(outputs):
299
- label = prediction['label']
300
- if (label == "floor") | (label == "wall") | (label == "ceiling"):
301
- mask = prediction['mask']
302
-
303
- fill_image = Image.new("RGB", image_rgb_pil.size, color=(255,255,255))
304
- cutout_image = Image.composite(image_rgb_pil, fill_image, mask)
305
-
306
- # Crop mask
307
- center, bbox_size = find_center_of_non_black_pixels(cutout_image)
308
- if center is not None:
309
- centered_image = create_centered_image(cutout_image, center, bbox_size)
310
-
311
- # Add image to image dictionary
312
- image_dict["image"].append(centered_image)
313
- image_dict["label"].append(label)
314
-
315
- segmented_ds = datasets.Dataset.from_dict(image_dict).cast_column("image", datasets.Image())
316
- return segmented_ds
317
 
 
1
+ import torch
2
+ from transformers import pipeline, AutoImageProcessor, SegformerForSemanticSegmentation
3
+ from typing import List
4
+ from PIL import Image, ImageDraw, ImageFont, ImageChops, ImageMorph
5
+ import numpy as np
6
+ import datasets
7
+
8
+ def find_center_of_non_black_pixels(image):
9
+ # Get image dimensions
10
+ width, height = image.size
11
+
12
+ # Iterate over the pixels to find the center of the non-black pixels
13
+ total_x = 0
14
+ total_y = 0
15
+ num_non_black_pixels = 0
16
+ top, left, bottom, right = height, width, 0, 0
17
+ for y in range(height):
18
+ for x in range(width):
19
+ pixel = image.getpixel((x, y))
20
+ if pixel != (255, 255, 255): # Non-black pixel
21
+ total_x += x
22
+ total_y += y
23
+ num_non_black_pixels += 1
24
+ top = min(top, y)
25
+ left = min(left, x)
26
+ bottom = max(bottom, y)
27
+ right = max(right, x)
28
+
29
+ bbox_width = right - left
30
+ bbox_height = bottom - top
31
+ bbox_size = max(bbox_height, bbox_width)
32
+ # Calculate the center of the non-black pixels
33
+ if num_non_black_pixels == 0:
34
+ return None # No non-black pixels found
35
+ center_x = total_x // num_non_black_pixels
36
+ center_y = total_y // num_non_black_pixels
37
+ return (center_x, center_y), bbox_size
38
+
39
+ def create_centered_image(image, center, bbox_size):
40
+ # Get image dimensions
41
+ width, height = image.size
42
+
43
+ # Calculate the offset to center the non-black pixels in the new image
44
+ offset_x = bbox_size // 2 - center[0]
45
+ offset_y = bbox_size // 2 - center[1]
46
+
47
+ # Create a new image with the same size as the original image
48
+ new_image = Image.new("RGB", (bbox_size, bbox_size), color=(255, 255, 255))
49
+
50
+ # Paste the non-black pixels onto the new image
51
+ new_image.paste(image, (offset_x, offset_y))
52
+
53
+ return new_image
54
+
55
+ def ade_palette():
56
+ """ADE20K palette that maps each class to RGB values."""
57
+ return [
58
+ [180, 120, 20],
59
+ [180, 120, 120],
60
+ [6, 230, 230],
61
+ [80, 50, 50],
62
+ [4, 200, 3],
63
+ [120, 120, 80],
64
+ [140, 140, 140],
65
+ [204, 5, 255],
66
+ [230, 230, 230],
67
+ [4, 250, 7],
68
+ [224, 5, 255],
69
+ [235, 255, 7],
70
+ [150, 5, 61],
71
+ [120, 120, 70],
72
+ [8, 255, 51],
73
+ [255, 6, 82],
74
+ [143, 255, 140],
75
+ [204, 255, 4],
76
+ [255, 51, 7],
77
+ [204, 70, 3],
78
+ [0, 102, 200],
79
+ [61, 230, 250],
80
+ [255, 6, 51],
81
+ [11, 102, 255],
82
+ [255, 7, 71],
83
+ [255, 9, 224],
84
+ [9, 7, 230],
85
+ [220, 220, 220],
86
+ [255, 9, 92],
87
+ [112, 9, 255],
88
+ [8, 255, 214],
89
+ [7, 255, 224],
90
+ [255, 184, 6],
91
+ [10, 255, 71],
92
+ [255, 41, 10],
93
+ [7, 255, 255],
94
+ [224, 255, 8],
95
+ [102, 8, 255],
96
+ [255, 61, 6],
97
+ [255, 194, 7],
98
+ [255, 122, 8],
99
+ [0, 255, 20],
100
+ [255, 8, 41],
101
+ [255, 5, 153],
102
+ [6, 51, 255],
103
+ [235, 12, 255],
104
+ [160, 150, 20],
105
+ [0, 163, 255],
106
+ [140, 140, 140],
107
+ [250, 10, 15],
108
+ [20, 255, 0],
109
+ [31, 255, 0],
110
+ [255, 31, 0],
111
+ [255, 224, 0],
112
+ [153, 255, 0],
113
+ [0, 0, 255],
114
+ [255, 71, 0],
115
+ [0, 235, 255],
116
+ [0, 173, 255],
117
+ [31, 0, 255],
118
+ [11, 200, 200],
119
+ [255, 82, 0],
120
+ [0, 255, 245],
121
+ [0, 61, 255],
122
+ [0, 255, 112],
123
+ [0, 255, 133],
124
+ [255, 0, 0],
125
+ [255, 163, 0],
126
+ [255, 102, 0],
127
+ [194, 255, 0],
128
+ [0, 143, 255],
129
+ [51, 255, 0],
130
+ [0, 82, 255],
131
+ [0, 255, 41],
132
+ [0, 255, 173],
133
+ [10, 0, 255],
134
+ [173, 255, 0],
135
+ [0, 255, 153],
136
+ [255, 92, 0],
137
+ [255, 0, 255],
138
+ [255, 0, 245],
139
+ [255, 0, 102],
140
+ [255, 173, 0],
141
+ [255, 0, 20],
142
+ [255, 184, 184],
143
+ [0, 31, 255],
144
+ [0, 255, 61],
145
+ [0, 71, 255],
146
+ [255, 0, 204],
147
+ [0, 255, 194],
148
+ [0, 255, 82],
149
+ [0, 10, 255],
150
+ [0, 112, 255],
151
+ [51, 0, 255],
152
+ [0, 194, 255],
153
+ [0, 122, 255],
154
+ [0, 255, 163],
155
+ [255, 153, 0],
156
+ [0, 255, 10],
157
+ [255, 112, 0],
158
+ [143, 255, 0],
159
+ [82, 0, 255],
160
+ [163, 255, 0],
161
+ [255, 235, 0],
162
+ [8, 184, 170],
163
+ [133, 0, 255],
164
+ [0, 255, 92],
165
+ [184, 0, 255],
166
+ [255, 0, 31],
167
+ [0, 184, 255],
168
+ [0, 214, 255],
169
+ [255, 0, 112],
170
+ [92, 255, 0],
171
+ [0, 224, 255],
172
+ [112, 224, 255],
173
+ [70, 184, 160],
174
+ [163, 0, 255],
175
+ [153, 0, 255],
176
+ [71, 255, 0],
177
+ [255, 0, 163],
178
+ [255, 204, 0],
179
+ [255, 0, 143],
180
+ [0, 255, 235],
181
+ [133, 255, 0],
182
+ [255, 0, 235],
183
+ [245, 0, 255],
184
+ [255, 0, 122],
185
+ [255, 245, 0],
186
+ [10, 190, 212],
187
+ [214, 255, 0],
188
+ [0, 204, 255],
189
+ [20, 0, 255],
190
+ [255, 255, 0],
191
+ [0, 153, 255],
192
+ [0, 41, 255],
193
+ [0, 255, 204],
194
+ [41, 0, 255],
195
+ [41, 255, 0],
196
+ [173, 0, 255],
197
+ [0, 245, 255],
198
+ [71, 0, 255],
199
+ [122, 0, 255],
200
+ [0, 255, 184],
201
+ [0, 92, 255],
202
+ [184, 255, 0],
203
+ [0, 133, 255],
204
+ [255, 214, 0],
205
+ [25, 194, 194],
206
+ [102, 255, 0],
207
+ [92, 0, 255],
208
+ ]
209
+
210
+ def label_to_color_image(label, colormap):
211
+ if label.ndim != 2:
212
+ raise ValueError("Expect 2-D input label")
213
+
214
+ if np.max(label) >= len(colormap):
215
+ raise ValueError("label value too large.")
216
+
217
+ return colormap[label]
218
+
219
+ labels_list = []
220
+
221
+ with open(r'labels.txt', 'r') as fp:
222
+ for line in fp:
223
+ labels_list.append(line[:-1])
224
+
225
+ colormap = np.asarray(ade_palette())
226
+ LABEL_NAMES = np.asarray(labels_list)
227
+ LABEL_TO_INDEX = {label: i for i, label in enumerate(labels_list)}
228
+ FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
229
+ FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP, colormap)
230
+ # FONT = ImageFont.truetype("Arial.ttf", 1000)
231
+
232
+ def lift_black_value(image, lift_amount):
233
+ """
234
+ Increase the black values of an image by a specified amount.
235
+
236
+ Parameters:
237
+ image (PIL.Image): The image to adjust.
238
+ lift_amount (int): The amount to increase the brightness of the darker pixels.
239
+
240
+ Returns:
241
+ PIL.Image: The adjusted image with lifted black values.
242
+ """
243
+ # Ensure that we don't go out of the 0-255 range for any pixel value
244
+ def adjust_value(value):
245
+ return min(255, max(0, value + lift_amount))
246
+
247
+ # Apply the point function to each channel
248
+ return image.point(adjust_value)
249
+
250
+ torch.set_grad_enabled(False)
251
+
252
+ DEVICE = 'cuda' if torch.cuda.is_available() else "cpu"
253
+ # MIN_AREA_THRESHOLD = 0.01
254
+
255
+ pipe = pipeline("image-segmentation", model="nvidia/segformer-b5-finetuned-ade-640-640")
256
+
257
+ def segmentation_inference(
258
+ image_rgb_pil: Image.Image,
259
+ savepath: str
260
+ ):
261
+ outputs = pipe(image_rgb_pil, points_per_batch=32)
262
+
263
+ for i, prediction in enumerate(outputs):
264
+ label = prediction['label']
265
+ if (label == "floor") | (label == "wall") | (label == "ceiling"):
266
+ mask = prediction['mask']
267
+
268
+ ## Save mask
269
+ label_savepath = savepath + label + str(i) + '.png'
270
+ fill_image = Image.new("RGB", image_rgb_pil.size, color=(255,255,255))
271
+ cutout_image = Image.composite(image_rgb_pil, fill_image, mask)
272
+
273
+ # Crop mask
274
+ center, bbox_size = find_center_of_non_black_pixels(cutout_image)
275
+ if center is not None:
276
+ centered_image = create_centered_image(cutout_image, center, bbox_size)
277
+ centered_image.save(label_savepath)
278
+
279
+ ## Inspect masks
280
+ # inverted_mask = ImageChops.invert(mask)
281
+ # mask_adjusted = lift_black_value(inverted_mask, 100)
282
+ # color_index = LABEL_TO_INDEX[label]
283
+ # color = tuple(FULL_COLOR_MAP[color_index][0])
284
+ # fill_image = Image.new("RGB", image_rgb_pil.size, color=color)
285
+ # image_rgb_pil = Image.composite(image_rgb_pil, fill_image, mask_adjusted)
286
+
287
+ # Display the final image
288
+ # image_rgb_pil.show()
289
+
290
+ # def online_segmentation_inference(
291
+ # image_rgb_pil: Image.Image
292
+ # ):
293
+ # outputs = pipe(image_rgb_pil, points_per_batch=32)
294
+
295
+ # # Create an image dictionary
296
+ # image_dict = {"image": [], "label":[]}
297
+
298
+ # for i, prediction in enumerate(outputs):
299
+ # label = prediction['label']
300
+ # if (label == "floor") | (label == "wall") | (label == "ceiling"):
301
+ # mask = prediction['mask']
302
+
303
+ # fill_image = Image.new("RGB", image_rgb_pil.size, color=(255,255,255))
304
+ # cutout_image = Image.composite(image_rgb_pil, fill_image, mask)
305
+
306
+ # # Crop mask
307
+ # center, bbox_size = find_center_of_non_black_pixels(cutout_image)
308
+ # if center is not None:
309
+ # centered_image = create_centered_image(cutout_image, center, bbox_size)
310
+
311
+ # # Add image to image dictionary
312
+ # image_dict["image"].append(centered_image)
313
+ # image_dict["label"].append(label)
314
+
315
+ # segmented_ds = datasets.Dataset.from_dict(image_dict).cast_column("image", datasets.Image())
316
+ # return segmented_ds
317