| """ |
| copy and modified https://github.com/pytorch/vision/blob/main/references/detection/coco_utils.py |
| |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. |
| """ |
|
|
|
|
| import torch |
| import torch.utils.data |
| import torchvision |
| import torchvision.transforms.functional as TVF |
| import faster_coco_eval.core.mask as coco_mask |
| from faster_coco_eval import COCO |
|
|
|
|
| def convert_coco_poly_to_mask(segmentations, height, width): |
| masks = [] |
| for polygons in segmentations: |
| rles = coco_mask.frPyObjects(polygons, height, width) |
| mask = coco_mask.decode(rles) |
| if len(mask.shape) < 3: |
| mask = mask[..., None] |
| mask = torch.as_tensor(mask, dtype=torch.uint8) |
| mask = mask.any(dim=2) |
| masks.append(mask) |
| if masks: |
| masks = torch.stack(masks, dim=0) |
| else: |
| masks = torch.zeros((0, height, width), dtype=torch.uint8) |
| return masks |
|
|
|
|
| class ConvertCocoPolysToMask: |
| def __call__(self, image, target): |
| w, h = image.size |
|
|
| image_id = target["image_id"] |
|
|
| anno = target["annotations"] |
|
|
| anno = [obj for obj in anno if obj["iscrowd"] == 0] |
|
|
| boxes = [obj["bbox"] for obj in anno] |
| |
| boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) |
| boxes[:, 2:] += boxes[:, :2] |
| boxes[:, 0::2].clamp_(min=0, max=w) |
| boxes[:, 1::2].clamp_(min=0, max=h) |
|
|
| classes = [obj["category_id"] for obj in anno] |
| classes = torch.tensor(classes, dtype=torch.int64) |
|
|
| segmentations = [obj["segmentation"] for obj in anno] |
| masks = convert_coco_poly_to_mask(segmentations, h, w) |
|
|
| keypoints = None |
| if anno and "keypoints" in anno[0]: |
| keypoints = [obj["keypoints"] for obj in anno] |
| keypoints = torch.as_tensor(keypoints, dtype=torch.float32) |
| num_keypoints = keypoints.shape[0] |
| if num_keypoints: |
| keypoints = keypoints.view(num_keypoints, -1, 3) |
|
|
| keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) |
| boxes = boxes[keep] |
| classes = classes[keep] |
| masks = masks[keep] |
| if keypoints is not None: |
| keypoints = keypoints[keep] |
|
|
| target = {} |
| target["boxes"] = boxes |
| target["labels"] = classes |
| target["masks"] = masks |
| target["image_id"] = image_id |
| if keypoints is not None: |
| target["keypoints"] = keypoints |
|
|
| |
| area = torch.tensor([obj["area"] for obj in anno]) |
| iscrowd = torch.tensor([obj["iscrowd"] for obj in anno]) |
| target["area"] = area |
| target["iscrowd"] = iscrowd |
|
|
| return image, target |
|
|
|
|
| def _coco_remove_images_without_annotations(dataset, cat_list=None): |
| def _has_only_empty_bbox(anno): |
| return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) |
|
|
| def _count_visible_keypoints(anno): |
| return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) |
|
|
| min_keypoints_per_image = 10 |
|
|
| def _has_valid_annotation(anno): |
| |
| if len(anno) == 0: |
| return False |
| |
| if _has_only_empty_bbox(anno): |
| return False |
| |
| |
| if "keypoints" not in anno[0]: |
| return True |
| |
| |
| if _count_visible_keypoints(anno) >= min_keypoints_per_image: |
| return True |
| return False |
|
|
| ids = [] |
| for ds_idx, img_id in enumerate(dataset.ids): |
| ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) |
| anno = dataset.coco.loadAnns(ann_ids) |
| if cat_list: |
| anno = [obj for obj in anno if obj["category_id"] in cat_list] |
| if _has_valid_annotation(anno): |
| ids.append(ds_idx) |
|
|
| dataset = torch.utils.data.Subset(dataset, ids) |
| return dataset |
|
|
|
|
| def convert_to_coco_api(ds): |
| coco_ds = COCO() |
| |
| ann_id = 1 |
| dataset = {"images": [], "categories": [], "annotations": []} |
| categories = set() |
| for img_idx in range(len(ds)): |
| |
| |
| |
|
|
| img, targets = ds.load_item(img_idx) |
| width, height = img.size |
|
|
| image_id = targets["image_id"].item() |
| img_dict = {} |
| img_dict["id"] = image_id |
| img_dict["width"] = width |
| img_dict["height"] = height |
| dataset["images"].append(img_dict) |
| bboxes = targets["boxes"].clone() |
| bboxes[:, 2:] -= bboxes[:, :2] |
| bboxes = bboxes.tolist() |
| labels = targets["labels"].tolist() |
| areas = targets["area"].tolist() |
| iscrowd = targets["iscrowd"].tolist() |
| if "masks" in targets: |
| masks = targets["masks"] |
| |
| masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1) |
| if "keypoints" in targets: |
| keypoints = targets["keypoints"] |
| keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist() |
| num_objs = len(bboxes) |
| for i in range(num_objs): |
| ann = {} |
| ann["image_id"] = image_id |
| ann["bbox"] = bboxes[i] |
| ann["category_id"] = labels[i] |
| categories.add(labels[i]) |
| ann["area"] = areas[i] |
| ann["iscrowd"] = iscrowd[i] |
| ann["id"] = ann_id |
| if "masks" in targets: |
| ann["segmentation"] = coco_mask.encode(masks[i].numpy()) |
| if "keypoints" in targets: |
| ann["keypoints"] = keypoints[i] |
| ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3]) |
| dataset["annotations"].append(ann) |
| ann_id += 1 |
| dataset["categories"] = [{"id": i} for i in sorted(categories)] |
| coco_ds.dataset = dataset |
| coco_ds.createIndex() |
| return coco_ds |
|
|
|
|
| def get_coco_api_from_dataset(dataset): |
| |
| for _ in range(10): |
| if isinstance(dataset, torchvision.datasets.CocoDetection): |
| break |
| if isinstance(dataset, torch.utils.data.Subset): |
| dataset = dataset.dataset |
| if isinstance(dataset, torchvision.datasets.CocoDetection): |
| return dataset.coco |
| return convert_to_coco_api(dataset) |
|
|