import os import gradio as gr import omegaconf import torch from vietocr.model.transformerocr import VietOCR from vietocr.model.vocab import Vocab from vietocr.translate import translate, process_input examples_data = os.listdir('examples') examples_data = [os.path.join('examples', line.split('\t')[0]) for line in examples_data] config = omegaconf.OmegaConf.load("vgg-seq2seq.yaml") config = omegaconf.OmegaConf.to_container(config, resolve=True) vocab = Vocab(config['vocab']) model = VietOCR(len(vocab), config['backbone'], config['cnn'], config['transformer'], config['seq_modeling']) model.load_state_dict(torch.load('train_old.pth', map_location=torch.device('cpu'))) def predict(inp): img = process_input(inp, config['dataset']['image_height'], config['dataset']['image_min_width'], config['dataset']['image_max_width']) out = translate(img, model)[0].tolist() out = vocab.decode(out) return out gr.Interface(fn=predict, title='Vietnamese Handwriting Recognition', inputs=gr.Image(type='pil'), outputs=gr.Text(), examples=examples_data, ).launch()