# Copyright (c) OpenMMLab. All rights reserved. import argparse import json import random from pathlib import Path import numpy as np from pycocotools.coco import COCO def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( '--json', type=str, required=True, help='COCO json label path') parser.add_argument( '--out-dir', type=str, required=True, help='output path') parser.add_argument( '--ratios', nargs='+', type=float, help='ratio for sub dataset, if set 2 number then will generate ' 'trainval + test (eg. "0.8 0.1 0.1" or "2 1 1"), if set 3 number ' 'then will generate train + val + test (eg. "0.85 0.15" or "2 1")') parser.add_argument( '--shuffle', action='store_true', help='Whether to display in disorder') parser.add_argument('--seed', default=-1, type=int, help='seed') args = parser.parse_args() return args def split_coco_dataset(coco_json_path: str, save_dir: str, ratios: list, shuffle: bool, seed: int): if not Path(coco_json_path).exists(): raise FileNotFoundError(f'Can not not found {coco_json_path}') if not Path(save_dir).exists(): Path(save_dir).mkdir(parents=True) # ratio normalize ratios = np.array(ratios) / np.array(ratios).sum() if len(ratios) == 2: ratio_train, ratio_test = ratios ratio_val = 0 train_type = 'trainval' elif len(ratios) == 3: ratio_train, ratio_val, ratio_test = ratios train_type = 'train' else: raise ValueError('ratios must set 2 or 3 group!') # Read coco info coco = COCO(coco_json_path) coco_image_ids = coco.getImgIds() # gen image number of each dataset val_image_num = int(len(coco_image_ids) * ratio_val) test_image_num = int(len(coco_image_ids) * ratio_test) train_image_num = len(coco_image_ids) - val_image_num - test_image_num print('Split info: ====== \n' f'Train ratio = {ratio_train}, number = {train_image_num}\n' f'Val ratio = {ratio_val}, number = {val_image_num}\n' f'Test ratio = {ratio_test}, number = {test_image_num}') seed = int(seed) if seed != -1: print(f'Set the global seed: {seed}') np.random.seed(seed) if shuffle: print('shuffle dataset.') random.shuffle(coco_image_ids) # split each dataset train_image_ids = coco_image_ids[:train_image_num] if val_image_num != 0: val_image_ids = coco_image_ids[train_image_num:train_image_num + val_image_num] else: val_image_ids = None test_image_ids = coco_image_ids[train_image_num + val_image_num:] # Save new json categories = coco.loadCats(coco.getCatIds()) for img_id_list in [train_image_ids, val_image_ids, test_image_ids]: if img_id_list is None: continue # Gen new json img_dict = { 'images': coco.loadImgs(ids=img_id_list), 'categories': categories, 'annotations': coco.loadAnns(coco.getAnnIds(imgIds=img_id_list)) } # save json if img_id_list == train_image_ids: json_file_path = Path(save_dir, f'{train_type}.json') elif img_id_list == val_image_ids: json_file_path = Path(save_dir, 'val.json') elif img_id_list == test_image_ids: json_file_path = Path(save_dir, 'test.json') else: raise ValueError('img_id_list ERROR!') print(f'Saving json to {json_file_path}') with open(json_file_path, 'w') as f_json: json.dump(img_dict, f_json, ensure_ascii=False, indent=2) print('All done!') def main(): args = parse_args() split_coco_dataset(args.json, args.out_dir, args.ratios, args.shuffle, args.seed) if __name__ == '__main__': main()