philschmid HF staff commited on
Commit
2a261b1
1 Parent(s): 294e921

add hadnler

Browse files
Files changed (3) hide show
  1. create_handler.ipynb +152 -0
  2. handler.py +46 -0
  3. sample.png +0 -0
create_handler.ipynb ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "!pip install transformers --upgrade"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "markdown",
14
+ "metadata": {},
15
+ "source": [
16
+ "## Create Custom Handler for Inference Endpoints\n"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": 17,
22
+ "metadata": {},
23
+ "outputs": [
24
+ {
25
+ "name": "stdout",
26
+ "output_type": "stream",
27
+ "text": [
28
+ "Overwriting handler.py\n"
29
+ ]
30
+ }
31
+ ],
32
+ "source": [
33
+ "%%writefile handler.py\n",
34
+ "from typing import Dict, List, Any\n",
35
+ "from transformers import DonutProcessor, VisionEncoderDecoderModel\n",
36
+ "import torch\n",
37
+ "\n",
38
+ "\n",
39
+ "# check for GPU\n",
40
+ "device = 0 if torch.cuda.is_available() else -1\n",
41
+ "\n",
42
+ "\n",
43
+ "class EndpointHandler:\n",
44
+ " def __init__(self, path=\"\"):\n",
45
+ " # load the model\n",
46
+ " self.processor = DonutProcessor.from_pretrained(path)\n",
47
+ " self.model = VisionEncoderDecoderModel.from_pretrained(path)\n",
48
+ " # move model to device\n",
49
+ " self.model.to(device)\n",
50
+ " self.decoder_input_ids = self.processor.tokenizer(\n",
51
+ " \"<s_cord-v2>\", add_special_tokens=False, return_tensors=\"pt\"\n",
52
+ " ).input_ids\n",
53
+ "\n",
54
+ " def __call__(self, data: Any) -> List[List[Dict[str, float]]]:\n",
55
+ "\n",
56
+ " inputs = data.pop(\"inputs\", data)\n",
57
+ "\n",
58
+ "\n",
59
+ " # preprocess the input\n",
60
+ " pixel_values = self.processor(inputs, return_tensors=\"pt\").pixel_values\n",
61
+ "\n",
62
+ " # forward pass\n",
63
+ " outputs = self.model.generate(\n",
64
+ " pixel_values.to(device),\n",
65
+ " decoder_input_ids=self.decoder_input_ids.to(device),\n",
66
+ " max_length=self.model.decoder.config.max_position_embeddings,\n",
67
+ " early_stopping=True,\n",
68
+ " pad_token_id=self.processor.tokenizer.pad_token_id,\n",
69
+ " eos_token_id=self.processor.tokenizer.eos_token_id,\n",
70
+ " use_cache=True,\n",
71
+ " num_beams=1,\n",
72
+ " bad_words_ids=[[self.processor.tokenizer.unk_token_id]],\n",
73
+ " return_dict_in_generate=True,\n",
74
+ " )\n",
75
+ " # process output\n",
76
+ " prediction = self.processor.batch_decode(outputs.sequences)[0]\n",
77
+ " prediction = self.processor.token2json(prediction)\n",
78
+ "\n",
79
+ " return prediction\n"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "markdown",
84
+ "metadata": {},
85
+ "source": [
86
+ "test custom pipeline"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": 2,
92
+ "metadata": {},
93
+ "outputs": [],
94
+ "source": [
95
+ "from handler import EndpointHandler\n",
96
+ "\n",
97
+ "my_handler = EndpointHandler(\".\")"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": 13,
103
+ "metadata": {},
104
+ "outputs": [
105
+ {
106
+ "name": "stdout",
107
+ "output_type": "stream",
108
+ "text": [
109
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
110
+ "To disable this warning, you can either:\n",
111
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
112
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
113
+ ]
114
+ }
115
+ ],
116
+ "source": [
117
+ "from PIL import Image\n",
118
+ "\n",
119
+ "payload = {\"inputs\": Image.open(\"sample.png\").convert(\"RGB\")}\n",
120
+ "\n",
121
+ "my_handler(payload)"
122
+ ]
123
+ }
124
+ ],
125
+ "metadata": {
126
+ "kernelspec": {
127
+ "display_name": "Python 3.9.13 ('dev': conda)",
128
+ "language": "python",
129
+ "name": "python3"
130
+ },
131
+ "language_info": {
132
+ "codemirror_mode": {
133
+ "name": "ipython",
134
+ "version": 3
135
+ },
136
+ "file_extension": ".py",
137
+ "mimetype": "text/x-python",
138
+ "name": "python",
139
+ "nbconvert_exporter": "python",
140
+ "pygments_lexer": "ipython3",
141
+ "version": "3.9.13"
142
+ },
143
+ "orig_nbformat": 4,
144
+ "vscode": {
145
+ "interpreter": {
146
+ "hash": "f6dd96c16031089903d5a31ec148b80aeb0d39c32affb1a1080393235fbfa2fc"
147
+ }
148
+ }
149
+ },
150
+ "nbformat": 4,
151
+ "nbformat_minor": 2
152
+ }
handler.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
3
+ import torch
4
+
5
+
6
+ # check for GPU
7
+ device = 0 if torch.cuda.is_available() else -1
8
+
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, path=""):
12
+ # load the model
13
+ self.processor = DonutProcessor.from_pretrained(path)
14
+ self.model = VisionEncoderDecoderModel.from_pretrained(path)
15
+ # move model to device
16
+ self.model.to(device)
17
+ self.decoder_input_ids = self.processor.tokenizer(
18
+ "<s_cord-v2>", add_special_tokens=False, return_tensors="pt"
19
+ ).input_ids
20
+
21
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
22
+
23
+ inputs = data.pop("inputs", data)
24
+
25
+
26
+ # preprocess the input
27
+ pixel_values = self.processor(inputs, return_tensors="pt").pixel_values
28
+
29
+ # forward pass
30
+ outputs = self.model.generate(
31
+ pixel_values.to(device),
32
+ decoder_input_ids=self.decoder_input_ids.to(device),
33
+ max_length=self.model.decoder.config.max_position_embeddings,
34
+ early_stopping=True,
35
+ pad_token_id=self.processor.tokenizer.pad_token_id,
36
+ eos_token_id=self.processor.tokenizer.eos_token_id,
37
+ use_cache=True,
38
+ num_beams=1,
39
+ bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
40
+ return_dict_in_generate=True,
41
+ )
42
+ # process output
43
+ prediction = self.processor.batch_decode(outputs.sequences)[0]
44
+ prediction = self.processor.token2json(prediction)
45
+
46
+ return prediction
sample.png ADDED