zdou0830's picture
desco
749745d
raw
history blame
No virus
2.5 kB
# Following LVIS dataset
# Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import json
import os
import time
from collections import defaultdict
import pdb
import pycocotools.mask as mask_utils
import torchvision
from PIL import Image
import torch
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask
from maskrcnn_benchmark.structures.keypoint import PersonKeypoints
from maskrcnn_benchmark.config import cfg
# from .coco import ConvertCocoPolysToMask, make_coco_transforms
from .modulated_coco import ConvertCocoPolysToMask
from .lvis import LVIS, LvisDetectionBase
class PacoDetection(LvisDetectionBase):
def __init__(self, img_folder, ann_file, transforms, return_masks=False, **kwargs):
super(PacoDetection, self).__init__(img_folder, ann_file)
self.ann_file = ann_file
self._transforms = transforms
self.ids = sorted(list(self.lvis.imgs.keys()))
self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
self.prepare = ConvertCocoPolysToMask(return_masks)
def categories(self):
id2cat = {c["id"]: c for c in self.lvis.dataset["categories"]}
all_cats = sorted(list(id2cat.keys()))
categories = {}
for l in list(all_cats):
categories[l] = id2cat[l]['name']
return categories
def __getitem__(self, idx):
pdb.set_trace()
img, target = super(PacoDetection, self).__getitem__(idx)
image_id = self.ids[idx]
target = {"image_id": image_id, "annotations": target}
img, target = self.prepare(img, target)
if self._transforms is not None:
img = self._transforms(img)
return img, target, idx
def convert_dict_anno_to_box(self, annos):
pass
def get_raw_image(self, idx):
img, target = super(PacoDetection, self).__getitem__(idx)
return img
def categories(self):
id2cat = {c["id"]: c for c in self.lvis.dataset["categories"]}
all_cats = sorted(list(id2cat.keys()))
categories = {}
for l in list(all_cats):
categories[l] = id2cat[l]['name']
return categories
def get_img_info(self, index):
img_id = self.id_to_img_map[index]
img_data = self.lvis.imgs[img_id]
return img_data