SleepyJesse commited on
Commit
bfc8bd3
·
verified ·
1 Parent(s): d033c87

Upload ai_music_detection_new_large.ipynb

Browse files
Files changed (1) hide show
  1. ai_music_detection_new_large.ipynb +819 -0
ai_music_detection_new_large.ipynb ADDED
@@ -0,0 +1,819 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Requirement already satisfied: librosa in /opt/conda/lib/python3.10/site-packages (0.10.2.post1)\n",
13
+ "Requirement already satisfied: soundfile in /opt/conda/lib/python3.10/site-packages (0.12.1)\n",
14
+ "Requirement already satisfied: torchaudio in /opt/conda/lib/python3.10/site-packages (2.2.0)\n",
15
+ "Requirement already satisfied: audiomentations in /opt/conda/lib/python3.10/site-packages (0.37.0)\n",
16
+ "Requirement already satisfied: evaluate in /opt/conda/lib/python3.10/site-packages (0.4.3)\n",
17
+ "Requirement already satisfied: ipywidgets in /opt/conda/lib/python3.10/site-packages (8.1.5)\n",
18
+ "Requirement already satisfied: matplotlib in /opt/conda/lib/python3.10/site-packages (3.9.3)\n",
19
+ "Requirement already satisfied: tensorboard in /opt/conda/lib/python3.10/site-packages (2.18.0)\n",
20
+ "Requirement already satisfied: datasets[audio] in /opt/conda/lib/python3.10/site-packages (3.1.0)\n",
21
+ "Requirement already satisfied: transformers[torch] in /opt/conda/lib/python3.10/site-packages (4.47.0)\n",
22
+ "Requirement already satisfied: audioread>=2.1.9 in /opt/conda/lib/python3.10/site-packages (from librosa) (3.0.1)\n",
23
+ "Requirement already satisfied: numpy!=1.22.0,!=1.22.1,!=1.22.2,>=1.20.3 in /opt/conda/lib/python3.10/site-packages (from librosa) (1.26.3)\n",
24
+ "Requirement already satisfied: scipy>=1.2.0 in /opt/conda/lib/python3.10/site-packages (from librosa) (1.12.0)\n",
25
+ "Requirement already satisfied: scikit-learn>=0.20.0 in /opt/conda/lib/python3.10/site-packages (from librosa) (1.5.2)\n",
26
+ "Requirement already satisfied: joblib>=0.14 in /opt/conda/lib/python3.10/site-packages (from librosa) (1.4.2)\n",
27
+ "Requirement already satisfied: decorator>=4.3.0 in /opt/conda/lib/python3.10/site-packages (from librosa) (5.1.1)\n",
28
+ "Requirement already satisfied: numba>=0.51.0 in /opt/conda/lib/python3.10/site-packages (from librosa) (0.60.0)\n",
29
+ "Requirement already satisfied: pooch>=1.1 in /opt/conda/lib/python3.10/site-packages (from librosa) (1.8.2)\n",
30
+ "Requirement already satisfied: soxr>=0.3.2 in /opt/conda/lib/python3.10/site-packages (from librosa) (0.5.0.post1)\n",
31
+ "Requirement already satisfied: typing-extensions>=4.1.1 in /opt/conda/lib/python3.10/site-packages (from librosa) (4.9.0)\n",
32
+ "Requirement already satisfied: lazy-loader>=0.1 in /opt/conda/lib/python3.10/site-packages (from librosa) (0.4)\n",
33
+ "Requirement already satisfied: msgpack>=1.0 in /opt/conda/lib/python3.10/site-packages (from librosa) (1.1.0)\n",
34
+ "Requirement already satisfied: cffi>=1.0 in /opt/conda/lib/python3.10/site-packages (from soundfile) (1.16.0)\n",
35
+ "Requirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (from torchaudio) (2.2.0)\n",
36
+ "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (3.13.1)\n",
37
+ "Requirement already satisfied: pyarrow>=15.0.0 in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (18.1.0)\n",
38
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (0.3.8)\n",
39
+ "Requirement already satisfied: pandas in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (2.2.3)\n",
40
+ "Requirement already satisfied: requests>=2.32.2 in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (2.32.3)\n",
41
+ "Requirement already satisfied: tqdm>=4.66.3 in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (4.67.1)\n",
42
+ "Requirement already satisfied: xxhash in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (3.5.0)\n",
43
+ "Requirement already satisfied: multiprocess<0.70.17 in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (0.70.16)\n",
44
+ "Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /opt/conda/lib/python3.10/site-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets[audio]) (2023.12.2)\n",
45
+ "Requirement already satisfied: aiohttp in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (3.11.10)\n",
46
+ "Requirement already satisfied: huggingface-hub>=0.23.0 in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (0.26.3)\n",
47
+ "Requirement already satisfied: packaging in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (23.1)\n",
48
+ "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from datasets[audio]) (6.0.1)\n",
49
+ "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.10/site-packages (from transformers[torch]) (2024.11.6)\n",
50
+ "Requirement already satisfied: tokenizers<0.22,>=0.21 in /opt/conda/lib/python3.10/site-packages (from transformers[torch]) (0.21.0)\n",
51
+ "Requirement already satisfied: safetensors>=0.4.1 in /opt/conda/lib/python3.10/site-packages (from transformers[torch]) (0.4.5)\n",
52
+ "Requirement already satisfied: accelerate>=0.26.0 in /opt/conda/lib/python3.10/site-packages (from transformers[torch]) (1.1.1)\n",
53
+ "Requirement already satisfied: numpy-minmax<1,>=0.3.0 in /opt/conda/lib/python3.10/site-packages (from audiomentations) (0.3.1)\n",
54
+ "Requirement already satisfied: numpy-rms<1,>=0.4.2 in /opt/conda/lib/python3.10/site-packages (from audiomentations) (0.4.2)\n",
55
+ "Requirement already satisfied: comm>=0.1.3 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (0.2.2)\n",
56
+ "Requirement already satisfied: ipython>=6.1.0 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (8.20.0)\n",
57
+ "Requirement already satisfied: traitlets>=4.3.1 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (5.7.1)\n",
58
+ "Requirement already satisfied: widgetsnbextension~=4.0.12 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (4.0.13)\n",
59
+ "Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /opt/conda/lib/python3.10/site-packages (from ipywidgets) (3.0.13)\n",
60
+ "Requirement already satisfied: contourpy>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (1.3.1)\n",
61
+ "Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (0.12.1)\n",
62
+ "Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (4.55.2)\n",
63
+ "Requirement already satisfied: kiwisolver>=1.3.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (1.4.7)\n",
64
+ "Requirement already satisfied: pillow>=8 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (10.0.1)\n",
65
+ "Requirement already satisfied: pyparsing>=2.3.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (3.2.0)\n",
66
+ "Requirement already satisfied: python-dateutil>=2.7 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (2.9.0.post0)\n",
67
+ "Requirement already satisfied: absl-py>=0.4 in /opt/conda/lib/python3.10/site-packages (from tensorboard) (2.1.0)\n",
68
+ "Requirement already satisfied: grpcio>=1.48.2 in /opt/conda/lib/python3.10/site-packages (from tensorboard) (1.68.1)\n",
69
+ "Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.10/site-packages (from tensorboard) (3.7)\n",
70
+ "Requirement already satisfied: protobuf!=4.24.0,>=3.19.6 in /opt/conda/lib/python3.10/site-packages (from tensorboard) (5.29.1)\n",
71
+ "Requirement already satisfied: setuptools>=41.0.0 in /opt/conda/lib/python3.10/site-packages (from tensorboard) (68.2.2)\n",
72
+ "Requirement already satisfied: six>1.9 in /opt/conda/lib/python3.10/site-packages (from tensorboard) (1.16.0)\n",
73
+ "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /opt/conda/lib/python3.10/site-packages (from tensorboard) (0.7.2)\n",
74
+ "Requirement already satisfied: werkzeug>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from tensorboard) (3.1.3)\n",
75
+ "Requirement already satisfied: psutil in /opt/conda/lib/python3.10/site-packages (from accelerate>=0.26.0->transformers[torch]) (5.9.0)\n",
76
+ "Requirement already satisfied: pycparser in /opt/conda/lib/python3.10/site-packages (from cffi>=1.0->soundfile) (2.21)\n",
77
+ "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets[audio]) (2.4.4)\n",
78
+ "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets[audio]) (1.3.1)\n",
79
+ "Requirement already satisfied: async-timeout<6.0,>=4.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets[audio]) (5.0.1)\n",
80
+ "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets[audio]) (23.1.0)\n",
81
+ "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets[audio]) (1.5.0)\n",
82
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets[audio]) (6.1.0)\n",
83
+ "Requirement already satisfied: propcache>=0.2.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets[audio]) (0.2.1)\n",
84
+ "Requirement already satisfied: yarl<2.0,>=1.17.0 in /opt/conda/lib/python3.10/site-packages (from aiohttp->datasets[audio]) (1.18.3)\n",
85
+ "Requirement already satisfied: jedi>=0.16 in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.18.1)\n",
86
+ "Requirement already satisfied: matplotlib-inline in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.1.6)\n",
87
+ "Requirement already satisfied: prompt-toolkit<3.1.0,>=3.0.41 in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (3.0.43)\n",
88
+ "Requirement already satisfied: pygments>=2.4.0 in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (2.15.1)\n",
89
+ "Requirement already satisfied: stack-data in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (0.2.0)\n",
90
+ "Requirement already satisfied: exceptiongroup in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (1.2.0)\n",
91
+ "Requirement already satisfied: pexpect>4.3 in /opt/conda/lib/python3.10/site-packages (from ipython>=6.1.0->ipywidgets) (4.8.0)\n",
92
+ "Requirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /opt/conda/lib/python3.10/site-packages (from numba>=0.51.0->librosa) (0.43.0)\n",
93
+ "Requirement already satisfied: platformdirs>=2.5.0 in /opt/conda/lib/python3.10/site-packages (from pooch>=1.1->librosa) (3.10.0)\n",
94
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests>=2.32.2->datasets[audio]) (2.0.4)\n",
95
+ "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests>=2.32.2->datasets[audio]) (3.4)\n",
96
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests>=2.32.2->datasets[audio]) (1.26.18)\n",
97
+ "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests>=2.32.2->datasets[audio]) (2023.11.17)\n",
98
+ "Requirement already satisfied: threadpoolctl>=3.1.0 in /opt/conda/lib/python3.10/site-packages (from scikit-learn>=0.20.0->librosa) (3.5.0)\n",
99
+ "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch->torchaudio) (1.12)\n",
100
+ "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch->torchaudio) (3.1)\n",
101
+ "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch->torchaudio) (3.1.2)\n",
102
+ "Requirement already satisfied: MarkupSafe>=2.1.1 in /opt/conda/lib/python3.10/site-packages (from werkzeug>=1.0.1->tensorboard) (2.1.3)\n",
103
+ "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets[audio]) (2023.3.post1)\n",
104
+ "Requirement already satisfied: tzdata>=2022.7 in /opt/conda/lib/python3.10/site-packages (from pandas->datasets[audio]) (2024.2)\n",
105
+ "Requirement already satisfied: parso<0.9.0,>=0.8.0 in /opt/conda/lib/python3.10/site-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets) (0.8.3)\n",
106
+ "Requirement already satisfied: ptyprocess>=0.5 in /opt/conda/lib/python3.10/site-packages (from pexpect>4.3->ipython>=6.1.0->ipywidgets) (0.7.0)\n",
107
+ "Requirement already satisfied: wcwidth in /opt/conda/lib/python3.10/site-packages (from prompt-toolkit<3.1.0,>=3.0.41->ipython>=6.1.0->ipywidgets) (0.2.5)\n",
108
+ "Requirement already satisfied: executing in /opt/conda/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (0.8.3)\n",
109
+ "Requirement already satisfied: asttokens in /opt/conda/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.0.5)\n",
110
+ "Requirement already satisfied: pure-eval in /opt/conda/lib/python3.10/site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (0.2.2)\n",
111
+ "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch->torchaudio) (1.3.0)\n",
112
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
113
+ "\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
114
+ ]
115
+ }
116
+ ],
117
+ "source": [
118
+ "%pip install librosa soundfile torchaudio datasets[audio] transformers[torch] audiomentations evaluate ipywidgets matplotlib tensorboard"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": 2,
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "import torch\n",
128
+ "import torchaudio\n",
129
+ "import librosa\n",
130
+ "import soundfile as sf\n",
131
+ "import numpy as np\n",
132
+ "import os\n",
133
+ "import matplotlib.pyplot as plt\n",
134
+ "import IPython.display as ipd\n",
135
+ "import datasets\n",
136
+ "import evaluate\n",
137
+ "from concurrent.futures import ProcessPoolExecutor\n",
138
+ "from transformers import ASTForAudioClassification, ASTFeatureExtractor, ASTConfig, TrainingArguments, Trainer\n",
139
+ "from audiomentations import Compose, AddGaussianSNR, GainTransition, Gain, ClippingDistortion, TimeStretch, PitchShift\n",
140
+ "from tqdm import tqdm"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": 3,
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "MODEL_DIR = \"/workspace\""
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": 4,
155
+ "metadata": {},
156
+ "outputs": [
157
+ {
158
+ "data": {
159
+ "application/vnd.jupyter.widget-view+json": {
160
+ "model_id": "03839539e36849f181d0d21bdbf63073",
161
+ "version_major": 2,
162
+ "version_minor": 0
163
+ },
164
+ "text/plain": [
165
+ "Resolving data files: 0%| | 0/147 [00:00<?, ?it/s]"
166
+ ]
167
+ },
168
+ "metadata": {},
169
+ "output_type": "display_data"
170
+ },
171
+ {
172
+ "data": {
173
+ "application/vnd.jupyter.widget-view+json": {
174
+ "model_id": "cf03c7eeca6647fc8853e507ca878b03",
175
+ "version_major": 2,
176
+ "version_minor": 0
177
+ },
178
+ "text/plain": [
179
+ "Resolving data files: 0%| | 0/147 [00:00<?, ?it/s]"
180
+ ]
181
+ },
182
+ "metadata": {},
183
+ "output_type": "display_data"
184
+ },
185
+ {
186
+ "data": {
187
+ "application/vnd.jupyter.widget-view+json": {
188
+ "model_id": "147d4213469f4fd294bdd89a4f633f1b",
189
+ "version_major": 2,
190
+ "version_minor": 0
191
+ },
192
+ "text/plain": [
193
+ "Loading dataset shards: 0%| | 0/115 [00:00<?, ?it/s]"
194
+ ]
195
+ },
196
+ "metadata": {},
197
+ "output_type": "display_data"
198
+ },
199
+ {
200
+ "name": "stdout",
201
+ "output_type": "stream",
202
+ "text": [
203
+ "DatasetDict({\n",
204
+ " train: Dataset({\n",
205
+ " features: ['audio', 'source', 'ai_generated'],\n",
206
+ " num_rows: 20000\n",
207
+ " })\n",
208
+ "})\n"
209
+ ]
210
+ }
211
+ ],
212
+ "source": [
213
+ "# Load the dataset\n",
214
+ "ds = datasets.load_dataset(\"SleepyJesse/ai_music_large\")\n",
215
+ "# Resample the audio files to 16kHz\n",
216
+ "ds = ds.cast_column(\"audio\", datasets.Audio(sampling_rate=16000, mono=True))\n",
217
+ "print(ds)"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": 5,
223
+ "metadata": {
224
+ "editable": true,
225
+ "slideshow": {
226
+ "slide_type": ""
227
+ },
228
+ "tags": []
229
+ },
230
+ "outputs": [],
231
+ "source": [
232
+ "# Cast the \"ai_generated\" column (boolean) to class labels (\"ai_generated\" or \"human\")\n",
233
+ "class_labels = datasets.ClassLabel(names=[\"human\", \"ai_generated\"])\n",
234
+ "labels = [1 if x else 0 for x in ds['train']['ai_generated']]\n",
235
+ "ds['train'] = ds['train'].add_column(\"labels\", labels, feature=class_labels)"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": 6,
241
+ "metadata": {},
242
+ "outputs": [],
243
+ "source": [
244
+ "# Remove the \"ai_generated\" and \"source\" columns\n",
245
+ "ds[\"train\"] = ds[\"train\"].remove_columns(\"ai_generated\")\n",
246
+ "ds[\"train\"] = ds[\"train\"].remove_columns(\"source\")\n",
247
+ "\n",
248
+ "# Rename the \"audio\" column to \"input_values\" to match the expected input key for the processor\n",
249
+ "ds = ds.rename_column(\"audio\", \"input_values\")"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "code",
254
+ "execution_count": 7,
255
+ "metadata": {},
256
+ "outputs": [
257
+ {
258
+ "name": "stdout",
259
+ "output_type": "stream",
260
+ "text": [
261
+ "DatasetDict({\n",
262
+ " train: Dataset({\n",
263
+ " features: ['input_values', 'labels'],\n",
264
+ " num_rows: 20000\n",
265
+ " })\n",
266
+ "})\n",
267
+ "{'input_values': {'path': '030312.mp3', 'array': array([ 0. , 0. , 0. , ..., -0.00048378,\n",
268
+ " -0.00049008, 0. ]), 'sampling_rate': 16000}, 'labels': 0}\n"
269
+ ]
270
+ }
271
+ ],
272
+ "source": [
273
+ "print(ds)\n",
274
+ "print(ds[\"train\"][0])"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "code",
279
+ "execution_count": 8,
280
+ "metadata": {},
281
+ "outputs": [],
282
+ "source": [
283
+ "model_name = \"MIT/ast-finetuned-audioset-10-10-0.4593\" # Pre-trained AST model\n",
284
+ "feature_extractor = ASTFeatureExtractor.from_pretrained(model_name)\n",
285
+ "model_input_name = feature_extractor.model_input_names[0]\n",
286
+ "sampling_rate = feature_extractor.sampling_rate"
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "code",
291
+ "execution_count": 9,
292
+ "metadata": {},
293
+ "outputs": [],
294
+ "source": [
295
+ "# Define a function to preprocess the audio data\n",
296
+ "def preprocess_audio(batch):\n",
297
+ " wavs = [audio[\"array\"] for audio in batch[\"input_values\"]]\n",
298
+ " # inputs are spectrograms as torch.tensors now\n",
299
+ " inputs = feature_extractor(wavs, sampling_rate=sampling_rate, return_tensors=\"pt\")\n",
300
+ "\n",
301
+ " output_batch = {model_input_name: inputs.get(model_input_name), \"labels\": list(batch[\"labels\"])}\n",
302
+ " return output_batch\n",
303
+ "\n",
304
+ "# Apply the preprocessing function to the dataset\n",
305
+ "ds[\"train\"].set_transform(preprocess_audio, output_all_columns=False)"
306
+ ]
307
+ },
308
+ {
309
+ "cell_type": "code",
310
+ "execution_count": 10,
311
+ "metadata": {},
312
+ "outputs": [],
313
+ "source": [
314
+ "# Create audio augmentations\n",
315
+ "audio_augmentations = Compose([\n",
316
+ " AddGaussianSNR(min_snr_db=10, max_snr_db=20),\n",
317
+ " Gain(min_gain_db=-6, max_gain_db=6),\n",
318
+ " GainTransition(min_gain_db=-6, max_gain_db=6, min_duration=0.01, max_duration=0.3, duration_unit=\"fraction\"),\n",
319
+ " ClippingDistortion(min_percentile_threshold=0, max_percentile_threshold=30, p=0.5),\n",
320
+ " TimeStretch(min_rate=0.8, max_rate=1.2),\n",
321
+ " PitchShift(min_semitones=-4, max_semitones=4),\n",
322
+ "], p=0.5, shuffle=True)"
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "code",
327
+ "execution_count": 11,
328
+ "metadata": {},
329
+ "outputs": [],
330
+ "source": [
331
+ "# Define the preprocessing function for the audio augmentations\n",
332
+ "def preprocess_audio_with_transforms(batch):\n",
333
+ " # we apply augmentations on each waveform\n",
334
+ " wavs = [audio_augmentations(audio[\"array\"], sample_rate=sampling_rate) for audio in batch[\"input_values\"]]\n",
335
+ " inputs = feature_extractor(wavs, sampling_rate=sampling_rate, return_tensors=\"pt\")\n",
336
+ "\n",
337
+ " output_batch = {model_input_name: inputs.get(model_input_name), \"labels\": list(batch[\"labels\"])}\n",
338
+ " return output_batch"
339
+ ]
340
+ },
341
+ {
342
+ "cell_type": "code",
343
+ "execution_count": 12,
344
+ "metadata": {},
345
+ "outputs": [],
346
+ "source": [
347
+ "# Calculate values for normalization (mean and std) for the dataset (Only need to run this once per dataset)\n",
348
+ "# feature_extractor.do_normalize = False # Disable normalization\n",
349
+ "\n",
350
+ "# means = []\n",
351
+ "# stds = []\n",
352
+ "\n",
353
+ "# def calculate_mean_std(index):\n",
354
+ "# try:\n",
355
+ "# audio_input = ds[\"train\"][index][\"input_values\"]\n",
356
+ "# except Exception as e:\n",
357
+ "# print(f\"Error processing index {index}: {e}\")\n",
358
+ "# return None, None\n",
359
+ "# cur_mean = torch.mean(audio_input)\n",
360
+ "# cur_std = torch.std(audio_input)\n",
361
+ "# return cur_mean, cur_std\n",
362
+ "\n",
363
+ "# with ProcessPoolExecutor() as executor:\n",
364
+ "# results = list(tqdm(executor.map(calculate_mean_std, range(len(ds[\"train\"]))), total=len(ds[\"train\"])))\n",
365
+ "\n",
366
+ "# means, stds = zip(*results)\n",
367
+ "# means = [x.item() for x in means if x is not None]\n",
368
+ "# stds = [x.item() for x in stds if x is not None]\n",
369
+ "# feature_extractor.mean = torch.tensor(means).mean().item()\n",
370
+ "# feature_extractor.std = torch.tensor(stds).mean().item()\n",
371
+ "# feature_extractor.do_normalize = True # Enable normalization\n",
372
+ "\n",
373
+ "# print(f\"Mean: {feature_extractor.mean}\")\n",
374
+ "# print(f\"Std: {feature_extractor.std}\")\n",
375
+ "# print(\"Save these values for normalization if you're using the same dataset in the future.\")"
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "code",
380
+ "execution_count": 13,
381
+ "metadata": {},
382
+ "outputs": [],
383
+ "source": [
384
+ "# Remove corrupted audio files (4481, 8603 in ai_music_large)\n",
385
+ "corrupted_audio_indices = [4481, 8603]\n",
386
+ "keep_indices = [i for i in range(len(ds[\"train\"])) if i not in corrupted_audio_indices]\n",
387
+ "ds[\"train\"] = ds[\"train\"].select(keep_indices, writer_batch_size=50)"
388
+ ]
389
+ },
390
+ {
391
+ "cell_type": "code",
392
+ "execution_count": 14,
393
+ "metadata": {},
394
+ "outputs": [],
395
+ "source": [
396
+ "# Set the normalization values in the feature extractor (the following values are for the ai_music_large dataset)\n",
397
+ "feature_extractor.mean = -4.855465888977051\n",
398
+ "feature_extractor.std = 3.2848217487335205"
399
+ ]
400
+ },
401
+ {
402
+ "cell_type": "code",
403
+ "execution_count": 15,
404
+ "metadata": {},
405
+ "outputs": [
406
+ {
407
+ "name": "stdout",
408
+ "output_type": "stream",
409
+ "text": [
410
+ "Mean: -4.855465888977051\n",
411
+ "Std: 3.2848217487335205\n"
412
+ ]
413
+ }
414
+ ],
415
+ "source": [
416
+ "print(f\"Mean: {feature_extractor.mean}\")\n",
417
+ "print(f\"Std: {feature_extractor.std}\")"
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "code",
422
+ "execution_count": 16,
423
+ "metadata": {},
424
+ "outputs": [
425
+ {
426
+ "name": "stdout",
427
+ "output_type": "stream",
428
+ "text": [
429
+ "DatasetDict({\n",
430
+ " train: Dataset({\n",
431
+ " features: ['input_values', 'labels'],\n",
432
+ " num_rows: 15998\n",
433
+ " })\n",
434
+ " test: Dataset({\n",
435
+ " features: ['input_values', 'labels'],\n",
436
+ " num_rows: 4000\n",
437
+ " })\n",
438
+ "})\n"
439
+ ]
440
+ }
441
+ ],
442
+ "source": [
443
+ "# Split the dataset\n",
444
+ "if \"test\" not in ds:\n",
445
+ " ds = ds[\"train\"].train_test_split(test_size=0.2, shuffle=True, seed=42, stratify_by_column=\"labels\")\n",
446
+ "\n",
447
+ "print(ds)"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "code",
452
+ "execution_count": 17,
453
+ "metadata": {},
454
+ "outputs": [],
455
+ "source": [
456
+ "# Set transforms for the train and test sets\n",
457
+ "ds[\"train\"].set_transform(preprocess_audio_with_transforms, output_all_columns=False)\n",
458
+ "ds[\"test\"].set_transform(preprocess_audio, output_all_columns=False)"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "code",
463
+ "execution_count": 18,
464
+ "metadata": {},
465
+ "outputs": [
466
+ {
467
+ "name": "stderr",
468
+ "output_type": "stream",
469
+ "text": [
470
+ "Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:\n",
471
+ "- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([2]) in the model instantiated\n",
472
+ "- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated\n",
473
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
474
+ ]
475
+ }
476
+ ],
477
+ "source": [
478
+ "# Load config from the pre-trained model\n",
479
+ "config = ASTConfig.from_pretrained(model_name)\n",
480
+ "\n",
481
+ "# Update the config with the labels we have in the dataset\n",
482
+ "config.num_labels = len(ds[\"train\"].features[\"labels\"].names)\n",
483
+ "config.label2id = {name: id for id, name in enumerate(ds[\"train\"].features[\"labels\"].names)}\n",
484
+ "config.id2label = {id: name for name, id in config.label2id.items()}\n",
485
+ "\n",
486
+ "# Initialize the model\n",
487
+ "model = ASTForAudioClassification.from_pretrained(model_name, config=config, ignore_mismatched_sizes=True)\n",
488
+ "model.init_weights()"
489
+ ]
490
+ },
491
+ {
492
+ "cell_type": "code",
493
+ "execution_count": 19,
494
+ "metadata": {},
495
+ "outputs": [],
496
+ "source": [
497
+ "# Configure the training arguments\n",
498
+ "training_args = TrainingArguments(\n",
499
+ " output_dir=MODEL_DIR + \"/out/ast_classifier_small\",\n",
500
+ " logging_dir=MODEL_DIR + \"/logs/ast_classifier_small\",\n",
501
+ " report_to=\"tensorboard\",\n",
502
+ " learning_rate=5e-5,\n",
503
+ " push_to_hub=False,\n",
504
+ " num_train_epochs=10,\n",
505
+ " per_device_train_batch_size=8,\n",
506
+ " eval_strategy=\"epoch\",\n",
507
+ " save_strategy=\"epoch\",\n",
508
+ " eval_steps=1,\n",
509
+ " save_steps=1,\n",
510
+ " logging_steps=10,\n",
511
+ " metric_for_best_model=\"accuracy\",\n",
512
+ " dataloader_num_workers=24,\n",
513
+ " dataloader_prefetch_factor=4,\n",
514
+ " dataloader_persistent_workers=True,\n",
515
+ ")"
516
+ ]
517
+ },
518
+ {
519
+ "cell_type": "code",
520
+ "execution_count": 20,
521
+ "metadata": {},
522
+ "outputs": [],
523
+ "source": [
524
+ "# Define evaluation metrics\n",
525
+ "accuracy = evaluate.load(\"accuracy\")\n",
526
+ "recall = evaluate.load(\"recall\")\n",
527
+ "precision = evaluate.load(\"precision\")\n",
528
+ "f1 = evaluate.load(\"f1\")\n",
529
+ "\n",
530
+ "average = \"macro\" if config.num_labels > 2 else \"binary\"\n",
531
+ "\n",
532
+ "def compute_metrics(eval_pred):\n",
533
+ " logits = eval_pred.predictions\n",
534
+ " predictions = np.argmax(logits, axis=-1)\n",
535
+ " metrics = accuracy.compute(predictions=predictions, references=eval_pred.label_ids)\n",
536
+ " metrics.update(precision.compute(predictions=predictions, references=eval_pred.label_ids, average=average))\n",
537
+ " metrics.update(recall.compute(predictions=predictions, references=eval_pred.label_ids, average=average))\n",
538
+ " metrics.update(f1.compute(predictions=predictions, references=eval_pred.label_ids, average=average))\n",
539
+ " return metrics"
540
+ ]
541
+ },
542
+ {
543
+ "cell_type": "code",
544
+ "execution_count": 21,
545
+ "metadata": {},
546
+ "outputs": [
547
+ {
548
+ "name": "stderr",
549
+ "output_type": "stream",
550
+ "text": [
551
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
552
+ " warnings.warn(\n",
553
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
554
+ " warnings.warn(\n",
555
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
556
+ " warnings.warn(\n",
557
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
558
+ " warnings.warn(\n",
559
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
560
+ " warnings.warn(\n",
561
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
562
+ " warnings.warn(\n",
563
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
564
+ " warnings.warn(\n",
565
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
566
+ " warnings.warn(\n",
567
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
568
+ " warnings.warn(\n",
569
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
570
+ " warnings.warn(\n",
571
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
572
+ " warnings.warn(\n",
573
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
574
+ " warnings.warn(\n",
575
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
576
+ " warnings.warn(\n",
577
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
578
+ " warnings.warn(\n",
579
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
580
+ " warnings.warn(\n",
581
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
582
+ " warnings.warn(\n",
583
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
584
+ " warnings.warn(\n",
585
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
586
+ " warnings.warn(\n",
587
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
588
+ " warnings.warn(\n",
589
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
590
+ " warnings.warn(\n",
591
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
592
+ " warnings.warn(\n",
593
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
594
+ " warnings.warn(\n",
595
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
596
+ " warnings.warn(\n",
597
+ "/opt/conda/lib/python3.10/site-packages/audiomentations/core/transforms_interface.py:62: UserWarning: Warning: input samples dtype is np.float64. Converting to np.float32\n",
598
+ " warnings.warn(\n"
599
+ ]
600
+ },
601
+ {
602
+ "data": {
603
+ "text/html": [
604
+ "\n",
605
+ " <div>\n",
606
+ " \n",
607
+ " <progress value='20000' max='20000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
608
+ " [20000/20000 2:46:13, Epoch 10/10]\n",
609
+ " </div>\n",
610
+ " <table border=\"1\" class=\"dataframe\">\n",
611
+ " <thead>\n",
612
+ " <tr style=\"text-align: left;\">\n",
613
+ " <th>Epoch</th>\n",
614
+ " <th>Training Loss</th>\n",
615
+ " <th>Validation Loss</th>\n",
616
+ " <th>Accuracy</th>\n",
617
+ " <th>Precision</th>\n",
618
+ " <th>Recall</th>\n",
619
+ " <th>F1</th>\n",
620
+ " </tr>\n",
621
+ " </thead>\n",
622
+ " <tbody>\n",
623
+ " <tr>\n",
624
+ " <td>5</td>\n",
625
+ " <td>0.031900</td>\n",
626
+ " <td>0.066149</td>\n",
627
+ " <td>0.982250</td>\n",
628
+ " <td>0.999482</td>\n",
629
+ " <td>0.965000</td>\n",
630
+ " <td>0.981938</td>\n",
631
+ " </tr>\n",
632
+ " <tr>\n",
633
+ " <td>6</td>\n",
634
+ " <td>0.234200</td>\n",
635
+ " <td>0.031733</td>\n",
636
+ " <td>0.992000</td>\n",
637
+ " <td>0.992000</td>\n",
638
+ " <td>0.992000</td>\n",
639
+ " <td>0.992000</td>\n",
640
+ " </tr>\n",
641
+ " <tr>\n",
642
+ " <td>7</td>\n",
643
+ " <td>0.063600</td>\n",
644
+ " <td>0.046821</td>\n",
645
+ " <td>0.992500</td>\n",
646
+ " <td>0.998987</td>\n",
647
+ " <td>0.986000</td>\n",
648
+ " <td>0.992451</td>\n",
649
+ " </tr>\n",
650
+ " <tr>\n",
651
+ " <td>8</td>\n",
652
+ " <td>0.210500</td>\n",
653
+ " <td>0.017158</td>\n",
654
+ " <td>0.995500</td>\n",
655
+ " <td>0.997990</td>\n",
656
+ " <td>0.993000</td>\n",
657
+ " <td>0.995489</td>\n",
658
+ " </tr>\n",
659
+ " <tr>\n",
660
+ " <td>9</td>\n",
661
+ " <td>0.001100</td>\n",
662
+ " <td>0.016046</td>\n",
663
+ " <td>0.996750</td>\n",
664
+ " <td>0.998995</td>\n",
665
+ " <td>0.994500</td>\n",
666
+ " <td>0.996743</td>\n",
667
+ " </tr>\n",
668
+ " <tr>\n",
669
+ " <td>10</td>\n",
670
+ " <td>0.001500</td>\n",
671
+ " <td>0.011154</td>\n",
672
+ " <td>0.997500</td>\n",
673
+ " <td>0.998497</td>\n",
674
+ " <td>0.996500</td>\n",
675
+ " <td>0.997497</td>\n",
676
+ " </tr>\n",
677
+ " </tbody>\n",
678
+ "</table><p>"
679
+ ],
680
+ "text/plain": [
681
+ "<IPython.core.display.HTML object>"
682
+ ]
683
+ },
684
+ "metadata": {},
685
+ "output_type": "display_data"
686
+ },
687
+ {
688
+ "data": {
689
+ "text/plain": [
690
+ "TrainOutput(global_step=20000, training_loss=0.04969663131231209, metrics={'train_runtime': 9996.4827, 'train_samples_per_second': 16.004, 'train_steps_per_second': 2.001, 'total_flos': 1.084389872624468e+19, 'train_loss': 0.04969663131231209, 'epoch': 10.0})"
691
+ ]
692
+ },
693
+ "execution_count": 21,
694
+ "metadata": {},
695
+ "output_type": "execute_result"
696
+ }
697
+ ],
698
+ "source": [
699
+ "# Initialize the Trainer\n",
700
+ "trainer = Trainer(\n",
701
+ " model=model,\n",
702
+ " args=training_args,\n",
703
+ " train_dataset=ds[\"train\"],\n",
704
+ " eval_dataset=ds[\"test\"],\n",
705
+ " compute_metrics=compute_metrics,\n",
706
+ ")\n",
707
+ "\n",
708
+ "# Train the model\n",
709
+ "trainer.train(resume_from_checkpoint=True)"
710
+ ]
711
+ },
712
+ {
713
+ "cell_type": "code",
714
+ "execution_count": 22,
715
+ "metadata": {},
716
+ "outputs": [],
717
+ "source": [
718
+ "trainer.save_model(output_dir=\"./model\")"
719
+ ]
720
+ },
721
+ {
722
+ "cell_type": "code",
723
+ "execution_count": 26,
724
+ "metadata": {},
725
+ "outputs": [
726
+ {
727
+ "name": "stdout",
728
+ "output_type": "stream",
729
+ "text": [
730
+ "Dataset({\n",
731
+ " features: ['input_values', 'labels'],\n",
732
+ " num_rows: 22\n",
733
+ "})\n"
734
+ ]
735
+ }
736
+ ],
737
+ "source": [
738
+ "import glob\n",
739
+ "unseen_files = glob.glob(\"/workspace/ai/*\")\n",
740
+ "unseen_set = datasets.Dataset.from_dict({\"input_values\": unseen_files}).cast_column(\"input_values\", datasets.Audio(sampling_rate=16000, mono=True))\n",
741
+ "unseen_set = unseen_set.add_column(name=\"labels\", column=[1 for _ in range(len(unseen_set))])\n",
742
+ "unseen_set.set_transform(preprocess_audio, output_all_columns=False)\n",
743
+ "print(unseen_set)"
744
+ ]
745
+ },
746
+ {
747
+ "cell_type": "code",
748
+ "execution_count": 27,
749
+ "metadata": {},
750
+ "outputs": [
751
+ {
752
+ "data": {
753
+ "text/html": [],
754
+ "text/plain": [
755
+ "<IPython.core.display.HTML object>"
756
+ ]
757
+ },
758
+ "metadata": {},
759
+ "output_type": "display_data"
760
+ },
761
+ {
762
+ "data": {
763
+ "text/plain": [
764
+ "PredictionOutput(predictions=array([[-6.1512766, 6.060689 ],\n",
765
+ " [-5.978138 , 5.743587 ],\n",
766
+ " [-4.0873733, 4.6713266],\n",
767
+ " [-4.008548 , 4.2211466],\n",
768
+ " [-5.873764 , 5.7459254],\n",
769
+ " [-6.206414 , 6.235821 ],\n",
770
+ " [-4.825156 , 4.8879967],\n",
771
+ " [ 2.4498227, -1.9184169],\n",
772
+ " [-5.554337 , 5.638381 ],\n",
773
+ " [-6.2935424, 6.2818317],\n",
774
+ " [-5.4350233, 5.3958435],\n",
775
+ " [-5.253522 , 5.241722 ],\n",
776
+ " [-3.9684274, 3.9555552],\n",
777
+ " [-6.3393865, 6.066998 ],\n",
778
+ " [-6.2268295, 5.997632 ],\n",
779
+ " [-6.1494975, 6.1331954],\n",
780
+ " [-5.7538185, 5.824904 ],\n",
781
+ " [ 3.1460629, -2.850086 ],\n",
782
+ " [-1.4815819, 2.1283977],\n",
783
+ " [-5.2852707, 5.146372 ],\n",
784
+ " [-6.6310973, 6.3678217],\n",
785
+ " [ 4.265127 , -3.486055 ]], dtype=float32), label_ids=array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), metrics={'test_loss': 0.8253983855247498, 'test_accuracy': 0.8636363636363636, 'test_precision': 1.0, 'test_recall': 0.8636363636363636, 'test_f1': 0.926829268292683, 'test_runtime': 6.5504, 'test_samples_per_second': 3.359, 'test_steps_per_second': 0.458})"
786
+ ]
787
+ },
788
+ "execution_count": 27,
789
+ "metadata": {},
790
+ "output_type": "execute_result"
791
+ }
792
+ ],
793
+ "source": [
794
+ "trainer.predict(unseen_set)"
795
+ ]
796
+ }
797
+ ],
798
+ "metadata": {
799
+ "kernelspec": {
800
+ "display_name": "base",
801
+ "language": "python",
802
+ "name": "python3"
803
+ },
804
+ "language_info": {
805
+ "codemirror_mode": {
806
+ "name": "ipython",
807
+ "version": 3
808
+ },
809
+ "file_extension": ".py",
810
+ "mimetype": "text/x-python",
811
+ "name": "python",
812
+ "nbconvert_exporter": "python",
813
+ "pygments_lexer": "ipython3",
814
+ "version": "3.10.13"
815
+ }
816
+ },
817
+ "nbformat": 4,
818
+ "nbformat_minor": 4
819
+ }