{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install transformers --upgrade" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create Custom Handler for Inference Endpoints\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting handler.py\n" ] } ], "source": [ "%%writefile handler.py\n", "from typing import Dict, List, Any\n", "from transformers import DonutProcessor, VisionEncoderDecoderModel\n", "import torch\n", "\n", "\n", "# check for GPU\n", "device = 0 if torch.cuda.is_available() else -1\n", "\n", "\n", "class EndpointHandler:\n", " def __init__(self, path=\"\"):\n", " # load the model\n", " self.processor = DonutProcessor.from_pretrained(path)\n", " self.model = VisionEncoderDecoderModel.from_pretrained(path)\n", " # move model to device\n", " self.model.to(device)\n", " self.decoder_input_ids = self.processor.tokenizer(\n", " \"\", add_special_tokens=False, return_tensors=\"pt\"\n", " ).input_ids\n", "\n", " def __call__(self, data: Any) -> List[List[Dict[str, float]]]:\n", "\n", " inputs = data.pop(\"inputs\", data)\n", "\n", "\n", " # preprocess the input\n", " pixel_values = self.processor(inputs, return_tensors=\"pt\").pixel_values\n", "\n", " # forward pass\n", " outputs = self.model.generate(\n", " pixel_values.to(device),\n", " decoder_input_ids=self.decoder_input_ids.to(device),\n", " max_length=self.model.decoder.config.max_position_embeddings,\n", " early_stopping=True,\n", " pad_token_id=self.processor.tokenizer.pad_token_id,\n", " eos_token_id=self.processor.tokenizer.eos_token_id,\n", " use_cache=True,\n", " num_beams=1,\n", " bad_words_ids=[[self.processor.tokenizer.unk_token_id]],\n", " return_dict_in_generate=True,\n", " )\n", " # process output\n", " prediction = self.processor.batch_decode(outputs.sequences)[0]\n", " prediction = self.processor.token2json(prediction)\n", "\n", " return prediction\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "test custom pipeline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from handler import EndpointHandler\n", "\n", "my_handler = EndpointHandler(\".\")" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'menu': [{'nm': '0571-1854 BLUS WANITA',\n", " 'unitprice': '@120.000',\n", " 'cnt': '1',\n", " 'price': '120,000'},\n", " {'nm': '1002-0060 SHOPPING BAG', 'cnt': '1', 'price': '0'}],\n", " 'total': {'total_price': '120,000',\n", " 'changeprice': '0',\n", " 'creditcardprice': '120,000',\n", " 'menuqty_cnt': '1'}}" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from PIL import Image\n", "\n", "payload = {\"inputs\": Image.open(\"sample.png\").convert(\"RGB\")}\n", "\n", "my_handler(payload)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.13 ('dev': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "f6dd96c16031089903d5a31ec148b80aeb0d39c32affb1a1080393235fbfa2fc" } } }, "nbformat": 4, "nbformat_minor": 2 }