test2 / tools /convert_datasets /pascal_context.py
mccaly's picture
Upload 660 files
b13b124
raw history blame
No virus
2.88 kB
import argparse
import os.path as osp
from functools import partial
import mmcv
import numpy as np
from detail import Detail
from PIL import Image
_mapping = np.sort(
np.array([
0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25, 284,
158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45, 46, 308, 59,
440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458, 34, 207, 80, 355,
85, 347, 220, 349, 360, 98, 187, 104, 105, 366, 189, 368, 113, 115
]))
_key = np.array(range(len(_mapping))).astype('uint8')
def generate_labels(img_id, detail, out_dir):
def _class_to_index(mask, _mapping, _key):
# assert the values
values = np.unique(mask)
for i in range(len(values)):
assert (values[i] in _mapping)
index = np.digitize(mask.ravel(), _mapping, right=True)
return _key[index].reshape(mask.shape)
mask = Image.fromarray(
_class_to_index(detail.getMask(img_id), _mapping=_mapping, _key=_key))
filename = img_id['file_name']
mask.save(osp.join(out_dir, filename.replace('jpg', 'png')))
return osp.splitext(osp.basename(filename))[0]
def parse_args():
parser = argparse.ArgumentParser(
description='Convert PASCAL VOC annotations to mmsegmentation format')
parser.add_argument('devkit_path', help='pascal voc devkit path')
parser.add_argument('json_path', help='annoation json filepath')
parser.add_argument('-o', '--out_dir', help='output path')
args = parser.parse_args()
return args
def main():
args = parse_args()
devkit_path = args.devkit_path
if args.out_dir is None:
out_dir = osp.join(devkit_path, 'VOC2010', 'SegmentationClassContext')
else:
out_dir = args.out_dir
json_path = args.json_path
mmcv.mkdir_or_exist(out_dir)
img_dir = osp.join(devkit_path, 'VOC2010', 'JPEGImages')
train_detail = Detail(json_path, img_dir, 'train')
train_ids = train_detail.getImgs()
val_detail = Detail(json_path, img_dir, 'val')
val_ids = val_detail.getImgs()
mmcv.mkdir_or_exist(
osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext'))
train_list = mmcv.track_progress(
partial(generate_labels, detail=train_detail, out_dir=out_dir),
train_ids)
with open(
osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext',
'train.txt'), 'w') as f:
f.writelines(line + '\n' for line in sorted(train_list))
val_list = mmcv.track_progress(
partial(generate_labels, detail=val_detail, out_dir=out_dir), val_ids)
with open(
osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext',
'val.txt'), 'w') as f:
f.writelines(line + '\n' for line in sorted(val_list))
print('Done!')
if __name__ == '__main__':
main()