Spaces:
Runtime error
Runtime error
File size: 4,561 Bytes
2366e36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
#!/usr/bin/env python
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import shutil
import time
from argparse import ArgumentParser
from itertools import compress
import mmcv
from mmcv.utils import ProgressBar
from mmocr.apis import init_detector, model_inference
from mmocr.core.evaluation.ocr_metric import eval_ocr_metric
from mmocr.datasets import build_dataset # noqa: F401
from mmocr.models import build_detector # noqa: F401
from mmocr.utils import get_root_logger, list_from_file, list_to_file
def save_results(img_paths, pred_labels, gt_labels, res_dir):
"""Save predicted results to txt file.
Args:
img_paths (list[str])
pred_labels (list[str])
gt_labels (list[str])
res_dir (str)
"""
assert len(img_paths) == len(pred_labels) == len(gt_labels)
corrects = [pred == gt for pred, gt in zip(pred_labels, gt_labels)]
wrongs = [not c for c in corrects]
lines = [
f'{img} {pred} {gt}'
for img, pred, gt in zip(img_paths, pred_labels, gt_labels)
]
list_to_file(osp.join(res_dir, 'results.txt'), lines)
list_to_file(osp.join(res_dir, 'correct.txt'), compress(lines, corrects))
list_to_file(osp.join(res_dir, 'wrong.txt'), compress(lines, wrongs))
def main():
parser = ArgumentParser()
parser.add_argument('img_root_path', type=str, help='Image root path')
parser.add_argument('img_list', type=str, help='Image path list file')
parser.add_argument('config', type=str, help='Config file')
parser.add_argument('checkpoint', type=str, help='Checkpoint file')
parser.add_argument(
'--out_dir', type=str, default='./results', help='Dir to save results')
parser.add_argument(
'--show', action='store_true', help='show image or save')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference.')
args = parser.parse_args()
# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(args.out_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level='INFO')
# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)
if hasattr(model, 'module'):
model = model.module
# Start Inference
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
mmcv.mkdir_or_exist(out_vis_dir)
correct_vis_dir = osp.join(args.out_dir, 'correct')
mmcv.mkdir_or_exist(correct_vis_dir)
wrong_vis_dir = osp.join(args.out_dir, 'wrong')
mmcv.mkdir_or_exist(wrong_vis_dir)
img_paths, pred_labels, gt_labels = [], [], []
lines = list_from_file(args.img_list)
progressbar = ProgressBar(task_num=len(lines))
num_gt_label = 0
for line in lines:
progressbar.update()
item_list = line.strip().split()
img_file = item_list[0]
gt_label = ''
if len(item_list) >= 2:
gt_label = item_list[1]
num_gt_label += 1
img_path = osp.join(args.img_root_path, img_file)
if not osp.exists(img_path):
raise FileNotFoundError(img_path)
# Test a single image
result = model_inference(model, img_path)
pred_label = result['text']
out_img_name = '_'.join(img_file.split('/'))
out_file = osp.join(out_vis_dir, out_img_name)
kwargs_dict = {
'gt_label': gt_label,
'show': args.show,
'out_file': '' if args.show else out_file
}
model.show_result(img_path, result, **kwargs_dict)
if gt_label != '':
if gt_label == pred_label:
dst_file = osp.join(correct_vis_dir, out_img_name)
else:
dst_file = osp.join(wrong_vis_dir, out_img_name)
shutil.copy(out_file, dst_file)
img_paths.append(img_path)
gt_labels.append(gt_label)
pred_labels.append(pred_label)
# Save results
save_results(img_paths, pred_labels, gt_labels, args.out_dir)
if num_gt_label == len(pred_labels):
# eval
eval_results = eval_ocr_metric(pred_labels, gt_labels)
logger.info('\n' + '-' * 100)
info = ('eval on testset with img_root_path '
f'{args.img_root_path} and img_list {args.img_list}\n')
logger.info(info)
logger.info(eval_results)
print(f'\nInference done, and results saved in {args.out_dir}\n')
if __name__ == '__main__':
main()
|