!pip install -q -U gradio peft import gradio as gr from tqdm.notebook import tqdm from PIL import Image import re import torch import torch.nn as nn from warnings import simplefilter simplefilter('ignore') device = 'cuda' if torch.cuda.is_available() else 'cpu' # Seting up the model from peft import PeftConfig, PeftModel numeric_lora_config = PeftConfig.from_pretrained("Edgar404/donut-sroie-lora-r8-x3") from transformers import VisionEncoderDecoderConfig image_size = [720,960] max_length = 512 config = VisionEncoderDecoderConfig.from_pretrained(numeric_lora_config.base_model_name_or_path) config.encoder.image_size = image_size config.decoder.max_length = max_length from transformers import DonutProcessor, VisionEncoderDecoderModel model = VisionEncoderDecoderModel.from_pretrained(numeric_lora_config.base_model_name_or_path ,config = config ) numeric_processor = DonutProcessor.from_pretrained("Edgar404/donut-sroie-lora-r8-x3") model.config.pad_token_id = numeric_processor.tokenizer.pad_token_id model.config.decoder_start_token_id = numeric_processor.tokenizer.convert_tokens_to_ids([''])[0] model.decoder.resize_token_embeddings(len(numeric_processor.tokenizer)) model = PeftModel.from_pretrained(model, model_id = "Edgar404/donut-sroie-lora-r8-x3", adapter_name = 'numeric') model.to(device) # Handwritten setting hand_processor = DonutProcessor.from_pretrained("Edgar404/donut-lora-r8-x2") def resize_token_handwritten(): try : model.load_adapter("Edgar404/donut-lora-r8-x2" ,'handwritten') except Exception : # resizing the handwritten embedding layer embedding_layer = model.decoder.model.decoder.embed_tokens.modules_to_save.handwritten old_num_tokens, old_embedding_dim = embedding_layer.weight.shape new_embeddings = nn.Embedding( len(hand_processor.tokenizer), old_embedding_dim ) new_embeddings.to( embedding_layer.weight.device, dtype=embedding_layer.weight.dtype, ) model.decoder.model.decoder.embed_tokens.modules_to_save.handwritten = new_embeddings # Resizing the handwritten lm_head layer lm_layer = model.decoder.lm_head.modules_to_save.handwritten old_num_tokens, old_input_dim = lm_layer.weight.shape new_lm_head = nn.Linear( old_input_dim, len(hand_processor.tokenizer), bias = False ) new_lm_head.to( lm_layer.weight.device, dtype=lm_layer.weight.dtype, ) model.decoder.lm_head.modules_to_save.handwritten = new_lm_head resize_token_handwritten() model.load_adapter("Edgar404/donut-lora-r8-x2" ,'handwritten') def process_image(image , mode = 'numeric' ): """ Function that takes an image and perform an OCR using the model DonUT via the task document parsing parameters __________ image : a machine readable image of class PIL or numpy""" model.set_adapter(mode) processor = numeric_processor if mode == 'numeric' else hand_processor task_prompt = "" decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids pixel_values = processor(image, return_tensors="pt").pixel_values outputs = model.generate( pixel_values.to(device), decoder_input_ids=decoder_input_ids.to(device), max_length=model.decoder.config.max_position_embeddings, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) sequence = processor.batch_decode(outputs.sequences)[0] sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() output = processor.token2json(sequence) return output def image_classifier(image , mode): return process_image(image , mode) examples_list = [['./test_images/TRAIN_00001.jpg' ,"handwritten"] , ['./test_images/001.jpg','numeric'], ['./test_images/TEST_0019.jpg' ,"handwritten"], ['./test_images/005.jpg','numeric'], ['./test_images/007.jpg','numeric'], ['./test_images/VALIDATION_0011.jpg' ,"handwritten"], ['./test_images/VALIDATION_0022.jpg' ,"handwritten"], ['./test_images/062.jpg','numeric'], ['./test_images/119.jpg','numeric'], ['./test_images/150.jpg','numeric'] ] demo = gr.Interface(fn=image_classifier, inputs=["image", gr.Radio(["handwritten", "numeric"], label="mode")], outputs="text", examples = examples_list ) demo.launch(share = True)