Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py | |
Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
""" | |
import faster_coco_eval | |
import faster_coco_eval.core.mask as coco_mask | |
import torch | |
import torch.utils.data | |
import torchvision | |
import os | |
from PIL import Image | |
from ...core import register | |
from .._misc import convert_to_tv_tensor | |
from ._dataset import DetDataset | |
torchvision.disable_beta_transforms_warning() | |
faster_coco_eval.init_as_pycocotools() | |
Image.MAX_IMAGE_PIXELS = None | |
__all__ = ["CocoDetection"] | |
class CocoDetection(torchvision.datasets.CocoDetection, DetDataset): | |
__inject__ = [ | |
"transforms", | |
] | |
__share__ = ["remap_mscoco_category"] | |
def __init__( | |
self, img_folder, ann_file, transforms, return_masks=False, remap_mscoco_category=False | |
): | |
super(CocoDetection, self).__init__(img_folder, ann_file) | |
self._transforms = transforms | |
self.prepare = ConvertCocoPolysToMask(return_masks) | |
self.img_folder = img_folder | |
self.ann_file = ann_file | |
self.return_masks = return_masks | |
self.remap_mscoco_category = remap_mscoco_category | |
def __getitem__(self, idx): | |
img, target = self.load_item(idx) | |
if self._transforms is not None: | |
img, target, _ = self._transforms(img, target, self) | |
return img, target | |
def load_item(self, idx): | |
image, target = super(CocoDetection, self).__getitem__(idx) | |
image_id = self.ids[idx] | |
image_path = os.path.join(self.img_folder, self.coco.loadImgs(image_id)[0]["file_name"]) | |
target = {"image_id": image_id, "image_path": image_path, "annotations": target} | |
if self.remap_mscoco_category: | |
image, target = self.prepare(image, target, category2label=mscoco_category2label) | |
else: | |
image, target = self.prepare(image, target) | |
target["idx"] = torch.tensor([idx]) | |
if "boxes" in target: | |
target["boxes"] = convert_to_tv_tensor( | |
target["boxes"], key="boxes", spatial_size=image.size[::-1] | |
) | |
if "masks" in target: | |
target["masks"] = convert_to_tv_tensor(target["masks"], key="masks") | |
return image, target | |
def extra_repr(self) -> str: | |
s = f" img_folder: {self.img_folder}\n ann_file: {self.ann_file}\n" | |
s += f" return_masks: {self.return_masks}\n" | |
if hasattr(self, "_transforms") and self._transforms is not None: | |
s += f" transforms:\n {repr(self._transforms)}" | |
if hasattr(self, "_preset") and self._preset is not None: | |
s += f" preset:\n {repr(self._preset)}" | |
return s | |
def categories( | |
self, | |
): | |
return self.coco.dataset["categories"] | |
def category2name( | |
self, | |
): | |
return {cat["id"]: cat["name"] for cat in self.categories} | |
def category2label( | |
self, | |
): | |
return {cat["id"]: i for i, cat in enumerate(self.categories)} | |
def label2category( | |
self, | |
): | |
return {i: cat["id"] for i, cat in enumerate(self.categories)} | |
def convert_coco_poly_to_mask(segmentations, height, width): | |
masks = [] | |
for polygons in segmentations: | |
rles = coco_mask.frPyObjects(polygons, height, width) | |
mask = coco_mask.decode(rles) | |
if len(mask.shape) < 3: | |
mask = mask[..., None] | |
mask = torch.as_tensor(mask, dtype=torch.uint8) | |
mask = mask.any(dim=2) | |
masks.append(mask) | |
if masks: | |
masks = torch.stack(masks, dim=0) | |
else: | |
masks = torch.zeros((0, height, width), dtype=torch.uint8) | |
return masks | |
class ConvertCocoPolysToMask(object): | |
def __init__(self, return_masks=False): | |
self.return_masks = return_masks | |
def __call__(self, image: Image.Image, target, **kwargs): | |
w, h = image.size | |
image_id = target["image_id"] | |
image_id = torch.tensor([image_id]) | |
image_path = target["image_path"] | |
anno = target["annotations"] | |
anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0] | |
boxes = [obj["bbox"] for obj in anno] | |
# guard against no boxes via resizing | |
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) | |
boxes[:, 2:] += boxes[:, :2] | |
boxes[:, 0::2].clamp_(min=0, max=w) | |
boxes[:, 1::2].clamp_(min=0, max=h) | |
category2label = kwargs.get("category2label", None) | |
if category2label is not None: | |
labels = [category2label[obj["category_id"]] for obj in anno] | |
else: | |
labels = [obj["category_id"] for obj in anno] | |
labels = torch.tensor(labels, dtype=torch.int64) | |
if self.return_masks: | |
segmentations = [obj["segmentation"] for obj in anno] | |
masks = convert_coco_poly_to_mask(segmentations, h, w) | |
keypoints = None | |
if anno and "keypoints" in anno[0]: | |
keypoints = [obj["keypoints"] for obj in anno] | |
keypoints = torch.as_tensor(keypoints, dtype=torch.float32) | |
num_keypoints = keypoints.shape[0] | |
if num_keypoints: | |
keypoints = keypoints.view(num_keypoints, -1, 3) | |
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) | |
boxes = boxes[keep] | |
labels = labels[keep] | |
if self.return_masks: | |
masks = masks[keep] | |
if keypoints is not None: | |
keypoints = keypoints[keep] | |
target = {} | |
target["boxes"] = boxes | |
target["labels"] = labels | |
if self.return_masks: | |
target["masks"] = masks | |
target["image_id"] = image_id | |
target["image_path"] = image_path | |
if keypoints is not None: | |
target["keypoints"] = keypoints | |
# for conversion to coco api | |
area = torch.tensor([obj["area"] for obj in anno]) | |
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) | |
target["area"] = area[keep] | |
target["iscrowd"] = iscrowd[keep] | |
target["orig_size"] = torch.as_tensor([int(w), int(h)]) | |
# target["size"] = torch.as_tensor([int(w), int(h)]) | |
return image, target | |
mscoco_category2name = { | |
1: "person", | |
2: "bicycle", | |
3: "car", | |
4: "motorcycle", | |
5: "airplane", | |
6: "bus", | |
7: "train", | |
8: "truck", | |
9: "boat", | |
10: "traffic light", | |
11: "fire hydrant", | |
13: "stop sign", | |
14: "parking meter", | |
15: "bench", | |
16: "bird", | |
17: "cat", | |
18: "dog", | |
19: "horse", | |
20: "sheep", | |
21: "cow", | |
22: "elephant", | |
23: "bear", | |
24: "zebra", | |
25: "giraffe", | |
27: "backpack", | |
28: "umbrella", | |
31: "handbag", | |
32: "tie", | |
33: "suitcase", | |
34: "frisbee", | |
35: "skis", | |
36: "snowboard", | |
37: "sports ball", | |
38: "kite", | |
39: "baseball bat", | |
40: "baseball glove", | |
41: "skateboard", | |
42: "surfboard", | |
43: "tennis racket", | |
44: "bottle", | |
46: "wine glass", | |
47: "cup", | |
48: "fork", | |
49: "knife", | |
50: "spoon", | |
51: "bowl", | |
52: "banana", | |
53: "apple", | |
54: "sandwich", | |
55: "orange", | |
56: "broccoli", | |
57: "carrot", | |
58: "hot dog", | |
59: "pizza", | |
60: "donut", | |
61: "cake", | |
62: "chair", | |
63: "couch", | |
64: "potted plant", | |
65: "bed", | |
67: "dining table", | |
70: "toilet", | |
72: "tv", | |
73: "laptop", | |
74: "mouse", | |
75: "remote", | |
76: "keyboard", | |
77: "cell phone", | |
78: "microwave", | |
79: "oven", | |
80: "toaster", | |
81: "sink", | |
82: "refrigerator", | |
84: "book", | |
85: "clock", | |
86: "vase", | |
87: "scissors", | |
88: "teddy bear", | |
89: "hair drier", | |
90: "toothbrush", | |
} | |
mscoco_category2label = {k: i for i, k in enumerate(mscoco_category2name.keys())} | |
mscoco_label2category = {v: k for k, v in mscoco_category2label.items()} | |