from typing import Dict, List, Any from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor import torch from subprocess import run run("apt install -y tesseract-ocr", shell=True, check=True) class HugEndpointException(Exception): def __init__(self, e): self.e = e def __str__(self): return f'Custom Endpoint Exception: {self.e}' # helper function to unnormalize bboxes for drawing onto the image def unnormalize_box(bbox, width, height): return [ width * (bbox[0] / 1000), height * (bbox[1] / 1000), width * (bbox[2] / 1000), height * (bbox[3] / 1000), ] # set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class EndpointHandler: def __init__(self, path=""): # load model and processor from path self.model = LayoutLMForTokenClassification.from_pretrained(path).to(device) self.processor = LayoutLMv2Processor.from_pretrained(path) def __call__(self, data: Dict[str, bytes]) -> Dict[str, List[Any]]: """ Args: data (:obj:): includes the deserialized image file as PIL.Image """ # process input image = data.pop("inputs", data) # process image encoding = self.processor(image, return_tensors="pt") try: # run prediction with torch.inference_mode(): outputs = self.model( input_ids=encoding.input_ids.to(device), bbox=encoding.bbox.to(device), attention_mask=encoding.attention_mask.to(device), token_type_ids=encoding.token_type_ids.to(device), ) predictions = outputs.logits.softmax(-1) # post process output result = [] for item, inp_ids, bbox in zip( predictions.squeeze(0).cpu(), encoding.input_ids.squeeze(0).cpu(), encoding.bbox.squeeze(0).cpu() ): label = self.model.config.id2label[int(item.argmax().cpu())] if label == "O": continue score = item.max().item() text = self.processor.tokenizer.decode(inp_ids) bbox = unnormalize_box(bbox.tolist(), image.width, image.height) result.append({"label": label, "score": score, "text": text, "bbox": bbox}) return {"predictions": result} except Exception as e: raise HugEndpointException(e)