zdou0830's picture
desco
749745d
raw
history blame
No virus
4.21 kB
import os
import torch
import torch.utils.data
from PIL import Image
import sys
if sys.version_info[0] == 2:
import xml.etree.cElementTree as ET
else:
import xml.etree.ElementTree as ET
from maskrcnn_benchmark.structures.bounding_box import BoxList
class PascalVOCDataset(torch.utils.data.Dataset):
CLASSES = (
"__background__ ",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
)
def __init__(self, data_dir, split, use_difficult=False, transforms=None):
self.root = data_dir
self.image_set = split
self.keep_difficult = use_difficult
self.transforms = transforms
self._annopath = os.path.join(self.root, "Annotations", "%s.xml")
self._imgpath = os.path.join(self.root, "JPEGImages", "%s.jpg")
self._imgsetpath = os.path.join(self.root, "ImageSets", "Main", "%s.txt")
with open(self._imgsetpath % self.image_set) as f:
self.ids = f.readlines()
self.ids = [x.strip("\n") for x in self.ids]
self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
cls = PascalVOCDataset.CLASSES
self.class_to_ind = dict(zip(cls, range(len(cls))))
def __getitem__(self, index):
img_id = self.ids[index]
img = Image.open(self._imgpath % img_id).convert("RGB")
target = self.get_groundtruth(index)
target = target.clip_to_image(remove_empty=True)
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target, index
def __len__(self):
return len(self.ids)
def get_groundtruth(self, index):
img_id = self.ids[index]
anno = ET.parse(self._annopath % img_id).getroot()
anno = self._preprocess_annotation(anno)
height, width = anno["im_info"]
target = BoxList(anno["boxes"], (width, height), mode="xyxy")
target.add_field("labels", anno["labels"])
target.add_field("difficult", anno["difficult"])
return target
def _preprocess_annotation(self, target):
boxes = []
gt_classes = []
difficult_boxes = []
TO_REMOVE = 1
for obj in target.iter("object"):
difficult = int(obj.find("difficult").text) == 1
if not self.keep_difficult and difficult:
continue
name = obj.find("name").text.lower().strip()
bb = obj.find("bndbox")
# Make pixel indexes 0-based
# Refer to "https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/pascal_voc.py#L208-L211"
box = [
bb.find("xmin").text,
bb.find("ymin").text,
bb.find("xmax").text,
bb.find("ymax").text,
]
bndbox = tuple(map(lambda x: x - TO_REMOVE, list(map(int, box))))
boxes.append(bndbox)
gt_classes.append(self.class_to_ind[name])
difficult_boxes.append(difficult)
size = target.find("size")
im_info = tuple(map(int, (size.find("height").text, size.find("width").text)))
res = {
"boxes": torch.tensor(boxes, dtype=torch.float32),
"labels": torch.tensor(gt_classes),
"difficult": torch.tensor(difficult_boxes),
"im_info": im_info,
}
return res
def get_img_info(self, index):
img_id = self.ids[index]
anno = ET.parse(self._annopath % img_id).getroot()
size = anno.find("size")
im_info = tuple(map(int, (size.find("height").text, size.find("width").text)))
return {"height": im_info[0], "width": im_info[1]}
def map_class_id_to_class_name(self, class_id):
return PascalVOCDataset.CLASSES[class_id]