|
import os
|
|
import gradio as gr
|
|
import omegaconf
|
|
import torch
|
|
from vietocr.model.transformerocr import VietOCR
|
|
from vietocr.model.vocab import Vocab
|
|
from vietocr.translate import translate, process_input
|
|
|
|
examples_data = os.listdir('examples')
|
|
examples_data = [os.path.join('examples', line.split('\t')[0]) for line in examples_data]
|
|
|
|
config = omegaconf.OmegaConf.load("vgg-seq2seq.yaml")
|
|
config = omegaconf.OmegaConf.to_container(config, resolve=True)
|
|
|
|
vocab = Vocab(config['vocab'])
|
|
model = VietOCR(len(vocab),
|
|
config['backbone'],
|
|
config['cnn'],
|
|
config['transformer'],
|
|
config['seq_modeling'])
|
|
model.load_state_dict(torch.load('train_old.pth', map_location=torch.device('cpu')))
|
|
|
|
def predict(inp):
|
|
img = process_input(inp, config['dataset']['image_height'],
|
|
config['dataset']['image_min_width'], config['dataset']['image_max_width'])
|
|
out = translate(img, model)[0].tolist()
|
|
out = vocab.decode(out)
|
|
return out
|
|
|
|
gr.Interface(fn=predict,
|
|
title='Vietnamese Handwriting Recognition',
|
|
inputs=gr.Image(type='pil'),
|
|
outputs=gr.Text(),
|
|
examples=examples_data,
|
|
).launch() |