suvash commited on
Commit
e8abe35
1 Parent(s): 3d3df40

add training and inference notebooks

Browse files
notebooks/food-101-infer-gradio.ipynb ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "87345732-d868-473b-b1a1-5c25839ce25b",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from fastai.vision.all import *"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 2,
16
+ "id": "79b9fbad-7b99-40fd-8768-b0a091bf85cb",
17
+ "metadata": {},
18
+ "outputs": [
19
+ {
20
+ "name": "stderr",
21
+ "output_type": "stream",
22
+ "text": [
23
+ "/conda/envs/py310-cuda116/lib/python3.10/site-packages/paramiko/transport.py:236: CryptographyDeprecationWarning: Blowfish has been deprecated\n",
24
+ " \"class\": algorithms.Blowfish,\n"
25
+ ]
26
+ }
27
+ ],
28
+ "source": [
29
+ "import gradio"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": 3,
35
+ "id": "5409c6a7-5cae-42bb-8335-587a04471f22",
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "MODELS_PATH = Path(\"./models\")"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": 4,
45
+ "id": "4e836799-6858-438a-8d70-d95f98cf54f7",
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "EXAMPLES_PATH = Path('./examples')"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": 5,
55
+ "id": "9ed20c60-9f23-4795-bb4b-79b00af0f6d1",
56
+ "metadata": {},
57
+ "outputs": [
58
+ {
59
+ "data": {
60
+ "text/plain": [
61
+ "(#2) [Path('models/food-101-resnet34.pkl'),Path('models/food-101-resnet50.pkl')]"
62
+ ]
63
+ },
64
+ "execution_count": 5,
65
+ "metadata": {},
66
+ "output_type": "execute_result"
67
+ }
68
+ ],
69
+ "source": [
70
+ "MODELS_PATH.ls()"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": 6,
76
+ "id": "0969ba8e-b0df-4550-a900-5d5a30fb0187",
77
+ "metadata": {},
78
+ "outputs": [
79
+ {
80
+ "data": {
81
+ "text/plain": [
82
+ "(#9) [Path('examples/pad_thai.jpeg'),Path('examples/takoyaki.jpeg'),Path('examples/momo.jpeg'),Path('examples/falafel.jpeg'),Path('examples/paella.jpeg'),Path('examples/ravioli.jpeg'),Path('examples/huevos_rancheros.jpeg'),Path('examples/edamame.jpeg'),Path('examples/sushi.jpeg')]"
83
+ ]
84
+ },
85
+ "execution_count": 6,
86
+ "metadata": {},
87
+ "output_type": "execute_result"
88
+ }
89
+ ],
90
+ "source": [
91
+ "EXAMPLES_PATH.ls()"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 7,
97
+ "id": "e9143742-c6bc-44f6-8ecd-3826502c84ac",
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": [
101
+ "def label_func(filepath):\n",
102
+ " return filepath.parent.name"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": 8,
108
+ "id": "c6ad64e8-f163-4472-b2f0-c0aa50ead4d8",
109
+ "metadata": {},
110
+ "outputs": [],
111
+ "source": [
112
+ "learn = load_learner(MODELS_PATH/'food-101-resnet50.pkl')"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": 9,
118
+ "id": "d1370d20-fd51-4512-bd28-5f170d216c7b",
119
+ "metadata": {},
120
+ "outputs": [
121
+ {
122
+ "data": {
123
+ "text/plain": [
124
+ "['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', 'ceviche', 'cheese_plate', 'cheesecake', 'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles']"
125
+ ]
126
+ },
127
+ "execution_count": 9,
128
+ "metadata": {},
129
+ "output_type": "execute_result"
130
+ }
131
+ ],
132
+ "source": [
133
+ "labels = learn.dls.vocab\n",
134
+ "labels"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "id": "8f666b42-9fdd-45ca-81ca-7e98dd191369",
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": []
144
+ },
145
+ {
146
+ "cell_type": "code",
147
+ "execution_count": 11,
148
+ "id": "a360dd6b-75a9-43e5-b91d-c6963ea462ea",
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": [
152
+ "def predict(img):\n",
153
+ " img = PILImage.create(img)\n",
154
+ " _pred, _pred_w_idx, probs = learn.predict(img)\n",
155
+ " labels_probs = {labels[i]: float(probs[i]) for i, _ in enumerate(labels)}\n",
156
+ " return labels_probs"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": 12,
162
+ "id": "febc7266-8587-4530-811b-f2fa9117dcd5",
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": [
166
+ "with open('gradio_article.md') as f:\n",
167
+ " article = f.read()"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": 13,
173
+ "id": "8fd4ffb4-11ca-4b25-999c-cde2a4e236b4",
174
+ "metadata": {},
175
+ "outputs": [
176
+ {
177
+ "name": "stderr",
178
+ "output_type": "stream",
179
+ "text": [
180
+ "/conda/envs/py310-cuda116/lib/python3.10/site-packages/gradio/interface.py:419: UserWarning: The `enable_queue` parameter in the `Interface`will be deprecated and may not work properly. Please use the `enable_queue` parameter in `launch()` instead\n",
181
+ " warnings.warn(\n"
182
+ ]
183
+ },
184
+ {
185
+ "name": "stdout",
186
+ "output_type": "stream",
187
+ "text": [
188
+ "Running on local URL: http://localhost:9999/\n",
189
+ "\n",
190
+ "To create a public link, set `share=True` in `launch()`.\n"
191
+ ]
192
+ },
193
+ {
194
+ "data": {
195
+ "text/html": [
196
+ "\n",
197
+ " <iframe\n",
198
+ " width=\"900\"\n",
199
+ " height=\"500\"\n",
200
+ " src=\"http://localhost:9999/\"\n",
201
+ " frameborder=\"0\"\n",
202
+ " allowfullscreen\n",
203
+ " \n",
204
+ " ></iframe>\n",
205
+ " "
206
+ ],
207
+ "text/plain": [
208
+ "<IPython.lib.display.IFrame at 0x7f69f5315840>"
209
+ ]
210
+ },
211
+ "metadata": {},
212
+ "output_type": "display_data"
213
+ },
214
+ {
215
+ "data": {
216
+ "text/plain": [
217
+ "(<fastapi.applications.FastAPI at 0x7f69f7595330>,\n",
218
+ " 'http://localhost:9999/',\n",
219
+ " None)"
220
+ ]
221
+ },
222
+ "execution_count": 13,
223
+ "metadata": {},
224
+ "output_type": "execute_result"
225
+ }
226
+ ],
227
+ "source": [
228
+ "interface_options = {\n",
229
+ " \"title\": \"Food-101 Classifier\",\n",
230
+ " \"description\": \"A food image classifier trained on the Food-101 (https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/) dataset with fastai with a ResNet50 CNN model.\",\n",
231
+ " \"article\": article,\n",
232
+ " \"examples\" : [f'{EXAMPLES_PATH}/{f.name}' for f in EXAMPLES_PATH.iterdir()],\n",
233
+ " \"interpretation\": \"default\",\n",
234
+ " \"layout\": \"horizontal\",\n",
235
+ " \"allow_flagging\": \"never\",\n",
236
+ " \"enable_queue\": True \n",
237
+ "}\n",
238
+ "\n",
239
+ "demo = gradio.Interface(fn=predict,\n",
240
+ " inputs=gradio.inputs.Image(shape=(512, 512)),\n",
241
+ " outputs=gradio.outputs.Label(num_top_classes=5),\n",
242
+ " **interface_options)\n",
243
+ "\n",
244
+ "demo_options = {\n",
245
+ " \"inline\": True,\n",
246
+ " \"inbrowser\": False,\n",
247
+ " \"share\": False,\n",
248
+ " \"show_error\": True,\n",
249
+ " \"server_name\": \"0.0.0.0\",\n",
250
+ " \"server_port\": 9999,\n",
251
+ " \"enable_queue\": True,\n",
252
+ "}\n",
253
+ "\n",
254
+ "demo.launch(**demo_options)"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": null,
260
+ "id": "570f8a3c-367e-4a7f-808d-8fa2e925a444",
261
+ "metadata": {},
262
+ "outputs": [],
263
+ "source": []
264
+ }
265
+ ],
266
+ "metadata": {
267
+ "kernelspec": {
268
+ "display_name": "Python 3 (ipykernel)",
269
+ "language": "python",
270
+ "name": "python3"
271
+ },
272
+ "language_info": {
273
+ "codemirror_mode": {
274
+ "name": "ipython",
275
+ "version": 3
276
+ },
277
+ "file_extension": ".py",
278
+ "mimetype": "text/x-python",
279
+ "name": "python",
280
+ "nbconvert_exporter": "python",
281
+ "pygments_lexer": "ipython3",
282
+ "version": "3.10.4"
283
+ }
284
+ },
285
+ "nbformat": 4,
286
+ "nbformat_minor": 5
287
+ }
notebooks/food-101-train-resnet.ipynb ADDED
The diff for this file is too large to render. See raw diff