zdou0830's picture
desco
749745d
raw
history blame contribute delete
No virus
1.63 kB
import os
import os.path
import json
from PIL import Image
import torch
import torchvision
import torch.utils.data as data
from maskrcnn_benchmark.structures.bounding_box import BoxList
class Background(data.Dataset):
"""Background
Args:
root (string): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
"""
def __init__(self, ann_file, root, remove_images_without_annotations=None, transforms=None):
self.root = root
with open(ann_file, "r") as f:
self.ids = json.load(f)["images"]
self.transform = transforms
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
"""
im_info = self.ids[index]
path = im_info["file_name"]
fp = os.path.join(self.root, path)
img = Image.open(fp).convert("RGB")
if self.transform is not None:
img, _ = self.transform(img, None)
null_target = BoxList(torch.zeros((0, 4)), (img.shape[-1], img.shape[-2]))
null_target.add_field("labels", torch.zeros(0))
return img, null_target, index
def __len__(self):
return len(self.ids)
def get_img_info(self, index):
im_info = self.ids[index]
return im_info