File size: 8,480 Bytes
1cff332 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
import os
import cv2
import numpy as np
from loguru import logger
from functools import wraps
from pycocotools.coco import COCO
from torch.utils.data.dataset import Dataset as torchDataset
COCO_CLASSES = (
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite',
'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',
'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut',
'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush')
def remove_useless_info(coco):
"""
Remove useless info in coco dataset. COCO object is modified inplace.
This function is mainly used for saving memory (save about 30% mem).
"""
if isinstance(coco, COCO):
dataset = coco.dataset
dataset.pop("info", None)
dataset.pop("licenses", None)
for img in dataset["images"]:
img.pop("license", None)
img.pop("coco_url", None)
img.pop("date_captured", None)
img.pop("flickr_url", None)
if "annotations" in coco.dataset:
for anno in coco.dataset["annotations"]:
anno.pop("segmentation", None)
class Dataset(torchDataset):
""" This class is a subclass of the base :class:`torch.utils.data.Dataset`,
that enables on the fly resizing of the ``input_dim``.
Args:
input_dimension (tuple): (width,height) tuple with default dimensions of the network
"""
def __init__(self, input_dimension, mosaic=True):
super().__init__()
self.__input_dim = input_dimension[:2]
self.enable_mosaic = mosaic
@property
def input_dim(self):
"""
Dimension that can be used by transforms to set the correct image size, etc.
This allows transforms to have a single source of truth
for the input dimension of the network.
Return:
list: Tuple containing the current width,height
"""
if hasattr(self, "_input_dim"):
return self._input_dim
return self.__input_dim
@staticmethod
def mosaic_getitem(getitem_fn):
"""
Decorator method that needs to be used around the ``__getitem__`` method. |br|
This decorator enables the closing mosaic
Example:
>>> class CustomSet(ln.data.Dataset):
... def __len__(self):
... return 10
... @ln.data.Dataset.mosaic_getitem
... def __getitem__(self, index):
... return self.enable_mosaic
"""
@wraps(getitem_fn)
def wrapper(self, index):
if not isinstance(index, int):
self.enable_mosaic = index[0]
index = index[1]
ret_val = getitem_fn(self, index)
return ret_val
return wrapper
class COCODataset(Dataset):
"""
COCO dataset class.
"""
def __init__(
self,
data_dir='data/COCO',
json_file="instances_train2017.json",
name="train2017",
img_size=(416, 416),
preproc=None
):
"""
COCO dataset initialization. Annotation data are read into memory by COCO API.
Args:
data_dir (str): dataset root directory
json_file (str): COCO json file name
name (str): COCO data name (e.g. 'train2017' or 'val2017')
img_size (tuple(int)): target image size after pre-processing
preproc: data augmentation strategy
"""
super().__init__(img_size)
self.data_dir = data_dir
self.json_file = json_file
self.coco = COCO(os.path.join(self.data_dir, "annotations", self.json_file))
remove_useless_info(self.coco)
self.ids = self.coco.getImgIds()
self.class_ids = sorted(self.coco.getCatIds())
self.cats = self.coco.loadCats(self.coco.getCatIds())
self._classes = tuple([c["name"] for c in self.cats])
self.imgs = None
self.name = name
self.img_size = img_size
self.preproc = preproc
self.annotations = self._load_coco_annotations()
def __len__(self):
return len(self.ids)
def __del__(self):
del self.imgs
def _load_coco_annotations(self):
return [self.load_anno_from_ids(_ids) for _ids in self.ids]
def load_anno_from_ids(self, id_):
im_ann = self.coco.loadImgs(id_)[0]
width = im_ann["width"]
height = im_ann["height"]
anno_ids = self.coco.getAnnIds(imgIds=[int(id_)], iscrowd=False)
annotations = self.coco.loadAnns(anno_ids)
objs = []
for obj in annotations:
x1 = np.max((0, obj["bbox"][0]))
y1 = np.max((0, obj["bbox"][1]))
x2 = np.min((width, x1 + np.max((0, obj["bbox"][2]))))
y2 = np.min((height, y1 + np.max((0, obj["bbox"][3]))))
if obj["area"] > 0 and x2 >= x1 and y2 >= y1:
obj["clean_bbox"] = [x1, y1, x2, y2]
objs.append(obj)
num_objs = len(objs)
res = np.zeros((num_objs, 5))
for ix, obj in enumerate(objs):
cls = self.class_ids.index(obj["category_id"])
res[ix, 0:4] = obj["clean_bbox"]
res[ix, 4] = cls
r = min(self.img_size[0] / height, self.img_size[1] / width)
res[:, :4] *= r
img_info = (height, width)
resized_info = (int(height * r), int(width * r))
file_name = (
im_ann["file_name"]
if "file_name" in im_ann
else "{:012}".format(id_) + ".jpg"
)
return res, img_info, resized_info, file_name
def load_anno(self, index):
return self.annotations[index][0]
def load_resized_img(self, index):
img = self.load_image(index)
r = min(self.img_size[0] / img.shape[0], self.img_size[1] / img.shape[1])
resized_img = cv2.resize(
img,
(int(img.shape[1] * r), int(img.shape[0] * r)),
interpolation=cv2.INTER_LINEAR,
).astype(np.uint8)
return resized_img
def load_image(self, index):
file_name = self.annotations[index][3]
img_file = os.path.join(self.data_dir, self.name, file_name)
img = cv2.imread(img_file)
assert img is not None, f"file named {img_file} not found"
return img
def pull_item(self, index):
id_ = self.ids[index]
res, img_info, resized_info, _ = self.annotations[index]
if self.imgs is not None:
pad_img = self.imgs[index]
img = pad_img[: resized_info[0], : resized_info[1], :].copy()
else:
img = self.load_resized_img(index)
return img, res.copy(), img_info, np.array([id_])
@Dataset.mosaic_getitem
def __getitem__(self, index):
"""
One image / label pair for the given index is picked up and pre-processed.
Args:
index (int): data index
Returns:
img (numpy.ndarray): pre-processed image
target (torch.Tensor): pre-processed label data.
The shape is :math:`[max_labels, 5]`.
each label consists of [class, xc, yc, w, h]:
class (float): class index.
xc, yc (float) : center of bbox whose values range from 0 to 1.
w, h (float) : size of bbox whose values range from 0 to 1.
img_info : tuple of h, w.
h, w (int): original shape of the image
img_id (int): same as the input index. Used for evaluation.
"""
img, target, img_info, img_id = self.pull_item(index)
if self.preproc is not None:
img, target = self.preproc(img, target, self.input_dim)
return img, target, img_info, img_id
|