scfive
Resolve README.md conflict and continue rebase
e8f2571
# Copyright (c) OpenMMLab. All rights reserved.
# written by lzx
import copy
import os.path as osp
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
from mmengine.fileio import get_local_path
from mmdet.registry import DATASETS
from mmdet.datasets.api_wrappers import COCO
from mmdet.datasets.base_det_dataset import BaseDetDataset
from mmdet.datasets.coco import CocoDataset
from mmengine.utils import is_abs
@DATASETS.register_module()
class HSIDataset(CocoDataset):
"""Dataset for COCO."""
METAINFO = {
'classes':
('CB', 'MP', 'VO', 'ZO', 'TO', 'FG', 'GS', 'IP', 'IS', 'NP', 'LO', 'NO', 'NC', 'NF', 'K_N', 'K_O', 'P_P', 'P_O', 'V_Y_W', 'C_Y_W','BlueTrap','BrownTrap',
'Airport', 'Brown', 'DarkGreen', 'PeaGreen', 'FauxVineyardGreen'),
# palette is a list of color tuples, which is used for visualization.
'palette':
[(220, 20, 60), (119, 11, 32), (0, 0, 230), (106, 0, 228),
(0, 60, 100), (0, 0, 70), (250, 170, 30),
(100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30),
(165, 42, 42), (255, 77, 255),(0, 226, 252), (182, 182, 255),
(0, 82, 0), (120, 166, 157),(110, 76, 0), (174, 57, 255),
(199, 100, 0),[0, 0, 255],(199, 100, 0),
]
}
COCOAPI = COCO
# ann_id is unique in coco dataset.
ANN_ID_UNIQUE = True
def __init__(self,
*args,
seg_prefix: Optional[str] = None,
abu_prefix: Optional[str] = None,
**kwargs) -> None:
self.seg_prefix = seg_prefix
self.abu_prefix = abu_prefix
super().__init__(*args, **kwargs)
def _join_prefix(self):
"""Join ``self.data_root`` with ``self.data_prefix`` and
``self.ann_file``.
Examples:
>>> # self.data_prefix contains relative paths
>>> self.data_root = 'a/b/c'
>>> self.data_prefix = dict(img='d/e/')
>>> self.ann_file = 'f'
>>> self._join_prefix()
>>> self.data_prefix
dict(img='a/b/c/d/e')
>>> self.ann_file
'a/b/c/f'
>>> # self.data_prefix contains absolute paths
>>> self.data_root = 'a/b/c'
>>> self.data_prefix = dict(img='/d/e/')
>>> self.ann_file = 'f'
>>> self._join_prefix()
>>> self.data_prefix
dict(img='/d/e')
>>> self.ann_file
'a/b/c/f'
"""
# Automatically join annotation file path with `self.root` if
# `self.ann_file` is not an absolute path.
if not is_abs(self.ann_file) and self.ann_file:
self.ann_file = osp.join(self.data_root, self.ann_file)
# Automatically join data directory with `self.root` if path value in
# `self.data_prefix` is not an absolute path.
for data_key, prefix in self.data_prefix.items():
if isinstance(prefix, str):
if not is_abs(prefix):
self.data_prefix[data_key] = osp.join(
self.data_root, prefix)
else:
self.data_prefix[data_key] = prefix
else:
raise TypeError('prefix should be a string, but got '
f'{type(prefix)}')
if self.seg_prefix is not None:
for data_key, prefix in self.seg_prefix.items():
if isinstance(prefix, str):
if not is_abs(prefix):
self.seg_prefix[data_key] = osp.join(
self.data_root, prefix)
else:
self.seg_prefix[data_key] = prefix
else:
raise TypeError('prefix should be a string, but got '
f'{type(prefix)}')
if self.abu_prefix is not None:
for data_key, prefix in self.abu_prefix.items():
if isinstance(prefix, str):
if not is_abs(prefix):
self.abu_prefix[data_key] = osp.join(
self.data_root, prefix)
else:
self.abu_prefix[data_key] = prefix
else:
raise TypeError('prefix should be a string, but got '
f'{type(prefix)}')
def load_data_list(self) -> List[dict]:
"""Load annotations from an annotation file named as ``self.ann_file``
Returns:
List[dict]: A list of annotation.
""" # noqa: E501
with get_local_path(
self.ann_file, backend_args=self.backend_args) as local_path:
self.coco = self.COCOAPI(local_path)
# The order of returned `cat_ids` will not
# change with the order of the `classes`
self.cat_ids = self.coco.get_cat_ids(
cat_names=self.metainfo['classes'])
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
self.cat_img_map = copy.deepcopy(self.coco.cat_img_map)
img_ids = self.coco.get_img_ids()
data_list = []
total_ann_ids = []
for img_id in img_ids:
raw_img_info = self.coco.load_imgs([img_id])[0]
raw_img_info['img_id'] = img_id
ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
raw_ann_info = self.coco.load_anns(ann_ids)
total_ann_ids.extend(ann_ids)
parsed_data_info = self.parse_data_info({
'raw_ann_info':
raw_ann_info,
'raw_img_info':
raw_img_info
})
data_list.append(parsed_data_info)
if self.ANN_ID_UNIQUE:
assert len(set(total_ann_ids)) == len(
total_ann_ids
), f"Annotation ids in '{self.ann_file}' are not unique!"
del self.coco
return data_list
def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]:
"""Parse raw annotation to target format.
Args:
raw_data_info (dict): Raw data information load from ``ann_file``
Returns:
Union[dict, List[dict]]: Parsed annotation.
"""
img_info = raw_data_info['raw_img_info']
ann_info = raw_data_info['raw_ann_info']
data_info = {}
# TODO: need to change data_prefix['img'] to data_prefix['img_path']
img_path = osp.join(self.data_prefix['img'], img_info['file_name'])
if self.data_prefix.get('seg', None):
seg_map_path = osp.join(
self.data_prefix['seg'],
img_info['file_name'].rsplit('.', 1)[0] + self.seg_map_suffix)
else:
seg_map_path = None
# if self.seg_prefix is not None:
if self.seg_prefix is not None:
seg_path = osp.join(self.seg_prefix['img'], img_info['file_name']).replace('.npy', '.png')
else:
seg_path = None
if self.abu_prefix is not None:
abu_path = osp.join(self.abu_prefix['img'], img_info['file_name']).replace('.npy', '.mat')
else:
abu_path = None
data_info['img_path'] = img_path
data_info['img_id'] = img_info['img_id']
data_info['seg_map_path'] = seg_map_path
data_info['seg_path'] = seg_path
data_info['abu_path'] = abu_path
data_info['height'] = img_info['height']
data_info['width'] = img_info['width']
instances = []
for i, ann in enumerate(ann_info):
instance = {}
if ann.get('ignore', False):
continue
x1, y1, w, h = ann['bbox']
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
if inter_w * inter_h == 0:
continue
if ann['area'] <= 0 or w < 1 or h < 1:
continue
if ann['category_id'] not in self.cat_ids:
continue
bbox = [x1, y1, x1 + w, y1 + h]
if ann.get('iscrowd', False):
instance['ignore_flag'] = 1
else:
instance['ignore_flag'] = 0
instance['bbox'] = bbox
instance['bbox_label'] = self.cat2label[ann['category_id']]
if ann.get('segmentation', None):
instance['mask'] = ann['segmentation']
instances.append(instance)
data_info['instances'] = instances
return data_info
def filter_data(self) -> List[dict]:
"""Filter annotations according to filter_cfg.
Returns:
List[dict]: Filtered results.
"""
if self.test_mode:
return self.data_list
if self.filter_cfg is None:
return self.data_list
filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False)
min_size = self.filter_cfg.get('min_size', 0)
# obtain images that contain annotation
ids_with_ann = set(data_info['img_id'] for data_info in self.data_list)
# obtain images that contain annotations of the required categories
ids_in_cat = set()
for i, class_id in enumerate(self.cat_ids):
ids_in_cat |= set(self.cat_img_map[class_id])
# merge the image id sets of the two conditions and use the merged set
# to filter out images if self.filter_empty_gt=True
ids_in_cat &= ids_with_ann
valid_data_infos = []
for i, data_info in enumerate(self.data_list):
img_id = data_info['img_id']
width = data_info['width']
height = data_info['height']
if filter_empty_gt and img_id not in ids_in_cat:
continue
if min(width, height) >= min_size:
valid_data_infos.append(data_info)
return valid_data_infos
# @DATASETS.register_module()
# class HSIDataset16(HSIDataset):
# """Dataset for COCO."""
#
# METAINFO = {
# 'classes':
# ( 'TO', 'FG', 'GS', 'IP', 'IS', 'NP', 'LO', 'NO', 'NC', 'NF', 'K_N', 'K_O', 'P_P', 'P_O', 'V_Y_W', 'C_Y_W'),
# # palette is a list of color tuples, which is used for visualization.
# 'palette':
# [
# (0, 60, 100), (0, 0, 70), (250, 170, 30),
# (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30),
# (165, 42, 42), (255, 77, 255),(0, 226, 252), (182, 182, 255),
# (0, 82, 0), (120, 166, 157),(110, 76, 0), (174, 57, 255),
# (199, 100, 0),]
# }
# COCOAPI = COCO
# # ann_id is unique in coco dataset.
# ANN_ID_UNIQUE = True