blanchon commited on
Commit
e18a750
1 Parent(s): 9e822eb

first commit

Browse files
__pycache__/dataloading.cpython-310.pyc ADDED
Binary file (4.55 kB). View file
__pycache__/gradio_utils.cpython-310.pyc ADDED
Binary file (1.5 kB). View file
__pycache__/preprocessing.cpython-310.pyc ADDED
Binary file (8.42 kB). View file
__pycache__/resnet.cpython-310.pyc ADDED
Binary file (2.13 kB). View file
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import skorch
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ import gradio as gr
8
+
9
+ import librosa
10
+
11
+ from joblib import dump, load
12
+
13
+ from sklearn.pipeline import Pipeline
14
+ from sklearn.preprocessing import LabelEncoder
15
+
16
+ from resnet import ResNet
17
+ from gradio_utils import load_as_librosa, predict_gradio
18
+ from dataloading import uniformize, to_numpy
19
+ from preprocessing import MfccTransformer, TorchTransform
20
+
21
+
22
+ SEED : int = 42
23
+ np.random.seed(SEED)
24
+ torch.manual_seed(SEED)
25
+
26
+ model = load('./model/model.joblib')
27
+ only_mffc_transform = load('./model/only_mffc_transform.joblib')
28
+ label_encoder = load('./model/label_encoder.joblib')
29
+ SAMPLE_RATE = load("./model/SAMPLE_RATE.joblib")
30
+ METHOD = load("./model/METHOD.joblib")
31
+ MAX_TIME = load("./model/MAX_TIME.joblib")
32
+ N_MFCC = load("./model/N_MFCC.joblib")
33
+ HOP_LENGHT = load("./model/HOP_LENGHT.joblib")
34
+
35
+ sklearn_model = Pipeline(
36
+ steps=[
37
+ ("mfcc", only_mffc_transform),
38
+ ("model", model)
39
+ ]
40
+ )
41
+
42
+ uniform_lambda = lambda y, sr: uniformize(y, sr, METHOD, MAX_TIME)
43
+
44
+ title = r"ResNet 9"
45
+
46
+ description = r"""
47
+ <center>
48
+ The resnet9 model was trained to classify drone speech command.
49
+ <img src="http://zeus.blanchon.cc/dropshare/modia.png" width=200px>
50
+ </center>
51
+ """
52
+ article = r"""
53
+ - [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385)
54
+ """
55
+
56
+ demo_men = gr.Interface(
57
+ title = title,
58
+ description = description,
59
+ article = article,
60
+ fn=lambda data: predict_gradio(
61
+ data=data,
62
+ uniform_lambda=uniform_lambda,
63
+ sklearn_model=sklearn_model,
64
+ label_transform=label_encoder,
65
+ target_sr=SAMPLE_RATE),
66
+ inputs = gr.Audio(source="microphone", type="numpy"),
67
+ outputs = gr.Label(),
68
+ # allow_flagging = "manual",
69
+ # flagging_options = ['recule', 'tournedroite', 'arretetoi', 'tournegauche', 'gauche', 'avance', 'droite'],
70
+ # flagging_dir = "./flag/men"
71
+ )
72
+
73
+ demo_men.launch()
best_model_gradio.ipynb ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Best Model"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 42,
13
+ "metadata": {},
14
+ "outputs": [
15
+ {
16
+ "name": "stdout",
17
+ "output_type": "stream",
18
+ "text": [
19
+ "The autoreload extension is already loaded. To reload it, use:\n",
20
+ " %reload_ext autoreload\n"
21
+ ]
22
+ }
23
+ ],
24
+ "source": [
25
+ "%load_ext autoreload\n",
26
+ "%autoreload 2\n",
27
+ "\n",
28
+ "import numpy as np\n",
29
+ "\n",
30
+ "import skorch\n",
31
+ "import torch\n",
32
+ "import torch.nn as nn\n",
33
+ "\n",
34
+ "import gradio as gr\n",
35
+ "\n",
36
+ "import librosa\n",
37
+ "\n",
38
+ "from joblib import dump, load\n",
39
+ "\n",
40
+ "from sklearn.pipeline import Pipeline\n",
41
+ "from sklearn.preprocessing import LabelEncoder\n",
42
+ "\n",
43
+ "from resnet import ResNet\n",
44
+ "from gradio_utils import load_as_librosa, predict_gradio\n",
45
+ "from dataloading import uniformize, to_numpy\n",
46
+ "from preprocessing import MfccTransformer, TorchTransform\n",
47
+ "\n"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": 27,
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "# Notebook params\n",
57
+ "SEED : int = 42\n",
58
+ "np.random.seed(SEED)\n",
59
+ "torch.manual_seed(SEED)\n",
60
+ "\n",
61
+ "# Dataloading params\n",
62
+ "PATHS: list[str] = [\n",
63
+ " \"../data/\",\n",
64
+ " \"../new_data/JulienNestor\",\n",
65
+ " \"../new_data/classroom_data\",\n",
66
+ " \"../new_data/class\",\n",
67
+ " \"../new_data/JulienRaph\",\n",
68
+ "]\n",
69
+ "REMOVE_LABEL: list[str] = [\n",
70
+ " \"penduleinverse\", \"pendule\", \n",
71
+ " \"decollage\", \"atterrissage\",\n",
72
+ " \"plushaut\", \"plusbas\",\n",
73
+ " \"etatdurgence\",\n",
74
+ " \"faisunflip\", \n",
75
+ " \"faisUnFlip\", \"arreteToi\", \"etatDurgence\",\n",
76
+ " # \"tournedroite\", \"arretetoi\", \"tournegauche\"\n",
77
+ "]\n",
78
+ "SAMPLE_RATE: int = 16_000\n",
79
+ "METHOD: str = \"time_stretch\"\n",
80
+ "MAX_TIME: float = 3.0\n",
81
+ "\n",
82
+ "# Features Extraction params\n",
83
+ "N_MFCC: int = 64\n",
84
+ "HOP_LENGHT = 2_048"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "markdown",
89
+ "metadata": {},
90
+ "source": [
91
+ "# 1 - Dataloading"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 28,
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "# 1-Dataloading\n",
101
+ "from dataloading import load_dataset, to_numpy\n",
102
+ "dataset, uniform_lambda = load_dataset(PATHS,\n",
103
+ " remove_label=REMOVE_LABEL,\n",
104
+ " sr=SAMPLE_RATE,\n",
105
+ " method=METHOD,\n",
106
+ " max_time=MAX_TIME\n",
107
+ " )"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": 29,
113
+ "metadata": {},
114
+ "outputs": [
115
+ {
116
+ "data": {
117
+ "text/plain": [
118
+ "['recule',\n",
119
+ " 'tournedroite',\n",
120
+ " 'arretetoi',\n",
121
+ " 'tournegauche',\n",
122
+ " 'gauche',\n",
123
+ " 'avance',\n",
124
+ " 'droite']"
125
+ ]
126
+ },
127
+ "execution_count": 29,
128
+ "metadata": {},
129
+ "output_type": "execute_result"
130
+ }
131
+ ],
132
+ "source": [
133
+ "list(dataset[\"ground_truth\"].unique())"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": 30,
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "# 2-Train and split\n",
143
+ "from sklearn.model_selection import train_test_split\n",
144
+ "dataset_train, dataset_test = train_test_split(dataset, random_state=0)\n",
145
+ "\n",
146
+ "X_train = to_numpy(dataset_train[\"y_uniform\"])\n",
147
+ "y_train = to_numpy(dataset_train[\"ground_truth\"])\n",
148
+ "X_test = to_numpy(dataset_test[\"y_uniform\"])\n",
149
+ "y_test = to_numpy(dataset_test[\"ground_truth\"])"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "markdown",
154
+ "metadata": {},
155
+ "source": [
156
+ "# 2 - Preprocessing"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": 31,
162
+ "metadata": {},
163
+ "outputs": [],
164
+ "source": [
165
+ "only_mffc_transform = Pipeline(\n",
166
+ " steps=[\n",
167
+ " (\"mfcc\", MfccTransformer(N_MFCC=N_MFCC, reshape_output=False, hop_length=HOP_LENGHT)),\n",
168
+ " (\"torch\", TorchTransform())\n",
169
+ " ]\n",
170
+ ")\n",
171
+ "\n",
172
+ "only_mffc_transform.fit(X_train)\n",
173
+ "\n",
174
+ "X_train_mfcc_torch = only_mffc_transform.transform(X_train)\n",
175
+ "X_test_mfcc_torch = only_mffc_transform.transform(X_test)"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": 32,
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": [
184
+ "# Train a LabelEncoder (if needed)\n",
185
+ "label_encoder = LabelEncoder()\n",
186
+ "label_encoder.fit(y_train)\n",
187
+ "y_train_enc = label_encoder.transform(y_train)\n",
188
+ "y_test_enc = label_encoder.transform(y_test)"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "markdown",
193
+ "metadata": {},
194
+ "source": [
195
+ "# 3 - ResNet"
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "code",
200
+ "execution_count": 33,
201
+ "metadata": {},
202
+ "outputs": [],
203
+ "source": [
204
+ "if hasattr(torch, \"has_mps\") and torch.has_mps:\n",
205
+ " device = torch.device(\"mps\")\n",
206
+ "elif hasattr(torch, \"has_cuda\") and torch.has_cuda:\n",
207
+ " device = torch.device(\"cuda\")\n",
208
+ "else:\n",
209
+ " device = torch.device(\"cpu\")"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "markdown",
214
+ "metadata": {},
215
+ "source": [
216
+ "## 3.1 - nn.Module"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "execution_count": 34,
222
+ "metadata": {},
223
+ "outputs": [],
224
+ "source": [
225
+ "# from resnet import ResNet"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "markdown",
230
+ "metadata": {},
231
+ "source": [
232
+ "## 3.2 - Train"
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "execution_count": 35,
238
+ "metadata": {},
239
+ "outputs": [
240
+ {
241
+ "name": "stdout",
242
+ "output_type": "stream",
243
+ "text": [
244
+ " epoch train_loss dur\n",
245
+ "------- ------------ ------\n",
246
+ " 1 \u001b[36m2.8646\u001b[0m 0.4461\n",
247
+ " 2 \u001b[36m1.9534\u001b[0m 0.4322\n",
248
+ " 3 \u001b[36m1.8164\u001b[0m 0.4331\n",
249
+ " 4 \u001b[36m1.6889\u001b[0m 0.4318\n",
250
+ " 5 \u001b[36m1.5808\u001b[0m 0.4329\n",
251
+ " 6 \u001b[36m1.4659\u001b[0m 0.4355\n",
252
+ " 7 \u001b[36m1.2894\u001b[0m 0.4285\n",
253
+ " 8 1.3207 0.4280\n",
254
+ " 9 \u001b[36m1.1546\u001b[0m 0.4274\n",
255
+ " 10 \u001b[36m1.0586\u001b[0m 0.4287\n",
256
+ " 11 \u001b[36m1.0195\u001b[0m 0.4313\n",
257
+ " 12 \u001b[36m0.8246\u001b[0m 0.4302\n",
258
+ " 13 \u001b[36m0.7612\u001b[0m 0.4330\n",
259
+ " 14 \u001b[36m0.7296\u001b[0m 0.4315\n",
260
+ " 15 \u001b[36m0.6690\u001b[0m 0.4293\n",
261
+ " 16 \u001b[36m0.6205\u001b[0m 0.4291\n",
262
+ " 17 \u001b[36m0.5764\u001b[0m 0.4290\n",
263
+ " 18 \u001b[36m0.4839\u001b[0m 0.4284\n",
264
+ " 19 0.4984 0.4314\n",
265
+ " 20 \u001b[36m0.4666\u001b[0m 0.4324\n",
266
+ " 21 \u001b[36m0.4132\u001b[0m 0.4322\n",
267
+ " 22 0.4440 0.4300\n",
268
+ " 23 0.4463 0.4300\n",
269
+ " 24 \u001b[36m0.4075\u001b[0m 0.4287\n",
270
+ " 25 \u001b[36m0.3908\u001b[0m 0.4282\n",
271
+ " 26 \u001b[36m0.3759\u001b[0m 0.4278\n",
272
+ " 27 \u001b[36m0.3612\u001b[0m 0.4296\n",
273
+ " 28 \u001b[36m0.3189\u001b[0m 0.4281\n",
274
+ " 29 0.3489 0.4308\n",
275
+ " 30 0.3308 0.4301\n",
276
+ " 31 0.3353 0.4299\n",
277
+ " 32 \u001b[36m0.3074\u001b[0m 0.4298\n",
278
+ " 33 0.3339 0.4350\n",
279
+ " 34 \u001b[36m0.2921\u001b[0m 0.4383\n",
280
+ " 35 \u001b[36m0.2852\u001b[0m 0.4345\n",
281
+ " 36 0.3170 0.4334\n",
282
+ " 37 0.2853 0.4304\n",
283
+ " 38 0.2857 0.4307\n",
284
+ " 39 \u001b[36m0.2607\u001b[0m 0.4310\n",
285
+ " 40 0.2765 0.4292\n",
286
+ " 41 0.2831 0.4305\n",
287
+ " 42 0.2836 0.4295\n",
288
+ " 43 0.2742 0.4307\n",
289
+ " 44 0.2653 0.4302\n",
290
+ " 45 \u001b[36m0.2370\u001b[0m 0.4335\n",
291
+ " 46 0.2475 0.4292\n",
292
+ " 47 0.2692 0.4329\n",
293
+ " 48 0.2657 0.4306\n",
294
+ " 49 0.2875 0.4305\n",
295
+ " 50 0.2839 0.4315\n",
296
+ " 51 0.2555 0.4307\n",
297
+ " 52 0.2794 0.4332\n",
298
+ " 53 \u001b[36m0.2272\u001b[0m 0.4302\n",
299
+ " 54 0.2519 0.4305\n",
300
+ " 55 0.2388 0.4307\n",
301
+ " 56 0.2504 0.4314\n",
302
+ " 57 0.2345 0.4328\n",
303
+ " 58 \u001b[36m0.2252\u001b[0m 0.4316\n",
304
+ " 59 0.2436 0.4329\n",
305
+ " 60 0.2297 0.4309\n",
306
+ " 61 0.2594 0.4306\n",
307
+ " 62 0.2412 0.4300\n",
308
+ " 63 0.2399 0.4319\n",
309
+ " 64 0.2600 0.4334\n",
310
+ " 65 0.2599 0.4304\n",
311
+ " 66 0.2360 0.4317\n",
312
+ " 67 0.2537 0.4301\n",
313
+ " 68 0.2268 0.4299\n",
314
+ " 69 0.2436 0.4301\n",
315
+ " 70 \u001b[36m0.2193\u001b[0m 0.4308\n",
316
+ " 71 0.2284 0.4322\n",
317
+ " 72 0.2339 0.4317\n",
318
+ " 73 0.2330 0.4331\n",
319
+ " 74 \u001b[36m0.2063\u001b[0m 0.4327\n",
320
+ " 75 0.2568 0.4332\n",
321
+ " 76 0.2372 0.4324\n",
322
+ " 77 0.2249 0.4327\n",
323
+ " 78 0.2449 0.4314\n",
324
+ " 79 0.2455 0.4310\n",
325
+ " 80 \u001b[36m0.2003\u001b[0m 0.4321\n",
326
+ " 81 0.2172 0.4318\n",
327
+ " 82 0.2278 0.4333\n",
328
+ " 83 0.2178 0.4334\n",
329
+ " 84 0.2240 0.4312\n",
330
+ " 85 0.2329 0.4338\n",
331
+ " 86 0.2267 0.4326\n",
332
+ " 87 0.2479 0.4341\n",
333
+ " 88 0.2266 0.4355\n",
334
+ " 89 0.2541 0.4350\n",
335
+ " 90 0.2167 0.4324\n",
336
+ " 91 0.2282 0.4353\n",
337
+ " 92 0.2097 0.4367\n",
338
+ " 93 0.2038 0.4351\n",
339
+ " 94 0.2078 0.4372\n",
340
+ " 95 0.2437 0.4344\n",
341
+ " 96 0.2283 0.4333\n",
342
+ " 97 0.2263 0.4329\n",
343
+ " 98 0.2146 0.4346\n",
344
+ " 99 0.2238 0.4323\n",
345
+ " 100 0.2035 0.4348\n",
346
+ " 101 0.2287 0.4348\n",
347
+ " 102 0.2231 0.4328\n",
348
+ " 103 0.2171 0.4326\n",
349
+ " 104 0.2417 0.4329\n",
350
+ "Stopping since train_loss has not improved in the last 25 epochs.\n",
351
+ "0.941908713692946\n"
352
+ ]
353
+ }
354
+ ],
355
+ "source": [
356
+ "# Define net\n",
357
+ "n_labels = np.unique(dataset.ground_truth).size\n",
358
+ "net = ResNet(in_channels=1, num_classes=n_labels)\n",
359
+ "\n",
360
+ "# Define model\n",
361
+ "model = skorch.NeuralNetClassifier(\n",
362
+ " module=net,\n",
363
+ " criterion=nn.CrossEntropyLoss(),\n",
364
+ " callbacks=[skorch.callbacks.EarlyStopping(monitor=\"train_loss\", patience=25)],\n",
365
+ " max_epochs=200,\n",
366
+ " lr=0.01,\n",
367
+ " batch_size=128,\n",
368
+ " train_split=None,\n",
369
+ " device=device,\n",
370
+ ")\n",
371
+ "\n",
372
+ "model.check_data(X_train_mfcc_torch, y_train_enc)\n",
373
+ "model.fit(X_train_mfcc_torch, y_train_enc)\n",
374
+ "\n",
375
+ "print(model.score(X_test_mfcc_torch, y_test_enc))"
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "code",
380
+ "execution_count": 39,
381
+ "metadata": {},
382
+ "outputs": [
383
+ {
384
+ "data": {
385
+ "text/plain": [
386
+ "['./model/HOP_LENGHT.joblib']"
387
+ ]
388
+ },
389
+ "execution_count": 39,
390
+ "metadata": {},
391
+ "output_type": "execute_result"
392
+ }
393
+ ],
394
+ "source": [
395
+ "from joblib import dump, load\n",
396
+ "\n",
397
+ "dump(model, './model/model.joblib') \n",
398
+ "dump(only_mffc_transform, './model/only_mffc_transform.joblib') \n",
399
+ "dump(label_encoder, './model/label_encoder.joblib')\n",
400
+ "dump(SAMPLE_RATE, \"./model/SAMPLE_RATE.joblib\")\n",
401
+ "dump(METHOD, \"./model/METHOD.joblib\")\n",
402
+ "dump(MAX_TIME, \"./model/MAX_TIME.joblib\")\n",
403
+ "dump(N_MFCC, \"./model/N_MFCC.joblib\")\n",
404
+ "dump(HOP_LENGHT, \"./model/HOP_LENGHT.joblib\")"
405
+ ]
406
+ },
407
+ {
408
+ "cell_type": "code",
409
+ "execution_count": 40,
410
+ "metadata": {},
411
+ "outputs": [],
412
+ "source": [
413
+ "model = load('./model/model.joblib') \n",
414
+ "only_mffc_transform = load('./model/only_mffc_transform.joblib') \n",
415
+ "label_encoder = load('./model/label_encoder.joblib') \n",
416
+ "SAMPLE_RATE = load(\"./model/SAMPLE_RATE.joblib\")\n",
417
+ "METHOD = load(\"./model/METHOD.joblib\")\n",
418
+ "MAX_TIME = load(\"./model/MAX_TIME.joblib\")\n",
419
+ "N_MFCC = load(\"./model/N_MFCC.joblib\")\n",
420
+ "HOP_LENGHT = load(\"./model/HOP_LENGHT.joblib\")\n",
421
+ "\n",
422
+ "sklearn_model = Pipeline(\n",
423
+ " steps=[\n",
424
+ " (\"mfcc\", only_mffc_transform),\n",
425
+ " (\"model\", model)\n",
426
+ " ]\n",
427
+ " )\n",
428
+ "\n",
429
+ "uniform_lambda = lambda y, sr: uniformize(y, sr, METHOD, MAX_TIME)"
430
+ ]
431
+ },
432
+ {
433
+ "cell_type": "code",
434
+ "execution_count": 43,
435
+ "metadata": {},
436
+ "outputs": [
437
+ {
438
+ "ename": "",
439
+ "evalue": "",
440
+ "output_type": "error",
441
+ "traceback": [
442
+ "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
443
+ ]
444
+ }
445
+ ],
446
+ "source": [
447
+ "title = r\"ResNet 9\"\n",
448
+ "\n",
449
+ "description = r\"\"\"\n",
450
+ "<center>\n",
451
+ "The resnet9 model was trained to classify drone speech command.\n",
452
+ "<img src=\"http://zeus.blanchon.cc/dropshare/modia.png\" width=200px>\n",
453
+ "</center>\n",
454
+ "\"\"\"\n",
455
+ "article = r\"\"\"\n",
456
+ "- [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385)\n",
457
+ "\"\"\"\n",
458
+ "\n",
459
+ "demo_men = gr.Interface(\n",
460
+ " title = title,\n",
461
+ " description = description,\n",
462
+ " article = article, \n",
463
+ " fn=lambda data: predict_gradio(\n",
464
+ " data=data, \n",
465
+ " uniform_lambda=uniform_lambda, \n",
466
+ " sklearn_model=sklearn_model,\n",
467
+ " label_transform=label_encoder,\n",
468
+ " target_sr=SAMPLE_RATE),\n",
469
+ " inputs = gr.Audio(source=\"microphone\", type=\"numpy\"),\n",
470
+ " outputs = gr.Label(),\n",
471
+ " # allow_flagging = \"manual\",\n",
472
+ " # flagging_options = ['recule', 'tournedroite', 'arretetoi', 'tournegauche', 'gauche', 'avance', 'droite'],\n",
473
+ " # flagging_dir = \"./flag/men\"\n",
474
+ ")"
475
+ ]
476
+ }
477
+ ],
478
+ "metadata": {
479
+ "kernelspec": {
480
+ "display_name": "Python 3.10.4 ('ml')",
481
+ "language": "python",
482
+ "name": "python3"
483
+ },
484
+ "language_info": {
485
+ "codemirror_mode": {
486
+ "name": "ipython",
487
+ "version": 3
488
+ },
489
+ "file_extension": ".py",
490
+ "mimetype": "text/x-python",
491
+ "name": "python",
492
+ "nbconvert_exporter": "python",
493
+ "pygments_lexer": "ipython3",
494
+ "version": "3.10.4"
495
+ },
496
+ "vscode": {
497
+ "interpreter": {
498
+ "hash": "f1f34988cae7bd54e626a86efbacac2b339eeffffea662e9af12f610fca26db7"
499
+ }
500
+ }
501
+ },
502
+ "nbformat": 4,
503
+ "nbformat_minor": 2
504
+ }
dataloading.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ import librosa
5
+
6
+ from pathlib import Path
7
+ from typing import Callable, Literal, Optional
8
+
9
+ def load_dataset(
10
+ paths: list[str],
11
+ remove_label: list[str] = [""],
12
+ sr: int = 22050,
13
+ method : Literal["fix_length", "time_stretch"] = "fix_length",
14
+ max_time: float = 4.0) -> tuple[pd.DataFrame, Callable[[np.ndarray, int], np.ndarray]]:
15
+ """Folder dataset in memory loader (return fully loaded pandas dataframe).
16
+ - For sklearn, load the whole dataset if possible otherwise use `proportion` to only load a part of the dataset.
17
+ - For pytorch, load the whole dataset if possible otherwise use `proportion` to only load a part of the dataset.
18
+ And convert output to Tensor on the fly.
19
+
20
+ Use `to_numpy(df.y)` to extract a numpy matrix with a (n_row, ...) shape.
21
+
22
+ Expect a dataset folder structure as: paths = [paths1, paths2, ...]
23
+ - paths1
24
+ - sub1
25
+ - blabla_GroundTruth1.wav
26
+ - blabla_GroundTruth2.wav
27
+ - sub2
28
+ - ...
29
+ ...
30
+ - ...
31
+
32
+ Args:
33
+ paths (list[Path]): list of dataset directory to parse.
34
+ remove_label (list[str], optional): list of label to remove. Defaults to None.. Defaults to [""].
35
+ shuffle (bool, optional): True to suffle the dataframe. Defaults to True.
36
+ proportion (float, optional): Proportion of file to load. Defaults to 1.0.
37
+ sr (int, optional): Sample Rate to resample audio file. Defaults to 22050.
38
+ method (Literal['fix_length';, 'time_stretch'], optional): uniformization method to apply. Defaults to "fix_length".
39
+ max_time (float, optional): Common audio duration . Defaults to 4.0.
40
+
41
+ Returns:
42
+ df (pd.DataFrame): A pd.DataFrame with such define column:
43
+ - absolute_path (str): file-system absolute path of the .wav file.
44
+ - labels (list[str]): list of labels defining the sound file (ie, subdirectories and post _ filename).
45
+ - ground_truth (str): ground_truth label meaning the last one after _ in the sound filename.
46
+ - y_original_signal (np.ndarray): sound signal normalize as `float64` and resample with the given sr by `librosa.load`
47
+ - y_original_duration (float): y_original_signal signal duration.
48
+ - y_uniform (np.ndarray): uniformized sound signal compute from y_original_signal using the chosen uniform method.
49
+ uniform_transform (Callable[[np.ndarray, int], np.ndarray]]): A lambda function to uniformized an audio signal as the same in df.
50
+ """
51
+ data = []
52
+ uniform_transform = lambda y, sr: uniformize(y, sr, method, max_time)
53
+ for path in paths:
54
+ path = Path(path)
55
+ for wav_file in path.rglob("*.wav"):
56
+ wav_file_dict = dict()
57
+ absolute_path = wav_file.absolute()
58
+ *labels, label = absolute_path.relative_to(path.absolute()).parts
59
+ label = label.replace(".wav", "").split("_")
60
+ labels.extend(label)
61
+ ground_truth = labels[-1]
62
+ if ground_truth not in remove_label:
63
+ y_original, sr = librosa.load(path=absolute_path, sr=sr)
64
+ # WARINING : Convert the sampling rate to 22.05 KHz,
65
+ # normalize the bit depth between -1 and 1 and convert stereo to mono
66
+ wav_file_dict["absolute_path"] = absolute_path
67
+ wav_file_dict["labels"] = labels
68
+ wav_file_dict["ground_truth"] = ground_truth
69
+ ## Save original sound signal
70
+ wav_file_dict["y_original_signal"] = y_original
71
+ duration = librosa.get_duration(y=y_original, sr=sr)
72
+ wav_file_dict["y_original_duration"] = duration
73
+ ## Save uniformized sound signal
74
+ wav_file_dict["y_uniform"] = uniform_transform(y_original, sr)
75
+ data.append(wav_file_dict)
76
+ df = pd.DataFrame(data)
77
+ return df, uniform_transform
78
+
79
+ def uniformize(
80
+ audio: np.ndarray,
81
+ sr: int,
82
+ method: Literal["fix_length", "time_stretch"] = "fix_length",
83
+ max_time: float = 4.0
84
+ ):
85
+ if method == "fix_length":
86
+ return librosa.util.fix_length(audio, size=int(np.ceil(max_time*sr)))
87
+ elif method == "time_stretch":
88
+ duration = librosa.get_duration(y=audio, sr=sr)
89
+ return librosa.effects.time_stretch(audio, rate=duration/max_time)
90
+
91
+
92
+ def to_numpy(ds: pd.Series) -> np.ndarray:
93
+ """Transform a pd.Series (ie columns slice) in a numpy array with the shape (n_row, cell_array.flatten()).
94
+
95
+ Args:
96
+ df (pd.Series): Columns to transform in numpy.
97
+
98
+ Returns:
99
+ np.ndarray: resulting np.array from the ds pd.Series.
100
+ """
101
+ numpy_df = np.stack([*ds.to_numpy()])
102
+ C, *o = numpy_df.shape
103
+
104
+ if o:
105
+ return numpy_df.reshape(numpy_df.shape[0], np.prod(o))
106
+ else:
107
+ return numpy_df.reshape(numpy_df.shape[0])
gradio_utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Callable, Optional
3
+
4
+ import numpy as np
5
+
6
+
7
+
8
+ import librosa
9
+
10
+ import gradio as gr
11
+
12
+ def predict_gradio(data: tuple[int, np.ndarray],
13
+ uniform_lambda: Callable[[np.ndarray, int], np.ndarray],
14
+ sklearn_model,
15
+ label_transform,
16
+ target_sr: int = 22_050) -> Optional[dict]:
17
+ if data is None:
18
+ return
19
+
20
+ classes = sklearn_model.classes_
21
+ if label_transform is not None:
22
+ classes = label_transform.inverse_transform(classes)
23
+
24
+
25
+ y, sr = data[1], data[0]
26
+ y_original_signal = load_as_librosa(y, sr, target_sr=target_sr)
27
+ y_uniform = uniform_lambda(y_original_signal, target_sr).astype(np.float32)
28
+ prediction = sklearn_model.predict_proba(y_uniform.reshape(1, -1))
29
+ result = {str(label): float(confidence) for (
30
+ label, confidence) in zip(classes, prediction.flatten())}
31
+ return result
32
+
33
+ def load_as_librosa(y: np.ndarray, sr: int, target_sr: int = 22050) -> np.ndarray:
34
+ data_dtype = y.dtype
35
+ dtype_min = np.iinfo(data_dtype).min
36
+ dtype_max = np.iinfo(data_dtype).max
37
+ dtype_range = np.abs(dtype_max-dtype_min)
38
+ y_normalize = (y.astype(np.float32)-dtype_min)/dtype_range
39
+ y_normalize_resample = librosa.resample(y=y_normalize,
40
+ orig_sr=sr,
41
+ target_sr=target_sr)
42
+ return y_normalize_resample
model/HOP_LENGHT.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ed7bcd9e9d07c9918817127d9d4d3862f00d680cf13572fd8776d611bddd7ee
3
+ size 15
model/MAX_TIME.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c63e7c444792b99fe2d588a2454f6a5b45f23e4973a77e6f2e3e280d5385bd1
3
+ size 21
model/METHOD.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0225bfd3de4895f2472fde5df0f7f9d67b1b922e62e84395a41fefb3122a4d09
3
+ size 27
model/N_MFCC.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e148c4bd8680b2de4785d81d31a1e4fbbd65c87e687e64c68d68c52aa2c4004
3
+ size 5
model/SAMPLE_RATE.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:510a2ce6eba70c0d21f882833ca726e75e0d1a7cbae3badd55f96c0a8e909ede
3
+ size 15
model/label_encoder.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f350bf3ad2da734f600262b0384aa61125de535a3eff8b80640af0f06e319246
3
+ size 617
model/model.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c88130d5500b9e58fb2bc8e5b3cce918c83fdb94c2361d991e24f79452328b00
3
+ size 53219183
model/only_mffc_transform.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d34fac514bbe21f95e0b62b679e86cced3a7b496c5bd12f087516d55bb9be71
3
+ size 255
preprocessing.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import librosa
4
+
5
+ from sklearn.base import BaseEstimator, TransformerMixin
6
+ from typing import Callable, Optional
7
+
8
+ class ReductionTransformer(BaseEstimator, TransformerMixin):
9
+ def __init__(self, windows_number: int = 300, statistique: Callable[[np.ndarray], np.ndarray] = np.mean):
10
+ self.windows_number = windows_number
11
+ self.statistique = statistique
12
+
13
+ def fit(self, X: np.ndarray, y: Optional[np.ndarray] = None):
14
+ return self
15
+
16
+ def fit_transform(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> np.ndarray:
17
+ self.fit(X, y)
18
+ return self.transform(X, y)
19
+
20
+ def transform(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> np.ndarray:
21
+ X_ = X.copy()
22
+ *c_, size_ = X_.shape
23
+ windows_size_ = size_//self.windows_number
24
+ metrique_clip = X_[..., :self.windows_number*windows_size_]
25
+ return np.apply_along_axis(self.statistique,
26
+ axis=-1,
27
+ arr=metrique_clip.reshape((*c_, self.windows_number, windows_size_)))
28
+
29
+ def inverse_transform(self, X: np.ndarray) -> np.ndarray:
30
+ raise NotImplementedError
31
+
32
+ class MeanTransformer(BaseEstimator, TransformerMixin):
33
+ def __init__(self, windows_number: int = 300):
34
+ self.windows_number = windows_number
35
+ self.windows_size = 0
36
+
37
+ def fit(self, X: np.ndarray, y: Optional[np.ndarray] = None):
38
+ return self
39
+
40
+ def fit_transform(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> np.ndarray:
41
+ self.fit(X, y)
42
+ return self.transform(X, y)
43
+
44
+ def transform(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> np.ndarray:
45
+ X_ = X.copy()
46
+ *c_, size_ = X_.shape
47
+ windows_size_ = size_//self.windows_number
48
+ self.windows_size = windows_size_
49
+ metrique_clip = X_[..., :self.windows_number*windows_size_]
50
+ return np.mean(metrique_clip.reshape((*c_, self.windows_number, windows_size_)), axis=-1)
51
+
52
+ def inverse_transform(self, X: np.ndarray) -> np.ndarray:
53
+ original_size = self.windows_size*self.windows_number
54
+ X_reconstruct = np.interp(
55
+ x = np.arange(start=0, stop=original_size, step=1),
56
+ xp = np.arange(start=0, stop=original_size, step=self.windows_size),
57
+ fp = X
58
+ )
59
+ return X_reconstruct
60
+
61
+ class StdTransformer(BaseEstimator, TransformerMixin):
62
+ def __init__(self, windows_number: int = 300):
63
+ self.windows_number = windows_number
64
+
65
+ def fit(self, X: np.ndarray, y: Optional[np.ndarray] = None):
66
+ return self
67
+
68
+ def fit_transform(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> np.ndarray:
69
+ self.fit(X, y)
70
+ return self.transform(X, y)
71
+
72
+ def transform(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> np.ndarray:
73
+ X_ = X.copy()
74
+ *c_, size_ = X_.shape
75
+ windows_size_ = size_//self.windows_number
76
+ metrique_clip = X_[..., :self.windows_number*windows_size_]
77
+ return np.std(metrique_clip.reshape((*c_, self.windows_number, windows_size_)), axis=-1)
78
+
79
+ def inverse_transform(self, X: np.ndarray) -> np.ndarray:
80
+ raise NotImplementedError
81
+
82
+ class MfccTransformer(BaseEstimator, TransformerMixin):
83
+ def __init__(self, sr: int = 22050, N_MFCC: int = 12, hop_length: int = 1024, reshape_output: bool = True):
84
+ self.sr = sr
85
+ self.N_MFCC = N_MFCC
86
+ self.hop_length = hop_length
87
+ self.reshape_output = reshape_output
88
+
89
+ def reshape(self, X: np.ndarray) -> np.ndarray:
90
+ X_ = X.copy()
91
+ c_, *_ = X_.shape
92
+ return X_.reshape(c_, -1, self.N_MFCC)
93
+
94
+ def fit(self, X: np.ndarray, y: Optional[np.ndarray] = None):
95
+ return self
96
+
97
+ def fit_transform(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> np.ndarray:
98
+ self.fit(X, y)
99
+ return self.transform(X, y)
100
+
101
+ def transform(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> np.ndarray:
102
+ X_ = X.copy()
103
+ c_, *_ = X_.shape
104
+ mfcc = librosa.feature.mfcc(y=X_,
105
+ sr=self.sr,
106
+ hop_length=self.hop_length,
107
+ n_mfcc=self.N_MFCC
108
+ )
109
+ if self.reshape_output:
110
+ mfcc = mfcc.reshape(c_, -1)
111
+
112
+ return mfcc
113
+
114
+ def inverse_transform(self, X: np.ndarray) -> np.ndarray:
115
+ X_reconstruct = librosa.feature.inverse.mfcc_to_audio(
116
+ mfcc = X,
117
+ n_mels = self.N_MFCC,
118
+ )
119
+ return X_reconstruct
120
+
121
+ class MelTransformer(BaseEstimator, TransformerMixin):
122
+ def __init__(self, sr: int = 22050, N_MEL: int = 12, hop_length: int = 1024, reshape_output: bool = True):
123
+ self.sr = sr
124
+ self.N_MEL = N_MEL
125
+ self.hop_length = hop_length
126
+ self.reshape_output = reshape_output
127
+
128
+ def reshape(self, X: np.ndarray) -> np.ndarray:
129
+ X_ = X.copy()
130
+ c_, *_ = X_.shape
131
+ return X_.reshape(c_, -1, self.N_MEL)
132
+
133
+ def fit(self, X: np.ndarray, y: Optional[np.ndarray] = None):
134
+ return self
135
+
136
+ def fit_transform(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> np.ndarray:
137
+ self.fit(X, y)
138
+ return self.transform(X, y)
139
+
140
+ def transform(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> np.ndarray:
141
+ X_ = X.copy()
142
+ c_, *_ = X_.shape
143
+ mel = librosa.feature.melspectrogram(y=X,
144
+ sr=self.sr,
145
+ hop_length=self.hop_length,
146
+ n_mels=self.N_MEL
147
+ )
148
+ if self.reshape_output:
149
+ mel = mel.reshape(c_, -1)
150
+
151
+ return mel
152
+
153
+ def inverse_transform(self, X: np.ndarray) -> np.ndarray:
154
+ X_reconstruct = librosa.feature.inverse.mel_to_audio(
155
+ M = X,
156
+ sr = self.sr,
157
+ hop_length = self.hop_length
158
+ )
159
+ return X_reconstruct
160
+
161
+ class TorchTransform(BaseEstimator, TransformerMixin):
162
+ def __init__(self):
163
+ pass
164
+
165
+ def fit(self, X: np.ndarray, y: Optional[np.ndarray] = None):
166
+ return self
167
+
168
+ def fit_transform(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> torch.Tensor:
169
+ self.fit(X, y)
170
+ return self.transform(X, y)
171
+
172
+ def transform(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> torch.Tensor:
173
+ return torch.tensor(X).unsqueeze(dim=1)
174
+
175
+ def inverse_transform(self, X: torch.Tensor) -> np.ndarray:
176
+ return np.array(X.squeeze(dim=1))
177
+
178
+ class ShuffleTransformer(BaseEstimator, TransformerMixin):
179
+ def __init__(self, p: float = 0.005):
180
+ self.p = p
181
+
182
+ def fit(self, X: np.ndarray, y: Optional[np.ndarray] = None):
183
+ return self
184
+
185
+ def fit_transform(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> np.ndarray:
186
+ self.fit(X, y)
187
+ return self.transform(X, y)
188
+
189
+ def transform(self, X: np.ndarray, y: Optional[np.ndarray] = None) -> np.ndarray:
190
+ will_swap = np.random.choice(X.shape[0], int(self.p*X.shape[0]))
191
+ will_swap_with = np.random.choice(X.shape[0], int(self.p*X.shape[0]))
192
+ if hasattr(X, "copy"):
193
+ X_ = X.copy()
194
+ elif hasattr(X, "clone"):
195
+ X_ = X.clone()
196
+ else:
197
+ X_ = X
198
+ X_[will_swap, ...] = X_[will_swap_with, ...]
199
+ return X_
200
+
201
+ def inverse_transform(self, X: np.ndarray) -> np.ndarray:
202
+ raise NotImplementedError
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ numpy
2
+ matplotlib
3
+ numpy
4
+ pandas
5
+ scikit-learn
6
+ skorch
7
+ librosa
8
+ gradio
resnet.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class ResNet(nn.Module):
6
+ def __init__(self, in_channels: int, num_classes: int):
7
+ """ResNet9"""
8
+ super().__init__()
9
+
10
+ self.conv1 = ConvBlock(in_channels, 64)
11
+ self.conv2 = ConvBlock(64, 128, pool=True)
12
+ self.res1 = nn.Sequential(
13
+ ConvBlock(128, 128),
14
+ ConvBlock(128, 128)
15
+ )
16
+
17
+ self.conv3 = ConvBlock(128, 256)
18
+ self.conv4 = ConvBlock(256, 512, pool=True)
19
+ self.res2 = nn.Sequential(
20
+ ConvBlock(512, 512),
21
+ ConvBlock(512, 512)
22
+ )
23
+
24
+ self.classifier = nn.Sequential(
25
+ nn.MaxPool2d(kernel_size=(4, 4)),
26
+ nn.AdaptiveAvgPool2d(1),
27
+ nn.Flatten(),
28
+ nn.Linear(512, 128),
29
+ nn.Dropout(0.25),
30
+ nn.Linear(128, num_classes),
31
+ nn.Dropout(0.25),
32
+ )
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ x = self.conv1(x)
36
+ x = self.conv2(x)
37
+ x = self.res1(x) + x #skip
38
+ x = self.conv3(x)
39
+ x = self.conv4(x)
40
+ x = self.res2(x) + x #skip
41
+ prediction = self.classifier(x)
42
+ return prediction
43
+
44
+ class ConvBlock(nn.Module):
45
+ def __init__(self, in_channels: int, out_channels: int, pool: bool = False, pool_no: int = 2):
46
+ super().__init__()
47
+ self.in_channels = in_channels
48
+ self.out_channels = out_channels
49
+ self.pool = pool
50
+ self.pool_no = pool_no
51
+
52
+ if self.pool:
53
+ self.pool_block = nn.Sequential(
54
+ nn.ReLU(inplace=True),
55
+ nn.MaxPool2d(self.pool_no)
56
+ )
57
+ else:
58
+ self.pool_block = nn.Sequential(
59
+ nn.ReLU(inplace=True),
60
+ )
61
+
62
+ self.block = nn.Sequential(
63
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
64
+ nn.BatchNorm2d(out_channels),
65
+ self.pool_block
66
+ )
67
+
68
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
69
+ x = self.block(x)
70
+ return x