Layoutlm_invoices / handler.py
Szczotar93's picture
Update handler.py
e281d7a verified
raw
history blame contribute delete
No virus
3.14 kB
from typing import Dict, List, Any
from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor
import torch
from subprocess import run
import subprocess
import os
# install tesseract-ocr and pytesseract
# run("apt install -y tesseract-ocr", shell=True, check=True)
# run("pip install pytesseract", shell=True, check=True)
subprocess.check_call("mkdir -p /data", shell = True)
subprocess.check_call("chmod 777 /data", shell = True)
# subprocess.check_call("apt-get update", shell = True)
# subprocess.check_call("apt-get install git-lfs" ,shell = True)
# subprocess.check_call("mkdir -p /var/lib/dpkg", shell = True)
# subprocess.check_call("id",shell = True)
subprocess.check_call("apt install tesseract-ocr -y", shell=True)
# subprocess.check_call("sudo apt install libtesseract-dev", shell=True)
subprocess.check_call("tesseract --version", shell = True)
# os.system('chmod 777 /tmp')
# os.system('apt-get update -y')
# os.system('apt-get install tesseract-ocr -y')
os.system('pip install -q pytesseract')
# helper function to unnormalize bboxes for drawing onto the image
import pytesseract
# pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract"
def unnormalize_box(bbox, width, height):
return [
width * (bbox[0] / 1000),
height * (bbox[1] / 1000),
width * (bbox[2] / 1000),
height * (bbox[3] / 1000),
]
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.model = LayoutLMForTokenClassification.from_pretrained(path).to(device)
self.processor = LayoutLMv2Processor.from_pretrained(path,apply_ocr=True)
def __call__(self, data: Dict[str, bytes]) -> Dict[str, List[Any]]:
"""
Args:
data (:obj:):
includes the deserialized image file as PIL.Image
"""
# process input
image = data.pop("inputs", data)
# process image
encoding = self.processor(image, return_tensors="pt")
# run prediction
with torch.inference_mode():
outputs = self.model(
input_ids=encoding.input_ids.to(device),
bbox=encoding.bbox.to(device),
attention_mask=encoding.attention_mask.to(device),
token_type_ids=encoding.token_type_ids.to(device),
)
predictions = outputs.logits.softmax(-1)
# post process output
result = []
for item, inp_ids, bbox in zip(
predictions.squeeze(0).cpu(), encoding.input_ids.squeeze(0).cpu(), encoding.bbox.squeeze(0).cpu()
):
label = self.model.config.id2label[int(item.argmax().cpu())]
if label == "O":
continue
score = item.max().item()
text = self.processor.tokenizer.decode(inp_ids)
bbox = unnormalize_box(bbox.tolist(), image.width, image.height)
result.append({"label": label, "score": score, "text": text, "bbox": bbox})
return {"predictions": result}