Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import numpy as np | |
from PIL import Image | |
from pycocotools.coco import COCO | |
import torchvision.transforms as T | |
class LegoDataset(torch.utils.data.Dataset): | |
def __init__(self, root, annFile, transforms=None): | |
self.root = root | |
self.coco = COCO(annFile) | |
self.ids = list(self.coco.imgs.keys()) | |
self.transforms = transforms or T.Compose([T.ToTensor()]) | |
def __getitem__(self, index): | |
img_id = self.ids[index] | |
img_info = self.coco.loadImgs(img_id)[0] | |
path = img_info["file_name"] | |
img = Image.open(os.path.join(self.root, path)).convert("RGB") | |
ann_ids = self.coco.getAnnIds(imgIds=img_id) | |
annotations = self.coco.loadAnns(ann_ids) | |
boxes = [] | |
labels = [] | |
masks = [] # Dummy masks | |
for ann in annotations: | |
xmin, ymin, width, height = ann["bbox"] | |
boxes.append([xmin, ymin, xmin + width, ymin + height]) | |
labels.append(1) # 'lego' is the only class, labeled as 1 | |
# Dummy mask for Mask R-CNN, filled with zeros | |
dummy_mask = np.zeros( | |
(img_info["height"], img_info["width"]), dtype=np.uint8 | |
) | |
masks.append(dummy_mask) | |
boxes = torch.as_tensor(boxes, dtype=torch.float32) | |
labels = torch.as_tensor(labels, dtype=torch.int64) | |
masks = torch.as_tensor(np.array(masks), dtype=torch.uint8) | |
target = { | |
"boxes": boxes, | |
"labels": labels, | |
"masks": masks, | |
"image_id": torch.tensor([img_id]), | |
} | |
if self.transforms: | |
img = self.transforms(img) | |
return img, target | |
def __len__(self): | |
return len(self.ids) | |