VietOCR / app.py
nhay103's picture
update app
e5923d4
import os
import gradio as gr
import omegaconf
import torch
from vietocr.model.transformerocr import VietOCR
from vietocr.model.vocab import Vocab
from vietocr.translate import translate, process_input
examples_data = os.listdir('examples')
examples_data = [os.path.join('examples', line.split('\t')[0]) for line in examples_data]
config = omegaconf.OmegaConf.load("vgg-seq2seq.yaml")
config = omegaconf.OmegaConf.to_container(config, resolve=True)
vocab = Vocab(config['vocab'])
model = VietOCR(len(vocab),
config['backbone'],
config['cnn'],
config['transformer'],
config['seq_modeling'])
model.load_state_dict(torch.load('train_old.pth', map_location=torch.device('cpu')))
def predict(inp):
img = process_input(inp, config['dataset']['image_height'],
config['dataset']['image_min_width'], config['dataset']['image_max_width'])
out = translate(img, model)[0].tolist()
out = vocab.decode(out)
return out
gr.Interface(fn=predict,
title='Vietnamese Handwriting Recognition',
inputs=gr.Image(type='pil'),
outputs=gr.Text(),
examples=examples_data,
).launch()