Spaces:
Running
on
T4
Running
on
T4
# Copyright (c) OpenMMLab. All rights reserved. | |
"""Extracting subsets from coco2017 dataset. | |
This script is mainly used to debug and verify the correctness of the | |
program quickly. | |
The root folder format must be in the following format: | |
βββ root | |
β βββ annotations | |
β βββ train2017 | |
β βββ val2017 | |
β βββ test2017 | |
Currently, only support COCO2017. In the future will support user-defined | |
datasets of standard coco JSON format. | |
Example: | |
python tools/misc/extract_subcoco.py ${ROOT} ${OUT_DIR} --num-img ${NUM_IMG} | |
""" | |
import argparse | |
import os.path as osp | |
import shutil | |
import mmengine | |
import numpy as np | |
from pycocotools.coco import COCO | |
# TODO: Currently only supports coco2017 | |
def _process_data(args, | |
in_dataset_type: str, | |
out_dataset_type: str, | |
year: str = '2017'): | |
assert in_dataset_type in ('train', 'val') | |
assert out_dataset_type in ('train', 'val') | |
int_ann_file_name = f'annotations/instances_{in_dataset_type}{year}.json' | |
out_ann_file_name = f'annotations/instances_{out_dataset_type}{year}.json' | |
ann_path = osp.join(args.root, int_ann_file_name) | |
json_data = mmengine.load(ann_path) | |
new_json_data = { | |
'info': json_data['info'], | |
'licenses': json_data['licenses'], | |
'categories': json_data['categories'], | |
'images': [], | |
'annotations': [] | |
} | |
area_dict = { | |
'small': [0., 32 * 32], | |
'medium': [32 * 32, 96 * 96], | |
'large': [96 * 96, float('inf')] | |
} | |
coco = COCO(ann_path) | |
# filter annotations by category ids and area range | |
areaRng = area_dict[args.area_size] if args.area_size else [] | |
catIds = coco.getCatIds(args.classes) if args.classes else [] | |
ann_ids = coco.getAnnIds(catIds=catIds, areaRng=areaRng) | |
ann_info = coco.loadAnns(ann_ids) | |
# get image ids by anns set | |
filter_img_ids = {ann['image_id'] for ann in ann_info} | |
filter_img = coco.loadImgs(filter_img_ids) | |
# shuffle | |
np.random.shuffle(filter_img) | |
num_img = args.num_img if args.num_img > 0 else len(filter_img) | |
if num_img > len(filter_img): | |
print( | |
f'num_img is too big, will be set to {len(filter_img)}, ' | |
'because of not enough image after filter by classes and area_size' | |
) | |
num_img = len(filter_img) | |
progress_bar = mmengine.ProgressBar(num_img) | |
for i in range(num_img): | |
file_name = filter_img[i]['file_name'] | |
image_path = osp.join(args.root, in_dataset_type + year, file_name) | |
ann_ids = coco.getAnnIds( | |
imgIds=[filter_img[i]['id']], catIds=catIds, areaRng=areaRng) | |
img_ann_info = coco.loadAnns(ann_ids) | |
new_json_data['images'].append(filter_img[i]) | |
new_json_data['annotations'].extend(img_ann_info) | |
shutil.copy(image_path, osp.join(args.out_dir, | |
out_dataset_type + year)) | |
progress_bar.update() | |
mmengine.dump(new_json_data, osp.join(args.out_dir, out_ann_file_name)) | |
def _make_dirs(out_dir): | |
mmengine.mkdir_or_exist(out_dir) | |
mmengine.mkdir_or_exist(osp.join(out_dir, 'annotations')) | |
mmengine.mkdir_or_exist(osp.join(out_dir, 'train2017')) | |
mmengine.mkdir_or_exist(osp.join(out_dir, 'val2017')) | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Extract coco subset') | |
parser.add_argument('root', help='root path') | |
parser.add_argument( | |
'out_dir', type=str, help='directory where subset coco will be saved.') | |
parser.add_argument( | |
'--num-img', | |
default=50, | |
type=int, | |
help='num of extract image, -1 means all images') | |
parser.add_argument( | |
'--area-size', | |
choices=['small', 'medium', 'large'], | |
help='filter ground-truth info by area size') | |
parser.add_argument( | |
'--classes', nargs='+', help='filter ground-truth by class name') | |
parser.add_argument( | |
'--use-training-set', | |
action='store_true', | |
help='Whether to use the training set when extract the training set. ' | |
'The training subset is extracted from the validation set by ' | |
'default which can speed up.') | |
parser.add_argument('--seed', default=-1, type=int, help='seed') | |
args = parser.parse_args() | |
return args | |
def main(): | |
args = parse_args() | |
assert args.out_dir != args.root, \ | |
'The file will be overwritten in place, ' \ | |
'so the same folder is not allowed !' | |
seed = int(args.seed) | |
if seed != -1: | |
print(f'Set the global seed: {seed}') | |
np.random.seed(int(args.seed)) | |
_make_dirs(args.out_dir) | |
print('====Start processing train dataset====') | |
if args.use_training_set: | |
_process_data(args, 'train', 'train') | |
else: | |
_process_data(args, 'val', 'train') | |
print('\n====Start processing val dataset====') | |
_process_data(args, 'val', 'val') | |
print(f'\n Result save to {args.out_dir}') | |
if __name__ == '__main__': | |
main() | |