Pinwheel's picture
HF Demo
128757a
raw
history blame
1.63 kB
import os
import os.path
import json
from PIL import Image
import torch
import torchvision
import torch.utils.data as data
from maskrcnn_benchmark.structures.bounding_box import BoxList
class Background(data.Dataset):
""" Background
Args:
root (string): Root directory where images are downloaded to.
annFile (string): Path to json annotation file.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
"""
def __init__(self, ann_file, root, remove_images_without_annotations=None, transforms=None):
self.root = root
with open(ann_file, 'r') as f:
self.ids = json.load(f)['images']
self.transform = transforms
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.
"""
im_info = self.ids[index]
path = im_info['file_name']
fp = os.path.join(self.root, path)
img = Image.open(fp).convert('RGB')
if self.transform is not None:
img, _ = self.transform(img, None)
null_target = BoxList(torch.zeros((0,4)), (img.shape[-1], img.shape[-2]))
null_target.add_field('labels', torch.zeros(0))
return img, null_target, index
def __len__(self):
return len(self.ids)
def get_img_info(self, index):
im_info = self.ids[index]
return im_info