Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
"""Demo.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1Icb8zeoaudyTDOKM1QySNay1cXzltRAp | |
""" | |
!pip install -q -U gradio peft | |
import gradio as gr | |
from tqdm.notebook import tqdm | |
from PIL import Image | |
import re | |
import torch | |
import torch.nn as nn | |
from warnings import simplefilter | |
simplefilter('ignore') | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# Seting up the model | |
from peft import PeftConfig, PeftModel | |
numeric_lora_config = PeftConfig.from_pretrained("Edgar404/donut-sroie-lora-r8-x3") | |
from transformers import VisionEncoderDecoderConfig | |
image_size = [720,960] | |
max_length = 512 | |
config = VisionEncoderDecoderConfig.from_pretrained(numeric_lora_config.base_model_name_or_path) | |
config.encoder.image_size = image_size | |
config.decoder.max_length = max_length | |
from transformers import DonutProcessor, VisionEncoderDecoderModel | |
model = VisionEncoderDecoderModel.from_pretrained(numeric_lora_config.base_model_name_or_path ,config = config ) | |
numeric_processor = DonutProcessor.from_pretrained("Edgar404/donut-sroie-lora-r8-x3") | |
model.config.pad_token_id = numeric_processor.tokenizer.pad_token_id | |
model.config.decoder_start_token_id = numeric_processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0] | |
model.decoder.resize_token_embeddings(len(numeric_processor.tokenizer)) | |
model = PeftModel.from_pretrained(model, model_id = "Edgar404/donut-sroie-lora-r8-x3", adapter_name = 'numeric') | |
model.to(device) | |
# Handwritten setting | |
hand_processor = DonutProcessor.from_pretrained("Edgar404/donut-lora-r8-x2") | |
def resize_token_handwritten(): | |
try : | |
model.load_adapter("Edgar404/donut-lora-r8-x2" ,'handwritten') | |
except Exception : | |
# resizing the handwritten embedding layer | |
embedding_layer = model.decoder.model.decoder.embed_tokens.modules_to_save.handwritten | |
old_num_tokens, old_embedding_dim = embedding_layer.weight.shape | |
new_embeddings = nn.Embedding( | |
len(hand_processor.tokenizer), old_embedding_dim | |
) | |
new_embeddings.to( | |
embedding_layer.weight.device, | |
dtype=embedding_layer.weight.dtype, | |
) | |
model.decoder.model.decoder.embed_tokens.modules_to_save.handwritten = new_embeddings | |
# Resizing the handwritten lm_head layer | |
lm_layer = model.decoder.lm_head.modules_to_save.handwritten | |
old_num_tokens, old_input_dim = lm_layer.weight.shape | |
new_lm_head = nn.Linear( | |
old_input_dim, len(hand_processor.tokenizer), | |
bias = False | |
) | |
new_lm_head.to( | |
lm_layer.weight.device, | |
dtype=lm_layer.weight.dtype, | |
) | |
model.decoder.lm_head.modules_to_save.handwritten = new_lm_head | |
resize_token_handwritten() | |
model.load_adapter("Edgar404/donut-lora-r8-x2" ,'handwritten') | |
def process_image(image , mode = 'numeric' ): | |
""" Function that takes an image and perform an OCR using the model DonUT via the task document | |
parsing | |
parameters | |
__________ | |
image : a machine readable image of class PIL or numpy""" | |
model.set_adapter(mode) | |
processor = numeric_processor if mode == 'numeric' else hand_processor | |
task_prompt = "<s_cord-v2>" | |
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids | |
pixel_values = processor(image, return_tensors="pt").pixel_values | |
outputs = model.generate( | |
pixel_values.to(device), | |
decoder_input_ids=decoder_input_ids.to(device), | |
max_length=model.decoder.config.max_position_embeddings, | |
pad_token_id=processor.tokenizer.pad_token_id, | |
eos_token_id=processor.tokenizer.eos_token_id, | |
use_cache=True, | |
bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
return_dict_in_generate=True, | |
) | |
sequence = processor.batch_decode(outputs.sequences)[0] | |
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") | |
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() | |
output = processor.token2json(sequence) | |
return output | |
import gradio as gr | |
def image_classifier(image , mode): | |
return process_image(image , mode) | |
examples_list = [['./test_images/TRAIN_00001.jpg' ,"handwritten"] , | |
['./test_images/001.jpg','numeric'], | |
['./test_images/TEST_0019.jpg' ,"handwritten"], | |
['./test_images/005.jpg','numeric'], | |
['./test_images/007.jpg','numeric'], | |
['./test_images/VALIDATION_0011.jpg' ,"handwritten"], | |
['./test_images/VALIDATION_0022.jpg' ,"handwritten"], | |
['./test_images/062.jpg','numeric'], | |
['./test_images/119.jpg','numeric'], | |
['./test_images/150.jpg','numeric'] | |
] | |
demo = gr.Interface(fn=image_classifier, inputs=["image", | |
gr.Radio(["handwritten", "numeric"], label="mode")], | |
outputs="text", | |
examples = examples_list ) | |
demo.launch(share = True) |