ewfian's picture
add docker runtime
1ea5949
raw
history blame contribute delete
No virus
1.55 kB
import re
from transformers import DonutProcessor, VisionEncoderDecoderModel
import gradio as gr
import torch
from PIL import Image
processor = DonutProcessor.from_pretrained("ewfian/donut_cn_invoice")
model = VisionEncoderDecoderModel.from_pretrained("ewfian/donut_cn_invoice")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
task_prompt = "<s_totalAmountInWords>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
def process_document(image):
pixel_values = processor(image, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
return processor.token2json(sequence)
demo = gr.Interface(
fn=process_document,
inputs="image",
outputs="json",
title="Demo: Donut 🍩 for Chinese Invioce Parsing",
cache_examples=False)
demo.launch(server_name="0.0.0.0")