from vietocr.vietocr.tool.translate import build_model, translate, translate_beam_search, process_input, predict, batch_translate_beam_search from vietocr.vietocr.tool.utils import download_weights import torch from collections import defaultdict class Predictor(): def __init__(self, config): device = config['device'] model, vocab = build_model(config) weights = '/tmp/weights.pth' if config['weights'].startswith('http'): weights = download_weights(config['weights']) else: weights = config['weights'] model.load_state_dict(torch.load(weights, map_location=torch.device(device))) self.config = config self.model = model self.vocab = vocab self.device = device def predict(self, img, return_prob=False): img = process_input(img, self.config['dataset']['image_height'], self.config['dataset']['image_min_width'], self.config['dataset']['image_max_width']) img = img.to(self.config['device']) if self.config['predictor']['beamsearch']: sent = translate_beam_search(img, self.model) s = sent prob = None else: s, prob = translate(img, self.model) s = s[0].tolist() prob = prob[0] s = self.vocab.decode(s) if return_prob: return s, prob else: return s def predict_batch(self, imgs, return_prob=False): bucket = defaultdict(list) bucket_idx = defaultdict(list) bucket_pred = {} sents, probs = [0]*len(imgs), [0]*len(imgs) for i, img in enumerate(imgs): img = process_input(img, self.config['dataset']['image_height'], self.config['dataset']['image_min_width'], self.config['dataset']['image_max_width']) bucket[img.shape[-1]].append(img) bucket_idx[img.shape[-1]].append(i) for k, batch in bucket.items(): batch = torch.cat(batch, 0).to(self.device) s, prob = translate(batch, self.model) prob = prob.tolist() s = s.tolist() s = self.vocab.batch_decode(s) bucket_pred[k] = (s, prob) for k in bucket_pred: idx = bucket_idx[k] sent, prob = bucket_pred[k] for i, j in enumerate(idx): sents[j] = sent[i] probs[j] = prob[i] if return_prob: return sents, probs else: return sents