import os import torch import torch.utils.data from PIL import Image import sys if sys.version_info[0] == 2: import xml.etree.cElementTree as ET else: import xml.etree.ElementTree as ET from maskrcnn_benchmark.structures.bounding_box import BoxList class PascalVOCDataset(torch.utils.data.Dataset): CLASSES = ( "__background__ ", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor", ) def __init__(self, data_dir, split, use_difficult=False, transforms=None): self.root = data_dir self.image_set = split self.keep_difficult = use_difficult self.transforms = transforms self._annopath = os.path.join(self.root, "Annotations", "%s.xml") self._imgpath = os.path.join(self.root, "JPEGImages", "%s.jpg") self._imgsetpath = os.path.join(self.root, "ImageSets", "Main", "%s.txt") with open(self._imgsetpath % self.image_set) as f: self.ids = f.readlines() self.ids = [x.strip("\n") for x in self.ids] self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} cls = PascalVOCDataset.CLASSES self.class_to_ind = dict(zip(cls, range(len(cls)))) def __getitem__(self, index): img_id = self.ids[index] img = Image.open(self._imgpath % img_id).convert("RGB") target = self.get_groundtruth(index) target = target.clip_to_image(remove_empty=True) if self.transforms is not None: img, target = self.transforms(img, target) return img, target, index def __len__(self): return len(self.ids) def get_groundtruth(self, index): img_id = self.ids[index] anno = ET.parse(self._annopath % img_id).getroot() anno = self._preprocess_annotation(anno) height, width = anno["im_info"] target = BoxList(anno["boxes"], (width, height), mode="xyxy") target.add_field("labels", anno["labels"]) target.add_field("difficult", anno["difficult"]) return target def _preprocess_annotation(self, target): boxes = [] gt_classes = [] difficult_boxes = [] TO_REMOVE = 1 for obj in target.iter("object"): difficult = int(obj.find("difficult").text) == 1 if not self.keep_difficult and difficult: continue name = obj.find("name").text.lower().strip() bb = obj.find("bndbox") # Make pixel indexes 0-based # Refer to "https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/pascal_voc.py#L208-L211" box = [ bb.find("xmin").text, bb.find("ymin").text, bb.find("xmax").text, bb.find("ymax").text, ] bndbox = tuple( map(lambda x: x - TO_REMOVE, list(map(int, box))) ) boxes.append(bndbox) gt_classes.append(self.class_to_ind[name]) difficult_boxes.append(difficult) size = target.find("size") im_info = tuple(map(int, (size.find("height").text, size.find("width").text))) res = { "boxes": torch.tensor(boxes, dtype=torch.float32), "labels": torch.tensor(gt_classes), "difficult": torch.tensor(difficult_boxes), "im_info": im_info, } return res def get_img_info(self, index): img_id = self.ids[index] anno = ET.parse(self._annopath % img_id).getroot() size = anno.find("size") im_info = tuple(map(int, (size.find("height").text, size.find("width").text))) return {"height": im_info[0], "width": im_info[1]} def map_class_id_to_class_name(self, class_id): return PascalVOCDataset.CLASSES[class_id]