hank1996 commited on
Commit
eb9d1af
·
1 Parent(s): 75ebbc7

Create new file

Browse files
Files changed (1) hide show
  1. lib/dataset/AutoDriveDataset.py +259 -0
lib/dataset/AutoDriveDataset.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import cv2
4
+ import numpy as np
5
+ # np.set_printoptions(threshold=np.inf)
6
+ import random
7
+ import torch
8
+ import torchvision.transforms as transforms
9
+ # from visualization import plot_img_and_mask,plot_one_box,show_seg_result
10
+ from pathlib import Path
11
+ from PIL import Image
12
+ from torch.utils.data import Dataset
13
+ from ..utils import letterbox, augment_hsv, random_perspective, xyxy2xywh, cutout
14
+
15
+
16
+ class AutoDriveDataset(Dataset):
17
+ """
18
+ A general Dataset for some common function
19
+ """
20
+ def __init__(self, cfg, is_train, inputsize=640, transform=None):
21
+ """
22
+ initial all the characteristic
23
+ Inputs:
24
+ -cfg: configurations
25
+ -is_train(bool): whether train set or not
26
+ -transform: ToTensor and Normalize
27
+
28
+ Returns:
29
+ None
30
+ """
31
+ self.is_train = is_train
32
+ self.cfg = cfg
33
+ self.transform = transform
34
+ self.inputsize = inputsize
35
+ self.Tensor = transforms.ToTensor()
36
+ img_root = Path(cfg.DATASET.DATAROOT)
37
+ label_root = Path(cfg.DATASET.LABELROOT)
38
+ mask_root = Path(cfg.DATASET.MASKROOT)
39
+ lane_root = Path(cfg.DATASET.LANEROOT)
40
+ if is_train:
41
+ indicator = cfg.DATASET.TRAIN_SET
42
+ else:
43
+ indicator = cfg.DATASET.TEST_SET
44
+ self.img_root = img_root / indicator
45
+ self.label_root = label_root / indicator
46
+ self.mask_root = mask_root / indicator
47
+ self.lane_root = lane_root / indicator
48
+ # self.label_list = self.label_root.iterdir()
49
+ self.mask_list = self.mask_root.iterdir()
50
+
51
+ self.db = []
52
+
53
+ self.data_format = cfg.DATASET.DATA_FORMAT
54
+
55
+ self.scale_factor = cfg.DATASET.SCALE_FACTOR
56
+ self.rotation_factor = cfg.DATASET.ROT_FACTOR
57
+ self.flip = cfg.DATASET.FLIP
58
+ self.color_rgb = cfg.DATASET.COLOR_RGB
59
+
60
+ # self.target_type = cfg.MODEL.TARGET_TYPE
61
+ self.shapes = np.array(cfg.DATASET.ORG_IMG_SIZE)
62
+
63
+ def _get_db(self):
64
+ """
65
+ finished on children Dataset(for dataset which is not in Bdd100k format, rewrite children Dataset)
66
+ """
67
+ raise NotImplementedError
68
+
69
+ def evaluate(self, cfg, preds, output_dir):
70
+ """
71
+ finished on children dataset
72
+ """
73
+ raise NotImplementedError
74
+
75
+ def __len__(self,):
76
+ """
77
+ number of objects in the dataset
78
+ """
79
+ return len(self.db)
80
+
81
+ def __getitem__(self, idx):
82
+ """
83
+ Get input and groud-truth from database & add data augmentation on input
84
+ Inputs:
85
+ -idx: the index of image in self.db(database)(list)
86
+ self.db(list) [a,b,c,...]
87
+ a: (dictionary){'image':, 'information':}
88
+ Returns:
89
+ -image: transformed image, first passed the data augmentation in __getitem__ function(type:numpy), then apply self.transform
90
+ -target: ground truth(det_gt,seg_gt)
91
+ function maybe useful
92
+ cv2.imread
93
+ cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
94
+ cv2.warpAffine
95
+ """
96
+ data = self.db[idx]
97
+ img = cv2.imread(data["image"], cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
98
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
99
+ # seg_label = cv2.imread(data["mask"], 0)
100
+ if self.cfg.num_seg_class == 3:
101
+ seg_label = cv2.imread(data["mask"])
102
+ else:
103
+ seg_label = cv2.imread(data["mask"], 0)
104
+ lane_label = cv2.imread(data["lane"], 0)
105
+ #print(lane_label.shape)
106
+ # print(seg_label.shape)
107
+ # print(lane_label.shape)
108
+ # print(seg_label.shape)
109
+ resized_shape = self.inputsize
110
+ if isinstance(resized_shape, list):
111
+ resized_shape = max(resized_shape)
112
+ h0, w0 = img.shape[:2] # orig hw
113
+ r = resized_shape / max(h0, w0) # resize image to img_size
114
+ if r != 1: # always resize down, only resize up if training with augmentation
115
+ interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
116
+ img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
117
+ seg_label = cv2.resize(seg_label, (int(w0 * r), int(h0 * r)), interpolation=interp)
118
+ lane_label = cv2.resize(lane_label, (int(w0 * r), int(h0 * r)), interpolation=interp)
119
+ h, w = img.shape[:2]
120
+
121
+ (img, seg_label, lane_label), ratio, pad = letterbox((img, seg_label, lane_label), resized_shape, auto=True, scaleup=self.is_train)
122
+ shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
123
+ # ratio = (w / w0, h / h0)
124
+ # print(resized_shape)
125
+
126
+ det_label = data["label"]
127
+ labels=[]
128
+
129
+ if det_label.size > 0:
130
+ # Normalized xywh to pixel xyxy format
131
+ labels = det_label.copy()
132
+ labels[:, 1] = ratio[0] * w * (det_label[:, 1] - det_label[:, 3] / 2) + pad[0] # pad width
133
+ labels[:, 2] = ratio[1] * h * (det_label[:, 2] - det_label[:, 4] / 2) + pad[1] # pad height
134
+ labels[:, 3] = ratio[0] * w * (det_label[:, 1] + det_label[:, 3] / 2) + pad[0]
135
+ labels[:, 4] = ratio[1] * h * (det_label[:, 2] + det_label[:, 4] / 2) + pad[1]
136
+
137
+ if self.is_train:
138
+ combination = (img, seg_label, lane_label)
139
+ (img, seg_label, lane_label), labels = random_perspective(
140
+ combination=combination,
141
+ targets=labels,
142
+ degrees=self.cfg.DATASET.ROT_FACTOR,
143
+ translate=self.cfg.DATASET.TRANSLATE,
144
+ scale=self.cfg.DATASET.SCALE_FACTOR,
145
+ shear=self.cfg.DATASET.SHEAR
146
+ )
147
+ #print(labels.shape)
148
+ augment_hsv(img, hgain=self.cfg.DATASET.HSV_H, sgain=self.cfg.DATASET.HSV_S, vgain=self.cfg.DATASET.HSV_V)
149
+ # img, seg_label, labels = cutout(combination=combination, labels=labels)
150
+
151
+ if len(labels):
152
+ # convert xyxy to xywh
153
+ labels[:, 1:5] = xyxy2xywh(labels[:, 1:5])
154
+
155
+ # Normalize coordinates 0 - 1
156
+ labels[:, [2, 4]] /= img.shape[0] # height
157
+ labels[:, [1, 3]] /= img.shape[1] # width
158
+
159
+ # if self.is_train:
160
+ # random left-right flip
161
+ lr_flip = True
162
+ if lr_flip and random.random() < 0.5:
163
+ img = np.fliplr(img)
164
+ seg_label = np.fliplr(seg_label)
165
+ lane_label = np.fliplr(lane_label)
166
+ if len(labels):
167
+ labels[:, 1] = 1 - labels[:, 1]
168
+
169
+ # random up-down flip
170
+ ud_flip = False
171
+ if ud_flip and random.random() < 0.5:
172
+ img = np.flipud(img)
173
+ seg_label = np.filpud(seg_label)
174
+ lane_label = np.filpud(lane_label)
175
+ if len(labels):
176
+ labels[:, 2] = 1 - labels[:, 2]
177
+
178
+ else:
179
+ if len(labels):
180
+ # convert xyxy to xywh
181
+ labels[:, 1:5] = xyxy2xywh(labels[:, 1:5])
182
+
183
+ # Normalize coordinates 0 - 1
184
+ labels[:, [2, 4]] /= img.shape[0] # height
185
+ labels[:, [1, 3]] /= img.shape[1] # width
186
+
187
+ labels_out = torch.zeros((len(labels), 6))
188
+ if len(labels):
189
+ labels_out[:, 1:] = torch.from_numpy(labels)
190
+ # Convert
191
+ # img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
192
+ # img = img.transpose(2, 0, 1)
193
+ img = np.ascontiguousarray(img)
194
+ # seg_label = np.ascontiguousarray(seg_label)
195
+ # if idx == 0:
196
+ # print(seg_label[:,:,0])
197
+
198
+ if self.cfg.num_seg_class == 3:
199
+ _,seg0 = cv2.threshold(seg_label[:,:,0],128,255,cv2.THRESH_BINARY)
200
+ _,seg1 = cv2.threshold(seg_label[:,:,1],1,255,cv2.THRESH_BINARY)
201
+ _,seg2 = cv2.threshold(seg_label[:,:,2],1,255,cv2.THRESH_BINARY)
202
+ else:
203
+ _,seg1 = cv2.threshold(seg_label,1,255,cv2.THRESH_BINARY)
204
+ _,seg2 = cv2.threshold(seg_label,1,255,cv2.THRESH_BINARY_INV)
205
+ _,lane1 = cv2.threshold(lane_label,1,255,cv2.THRESH_BINARY)
206
+ _,lane2 = cv2.threshold(lane_label,1,255,cv2.THRESH_BINARY_INV)
207
+ # _,seg2 = cv2.threshold(seg_label[:,:,2],1,255,cv2.THRESH_BINARY)
208
+ # # seg1[cutout_mask] = 0
209
+ # # seg2[cutout_mask] = 0
210
+
211
+ # seg_label /= 255
212
+ # seg0 = self.Tensor(seg0)
213
+ if self.cfg.num_seg_class == 3:
214
+ seg0 = self.Tensor(seg0)
215
+ seg1 = self.Tensor(seg1)
216
+ seg2 = self.Tensor(seg2)
217
+ # seg1 = self.Tensor(seg1)
218
+ # seg2 = self.Tensor(seg2)
219
+ lane1 = self.Tensor(lane1)
220
+ lane2 = self.Tensor(lane2)
221
+
222
+ # seg_label = torch.stack((seg2[0], seg1[0]),0)
223
+ if self.cfg.num_seg_class == 3:
224
+ seg_label = torch.stack((seg0[0],seg1[0],seg2[0]),0)
225
+ else:
226
+ seg_label = torch.stack((seg2[0], seg1[0]),0)
227
+
228
+ lane_label = torch.stack((lane2[0], lane1[0]),0)
229
+ # _, gt_mask = torch.max(seg_label, 0)
230
+ # _ = show_seg_result(img, gt_mask, idx, 0, save_dir='debug', is_gt=True)
231
+
232
+
233
+ target = [labels_out, seg_label, lane_label]
234
+ img = self.transform(img)
235
+
236
+ return img, target, data["image"], shapes
237
+
238
+ def select_data(self, db):
239
+ """
240
+ You can use this function to filter useless images in the dataset
241
+ Inputs:
242
+ -db: (list)database
243
+ Returns:
244
+ -db_selected: (list)filtered dataset
245
+ """
246
+ db_selected = ...
247
+ return db_selected
248
+
249
+ @staticmethod
250
+ def collate_fn(batch):
251
+ img, label, paths, shapes= zip(*batch)
252
+ label_det, label_seg, label_lane = [], [], []
253
+ for i, l in enumerate(label):
254
+ l_det, l_seg, l_lane = l
255
+ l_det[:, 0] = i # add target image index for build_targets()
256
+ label_det.append(l_det)
257
+ label_seg.append(l_seg)
258
+ label_lane.append(l_lane)
259
+ return torch.stack(img, 0), [torch.cat(label_det, 0), torch.stack(label_seg, 0), torch.stack(label_lane, 0)], paths, shapes