sitloboi2012's picture
update handler
96ac1a0
raw
history blame
1.61 kB
import io
from typing import Dict, List, Any
from transformers import LayoutLMv3ForSequenceClassification, LayoutLMv3FeatureExtractor, LayoutLMv3Tokenizer, LayoutLMv3Processor
import torch
from subprocess import run
from PIL import Image
# install tesseract-ocr and pytesseract
run("apt install -y tesseract-ocr", shell=True, check=True)
run("python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.10/index.html", shell=True, check=True)
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.FEATURE_EXTRACTOR = LayoutLMv3FeatureExtractor()
self.TOKENIZER = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
self.PROCESSOR = LayoutLMv3Processor(self.FEATURE_EXTRACTOR, self.TOKENIZER)
self.MODEL = LayoutLMv3ForSequenceClassification.from_pretrained("OtraBoi/document_classifier_testing").to(device)
def __call__(self, data: Dict):
image = Image.open(io.BytesIO(data["inputs"])).convert("RGB")
encoding = self.PROCESSOR(image, return_tensors="pt", padding="max_length", truncation=True)
for k,v in encoding.items():
encoding[k] = v.to(self.MODEL.device)
# run prediction
with torch.inference_mode():
outputs = self.MODEL(**encoding)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
return self.MODEL.config.id2label[predicted_class_idx]