#!/usr/bin/env python3 # -*- coding:utf-8 -*- import os import cv2 import numpy as np from loguru import logger from functools import wraps from pycocotools.coco import COCO from torch.utils.data.dataset import Dataset as torchDataset COCO_CLASSES = ( 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush') def remove_useless_info(coco): """ Remove useless info in coco dataset. COCO object is modified inplace. This function is mainly used for saving memory (save about 30% mem). """ if isinstance(coco, COCO): dataset = coco.dataset dataset.pop("info", None) dataset.pop("licenses", None) for img in dataset["images"]: img.pop("license", None) img.pop("coco_url", None) img.pop("date_captured", None) img.pop("flickr_url", None) if "annotations" in coco.dataset: for anno in coco.dataset["annotations"]: anno.pop("segmentation", None) class Dataset(torchDataset): """ This class is a subclass of the base :class:`torch.utils.data.Dataset`, that enables on the fly resizing of the ``input_dim``. Args: input_dimension (tuple): (width,height) tuple with default dimensions of the network """ def __init__(self, input_dimension, mosaic=True): super().__init__() self.__input_dim = input_dimension[:2] self.enable_mosaic = mosaic @property def input_dim(self): """ Dimension that can be used by transforms to set the correct image size, etc. This allows transforms to have a single source of truth for the input dimension of the network. Return: list: Tuple containing the current width,height """ if hasattr(self, "_input_dim"): return self._input_dim return self.__input_dim @staticmethod def mosaic_getitem(getitem_fn): """ Decorator method that needs to be used around the ``__getitem__`` method. |br| This decorator enables the closing mosaic Example: >>> class CustomSet(ln.data.Dataset): ... def __len__(self): ... return 10 ... @ln.data.Dataset.mosaic_getitem ... def __getitem__(self, index): ... return self.enable_mosaic """ @wraps(getitem_fn) def wrapper(self, index): if not isinstance(index, int): self.enable_mosaic = index[0] index = index[1] ret_val = getitem_fn(self, index) return ret_val return wrapper class COCODataset(Dataset): """ COCO dataset class. """ def __init__( self, data_dir='data/COCO', json_file="instances_train2017.json", name="train2017", img_size=(416, 416), preproc=None ): """ COCO dataset initialization. Annotation data are read into memory by COCO API. Args: data_dir (str): dataset root directory json_file (str): COCO json file name name (str): COCO data name (e.g. 'train2017' or 'val2017') img_size (tuple(int)): target image size after pre-processing preproc: data augmentation strategy """ super().__init__(img_size) self.data_dir = data_dir self.json_file = json_file self.coco = COCO(os.path.join(self.data_dir, "annotations", self.json_file)) remove_useless_info(self.coco) self.ids = self.coco.getImgIds() self.class_ids = sorted(self.coco.getCatIds()) self.cats = self.coco.loadCats(self.coco.getCatIds()) self._classes = tuple([c["name"] for c in self.cats]) self.imgs = None self.name = name self.img_size = img_size self.preproc = preproc self.annotations = self._load_coco_annotations() def __len__(self): return len(self.ids) def __del__(self): del self.imgs def _load_coco_annotations(self): return [self.load_anno_from_ids(_ids) for _ids in self.ids] def load_anno_from_ids(self, id_): im_ann = self.coco.loadImgs(id_)[0] width = im_ann["width"] height = im_ann["height"] anno_ids = self.coco.getAnnIds(imgIds=[int(id_)], iscrowd=False) annotations = self.coco.loadAnns(anno_ids) objs = [] for obj in annotations: x1 = np.max((0, obj["bbox"][0])) y1 = np.max((0, obj["bbox"][1])) x2 = np.min((width, x1 + np.max((0, obj["bbox"][2])))) y2 = np.min((height, y1 + np.max((0, obj["bbox"][3])))) if obj["area"] > 0 and x2 >= x1 and y2 >= y1: obj["clean_bbox"] = [x1, y1, x2, y2] objs.append(obj) num_objs = len(objs) res = np.zeros((num_objs, 5)) for ix, obj in enumerate(objs): cls = self.class_ids.index(obj["category_id"]) res[ix, 0:4] = obj["clean_bbox"] res[ix, 4] = cls r = min(self.img_size[0] / height, self.img_size[1] / width) res[:, :4] *= r img_info = (height, width) resized_info = (int(height * r), int(width * r)) file_name = ( im_ann["file_name"] if "file_name" in im_ann else "{:012}".format(id_) + ".jpg" ) return res, img_info, resized_info, file_name def load_anno(self, index): return self.annotations[index][0] def load_resized_img(self, index): img = self.load_image(index) r = min(self.img_size[0] / img.shape[0], self.img_size[1] / img.shape[1]) resized_img = cv2.resize( img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LINEAR, ).astype(np.uint8) return resized_img def load_image(self, index): file_name = self.annotations[index][3] img_file = os.path.join(self.data_dir, self.name, file_name) img = cv2.imread(img_file) assert img is not None, f"file named {img_file} not found" return img def pull_item(self, index): id_ = self.ids[index] res, img_info, resized_info, _ = self.annotations[index] if self.imgs is not None: pad_img = self.imgs[index] img = pad_img[: resized_info[0], : resized_info[1], :].copy() else: img = self.load_resized_img(index) return img, res.copy(), img_info, np.array([id_]) @Dataset.mosaic_getitem def __getitem__(self, index): """ One image / label pair for the given index is picked up and pre-processed. Args: index (int): data index Returns: img (numpy.ndarray): pre-processed image target (torch.Tensor): pre-processed label data. The shape is :math:`[max_labels, 5]`. each label consists of [class, xc, yc, w, h]: class (float): class index. xc, yc (float) : center of bbox whose values range from 0 to 1. w, h (float) : size of bbox whose values range from 0 to 1. img_info : tuple of h, w. h, w (int): original shape of the image img_id (int): same as the input index. Used for evaluation. """ img, target, img_info, img_id = self.pull_item(index) if self.preproc is not None: img, target = self.preproc(img, target, self.input_dim) return img, target, img_info, img_id