|
import torch |
|
from sconf import Config |
|
from PIL import Image, ImageOps |
|
from donut import DonutConfig, DonutModel |
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
from transformers import logging |
|
|
|
logging.set_verbosity_warning() |
|
|
|
config = Config(default="./config.yaml") |
|
|
|
model = DonutModel.from_pretrained( |
|
config.pretrained_model_name_or_path, |
|
input_size=config.input_size, |
|
max_length=config.max_position_embeddings, |
|
align_long_axis=config.align_long_axis, |
|
ignore_mismatched_sizes=True, |
|
) |
|
|
|
task_name = "matricula" |
|
task_prompt = f"<s_{task_name}>" |
|
|
|
def predict_matricula(model, task_name, image): |
|
image = ImageOps.exif_transpose(image) |
|
image = image.resize(size=(1280, 960), |
|
resample=Image.Resampling.NEAREST) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
model.eval() |
|
model.to(device) |
|
|
|
result = model.inference(image=image, prompt=f"<s_{task_name}>")["predictions"][0] |
|
return result |
|
|
|
|
|
import gradio as gr |
|
|
|
demo = gr.Interface( |
|
fn=lambda x:predict_matricula(model, task_name="matricula", image=x), |
|
title="Demo: Donut 🍩 for DR Matriculas", |
|
description="Dominican Vehicle **Matriculas OCR** Infering", |
|
inputs=gr.Image(label="Matricula", sources="upload", type="pil", show_label=True), |
|
outputs=[gr.JSON(label="Matricula JSON", show_label=True, value={})] |
|
) |
|
|
|
demo.launch(share=True) |
|
|