Spaces:
Runtime error
Runtime error
File size: 4,381 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from functools import partial
import mmcv
import numpy as np
from mmocr.utils import bezier_to_polygon, sort_points
# The default dictionary used by CurvedSynthText
dict95 = [
' ', '!', '"', '#', '$', '%', '&', '\'', '(', ')', '*', '+', ',', '-', '.',
'/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=',
'>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L',
'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[',
'\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j',
'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y',
'z', '{', '|', '}', '~'
]
UNK = len(dict95)
EOS = UNK + 1
def digit2text(rec):
res = []
for d in rec:
assert d <= EOS
if d == EOS:
break
if d == UNK:
print('Warning: Has a UNK character')
res.append('口') # Or any special character not in the target dict
res.append(dict95[d])
return ''.join(res)
def modify_annotation(ann, num_sample, start_img_id=0, start_ann_id=0):
ann['text'] = digit2text(ann.pop('rec'))
# Get hide egmentation points
polygon_pts = bezier_to_polygon(ann['bezier_pts'], num_sample=num_sample)
ann['segmentation'] = np.asarray(sort_points(polygon_pts)).reshape(
1, -1).tolist()
ann['image_id'] += start_img_id
ann['id'] += start_ann_id
return ann
def modify_image_info(image_info, path_prefix, start_img_id=0):
image_info['file_name'] = osp.join(path_prefix, image_info['file_name'])
image_info['id'] += start_img_id
return image_info
def parse_args():
parser = argparse.ArgumentParser(
description='Convert CurvedSynText150k to COCO format')
parser.add_argument('root_path', help='CurvedSynText150k root path')
parser.add_argument('-o', '--out-dir', help='Output path')
parser.add_argument(
'-n',
'--num-sample',
type=int,
default=4,
help='Number of sample points at each Bezier curve.')
parser.add_argument(
'--nproc', default=1, type=int, help='Number of processes')
args = parser.parse_args()
return args
def convert_annotations(data,
path_prefix,
num_sample,
nproc,
start_img_id=0,
start_ann_id=0):
modify_image_info_with_params = partial(
modify_image_info, path_prefix=path_prefix, start_img_id=start_img_id)
modify_annotation_with_params = partial(
modify_annotation,
num_sample=num_sample,
start_img_id=start_img_id,
start_ann_id=start_ann_id)
if nproc > 1:
data['annotations'] = mmcv.track_parallel_progress(
modify_annotation_with_params, data['annotations'], nproc=nproc)
data['images'] = mmcv.track_parallel_progress(
modify_image_info_with_params, data['images'], nproc=nproc)
else:
data['annotations'] = mmcv.track_progress(
modify_annotation_with_params, data['annotations'])
data['images'] = mmcv.track_progress(
modify_image_info_with_params,
data['images'],
)
data['categories'] = [{'id': 1, 'name': 'text'}]
return data
def main():
args = parse_args()
root_path = args.root_path
out_dir = args.out_dir if args.out_dir else root_path
mmcv.mkdir_or_exist(out_dir)
anns = mmcv.load(osp.join(root_path, 'train1.json'))
data1 = convert_annotations(anns, 'syntext_word_eng', args.num_sample,
args.nproc)
# Get the maximum image id from data1
start_img_id = max(data1['images'], key=lambda x: x['id'])['id'] + 1
start_ann_id = max(data1['annotations'], key=lambda x: x['id'])['id'] + 1
anns = mmcv.load(osp.join(root_path, 'train2.json'))
data2 = convert_annotations(
anns,
'emcs_imgs',
args.num_sample,
args.nproc,
start_img_id=start_img_id,
start_ann_id=start_ann_id)
data1['images'] += data2['images']
data1['annotations'] += data2['annotations']
mmcv.dump(data1, osp.join(out_dir, 'instances_training.json'))
if __name__ == '__main__':
main()
|