Spaces:
Build error
Build error
import argparse | |
import logging | |
import os | |
import glob | |
import tqdm | |
import torch | |
import PIL | |
import cv2 | |
import numpy as np | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from utils import Config, Logger, CharsetMapper | |
def get_model(config): | |
import importlib | |
names = config.model_name.split('.') | |
module_name, class_name = '.'.join(names[:-1]), names[-1] | |
cls = getattr(importlib.import_module(module_name), class_name) | |
model = cls(config) | |
logging.info(model) | |
model = model.eval() | |
return model | |
def preprocess(img, width, height): | |
img = cv2.resize(np.array(img), (width, height)) | |
img = transforms.ToTensor()(img).unsqueeze(0) | |
mean = torch.tensor([0.485, 0.456, 0.406]) | |
std = torch.tensor([0.229, 0.224, 0.225]) | |
return (img-mean[...,None,None]) / std[...,None,None] | |
def postprocess(output, charset, model_eval): | |
def _get_output(last_output, model_eval): | |
if isinstance(last_output, (tuple, list)): | |
for res in last_output: | |
if res['name'] == model_eval: output = res | |
else: output = last_output | |
return output | |
def _decode(logit): | |
""" Greed decode """ | |
out = F.softmax(logit, dim=2) | |
pt_text, pt_scores, pt_lengths = [], [], [] | |
for o in out: | |
text = charset.get_text(o.argmax(dim=1), padding=False, trim=False) | |
text = text.split(charset.null_char)[0] # end at end-token | |
pt_text.append(text) | |
pt_scores.append(o.max(dim=1)[0]) | |
pt_lengths.append(min(len(text) + 1, charset.max_length)) # one for end-token | |
return pt_text, pt_scores, pt_lengths | |
output = _get_output(output, model_eval) | |
logits, pt_lengths = output['logits'], output['pt_lengths'] | |
pt_text, pt_scores, pt_lengths_ = _decode(logits) | |
return pt_text, pt_scores, pt_lengths_ | |
def load(model, file, device=None, strict=True): | |
if device is None: device = 'cpu' | |
elif isinstance(device, int): device = torch.device('cuda', device) | |
assert os.path.isfile(file) | |
state = torch.load(file, map_location=device) | |
if set(state.keys()) == {'model', 'opt'}: | |
state = state['model'] | |
model.load_state_dict(state, strict=strict) | |
return model | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--config', type=str, default='configs/train_abinet.yaml', | |
help='path to config file') | |
parser.add_argument('--input', type=str, default='figs/test') | |
parser.add_argument('--cuda', type=int, default=-1) | |
parser.add_argument('--checkpoint', type=str, default='workdir/train-abinet/best-train-abinet.pth') | |
parser.add_argument('--model_eval', type=str, default='alignment', | |
choices=['alignment', 'vision', 'language']) | |
args = parser.parse_args() | |
config = Config(args.config) | |
if args.checkpoint is not None: config.model_checkpoint = args.checkpoint | |
if args.model_eval is not None: config.model_eval = args.model_eval | |
config.global_phase = 'test' | |
config.model_vision_checkpoint, config.model_language_checkpoint = None, None | |
device = 'cpu' if args.cuda < 0 else f'cuda:{args.cuda}' | |
Logger.init(config.global_workdir, config.global_name, config.global_phase) | |
Logger.enable_file() | |
logging.info(config) | |
logging.info('Construct model.') | |
model = get_model(config).to(device) | |
model = load(model, config.model_checkpoint, device=device) | |
charset = CharsetMapper(filename=config.dataset_charset_path, | |
max_length=config.dataset_max_length + 1) | |
if os.path.isdir(args.input): | |
paths = [os.path.join(args.input, fname) for fname in os.listdir(args.input)] | |
else: | |
paths = glob.glob(os.path.expanduser(args.input)) | |
assert paths, "The input path(s) was not found" | |
paths = sorted(paths) | |
for path in tqdm.tqdm(paths): | |
img = PIL.Image.open(path).convert('RGB') | |
img = preprocess(img, config.dataset_image_width, config.dataset_image_height) | |
img = img.to(device) | |
res = model(img) | |
pt_text, _, __ = postprocess(res, charset, config.model_eval) | |
logging.info(f'{path}: {pt_text[0]}') | |
if __name__ == '__main__': | |
main() | |