microScan / torch_utils /coco_utils.py
crazyscientist1's picture
initial commit
d70f24c
raw
history blame
8.73 kB
import copy
import os
import torch
import torch.utils.data
import torchvision
# import transforms as T
from pycocotools import mask as coco_mask
from pycocotools.coco import COCO
class FilterAndRemapCocoCategories:
def __init__(self, categories, remap=True):
self.categories = categories
self.remap = remap
def __call__(self, image, target):
anno = target["annotations"]
anno = [obj for obj in anno if obj["category_id"] in self.categories]
if not self.remap:
target["annotations"] = anno
return image, target
anno = copy.deepcopy(anno)
for obj in anno:
obj["category_id"] = self.categories.index(obj["category_id"])
target["annotations"] = anno
return image, target
def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
if masks:
masks = torch.stack(masks, dim=0)
else:
masks = torch.zeros((0, height, width), dtype=torch.uint8)
return masks
class ConvertCocoPolysToMask:
def __call__(self, image, target):
w, h = image.size
image_id = target["image_id"]
image_id = torch.tensor([image_id])
anno = target["annotations"]
anno = [obj for obj in anno if obj["iscrowd"] == 0]
boxes = [obj["bbox"] for obj in anno]
# guard against no boxes via resizing
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
boxes[:, 2:] += boxes[:, :2]
boxes[:, 0::2].clamp_(min=0, max=w)
boxes[:, 1::2].clamp_(min=0, max=h)
classes = [obj["category_id"] for obj in anno]
classes = torch.tensor(classes, dtype=torch.int64)
segmentations = [obj["segmentation"] for obj in anno]
masks = convert_coco_poly_to_mask(segmentations, h, w)
keypoints = None
if anno and "keypoints" in anno[0]:
keypoints = [obj["keypoints"] for obj in anno]
keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
num_keypoints = keypoints.shape[0]
if num_keypoints:
keypoints = keypoints.view(num_keypoints, -1, 3)
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
boxes = boxes[keep]
classes = classes[keep]
masks = masks[keep]
if keypoints is not None:
keypoints = keypoints[keep]
target = {}
target["boxes"] = boxes
target["labels"] = classes
target["masks"] = masks
target["image_id"] = image_id
if keypoints is not None:
target["keypoints"] = keypoints
# for conversion to coco api
area = torch.tensor([obj["area"] for obj in anno])
iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
target["area"] = area
target["iscrowd"] = iscrowd
return image, target
def _coco_remove_images_without_annotations(dataset, cat_list=None):
def _has_only_empty_bbox(anno):
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
def _count_visible_keypoints(anno):
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
min_keypoints_per_image = 10
def _has_valid_annotation(anno):
# if it's empty, there is no annotation
if len(anno) == 0:
return False
# if all boxes have close to zero area, there is no annotation
if _has_only_empty_bbox(anno):
return False
# keypoints task have a slight different critera for considering
# if an annotation is valid
if "keypoints" not in anno[0]:
return True
# for keypoint detection tasks, only consider valid images those
# containing at least min_keypoints_per_image
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
return True
return False
assert isinstance(dataset, torchvision.datasets.CocoDetection)
ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = dataset.coco.loadAnns(ann_ids)
if cat_list:
anno = [obj for obj in anno if obj["category_id"] in cat_list]
if _has_valid_annotation(anno):
ids.append(ds_idx)
dataset = torch.utils.data.Subset(dataset, ids)
return dataset
def convert_to_coco_api(ds):
coco_ds = COCO()
# annotation IDs need to start at 1, not 0, see torchvision issue #1530
ann_id = 1
dataset = {"images": [], "categories": [], "annotations": []}
categories = set()
for img_idx in range(len(ds)):
# find better way to get target
# targets = ds.get_annotations(img_idx)
img, targets = ds[img_idx]
image_id = targets["image_id"].item()
img_dict = {}
img_dict["id"] = image_id
img_dict["height"] = img.shape[-2]
img_dict["width"] = img.shape[-1]
dataset["images"].append(img_dict)
bboxes = targets["boxes"]
bboxes[:, 2:] -= bboxes[:, :2]
bboxes = bboxes.tolist()
labels = targets["labels"].tolist()
areas = targets["area"].tolist()
iscrowd = targets["iscrowd"].tolist()
if "masks" in targets:
masks = targets["masks"]
# make masks Fortran contiguous for coco_mask
masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
if "keypoints" in targets:
keypoints = targets["keypoints"]
keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
num_objs = len(bboxes)
for i in range(num_objs):
ann = {}
ann["image_id"] = image_id
ann["bbox"] = bboxes[i]
ann["category_id"] = labels[i]
categories.add(labels[i])
ann["area"] = areas[i]
ann["iscrowd"] = iscrowd[i]
ann["id"] = ann_id
if "masks" in targets:
ann["segmentation"] = coco_mask.encode(masks[i].numpy())
if "keypoints" in targets:
ann["keypoints"] = keypoints[i]
ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
dataset["annotations"].append(ann)
ann_id += 1
dataset["categories"] = [{"id": i} for i in sorted(categories)]
coco_ds.dataset = dataset
coco_ds.createIndex()
return coco_ds
def get_coco_api_from_dataset(dataset):
for _ in range(10):
if isinstance(dataset, torchvision.datasets.CocoDetection):
break
if isinstance(dataset, torch.utils.data.Subset):
dataset = dataset.dataset
if isinstance(dataset, torchvision.datasets.CocoDetection):
return dataset.coco
return convert_to_coco_api(dataset)
class CocoDetection(torchvision.datasets.CocoDetection):
def __init__(self, img_folder, ann_file, transforms):
super().__init__(img_folder, ann_file)
self._transforms = transforms
def __getitem__(self, idx):
img, target = super().__getitem__(idx)
image_id = self.ids[idx]
target = dict(image_id=image_id, annotations=target)
if self._transforms is not None:
img, target = self._transforms(img, target)
return img, target
def get_coco(root, image_set, transforms, mode="instances"):
anno_file_template = "{}_{}2017.json"
PATHS = {
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
"val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))),
# "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
}
t = [ConvertCocoPolysToMask()]
if transforms is not None:
t.append(transforms)
transforms = T.Compose(t)
img_folder, ann_file = PATHS[image_set]
img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file)
dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
if image_set == "train":
dataset = _coco_remove_images_without_annotations(dataset)
# dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])
return dataset
def get_coco_kp(root, image_set, transforms):
return get_coco(root, image_set, transforms, mode="person_keypoints")