{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Setup & Installation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!apt install -y tesseract-ocr\n", "pip install pytesseract" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Create Custom Handler for Inference Endpoints\n" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting handler.py\n" ] } ], "source": [ "%%writefile handler.py\n", "from typing import Dict, List, Any\n", "from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor\n", "import torch\n", "from subprocess import run\n", "\n", "# install tesseract-ocr and pytesseract\n", "run(\"apt install -y tesseract-ocr\", shell=True, check=True)\n", "run(\"pip install pytesseract\", shell=True, check=True)\n", "\n", "# helper function to unnormalize bboxes for drawing onto the image\n", "def unnormalize_box(bbox, width, height):\n", " return [\n", " width * (bbox[0] / 1000),\n", " height * (bbox[1] / 1000),\n", " width * (bbox[2] / 1000),\n", " height * (bbox[3] / 1000),\n", " ]\n", "\n", "\n", "# set device\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "\n", "class EndpointHandler:\n", " def __init__(self, path=\"\"):\n", " # load model and processor from path\n", " self.model = LayoutLMForTokenClassification.from_pretrained(\"philschmid/layoutlm-funsd\").to(device)\n", " self.processor = LayoutLMv2Processor.from_pretrained(\"philschmid/layoutlm-funsd\")\n", "\n", " def __call__(self, data: Dict[str, bytes]) -> Dict[str, List[Any]]:\n", " \"\"\"\n", " Args:\n", " data (:obj:):\n", " includes the deserialized image file as PIL.Image\n", " \"\"\"\n", " # process input\n", " image = data.pop(\"inputs\", data)\n", "\n", " # process image\n", " encoding = self.processor(image, return_tensors=\"pt\")\n", "\n", " # run prediction\n", " with torch.inference_mode():\n", " outputs = self.model(\n", " input_ids=encoding.input_ids.to(device),\n", " bbox=encoding.bbox.to(device),\n", " attention_mask=encoding.attention_mask.to(device),\n", " token_type_ids=encoding.token_type_ids.to(device),\n", " )\n", " predictions = outputs.logits.softmax(-1)\n", "\n", " # post process output\n", " result = []\n", " for item, inp_ids, bbox in zip(\n", " predictions.squeeze(0).cpu(), encoding.input_ids.squeeze(0).cpu(), encoding.bbox.squeeze(0).cpu()\n", " ):\n", " label = self.model.config.id2label[int(item.argmax().cpu())]\n", " if label == \"O\":\n", " continue\n", " score = item.max().item()\n", " text = self.processor.tokenizer.decode(inp_ids)\n", " bbox = unnormalize_box(bbox.tolist(), image.width, image.height)\n", " result.append({\"label\": label, \"score\": score, \"text\": text, \"bbox\": bbox})\n", " return {\"predictions\": result}\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "test custom pipeline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from handler import EndpointHandler\n", "\n", "my_handler = EndpointHandler(\".\")" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", "To disable this warning, you can either:\n", "\t- Avoid using `tokenizers` before the fork if possible\n", "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" ] } ], "source": [ "import base64\n", "from PIL import Image\n", "from io import BytesIO\n", "import json\n", "\n", "# read image from disk\n", "image = Image.open(\"invoice_example.png\")\n", "request = {\"inputs\":image }\n", "\n", "# test the handler\n", "pred = my_handler(request)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "from PIL import Image, ImageDraw, ImageFont\n", "\n", "\n", "def draw_result(image,result):\n", " label2color = {\n", " \"B-HEADER\": \"blue\",\n", " \"B-QUESTION\": \"red\",\n", " \"B-ANSWER\": \"green\",\n", " \"I-HEADER\": \"blue\",\n", " \"I-QUESTION\": \"red\",\n", " \"I-ANSWER\": \"green\",\n", " }\n", "\n", "\n", " # draw predictions over the image\n", " draw = ImageDraw.Draw(image)\n", " font = ImageFont.load_default()\n", " for res in result:\n", " draw.rectangle(res[\"bbox\"], outline=\"black\")\n", " draw.rectangle(res[\"bbox\"], outline=label2color[res[\"label\"]])\n", " draw.text((res[\"bbox\"][0] + 10, res[\"bbox\"][1] - 10), text=res[\"label\"], fill=label2color[res[\"label\"]], font=font)\n", " return image\n", "\n", "draw_result(image,pred[\"predictions\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.13 ('dev': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "f6dd96c16031089903d5a31ec148b80aeb0d39c32affb1a1080393235fbfa2fc" } } }, "nbformat": 4, "nbformat_minor": 2 }