{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Best Model" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The autoreload extension is already loaded. To reload it, use:\n", " %reload_ext autoreload\n" ] } ], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "import numpy as np\n", "\n", "import skorch\n", "import torch\n", "import torch.nn as nn\n", "\n", "import gradio as gr\n", "\n", "import librosa\n", "\n", "from joblib import dump, load\n", "\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.preprocessing import LabelEncoder\n", "\n", "from resnet import ResNet\n", "from gradio_utils import load_as_librosa, predict_gradio\n", "from dataloading import uniformize, to_numpy\n", "from preprocessing import MfccTransformer, TorchTransform\n", "\n" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "# Notebook params\n", "SEED : int = 42\n", "np.random.seed(SEED)\n", "torch.manual_seed(SEED)\n", "\n", "# Dataloading params\n", "PATHS: list = [\n", " \"../data/\",\n", " \"../new_data/JulienNestor\",\n", " \"../new_data/classroom_data\",\n", " \"../new_data/class\",\n", " \"../new_data/JulienRaph\",\n", "]\n", "REMOVE_LABEL: list = [\n", " \"penduleinverse\", \"pendule\", \n", " \"decollage\", \"atterrissage\",\n", " \"plushaut\", \"plusbas\",\n", " \"etatdurgence\",\n", " \"faisunflip\", \n", " \"faisUnFlip\", \"arreteToi\", \"etatDurgence\",\n", " # \"tournedroite\", \"arretetoi\", \"tournegauche\"\n", "]\n", "SAMPLE_RATE: int = 16_000\n", "METHOD: str = \"time_stretch\"\n", "MAX_TIME: float = 3.0\n", "\n", "# Features Extraction params\n", "N_MFCC: int = 64\n", "HOP_LENGHT = 2_048" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 1 - Dataloading" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "# 1-Dataloading\n", "from dataloading import load_dataset, to_numpy\n", "dataset, uniform_lambda = load_dataset(PATHS,\n", " remove_label=REMOVE_LABEL,\n", " sr=SAMPLE_RATE,\n", " method=METHOD,\n", " max_time=MAX_TIME\n", " )" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['recule',\n", " 'tournedroite',\n", " 'arretetoi',\n", " 'tournegauche',\n", " 'gauche',\n", " 'avance',\n", " 'droite']" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(dataset[\"ground_truth\"].unique())" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "# 2-Train and split\n", "from sklearn.model_selection import train_test_split\n", "dataset_train, dataset_test = train_test_split(dataset, random_state=0)\n", "\n", "X_train = to_numpy(dataset_train[\"y_uniform\"])\n", "y_train = to_numpy(dataset_train[\"ground_truth\"])\n", "X_test = to_numpy(dataset_test[\"y_uniform\"])\n", "y_test = to_numpy(dataset_test[\"ground_truth\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 2 - Preprocessing" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "only_mffc_transform = Pipeline(\n", " steps=[\n", " (\"mfcc\", MfccTransformer(N_MFCC=N_MFCC, reshape_output=False, hop_length=HOP_LENGHT)),\n", " (\"torch\", TorchTransform())\n", " ]\n", ")\n", "\n", "only_mffc_transform.fit(X_train)\n", "\n", "X_train_mfcc_torch = only_mffc_transform.transform(X_train)\n", "X_test_mfcc_torch = only_mffc_transform.transform(X_test)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "# Train a LabelEncoder (if needed)\n", "label_encoder = LabelEncoder()\n", "label_encoder.fit(y_train)\n", "y_train_enc = label_encoder.transform(y_train)\n", "y_test_enc = label_encoder.transform(y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 3 - ResNet" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "if hasattr(torch, \"has_mps\") and torch.has_mps:\n", " device = torch.device(\"mps\")\n", "elif hasattr(torch, \"has_cuda\") and torch.has_cuda:\n", " device = torch.device(\"cuda\")\n", "else:\n", " device = torch.device(\"cpu\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.1 - nn.Module" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "# from resnet import ResNet" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3.2 - Train" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " epoch train_loss dur\n", "------- ------------ ------\n", " 1 \u001b[36m2.8646\u001b[0m 0.4461\n", " 2 \u001b[36m1.9534\u001b[0m 0.4322\n", " 3 \u001b[36m1.8164\u001b[0m 0.4331\n", " 4 \u001b[36m1.6889\u001b[0m 0.4318\n", " 5 \u001b[36m1.5808\u001b[0m 0.4329\n", " 6 \u001b[36m1.4659\u001b[0m 0.4355\n", " 7 \u001b[36m1.2894\u001b[0m 0.4285\n", " 8 1.3207 0.4280\n", " 9 \u001b[36m1.1546\u001b[0m 0.4274\n", " 10 \u001b[36m1.0586\u001b[0m 0.4287\n", " 11 \u001b[36m1.0195\u001b[0m 0.4313\n", " 12 \u001b[36m0.8246\u001b[0m 0.4302\n", " 13 \u001b[36m0.7612\u001b[0m 0.4330\n", " 14 \u001b[36m0.7296\u001b[0m 0.4315\n", " 15 \u001b[36m0.6690\u001b[0m 0.4293\n", " 16 \u001b[36m0.6205\u001b[0m 0.4291\n", " 17 \u001b[36m0.5764\u001b[0m 0.4290\n", " 18 \u001b[36m0.4839\u001b[0m 0.4284\n", " 19 0.4984 0.4314\n", " 20 \u001b[36m0.4666\u001b[0m 0.4324\n", " 21 \u001b[36m0.4132\u001b[0m 0.4322\n", " 22 0.4440 0.4300\n", " 23 0.4463 0.4300\n", " 24 \u001b[36m0.4075\u001b[0m 0.4287\n", " 25 \u001b[36m0.3908\u001b[0m 0.4282\n", " 26 \u001b[36m0.3759\u001b[0m 0.4278\n", " 27 \u001b[36m0.3612\u001b[0m 0.4296\n", " 28 \u001b[36m0.3189\u001b[0m 0.4281\n", " 29 0.3489 0.4308\n", " 30 0.3308 0.4301\n", " 31 0.3353 0.4299\n", " 32 \u001b[36m0.3074\u001b[0m 0.4298\n", " 33 0.3339 0.4350\n", " 34 \u001b[36m0.2921\u001b[0m 0.4383\n", " 35 \u001b[36m0.2852\u001b[0m 0.4345\n", " 36 0.3170 0.4334\n", " 37 0.2853 0.4304\n", " 38 0.2857 0.4307\n", " 39 \u001b[36m0.2607\u001b[0m 0.4310\n", " 40 0.2765 0.4292\n", " 41 0.2831 0.4305\n", " 42 0.2836 0.4295\n", " 43 0.2742 0.4307\n", " 44 0.2653 0.4302\n", " 45 \u001b[36m0.2370\u001b[0m 0.4335\n", " 46 0.2475 0.4292\n", " 47 0.2692 0.4329\n", " 48 0.2657 0.4306\n", " 49 0.2875 0.4305\n", " 50 0.2839 0.4315\n", " 51 0.2555 0.4307\n", " 52 0.2794 0.4332\n", " 53 \u001b[36m0.2272\u001b[0m 0.4302\n", " 54 0.2519 0.4305\n", " 55 0.2388 0.4307\n", " 56 0.2504 0.4314\n", " 57 0.2345 0.4328\n", " 58 \u001b[36m0.2252\u001b[0m 0.4316\n", " 59 0.2436 0.4329\n", " 60 0.2297 0.4309\n", " 61 0.2594 0.4306\n", " 62 0.2412 0.4300\n", " 63 0.2399 0.4319\n", " 64 0.2600 0.4334\n", " 65 0.2599 0.4304\n", " 66 0.2360 0.4317\n", " 67 0.2537 0.4301\n", " 68 0.2268 0.4299\n", " 69 0.2436 0.4301\n", " 70 \u001b[36m0.2193\u001b[0m 0.4308\n", " 71 0.2284 0.4322\n", " 72 0.2339 0.4317\n", " 73 0.2330 0.4331\n", " 74 \u001b[36m0.2063\u001b[0m 0.4327\n", " 75 0.2568 0.4332\n", " 76 0.2372 0.4324\n", " 77 0.2249 0.4327\n", " 78 0.2449 0.4314\n", " 79 0.2455 0.4310\n", " 80 \u001b[36m0.2003\u001b[0m 0.4321\n", " 81 0.2172 0.4318\n", " 82 0.2278 0.4333\n", " 83 0.2178 0.4334\n", " 84 0.2240 0.4312\n", " 85 0.2329 0.4338\n", " 86 0.2267 0.4326\n", " 87 0.2479 0.4341\n", " 88 0.2266 0.4355\n", " 89 0.2541 0.4350\n", " 90 0.2167 0.4324\n", " 91 0.2282 0.4353\n", " 92 0.2097 0.4367\n", " 93 0.2038 0.4351\n", " 94 0.2078 0.4372\n", " 95 0.2437 0.4344\n", " 96 0.2283 0.4333\n", " 97 0.2263 0.4329\n", " 98 0.2146 0.4346\n", " 99 0.2238 0.4323\n", " 100 0.2035 0.4348\n", " 101 0.2287 0.4348\n", " 102 0.2231 0.4328\n", " 103 0.2171 0.4326\n", " 104 0.2417 0.4329\n", "Stopping since train_loss has not improved in the last 25 epochs.\n", "0.941908713692946\n" ] } ], "source": [ "# Define net\n", "n_labels = np.unique(dataset.ground_truth).size\n", "net = ResNet(in_channels=1, num_classes=n_labels)\n", "\n", "# Define model\n", "model = skorch.NeuralNetClassifier(\n", " module=net,\n", " criterion=nn.CrossEntropyLoss(),\n", " callbacks=[skorch.callbacks.EarlyStopping(monitor=\"train_loss\", patience=25)],\n", " max_epochs=200,\n", " lr=0.01,\n", " batch_size=128,\n", " train_split=None,\n", " device=device,\n", ")\n", "\n", "model.check_data(X_train_mfcc_torch, y_train_enc)\n", "model.fit(X_train_mfcc_torch, y_train_enc)\n", "\n", "print(model.score(X_test_mfcc_torch, y_test_enc))" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['./model/HOP_LENGHT.joblib']" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from joblib import dump, load\n", "\n", "dump(model, './model/model.joblib') \n", "dump(only_mffc_transform, './model/only_mffc_transform.joblib') \n", "dump(label_encoder, './model/label_encoder.joblib')\n", "dump(SAMPLE_RATE, \"./model/SAMPLE_RATE.joblib\")\n", "dump(METHOD, \"./model/METHOD.joblib\")\n", "dump(MAX_TIME, \"./model/MAX_TIME.joblib\")\n", "dump(N_MFCC, \"./model/N_MFCC.joblib\")\n", "dump(HOP_LENGHT, \"./model/HOP_LENGHT.joblib\")" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "model = load('./model/model.joblib') \n", "only_mffc_transform = load('./model/only_mffc_transform.joblib') \n", "label_encoder = load('./model/label_encoder.joblib') \n", "SAMPLE_RATE = load(\"./model/SAMPLE_RATE.joblib\")\n", "METHOD = load(\"./model/METHOD.joblib\")\n", "MAX_TIME = load(\"./model/MAX_TIME.joblib\")\n", "N_MFCC = load(\"./model/N_MFCC.joblib\")\n", "HOP_LENGHT = load(\"./model/HOP_LENGHT.joblib\")\n", "\n", "sklearn_model = Pipeline(\n", " steps=[\n", " (\"mfcc\", only_mffc_transform),\n", " (\"model\", model)\n", " ]\n", " )\n", "\n", "uniform_lambda = lambda y, sr: uniformize(y, sr, METHOD, MAX_TIME)" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "ename": "", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click here for more info. View Jupyter log for further details." ] } ], "source": [ "title = r\"ResNet 9\"\n", "\n", "description = r\"\"\"\n", "