File size: 5,043 Bytes
186701e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Copyright (c) OpenMMLab. All rights reserved.
"""Extracting subsets from coco2017 dataset.

This script is mainly used to debug and verify the correctness of the
program quickly.
The root folder format must be in the following format:

β”œβ”€β”€ root
β”‚   β”œβ”€β”€ annotations
β”‚   β”œβ”€β”€ train2017
β”‚   β”œβ”€β”€ val2017
β”‚   β”œβ”€β”€ test2017

Currently, only support COCO2017. In the future will support user-defined
datasets of standard coco JSON format.

Example:
   python tools/misc/extract_subcoco.py ${ROOT} ${OUT_DIR} --num-img ${NUM_IMG}
"""

import argparse
import os.path as osp
import shutil

import mmengine
import numpy as np
from pycocotools.coco import COCO


# TODO: Currently only supports coco2017
def _process_data(args,
                  in_dataset_type: str,
                  out_dataset_type: str,
                  year: str = '2017'):
    assert in_dataset_type in ('train', 'val')
    assert out_dataset_type in ('train', 'val')

    int_ann_file_name = f'annotations/instances_{in_dataset_type}{year}.json'
    out_ann_file_name = f'annotations/instances_{out_dataset_type}{year}.json'

    ann_path = osp.join(args.root, int_ann_file_name)
    json_data = mmengine.load(ann_path)

    new_json_data = {
        'info': json_data['info'],
        'licenses': json_data['licenses'],
        'categories': json_data['categories'],
        'images': [],
        'annotations': []
    }

    area_dict = {
        'small': [0., 32 * 32],
        'medium': [32 * 32, 96 * 96],
        'large': [96 * 96, float('inf')]
    }

    coco = COCO(ann_path)

    # filter annotations by category ids and area range
    areaRng = area_dict[args.area_size] if args.area_size else []
    catIds = coco.getCatIds(args.classes) if args.classes else []
    ann_ids = coco.getAnnIds(catIds=catIds, areaRng=areaRng)
    ann_info = coco.loadAnns(ann_ids)

    # get image ids by anns set
    filter_img_ids = {ann['image_id'] for ann in ann_info}
    filter_img = coco.loadImgs(filter_img_ids)

    # shuffle
    np.random.shuffle(filter_img)

    num_img = args.num_img if args.num_img > 0 else len(filter_img)
    if num_img > len(filter_img):
        print(
            f'num_img is too big, will be set to {len(filter_img)}, '
            'because of not enough image after filter by classes and area_size'
        )
        num_img = len(filter_img)

    progress_bar = mmengine.ProgressBar(num_img)

    for i in range(num_img):
        file_name = filter_img[i]['file_name']
        image_path = osp.join(args.root, in_dataset_type + year, file_name)

        ann_ids = coco.getAnnIds(
            imgIds=[filter_img[i]['id']], catIds=catIds, areaRng=areaRng)
        img_ann_info = coco.loadAnns(ann_ids)

        new_json_data['images'].append(filter_img[i])
        new_json_data['annotations'].extend(img_ann_info)

        shutil.copy(image_path, osp.join(args.out_dir,
                                         out_dataset_type + year))

        progress_bar.update()

    mmengine.dump(new_json_data, osp.join(args.out_dir, out_ann_file_name))


def _make_dirs(out_dir):
    mmengine.mkdir_or_exist(out_dir)
    mmengine.mkdir_or_exist(osp.join(out_dir, 'annotations'))
    mmengine.mkdir_or_exist(osp.join(out_dir, 'train2017'))
    mmengine.mkdir_or_exist(osp.join(out_dir, 'val2017'))


def parse_args():
    parser = argparse.ArgumentParser(description='Extract coco subset')
    parser.add_argument('root', help='root path')
    parser.add_argument(
        'out_dir', type=str, help='directory where subset coco will be saved.')
    parser.add_argument(
        '--num-img',
        default=50,
        type=int,
        help='num of extract image, -1 means all images')
    parser.add_argument(
        '--area-size',
        choices=['small', 'medium', 'large'],
        help='filter ground-truth info by area size')
    parser.add_argument(
        '--classes', nargs='+', help='filter ground-truth by class name')
    parser.add_argument(
        '--use-training-set',
        action='store_true',
        help='Whether to use the training set when extract the training set. '
        'The training subset is extracted from the validation set by '
        'default which can speed up.')
    parser.add_argument('--seed', default=-1, type=int, help='seed')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    assert args.out_dir != args.root, \
        'The file will be overwritten in place, ' \
        'so the same folder is not allowed !'

    seed = int(args.seed)
    if seed != -1:
        print(f'Set the global seed: {seed}')
        np.random.seed(int(args.seed))

    _make_dirs(args.out_dir)

    print('====Start processing train dataset====')
    if args.use_training_set:
        _process_data(args, 'train', 'train')
    else:
        _process_data(args, 'val', 'train')
    print('\n====Start processing val dataset====')
    _process_data(args, 'val', 'val')
    print(f'\n Result save to {args.out_dir}')


if __name__ == '__main__':
    main()