Spaces:
Build error
Build error
File size: 4,255 Bytes
708dec4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import os
import torch
import torch.utils.data
from PIL import Image
import sys
if sys.version_info[0] == 2:
import xml.etree.cElementTree as ET
else:
import xml.etree.ElementTree as ET
from maskrcnn_benchmark.structures.bounding_box import BoxList
class PascalVOCDataset(torch.utils.data.Dataset):
CLASSES = (
"__background__ ",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
)
def __init__(self, data_dir, split, use_difficult=False, transforms=None):
self.root = data_dir
self.image_set = split
self.keep_difficult = use_difficult
self.transforms = transforms
self._annopath = os.path.join(self.root, "Annotations", "%s.xml")
self._imgpath = os.path.join(self.root, "JPEGImages", "%s.jpg")
self._imgsetpath = os.path.join(self.root, "ImageSets", "Main", "%s.txt")
with open(self._imgsetpath % self.image_set) as f:
self.ids = f.readlines()
self.ids = [x.strip("\n") for x in self.ids]
self.id_to_img_map = {k: v for k, v in enumerate(self.ids)}
cls = PascalVOCDataset.CLASSES
self.class_to_ind = dict(zip(cls, range(len(cls))))
def __getitem__(self, index):
img_id = self.ids[index]
img = Image.open(self._imgpath % img_id).convert("RGB")
target = self.get_groundtruth(index)
target = target.clip_to_image(remove_empty=True)
if self.transforms is not None:
img, target = self.transforms(img, target)
return img, target, index
def __len__(self):
return len(self.ids)
def get_groundtruth(self, index):
img_id = self.ids[index]
anno = ET.parse(self._annopath % img_id).getroot()
anno = self._preprocess_annotation(anno)
height, width = anno["im_info"]
target = BoxList(anno["boxes"], (width, height), mode="xyxy")
target.add_field("labels", anno["labels"])
target.add_field("difficult", anno["difficult"])
return target
def _preprocess_annotation(self, target):
boxes = []
gt_classes = []
difficult_boxes = []
TO_REMOVE = 1
for obj in target.iter("object"):
difficult = int(obj.find("difficult").text) == 1
if not self.keep_difficult and difficult:
continue
name = obj.find("name").text.lower().strip()
bb = obj.find("bndbox")
# Make pixel indexes 0-based
# Refer to "https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/pascal_voc.py#L208-L211"
box = [
bb.find("xmin").text,
bb.find("ymin").text,
bb.find("xmax").text,
bb.find("ymax").text,
]
bndbox = tuple(
map(lambda x: x - TO_REMOVE, list(map(int, box)))
)
boxes.append(bndbox)
gt_classes.append(self.class_to_ind[name])
difficult_boxes.append(difficult)
size = target.find("size")
im_info = tuple(map(int, (size.find("height").text, size.find("width").text)))
res = {
"boxes": torch.tensor(boxes, dtype=torch.float32),
"labels": torch.tensor(gt_classes),
"difficult": torch.tensor(difficult_boxes),
"im_info": im_info,
}
return res
def get_img_info(self, index):
img_id = self.ids[index]
anno = ET.parse(self._annopath % img_id).getroot()
size = anno.find("size")
im_info = tuple(map(int, (size.find("height").text, size.find("width").text)))
return {"height": im_info[0], "width": im_info[1]}
def map_class_id_to_class_name(self, class_id):
return PascalVOCDataset.CLASSES[class_id]
|