ChequeEasy / predict_cheque_parser.py
shivi's picture
updated model to extract bank_name and cheque_date
0f164ad
raw
history blame
No virus
5.27 kB
from transformers import DonutProcessor, VisionEncoderDecoderModel
import pkg_resources
from symspellpy import SymSpell
from word2number import w2n
from dateutil import relativedelta
from datetime import datetime
from word2number import w2n
from PIL import Image
import torch
import re
CHEQUE_PARSER_MODEL = "shivi/donut-cheque-parser"
TASK_PROMPT = "<parse-cheque>"
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_donut_model_and_processor():
donut_processor = DonutProcessor.from_pretrained(CHEQUE_PARSER_MODEL)
model = VisionEncoderDecoderModel.from_pretrained(CHEQUE_PARSER_MODEL)
model.to(device)
return donut_processor, model
def prepare_data_using_processor(donut_processor,image_path):
## Pass image through donut processor's feature extractor and retrieve image tensor
image = load_image(image_path)
pixel_values = donut_processor(image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
## Pass task prompt for document (cheque) parsing task to donut processor's tokenizer and retrieve the input_ids
decoder_input_ids = donut_processor.tokenizer(TASK_PROMPT, add_special_tokens=False, return_tensors="pt")["input_ids"]
decoder_input_ids = decoder_input_ids.to(device)
return pixel_values, decoder_input_ids
def load_image(image_path):
image = Image.open(image_path).convert("RGB")
return image
def parse_cheque_with_donut(input_image_path):
donut_processor, model = load_donut_model_and_processor()
cheque_image_tensor, input_for_decoder = prepare_data_using_processor(donut_processor,input_image_path)
outputs = model.generate(cheque_image_tensor,
decoder_input_ids=input_for_decoder,
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=donut_processor.tokenizer.pad_token_id,
eos_token_id=donut_processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[donut_processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
output_scores=True,)
decoded_output_sequence = donut_processor.batch_decode(outputs.sequences)[0]
extracted_cheque_details = decoded_output_sequence.replace(donut_processor.tokenizer.eos_token, "").replace(donut_processor.tokenizer.pad_token, "")
## remove task prompt from token sequence
cleaned_cheque_details = re.sub(r"<.*?>", "", extracted_cheque_details, count=1).strip()
## generate ordered json sequence from output token sequence
cheque_details_json = donut_processor.token2json(cleaned_cheque_details)
print("cheque_details_json:",cheque_details_json['cheque_details'])
## extract required fields from predicted json
amt_in_words = cheque_details_json['cheque_details'][0]['amt_in_words']
amt_in_figures = cheque_details_json['cheque_details'][1]['amt_in_figures']
macthing_amts = match_legal_and_courstesy_amount(amt_in_words,amt_in_figures)
payee_name = cheque_details_json['cheque_details'][2]['payee_name']
bank_name = cheque_details_json['cheque_details'][3]['bank_name']
cheque_date = cheque_details_json['cheque_details'][4]['cheque_date']
stale_cheque = check_if_cheque_is_stale(cheque_date)
return payee_name,amt_in_words,amt_in_figures,bank_name,cheque_date,macthing_amts,stale_cheque
def spell_check(amt_in_words):
sym_spell = SymSpell(max_dictionary_edit_distance=2,prefix_length=7)
dictionary_path = pkg_resources.resource_filename("symspellpy", "frequency_dictionary_82_765.txt")
bigram_path = pkg_resources.resource_filename("symspellpy", "frequency_bigramdictionary_en_243_342.txt")
sym_spell.load_dictionary(dictionary_path, term_index=0, count_index=1)
sym_spell.load_bigram_dictionary(bigram_path, term_index=0, count_index=2)
suggestions = sym_spell.lookup_compound(amt_in_words, max_edit_distance=2)
return suggestions[0].term
def match_legal_and_courstesy_amount(legal_amount,courtesy_amount):
macthing_amts = False
if len(legal_amount) == 0:
return macthing_amts
corrected_amt_in_words = spell_check(legal_amount)
print("corrected_amt_in_words:",corrected_amt_in_words)
numeric_legal_amt = w2n.word_to_num(corrected_amt_in_words)
print("numeric_legal_amt:",numeric_legal_amt)
if int(numeric_legal_amt) == int(courtesy_amount):
macthing_amts = True
return macthing_amts
def check_if_cheque_is_stale(cheque_issue_date):
stale_check = False
current_date = datetime.now().strftime('%d/%m/%y')
current_date_ = datetime.strptime(current_date, "%d/%m/%y")
cheque_issue_date_ = datetime.strptime(cheque_issue_date, "%d/%m/%y")
relative_diff = relativedelta.relativedelta(current_date_, cheque_issue_date_)
months_difference = (relative_diff.years * 12) + relative_diff.months
print("months_difference:",months_difference)
if months_difference > 3:
stale_check = True
return stale_check