josedolot commited on
Commit
549e090
·
1 Parent(s): aff8ba8

Upload hybridnets/dataset.py

Browse files
Files changed (1) hide show
  1. hybridnets/dataset.py +409 -0
hybridnets/dataset.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ # np.set_printoptions(threshold=np.inf)
4
+ import random
5
+ import torch
6
+ import torchvision.transforms as transforms
7
+ from pathlib import Path
8
+ from torch.utils.data import Dataset
9
+ from utils.utils import letterbox, augment_hsv, random_perspective
10
+ from tqdm.autonotebook import tqdm
11
+ import json
12
+ import albumentations as A
13
+
14
+
15
+ class BddDataset(Dataset):
16
+ def __init__(self, params, is_train, inputsize=640, transform=None):
17
+ """
18
+ initial all the characteristic
19
+
20
+ Inputs:
21
+ -params: configuration parameters
22
+ -is_train(bool): whether train set or not
23
+ -transform: ToTensor and Normalize
24
+
25
+ Returns:
26
+ None
27
+ """
28
+ self.single_cls = True # just detect vehicle
29
+ self.is_train = is_train
30
+ self.params = params
31
+ self.transform = transform
32
+ self.inputsize = inputsize
33
+ self.Tensor = transforms.ToTensor()
34
+ img_root = Path(params.dataset['dataroot'])
35
+ label_root = Path(params.dataset['labelroot'])
36
+ mask_root = Path(params.dataset['maskroot'])
37
+ lane_root = Path(params.dataset['laneroot'])
38
+ if is_train:
39
+ indicator = params.dataset['train_set']
40
+ else:
41
+ indicator = params.dataset['test_set']
42
+ self.img_root = img_root / indicator
43
+ self.label_root = label_root / indicator
44
+ self.mask_root = mask_root / indicator
45
+ self.lane_root = lane_root / indicator
46
+ # self.label_list = self.label_root.iterdir()
47
+ self.mask_list = self.mask_root.iterdir()
48
+ self.data_format = params.dataset['data_format']
49
+ self.scale_factor = params.dataset['scale_factor']
50
+ self.rotation_factor = params.dataset['rot_factor']
51
+ self.flip = params.dataset['flip']
52
+ self.color_rgb = params.dataset['color_rgb']
53
+ self.albumentations_transform = A.Compose([
54
+ A.Blur(p=0.01),
55
+ A.MedianBlur(p=0.01),
56
+ A.ToGray(p=0.01),
57
+ A.CLAHE(p=0.01),
58
+ A.RandomBrightnessContrast(p=0.01),
59
+ A.RandomGamma(p=0.01),
60
+ A.ImageCompression(quality_lower=75, p=0.01)],
61
+ bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels']),
62
+ additional_targets={'mask0': 'mask'})
63
+
64
+ # bdd_labels = {
65
+ # 'unlabeled':0, 'dynamic': 1, 'ego vehicle': 2, 'ground': 3,
66
+ # 'static': 4, 'parking': 5, 'rail track': 6, 'road': 7,
67
+ # 'sidewalk': 8, 'bridge': 9, 'building': 10, 'fence': 11,
68
+ # 'garage': 12, 'guard rail': 13, 'tunnel': 14, 'wall': 15,
69
+ # 'banner': 16, 'billboard': 17, 'lane divider': 18,'parking sign': 19,
70
+ # 'pole': 20, 'polegroup': 21, 'street light': 22, 'traffic cone': 23,
71
+ # 'traffic device': 24, 'traffic light': 25, 'traffic sign': 26, 'traffic sign frame': 27,
72
+ # 'terrain': 28, 'vegetation': 29, 'sky': 30, 'person': 31,
73
+ # 'rider': 32, 'bicycle': 33, 'bus': 34, 'car': 35,
74
+ # 'caravan': 36, 'motorcycle': 37, 'trailer': 38, 'train': 39,
75
+ # 'truck': 40
76
+ # }
77
+ self.id_dict = {'person': 0, 'rider': 1, 'car': 2, 'bus': 3, 'truck': 4,
78
+ 'bike': 5, 'motor': 6, 'tl_green': 7, 'tl_red': 8,
79
+ 'tl_yellow': 9, 'tl_none': 10, 'traffic sign': 11, 'train': 12}
80
+ self.id_dict_single = {'car': 0, 'bus': 1, 'truck': 2, 'train': 3}
81
+ # id_dict = {'car': 0, 'bus': 1, 'truck': 2}
82
+
83
+ self.shapes = np.array(params.dataset['org_img_size'])
84
+ self.db = self._get_db()
85
+
86
+ def _get_db(self):
87
+ """
88
+ get database from the annotation file
89
+
90
+ Inputs:
91
+
92
+ Returns:
93
+ gt_db: (list)database [a,b,c,...]
94
+ a: (dictionary){'image':, 'information':, ......}
95
+ image: image path
96
+ mask: path of the segmetation label
97
+ label: [cls_id, center_x//256, center_y//256, w//256, h//256] 256=IMAGE_SIZE
98
+ """
99
+ print('building database...')
100
+ gt_db = []
101
+ height, width = self.shapes
102
+ for mask in tqdm(list(self.mask_list)):
103
+ mask_path = str(mask)
104
+ label_path = mask_path.replace(str(self.mask_root), str(self.label_root)).replace(".png", ".json")
105
+ image_path = mask_path.replace(str(self.mask_root), str(self.img_root)).replace(".png", ".jpg")
106
+ lane_path = mask_path.replace(str(self.mask_root), str(self.lane_root))
107
+ with open(label_path, 'r') as f:
108
+ label = json.load(f)
109
+ data = label['frames'][0]['objects']
110
+ data = self.select_data(data)
111
+ gt = np.zeros((len(data), 5))
112
+ for idx, obj in enumerate(data):
113
+ category = obj['category']
114
+ if category == "traffic light":
115
+ color = obj['attributes']['trafficLightColor']
116
+ category = "tl_" + color
117
+ if category in self.id_dict.keys():
118
+ x1 = float(obj['box2d']['x1'])
119
+ y1 = float(obj['box2d']['y1'])
120
+ x2 = float(obj['box2d']['x2'])
121
+ y2 = float(obj['box2d']['y2'])
122
+ cls_id = self.id_dict[category]
123
+ if self.single_cls:
124
+ cls_id = 0
125
+ gt[idx][0] = cls_id
126
+ box = self.convert((width, height), (x1, x2, y1, y2))
127
+ gt[idx][1:] = list(box)
128
+
129
+ rec = [{
130
+ 'image': image_path,
131
+ 'label': gt,
132
+ 'mask': mask_path,
133
+ 'lane': lane_path
134
+ }]
135
+
136
+ # img = cv2.imread(image_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_UNCHANGED)
137
+ # # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
138
+ # for label in gt:
139
+ # # print(label[1])
140
+ # x1 = label[1] - label[3] / 2
141
+ # x1 *= 1280
142
+ # x1 = int(x1)
143
+ # # print(x1)
144
+ # x2 = label[1] + label[3] / 2
145
+ # x2 *= 1280
146
+ # x2 = int(x2)
147
+ # y1 = label[2] - label[4] / 2
148
+ # y1 *= 720
149
+ # y1 = int(y1)
150
+ # y2 = label[2] + label[4] / 2
151
+ # y2 *= 720
152
+ # y2 = int(y2)
153
+ # img = cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
154
+ # cv2.imwrite('gt/{}'.format(image_path.split('/')[-1]), img)
155
+
156
+ gt_db += rec
157
+ print('database build finish')
158
+ return gt_db
159
+
160
+
161
+ def evaluate(self, params, preds, output_dir):
162
+ """
163
+ finished on children dataset
164
+ """
165
+ raise NotImplementedError
166
+
167
+ def __len__(self, ):
168
+ """
169
+ number of objects in the dataset
170
+ """
171
+ return len(self.db)
172
+
173
+ def __getitem__(self, idx):
174
+ """
175
+ Get input and groud-truth from database & add data augmentation on input
176
+
177
+ Inputs:
178
+ -idx: the index of image in self.db(database)(list)
179
+ self.db(list) [a,b,c,...]
180
+ a: (dictionary){'image':, 'information':}
181
+
182
+ Returns:
183
+ -image: transformed image, first passed the data augmentation in __getitem__ function(type:numpy), then apply self.transform
184
+ -target: ground truth(det_gt,seg_gt)
185
+
186
+ function maybe useful
187
+ cv2.imread
188
+ cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
189
+ cv2.warpAffine
190
+ """
191
+ data = self.db[idx]
192
+ img = cv2.imread(data["image"], cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
193
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
194
+
195
+ if self.params.num_seg_class == 3:
196
+ seg_label = cv2.imread(data["mask"])
197
+ else:
198
+ seg_label = cv2.imread(data["mask"], 0)
199
+ lane_label = cv2.imread(data["lane"], 0)
200
+
201
+ # print(lane_label.shape)
202
+ # print(seg_label.shape)
203
+ # print(lane_label.shape)
204
+ # print(seg_label.shape)
205
+ resized_shape = self.inputsize
206
+ if isinstance(resized_shape, list):
207
+ resized_shape = max(resized_shape)
208
+ h0, w0 = img.shape[:2] # orig hw
209
+ r = resized_shape / max(h0, w0) # resize image to img_size
210
+ if r != 1: # always resize down, only resize up if training with augmentation
211
+ interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
212
+ img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
213
+ seg_label = cv2.resize(seg_label, (int(w0 * r), int(h0 * r)), interpolation=interp)
214
+ lane_label = cv2.resize(lane_label, (int(w0 * r), int(h0 * r)), interpolation=interp)
215
+ h, w = img.shape[:2]
216
+
217
+ (img, seg_label, lane_label), ratio, pad = letterbox((img, seg_label, lane_label), resized_shape, auto=True,
218
+ scaleup=self.is_train)
219
+ shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
220
+ # ratio = (w / w0, h / h0)
221
+ # print(resized_shape)
222
+
223
+ det_label = data["label"]
224
+ # print(det_label)
225
+
226
+ labels = []
227
+ labels_app = np.array([])
228
+
229
+ if det_label.size > 0:
230
+ # Normalized xywh to pixel xyxy format
231
+ labels = det_label.copy()
232
+ labels[:, 1] = ratio[0] * w * (det_label[:, 1] - det_label[:, 3] / 2) + pad[0] # pad width
233
+ labels[:, 2] = ratio[1] * h * (det_label[:, 2] - det_label[:, 4] / 2) + pad[1] # pad height
234
+ labels[:, 3] = ratio[0] * w * (det_label[:, 1] + det_label[:, 3] / 2) + pad[0]
235
+ labels[:, 4] = ratio[1] * h * (det_label[:, 2] + det_label[:, 4] / 2) + pad[1]
236
+
237
+ # print(labels[:, 1:4])
238
+ if self.is_train:
239
+ # albumentations
240
+ try:
241
+ new = self.albumentations_transform(image=img, mask=seg_label, mask0=lane_label,
242
+ bboxes=labels[:, 1:] if len(labels) else labels,
243
+ class_labels=labels[:, 0] if len(labels) else labels)
244
+ img = new['image']
245
+ labels = np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])]) if len(labels) else labels
246
+ seg_label = new['mask']
247
+ lane_label = new['mask0']
248
+ except ValueError: # bbox have width or height == 0
249
+ pass
250
+
251
+ # augmentation
252
+ combination = (img, seg_label, lane_label)
253
+ (img, seg_label, lane_label), labels = random_perspective(
254
+ combination=combination,
255
+ targets=labels,
256
+ degrees=self.params.dataset['rot_factor'],
257
+ translate=self.params.dataset['translate'],
258
+ scale=self.params.dataset['scale_factor'],
259
+ shear=self.params.dataset['shear']
260
+ )
261
+ # print(labels.shape)
262
+ augment_hsv(img, hgain=self.params.dataset['hsv_h'], sgain=self.params.dataset['hsv_s'], vgain=self.params.dataset['hsv_v'])
263
+ # img, seg_label, labels = cutout(combination=combination, labels=labels)
264
+
265
+ # random left-right flip
266
+ lr_flip = True
267
+ if lr_flip and random.random() < 0.5:
268
+ img = img[:, ::-1, :]
269
+
270
+ if len(labels):
271
+ rows, cols, channels = img.shape
272
+
273
+ x1 = labels[:, 1].copy()
274
+ x2 = labels[:, 3].copy()
275
+
276
+ x_tmp = x1.copy()
277
+
278
+ labels[:, 1] = cols - x2
279
+ labels[:, 3] = cols - x_tmp
280
+
281
+ # Segmentation
282
+ seg_label = np.fliplr(seg_label)
283
+ lane_label = np.fliplr(lane_label)
284
+
285
+ # cv2.imwrite('img0.jpg',img)
286
+ # cv2.imwrite('img1.jpg',seg_label)
287
+ # cv2.imwrite('img2.jpg',lane_label)
288
+
289
+ # exit()
290
+
291
+ # print(labels)
292
+
293
+ # random up-down flip
294
+ ud_flip = False
295
+ if ud_flip and random.random() < 0.5:
296
+ img = np.flipud(img)
297
+ seg_label = np.filpud(seg_label)
298
+ lane_label = np.filpud(lane_label)
299
+ if len(labels):
300
+ rows, cols, channels = img.shape
301
+
302
+ y1 = labels[:, 2].copy()
303
+ y2 = labels[:, 4].copy()
304
+
305
+ y_tmp = y1.copy()
306
+
307
+ labels[:, 2] = rows - y2
308
+ labels[:, 4] = rows - y_tmp
309
+
310
+ # for anno in labels:
311
+ # x1, y1, x2, y2 = [int(x) for x in anno[1:5]]
312
+ # print(x1,y1,x2,y2)
313
+ # cv2.rectangle(img, (x1,y1), (x2,y2), (0,0,255), 3)
314
+ # cv2.imwrite(data["image"].split("/")[-1], img)
315
+
316
+ if len(labels):
317
+ labels_app = np.zeros((len(labels), 5))
318
+ labels_app[:, 0:4] = labels[:, 1:5]
319
+ labels_app[:, 4] = labels[:, 0]
320
+
321
+ img = np.ascontiguousarray(img)
322
+
323
+ _, seg1 = cv2.threshold(seg_label, 1, 255, cv2.THRESH_BINARY)
324
+ _, lane1 = cv2.threshold(lane_label, 1, 255, cv2.THRESH_BINARY)
325
+ # prefer lane
326
+ seg1 = seg1 - (seg1 & lane1)
327
+
328
+ union = seg1 | lane1
329
+ # print(union.shape)
330
+ background = 255 - union
331
+
332
+ # print(img.shape)
333
+ # print(lane1.shape)
334
+ # img_copy = img.copy()
335
+ # img_copy[lane1 == 255] = (0, 255, 0)
336
+ # cv2.imwrite('seg_gt/' + data['image'].split('/')[-1], img_copy)
337
+ # cv2.imwrite('background.jpg', background)
338
+ # cv2.imwrite('{}.jpg'.format(data['image'].split('/')[-1]), img)
339
+ # cv2.imwrite('{}-lane.jpg'.format(data['image'].split('/')[-1]),lane1)
340
+ # cv2.imwrite('{}-seg.jpg'.format(data['image'].split('/')[-1]),seg1)
341
+
342
+ seg1 = self.Tensor(seg1)
343
+ lane1 = self.Tensor(lane1)
344
+ background = self.Tensor(background)
345
+
346
+ segmentation = torch.cat([background, seg1, lane1], dim=0)
347
+ # print(segmentation.size())
348
+ # print(seg1.shape)
349
+
350
+ # for anno in labels_app:
351
+ # x1, y1, x2, y2 = [int(x) for x in anno[anno != -1][:4]]
352
+ # cv2.rectangle(img, (x1,y1), (x2,y2), (0,0,255), 1)
353
+ # cv2.imwrite(data["image"].split("/")[-1], img)
354
+
355
+ img = self.transform(img)
356
+
357
+ return img, data["image"], shapes, torch.from_numpy(labels_app), segmentation
358
+
359
+ def select_data(self, db):
360
+ """
361
+ You can use this function to filter useless images in the dataset
362
+
363
+ Inputs:
364
+ -db: (list)database
365
+
366
+ Returns:
367
+ -db_selected: (list)filtered dataset
368
+ """
369
+ remain = []
370
+ for obj in db:
371
+ if 'box2d' in obj.keys(): # obj.has_key('box2d'):
372
+ if self.single_cls:
373
+ if obj['category'] in self.id_dict_single.keys():
374
+ remain.append(obj)
375
+ else:
376
+ remain.append(obj)
377
+ return remain
378
+
379
+ def convert(self, size, box):
380
+ dw = 1. / (size[0])
381
+ dh = 1. / (size[1])
382
+ x = (box[0] + box[1]) / 2.0
383
+ y = (box[2] + box[3]) / 2.0
384
+ w = box[1] - box[0]
385
+ h = box[3] - box[2]
386
+ x = x * dw
387
+ w = w * dw
388
+ y = y * dh
389
+ h = h * dh
390
+ return x, y, w, h
391
+
392
+ @staticmethod
393
+ def collate_fn(batch):
394
+ img, paths, shapes, labels_app, segmentation = zip(*batch)
395
+ filenames = [file.split('/')[-1] for file in paths]
396
+ # print(len(labels_app))
397
+ max_num_annots = max(label.size(0) for label in labels_app)
398
+
399
+ if max_num_annots > 0:
400
+ annot_padded = torch.ones((len(labels_app), max_num_annots, 5)) * -1
401
+ for idx, label in enumerate(labels_app):
402
+ if label.size(0) > 0:
403
+ annot_padded[idx, :label.size(0), :] = label
404
+ else:
405
+ annot_padded = torch.ones((len(labels_app), 1, 5)) * -1
406
+
407
+ # print("ABC", seg1.size())
408
+ return {'img': torch.stack(img, 0), 'annot': annot_padded, 'segmentation': torch.stack(segmentation, 0),
409
+ 'filenames': filenames, 'shapes': shapes}