oValach commited on
Commit
d2eab27
1 Parent(s): 9cf0472

Initial commit

Browse files
Files changed (2) hide show
  1. TheDistanceAssessor.py +905 -0
  2. test_filtered_cls.py +283 -0
TheDistanceAssessor.py ADDED
@@ -0,0 +1,905 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import time
4
+ import json
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ from sklearn.linear_model import LinearRegression
8
+ import matplotlib.path as mplPath
9
+ import matplotlib.patches as patches
10
+ from ultralyticsplus import YOLO
11
+ from scripts.test_filtered_cls import load, load_model, process
12
+
13
+ PATH_jpgs = 'assets/rs19val/jpgs/test'
14
+ PATH_model_seg = 'assets/models_pretrained/segformer/SegFormer_B3_1024_finetuned.pth'
15
+ PATH_model_det = 'assets/models_pretrained/ultralyticsplus/yolov8s'
16
+ PATH_base = 'assets/pilsen_railway_dataset/'
17
+ eda_path = "assets/pilsen_railway_dataset/eda_table.table.json"
18
+ data_json = json.load(open(eda_path, 'r'))
19
+
20
+ def load_yolo(PATH_model):
21
+ model = YOLO(PATH_model)
22
+
23
+ model.overrides['conf'] = 0.25 # NMS confidence threshold
24
+ model.overrides['iou'] = 0.45 # NMS IoU threshold
25
+ model.overrides['agnostic_nms'] = False # NMS class-agnostic
26
+ model.overrides['max_det'] = 1000 # maximum number of detections per image
27
+ return model
28
+
29
+ def find_extreme_y_values(arr, values=[0, 6]):
30
+ """
31
+ Optimized function to find the lowest and highest y-values (row indices) in a 2D array where 0 or 6 appears.
32
+
33
+ Parameters:
34
+ - arr: The input 2D NumPy array.
35
+ - values: The values to search for (default is [0, 6]).
36
+
37
+ Returns:
38
+ A tuple (lowest_y, highest_y) representing the lowest and highest y-values. If values are not found, returns None.
39
+ """
40
+ mask = np.isin(arr, values)
41
+ rows_with_values = np.any(mask, axis=1)
42
+
43
+ y_indices = np.nonzero(rows_with_values)[0] # Directly finding non-zero (True) indices
44
+
45
+ if y_indices.size == 0:
46
+ return None, None # Early return if values not found
47
+
48
+ return y_indices[0], y_indices[-1]
49
+
50
+ def find_nearest_pairs(arr1, arr2):
51
+ # Convert lists to numpy arrays for vectorized operations
52
+ arr1_np = np.array(arr1)
53
+ arr2_np = np.array(arr2)
54
+
55
+ # Determine which array is shorter
56
+ if len(arr1_np) < len(arr2_np):
57
+ base_array, compare_array = arr1_np, arr2_np
58
+ else:
59
+ base_array, compare_array = arr2_np, arr1_np
60
+
61
+ paired_base = []
62
+ paired_compare = []
63
+
64
+ # Mask to keep track of paired elements
65
+ paired_mask = np.zeros(len(compare_array), dtype=bool)
66
+
67
+ for item in base_array:
68
+ # Calculate distances from the current item to all items in the compare_array
69
+ distances = np.linalg.norm(compare_array - item, axis=1)
70
+ nearest_index = np.argmin(distances)
71
+ paired_base.append(item)
72
+ paired_compare.append(compare_array[nearest_index])
73
+ # Mark the paired element to exclude it from further pairing
74
+ paired_mask[nearest_index] = True
75
+
76
+ # Check if all elements from the compare_array have been paired
77
+ if paired_mask.all():
78
+ break
79
+
80
+ paired_base = np.array(paired_base)
81
+ paired_compare = compare_array[paired_mask]
82
+
83
+ return (paired_base, paired_compare) if len(arr1_np) < len(arr2_np) else (paired_compare, paired_base)
84
+
85
+ def filter_crossings(image, edges_dict):
86
+ filtered_edges = {}
87
+ for key, values in edges_dict.items():
88
+ merged = [values[0]]
89
+ for start, end in values[1:]:
90
+ if start - merged[-1][1] < 50:
91
+
92
+ key_up = max([0, key-10])
93
+ key_down = min([image.shape[0]-1, key+10])
94
+ if key_up == 0:
95
+ key_up = key+20
96
+ if key_down == image.shape[0]-1:
97
+ key_down = key-20
98
+
99
+ edges_to_test_slope1 = robust_edges(image, [key_up], values=[0, 6], min_width=19)
100
+ edges_to_test_slope2 = robust_edges(image, [key_down], values=[0, 6], min_width=19)
101
+
102
+ values1, edges_to_test_slope1 = find_nearest_pairs(values, edges_to_test_slope1)
103
+ values2, edges_to_test_slope2 = find_nearest_pairs(values, edges_to_test_slope2)
104
+
105
+ differences_y = []
106
+ for i, value in enumerate(values1):
107
+ if start in value:
108
+ idx = list(value).index(start)
109
+ try:
110
+ differences_y.append(abs(start-edges_to_test_slope1[i][idx]))
111
+ except:
112
+ pass
113
+ if merged[-1][1] in value:
114
+ idx = list(value).index(merged[-1][1])
115
+ try:
116
+ differences_y.append(abs(merged[-1][1]-edges_to_test_slope1[i][idx]))
117
+ except:
118
+ pass
119
+ for i, value in enumerate(values2):
120
+ if start in value:
121
+ idx = list(value).index(start)
122
+ try:
123
+ differences_y.append(abs(start-edges_to_test_slope2[i][idx]))
124
+ except:
125
+ pass
126
+ if merged[-1][1] in value:
127
+ idx = list(value).index(merged[-1][1])
128
+ try:
129
+ differences_y.append(abs(merged[-1][1]-edges_to_test_slope2[i][idx]))
130
+ except:
131
+ pass
132
+
133
+ if any(element > 30 for element in differences_y):
134
+ merged[-1] = (merged[-1][0], end)
135
+ else:
136
+ merged.append((start, end))
137
+ else:
138
+ merged.append((start, end))
139
+ filtered_edges[key] = merged
140
+
141
+ return filtered_edges
142
+
143
+ def robust_edges(image, y_levels, values=[0, 6], min_width=19):
144
+
145
+ for y in y_levels:
146
+ row = image[y, :]
147
+ mask = np.isin(row, values).astype(int)
148
+ padded_mask = np.pad(mask, (1, 1), 'constant', constant_values=0)
149
+ diff = np.diff(padded_mask)
150
+ starts = np.where(diff == 1)[0]
151
+ ends = np.where(diff == -1)[0] - 1
152
+
153
+ # Filter sequences based on the minimum width criteria
154
+ filtered_edges = [(start, end) for start, end in zip(starts, ends) if end - start + 1 >= min_width]
155
+ filtered_edges = [(start, end) for start, end in filtered_edges if 0 not in (start, end) and 1919 not in (start, end)]
156
+
157
+ return filtered_edges
158
+
159
+ def find_edges(image, y_levels, values=[0, 6], min_width=19):
160
+ """
161
+ Find start and end positions of continuous sequences of specified values at given y-levels in a 2D array,
162
+ filtering for sequences that meet or exceed a specified minimum width.
163
+
164
+ Parameters:
165
+ - arr: 2D NumPy array to search within.
166
+ - y_levels: List of y-levels (row indices) to examine.
167
+ - values: Values to search for (default is [0, 6]).
168
+ - min_width: Minimum width of sequences to be included in the results.
169
+
170
+ Returns:
171
+ A dict with y-levels as keys and lists of (start, end) tuples for each sequence found in that row that meets the width criteria.
172
+ """
173
+ edges_dict = {}
174
+ for y in y_levels:
175
+ row = image[y, :]
176
+ mask = np.isin(row, values).astype(int)
177
+ padded_mask = np.pad(mask, (1, 1), 'constant', constant_values=0)
178
+ diff = np.diff(padded_mask)
179
+ starts = np.where(diff == 1)[0]
180
+ ends = np.where(diff == -1)[0] - 1
181
+
182
+ # Filter sequences based on the minimum width criteria
183
+ filtered_edges = [(start, end) for start, end in zip(starts, ends) if end - start + 1 >= min_width]
184
+ filtered_edges = [(start, end) for start, end in filtered_edges if 0 not in (start, end) and 1919 not in (start, end)]
185
+
186
+ edges_with_guard_rails = []
187
+ for edge in filtered_edges:
188
+ cutout_left = image[y,edge[0]-50:edge[0]][::-1]
189
+ cutout_right = image[y,edge[1]:edge[1]+50]
190
+
191
+ not_ones = np.where(cutout_left != 1)[0]
192
+ if len(not_ones) > 0 and not_ones[0] > 0:
193
+ last_one_index = not_ones[0] - 1
194
+ edge = (edge[0] - last_one_index,) + edge[1:]
195
+ else:
196
+ last_one_index = None if len(not_ones) == 0 else not_ones[-1] - 1
197
+
198
+ not_ones = np.where(cutout_right != 1)[0]
199
+ if len(not_ones) > 0 and not_ones[0] > 0:
200
+ last_one_index = not_ones[0] - 1
201
+ edge = (edge[0], edge[1] - last_one_index) + edge[2:]
202
+ else:
203
+ last_one_index = None if len(not_ones) == 0 else not_ones[-1] - 1
204
+
205
+ edges_with_guard_rails.append(edge)
206
+
207
+ edges_dict[y] = edges_with_guard_rails
208
+
209
+ edges_dict = {k: v for k, v in edges_dict.items() if v}
210
+
211
+ edges_dict = filter_crossings(image, edges_dict)
212
+
213
+ return edges_dict
214
+
215
+ def find_rails(arr, y_levels, values=[9, 10], min_width=5):
216
+ edges_all = []
217
+ for y in y_levels:
218
+ row = arr[y, :]
219
+ mask = np.isin(row, values).astype(int)
220
+ padded_mask = np.pad(mask, (1, 1), 'constant', constant_values=0)
221
+ diff = np.diff(padded_mask)
222
+ starts = np.where(diff == 1)[0]
223
+ ends = np.where(diff == -1)[0] - 1
224
+
225
+ # Filter sequences based on the minimum width criteria
226
+ filtered_edges = [(start, end) for start, end in zip(starts, ends) if end - start + 1 >= min_width]
227
+ filtered_edges = [(start, end) for start, end in filtered_edges if 0 not in (start, end) and 1919 not in (start, end)]
228
+ edges_all = filtered_edges
229
+
230
+ return edges_all
231
+
232
+ def mark_edges(arr, edges_dict, mark_value=30):
233
+ """
234
+ Marks a 5x5 zone around the edges found in the array with a specific value.
235
+
236
+ Parameters:
237
+ - arr: The original 2D NumPy array.
238
+ - edges_dict: A dictionary with y-levels as keys and lists of (start, end) tuples for edges.
239
+ - mark_value: The value used to mark the edges.
240
+
241
+ Returns:
242
+ The modified array with marked zones.
243
+ """
244
+ marked_arr = np.copy(arr) # Create a copy of the array to avoid modifying the original
245
+ offset = 2 # To mark a 5x5 area, we go 2 pixels in each direction from the center
246
+
247
+ for y, edges in edges_dict.items():
248
+ for start, end in edges:
249
+ # Mark a 5x5 zone around the start and end positions
250
+ for dy in range(-offset, offset + 1):
251
+ for dx in range(-offset, offset + 1):
252
+ # Check array bounds before marking
253
+ if 0 <= y + dy < marked_arr.shape[0] and 0 <= start + dx < marked_arr.shape[1]:
254
+ marked_arr[y + dy, start + dx] = mark_value
255
+ if 0 <= y + dy < marked_arr.shape[0] and 0 <= end + dx < marked_arr.shape[1]:
256
+ marked_arr[y + dy, end + dx] = mark_value
257
+
258
+ return marked_arr
259
+
260
+ def find_rail_sides(img, edges_dict):
261
+ left_border = []
262
+ right_border = []
263
+ for y,xs in edges_dict.items():
264
+ rails = find_rails(img, [y], values=[9,10], min_width=5)
265
+ left_border_actual = [min(xs)[0],y]
266
+ right_border_actual = [max(xs)[1],y]
267
+
268
+ for zone in rails:
269
+ if abs(zone[1]-left_border_actual[0]) < y*0.04: # dynamic treshold
270
+ left_border_actual[0] = zone[0]
271
+ if abs(zone[0]-right_border_actual[0]) < y*0.04:
272
+ right_border_actual[0] = zone[1]
273
+
274
+ left_border.append(left_border_actual)
275
+ right_border.append(right_border_actual)
276
+
277
+ # removing detected uncontioussness
278
+ left_border, flags_l, _ = robust_rail_sides(left_border) # filter outliers
279
+ right_border, flags_r, _ = robust_rail_sides(right_border)
280
+
281
+ return left_border, right_border, flags_l, flags_r
282
+
283
+ def robust_rail_sides(border, threshold=7):
284
+ border = np.array(border)
285
+ if border.size > 0:
286
+ # delete borders found on the bottom side of the image
287
+ border = border[border[:, 1] != 1079]
288
+
289
+ steps_x = np.diff(border[:, 0])
290
+ median_step = np.median(np.abs(steps_x))
291
+
292
+ threshold_step = np.abs(threshold*np.abs(median_step))
293
+ treshold_overcommings = abs(steps_x) > abs(threshold_step)
294
+
295
+ flags = []
296
+
297
+ if True not in treshold_overcommings:
298
+ return border, flags, []
299
+ else:
300
+ overcommings_indices = [i for i, element in enumerate(treshold_overcommings) if element == True]
301
+ if overcommings_indices and np.all(np.diff(overcommings_indices) == 1):
302
+ overcommings_indices = [overcommings_indices[0]]
303
+
304
+ filtered_border = border
305
+
306
+ previously_deleted = []
307
+ for i in overcommings_indices:
308
+ for item in previously_deleted:
309
+ if item[0] < i:
310
+ i -= item[1]
311
+ first_part = filtered_border[:i+1]
312
+ second_part = filtered_border[i+1:]
313
+ if len(second_part)<2:
314
+ filtered_border = first_part
315
+ previously_deleted.append([i,len(second_part)])
316
+ elif len(first_part)<2:
317
+ filtered_border = second_part
318
+ previously_deleted.append([i,len(first_part)])
319
+ else:
320
+ first_b, _, deleted_first = robust_rail_sides(first_part)
321
+ second_b, _, _ = robust_rail_sides(second_part)
322
+ filtered_border = np.concatenate((first_b,second_b), axis=0)
323
+
324
+ if deleted_first:
325
+ for deleted_item in deleted_first:
326
+ if deleted_item[0]<=i:
327
+ i -= deleted_item[1]
328
+
329
+ flags.append(i)
330
+ return filtered_border, flags, previously_deleted
331
+ else:
332
+ return border, [], []
333
+
334
+ def find_dist_from_edges(id_map, image, edges_dict, left_border, right_border, real_life_width_mm, real_life_target_mm, mark_value=30):
335
+ """
336
+ Mark regions representing a real-life distance (e.g., 2 meters) to the left and right from the furthest edges.
337
+
338
+ Parameters:
339
+ - arr: 2D NumPy array representing the id_map.
340
+ - edges_dict: Dictionary with y-levels as keys and lists of (start, end) tuples for edges.
341
+ - real_life_width_mm: The real-world width in millimeters that the average sequence width represents.
342
+ - real_life_target_mm: The real-world distance in millimeters to mark from the edges.
343
+
344
+ Returns:
345
+ - A NumPy array with the marked regions.
346
+ """
347
+ # Calculate the rail widths
348
+ diffs_widths = {k: sum(e-s for s, e in v) / len(v) for k, v in edges_dict.items() if v}
349
+ diffs_width = {k: max(e-s for s, e in v) for k, v in edges_dict.items() if v}
350
+
351
+ # Pixel to mm scale factor
352
+ scale_factors = {k: real_life_width_mm / v for k, v in diffs_width.items()}
353
+ # Converting the real-life target distance to pixels
354
+ target_distances_px = {k: int(real_life_target_mm / v) for k, v in scale_factors.items()}
355
+
356
+ # Mark the regions representing the target distance to the left and right from the furthest edges
357
+ end_points_left = {}
358
+ region_levels_left = []
359
+ for point in left_border:
360
+ min_edge = point[0]
361
+
362
+ # Ensure we stay within the image bounds
363
+ #left_mark_start = max(0, min_edge - int(target_distances_px[point[1]]))
364
+ left_mark_start = min_edge - int(target_distances_px[point[1]])
365
+ end_points_left[point[1]] = left_mark_start
366
+
367
+ # Left region points
368
+ if left_mark_start < min_edge:
369
+ y_values = np.arange(left_mark_start, min_edge)
370
+ x_values = np.full_like(y_values, point[1])
371
+ region_line = np.column_stack((x_values, y_values))
372
+ region_levels_left.append(region_line)
373
+
374
+ end_points_right = {}
375
+ region_levels_right = []
376
+ for point in right_border:
377
+ max_edge = point[0]
378
+
379
+ # Ensure we stay within the image bounds
380
+ right_mark_end = min(id_map.shape[1], max_edge + int(target_distances_px[point[1]]))
381
+ if right_mark_end != id_map.shape[1]:
382
+ end_points_right[point[1]] = right_mark_end
383
+
384
+ # Right region points
385
+ if max_edge < right_mark_end:
386
+ y_values = np.arange(max_edge, right_mark_end)
387
+ x_values = np.full_like(y_values, point[1])
388
+ region_line = np.column_stack((x_values, y_values))
389
+ region_levels_right.append(region_line)
390
+
391
+ return id_map, end_points_left, end_points_right, region_levels_left, region_levels_right
392
+
393
+ def bresenham_line(x0, y0, x1, y1):
394
+ """
395
+ Generate the coordinates of a line from (x0, y0) to (x1, y1) using Bresenham's algorithm.
396
+ """
397
+ line = []
398
+ dx = abs(x1 - x0)
399
+ dy = -abs(y1 - y0)
400
+ sx = 1 if x0 < x1 else -1
401
+ sy = 1 if y0 < y1 else -1
402
+ err = dx + dy # error value e_xy
403
+
404
+ while True:
405
+ line.append((x0, y0)) # Add the current point to the line
406
+ if x0 == x1 and y0 == y1:
407
+ break
408
+ e2 = 2 * err
409
+ if e2 >= dy: # e_xy+e_x > 0
410
+ err += dy
411
+ x0 += sx
412
+ if e2 <= dx: # e_xy+e_y < 0
413
+ err += dx
414
+ y0 += sy
415
+
416
+ return line
417
+
418
+ def interpolate_end_points(end_points_dict, flags):
419
+ line_arr = []
420
+ ys = list(end_points_dict.keys())
421
+ xs = list(end_points_dict.values())
422
+
423
+ if flags and len(flags) == 1:
424
+ pass
425
+ elif flags and np.all(np.diff(flags) == 1):
426
+ flags = [flags[0]]
427
+
428
+ for i in range(0, len(ys) - 1):
429
+ if i in flags:
430
+ continue
431
+ y1, y2 = ys[i], ys[i + 1]
432
+ x1, x2 = xs[i], xs[i + 1]
433
+ line = np.array(bresenham_line(x1, y1, x2, y2))
434
+ if np.any(line[:, 0] < 0):
435
+ line = line[line[:, 0] > 0]
436
+ line_arr = line_arr + list(line)
437
+
438
+ return line_arr
439
+
440
+ def extrapolate_line(pixels, image, min_y=None, extr_pixels=10):
441
+ """
442
+ Extrapolate a line based on the last segment using linear regression.
443
+
444
+ Parameters:
445
+ - pixels: List of (x, y) tuples representing line pixel coordinates.
446
+ - image: 2D numpy array representing the image.
447
+ - min_y: Minimum y-value to extrapolate to (optional).
448
+
449
+ Returns:
450
+ - A list of new extrapolated (x, y) pixel coordinates.
451
+ """
452
+ if len(pixels) < extr_pixels:
453
+ print("Not enough pixels to perform extrapolation.")
454
+ return []
455
+
456
+ recent_pixels = np.array(pixels[-extr_pixels:])
457
+
458
+ X = recent_pixels[:, 0].reshape(-1, 1) # Reshape for sklearn
459
+ y = recent_pixels[:, 1]
460
+
461
+ model = LinearRegression()
462
+ model.fit(X, y)
463
+
464
+ slope = model.coef_[0]
465
+ intercept = model.intercept_
466
+
467
+ extrapolate = lambda x: slope * x + intercept
468
+
469
+ # Calculate direction based on last two pixels
470
+ dx, dy = 0, 0 # Default values
471
+
472
+ x_diffs = []
473
+ y_diffs = []
474
+ for i in range(1,extr_pixels-1):
475
+ x_diffs.append(pixels[-i][0] - pixels[-(i+1)][0])
476
+ y_diffs.append(pixels[-i][1] - pixels[-(i+1)][1])
477
+
478
+ x_diff = x_diffs[np.argmax(np.abs(x_diffs))]
479
+ y_diff = y_diffs[np.argmax(np.abs(y_diffs))]
480
+
481
+ if abs(int(x_diff)) >= abs(int(y_diff)):
482
+ dx = 1 if x_diff >= 0 else -1
483
+ else:
484
+ dy = 1 if y_diff >= 0 else -1
485
+
486
+ last_pixel = pixels[-1]
487
+ new_pixels = []
488
+ x, y = last_pixel
489
+
490
+ min_y = min_y if min_y is not None else image.shape[0] - 1
491
+
492
+ while 0 <= x < image.shape[1] and min_y <= y < image.shape[0]:
493
+ if dx != 0: # Horizontal or diagonal movement
494
+ x += dx
495
+ y = int(extrapolate(x))
496
+ elif dy != 0: # Vertical movement
497
+ y += dy
498
+ # For vertical lines, approximate x based on the last known value
499
+ x = int(x)
500
+
501
+ if 0 <= y < image.shape[0] and 0 <= x < image.shape[1]:
502
+ new_pixels.append((x, y))
503
+ else:
504
+ break
505
+
506
+ return new_pixels
507
+
508
+ def extrapolate_borders(dist_marked_id_map, border_l, border_r, lowest_y):
509
+
510
+ #border_extrapolation_l1 = extrapolate_line(border_l, dist_marked_id_map, lowest_y)
511
+ border_extrapolation_l2 = extrapolate_line(border_l[::-1], dist_marked_id_map, lowest_y)
512
+
513
+ #border_extrapolation_r1 = extrapolate_line(border_r, dist_marked_id_map, lowest_y)
514
+ border_extrapolation_r2 = extrapolate_line(border_r[::-1], dist_marked_id_map, lowest_y)
515
+
516
+ #border_l = border_extrapolation_l2[::-1] + border_l + border_extrapolation_l1
517
+ #border_r = border_extrapolation_r2[::-1] + border_r + border_extrapolation_r1
518
+
519
+ border_l = border_extrapolation_l2[::-1] + border_l
520
+ border_r = border_extrapolation_r2[::-1] + border_r
521
+
522
+ return border_l, border_r
523
+
524
+ def find_zone_border(id_map, image, edges, irl_width_mm=1435, irl_target_mm=1000, lowest_y = 0):
525
+
526
+ left_border, right_border, flags_l, flags_r = find_rail_sides(id_map, edges)
527
+
528
+ dist_marked_id_map, end_points_left, end_points_right, left_region, right_region = find_dist_from_edges(id_map, image, edges, left_border, right_border, irl_width_mm, irl_target_mm)
529
+
530
+ border_l = interpolate_end_points(end_points_left, flags_l)
531
+ border_r = interpolate_end_points(end_points_right, flags_r)
532
+
533
+ border_l, border_r = extrapolate_borders(dist_marked_id_map, border_l, border_r, lowest_y)
534
+
535
+ return [border_l, border_r],[left_region, right_region]
536
+
537
+ def get_clues(segmentation_mask, number_of_clues):
538
+
539
+ lowest, highest = find_extreme_y_values(segmentation_mask)
540
+ if lowest is not None and highest is not None:
541
+ clue_step = int((highest - lowest) / number_of_clues+1)
542
+ clues = []
543
+ for i in range(number_of_clues):
544
+ clues.append(highest - (i*clue_step))
545
+ clues.append(lowest+int(0.5*clue_step))
546
+
547
+ return clues
548
+ else:
549
+ return []
550
+
551
+ def border_handler(id_map, image, edges, target_distances):
552
+
553
+ lowest, _ = find_extreme_y_values(id_map)
554
+ borders = []
555
+ regions = []
556
+ for target in target_distances:
557
+ borders_regions = find_zone_border(id_map, image, edges, irl_target_mm=target, lowest_y = lowest)
558
+ borders.append(borders_regions[0])
559
+ regions.append(borders_regions[1])
560
+
561
+ return borders, id_map, regions
562
+
563
+ def segment(model_seg, image_size, filename, PATH_jpgs, dataset_type, model_type, item=None):
564
+ image_norm, _, image, mask, _ = load(filename, PATH_jpgs, image_size, dataset_type=dataset_type, item=item)
565
+ id_map = process(model_seg, image_norm, mask, model_type)
566
+ id_map = cv2.resize(id_map, [1920,1080], interpolation=cv2.INTER_NEAREST)
567
+ return id_map, image
568
+
569
+ def detect(model_det, filename_img, PATH_jpgs):
570
+
571
+ image = cv2.imread(os.path.join(PATH_jpgs, filename_img))
572
+ results = model_det.predict(image)
573
+
574
+ return results, model_det, image
575
+
576
+ def manage_detections(results, model):
577
+ bbox = results[0].boxes.xywh.tolist()
578
+ cls = results[0].boxes.cls.tolist()
579
+ accepted_stationary = np.array([24,25,28,36])
580
+ accepted_moving = np.array([0,1,2,3,7,15,16,17,18,19])
581
+ boxes_moving = {}
582
+ boxes_stationary = {}
583
+ if len(bbox) > 0:
584
+ for xywh, clss in zip(bbox, cls):
585
+ if clss in accepted_moving:
586
+ if clss in boxes_moving.keys() and len(boxes_moving[clss]) > 0:
587
+ boxes_moving[clss].append(xywh)
588
+ else:
589
+ boxes_moving[clss] = [xywh]
590
+ if clss in accepted_stationary:
591
+ if clss in boxes_stationary.keys() and len(boxes_stationary[clss]) > 0:
592
+ boxes_stationary[clss].append(xywh)
593
+ else:
594
+ boxes_stationary[clss] = [xywh]
595
+
596
+ return boxes_moving, boxes_stationary
597
+
598
+ def compute_detection_borders(borders, output_dims=[1080,1920]):
599
+ det_height = output_dims[0]-1
600
+ det_width = output_dims[1]-1
601
+
602
+ for i,border in enumerate(borders):
603
+ border_l = np.array(border[0])
604
+
605
+ if list(border_l):
606
+ pass
607
+ else:
608
+ border_l=np.array([[0,0],[0,0]])
609
+
610
+ endpoints_l = [border_l[0],border_l[-1]]
611
+
612
+ border_r = np.array(border[1])
613
+ if list(border_r):
614
+ pass
615
+ else:
616
+ border_r=np.array([[0,0],[0,0]])
617
+
618
+ endpoints_r = [border_r[0],border_r[-1]]
619
+
620
+ if np.array_equal(np.array([[0,0],[0,0]]), endpoints_l):
621
+ endpoints_l = [[0,endpoints_r[0][1]],[0,endpoints_r[1][1]]]
622
+
623
+ if np.array_equal(np.array([[0,0],[0,0]]), endpoints_r):
624
+ endpoints_r = [[det_width,endpoints_l[0][1]],[det_width,endpoints_l[1][1]]]
625
+
626
+ interpolated_top = bresenham_line(endpoints_l[1][0],endpoints_l[1][1],endpoints_r[1][0],endpoints_r[1][1])
627
+
628
+ zero_range = [0,1,2,3]
629
+ height_range = [det_height,det_height-1,det_height-2,det_height-3]
630
+ width_range = [det_width,det_width-1,det_width-2,det_width-3]
631
+
632
+ if (endpoints_l[0][0] in zero_range and endpoints_r[0][1] in height_range):
633
+ y_values = np.arange(endpoints_l[0][1], det_height)
634
+ x_values = np.full_like(y_values, 0)
635
+ bottom1 = np.column_stack((x_values, y_values))
636
+
637
+ x_values = np.arange(0, endpoints_r[0][0])
638
+ y_values = np.full_like(x_values, det_height)
639
+ bottom2 = np.column_stack((x_values, y_values))
640
+
641
+ interpolated_bottom = np.vstack((bottom1, bottom2))
642
+
643
+ elif (endpoints_l[0][1] in height_range and endpoints_r[0][0] in width_range):
644
+ y_values = np.arange(endpoints_r[0][1], det_height)
645
+ x_values = np.full_like(y_values, det_width)
646
+ bottom1 = np.column_stack((x_values, y_values))
647
+
648
+ x_values = np.arange(endpoints_l[0][0], det_width)
649
+ y_values = np.full_like(x_values, det_height)
650
+ bottom2 = np.column_stack((x_values, y_values))
651
+
652
+ interpolated_bottom = np.vstack((bottom1, bottom2))
653
+
654
+ elif endpoints_l[0][0] in zero_range and endpoints_r[0][0] in width_range:
655
+ y_values = np.arange(endpoints_l[0][1], det_height)
656
+ x_values = np.full_like(y_values, 0)
657
+ bottom1 = np.column_stack((x_values, y_values))
658
+
659
+ y_values = np.arange(endpoints_r[0][1], det_height)
660
+ x_values = np.full_like(y_values, det_width)
661
+ bottom2 = np.column_stack((x_values, y_values))
662
+
663
+ bottom3_mid = bresenham_line(bottom1[-1][0],bottom1[-1][1],bottom2[-1][0],bottom2[-1][1])
664
+
665
+ interpolated_bottom = np.vstack((bottom1, bottom2, bottom3_mid))
666
+
667
+
668
+ else:
669
+ interpolated_bottom = bresenham_line(endpoints_l[0][0],endpoints_l[0][1],endpoints_r[0][0],endpoints_r[0][1])
670
+
671
+ borders[i].append(interpolated_bottom)
672
+ borders[i].append(interpolated_top)
673
+
674
+ return borders
675
+
676
+ def get_bounding_box_points(cx, cy, w, h):
677
+ top_left = (cx - w / 2, cy - h / 2)
678
+ top_right = (cx + w / 2, cy - h / 2)
679
+ bottom_right = (cx + w / 2, cy + h / 2)
680
+ bottom_left = (cx - w / 2, cy + h / 2)
681
+
682
+ corners = [top_left, top_right, bottom_right, bottom_left]
683
+
684
+ def interpolate(point1, point2, fraction):
685
+ """Interpolate between two points at a given fraction of the distance."""
686
+ return (point1[0] + fraction * (point2[0] - point1[0]),
687
+ point1[1] + fraction * (point2[1] - point1[1]))
688
+
689
+ points = []
690
+ for i in range(4):
691
+ next_i = (i + 1) % 4
692
+ points.append(corners[i])
693
+ points.append(interpolate(corners[i], corners[next_i], 1 / 3))
694
+ points.append(interpolate(corners[i], corners[next_i], 2 / 3))
695
+
696
+ return points
697
+
698
+ def classify_detections(boxes_moving, boxes_stationary, borders, img_dims, output_dims=[1080,1920]):
699
+ img_h, img_w, _ = img_dims
700
+ img_h_scaletofullHD = output_dims[1]/img_w
701
+ img_w_scaletofullHD = output_dims[0]/img_h
702
+ colors = ["yellow","orange","red","green","blue"]
703
+
704
+ borders = compute_detection_borders(borders,output_dims)
705
+
706
+ boxes_info = []
707
+
708
+ if boxes_moving or boxes_stationary:
709
+ if boxes_moving:
710
+ for item, coords in boxes_moving.items():
711
+ for coord in coords:
712
+ x = coord[0]*img_w_scaletofullHD
713
+ y = coord[1]*img_h_scaletofullHD
714
+ w = coord[2]*img_w_scaletofullHD
715
+ h = coord[3]*img_h_scaletofullHD
716
+
717
+ points_to_test = get_bounding_box_points(x, y, w, h)
718
+
719
+ complete_border = []
720
+ criticality = -1
721
+ color = None
722
+ for i,border in enumerate(reversed(borders)):
723
+ border_nonempty = [np.array(arr) for arr in border if np.array(arr).size > 0]
724
+ complete_border = np.vstack((border_nonempty))
725
+ instance_border_path = mplPath.Path(np.array(complete_border))
726
+
727
+ is_inside_borders = False
728
+ for point in points_to_test:
729
+ is_inside = instance_border_path.contains_point(point)
730
+ if is_inside:
731
+ is_inside_borders = True
732
+
733
+ if is_inside_borders:
734
+ criticality = i
735
+ color = colors[i]
736
+
737
+ if criticality == -1:
738
+ color = colors[3]
739
+
740
+ boxes_info.append([item, criticality, color, [x, y], [w, h], 1])
741
+
742
+ if boxes_stationary:
743
+ for item, coords in boxes_stationary.items():
744
+ for coord in coords:
745
+ x = coord[0]*img_w_scaletofullHD
746
+ y = coord[1]*img_h_scaletofullHD
747
+ w = coord[2]*img_w_scaletofullHD
748
+ h = coord[3]*img_h_scaletofullHD
749
+
750
+ points_to_test = get_bounding_box_points(x, y, w, h)
751
+
752
+ complete_border = []
753
+ criticality = -1
754
+ color = None
755
+ is_inside_borders = 0
756
+ for i,border in enumerate(reversed(borders), start=len(borders) - 1):
757
+ border_nonempty = [np.array(arr) for arr in border if np.array(arr).size > 0]
758
+ complete_border = np.vstack(border_nonempty)
759
+ instance_border_path = mplPath.Path(np.array(complete_border))
760
+
761
+ is_inside_borders = False
762
+ for point in points_to_test:
763
+ is_inside = instance_border_path.contains_point(point)
764
+ if is_inside:
765
+ is_inside_borders = True
766
+
767
+ if is_inside_borders:
768
+ criticality = i
769
+ color = colors[4]
770
+
771
+ if criticality == -1:
772
+ color = colors[3]
773
+
774
+ boxes_info.append([item, criticality, color, [x, y], [w, h], 0])
775
+
776
+ return boxes_info
777
+
778
+ else:
779
+ print("No accepted detections in this image.")
780
+ return []
781
+
782
+ def draw_classification(classification, id_map):
783
+ if classification:
784
+ for box in classification:
785
+ x,y = box[3]
786
+ mark_value = 30
787
+
788
+ x_start = int(max(x - 2, 0))
789
+ x_end = int(min(x + 3, id_map.shape[1]))
790
+ y_start = int(max(y - 2, 0))
791
+ y_end = int(min(y + 3, id_map.shape[0]))
792
+
793
+ id_map[y_start:y_end, x_start:x_end] = mark_value
794
+ else:
795
+ return
796
+
797
+ def show_result(classification, id_map, names, borders, image, regions, file_index):
798
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
799
+ image = cv2.resize(image, (id_map.shape[1], id_map.shape[0]), interpolation = cv2.INTER_LINEAR)
800
+ fig = plt.figure(figsize=(16, 9), dpi=100)
801
+ plt.imshow(image, cmap='gray')
802
+
803
+ if classification:
804
+ for box in classification:
805
+
806
+ boxes = True
807
+ cx,cy = box[3]
808
+ name = names[box[0]]
809
+ if boxes:
810
+ w,h = box[4]
811
+ x = cx - w / 2
812
+ y = cy - h / 2
813
+ rect = patches.Rectangle((x, y), w, h, linewidth=2, edgecolor=box[2], facecolor='none')
814
+
815
+ ax = plt.gca()
816
+ ax.add_patch(rect)
817
+ plt.text(x, y-17, name, color='black', fontsize=10, ha='center', va='center', fontweight='bold', bbox=dict(facecolor=box[2], edgecolor='none', alpha=1))
818
+ else:
819
+ plt.imshow(id_map, cmap='gray')
820
+ plt.text(cx, cy+10, name, color=box[2], fontsize=10, ha='center', va='center', fontweight='bold')
821
+
822
+ for region in regions:
823
+ for side in region:
824
+ for line in side:
825
+ line = np.array(line)
826
+ plt.plot(line[:,1], line[:,0] ,'-', color='lightgrey', marker=None, linewidth=0.5)
827
+ plt.ylim(0, 1080)
828
+ plt.xlim(0, 1920)
829
+ plt.gca().invert_yaxis()
830
+
831
+ colors = ['yellow','orange','red']
832
+ borders.reverse()
833
+ for i,border in enumerate(borders):
834
+ for side in border:
835
+ side = np.array(side)
836
+ if side.size > 0:
837
+ plt.plot(side[:,0],side[:,1] ,'-', color=colors[i], marker=None, linewidth=0.6) #color=colors[i]
838
+ plt.ylim(0, 1080)
839
+ plt.xlim(0, 1920)
840
+ plt.gca().invert_yaxis()
841
+
842
+ plt.show()
843
+ #plt.tight_layout()
844
+ #plt.savefig(f'Grafika/Video_export/frames_estimated/frame_{file_index:04d}.jpg', format='jpg', bbox_inches='tight')
845
+ #plt.close()
846
+ print('Frame processed successfully.')
847
+
848
+ def run(model_seg, model_det, image_size, filepath_img, PATH_jpgs, dataset_type, model_type, target_distances, file_index, vis, item=None, num_ys = 15):
849
+
850
+ segmentation_mask, image = segment(model_seg, image_size, filepath_img, PATH_jpgs, dataset_type, model_type, item)
851
+ print('File: {}'.format(filepath_img))
852
+
853
+ # Border search
854
+ clues = get_clues(segmentation_mask, num_ys)
855
+ #edges = find_edges(segmentation_mask, clues, min_width=int(segmentation_mask.shape[1]*0.02))
856
+ edges = find_edges(segmentation_mask, clues, min_width=0)
857
+ #id_map_marked = mark_edges(segmentation_mask, edges)
858
+
859
+ borders, id_map, regions = border_handler(segmentation_mask, image, edges, target_distances)
860
+
861
+ # Detection
862
+ results, model, image = detect(model_det, filepath_img, PATH_jpgs)
863
+ boxes_moving, boxes_stationary = manage_detections(results, model)
864
+
865
+ classification = classify_detections(boxes_moving, boxes_stationary, borders, image.shape, output_dims=segmentation_mask.shape)
866
+
867
+ #draw_classification(classification, id_map)
868
+ show_result(classification, id_map, model.names, borders, image, regions, file_index)
869
+
870
+ if __name__ == "__main__":
871
+
872
+ data_type = 'railsem19' #railsem19, pilsen or testdata
873
+ model_type = "segformer" #segformer or deeplab
874
+ vis = False
875
+ image_size = [1024,1024]
876
+ target_distances = [650,1000,2000] #[600,1000,2000] [4000,5500,6500] [2000,3000,4000]
877
+ num_ys = 10
878
+
879
+ if data_type == 'pilsen':
880
+ file_index = 0
881
+ model_seg = load_model(PATH_model_seg)
882
+ model_det = load_yolo(PATH_model_det)
883
+ for item in enumerate(data_json["data"]):
884
+ filepath_img = item[1][1]["path"]
885
+ run(model_seg, model_det, image_size, filepath_img, PATH_base, data_type, model_type, target_distances, file_index, vis=vis, item=item, num_ys=num_ys)
886
+ elif data_type == 'railsem19':
887
+ file_index = 0
888
+ model_seg = load_model(PATH_model_seg)
889
+ model_det = load_yolo(PATH_model_det)
890
+ for filename_img in os.listdir(PATH_jpgs):
891
+ #filename_img = "rs07650.jpg"
892
+ run(model_seg, model_det, image_size, filename_img, PATH_jpgs, data_type, model_type, target_distances, file_index, vis=vis, item=None, num_ys=num_ys)
893
+ file_index += 1
894
+ else:
895
+ file_index = 0
896
+ PATH_jpgs = 'Grafika/Video_export/frames'
897
+ model_seg = load_model(PATH_model_seg)
898
+ model_det = load_yolo(PATH_model_det)
899
+ for filename_img in os.listdir(PATH_jpgs):
900
+ if os.path.exists(os.path.join('Grafika/Video_export/frames_estimated', filename_img)):
901
+ file_index += 1
902
+ continue
903
+ else:
904
+ run(model_seg, model_det, image_size , filename_img, PATH_jpgs, data_type, model_type, target_distances, file_index, vis=vis, item=None, num_ys=num_ys)
905
+ file_index += 1
test_filtered_cls.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch
4
+ import cv2
5
+ import os
6
+ import torch.nn as nn
7
+ import albumentations as A
8
+ from albumentations.pytorch import ToTensorV2
9
+ import torch.nn.functional as F
10
+ from scripts.metrics_filtered_cls import compute_map_cls, compute_IoU, image_morpho
11
+ from rs19_val.example_vis import rs19_label2bgr
12
+
13
+ PATH_jpgs = 'assets/rs19val/jpgs/test'
14
+ PATH_masks = 'assets/rs19val/uint8/test'
15
+ PATH_model = 'assets/models_pretrained/segformer/SegFormer_B3_1024_finetuned.pth'
16
+
17
+ def load(filename, PATH_jpgs, input_size=[224,224], dataset_type='rs19val', item = None):
18
+ transform_img = A.Compose([
19
+ A.Resize(height=input_size[0], width=input_size[1], interpolation=cv2.INTER_NEAREST),
20
+ A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
21
+ ToTensorV2(p=1.0),
22
+ ])
23
+ transform_mask = A.Compose([
24
+ A.Resize(height=input_size[0], width=input_size[1], interpolation=cv2.INTER_NEAREST),
25
+ ToTensorV2(p=1.0),
26
+ ])
27
+
28
+ if dataset_type == 'pilsen':
29
+ mask_pth = item[1][1]["masks"]["ground_truth"]["path"]
30
+ mask_pth = os.path.join(PATH_jpgs, mask_pth)
31
+ elif dataset_type == 'railsem19':
32
+ mask_pth = os.path.join(PATH_masks, filename).replace('.jpg', '.png')
33
+ else:
34
+ mask_pth = "rs19_val/jpgs/placeholder_mask.png"
35
+
36
+ image_in = cv2.imread(os.path.join(PATH_jpgs, filename))
37
+ mask = cv2.imread(mask_pth, cv2.IMREAD_GRAYSCALE)
38
+
39
+ if dataset_type == 'testdata':
40
+ image_in = cv2.resize(image_in, (1920, 1080))
41
+
42
+ image_tr = transform_img(image=image_in)['image']
43
+ image_tr = image_tr.unsqueeze(0)
44
+ image_vis = transform_mask(image=image_in)['image']
45
+ mask = transform_mask(image=mask)['image']
46
+ mask_id_map = np.array(mask.cpu().detach().numpy(), dtype=np.uint8)
47
+
48
+ image_tr = image_tr.cpu()
49
+
50
+ return image_tr, image_vis, image_in, mask, mask_id_map
51
+
52
+ def load_model(path_model):
53
+
54
+ model = torch.load(path_model, map_location=torch.device('cpu'))
55
+ model = model.cpu()
56
+ model.eval()
57
+ return model
58
+
59
+ def remap_ignored_clss(id_map):
60
+ ignore_list = [0,1,2,6,8,9,15,16,19,20]
61
+ for cls in ignore_list:
62
+ id_map[id_map==cls] = 255
63
+
64
+ ignore_set = set(ignore_list)
65
+ cls_remaining = [num for num in range(0, 22) if num not in ignore_set]
66
+
67
+ # renumber the remaining classes 0-number of remaining classes
68
+ for idx, cls in enumerate(cls_remaining):
69
+ id_map[id_map==cls] = idx
70
+
71
+ id_map[id_map==255] = 12 # background
72
+
73
+ return id_map
74
+
75
+ def prepare_for_display(mask, image, id_map, rs19_label2bgr, image_size = [224,224]):
76
+ # Mask + prediction preparation
77
+ mask = mask + 1
78
+ mask[mask==256] = 0
79
+ mask = remap_ignored_clss(mask)
80
+ mask = (mask + 100).detach().numpy().squeeze().astype(np.uint8)
81
+ mask_rgb = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)
82
+
83
+ # Opacity channel addition to both mask and img
84
+ alpha_channel = np.full((mask.shape[0], mask.shape[1]), 255, dtype=np.uint8)
85
+ back_ids = mask==100
86
+ alpha_channel[back_ids] = 0
87
+ rgba_mask = cv2.merge((mask_rgb, alpha_channel))
88
+
89
+ image = np.array(image.cpu().detach().numpy(), dtype=np.uint8)
90
+ rgba_img = cv2.merge((image.transpose(1, 2, 0), alpha_channel))
91
+
92
+ # Label colors + background
93
+ rgbs = list(rs19_label2bgr.values())
94
+ rgbs.append((255,255,255))
95
+
96
+ blend_sources = np.zeros((image_size[0], image_size[1], 3), dtype=np.uint8)
97
+ for class_id in range(21):
98
+ class_pixels = id_map == class_id
99
+ rgb_color = np.array(rgbs[class_id])
100
+
101
+ for i in range(3):
102
+ blend_sources[:,:,i] = blend_sources[:,:,i] + (rgb_color[i] * class_pixels).astype(np.uint8)
103
+
104
+ # Opacity channel for the rgb class mask and merge with input mask
105
+ alpha_channel_blend = np.full((blend_sources.shape[0], blend_sources.shape[1]), 150, dtype=np.uint8)
106
+ rgba_blend = cv2.merge((blend_sources , alpha_channel_blend))
107
+ blend_sources = (rgba_blend * 0.1 + rgba_img * 0.9).astype(np.uint8)
108
+
109
+ return(rgba_mask, rgba_blend, blend_sources)
110
+
111
+ def visualize(rgba_blend, rgba_mask):
112
+ # CV2 VIZUALISATION
113
+ image1 = rgba_blend
114
+ image2 = rgba_mask
115
+
116
+ initial_opacity1 = 0.05
117
+ initial_opacity2 = 0.95
118
+ # Load two smaller images
119
+ small_image1 = cv2.resize(image1, (300, 300), interpolation=cv2.INTER_NEAREST)
120
+ small_image2 = cv2.resize(image2, (300, 300), interpolation=cv2.INTER_NEAREST)
121
+
122
+ # Create a blank canvas for the combined visualization
123
+ combined_image = np.zeros((600, 900, 4), dtype=np.uint8) # Adjust the size as needed
124
+
125
+ # Main loop for adjusting opacity and displaying the images
126
+ cv2.namedWindow('{} | mAP:{:.3f} | MmAP:{:.3f} '.format(filename, map, Mmap), cv2.WINDOW_NORMAL)
127
+ cv2.resizeWindow('{} | mAP:{:.3f} | MmAP:{:.3f} '.format(filename, map, Mmap), 900, 600) # Adjust the size as needed
128
+
129
+ while True:
130
+
131
+ overlay_image = image1.copy()
132
+ overlay_image[:, :, 3] = (image1[:, :, 3] * initial_opacity1).astype(np.uint8)
133
+
134
+ alpha = (image2[:, :, 3] * initial_opacity2).astype(float)
135
+ beta = 1.0 - alpha / 255.0
136
+
137
+ blended_image = np.empty_like(overlay_image)
138
+ blended_image[:, :, :3] = (overlay_image[:, :, :3] * alpha[:, :, np.newaxis] + image2[:, :, :3] * beta[:, :, np.newaxis]).astype(np.uint8)
139
+ blended_image[:, :, 3] = (overlay_image[:, :, 3] * alpha + image2[:, :, 3] * beta).astype(np.uint8)
140
+
141
+ blended_image = (image1 * initial_opacity1 + image2 * initial_opacity2).astype(np.uint8)
142
+
143
+ blended_image_resized = cv2.resize(blended_image, (600, 600)) # Adjust the size as needed
144
+ combined_image[:, :600, :] = blended_image_resized
145
+
146
+ # Copy the smaller images to the right portion of the canvas
147
+ combined_image[0:300, 600:900, :] = small_image1[:, :, :]
148
+ combined_image[300:600, 600:900, :] = small_image2[:, :, :]
149
+
150
+ cv2.imshow('{} | mAP:{:.3f} | MmAP:{:.3f} '.format(filename, map, Mmap), combined_image)
151
+
152
+ key = cv2.waitKey(1) & 0xFF
153
+ if key == ord('q'):
154
+ break
155
+ elif key == ord('a'):
156
+ initial_opacity1 += 0.1
157
+ initial_opacity1 = min(initial_opacity1, 1.0)
158
+ elif key == ord('s'):
159
+ initial_opacity1 -= 0.1
160
+ initial_opacity1 = max(initial_opacity1, 0.0)
161
+ elif key == ord('z'):
162
+ initial_opacity2 += 0.1
163
+ initial_opacity2 = min(initial_opacity2, 1.0)
164
+ elif key == ord('x'):
165
+ initial_opacity2 -= 0.1
166
+ initial_opacity2 = max(initial_opacity2, 0.0)
167
+
168
+ cv2.destroyAllWindows()
169
+
170
+ def stats_mean_and_reorder(classes_ap,classes_Map,classes_stats,classes_Mstats):
171
+ for cls, value in classes_ap.items():
172
+ classes_ap[cls] = np.divide(value[0], value[1])
173
+ classes_ap['all']= np.mean(np.array(list(classes_ap.values())), axis=0)
174
+
175
+ for cls, value in classes_Map.items():
176
+ classes_Map[cls] = np.divide(value[0], value[1])
177
+ classes_Map['all']= np.mean(np.array(list(classes_Map.values())), axis=0)
178
+
179
+ for cls, value in classes_stats.items():
180
+ classes_stats[cls] = np.divide(value[0], value[1])
181
+ classes_stats['all']= np.mean(np.array(list(classes_stats.values()))[:, :4], axis=0)
182
+
183
+ for cls, value in classes_Mstats.items():
184
+ classes_Mstats[cls] = np.divide(value[0], value[1])
185
+ classes_Mstats['all']= np.mean(np.array(list(classes_Mstats.values()))[:, :4], axis=0)
186
+
187
+ for cls, value in classes_Mstats.items():
188
+ classes_stats[cls] = np.insert(classes_stats[cls], 1, value[0])
189
+ classes_stats[cls] = np.insert(classes_stats[cls], 3, value[1])
190
+ classes_stats[cls] = np.insert(classes_stats[cls], 5, value[2])
191
+ classes_stats[cls] = np.insert(classes_stats[cls], 7, value[3])
192
+
193
+ return classes_ap,classes_Map,classes_stats,classes_Mstats
194
+
195
+ def process(model, input_img, mask, model_type):
196
+ if model_type == "segformer":
197
+ outputs = model(input_img) # segformer
198
+ elif model_type == "deeplab":
199
+ outputs = model(input_img)['out'] # deeplab resnet
200
+
201
+ logits = outputs.logits
202
+ upsampled_logits = nn.functional.interpolate(
203
+ logits,
204
+ size=mask.shape[-2:],
205
+ mode="bilinear",
206
+ align_corners=False
207
+ )
208
+
209
+ output = upsampled_logits.float()
210
+
211
+ confidence_scores = F.softmax(output, dim=1).cpu().detach().numpy().squeeze()
212
+ id_map = np.argmax(confidence_scores, axis=0).astype(np.uint8)
213
+ id_map = image_morpho(id_map)
214
+
215
+ return id_map
216
+
217
+ if __name__ == "__main__":
218
+ mAPs,MmAPs,IoUs,MIoUs,accs,Maccs,precs,Mprecs,recs,Mrecs= list(),list(),list(),list(),list(),list(),list(),list(),list(),list()
219
+ classes_ap,classes_Map,classes_stats,classes_Mstats = {},{},{},{}
220
+ images_computed = 0
221
+
222
+ for filename in os.listdir(PATH_jpgs):
223
+ images_computed += 1
224
+
225
+ vis = False
226
+ to_break = False
227
+ image_size = [1024,1024]
228
+
229
+ if to_break:
230
+ if images_computed > 50:
231
+ break
232
+
233
+ model_type = "segformer" #"deeplab"
234
+ dataset_type = 'rs19val'
235
+ image_norm, image, _, mask, id_map_gt = load(filename, PATH_jpgs, image_size, dataset_type)
236
+ model = load_model(PATH_model)
237
+ # INFERENCE + SOFTMAX
238
+ id_map = process(model, image_norm, mask, model_type)
239
+
240
+ # mAP
241
+ id_map_gt = remap_ignored_clss(id_map_gt)
242
+ map,classes_ap = compute_map_cls(id_map_gt, id_map, classes_ap)
243
+ Mmap,classes_Map = compute_map_cls(id_map_gt, id_map, classes_Map, major = True)
244
+ IoU,acc,prec,rec,classes_stats = compute_IoU(id_map_gt, id_map, classes_stats)
245
+ MIoU,Macc,Mprec,Mrec,classes_Mstats = compute_IoU(id_map_gt, id_map, classes_Mstats, major=True)
246
+
247
+ print('{} | mAP:{:.3f}/{:.3f} | IoU:{:.3f}/{:.3f} | prec:{:.3f}/{:.3f} | rec:{:.3f}/{:.3f} | acc:{:.3f}/{:.3f}'.format(filename,map,Mmap,IoU,MIoU,prec,Mprec,rec,Mrec,acc,Macc))
248
+ mAPs.append(map)
249
+ MmAPs.append(Mmap)
250
+ IoUs.append(IoU)
251
+ MIoUs.append(MIoU)
252
+ accs.append(acc)
253
+ Maccs.append(Macc)
254
+ precs.append(prec)
255
+ Mprecs.append(Mprec)
256
+ recs.append(rec)
257
+ Mrecs.append(Mrec)
258
+
259
+ if vis:
260
+ rgba_mask, rgba_blend, blend_sources = prepare_for_display(mask, image, id_map, rs19_label2bgr, image_size)
261
+ visualize(rgba_blend, rgba_mask)
262
+
263
+ mAPs_avg, MmAPs_avg = np.nanmean(mAPs), np.nanmean(MmAPs)
264
+ IoUs_avg, MIoUs_avg = np.nanmean(IoUs), np.nanmean(MIoUs)
265
+ accs_avg, Maccs_avg = np.nanmean(accs), np.nanmean(Maccs)
266
+ precs_avg, Mprecs_avg = np.nanmean(precs), np.nanmean(Mprecs)
267
+ recs_avg, Mrecs_avg = np.nanmean(recs), np.nanmean(Mrecs)
268
+
269
+ print('All | mAP:{:.3f}/{:.3f} | IoU:{:.3f}/{:.3f} | prec:{:.3f}/{:.3f} | rec:{:.3f}/{:.3f} | acc:{:.3f}/{:.3f}'.format(mAPs_avg,MmAPs_avg,IoUs_avg,MIoUs_avg,precs_avg,Mprecs_avg,recs_avg,Mrecs_avg,accs_avg,Maccs_avg))
270
+ print('mAP: {:.3f}-{:.3f} | MmAP: {:.3f}-{:.3f} | IoU: {:.3f}-{:.3f} | MIoU: {:.3f}-{:.3f}'.format(np.nanmin(mAPs), np.nanmax(mAPs), np.nanmin(MmAPs), np.nanmax(MmAPs),np.nanmin(IoUs), np.nanmax(IoUs), np.nanmin(MIoUs), np.nanmax(MIoUs)))
271
+
272
+ classes_ap,classes_Map,classes_stats,classes_Mstats = stats_mean_and_reorder(classes_ap,classes_Map,classes_stats,classes_Mstats)
273
+
274
+ df_ap = pd.DataFrame(list(classes_ap.items()), columns=['Class', 'mAP'])
275
+ df_Map = pd.DataFrame(list(classes_Map.items()), columns=['Class', 'MmAP'])
276
+
277
+ classes_stats_flat = [(key, *value) for key, value in classes_stats.items()]
278
+ df_stats = pd.DataFrame(classes_stats_flat, columns=['Class','IoU','MIoU', 'acc','Macc', 'precision','Mprecision','recall','Mrecall'])
279
+
280
+ df_merged = pd.merge(df_ap, df_Map, on='Class', how='outer')
281
+ df_merged = pd.merge(df_merged, df_stats, on='Class', how='outer')
282
+
283
+ print(df_merged)