File size: 3,963 Bytes
186701e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import json
import random
from pathlib import Path

import numpy as np
from pycocotools.coco import COCO


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--json', type=str, required=True, help='COCO json label path')
    parser.add_argument(
        '--out-dir', type=str, required=True, help='output path')
    parser.add_argument(
        '--ratios',
        nargs='+',
        type=float,
        help='ratio for sub dataset, if set 2 number then will generate '
        'trainval + test (eg. "0.8 0.1 0.1" or "2 1 1"), if set 3 number '
        'then will generate train + val + test (eg. "0.85 0.15" or "2 1")')
    parser.add_argument(
        '--shuffle',
        action='store_true',
        help='Whether to display in disorder')
    parser.add_argument('--seed', default=-1, type=int, help='seed')
    args = parser.parse_args()
    return args


def split_coco_dataset(coco_json_path: str, save_dir: str, ratios: list,
                       shuffle: bool, seed: int):
    if not Path(coco_json_path).exists():
        raise FileNotFoundError(f'Can not not found {coco_json_path}')

    if not Path(save_dir).exists():
        Path(save_dir).mkdir(parents=True)

    # ratio normalize
    ratios = np.array(ratios) / np.array(ratios).sum()

    if len(ratios) == 2:
        ratio_train, ratio_test = ratios
        ratio_val = 0
        train_type = 'trainval'
    elif len(ratios) == 3:
        ratio_train, ratio_val, ratio_test = ratios
        train_type = 'train'
    else:
        raise ValueError('ratios must set 2 or 3 group!')

    # Read coco info
    coco = COCO(coco_json_path)
    coco_image_ids = coco.getImgIds()

    # gen image number of each dataset
    val_image_num = int(len(coco_image_ids) * ratio_val)
    test_image_num = int(len(coco_image_ids) * ratio_test)
    train_image_num = len(coco_image_ids) - val_image_num - test_image_num
    print('Split info: ====== \n'
          f'Train ratio = {ratio_train}, number = {train_image_num}\n'
          f'Val ratio = {ratio_val}, number = {val_image_num}\n'
          f'Test ratio = {ratio_test}, number = {test_image_num}')

    seed = int(seed)
    if seed != -1:
        print(f'Set the global seed: {seed}')
        np.random.seed(seed)

    if shuffle:
        print('shuffle dataset.')
        random.shuffle(coco_image_ids)

    # split each dataset
    train_image_ids = coco_image_ids[:train_image_num]
    if val_image_num != 0:
        val_image_ids = coco_image_ids[train_image_num:train_image_num +
                                       val_image_num]
    else:
        val_image_ids = None
    test_image_ids = coco_image_ids[train_image_num + val_image_num:]

    # Save new json
    categories = coco.loadCats(coco.getCatIds())
    for img_id_list in [train_image_ids, val_image_ids, test_image_ids]:
        if img_id_list is None:
            continue

        # Gen new json
        img_dict = {
            'images': coco.loadImgs(ids=img_id_list),
            'categories': categories,
            'annotations': coco.loadAnns(coco.getAnnIds(imgIds=img_id_list))
        }

        # save json
        if img_id_list == train_image_ids:
            json_file_path = Path(save_dir, f'{train_type}.json')
        elif img_id_list == val_image_ids:
            json_file_path = Path(save_dir, 'val.json')
        elif img_id_list == test_image_ids:
            json_file_path = Path(save_dir, 'test.json')
        else:
            raise ValueError('img_id_list ERROR!')

        print(f'Saving json to {json_file_path}')
        with open(json_file_path, 'w') as f_json:
            json.dump(img_dict, f_json, ensure_ascii=False, indent=2)

    print('All done!')


def main():
    args = parse_args()
    split_coco_dataset(args.json, args.out_dir, args.ratios, args.shuffle,
                       args.seed)


if __name__ == '__main__':
    main()