Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
import logging | |
import os.path as osp | |
import pickle | |
from typing import List, Union | |
import h5py | |
import tqdm | |
from mmdet.datasets.api_wrappers import COCO | |
from mmdet.datasets.base_det_dataset import BaseDetDataset | |
from mmdet.registry import DATASETS | |
from mmengine.fileio import get_local_path | |
from mmengine.logging import print_log | |
class MASADataset(BaseDetDataset): | |
"""Dataset for COCO.""" | |
METAINFO = { | |
"classes": ("object"), | |
# palette is a list of color tuples, which is used for visualization. | |
"palette": [(220, 20, 60)], | |
} | |
COCOAPI = COCO | |
# ann_id is unique in coco dataset. | |
ANN_ID_UNIQUE = True | |
def __init__(self, anno_hdf5_path=None, img_prefix=None, *args, **kwargs): | |
self.anno_hdf5_path = anno_hdf5_path | |
self.img_prefix = img_prefix | |
super().__init__(*args, **kwargs) | |
def read_dicts_from_hdf5(self, hdf5_file_path, pkl_file_path): | |
with h5py.File(hdf5_file_path, "r") as hf: | |
# Retrieve the dataset corresponding to the specified .pkl file path | |
dataset = hf[pkl_file_path] | |
binary_data = dataset[()] | |
# Deserialize the binary data and load the list of dictionaries | |
list_of_dicts = pickle.loads(binary_data) | |
return list_of_dicts | |
def get_ann_info(self, img_info): | |
"""Get COCO annotation by index. | |
Args: | |
idx (int): Index of data. | |
Returns: | |
dict: Annotation info of specified index. | |
""" | |
if self.anno_hdf5_path is not None: | |
try: | |
ann_info = self.read_dicts_from_hdf5( | |
self.anno_hdf5_path, img_info["file_name"].replace(".jpg", ".pkl") | |
) | |
return ann_info | |
except: | |
print(self.anno_hdf5_path) | |
print(img_info["file_name"].replace(".jpg", ".pkl")) | |
return None | |
else: | |
img_id = img_info["id"] | |
ann_ids = self.coco.get_ann_ids(img_ids=[img_id], cat_ids=self.cat_ids) | |
ann_info = self.coco.load_anns(ann_ids) | |
return ann_info | |
def __getitem__(self, idx: int) -> dict: | |
"""Get the idx-th image and data information of dataset after | |
``self.pipeline``, and ``full_init`` will be called if the dataset has | |
not been fully initialized. | |
During training phase, if ``self.pipeline`` get ``None``, | |
``self._rand_another`` will be called until a valid image is fetched or | |
the maximum limit of refetech is reached. | |
Args: | |
idx (int): The index of self.data_list. | |
Returns: | |
dict: The idx-th image and data information of dataset after | |
``self.pipeline``. | |
""" | |
# Performing full initialization by calling `__getitem__` will consume | |
# extra memory. If a dataset is not fully initialized by setting | |
# `lazy_init=True` and then fed into the dataloader. Different workers | |
# will simultaneously read and parse the annotation. It will cost more | |
# time and memory, although this may work. Therefore, it is recommended | |
# to manually call `full_init` before dataset fed into dataloader to | |
# ensure all workers use shared RAM from master process. | |
if not self._fully_initialized: | |
print_log( | |
"Please call `full_init()` method manually to accelerate " "the speed.", | |
logger="current", | |
level=logging.WARNING, | |
) | |
self.full_init() | |
if self.test_mode: | |
data = self.prepare_data(idx) | |
if data is None: | |
raise Exception( | |
"Test time pipline should not get `None` " "data_sample" | |
) | |
return data | |
for _ in range(self.max_refetch + 1): | |
try: | |
data = self.prepare_data(idx) | |
except Exception as e: | |
data = None | |
# Broken images or random augmentations may cause the returned data | |
# to be None | |
if data is None: | |
idx = self._rand_another() | |
continue | |
return data | |
raise Exception( | |
f"Cannot find valid image after {self.max_refetch}! " | |
"Please check your image path and pipeline" | |
) | |
def load_data_list(self) -> List[dict]: | |
"""Load annotations from an annotation file named as ``self.ann_file`` | |
Returns: | |
List[dict]: A list of annotation. | |
""" # noqa: E501 | |
with get_local_path( | |
self.ann_file, backend_args=self.backend_args | |
) as local_path: | |
self.coco = self.COCOAPI(local_path) | |
# The order of returned `cat_ids` will not | |
# change with the order of the `classes` | |
self.cat_ids = self.coco.get_cat_ids(cat_names=self.metainfo["classes"]) | |
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} | |
self.cat_img_map = copy.deepcopy(self.coco.cat_img_map) | |
img_ids = self.coco.get_img_ids() | |
data_list = [] | |
total_ann_ids = [] | |
print("Loading data list...") | |
for img_id in tqdm.tqdm(img_ids): | |
raw_img_info = self.coco.load_imgs([img_id])[0] | |
raw_img_info["img_id"] = img_id | |
ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) | |
raw_ann_info = self.coco.load_anns(ann_ids) | |
total_ann_ids.extend(ann_ids) | |
parsed_data_info = self.parse_data_info( | |
{"raw_ann_info": raw_ann_info, "raw_img_info": raw_img_info} | |
) | |
data_list.append(parsed_data_info) | |
if self.ANN_ID_UNIQUE: | |
assert len(set(total_ann_ids)) == len( | |
total_ann_ids | |
), f"Annotation ids in '{self.ann_file}' are not unique!" | |
del self.coco | |
return data_list | |
def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: | |
"""Parse raw annotation to target format. | |
Args: | |
raw_data_info (dict): Raw data information load from ``ann_file`` | |
Returns: | |
Union[dict, List[dict]]: Parsed annotation. | |
""" | |
img_info = raw_data_info["raw_img_info"] | |
ann_info = raw_data_info["raw_ann_info"] | |
data_info = {} | |
# TODO: need to change data_prefix['img'] to data_prefix['img_path'] | |
img_path = osp.join(self.data_prefix["img"], img_info["file_name"]) | |
if self.data_prefix.get("seg", None): | |
seg_map_path = osp.join( | |
self.data_prefix["seg"], | |
img_info["file_name"].rsplit(".", 1)[0] + self.seg_map_suffix, | |
) | |
else: | |
seg_map_path = None | |
data_info["img_path"] = img_path | |
data_info["img_id"] = img_info["img_id"] | |
data_info["seg_map_path"] = seg_map_path | |
data_info["height"] = img_info["height"] | |
data_info["width"] = img_info["width"] | |
if self.return_classes: | |
data_info["text"] = self.metainfo["classes"] | |
data_info["caption_prompt"] = self.caption_prompt | |
data_info["custom_entities"] = True | |
instances = [] | |
for i, ann in enumerate(ann_info): | |
instance = {} | |
if ann.get("ignore", False): | |
continue | |
x1, y1, w, h = ann["bbox"] | |
inter_w = max(0, min(x1 + w, img_info["width"]) - max(x1, 0)) | |
inter_h = max(0, min(y1 + h, img_info["height"]) - max(y1, 0)) | |
if inter_w * inter_h == 0: | |
continue | |
if ann["area"] <= 0 or w < 1 or h < 1: | |
continue | |
if "category_id" not in ann: | |
ann["category_id"] = 1 | |
if ann["category_id"] not in self.cat_ids: | |
continue | |
bbox = [x1, y1, x1 + w, y1 + h] | |
if ann.get("iscrowd", False): | |
instance["ignore_flag"] = 1 | |
else: | |
instance["ignore_flag"] = 0 | |
instance["bbox"] = bbox | |
instance["bbox_label"] = self.cat2label[ann["category_id"]] | |
if ann.get("segmentation", None): | |
instance["mask"] = ann["segmentation"] | |
if "instance_id" in ann: | |
instance["instance_id"] = ann["instance_id"] | |
else: | |
instance["instance_id"] = ann["id"] | |
instances.append(instance) | |
data_info["instances"] = instances | |
return data_info | |
def filter_data(self) -> List[dict]: | |
"""Filter annotations according to filter_cfg. | |
Returns: | |
List[dict]: Filtered results. | |
""" | |
if self.test_mode: | |
return self.data_list | |
if self.filter_cfg is None: | |
return self.data_list | |
filter_empty_gt = self.filter_cfg.get("filter_empty_gt", False) | |
min_size = self.filter_cfg.get("min_size", 0) | |
# obtain images that contain annotation | |
ids_with_ann = set(data_info["img_id"] for data_info in self.data_list) | |
# obtain images that contain annotations of the required categories | |
ids_in_cat = set() | |
for i, class_id in enumerate(self.cat_ids): | |
ids_in_cat |= set(self.cat_img_map[class_id]) | |
# merge the image id sets of the two conditions and use the merged set | |
# to filter out images if self.filter_empty_gt=True | |
ids_in_cat &= ids_with_ann | |
valid_data_infos = [] | |
for i, data_info in enumerate(self.data_list): | |
img_id = data_info["img_id"] | |
width = data_info["width"] | |
height = data_info["height"] | |
if filter_empty_gt and img_id not in ids_in_cat: | |
continue | |
if min(width, height) >= min_size: | |
valid_data_infos.append(data_info) | |
return valid_data_infos | |