import argparse import logging import os import glob import tqdm import torch import PIL import cv2 import numpy as np import torch.nn.functional as F from torchvision import transforms from utils import Config, Logger, CharsetMapper def get_model(config): import importlib names = config.model_name.split('.') module_name, class_name = '.'.join(names[:-1]), names[-1] cls = getattr(importlib.import_module(module_name), class_name) model = cls(config) logging.info(model) model = model.eval() return model def preprocess(img, width, height): img = cv2.resize(np.array(img), (width, height)) img = transforms.ToTensor()(img).unsqueeze(0) mean = torch.tensor([0.485, 0.456, 0.406]) std = torch.tensor([0.229, 0.224, 0.225]) return (img-mean[...,None,None]) / std[...,None,None] def postprocess(output, charset, model_eval): def _get_output(last_output, model_eval): if isinstance(last_output, (tuple, list)): for res in last_output: if res['name'] == model_eval: output = res else: output = last_output return output def _decode(logit): """ Greed decode """ out = F.softmax(logit, dim=2) pt_text, pt_scores, pt_lengths = [], [], [] for o in out: text = charset.get_text(o.argmax(dim=1), padding=False, trim=False) text = text.split(charset.null_char)[0] # end at end-token pt_text.append(text) pt_scores.append(o.max(dim=1)[0]) pt_lengths.append(min(len(text) + 1, charset.max_length)) # one for end-token return pt_text, pt_scores, pt_lengths output = _get_output(output, model_eval) logits, pt_lengths = output['logits'], output['pt_lengths'] pt_text, pt_scores, pt_lengths_ = _decode(logits) return pt_text, pt_scores, pt_lengths_ def load(model, file, device=None, strict=True): if device is None: device = 'cpu' elif isinstance(device, int): device = torch.device('cuda', device) assert os.path.isfile(file) state = torch.load(file, map_location=device) if set(state.keys()) == {'model', 'opt'}: state = state['model'] model.load_state_dict(state, strict=strict) return model def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, default='configs/train_abinet.yaml', help='path to config file') parser.add_argument('--input', type=str, default='figs/test') parser.add_argument('--cuda', type=int, default=-1) parser.add_argument('--checkpoint', type=str, default='workdir/train-abinet/best-train-abinet.pth') parser.add_argument('--model_eval', type=str, default='alignment', choices=['alignment', 'vision', 'language']) args = parser.parse_args() config = Config(args.config) if args.checkpoint is not None: config.model_checkpoint = args.checkpoint if args.model_eval is not None: config.model_eval = args.model_eval config.global_phase = 'test' config.model_vision_checkpoint, config.model_language_checkpoint = None, None device = 'cpu' if args.cuda < 0 else f'cuda:{args.cuda}' Logger.init(config.global_workdir, config.global_name, config.global_phase) Logger.enable_file() logging.info(config) logging.info('Construct model.') model = get_model(config).to(device) model = load(model, config.model_checkpoint, device=device) charset = CharsetMapper(filename=config.dataset_charset_path, max_length=config.dataset_max_length + 1) if os.path.isdir(args.input): paths = [os.path.join(args.input, fname) for fname in os.listdir(args.input)] else: paths = glob.glob(os.path.expanduser(args.input)) assert paths, "The input path(s) was not found" paths = sorted(paths) for path in tqdm.tqdm(paths): img = PIL.Image.open(path).convert('RGB') img = preprocess(img, config.dataset_image_width, config.dataset_image_height) img = img.to(device) res = model(img) pt_text, _, __ = postprocess(res, charset, config.model_eval) logging.info(f'{path}: {pt_text[0]}') if __name__ == '__main__': main()