daspartho commited on
Commit
fed847e
1 Parent(s): 02182ab

updated model

Browse files
Files changed (2) hide show
  1. model.ipynb +273 -0
  2. model.pkl +2 -2
model.ipynb ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "FDfI95Sh1lW0"
7
+ },
8
+ "source": [
9
+ "# Is it Huggable?\n",
10
+ "*Classify objects as huggable or not.*\n",
11
+ "\n",
12
+ "This notebook has steps to make the model.\n",
13
+ "\n",
14
+ "Just want to play? Use directly on the [website](https://daspartho.github.io/is-it-huggable)."
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "markdown",
19
+ "metadata": {
20
+ "id": "r-GyBdvhzfY2"
21
+ },
22
+ "source": [
23
+ "### Install required libraries"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "metadata": {
30
+ "id": "CQdd5Egc-FQV"
31
+ },
32
+ "outputs": [],
33
+ "source": [
34
+ "!pip install -Uqq fastai duckduckgo_search"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "markdown",
39
+ "metadata": {
40
+ "id": "vgvpU91p0ERn"
41
+ },
42
+ "source": [
43
+ "### Import modules required"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": 2,
49
+ "metadata": {
50
+ "id": "BD7-yF0l-Y4h"
51
+ },
52
+ "outputs": [],
53
+ "source": [
54
+ "from duckduckgo_search import ddg_images\n",
55
+ "from fastcore.all import *\n",
56
+ "from fastdownload import download_url\n",
57
+ "from fastai.vision.all import *"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "markdown",
62
+ "metadata": {
63
+ "id": "WKZC9jY_zOfx"
64
+ },
65
+ "source": [
66
+ "### Use DuckDuckGo to search for images of examples of the two groups"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "metadata": {
73
+ "id": "hqnqTAXWCAn6"
74
+ },
75
+ "outputs": [],
76
+ "source": [
77
+ "def search_images(term, max_images=50):\n",
78
+ " print(f\"Searching for '{term}'\")\n",
79
+ " return L(ddg_images(term, max_results=max_images)).itemgot('image')\n",
80
+ "\n",
81
+ "path = Path('huggable_or_not')\n",
82
+ "\n",
83
+ "# examples of both groups\n",
84
+ "categories={\n",
85
+ " 'huggable':['plushie', 'pillow' , 'ballon', 'dog', 'cat', 'bunny', 'snowman', 'bed', 'sofa', 'people', 'baby', 'cloud', 'dolphin', 'horse', 'cow', 'sheep'], \n",
86
+ " 'not huggable':['chainsaw', 'sword', 'cactus', 'barbwire', 'bear', 'snake', 'lion', 'shark', 'fire','knive','fork', 'dinosaur', 'crocodile', 'spider', 'bees', 'porcupine']\n",
87
+ " }\n",
88
+ "\n",
89
+ "for category in categories:\n",
90
+ " dest = (path/category)\n",
91
+ " dest.mkdir(exist_ok=True, parents=True)\n",
92
+ " for example in categories[category]:\n",
93
+ " download_images(dest, urls=search_images(f'{example} photo'))\n",
94
+ " resize_images(path/category, max_size=400, dest=path/category)"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "markdown",
99
+ "metadata": {
100
+ "id": "Bpsp4MTGxBWl"
101
+ },
102
+ "source": [
103
+ "### Remove photos that didn't download correctly."
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "metadata": {
110
+ "id": "FzuHMc_qD0UO"
111
+ },
112
+ "outputs": [],
113
+ "source": [
114
+ "failed = verify_images(get_image_files(path))\n",
115
+ "failed.map(Path.unlink)\n",
116
+ "len(failed)"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "markdown",
121
+ "metadata": {
122
+ "id": "eFFr_VE45ihe"
123
+ },
124
+ "source": [
125
+ "### Preparing the data for training"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "metadata": {
132
+ "id": "wzSeghRAFYqF"
133
+ },
134
+ "outputs": [],
135
+ "source": [
136
+ "dls = DataBlock(\n",
137
+ " blocks=(ImageBlock, CategoryBlock), # inputs to our model are images, and the outputs are categories\n",
138
+ " get_items=get_image_files, \n",
139
+ " splitter=RandomSplitter(valid_pct=0.2, seed=42), # Split the data into training and validation sets randomly, using 20% of the data for the validation set\n",
140
+ " get_y=parent_label, # The labels is the name of the parent of each file\n",
141
+ " item_tfms=RandomResizedCrop(224, min_scale=0.3), # picks a random scaled crop of an image and resize it to 224x224 pixels\n",
142
+ " batch_tfms=aug_transforms() # applies augmentations to an entire batch\n",
143
+ ").dataloaders(path, bs=32)\n",
144
+ "\n",
145
+ "dls.show_batch()"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "markdown",
150
+ "metadata": {
151
+ "id": "y-FNWY-3zEF3"
152
+ },
153
+ "source": [
154
+ "### Fine-tune a pretrained neural network to recognise these two groups"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": null,
160
+ "metadata": {
161
+ "id": "5ao0lw2cG2WP"
162
+ },
163
+ "outputs": [],
164
+ "source": [
165
+ "learn = vision_learner(dls, resnet34, metrics=error_rate)\n",
166
+ "learn.fine_tune(10)"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "markdown",
171
+ "metadata": {
172
+ "id": "0wZAFpxi7L6z"
173
+ },
174
+ "source": [
175
+ "### Show predictions the model made on images in validation set"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": null,
181
+ "metadata": {
182
+ "id": "aE0vp3jeVtBT"
183
+ },
184
+ "outputs": [],
185
+ "source": [
186
+ "learn.show_results()"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "markdown",
191
+ "metadata": {
192
+ "id": "gFpqdZr87ZSS"
193
+ },
194
+ "source": [
195
+ "### Download an image from internet for trying the model"
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "code",
200
+ "execution_count": null,
201
+ "metadata": {
202
+ "id": "C0kfX6QUMRoN"
203
+ },
204
+ "outputs": [],
205
+ "source": [
206
+ "term='penguin' # change the search term\n",
207
+ "download_url(search_images(term, max_images=1)[0], 'test.jpg', show_progress=False)\n",
208
+ "Image.open('test.jpg').to_thumb(256,256)"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "markdown",
213
+ "metadata": {
214
+ "id": "AgOQPzTX7q3o"
215
+ },
216
+ "source": [
217
+ "### Trying the model on the downloaded image"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": null,
223
+ "metadata": {
224
+ "id": "sz1dVCZMHz3N"
225
+ },
226
+ "outputs": [],
227
+ "source": [
228
+ "predict,n,prob = learn.predict(PILImage.create('test.jpg'))\n",
229
+ "print(f\"It's {predict}!\")\n",
230
+ "perc = prob[n]*100\n",
231
+ "print(f\"I'm {perc:.02f}% confident.\")"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "markdown",
236
+ "metadata": {
237
+ "id": "lSSjWJq874WE"
238
+ },
239
+ "source": [
240
+ "### Export the model"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": 94,
246
+ "metadata": {
247
+ "id": "ae2bc6ac"
248
+ },
249
+ "outputs": [],
250
+ "source": [
251
+ "learn.export('model.pkl')"
252
+ ]
253
+ }
254
+ ],
255
+ "metadata": {
256
+ "accelerator": "GPU",
257
+ "colab": {
258
+ "collapsed_sections": [],
259
+ "name": "model.ipynb",
260
+ "provenance": []
261
+ },
262
+ "gpuClass": "standard",
263
+ "kernelspec": {
264
+ "display_name": "Python 3",
265
+ "name": "python3"
266
+ },
267
+ "language_info": {
268
+ "name": "python"
269
+ }
270
+ },
271
+ "nbformat": 4,
272
+ "nbformat_minor": 0
273
+ }
model.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:afbd4ed3363dd00b44ce149bf20e7d6891e96ca7280b1fa65e7ec4a0994cf115
3
- size 87468645
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8e13d90feb325569e3ac1ce4371ce77195222ac28f044a97b1620aef8f70efe2
3
+ size 87503717