Spaces:
Runtime error
Runtime error
File size: 5,069 Bytes
34c41c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# -*- coding: utf-8 -*-
"""Demo.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1Icb8zeoaudyTDOKM1QySNay1cXzltRAp
"""
import gradio as gr
from tqdm.notebook import tqdm
from PIL import Image
import re
import torch
import torch.nn as nn
from warnings import simplefilter
simplefilter('ignore')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Seting up the model
from peft import PeftConfig, PeftModel
numeric_lora_config = PeftConfig.from_pretrained("Edgar404/donut-sroie-lora-r8-x3")
from transformers import VisionEncoderDecoderConfig
image_size = [720,960]
max_length = 512
config = VisionEncoderDecoderConfig.from_pretrained(numeric_lora_config.base_model_name_or_path)
config.encoder.image_size = image_size
config.decoder.max_length = max_length
from transformers import DonutProcessor, VisionEncoderDecoderModel
model = VisionEncoderDecoderModel.from_pretrained(numeric_lora_config.base_model_name_or_path ,config = config )
numeric_processor = DonutProcessor.from_pretrained("Edgar404/donut-sroie-lora-r8-x3")
model.config.pad_token_id = numeric_processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = numeric_processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0]
model.decoder.resize_token_embeddings(len(numeric_processor.tokenizer))
model = PeftModel.from_pretrained(model, model_id = "Edgar404/donut-sroie-lora-r8-x3", adapter_name = 'numeric')
model.to(device)
# Handwritten setting
hand_processor = DonutProcessor.from_pretrained("Edgar404/donut-lora-r8-x2")
def resize_token_handwritten():
try :
model.load_adapter("Edgar404/donut-lora-r8-x2" ,'handwritten')
except Exception :
# resizing the handwritten embedding layer
embedding_layer = model.decoder.model.decoder.embed_tokens.modules_to_save.handwritten
old_num_tokens, old_embedding_dim = embedding_layer.weight.shape
new_embeddings = nn.Embedding(
len(hand_processor.tokenizer), old_embedding_dim
)
new_embeddings.to(
embedding_layer.weight.device,
dtype=embedding_layer.weight.dtype,
)
model.decoder.model.decoder.embed_tokens.modules_to_save.handwritten = new_embeddings
# Resizing the handwritten lm_head layer
lm_layer = model.decoder.lm_head.modules_to_save.handwritten
old_num_tokens, old_input_dim = lm_layer.weight.shape
new_lm_head = nn.Linear(
old_input_dim, len(hand_processor.tokenizer),
bias = False
)
new_lm_head.to(
lm_layer.weight.device,
dtype=lm_layer.weight.dtype,
)
model.decoder.lm_head.modules_to_save.handwritten = new_lm_head
resize_token_handwritten()
model.load_adapter("Edgar404/donut-lora-r8-x2" ,'handwritten')
def process_image(image , mode = 'numeric' ):
""" Function that takes an image and perform an OCR using the model DonUT via the task document
parsing
parameters
__________
image : a machine readable image of class PIL or numpy"""
model.set_adapter(mode)
processor = numeric_processor if mode == 'numeric' else hand_processor
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
pixel_values = processor(image, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
output = processor.token2json(sequence)
return output
import gradio as gr
def image_classifier(image , mode):
return process_image(image , mode)
examples_list = [['./test_images/TRAIN_00001.jpg' ,"handwritten"] ,
['./test_images/001.jpg','numeric'],
['./test_images/TEST_0019.jpg' ,"handwritten"],
['./test_images/005.jpg','numeric'],
['./test_images/007.jpg','numeric'],
['./test_images/VALIDATION_0011.jpg' ,"handwritten"],
['./test_images/VALIDATION_0022.jpg' ,"handwritten"],
['./test_images/062.jpg','numeric'],
['./test_images/119.jpg','numeric'],
['./test_images/150.jpg','numeric']
]
demo = gr.Interface(fn=image_classifier, inputs=["image",
gr.Radio(["handwritten", "numeric"], label="mode")],
outputs="text",
examples = examples_list )
demo.launch(share = True) |