{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup & Installation" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting requirements.txt\n" ] } ], "source": [ "%%writefile requirements.txt\n", "diffusers==0.2.4" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install -r requirements.txt --upgrade" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Create Custom Handler for Inference Endpoints\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='cuda')" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "device" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "if device.type != 'cuda':\n", " raise ValueError(\"need to run on GPU\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Overwriting handler.py\n" ] } ], "source": [ "%%writefile handler.py\n", "from typing import Dict, List, Any\n", "import torch\n", "from torch import autocast\n", "from diffusers import StableDiffusionPipeline\n", "import base64\n", "from io import BytesIO\n", "\n", "\n", "# set device\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "if device.type != 'cuda':\n", " raise ValueError(\"need to run on GPU\")\n", "\n", "class EndpointHandler():\n", " def __init__(self, path=\"\"):\n", " # load the optimized model\n", " self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)\n", " self.pipe = self.pipe.to(device)\n", "\n", "\n", " def __call__(self, data: Any) -> List[List[Dict[str, float]]]:\n", " \"\"\"\n", " Args:\n", " data (:obj:):\n", " includes the input data and the parameters for the inference.\n", " Return:\n", " A :obj:`dict`:. base64 encoded image\n", " \"\"\"\n", " inputs = data.pop(\"inputs\", data)\n", " \n", " # run inference pipeline\n", " with autocast(device.type):\n", " image = self.pipe(inputs, guidance_scale=7.5)[\"sample\"][0] \n", " \n", " # encode image as base 64\n", " buffered = BytesIO()\n", " image.save(buffered, format=\"JPEG\")\n", " img_str = base64.b64encode(buffered.getvalue())\n", "\n", " # postprocess the prediction\n", " return {\"image\": img_str.decode()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "test custom pipeline" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'1.11.0+cu113'" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import torch\n", "\n", "torch.__version__" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.\n" ] } ], "source": [ "from handler import EndpointHandler\n", "\n", "# init handler\n", "my_handler = EndpointHandler(path=\".\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "376de150f16b4b4bb0c3ab8c513de5c0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import base64\n", "from PIL import Image\n", "from io import BytesIO\n", "import json\n", "\n", "# helper decoder\n", "def decode_base64_image(image_string):\n", " base64_image = base64.b64decode(image_string)\n", " buffer = BytesIO(base64_image)\n", " return Image.open(buffer)\n", "\n", "# prepare sample payload\n", "request = {\"inputs\": \"a high resulotion image of a macbook\"}\n", "\n", "# test the handler\n", "pred = my_handler(request)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "decode_base64_image(pred[\"image\"]).save(\"sample.jpg\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![test](sample.jpg)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.13 ('dev': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "f6dd96c16031089903d5a31ec148b80aeb0d39c32affb1a1080393235fbfa2fc" } } }, "nbformat": 4, "nbformat_minor": 2 }