{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Create Custom Handler for Inference Endpoints\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Install dependencies" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting requirements.txt\n" ] } ], "source": [ "%%writefile requirements.txt\n", "\n", "git+https://github.com/huggingface/diffusers.git\n", "safetensors \n", "opencv-python\n", "controlnet_hinter==0.0.5" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install -r requirements.txt --quiet" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Create handler.py" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "\n", "if device.type != 'cuda':\n", " raise ValueError(\"need to run on GPU\")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting handler.py\n" ] } ], "source": [ "%%writefile handler.py\n", "from typing import Dict, List, Any\n", "import base64\n", "from PIL import Image\n", "from io import BytesIO\n", "from diffusers import StableDiffusionControlNetPipeline, ControlNetModel\n", "import torch\n", "\n", "\n", "import numpy as np\n", "import cv2\n", "import controlnet_hinter\n", "\n", "# set device\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "if device.type != 'cuda':\n", " raise ValueError(\"need to run on GPU\")\n", "# set mixed precision dtype\n", "dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16\n", "\n", "# controlnet mapping for controlnet id and control hinter\n", "CONTROLNET_MAPPING = {\n", " \"canny_edge\": {\n", " \"model_id\": \"lllyasviel/sd-controlnet-canny\",\n", " \"hinter\": controlnet_hinter.hint_canny\n", " },\n", " \"pose\": {\n", " \"model_id\": \"lllyasviel/sd-controlnet-openpose\",\n", " \"hinter\": controlnet_hinter.hint_openpose\n", " },\n", " \"depth\": {\n", " \"model_id\": \"lllyasviel/sd-controlnet-depth\",\n", " \"hinter\": controlnet_hinter.hint_depth\n", " },\n", " \"scribble\": {\n", " \"model_id\": \"lllyasviel/sd-controlnet-scribble\",\n", " \"hinter\": controlnet_hinter.hint_scribble,\n", " },\n", " \"segmentation\": {\n", " \"model_id\": \"lllyasviel/sd-controlnet-seg\",\n", " \"hinter\": controlnet_hinter.hint_segmentation,\n", " },\n", " \"normal\": {\n", " \"model_id\": \"lllyasviel/sd-controlnet-normal\",\n", " \"hinter\": controlnet_hinter.hint_normal,\n", " },\n", " \"hed\": {\n", " \"model_id\": \"lllyasviel/sd-controlnet-hed\",\n", " \"hinter\": controlnet_hinter.hint_hed,\n", " },\n", " \"hough\": {\n", " \"model_id\": \"lllyasviel/sd-controlnet-mlsd\",\n", " \"hinter\": controlnet_hinter.hint_hough,\n", " }\n", "}\n", "\n", "\n", "class EndpointHandler():\n", " def __init__(self, path=\"\"):\n", " # define default controlnet id and load controlnet\n", " self.control_type = \"normal\"\n", " self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type][\"model_id\"],torch_dtype=dtype).to(device)\n", " \n", " # Load StableDiffusionControlNetPipeline \n", " self.stable_diffusion_id = \"runwayml/stable-diffusion-v1-5\"\n", " self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id, \n", " controlnet=self.controlnet, \n", " torch_dtype=dtype,\n", " safety_checker=None).to(device)\n", " # Define Generator with seed\n", " self.generator = torch.Generator(device=\"cpu\").manual_seed(3)\n", "\n", " def __call__(self, data: Any) -> List[List[Dict[str, float]]]:\n", " \"\"\"\n", " :param data: A dictionary contains `inputs` and optional `image` field.\n", " :return: A dictionary with `image` field contains image in base64.\n", " \"\"\"\n", " prompt = data.pop(\"inputs\", None)\n", " image = data.pop(\"image\", None)\n", " controlnet_type = data.pop(\"controlnet_type\", None)\n", " \n", " # Check if neither prompt nor image is provided\n", " if prompt is None and image is None:\n", " return {\"error\": \"Please provide a prompt and base64 encoded image.\"}\n", " \n", " # Check if a new controlnet is provided\n", " if controlnet_type is not None and controlnet_type != self.control_type:\n", " print(f\"changing controlnet from {self.control_type} to {controlnet_type} using {CONTROLNET_MAPPING[controlnet_type]['model_id']} model\")\n", " self.control_type = controlnet_type\n", " self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type][\"model_id\"],\n", " torch_dtype=dtype).to(device)\n", " self.pipe.controlnet = self.controlnet\n", " \n", " \n", " # hyperparamters\n", " negatice_prompt = data.pop(\"negative_prompt\", None)\n", " num_inference_steps = data.pop(\"num_inference_steps\", 30)\n", " guidance_scale = data.pop(\"guidance_scale\", 7.5)\n", " negative_prompt = data.pop(\"negative_prompt\", None)\n", " height = data.pop(\"height\", None)\n", " width = data.pop(\"width\", None)\n", " controlnet_conditioning_scale = data.pop(\"controlnet_conditioning_scale\", 1.0)\n", " \n", " # process image\n", " image = self.decode_base64_image(image)\n", " control_image = CONTROLNET_MAPPING[self.control_type][\"hinter\"](image)\n", " \n", " # run inference pipeline\n", " out = self.pipe(\n", " prompt=prompt, \n", " negative_prompt=negative_prompt,\n", " image=control_image,\n", " num_inference_steps=num_inference_steps, \n", " guidance_scale=guidance_scale,\n", " num_images_per_prompt=1,\n", " height=height,\n", " width=width,\n", " controlnet_conditioning_scale=controlnet_conditioning_scale,\n", " generator=self.generator\n", " )\n", "\n", " \n", " # return first generate PIL image\n", " return out.images[0]\n", " \n", " # helper to decode input image\n", " def decode_base64_image(self, image_string):\n", " base64_image = base64.b64decode(image_string)\n", " buffer = BytesIO(base64_image)\n", " image = Image.open(buffer)\n", " return image" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Load Handler" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/envs/pytorch/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: /opt/conda/envs/pytorch/lib/python3.9/site-packages/torchvision/image.so: undefined symbol: _ZN2at4_ops19empty_memory_format4callEN3c108ArrayRefIlEENS2_8optionalINS2_10ScalarTypeEEENS5_INS2_6LayoutEEENS5_INS2_6DeviceEEENS5_IbEENS5_INS2_12MemoryFormatEEE\n", " warn(f\"Failed to load image Python extension: {e}\")\n" ] }, { "data": { "application/json": { "ascii": false, "bar_format": null, "colour": null, "elapsed": 0.016447782516479492, "initial": 0, "n": 0, "ncols": null, "nrows": null, "postfix": null, "prefix": "Fetching 15 files", "rate": null, "total": 15, "unit": "it", "unit_divisor": 1000, "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { "model_id": "72bcab888d3e4919aec099f691a9a567", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Fetching 15 files: 0%| | 0/15 [00:00 by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\n" ] } ], "source": [ "from handler import EndpointHandler\n", "\n", "# init handler\n", "my_handler = EndpointHandler(path=\".\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Test handler" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!wget https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_imgvar/input_image_vermeer.png" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "application/json": { "ascii": false, "bar_format": null, "colour": null, "elapsed": 0.011706113815307617, "initial": 0, "n": 0, "ncols": null, "nrows": null, "postfix": null, "prefix": "", "rate": null, "total": 30, "unit": "it", "unit_divisor": 1000, "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { "model_id": "806f910bde6a4f3096cb5a239c04421c", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/30 [00:00" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import base64\n", "from PIL import Image\n", "from io import BytesIO\n", "import base64\n", "\n", "# helper image utils\n", "def encode_image(image_path):\n", " with open(image_path, \"rb\") as i:\n", " b64 = base64.b64encode(i.read())\n", " return b64.decode(\"utf-8\")\n", "\n", "prompt = \"high quality, cypberpunk background, lights, futuristic, lights, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3\"\n", "negative_prompt=\"worst quality, low quality, oil painting, historic\"\n", "controlnet_type = \"hed\"\n", "image = encode_image(\"crysis.jpeg\")\n", "\n", "# prepare sample payload\n", "request = {\"inputs\": prompt, \"image\": image, \"negative_prompt\": negative_prompt, \"controlnet_type\": controlnet_type}\n", "\n", "# test the handler\n", "pred = my_handler(request)\n", "pred\n", "\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "pred.save(\"result_crysis.png\")" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "application/json": { "ascii": false, "bar_format": null, "colour": null, "elapsed": 0.012011289596557617, "initial": 0, "n": 0, "ncols": null, "nrows": null, "postfix": null, "prefix": "", "rate": null, "total": 30, "unit": "it", "unit_divisor": 1000, "unit_scale": false }, "application/vnd.jupyter.widget-view+json": { "model_id": "e8cff0984648407c83575a3984a32f3a", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/30 [00:00