import torch from typing import Any, Optional from transformers import LayoutLMv2ForQuestionAnswering from transformers import LayoutLMv2Processor from transformers import LayoutLMv2FeatureExtractor from transformers import LayoutLMv2ImageProcessor from transformers import LayoutLMv2TokenizerFast from transformers.tokenization_utils_base import BatchEncoding from transformers.tokenization_utils_base import TruncationStrategy from transformers.utils import TensorType # from transformers.modeling_outputs import ( # QuestionAnsweringModelOutput as QuestionAnsweringModelOutputBase # ) import numpy as np # from PIL import Image, ImageDraw, ImageFont # from subprocess import run import pdf2image # from pprint import pprint import logging from os import environ # from dataclasses import dataclass # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # install tesseract-ocr and pytesseract # run("apt install -y tesseract-ocr", shell=True, check=True) feature_extractor = LayoutLMv2FeatureExtractor() # @dataclass # class QuestionAnsweringModelOutput(QuestionAnsweringModelOutputBase): # token_logits: Optional[torch.FloatTensor] = None class NoOCRReaderFound(Exception): def __init__(self, e): self.e = e def __str__(self): return f"Could not load OCR Reader: {self.e}" def pdf_to_image(b: bytes): # First, try to extract text directly # TODO: This library requires poppler, which is not present everywhere. # We should look into alternatives. We could also gracefully handle this # and simply fall back to _only_ extracted text images = [x.convert("RGB") for x in pdf2image.convert_from_bytes(b)] encoded_inputs = feature_extractor(images) print('feature_extractor: ', encoded_inputs.keys()) data = {} data['image'] = encoded_inputs.pixel_values data['words'] = encoded_inputs.words data['boxes'] = encoded_inputs.boxes return data def setup_logger(which_logger: Optional[str] = None): lib_level = logging.DEBUG # Default level for your logger root_level = logging.INFO log_format = '%(asctime)s - %(process)d - %(levelname)s - %(funcName)s - %(message)s' logging.basicConfig( filename=environ.get('LOG_FILE_PATH_LAYOUTLM_V2'),# taken from loca .env file, not set in settings.py format=log_format, datefmt='%d-%b-%y %H:%M:%S', level=root_level, force=True ) log = logging.getLogger(which_logger) log.setLevel(lib_level) return log logger = setup_logger(__name__) class Funcs: # helper function to unnormalize bboxes for drawing onto the image @staticmethod def unnormalize_box(bbox, width, height): return [ width * (bbox[0] / 1000), height * (bbox[1] / 1000), width * (bbox[2] / 1000), height * (bbox[3] / 1000), ] @staticmethod def num_spans(encoding: BatchEncoding) -> int: return len(encoding["input_ids"]) @staticmethod def p_mask(num_spans: int, encoding: BatchEncoding) -> list: try: return [ [tok != 1 for tok in encoding.sequence_ids(span_id)] \ for span_id in range(num_spans) ] except Exception as e: raise @staticmethod def token_start_end(encoding, tokenizer): sequence_ids = encoding.sequence_ids() # Start token index of the current span in the text. token_start_index = 0 while sequence_ids[token_start_index] != 1: token_start_index += 1 # End token index of the current span in the text. token_end_index = len(encoding.input_ids) - 1 while sequence_ids[token_end_index] != 1: token_end_index -= 1 print("Token start index:", token_start_index) print("Token end index:", token_end_index) print('token_start_end: ', tokenizer.decode(encoding.input_ids[token_start_index:token_end_index+1])) return token_start_index, token_end_index @staticmethod def reconstruct_answer(word_idx_start, word_idx_end, encoding, tokenizer): word_ids = encoding.word_ids()[token_start_index:token_end_index+1] print("Word ids:", word_ids) for id in word_ids: if id == word_idx_start: start_position = token_start_index else: token_start_index += 1 for id in word_ids[::-1]: if id == word_idx_end: end_position = token_end_index else: token_end_index -= 1 print("Reconstructed answer:", tokenizer.decode(encoding.input_ids[start_position:end_position+1]) ) return start_position, end_position @staticmethod def sigmoid(_outputs): return 1.0 / (1.0 + np.exp(-_outputs)) @staticmethod def softmax(_outputs): maxes = np.max(_outputs, axis=-1, keepdims=True) shifted_exp = np.exp(_outputs - maxes) return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True) class EndpointHandler: def __init__(self, path="./"): # self.model = LayoutLMv2ForQuestionAnswering.from_pretrained(path).to(device) self.model = LayoutLMv2ForQuestionAnswering.from_pretrained(path) self.tokenizer = LayoutLMv2TokenizerFast.from_pretrained(path) # self.image_processor = LayoutLMv2ImageProcessor() # apply_ocr is set to True by default self.processor = LayoutLMv2Processor.from_pretrained( path, # image_processor=self.image_processor, tokenizer=self.tokenizer) def __call__(self, data: dict[str, bytes]): """ Args: data (:obj:): includes the deserialized image file as PIL.Image """ image = data.pop("inputs", data) images = [x.convert("RGB") for x in pdf2image.convert_from_bytes(image)] question = "what is the bill date" with torch.no_grad(): for image in images: # max_seq_len = min(self.tokenizer.model_max_length, 512) # doc_stride = min(max_seq_len // 2, 256) encoding = self.processor( image, question, # max_length=max_seq_len, # stride=doc_stride, truncation=True, # truncation=TruncationStrategy.ONLY_SECOND, # return_offsets_mapping=True, # return_token_type_ids=True, # return_overflowing_tokens=True, return_tensors=TensorType.PYTORCH ) print('encoding: ', encoding.keys()) # for k, v in encoding.items(): # encoding[k] = v.to(self.model.device) # num_spans = Funcs.num_spans(encoding) # p_mask = Funcs.p_mask(num_spans, encoding) # offset_mapping = encoding.pop('offset_mapping') # smaple_mapping = encoding.pop('overflow_to_sample_mapping') outputs = self.model(**encoding) # print('model outputs: ', outputs.keys()) start_logits = outputs.start_logits end_logits = outputs.end_logits predicted_start_idx = start_logits.argmax(-1).item() predicted_end_idx = end_logits.argmax(-1).item() predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1] predicted_answer = self.processor.tokenizer.decode(predicted_answer_tokens) # print('answer: ', predicted_answer) target_start_index = torch.tensor([7]) target_end_index = torch.tensor([14]) outputs = self.model(**encoding, start_positions=target_start_index, end_positions=target_end_index) # predicted_answer_span_start = outputs.start_logits.argmax(-1).item() # predicted_answer_span_end = outputs.end_logits.argmax(-1).item() logger.info(f''' START predicted_start_idx: {predicted_start_idx} predicted_end_idx: {predicted_end_idx} --- answer: {predicted_answer} END''') return {'data': 'success'}