# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import argparse import os.path as osp from functools import partial from glob import glob import numpy as np from mmengine.utils import (mkdir_or_exist, track_parallel_progress, track_progress) from PIL import Image COCO_LEN = 123287 clsID_to_trID = { 0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10, 12: 11, 13: 12, 14: 13, 15: 14, 16: 15, 17: 16, 18: 17, 19: 18, 20: 19, 21: 20, 22: 21, 23: 22, 24: 23, 26: 24, 27: 25, 30: 26, 31: 27, 32: 28, 33: 29, 34: 30, 35: 31, 36: 32, 37: 33, 38: 34, 39: 35, 40: 36, 41: 37, 42: 38, 43: 39, 45: 40, 46: 41, 47: 42, 48: 43, 49: 44, 50: 45, 51: 46, 52: 47, 53: 48, 54: 49, 55: 50, 56: 51, 57: 52, 58: 53, 59: 54, 60: 55, 61: 56, 62: 57, 63: 58, 64: 59, 66: 60, 69: 61, 71: 62, 72: 63, 73: 64, 74: 65, 75: 66, 76: 67, 77: 68, 78: 69, 79: 70, 80: 71, 81: 72, 83: 73, 84: 74, 85: 75, 86: 76, 87: 77, 88: 78, 89: 79, 91: 80, 92: 81, 93: 82, 94: 83, 95: 84, 96: 85, 97: 86, 98: 87, 99: 88, 100: 89, 101: 90, 102: 91, 103: 92, 104: 93, 105: 94, 106: 95, 107: 96, 108: 97, 109: 98, 110: 99, 111: 100, 112: 101, 113: 102, 114: 103, 115: 104, 116: 105, 117: 106, 118: 107, 119: 108, 120: 109, 121: 110, 122: 111, 123: 112, 124: 113, 125: 114, 126: 115, 127: 116, 128: 117, 129: 118, 130: 119, 131: 120, 132: 121, 133: 122, 134: 123, 135: 124, 136: 125, 137: 126, 138: 127, 139: 128, 140: 129, 141: 130, 142: 131, 143: 132, 144: 133, 145: 134, 146: 135, 147: 136, 148: 137, 149: 138, 150: 139, 151: 140, 152: 141, 153: 142, 154: 143, 155: 144, 156: 145, 157: 146, 158: 147, 159: 148, 160: 149, 161: 150, 162: 151, 163: 152, 164: 153, 165: 154, 166: 155, 167: 156, 168: 157, 169: 158, 170: 159, 171: 160, 172: 161, 173: 162, 174: 163, 175: 164, 176: 165, 177: 166, 178: 167, 179: 168, 180: 169, 181: 170, 255: 255 } def convert_to_trainID(maskpath, out_mask_dir, is_train): mask = np.array(Image.open(maskpath)) mask_copy = mask.copy() for clsID, trID in clsID_to_trID.items(): mask_copy[mask == clsID] = trID seg_filename = osp.join(out_mask_dir, 'train2017', osp.basename(maskpath)) if is_train else osp.join( out_mask_dir, 'val2017', osp.basename(maskpath)) Image.fromarray(mask_copy).save(seg_filename, 'PNG') def parse_args(): parser = argparse.ArgumentParser( description=\ 'Convert COCO Stuff 164k annotations to mmdet format') # noqa parser.add_argument('coco_path', help='coco stuff path') parser.add_argument( '--out-dir-name', '-o', default='stuffthingmaps_semseg', help='output path') parser.add_argument( '--nproc', default=16, type=int, help='number of process') args = parser.parse_args() return args def main(): args = parse_args() coco_path = args.coco_path out_dir = osp.join(coco_path, args.out_dir_name) nproc = args.nproc mkdir_or_exist(osp.join(out_dir, 'train2017')) mkdir_or_exist(osp.join(out_dir, 'val2017')) train_list = glob(osp.join(coco_path, 'stuffthingmaps/train2017', '*.png')) val_list = glob(osp.join(coco_path, 'stuffthingmaps/val2017', '*.png')) assert (len(train_list) + len(val_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format( len(train_list), len(val_list)) if args.nproc > 1: track_parallel_progress( partial(convert_to_trainID, out_mask_dir=out_dir, is_train=True), train_list, nproc=nproc) track_parallel_progress( partial(convert_to_trainID, out_mask_dir=out_dir, is_train=False), val_list, nproc=nproc) else: track_progress( partial(convert_to_trainID, out_mask_dir=out_dir, is_train=True), train_list) track_progress( partial(convert_to_trainID, out_mask_dir=out_dir, is_train=False), val_list) print('Done!') if __name__ == '__main__': main()