Alexander Slessor
updated readme
01a1c19
import torch
from typing import Any, Optional
from transformers import LayoutLMv2ForQuestionAnswering
from transformers import LayoutLMv2Processor
from transformers import LayoutLMv2FeatureExtractor
from transformers import LayoutLMv2ImageProcessor
from transformers import LayoutLMv2TokenizerFast
from transformers.tokenization_utils_base import BatchEncoding
from transformers.tokenization_utils_base import TruncationStrategy
from transformers.utils import TensorType
# from transformers.modeling_outputs import (
# QuestionAnsweringModelOutput as QuestionAnsweringModelOutputBase
# )
import numpy as np
# from PIL import Image, ImageDraw, ImageFont
# from subprocess import run
import pdf2image
# from pprint import pprint
import logging
from os import environ
# from dataclasses import dataclass
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# install tesseract-ocr and pytesseract
# run("apt install -y tesseract-ocr", shell=True, check=True)
feature_extractor = LayoutLMv2FeatureExtractor()
# @dataclass
# class QuestionAnsweringModelOutput(QuestionAnsweringModelOutputBase):
# token_logits: Optional[torch.FloatTensor] = None
class NoOCRReaderFound(Exception):
def __init__(self, e):
self.e = e
def __str__(self):
return f"Could not load OCR Reader: {self.e}"
def pdf_to_image(b: bytes):
# First, try to extract text directly
# TODO: This library requires poppler, which is not present everywhere.
# We should look into alternatives. We could also gracefully handle this
# and simply fall back to _only_ extracted text
images = [x.convert("RGB") for x in pdf2image.convert_from_bytes(b)]
encoded_inputs = feature_extractor(images)
print('feature_extractor: ', encoded_inputs.keys())
data = {}
data['image'] = encoded_inputs.pixel_values
data['words'] = encoded_inputs.words
data['boxes'] = encoded_inputs.boxes
return data
def setup_logger(which_logger: Optional[str] = None):
lib_level = logging.DEBUG # Default level for your logger
root_level = logging.INFO
log_format = '%(asctime)s - %(process)d - %(levelname)s - %(funcName)s - %(message)s'
logging.basicConfig(
filename=environ.get('LOG_FILE_PATH_LAYOUTLM_V2'),# taken from loca .env file, not set in settings.py
format=log_format,
datefmt='%d-%b-%y %H:%M:%S',
level=root_level,
force=True
)
log = logging.getLogger(which_logger)
log.setLevel(lib_level)
return log
logger = setup_logger(__name__)
class Funcs:
# helper function to unnormalize bboxes for drawing onto the image
@staticmethod
def unnormalize_box(bbox, width, height):
return [
width * (bbox[0] / 1000),
height * (bbox[1] / 1000),
width * (bbox[2] / 1000),
height * (bbox[3] / 1000),
]
@staticmethod
def num_spans(encoding: BatchEncoding) -> int:
return len(encoding["input_ids"])
@staticmethod
def p_mask(num_spans: int, encoding: BatchEncoding) -> list:
try:
return [
[tok != 1 for tok in encoding.sequence_ids(span_id)] \
for span_id in range(num_spans)
]
except Exception as e:
raise
@staticmethod
def token_start_end(encoding, tokenizer):
sequence_ids = encoding.sequence_ids()
# Start token index of the current span in the text.
token_start_index = 0
while sequence_ids[token_start_index] != 1:
token_start_index += 1
# End token index of the current span in the text.
token_end_index = len(encoding.input_ids) - 1
while sequence_ids[token_end_index] != 1:
token_end_index -= 1
print("Token start index:", token_start_index)
print("Token end index:", token_end_index)
print('token_start_end: ', tokenizer.decode(encoding.input_ids[token_start_index:token_end_index+1]))
return token_start_index, token_end_index
@staticmethod
def reconstruct_answer(word_idx_start, word_idx_end, encoding, tokenizer):
word_ids = encoding.word_ids()[token_start_index:token_end_index+1]
print("Word ids:", word_ids)
for id in word_ids:
if id == word_idx_start:
start_position = token_start_index
else:
token_start_index += 1
for id in word_ids[::-1]:
if id == word_idx_end:
end_position = token_end_index
else:
token_end_index -= 1
print("Reconstructed answer:",
tokenizer.decode(encoding.input_ids[start_position:end_position+1])
)
return start_position, end_position
@staticmethod
def sigmoid(_outputs):
return 1.0 / (1.0 + np.exp(-_outputs))
@staticmethod
def softmax(_outputs):
maxes = np.max(_outputs, axis=-1, keepdims=True)
shifted_exp = np.exp(_outputs - maxes)
return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
class EndpointHandler:
def __init__(self, path="./"):
# self.model = LayoutLMv2ForQuestionAnswering.from_pretrained(path).to(device)
self.model = LayoutLMv2ForQuestionAnswering.from_pretrained(path)
self.tokenizer = LayoutLMv2TokenizerFast.from_pretrained(path)
# self.image_processor = LayoutLMv2ImageProcessor() # apply_ocr is set to True by default
self.processor = LayoutLMv2Processor.from_pretrained(
path,
# image_processor=self.image_processor,
tokenizer=self.tokenizer)
def __call__(self, data: dict[str, bytes]):
"""
Args:
data (:obj:):
includes the deserialized image file as PIL.Image
"""
image = data.pop("inputs", data)
images = [x.convert("RGB") for x in pdf2image.convert_from_bytes(image)]
question = "what is the bill date"
with torch.no_grad():
for image in images:
# max_seq_len = min(self.tokenizer.model_max_length, 512)
# doc_stride = min(max_seq_len // 2, 256)
encoding = self.processor(
image,
question,
# max_length=max_seq_len,
# stride=doc_stride,
truncation=True,
# truncation=TruncationStrategy.ONLY_SECOND,
# return_offsets_mapping=True,
# return_token_type_ids=True,
# return_overflowing_tokens=True,
return_tensors=TensorType.PYTORCH
)
print('encoding: ', encoding.keys())
# for k, v in encoding.items():
# encoding[k] = v.to(self.model.device)
# num_spans = Funcs.num_spans(encoding)
# p_mask = Funcs.p_mask(num_spans, encoding)
# offset_mapping = encoding.pop('offset_mapping')
# smaple_mapping = encoding.pop('overflow_to_sample_mapping')
outputs = self.model(**encoding)
# print('model outputs: ', outputs.keys())
start_logits = outputs.start_logits
end_logits = outputs.end_logits
predicted_start_idx = start_logits.argmax(-1).item()
predicted_end_idx = end_logits.argmax(-1).item()
predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1]
predicted_answer = self.processor.tokenizer.decode(predicted_answer_tokens)
# print('answer: ', predicted_answer)
target_start_index = torch.tensor([7])
target_end_index = torch.tensor([14])
outputs = self.model(**encoding, start_positions=target_start_index, end_positions=target_end_index)
# predicted_answer_span_start = outputs.start_logits.argmax(-1).item()
# predicted_answer_span_end = outputs.end_logits.argmax(-1).item()
logger.info(f'''
START
predicted_start_idx: {predicted_start_idx}
predicted_end_idx: {predicted_end_idx}
---
answer: {predicted_answer}
END''')
return {'data': 'success'}