Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# Copyright (c) OpenMMLab. All rights reserved. | |
import os.path as osp | |
import shutil | |
import time | |
from argparse import ArgumentParser | |
from itertools import compress | |
import mmcv | |
from mmcv.utils import ProgressBar | |
from mmocr.apis import init_detector, model_inference | |
from mmocr.core.evaluation.ocr_metric import eval_ocr_metric | |
from mmocr.datasets import build_dataset # noqa: F401 | |
from mmocr.models import build_detector # noqa: F401 | |
from mmocr.utils import get_root_logger, list_from_file, list_to_file | |
def save_results(img_paths, pred_labels, gt_labels, res_dir): | |
"""Save predicted results to txt file. | |
Args: | |
img_paths (list[str]) | |
pred_labels (list[str]) | |
gt_labels (list[str]) | |
res_dir (str) | |
""" | |
assert len(img_paths) == len(pred_labels) == len(gt_labels) | |
corrects = [pred == gt for pred, gt in zip(pred_labels, gt_labels)] | |
wrongs = [not c for c in corrects] | |
lines = [ | |
f'{img} {pred} {gt}' | |
for img, pred, gt in zip(img_paths, pred_labels, gt_labels) | |
] | |
list_to_file(osp.join(res_dir, 'results.txt'), lines) | |
list_to_file(osp.join(res_dir, 'correct.txt'), compress(lines, corrects)) | |
list_to_file(osp.join(res_dir, 'wrong.txt'), compress(lines, wrongs)) | |
def main(): | |
parser = ArgumentParser() | |
parser.add_argument('img_root_path', type=str, help='Image root path') | |
parser.add_argument('img_list', type=str, help='Image path list file') | |
parser.add_argument('config', type=str, help='Config file') | |
parser.add_argument('checkpoint', type=str, help='Checkpoint file') | |
parser.add_argument( | |
'--out_dir', type=str, default='./results', help='Dir to save results') | |
parser.add_argument( | |
'--show', action='store_true', help='show image or save') | |
parser.add_argument( | |
'--device', default='cuda:0', help='Device used for inference.') | |
args = parser.parse_args() | |
# init the logger before other steps | |
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) | |
log_file = osp.join(args.out_dir, f'{timestamp}.log') | |
logger = get_root_logger(log_file=log_file, log_level='INFO') | |
# build the model from a config file and a checkpoint file | |
model = init_detector(args.config, args.checkpoint, device=args.device) | |
if hasattr(model, 'module'): | |
model = model.module | |
# Start Inference | |
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir') | |
mmcv.mkdir_or_exist(out_vis_dir) | |
correct_vis_dir = osp.join(args.out_dir, 'correct') | |
mmcv.mkdir_or_exist(correct_vis_dir) | |
wrong_vis_dir = osp.join(args.out_dir, 'wrong') | |
mmcv.mkdir_or_exist(wrong_vis_dir) | |
img_paths, pred_labels, gt_labels = [], [], [] | |
lines = list_from_file(args.img_list) | |
progressbar = ProgressBar(task_num=len(lines)) | |
num_gt_label = 0 | |
for line in lines: | |
progressbar.update() | |
item_list = line.strip().split() | |
img_file = item_list[0] | |
gt_label = '' | |
if len(item_list) >= 2: | |
gt_label = item_list[1] | |
num_gt_label += 1 | |
img_path = osp.join(args.img_root_path, img_file) | |
if not osp.exists(img_path): | |
raise FileNotFoundError(img_path) | |
# Test a single image | |
result = model_inference(model, img_path) | |
pred_label = result['text'] | |
out_img_name = '_'.join(img_file.split('/')) | |
out_file = osp.join(out_vis_dir, out_img_name) | |
kwargs_dict = { | |
'gt_label': gt_label, | |
'show': args.show, | |
'out_file': '' if args.show else out_file | |
} | |
model.show_result(img_path, result, **kwargs_dict) | |
if gt_label != '': | |
if gt_label == pred_label: | |
dst_file = osp.join(correct_vis_dir, out_img_name) | |
else: | |
dst_file = osp.join(wrong_vis_dir, out_img_name) | |
shutil.copy(out_file, dst_file) | |
img_paths.append(img_path) | |
gt_labels.append(gt_label) | |
pred_labels.append(pred_label) | |
# Save results | |
save_results(img_paths, pred_labels, gt_labels, args.out_dir) | |
if num_gt_label == len(pred_labels): | |
# eval | |
eval_results = eval_ocr_metric(pred_labels, gt_labels) | |
logger.info('\n' + '-' * 100) | |
info = ('eval on testset with img_root_path ' | |
f'{args.img_root_path} and img_list {args.img_list}\n') | |
logger.info(info) | |
logger.info(eval_results) | |
print(f'\nInference done, and results saved in {args.out_dir}\n') | |
if __name__ == '__main__': | |
main() | |