philschmid HF staff commited on
Commit
56d3072
1 Parent(s): f35b460

Upload create_handler.ipynb

Browse files
Files changed (1) hide show
  1. create_handler.ipynb +275 -0
create_handler.ipynb ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "## Setup & Installation"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 1,
13
+ "metadata": {},
14
+ "outputs": [
15
+ {
16
+ "name": "stdout",
17
+ "output_type": "stream",
18
+ "text": [
19
+ "Overwriting requirements.txt\n"
20
+ ]
21
+ }
22
+ ],
23
+ "source": [
24
+ "%%writefile requirements.txt\n",
25
+ "diffusers==0.2.4"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "!pip install -r requirements.txt --upgrade"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "markdown",
39
+ "metadata": {},
40
+ "source": [
41
+ "## 3. Create Custom Handler for Inference Endpoints\n"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 10,
47
+ "metadata": {},
48
+ "outputs": [
49
+ {
50
+ "data": {
51
+ "text/plain": [
52
+ "device(type='cuda')"
53
+ ]
54
+ },
55
+ "execution_count": 10,
56
+ "metadata": {},
57
+ "output_type": "execute_result"
58
+ }
59
+ ],
60
+ "source": [
61
+ "import torch\n",
62
+ "\n",
63
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
64
+ "device"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": 11,
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "if device.type != 'cuda':\n",
74
+ " raise ValueError(\"need to run on GPU\")"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": 5,
80
+ "metadata": {},
81
+ "outputs": [
82
+ {
83
+ "name": "stdout",
84
+ "output_type": "stream",
85
+ "text": [
86
+ "Overwriting handler.py\n"
87
+ ]
88
+ }
89
+ ],
90
+ "source": [
91
+ "%%writefile handler.py\n",
92
+ "from typing import Dict, List, Any\n",
93
+ "import torch\n",
94
+ "from torch import autocast\n",
95
+ "from diffusers import StableDiffusionPipeline\n",
96
+ "import base64\n",
97
+ "from io import BytesIO\n",
98
+ "\n",
99
+ "\n",
100
+ "# set device\n",
101
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
102
+ "\n",
103
+ "if device.type != 'cuda':\n",
104
+ " raise ValueError(\"need to run on GPU\")\n",
105
+ "\n",
106
+ "class EndpointHandler():\n",
107
+ " def __init__(self, path=\"\"):\n",
108
+ " # load the optimized model\n",
109
+ " self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)\n",
110
+ " self.pipe = self.pipe.to(device)\n",
111
+ "\n",
112
+ "\n",
113
+ " def __call__(self, data: Any) -> List[List[Dict[str, float]]]:\n",
114
+ " \"\"\"\n",
115
+ " Args:\n",
116
+ " data (:obj:):\n",
117
+ " includes the input data and the parameters for the inference.\n",
118
+ " Return:\n",
119
+ " A :obj:`dict`:. base64 encoded image\n",
120
+ " \"\"\"\n",
121
+ " inputs = data.pop(\"inputs\", data)\n",
122
+ " \n",
123
+ " # run inference pipeline\n",
124
+ " with autocast(device.type):\n",
125
+ " image = self.pipe(inputs, guidance_scale=7.5)[\"sample\"][0] \n",
126
+ " \n",
127
+ " # encode image as base 64\n",
128
+ " buffered = BytesIO()\n",
129
+ " image.save(buffered, format=\"JPEG\")\n",
130
+ " img_str = base64.b64encode(buffered.getvalue())\n",
131
+ "\n",
132
+ " # postprocess the prediction\n",
133
+ " return {\"image\": img_str.decode()}"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "markdown",
138
+ "metadata": {},
139
+ "source": [
140
+ "test custom pipeline"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": 6,
146
+ "metadata": {},
147
+ "outputs": [
148
+ {
149
+ "data": {
150
+ "text/plain": [
151
+ "'1.11.0+cu113'"
152
+ ]
153
+ },
154
+ "execution_count": 6,
155
+ "metadata": {},
156
+ "output_type": "execute_result"
157
+ }
158
+ ],
159
+ "source": [
160
+ "import torch\n",
161
+ "\n",
162
+ "torch.__version__"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "code",
167
+ "execution_count": 1,
168
+ "metadata": {},
169
+ "outputs": [
170
+ {
171
+ "name": "stderr",
172
+ "output_type": "stream",
173
+ "text": [
174
+ "ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.\n"
175
+ ]
176
+ }
177
+ ],
178
+ "source": [
179
+ "from handler import EndpointHandler\n",
180
+ "\n",
181
+ "# init handler\n",
182
+ "my_handler = EndpointHandler(path=\".\")"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": 6,
188
+ "metadata": {},
189
+ "outputs": [
190
+ {
191
+ "data": {
192
+ "application/vnd.jupyter.widget-view+json": {
193
+ "model_id": "376de150f16b4b4bb0c3ab8c513de5c0",
194
+ "version_major": 2,
195
+ "version_minor": 0
196
+ },
197
+ "text/plain": [
198
+ "0it [00:00, ?it/s]"
199
+ ]
200
+ },
201
+ "metadata": {},
202
+ "output_type": "display_data"
203
+ }
204
+ ],
205
+ "source": [
206
+ "import base64\n",
207
+ "from PIL import Image\n",
208
+ "from io import BytesIO\n",
209
+ "import json\n",
210
+ "\n",
211
+ "# helper decoder\n",
212
+ "def decode_base64_image(image_string):\n",
213
+ " base64_image = base64.b64decode(image_string)\n",
214
+ " buffer = BytesIO(base64_image)\n",
215
+ " return Image.open(buffer)\n",
216
+ "\n",
217
+ "# prepare sample payload\n",
218
+ "request = {\"inputs\": \"a high resulotion image of a macbook\"}\n",
219
+ "\n",
220
+ "# test the handler\n",
221
+ "pred = my_handler(request)"
222
+ ]
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": 4,
227
+ "metadata": {},
228
+ "outputs": [],
229
+ "source": [
230
+ "decode_base64_image(pred[\"image\"]).save(\"sample.jpg\")"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "markdown",
235
+ "metadata": {},
236
+ "source": [
237
+ "![test](sample.jpg)"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "code",
242
+ "execution_count": null,
243
+ "metadata": {},
244
+ "outputs": [],
245
+ "source": []
246
+ }
247
+ ],
248
+ "metadata": {
249
+ "kernelspec": {
250
+ "display_name": "Python 3.9.13 ('dev': conda)",
251
+ "language": "python",
252
+ "name": "python3"
253
+ },
254
+ "language_info": {
255
+ "codemirror_mode": {
256
+ "name": "ipython",
257
+ "version": 3
258
+ },
259
+ "file_extension": ".py",
260
+ "mimetype": "text/x-python",
261
+ "name": "python",
262
+ "nbconvert_exporter": "python",
263
+ "pygments_lexer": "ipython3",
264
+ "version": "3.9.13"
265
+ },
266
+ "orig_nbformat": 4,
267
+ "vscode": {
268
+ "interpreter": {
269
+ "hash": "f6dd96c16031089903d5a31ec148b80aeb0d39c32affb1a1080393235fbfa2fc"
270
+ }
271
+ }
272
+ },
273
+ "nbformat": 4,
274
+ "nbformat_minor": 2
275
+ }