|
import io |
|
from typing import Dict, List, Any |
|
from transformers import LayoutLMv3ForSequenceClassification, LayoutLMv3FeatureExtractor, LayoutLMv3Tokenizer, LayoutLMv3Processor |
|
import torch |
|
from subprocess import run |
|
from PIL import Image |
|
|
|
|
|
run("apt install -y tesseract-ocr", shell=True, check=True) |
|
run("python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.10/index.html", shell=True, check=True) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
self.FEATURE_EXTRACTOR = LayoutLMv3FeatureExtractor() |
|
self.TOKENIZER = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base") |
|
self.PROCESSOR = LayoutLMv3Processor(self.FEATURE_EXTRACTOR, self.TOKENIZER) |
|
self.MODEL = LayoutLMv3ForSequenceClassification.from_pretrained("OtraBoi/document_classifier_testing").to(device) |
|
|
|
def __call__(self, data: Dict): |
|
image = Image.open(io.BytesIO(data["inputs"])).convert("RGB") |
|
encoding = self.PROCESSOR(image, return_tensors="pt", padding="max_length", truncation=True) |
|
|
|
for k,v in encoding.items(): |
|
encoding[k] = v.to(self.MODEL.device) |
|
|
|
|
|
with torch.inference_mode(): |
|
outputs = self.MODEL(**encoding) |
|
logits = outputs.logits |
|
predicted_class_idx = logits.argmax(-1).item() |
|
|
|
return self.MODEL.config.id2label[predicted_class_idx] |
|
|