{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Create Custom Handler for Inference Endpoints\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: diffusers in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (0.9.0)\n", "Collecting diffusers\n", " Using cached diffusers-0.10.2-py3-none-any.whl (503 kB)\n", "Requirement already satisfied: importlib-metadata in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from diffusers) (4.11.4)\n", "Requirement already satisfied: filelock in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from diffusers) (3.8.0)\n", "Requirement already satisfied: numpy in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from diffusers) (1.22.4)\n", "Requirement already satisfied: Pillow in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from diffusers) (9.2.0)\n", "Requirement already satisfied: huggingface-hub>=0.10.0 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from diffusers) (0.11.1)\n", "Requirement already satisfied: requests in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from diffusers) (2.28.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from diffusers) (2022.7.25)\n", "Requirement already satisfied: packaging>=20.9 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from huggingface-hub>=0.10.0->diffusers) (21.3)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from huggingface-hub>=0.10.0->diffusers) (4.3.0)\n", "Requirement already satisfied: pyyaml>=5.1 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from huggingface-hub>=0.10.0->diffusers) (6.0)\n", "Requirement already satisfied: tqdm in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from huggingface-hub>=0.10.0->diffusers) (4.64.0)\n", "Requirement already satisfied: zipp>=0.5 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from importlib-metadata->diffusers) (3.8.1)\n", "Requirement already satisfied: idna<4,>=2.5 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from requests->diffusers) (3.3)\n", "Requirement already satisfied: charset-normalizer<3,>=2 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from requests->diffusers) (2.1.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from requests->diffusers) (2022.6.15)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from requests->diffusers) (1.26.11)\n", "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /home/ubuntu/miniconda/envs/dev/lib/python3.9/site-packages (from packaging>=20.9->huggingface-hub>=0.10.0->diffusers) (3.0.9)\n", "Installing collected packages: diffusers\n", " Attempting uninstall: diffusers\n", " Found existing installation: diffusers 0.9.0\n", " Uninstalling diffusers-0.9.0:\n", " Successfully uninstalled diffusers-0.9.0\n", "Successfully installed diffusers-0.10.2\n" ] } ], "source": [ "!pip install diffusers --upgrade" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "\n", "if device.type != 'cuda':\n", " raise ValueError(\"need to run on GPU\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting handler.py\n" ] } ], "source": [ "%%writefile handler.py\n", "from typing import Dict, List, Any\n", "import torch\n", "from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline\n", "from PIL import Image\n", "import base64\n", "from io import BytesIO\n", "\n", "\n", "# set device\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "if device.type != 'cuda':\n", " raise ValueError(\"need to run on GPU\")\n", "\n", "class EndpointHandler():\n", " def __init__(self, path=\"\"):\n", " # load StableDiffusionInpaintPipeline pipeline\n", " self.pipe = StableDiffusionInpaintPipeline.from_pretrained(path, torch_dtype=torch.float16)\n", " # use DPMSolverMultistepScheduler\n", " self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)\n", " # move to device \n", " self.pipe = self.pipe.to(device)\n", "\n", "\n", " def __call__(self, data: Any) -> List[List[Dict[str, float]]]:\n", " \"\"\"\n", " :param data: A dictionary contains `inputs` and optional `image` field.\n", " :return: A dictionary with `image` field contains image in base64.\n", " \"\"\"\n", " inputs = data.pop(\"inputs\", data)\n", " encoded_image = data.pop(\"image\", None)\n", " encoded_mask_image = data.pop(\"mask_image\", None)\n", " \n", " # hyperparamters\n", " num_inference_steps = data.pop(\"num_inference_steps\", 25)\n", " guidance_scale = data.pop(\"guidance_scale\", 7.5)\n", " negative_prompt = data.pop(\"negative_prompt\", None)\n", " height = data.pop(\"height\", None)\n", " width = data.pop(\"width\", None)\n", " \n", " # process image\n", " if encoded_image is not None and encoded_mask_image is not None:\n", " image = self.decode_base64_image(encoded_image)\n", " mask_image = self.decode_base64_image(encoded_mask_image)\n", " else:\n", " image = None\n", " mask_image = None \n", " \n", " # run inference pipeline\n", " out = self.pipe(inputs, \n", " image=image, \n", " mask_image=mask_image, \n", " num_inference_steps=num_inference_steps,\n", " guidance_scale=guidance_scale,\n", " num_images_per_prompt=1,\n", " negative_prompt=negative_prompt,\n", " height=height,\n", " width=width\n", " )\n", " \n", " # return first generate PIL image\n", " return out.images[0]\n", " \n", " # helper to decode input image\n", " def decode_base64_image(self, image_string):\n", " base64_image = base64.b64decode(image_string)\n", " buffer = BytesIO(base64_image)\n", " image = Image.open(buffer)\n", " return image" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from handler import EndpointHandler\n", "\n", "# init handler\n", "my_handler = EndpointHandler(path=\".\")" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8523b8998b74472ead35a11270dff3a5", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/25 [00:00" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "pred.save(\"result.png\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.13 ('dev': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "f6dd96c16031089903d5a31ec148b80aeb0d39c32affb1a1080393235fbfa2fc" } } }, "nbformat": 4, "nbformat_minor": 2 }