| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import cv2 |
| import json |
| import copy |
| import numpy as np |
|
|
| try: |
| from collections.abc import Sequence |
| except Exception: |
| from collections import Sequence |
|
|
| from ppdet.core.workspace import register, serializable |
| from ppdet.data.crop_utils.annotation_cropper import AnnoCropper |
| from .coco import COCODataSet |
| from .dataset import _make_dataset, _is_valid_file |
| from ppdet.utils.logger import setup_logger |
|
|
| logger = setup_logger('sniper_coco_dataset') |
|
|
|
|
| @register |
| @serializable |
| class SniperCOCODataSet(COCODataSet): |
| """SniperCOCODataSet""" |
|
|
| def __init__(self, |
| dataset_dir=None, |
| image_dir=None, |
| anno_path=None, |
| proposals_file=None, |
| data_fields=['image'], |
| sample_num=-1, |
| load_crowd=False, |
| allow_empty=True, |
| empty_ratio=1., |
| is_trainset=True, |
| image_target_sizes=[2000, 1000], |
| valid_box_ratio_ranges=[[-1, 0.1],[0.08, -1]], |
| chip_target_size=500, |
| chip_target_stride=200, |
| use_neg_chip=False, |
| max_neg_num_per_im=8, |
| max_per_img=-1, |
| nms_thresh=0.5): |
| super(SniperCOCODataSet, self).__init__( |
| dataset_dir=dataset_dir, |
| image_dir=image_dir, |
| anno_path=anno_path, |
| data_fields=data_fields, |
| sample_num=sample_num, |
| load_crowd=load_crowd, |
| allow_empty=allow_empty, |
| empty_ratio=empty_ratio |
| ) |
| self.proposals_file = proposals_file |
| self.proposals = None |
| self.anno_cropper = None |
| self.is_trainset = is_trainset |
| self.image_target_sizes = image_target_sizes |
| self.valid_box_ratio_ranges = valid_box_ratio_ranges |
| self.chip_target_size = chip_target_size |
| self.chip_target_stride = chip_target_stride |
| self.use_neg_chip = use_neg_chip |
| self.max_neg_num_per_im = max_neg_num_per_im |
| self.max_per_img = max_per_img |
| self.nms_thresh = nms_thresh |
|
|
|
|
| def parse_dataset(self): |
| if not hasattr(self, "roidbs"): |
| super(SniperCOCODataSet, self).parse_dataset() |
| if self.is_trainset: |
| self._parse_proposals() |
| self._merge_anno_proposals() |
| self.ori_roidbs = copy.deepcopy(self.roidbs) |
| self.init_anno_cropper() |
| self.roidbs = self.generate_chips_roidbs(self.roidbs, self.is_trainset) |
|
|
| def set_proposals_file(self, file_path): |
| self.proposals_file = file_path |
|
|
| def init_anno_cropper(self): |
| logger.info("Init AnnoCropper...") |
| self.anno_cropper = AnnoCropper( |
| image_target_sizes=self.image_target_sizes, |
| valid_box_ratio_ranges=self.valid_box_ratio_ranges, |
| chip_target_size=self.chip_target_size, |
| chip_target_stride=self.chip_target_stride, |
| use_neg_chip=self.use_neg_chip, |
| max_neg_num_per_im=self.max_neg_num_per_im, |
| max_per_img=self.max_per_img, |
| nms_thresh=self.nms_thresh |
| ) |
|
|
| def generate_chips_roidbs(self, roidbs, is_trainset): |
| if is_trainset: |
| roidbs = self.anno_cropper.crop_anno_records(roidbs) |
| else: |
| roidbs = self.anno_cropper.crop_infer_anno_records(roidbs) |
| return roidbs |
|
|
| def _parse_proposals(self): |
| if self.proposals_file: |
| self.proposals = {} |
| logger.info("Parse proposals file:{}".format(self.proposals_file)) |
| with open(self.proposals_file, 'r') as f: |
| proposals = json.load(f) |
| for prop in proposals: |
| image_id = prop["image_id"] |
| if image_id not in self.proposals: |
| self.proposals[image_id] = [] |
| x, y, w, h = prop["bbox"] |
| self.proposals[image_id].append([x, y, x + w, y + h]) |
|
|
| def _merge_anno_proposals(self): |
| assert self.roidbs |
| if self.proposals and len(self.proposals.keys()) > 0: |
| logger.info("merge proposals to annos") |
| for id, record in enumerate(self.roidbs): |
| image_id = int(record["im_id"]) |
| if image_id not in self.proposals.keys(): |
| logger.info("image id :{} no proposals".format(image_id)) |
| record["proposals"] = np.array(self.proposals.get(image_id, []), dtype=np.float32) |
| self.roidbs[id] = record |
|
|
| def get_ori_roidbs(self): |
| if not hasattr(self, "ori_roidbs"): |
| return None |
| return self.ori_roidbs |
|
|
| def get_roidbs(self): |
| if not hasattr(self, "roidbs"): |
| self.parse_dataset() |
| return self.roidbs |
|
|
| def set_roidbs(self, roidbs): |
| self.roidbs = roidbs |
|
|
| def check_or_download_dataset(self): |
| return |
|
|
| def _parse(self): |
| image_dir = self.image_dir |
| if not isinstance(image_dir, Sequence): |
| image_dir = [image_dir] |
| images = [] |
| for im_dir in image_dir: |
| if os.path.isdir(im_dir): |
| im_dir = os.path.join(self.dataset_dir, im_dir) |
| images.extend(_make_dataset(im_dir)) |
| elif os.path.isfile(im_dir) and _is_valid_file(im_dir): |
| images.append(im_dir) |
| return images |
|
|
| def _load_images(self): |
| images = self._parse() |
| ct = 0 |
| records = [] |
| for image in images: |
| assert image != '' and os.path.isfile(image), \ |
| "Image {} not found".format(image) |
| if self.sample_num > 0 and ct >= self.sample_num: |
| break |
| im = cv2.imread(image) |
| h, w, c = im.shape |
| rec = {'im_id': np.array([ct]), 'im_file': image, "h": h, "w": w} |
| self._imid2path[ct] = image |
| ct += 1 |
| records.append(rec) |
| assert len(records) > 0, "No image file found" |
| return records |
|
|
| def get_imid2path(self): |
| return self._imid2path |
|
|
| def set_images(self, images): |
| self._imid2path = {} |
| self.image_dir = images |
| self.roidbs = self._load_images() |
|
|
|
|