Donut_prototype / demo.py
Edgar404's picture
Update demo.py
54ff2c7 verified
raw
history blame contribute delete
No virus
4.89 kB
!pip install -q -U gradio peft
import gradio as gr
from tqdm.notebook import tqdm
from PIL import Image
import re
import torch
import torch.nn as nn
from warnings import simplefilter
simplefilter('ignore')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Seting up the model
from peft import PeftConfig, PeftModel
numeric_lora_config = PeftConfig.from_pretrained("Edgar404/donut-sroie-lora-r8-x3")
from transformers import VisionEncoderDecoderConfig
image_size = [720,960]
max_length = 512
config = VisionEncoderDecoderConfig.from_pretrained(numeric_lora_config.base_model_name_or_path)
config.encoder.image_size = image_size
config.decoder.max_length = max_length
from transformers import DonutProcessor, VisionEncoderDecoderModel
model = VisionEncoderDecoderModel.from_pretrained(numeric_lora_config.base_model_name_or_path ,config = config )
numeric_processor = DonutProcessor.from_pretrained("Edgar404/donut-sroie-lora-r8-x3")
model.config.pad_token_id = numeric_processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = numeric_processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0]
model.decoder.resize_token_embeddings(len(numeric_processor.tokenizer))
model = PeftModel.from_pretrained(model, model_id = "Edgar404/donut-sroie-lora-r8-x3", adapter_name = 'numeric')
model.to(device)
# Handwritten setting
hand_processor = DonutProcessor.from_pretrained("Edgar404/donut-lora-r8-x2")
def resize_token_handwritten():
try :
model.load_adapter("Edgar404/donut-lora-r8-x2" ,'handwritten')
except Exception :
# resizing the handwritten embedding layer
embedding_layer = model.decoder.model.decoder.embed_tokens.modules_to_save.handwritten
old_num_tokens, old_embedding_dim = embedding_layer.weight.shape
new_embeddings = nn.Embedding(
len(hand_processor.tokenizer), old_embedding_dim
)
new_embeddings.to(
embedding_layer.weight.device,
dtype=embedding_layer.weight.dtype,
)
model.decoder.model.decoder.embed_tokens.modules_to_save.handwritten = new_embeddings
# Resizing the handwritten lm_head layer
lm_layer = model.decoder.lm_head.modules_to_save.handwritten
old_num_tokens, old_input_dim = lm_layer.weight.shape
new_lm_head = nn.Linear(
old_input_dim, len(hand_processor.tokenizer),
bias = False
)
new_lm_head.to(
lm_layer.weight.device,
dtype=lm_layer.weight.dtype,
)
model.decoder.lm_head.modules_to_save.handwritten = new_lm_head
resize_token_handwritten()
model.load_adapter("Edgar404/donut-lora-r8-x2" ,'handwritten')
def process_image(image , mode = 'numeric' ):
""" Function that takes an image and perform an OCR using the model DonUT via the task document
parsing
parameters
__________
image : a machine readable image of class PIL or numpy"""
model.set_adapter(mode)
processor = numeric_processor if mode == 'numeric' else hand_processor
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
pixel_values = processor(image, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
output = processor.token2json(sequence)
return output
def image_classifier(image , mode):
return process_image(image , mode)
examples_list = [['./test_images/TRAIN_00001.jpg' ,"handwritten"] ,
['./test_images/001.jpg','numeric'],
['./test_images/TEST_0019.jpg' ,"handwritten"],
['./test_images/005.jpg','numeric'],
['./test_images/007.jpg','numeric'],
['./test_images/VALIDATION_0011.jpg' ,"handwritten"],
['./test_images/VALIDATION_0022.jpg' ,"handwritten"],
['./test_images/062.jpg','numeric'],
['./test_images/119.jpg','numeric'],
['./test_images/150.jpg','numeric']
]
demo = gr.Interface(fn=image_classifier, inputs=["image",
gr.Radio(["handwritten", "numeric"], label="mode")],
outputs="text",
examples = examples_list )
demo.launch(share = True)