Spaces:
Runtime error
Runtime error
Upload hybridnets/dataset.py
Browse files- hybridnets/dataset.py +409 -0
hybridnets/dataset.py
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
# np.set_printoptions(threshold=np.inf)
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from pathlib import Path
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
from utils.utils import letterbox, augment_hsv, random_perspective
|
10 |
+
from tqdm.autonotebook import tqdm
|
11 |
+
import json
|
12 |
+
import albumentations as A
|
13 |
+
|
14 |
+
|
15 |
+
class BddDataset(Dataset):
|
16 |
+
def __init__(self, params, is_train, inputsize=640, transform=None):
|
17 |
+
"""
|
18 |
+
initial all the characteristic
|
19 |
+
|
20 |
+
Inputs:
|
21 |
+
-params: configuration parameters
|
22 |
+
-is_train(bool): whether train set or not
|
23 |
+
-transform: ToTensor and Normalize
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
None
|
27 |
+
"""
|
28 |
+
self.single_cls = True # just detect vehicle
|
29 |
+
self.is_train = is_train
|
30 |
+
self.params = params
|
31 |
+
self.transform = transform
|
32 |
+
self.inputsize = inputsize
|
33 |
+
self.Tensor = transforms.ToTensor()
|
34 |
+
img_root = Path(params.dataset['dataroot'])
|
35 |
+
label_root = Path(params.dataset['labelroot'])
|
36 |
+
mask_root = Path(params.dataset['maskroot'])
|
37 |
+
lane_root = Path(params.dataset['laneroot'])
|
38 |
+
if is_train:
|
39 |
+
indicator = params.dataset['train_set']
|
40 |
+
else:
|
41 |
+
indicator = params.dataset['test_set']
|
42 |
+
self.img_root = img_root / indicator
|
43 |
+
self.label_root = label_root / indicator
|
44 |
+
self.mask_root = mask_root / indicator
|
45 |
+
self.lane_root = lane_root / indicator
|
46 |
+
# self.label_list = self.label_root.iterdir()
|
47 |
+
self.mask_list = self.mask_root.iterdir()
|
48 |
+
self.data_format = params.dataset['data_format']
|
49 |
+
self.scale_factor = params.dataset['scale_factor']
|
50 |
+
self.rotation_factor = params.dataset['rot_factor']
|
51 |
+
self.flip = params.dataset['flip']
|
52 |
+
self.color_rgb = params.dataset['color_rgb']
|
53 |
+
self.albumentations_transform = A.Compose([
|
54 |
+
A.Blur(p=0.01),
|
55 |
+
A.MedianBlur(p=0.01),
|
56 |
+
A.ToGray(p=0.01),
|
57 |
+
A.CLAHE(p=0.01),
|
58 |
+
A.RandomBrightnessContrast(p=0.01),
|
59 |
+
A.RandomGamma(p=0.01),
|
60 |
+
A.ImageCompression(quality_lower=75, p=0.01)],
|
61 |
+
bbox_params=A.BboxParams(format='pascal_voc', label_fields=['class_labels']),
|
62 |
+
additional_targets={'mask0': 'mask'})
|
63 |
+
|
64 |
+
# bdd_labels = {
|
65 |
+
# 'unlabeled':0, 'dynamic': 1, 'ego vehicle': 2, 'ground': 3,
|
66 |
+
# 'static': 4, 'parking': 5, 'rail track': 6, 'road': 7,
|
67 |
+
# 'sidewalk': 8, 'bridge': 9, 'building': 10, 'fence': 11,
|
68 |
+
# 'garage': 12, 'guard rail': 13, 'tunnel': 14, 'wall': 15,
|
69 |
+
# 'banner': 16, 'billboard': 17, 'lane divider': 18,'parking sign': 19,
|
70 |
+
# 'pole': 20, 'polegroup': 21, 'street light': 22, 'traffic cone': 23,
|
71 |
+
# 'traffic device': 24, 'traffic light': 25, 'traffic sign': 26, 'traffic sign frame': 27,
|
72 |
+
# 'terrain': 28, 'vegetation': 29, 'sky': 30, 'person': 31,
|
73 |
+
# 'rider': 32, 'bicycle': 33, 'bus': 34, 'car': 35,
|
74 |
+
# 'caravan': 36, 'motorcycle': 37, 'trailer': 38, 'train': 39,
|
75 |
+
# 'truck': 40
|
76 |
+
# }
|
77 |
+
self.id_dict = {'person': 0, 'rider': 1, 'car': 2, 'bus': 3, 'truck': 4,
|
78 |
+
'bike': 5, 'motor': 6, 'tl_green': 7, 'tl_red': 8,
|
79 |
+
'tl_yellow': 9, 'tl_none': 10, 'traffic sign': 11, 'train': 12}
|
80 |
+
self.id_dict_single = {'car': 0, 'bus': 1, 'truck': 2, 'train': 3}
|
81 |
+
# id_dict = {'car': 0, 'bus': 1, 'truck': 2}
|
82 |
+
|
83 |
+
self.shapes = np.array(params.dataset['org_img_size'])
|
84 |
+
self.db = self._get_db()
|
85 |
+
|
86 |
+
def _get_db(self):
|
87 |
+
"""
|
88 |
+
get database from the annotation file
|
89 |
+
|
90 |
+
Inputs:
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
gt_db: (list)database [a,b,c,...]
|
94 |
+
a: (dictionary){'image':, 'information':, ......}
|
95 |
+
image: image path
|
96 |
+
mask: path of the segmetation label
|
97 |
+
label: [cls_id, center_x//256, center_y//256, w//256, h//256] 256=IMAGE_SIZE
|
98 |
+
"""
|
99 |
+
print('building database...')
|
100 |
+
gt_db = []
|
101 |
+
height, width = self.shapes
|
102 |
+
for mask in tqdm(list(self.mask_list)):
|
103 |
+
mask_path = str(mask)
|
104 |
+
label_path = mask_path.replace(str(self.mask_root), str(self.label_root)).replace(".png", ".json")
|
105 |
+
image_path = mask_path.replace(str(self.mask_root), str(self.img_root)).replace(".png", ".jpg")
|
106 |
+
lane_path = mask_path.replace(str(self.mask_root), str(self.lane_root))
|
107 |
+
with open(label_path, 'r') as f:
|
108 |
+
label = json.load(f)
|
109 |
+
data = label['frames'][0]['objects']
|
110 |
+
data = self.select_data(data)
|
111 |
+
gt = np.zeros((len(data), 5))
|
112 |
+
for idx, obj in enumerate(data):
|
113 |
+
category = obj['category']
|
114 |
+
if category == "traffic light":
|
115 |
+
color = obj['attributes']['trafficLightColor']
|
116 |
+
category = "tl_" + color
|
117 |
+
if category in self.id_dict.keys():
|
118 |
+
x1 = float(obj['box2d']['x1'])
|
119 |
+
y1 = float(obj['box2d']['y1'])
|
120 |
+
x2 = float(obj['box2d']['x2'])
|
121 |
+
y2 = float(obj['box2d']['y2'])
|
122 |
+
cls_id = self.id_dict[category]
|
123 |
+
if self.single_cls:
|
124 |
+
cls_id = 0
|
125 |
+
gt[idx][0] = cls_id
|
126 |
+
box = self.convert((width, height), (x1, x2, y1, y2))
|
127 |
+
gt[idx][1:] = list(box)
|
128 |
+
|
129 |
+
rec = [{
|
130 |
+
'image': image_path,
|
131 |
+
'label': gt,
|
132 |
+
'mask': mask_path,
|
133 |
+
'lane': lane_path
|
134 |
+
}]
|
135 |
+
|
136 |
+
# img = cv2.imread(image_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_UNCHANGED)
|
137 |
+
# # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
138 |
+
# for label in gt:
|
139 |
+
# # print(label[1])
|
140 |
+
# x1 = label[1] - label[3] / 2
|
141 |
+
# x1 *= 1280
|
142 |
+
# x1 = int(x1)
|
143 |
+
# # print(x1)
|
144 |
+
# x2 = label[1] + label[3] / 2
|
145 |
+
# x2 *= 1280
|
146 |
+
# x2 = int(x2)
|
147 |
+
# y1 = label[2] - label[4] / 2
|
148 |
+
# y1 *= 720
|
149 |
+
# y1 = int(y1)
|
150 |
+
# y2 = label[2] + label[4] / 2
|
151 |
+
# y2 *= 720
|
152 |
+
# y2 = int(y2)
|
153 |
+
# img = cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
|
154 |
+
# cv2.imwrite('gt/{}'.format(image_path.split('/')[-1]), img)
|
155 |
+
|
156 |
+
gt_db += rec
|
157 |
+
print('database build finish')
|
158 |
+
return gt_db
|
159 |
+
|
160 |
+
|
161 |
+
def evaluate(self, params, preds, output_dir):
|
162 |
+
"""
|
163 |
+
finished on children dataset
|
164 |
+
"""
|
165 |
+
raise NotImplementedError
|
166 |
+
|
167 |
+
def __len__(self, ):
|
168 |
+
"""
|
169 |
+
number of objects in the dataset
|
170 |
+
"""
|
171 |
+
return len(self.db)
|
172 |
+
|
173 |
+
def __getitem__(self, idx):
|
174 |
+
"""
|
175 |
+
Get input and groud-truth from database & add data augmentation on input
|
176 |
+
|
177 |
+
Inputs:
|
178 |
+
-idx: the index of image in self.db(database)(list)
|
179 |
+
self.db(list) [a,b,c,...]
|
180 |
+
a: (dictionary){'image':, 'information':}
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
-image: transformed image, first passed the data augmentation in __getitem__ function(type:numpy), then apply self.transform
|
184 |
+
-target: ground truth(det_gt,seg_gt)
|
185 |
+
|
186 |
+
function maybe useful
|
187 |
+
cv2.imread
|
188 |
+
cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
|
189 |
+
cv2.warpAffine
|
190 |
+
"""
|
191 |
+
data = self.db[idx]
|
192 |
+
img = cv2.imread(data["image"], cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
|
193 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
194 |
+
|
195 |
+
if self.params.num_seg_class == 3:
|
196 |
+
seg_label = cv2.imread(data["mask"])
|
197 |
+
else:
|
198 |
+
seg_label = cv2.imread(data["mask"], 0)
|
199 |
+
lane_label = cv2.imread(data["lane"], 0)
|
200 |
+
|
201 |
+
# print(lane_label.shape)
|
202 |
+
# print(seg_label.shape)
|
203 |
+
# print(lane_label.shape)
|
204 |
+
# print(seg_label.shape)
|
205 |
+
resized_shape = self.inputsize
|
206 |
+
if isinstance(resized_shape, list):
|
207 |
+
resized_shape = max(resized_shape)
|
208 |
+
h0, w0 = img.shape[:2] # orig hw
|
209 |
+
r = resized_shape / max(h0, w0) # resize image to img_size
|
210 |
+
if r != 1: # always resize down, only resize up if training with augmentation
|
211 |
+
interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR
|
212 |
+
img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp)
|
213 |
+
seg_label = cv2.resize(seg_label, (int(w0 * r), int(h0 * r)), interpolation=interp)
|
214 |
+
lane_label = cv2.resize(lane_label, (int(w0 * r), int(h0 * r)), interpolation=interp)
|
215 |
+
h, w = img.shape[:2]
|
216 |
+
|
217 |
+
(img, seg_label, lane_label), ratio, pad = letterbox((img, seg_label, lane_label), resized_shape, auto=True,
|
218 |
+
scaleup=self.is_train)
|
219 |
+
shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
|
220 |
+
# ratio = (w / w0, h / h0)
|
221 |
+
# print(resized_shape)
|
222 |
+
|
223 |
+
det_label = data["label"]
|
224 |
+
# print(det_label)
|
225 |
+
|
226 |
+
labels = []
|
227 |
+
labels_app = np.array([])
|
228 |
+
|
229 |
+
if det_label.size > 0:
|
230 |
+
# Normalized xywh to pixel xyxy format
|
231 |
+
labels = det_label.copy()
|
232 |
+
labels[:, 1] = ratio[0] * w * (det_label[:, 1] - det_label[:, 3] / 2) + pad[0] # pad width
|
233 |
+
labels[:, 2] = ratio[1] * h * (det_label[:, 2] - det_label[:, 4] / 2) + pad[1] # pad height
|
234 |
+
labels[:, 3] = ratio[0] * w * (det_label[:, 1] + det_label[:, 3] / 2) + pad[0]
|
235 |
+
labels[:, 4] = ratio[1] * h * (det_label[:, 2] + det_label[:, 4] / 2) + pad[1]
|
236 |
+
|
237 |
+
# print(labels[:, 1:4])
|
238 |
+
if self.is_train:
|
239 |
+
# albumentations
|
240 |
+
try:
|
241 |
+
new = self.albumentations_transform(image=img, mask=seg_label, mask0=lane_label,
|
242 |
+
bboxes=labels[:, 1:] if len(labels) else labels,
|
243 |
+
class_labels=labels[:, 0] if len(labels) else labels)
|
244 |
+
img = new['image']
|
245 |
+
labels = np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])]) if len(labels) else labels
|
246 |
+
seg_label = new['mask']
|
247 |
+
lane_label = new['mask0']
|
248 |
+
except ValueError: # bbox have width or height == 0
|
249 |
+
pass
|
250 |
+
|
251 |
+
# augmentation
|
252 |
+
combination = (img, seg_label, lane_label)
|
253 |
+
(img, seg_label, lane_label), labels = random_perspective(
|
254 |
+
combination=combination,
|
255 |
+
targets=labels,
|
256 |
+
degrees=self.params.dataset['rot_factor'],
|
257 |
+
translate=self.params.dataset['translate'],
|
258 |
+
scale=self.params.dataset['scale_factor'],
|
259 |
+
shear=self.params.dataset['shear']
|
260 |
+
)
|
261 |
+
# print(labels.shape)
|
262 |
+
augment_hsv(img, hgain=self.params.dataset['hsv_h'], sgain=self.params.dataset['hsv_s'], vgain=self.params.dataset['hsv_v'])
|
263 |
+
# img, seg_label, labels = cutout(combination=combination, labels=labels)
|
264 |
+
|
265 |
+
# random left-right flip
|
266 |
+
lr_flip = True
|
267 |
+
if lr_flip and random.random() < 0.5:
|
268 |
+
img = img[:, ::-1, :]
|
269 |
+
|
270 |
+
if len(labels):
|
271 |
+
rows, cols, channels = img.shape
|
272 |
+
|
273 |
+
x1 = labels[:, 1].copy()
|
274 |
+
x2 = labels[:, 3].copy()
|
275 |
+
|
276 |
+
x_tmp = x1.copy()
|
277 |
+
|
278 |
+
labels[:, 1] = cols - x2
|
279 |
+
labels[:, 3] = cols - x_tmp
|
280 |
+
|
281 |
+
# Segmentation
|
282 |
+
seg_label = np.fliplr(seg_label)
|
283 |
+
lane_label = np.fliplr(lane_label)
|
284 |
+
|
285 |
+
# cv2.imwrite('img0.jpg',img)
|
286 |
+
# cv2.imwrite('img1.jpg',seg_label)
|
287 |
+
# cv2.imwrite('img2.jpg',lane_label)
|
288 |
+
|
289 |
+
# exit()
|
290 |
+
|
291 |
+
# print(labels)
|
292 |
+
|
293 |
+
# random up-down flip
|
294 |
+
ud_flip = False
|
295 |
+
if ud_flip and random.random() < 0.5:
|
296 |
+
img = np.flipud(img)
|
297 |
+
seg_label = np.filpud(seg_label)
|
298 |
+
lane_label = np.filpud(lane_label)
|
299 |
+
if len(labels):
|
300 |
+
rows, cols, channels = img.shape
|
301 |
+
|
302 |
+
y1 = labels[:, 2].copy()
|
303 |
+
y2 = labels[:, 4].copy()
|
304 |
+
|
305 |
+
y_tmp = y1.copy()
|
306 |
+
|
307 |
+
labels[:, 2] = rows - y2
|
308 |
+
labels[:, 4] = rows - y_tmp
|
309 |
+
|
310 |
+
# for anno in labels:
|
311 |
+
# x1, y1, x2, y2 = [int(x) for x in anno[1:5]]
|
312 |
+
# print(x1,y1,x2,y2)
|
313 |
+
# cv2.rectangle(img, (x1,y1), (x2,y2), (0,0,255), 3)
|
314 |
+
# cv2.imwrite(data["image"].split("/")[-1], img)
|
315 |
+
|
316 |
+
if len(labels):
|
317 |
+
labels_app = np.zeros((len(labels), 5))
|
318 |
+
labels_app[:, 0:4] = labels[:, 1:5]
|
319 |
+
labels_app[:, 4] = labels[:, 0]
|
320 |
+
|
321 |
+
img = np.ascontiguousarray(img)
|
322 |
+
|
323 |
+
_, seg1 = cv2.threshold(seg_label, 1, 255, cv2.THRESH_BINARY)
|
324 |
+
_, lane1 = cv2.threshold(lane_label, 1, 255, cv2.THRESH_BINARY)
|
325 |
+
# prefer lane
|
326 |
+
seg1 = seg1 - (seg1 & lane1)
|
327 |
+
|
328 |
+
union = seg1 | lane1
|
329 |
+
# print(union.shape)
|
330 |
+
background = 255 - union
|
331 |
+
|
332 |
+
# print(img.shape)
|
333 |
+
# print(lane1.shape)
|
334 |
+
# img_copy = img.copy()
|
335 |
+
# img_copy[lane1 == 255] = (0, 255, 0)
|
336 |
+
# cv2.imwrite('seg_gt/' + data['image'].split('/')[-1], img_copy)
|
337 |
+
# cv2.imwrite('background.jpg', background)
|
338 |
+
# cv2.imwrite('{}.jpg'.format(data['image'].split('/')[-1]), img)
|
339 |
+
# cv2.imwrite('{}-lane.jpg'.format(data['image'].split('/')[-1]),lane1)
|
340 |
+
# cv2.imwrite('{}-seg.jpg'.format(data['image'].split('/')[-1]),seg1)
|
341 |
+
|
342 |
+
seg1 = self.Tensor(seg1)
|
343 |
+
lane1 = self.Tensor(lane1)
|
344 |
+
background = self.Tensor(background)
|
345 |
+
|
346 |
+
segmentation = torch.cat([background, seg1, lane1], dim=0)
|
347 |
+
# print(segmentation.size())
|
348 |
+
# print(seg1.shape)
|
349 |
+
|
350 |
+
# for anno in labels_app:
|
351 |
+
# x1, y1, x2, y2 = [int(x) for x in anno[anno != -1][:4]]
|
352 |
+
# cv2.rectangle(img, (x1,y1), (x2,y2), (0,0,255), 1)
|
353 |
+
# cv2.imwrite(data["image"].split("/")[-1], img)
|
354 |
+
|
355 |
+
img = self.transform(img)
|
356 |
+
|
357 |
+
return img, data["image"], shapes, torch.from_numpy(labels_app), segmentation
|
358 |
+
|
359 |
+
def select_data(self, db):
|
360 |
+
"""
|
361 |
+
You can use this function to filter useless images in the dataset
|
362 |
+
|
363 |
+
Inputs:
|
364 |
+
-db: (list)database
|
365 |
+
|
366 |
+
Returns:
|
367 |
+
-db_selected: (list)filtered dataset
|
368 |
+
"""
|
369 |
+
remain = []
|
370 |
+
for obj in db:
|
371 |
+
if 'box2d' in obj.keys(): # obj.has_key('box2d'):
|
372 |
+
if self.single_cls:
|
373 |
+
if obj['category'] in self.id_dict_single.keys():
|
374 |
+
remain.append(obj)
|
375 |
+
else:
|
376 |
+
remain.append(obj)
|
377 |
+
return remain
|
378 |
+
|
379 |
+
def convert(self, size, box):
|
380 |
+
dw = 1. / (size[0])
|
381 |
+
dh = 1. / (size[1])
|
382 |
+
x = (box[0] + box[1]) / 2.0
|
383 |
+
y = (box[2] + box[3]) / 2.0
|
384 |
+
w = box[1] - box[0]
|
385 |
+
h = box[3] - box[2]
|
386 |
+
x = x * dw
|
387 |
+
w = w * dw
|
388 |
+
y = y * dh
|
389 |
+
h = h * dh
|
390 |
+
return x, y, w, h
|
391 |
+
|
392 |
+
@staticmethod
|
393 |
+
def collate_fn(batch):
|
394 |
+
img, paths, shapes, labels_app, segmentation = zip(*batch)
|
395 |
+
filenames = [file.split('/')[-1] for file in paths]
|
396 |
+
# print(len(labels_app))
|
397 |
+
max_num_annots = max(label.size(0) for label in labels_app)
|
398 |
+
|
399 |
+
if max_num_annots > 0:
|
400 |
+
annot_padded = torch.ones((len(labels_app), max_num_annots, 5)) * -1
|
401 |
+
for idx, label in enumerate(labels_app):
|
402 |
+
if label.size(0) > 0:
|
403 |
+
annot_padded[idx, :label.size(0), :] = label
|
404 |
+
else:
|
405 |
+
annot_padded = torch.ones((len(labels_app), 1, 5)) * -1
|
406 |
+
|
407 |
+
# print("ABC", seg1.size())
|
408 |
+
return {'img': torch.stack(img, 0), 'annot': annot_padded, 'segmentation': torch.stack(segmentation, 0),
|
409 |
+
'filenames': filenames, 'shapes': shapes}
|