Smashinfries commited on
Commit
38206f4
·
verified ·
1 Parent(s): 24cd3c8

add colab notebook

Browse files
Files changed (1) hide show
  1. WD_Tagger_Mobile.ipynb +436 -0
WD_Tagger_Mobile.ipynb ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "collapsed_sections": [
8
+ "GAZViC5o2bya",
9
+ "QwoVd4CE2njF",
10
+ "8r0qzU2NRoIT",
11
+ "lgaEjLAo7lMd",
12
+ "RadVNaev2_mF"
13
+ ]
14
+ },
15
+ "kernelspec": {
16
+ "name": "python3",
17
+ "display_name": "Python 3"
18
+ },
19
+ "language_info": {
20
+ "name": "python"
21
+ }
22
+ },
23
+ "cells": [
24
+ {
25
+ "cell_type": "markdown",
26
+ "source": [
27
+ "# Dependencies"
28
+ ],
29
+ "metadata": {
30
+ "id": "GAZViC5o2bya"
31
+ }
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 1,
36
+ "metadata": {
37
+ "colab": {
38
+ "base_uri": "https://localhost:8080/"
39
+ },
40
+ "id": "wNCZ04U82IiL",
41
+ "outputId": "dc277e76-67e7-4781-95a1-123bb139bbf3"
42
+ },
43
+ "outputs": [
44
+ {
45
+ "output_type": "stream",
46
+ "name": "stdout",
47
+ "text": [
48
+ "Collecting onnxruntime\n",
49
+ " Downloading onnxruntime-1.17.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.8 MB)\n",
50
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.8/6.8 MB\u001b[0m \u001b[31m19.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
51
+ "\u001b[?25hCollecting coloredlogs (from onnxruntime)\n",
52
+ " Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)\n",
53
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m4.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
54
+ "\u001b[?25hRequirement already satisfied: flatbuffers in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (23.5.26)\n",
55
+ "Requirement already satisfied: numpy>=1.21.6 in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (1.25.2)\n",
56
+ "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (23.2)\n",
57
+ "Requirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (3.20.3)\n",
58
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from onnxruntime) (1.12)\n",
59
+ "Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)\n",
60
+ " Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)\n",
61
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
62
+ "\u001b[?25hRequirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->onnxruntime) (1.3.0)\n",
63
+ "Installing collected packages: humanfriendly, coloredlogs, onnxruntime\n",
64
+ "Successfully installed coloredlogs-15.0.1 humanfriendly-10.0 onnxruntime-1.17.1\n",
65
+ "Collecting onnx\n",
66
+ " Downloading onnx-1.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.7 MB)\n",
67
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m15.7/15.7 MB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
68
+ "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from onnx) (1.25.2)\n",
69
+ "Requirement already satisfied: protobuf>=3.20.2 in /usr/local/lib/python3.10/dist-packages (from onnx) (3.20.3)\n",
70
+ "Installing collected packages: onnx\n",
71
+ "Successfully installed onnx-1.15.0\n",
72
+ "Collecting onnxruntime-extensions\n",
73
+ " Downloading onnxruntime_extensions-0.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.0 MB)\n",
74
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.0/7.0 MB\u001b[0m \u001b[31m21.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
75
+ "\u001b[?25hInstalling collected packages: onnxruntime-extensions\n",
76
+ "Successfully installed onnxruntime-extensions-0.10.1\n"
77
+ ]
78
+ }
79
+ ],
80
+ "source": [
81
+ "!pip install onnxruntime\n",
82
+ "!pip install onnx\n",
83
+ "!pip install onnxruntime-extensions"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "markdown",
88
+ "source": [
89
+ "# Download ONNX Model\n",
90
+ "This downloads [wd-convnext-tagger-v3](https://huggingface.co/SmilingWolf/wd-convnext-tagger-v3) created by [SmilingWolf](https://huggingface.co/SmilingWolf).\n",
91
+ "\n",
92
+ "Feel free to use SmilingWolfs other model variants instead.\n",
93
+ "\n",
94
+ "The tags and power image is also downloaded for inferencing."
95
+ ],
96
+ "metadata": {
97
+ "id": "QwoVd4CE2njF"
98
+ }
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "source": [
103
+ "!wget https://huggingface.co/SmilingWolf/wd-convnext-tagger-v3/resolve/main/model.onnx?download=true -O model.onnx"
104
+ ],
105
+ "metadata": {
106
+ "colab": {
107
+ "base_uri": "https://localhost:8080/"
108
+ },
109
+ "id": "AMF_IIxm2tT_",
110
+ "outputId": "b8e574b8-8276-4f74-92fa-b3a946e92655"
111
+ },
112
+ "execution_count": 2,
113
+ "outputs": [
114
+ {
115
+ "output_type": "stream",
116
+ "name": "stdout",
117
+ "text": [
118
+ "--2024-03-09 05:03:49-- https://huggingface.co/SmilingWolf/wd-convnext-tagger-v3/resolve/main/model.onnx?download=true\n",
119
+ "Resolving huggingface.co (huggingface.co)... 3.163.189.90, 3.163.189.74, 3.163.189.37, ...\n",
120
+ "Connecting to huggingface.co (huggingface.co)|3.163.189.90|:443... connected.\n",
121
+ "HTTP request sent, awaiting response... 302 Found\n",
122
+ "Location: https://cdn-lfs-us-1.huggingface.co/repos/d8/61/d8612304f05de662484c881a2ac180318d718b820314ffaaa700ef22c267e1a1/02f30d4de9bada756981a11464d13aa206f5e2d4ff6da384511beb812d58b2ca?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27model.onnx%3B+filename%3D%22model.onnx%22%3B&Expires=1710219829&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxMDIxOTgyOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2Q4LzYxL2Q4NjEyMzA0ZjA1ZGU2NjI0ODRjODgxYTJhYzE4MDMxOGQ3MThiODIwMzE0ZmZhYWE3MDBlZjIyYzI2N2UxYTEvMDJmMzBkNGRlOWJhZGE3NTY5ODFhMTE0NjRkMTNhYTIwNmY1ZTJkNGZmNmRhMzg0NTExYmViODEyZDU4YjJjYT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=NUW35U0E0VvUCTynr4WArU1pgdg-F506HK5TiNnP7IrwbhEJfQpEcJo5CBoz1e4iUprWUCcEZJS0dRCmlGrr0PGYIjKXZ00BE4EiGZyi2vUqdP%7ExxUzWxps6XwEIVGiXc5R9yC%7EQgtd6oSJYQOH4ITBvEoNOJoQUPnjL5m1vk9T8-xHpeAxkHkHeOaF8FjlU5HKvUIc65SlUGirxOsHXl0v8o7sKmYlFs0Nmkoj9MurWKFL0sLFW5XIxkZveAGS9GB2sisitzkc4BUhICqDMSfv5CtlTEhXpgDUGbFo%7EohbeuKkQjIgnSU%7EVdFhDvY7Qew%7E5emodk-508AHvCx-UrA__&Key-Pair-Id=KCD77M1F0VK2B [following]\n",
123
+ "--2024-03-09 05:03:49-- https://cdn-lfs-us-1.huggingface.co/repos/d8/61/d8612304f05de662484c881a2ac180318d718b820314ffaaa700ef22c267e1a1/02f30d4de9bada756981a11464d13aa206f5e2d4ff6da384511beb812d58b2ca?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27model.onnx%3B+filename%3D%22model.onnx%22%3B&Expires=1710219829&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxMDIxOTgyOX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2Q4LzYxL2Q4NjEyMzA0ZjA1ZGU2NjI0ODRjODgxYTJhYzE4MDMxOGQ3MThiODIwMzE0ZmZhYWE3MDBlZjIyYzI2N2UxYTEvMDJmMzBkNGRlOWJhZGE3NTY5ODFhMTE0NjRkMTNhYTIwNmY1ZTJkNGZmNmRhMzg0NTExYmViODEyZDU4YjJjYT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=NUW35U0E0VvUCTynr4WArU1pgdg-F506HK5TiNnP7IrwbhEJfQpEcJo5CBoz1e4iUprWUCcEZJS0dRCmlGrr0PGYIjKXZ00BE4EiGZyi2vUqdP%7ExxUzWxps6XwEIVGiXc5R9yC%7EQgtd6oSJYQOH4ITBvEoNOJoQUPnjL5m1vk9T8-xHpeAxkHkHeOaF8FjlU5HKvUIc65SlUGirxOsHXl0v8o7sKmYlFs0Nmkoj9MurWKFL0sLFW5XIxkZveAGS9GB2sisitzkc4BUhICqDMSfv5CtlTEhXpgDUGbFo%7EohbeuKkQjIgnSU%7EVdFhDvY7Qew%7E5emodk-508AHvCx-UrA__&Key-Pair-Id=KCD77M1F0VK2B\n",
124
+ "Resolving cdn-lfs-us-1.huggingface.co (cdn-lfs-us-1.huggingface.co)... 3.163.189.20, 3.163.189.28, 3.163.189.91, ...\n",
125
+ "Connecting to cdn-lfs-us-1.huggingface.co (cdn-lfs-us-1.huggingface.co)|3.163.189.20|:443... connected.\n",
126
+ "HTTP request sent, awaiting response... 200 OK\n",
127
+ "Length: 394990732 (377M) [application/octet-stream]\n",
128
+ "Saving to: ‘model.onnx’\n",
129
+ "\n",
130
+ "model.onnx 100%[===================>] 376.69M 31.6MB/s in 5.2s \n",
131
+ "\n",
132
+ "2024-03-09 05:03:54 (72.7 MB/s) - ‘model.onnx’ saved [394990732/394990732]\n",
133
+ "\n"
134
+ ]
135
+ }
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "markdown",
140
+ "source": [
141
+ "## Download Tags / Test Image"
142
+ ],
143
+ "metadata": {
144
+ "id": "8r0qzU2NRoIT"
145
+ }
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "source": [
150
+ "!wget https://huggingface.co/SmilingWolf/wd-convnext-tagger-v3/resolve/main/selected_tags.csv?download=true -O tags.csv\n",
151
+ "!wget https://huggingface.co/spaces/SmilingWolf/wd-tagger/resolve/main/power.jpg?download=true -O power.jpg"
152
+ ],
153
+ "metadata": {
154
+ "colab": {
155
+ "base_uri": "https://localhost:8080/"
156
+ },
157
+ "id": "WPrRzNP-RqKs",
158
+ "outputId": "a4a5af15-3bf2-4383-d4de-616e85485c20"
159
+ },
160
+ "execution_count": 3,
161
+ "outputs": [
162
+ {
163
+ "output_type": "stream",
164
+ "name": "stdout",
165
+ "text": [
166
+ "--2024-03-09 05:03:54-- https://huggingface.co/SmilingWolf/wd-convnext-tagger-v3/resolve/main/selected_tags.csv?download=true\n",
167
+ "Resolving huggingface.co (huggingface.co)... 3.163.189.90, 3.163.189.74, 3.163.189.37, ...\n",
168
+ "Connecting to huggingface.co (huggingface.co)|3.163.189.90|:443... connected.\n",
169
+ "HTTP request sent, awaiting response... 200 OK\n",
170
+ "Length: 308468 (301K) [text/plain]\n",
171
+ "Saving to: ‘tags.csv’\n",
172
+ "\n",
173
+ "\rtags.csv 0%[ ] 0 --.-KB/s \rtags.csv 100%[===================>] 301.24K --.-KB/s in 0.03s \n",
174
+ "\n",
175
+ "2024-03-09 05:03:54 (11.1 MB/s) - ‘tags.csv’ saved [308468/308468]\n",
176
+ "\n",
177
+ "--2024-03-09 05:03:55-- https://huggingface.co/spaces/SmilingWolf/wd-tagger/resolve/main/power.jpg?download=true\n",
178
+ "Resolving huggingface.co (huggingface.co)... 3.163.189.90, 3.163.189.74, 3.163.189.37, ...\n",
179
+ "Connecting to huggingface.co (huggingface.co)|3.163.189.90|:443... connected.\n",
180
+ "HTTP request sent, awaiting response... 200 OK\n",
181
+ "Length: 91159 (89K) [image/jpeg]\n",
182
+ "Saving to: ‘power.jpg’\n",
183
+ "\n",
184
+ "power.jpg 100%[===================>] 89.02K --.-KB/s in 0.01s \n",
185
+ "\n",
186
+ "2024-03-09 05:03:55 (8.09 MB/s) - ‘power.jpg’ saved [91159/91159]\n",
187
+ "\n"
188
+ ]
189
+ }
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "markdown",
194
+ "source": [
195
+ "# ONNX QUANT\n",
196
+ "To cut down on model size and have it work on mobile devices, quantization is needed (i think).\n",
197
+ "\n",
198
+ "First preprocess model for quantization - then quantize.\n",
199
+ "\n",
200
+ "The quant model name will be **model.quant.onnx**\n",
201
+ "\n",
202
+ "The convnext model went from ~377 MB down to 105 MB!"
203
+ ],
204
+ "metadata": {
205
+ "id": "lgaEjLAo7lMd"
206
+ }
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "source": [
211
+ "!python -m onnxruntime.quantization.preprocess --input model.onnx --output model_pre_quant.onnx"
212
+ ],
213
+ "metadata": {
214
+ "id": "sdk95gWw7Imp"
215
+ },
216
+ "execution_count": 4,
217
+ "outputs": []
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "source": [
222
+ "import onnx\n",
223
+ "from onnxruntime.quantization import quantize_dynamic, QuantType\n",
224
+ "\n",
225
+ "model_fp32 = 'model_pre_quant.onnx'\n",
226
+ "model_quant = 'model.quant.onnx'\n",
227
+ "# quantized_model = quantize_dynamic(model_fp32, model_quant, nodes_to_exclude=[\"Conv\", \"/core_model/stem/stem.0/Conv\", \"/core_model/stages/stages.0/blocks/blocks.0/conv_dw/Conv\", \"/core_model/stages/stages.0/blocks/blocks.1/conv_dw/Conv\", \"/core_model/stages/stages.0/blocks/blocks.2/conv_dw/Conv\"])\n",
228
+ "quantized_model = quantize_dynamic(model_fp32, model_quant, op_types_to_quantize=['MatMul', 'Transpose', 'Gemm', 'LayerNormalization'])\n",
229
+ "\n",
230
+ "# remove unneeded model\n",
231
+ "%rm model_pre_quant.onnx"
232
+ ],
233
+ "metadata": {
234
+ "id": "E7M68khX7H93"
235
+ },
236
+ "execution_count": 5,
237
+ "outputs": []
238
+ },
239
+ {
240
+ "cell_type": "markdown",
241
+ "source": [
242
+ "# Add Preprocessing / Postprocessing\n",
243
+ "\n",
244
+ "To make mobile inference easier, we will add preprocessing to the model.\n",
245
+ "\n",
246
+ "Instead of resizing, adding padding, converting image to float32 array, and converting to BGR before inferencing - we can add these steps to the model so that only a uint8 tensor is needed.\n",
247
+ "\n",
248
+ "The model will be named **model.quant.preproc.onnx**\n",
249
+ "\n",
250
+ "**WARNING** \n",
251
+ "It's very possible that I could be doing this wrong or that it could have some improvements. I'm not really sure what I'm doing but I found that these settings have given me the closest results to the base quant model."
252
+ ],
253
+ "metadata": {
254
+ "id": "TmCdMPTwb6Mc"
255
+ }
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "source": [
260
+ "import onnx\n",
261
+ "from onnxruntime_extensions.tools.pre_post_processing import create_named_value, Normalize, Transpose, Debug, ReverseAxis, PixelsToYCbCr, PrePostProcessor, Unsqueeze, LetterBox, ConvertImageToBGR, Resize, CenterCrop, ImageBytesToFloat, ChannelsLastToChannelsFirst\n",
262
+ "\n",
263
+ "image_mean = [0.5,0.5,0.5]\n",
264
+ "image_std = [0.5,0.5,0.5]\n",
265
+ "\n",
266
+ "img_size = 448\n",
267
+ "mean_std = list(zip(image_mean, image_std))\n",
268
+ "new_input = create_named_value('image', onnx.TensorProto.UINT8, [\"num_bytes\"])\n",
269
+ "pipeline = PrePostProcessor([new_input], onnx_opset=18)\n",
270
+ "pipeline.add_pre_processing(\n",
271
+ " [\n",
272
+ "\n",
273
+ " ConvertImageToBGR(),\n",
274
+ " Resize((img_size, img_size), policy=\"not_larger\"),\n",
275
+ " LetterBox(target_shape=(img_size, img_size)), # adds padding\n",
276
+ " ImageBytesToFloat((255/2) / 255), # NO IDEA WHAT IM DOING. all i know is that the default value gives bad results\n",
277
+ " Normalize(mean_std, layout='HWC'), # copied values from the config on HF. seems to help results match closer to non-preprocessed model.\n",
278
+ " Unsqueeze(axes=[0]), # add batch dim so shape is {1, 448, 448, channels}.\n",
279
+ " ]\n",
280
+ ")"
281
+ ],
282
+ "metadata": {
283
+ "id": "HzJjPcSrb-DL"
284
+ },
285
+ "execution_count": 7,
286
+ "outputs": []
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "source": [
291
+ "# Save Model\n",
292
+ "model = onnx.load('model.quant.onnx')\n",
293
+ "new_model = pipeline.run(model)\n",
294
+ "onnx.save_model(new_model, 'model.quant.preproc.onnx')"
295
+ ],
296
+ "metadata": {
297
+ "id": "_zjoi_AWhIZN"
298
+ },
299
+ "execution_count": 8,
300
+ "outputs": []
301
+ },
302
+ {
303
+ "cell_type": "markdown",
304
+ "source": [
305
+ "# Test Model\n",
306
+ "Most of the inference code is directly from SmilingWolf's wd tagger space: https://huggingface.co/spaces/SmilingWolf/wd-tagger/blob/main/app.py"
307
+ ],
308
+ "metadata": {
309
+ "id": "nXk6AZM0kfL4"
310
+ }
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "source": [
315
+ "import onnxruntime as _ort\n",
316
+ "from onnxruntime_extensions import get_library_path as _lib_path\n",
317
+ "from PIL import Image\n",
318
+ "import numpy as np\n",
319
+ "import pandas as pd\n",
320
+ "\n",
321
+ "# Step 1: setup session options\n",
322
+ "so = _ort.SessionOptions()\n",
323
+ "so.register_custom_ops_library(_lib_path())\n",
324
+ "\n",
325
+ "# Step 2: create session\n",
326
+ "sess = _ort.InferenceSession(\"/content/model.quant.preproc.onnx\",so) # Don't forget to add session options (so)\n",
327
+ "\n",
328
+ "# Step 3: load image (no preprocessing needed!)\n",
329
+ "image = np.frombuffer(open('/content/power.jpg', 'rb').read(), dtype=np.uint8)\n",
330
+ "\n",
331
+ "# Step 4: run cell!\n",
332
+ "\n",
333
+ "\n",
334
+ "###### Inference Code ######\n",
335
+ "kaomojis = [\n",
336
+ " \"0_0\",\n",
337
+ " \"(o)_(o)\",\n",
338
+ " \"+_+\",\n",
339
+ " \"+_-\",\n",
340
+ " \"._.\",\n",
341
+ " \"<o>_<o>\",\n",
342
+ " \"<|>_<|>\",\n",
343
+ " \"=_=\",\n",
344
+ " \">_<\",\n",
345
+ " \"3_3\",\n",
346
+ " \"6_9\",\n",
347
+ " \">_o\",\n",
348
+ " \"@_@\",\n",
349
+ " \"^_^\",\n",
350
+ " \"o_o\",\n",
351
+ " \"u_u\",\n",
352
+ " \"x_x\",\n",
353
+ " \"|_|\",\n",
354
+ " \"||_||\",\n",
355
+ "]\n",
356
+ "\n",
357
+ "\n",
358
+ "def load_labels(dataframe) -> list[str]:\n",
359
+ " name_series = dataframe[\"name\"]\n",
360
+ " name_series = name_series.map(\n",
361
+ " lambda x: x.replace(\"_\", \" \") if x not in kaomojis else x\n",
362
+ " )\n",
363
+ " tag_names = name_series.tolist()\n",
364
+ "\n",
365
+ " rating_indexes = list(np.where(dataframe[\"category\"] == 9)[0])\n",
366
+ " general_indexes = list(np.where(dataframe[\"category\"] == 0)[0])\n",
367
+ " character_indexes = list(np.where(dataframe[\"category\"] == 4)[0])\n",
368
+ " return tag_names, rating_indexes, general_indexes, character_indexes\n",
369
+ "\n",
370
+ "csv_path = \"/content/tags.csv\"\n",
371
+ "\n",
372
+ "tags_df = pd.read_csv(csv_path)\n",
373
+ "sep_tags = load_labels(tags_df)\n",
374
+ "\n",
375
+ "tag_names = sep_tags[0]\n",
376
+ "rating_indexes = sep_tags[1]\n",
377
+ "general_indexes = sep_tags[2]\n",
378
+ "character_indexes = sep_tags[3]\n",
379
+ "\n",
380
+ "input_name = sess.get_inputs()[0].name\n",
381
+ "label_name = sess.get_outputs()[0].name\n",
382
+ "\n",
383
+ "preds = sess.run([label_name], {input_name: image})[0]\n",
384
+ "\n",
385
+ "\n",
386
+ "labels = list(zip(tag_names, preds[0].astype(float)))\n",
387
+ "ratings_names = [labels[i] for i in rating_indexes]\n",
388
+ "rating = dict(ratings_names)\n",
389
+ "\n",
390
+ "character_names = [labels[i] for i in character_indexes]\n",
391
+ "\n",
392
+ "character_res = [x for x in character_names if x[1] > 0.85]\n",
393
+ "character_res = dict(character_res)\n",
394
+ "\n",
395
+ "general_names = [labels[i] for i in general_indexes]\n",
396
+ "general_res = [x for x in general_names if x[1] > 0.35]\n",
397
+ "general_res = dict(general_res)\n",
398
+ "\n",
399
+ "sorted_general_strings = sorted(\n",
400
+ " general_res.items(),\n",
401
+ " key=lambda x: x[1],\n",
402
+ " reverse=True,\n",
403
+ ")\n",
404
+ "sorted_general_strings = [x[0] for x in sorted_general_strings]\n",
405
+ "sorted_general_strings = (\n",
406
+ " \", \".join(sorted_general_strings).replace(\"(\", \"\\(\").replace(\")\", \"\\)\")\n",
407
+ ")\n",
408
+ "\n",
409
+ "print(rating)\n",
410
+ "print(character_res)\n",
411
+ "print(general_res)\n",
412
+ "print(sorted_general_strings)"
413
+ ],
414
+ "metadata": {
415
+ "colab": {
416
+ "base_uri": "https://localhost:8080/"
417
+ },
418
+ "id": "m5B0Wj4NkhMt",
419
+ "outputId": "6b551906-9b9e-4db9-f2bc-5ae5427381a3"
420
+ },
421
+ "execution_count": 9,
422
+ "outputs": [
423
+ {
424
+ "output_type": "stream",
425
+ "name": "stdout",
426
+ "text": [
427
+ "{'general': 0.9169240593910217, 'sensitive': 0.0812525749206543, 'questionable': 0.0006865859031677246, 'explicit': 0.0002942383289337158}\n",
428
+ "{'power (chainsaw man)': 0.9924684762954712}\n",
429
+ "{'1girl': 0.9980734586715698, 'solo': 0.967477560043335, 'long hair': 0.8743129968643188, 'looking at viewer': 0.8921941518783569, 'smile': 0.7079806327819824, 'open mouth': 0.8572969436645508, 'simple background': 0.6686466336250305, 'shirt': 0.9388805627822876, 'blonde hair': 0.647895336151123, 'white background': 0.5928694009780884, 'red eyes': 0.4210684299468994, 'hair between eyes': 0.8992906212806702, 'jacket': 0.5598545074462891, 'white shirt': 0.8964416980743408, 'upper body': 0.666782557964325, 'horns': 0.9738106727600098, 'teeth': 0.9321538209915161, 'necktie': 0.9494357109069824, 'collared shirt': 0.8381757736206055, 'orange eyes': 0.4594384431838989, 'symbol-shaped pupils': 0.8655499219894409, 'fangs': 0.3685188889503479, 'demon horns': 0.5966249704360962, 'sharp teeth': 0.8942122459411621, 'black necktie': 0.8483953475952148, 'claw pose': 0.5946617722511292, 'red horns': 0.9497503042221069, 'cross-shaped pupils': 0.9292328357696533, 'pillarboxed': 0.766990065574646}\n",
430
+ "1girl, horns, solo, red horns, necktie, shirt, teeth, cross-shaped pupils, hair between eyes, white shirt, sharp teeth, looking at viewer, long hair, symbol-shaped pupils, open mouth, black necktie, collared shirt, pillarboxed, smile, simple background, upper body, blonde hair, demon horns, claw pose, white background, jacket, orange eyes, red eyes, fangs\n"
431
+ ]
432
+ }
433
+ ]
434
+ }
435
+ ]
436
+ }