stevengrove
initial commit
186701e
raw history blame
No virus
5.04 kB
# Copyright (c) OpenMMLab. All rights reserved.
"""Extracting subsets from coco2017 dataset.
This script is mainly used to debug and verify the correctness of the
program quickly.
The root folder format must be in the following format:
β”œβ”€β”€ root
β”‚ β”œβ”€β”€ annotations
β”‚ β”œβ”€β”€ train2017
β”‚ β”œβ”€β”€ val2017
β”‚ β”œβ”€β”€ test2017
Currently, only support COCO2017. In the future will support user-defined
datasets of standard coco JSON format.
Example:
python tools/misc/extract_subcoco.py ${ROOT} ${OUT_DIR} --num-img ${NUM_IMG}
"""
import argparse
import os.path as osp
import shutil
import mmengine
import numpy as np
from pycocotools.coco import COCO
# TODO: Currently only supports coco2017
def _process_data(args,
in_dataset_type: str,
out_dataset_type: str,
year: str = '2017'):
assert in_dataset_type in ('train', 'val')
assert out_dataset_type in ('train', 'val')
int_ann_file_name = f'annotations/instances_{in_dataset_type}{year}.json'
out_ann_file_name = f'annotations/instances_{out_dataset_type}{year}.json'
ann_path = osp.join(args.root, int_ann_file_name)
json_data = mmengine.load(ann_path)
new_json_data = {
'info': json_data['info'],
'licenses': json_data['licenses'],
'categories': json_data['categories'],
'images': [],
'annotations': []
}
area_dict = {
'small': [0., 32 * 32],
'medium': [32 * 32, 96 * 96],
'large': [96 * 96, float('inf')]
}
coco = COCO(ann_path)
# filter annotations by category ids and area range
areaRng = area_dict[args.area_size] if args.area_size else []
catIds = coco.getCatIds(args.classes) if args.classes else []
ann_ids = coco.getAnnIds(catIds=catIds, areaRng=areaRng)
ann_info = coco.loadAnns(ann_ids)
# get image ids by anns set
filter_img_ids = {ann['image_id'] for ann in ann_info}
filter_img = coco.loadImgs(filter_img_ids)
# shuffle
np.random.shuffle(filter_img)
num_img = args.num_img if args.num_img > 0 else len(filter_img)
if num_img > len(filter_img):
print(
f'num_img is too big, will be set to {len(filter_img)}, '
'because of not enough image after filter by classes and area_size'
)
num_img = len(filter_img)
progress_bar = mmengine.ProgressBar(num_img)
for i in range(num_img):
file_name = filter_img[i]['file_name']
image_path = osp.join(args.root, in_dataset_type + year, file_name)
ann_ids = coco.getAnnIds(
imgIds=[filter_img[i]['id']], catIds=catIds, areaRng=areaRng)
img_ann_info = coco.loadAnns(ann_ids)
new_json_data['images'].append(filter_img[i])
new_json_data['annotations'].extend(img_ann_info)
shutil.copy(image_path, osp.join(args.out_dir,
out_dataset_type + year))
progress_bar.update()
mmengine.dump(new_json_data, osp.join(args.out_dir, out_ann_file_name))
def _make_dirs(out_dir):
mmengine.mkdir_or_exist(out_dir)
mmengine.mkdir_or_exist(osp.join(out_dir, 'annotations'))
mmengine.mkdir_or_exist(osp.join(out_dir, 'train2017'))
mmengine.mkdir_or_exist(osp.join(out_dir, 'val2017'))
def parse_args():
parser = argparse.ArgumentParser(description='Extract coco subset')
parser.add_argument('root', help='root path')
parser.add_argument(
'out_dir', type=str, help='directory where subset coco will be saved.')
parser.add_argument(
'--num-img',
default=50,
type=int,
help='num of extract image, -1 means all images')
parser.add_argument(
'--area-size',
choices=['small', 'medium', 'large'],
help='filter ground-truth info by area size')
parser.add_argument(
'--classes', nargs='+', help='filter ground-truth by class name')
parser.add_argument(
'--use-training-set',
action='store_true',
help='Whether to use the training set when extract the training set. '
'The training subset is extracted from the validation set by '
'default which can speed up.')
parser.add_argument('--seed', default=-1, type=int, help='seed')
args = parser.parse_args()
return args
def main():
args = parse_args()
assert args.out_dir != args.root, \
'The file will be overwritten in place, ' \
'so the same folder is not allowed !'
seed = int(args.seed)
if seed != -1:
print(f'Set the global seed: {seed}')
np.random.seed(int(args.seed))
_make_dirs(args.out_dir)
print('====Start processing train dataset====')
if args.use_training_set:
_process_data(args, 'train', 'train')
else:
_process_data(args, 'val', 'train')
print('\n====Start processing val dataset====')
_process_data(args, 'val', 'val')
print(f'\n Result save to {args.out_dir}')
if __name__ == '__main__':
main()