varunril commited on
Commit
5b5b156
1 Parent(s): 6799ac4

Upload 13 files

Browse files
README.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - audio-classification
5
+ - generated_from_trainer
6
+ datasets:
7
+ - xtreme_s
8
+ metrics:
9
+ - accuracy
10
+ base_model: openai/whisper-medium
11
+ model-index:
12
+ - name: whisper-medium-fleurs-lang-id
13
+ results: []
14
+ ---
15
+
16
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
17
+ should probably proofread and complete it, then remove this comment. -->
18
+
19
+ # Whisper Medium FLEURS Language Identification
20
+
21
+ This model is a fine-tuned version of [openai/whisper-medium](https://huggingface.co/openai/whisper-medium) on the [FLEURS subset](https://huggingface.co/datasets/google/xtreme_s#language-identification---fleurs-langid) of the [google/xtreme_s](https://huggingface.co/google/xtreme_s) dataset.
22
+ It achieves the following results on the evaluation set:
23
+ - Loss: 0.8413
24
+ - Accuracy: 0.8805
25
+
26
+ To reproduce this run, execute the command in [`run.sh`](https://huggingface.co/sanchit-gandhi/whisper-medium-fleurs-lang-id/blob/main/run.sh).
27
+
28
+ ## Model description
29
+
30
+ More information needed
31
+
32
+ ## Intended uses & limitations
33
+
34
+ More information needed
35
+
36
+ ## Training and evaluation data
37
+
38
+ More information needed
39
+
40
+ ## Training procedure
41
+
42
+ ### Training hyperparameters
43
+
44
+ The following hyperparameters were used during training:
45
+ - learning_rate: 3e-05
46
+ - train_batch_size: 16
47
+ - eval_batch_size: 32
48
+ - seed: 0
49
+ - distributed_type: multi-GPU
50
+ - gradient_accumulation_steps: 2
51
+ - total_train_batch_size: 32
52
+ - optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
53
+ - lr_scheduler_type: linear
54
+ - lr_scheduler_warmup_ratio: 0.1
55
+ - num_epochs: 3.0
56
+
57
+ ### Training results
58
+
59
+ | Training Loss | Epoch | Step | Validation Loss | Accuracy |
60
+ |:-------------:|:-----:|:-----:|:---------------:|:--------:|
61
+ | 0.0152 | 1.0 | 8494 | 0.9087 | 0.8431 |
62
+ | 0.0003 | 2.0 | 16988 | 1.0059 | 0.8460 |
63
+ | 0.0 | 3.0 | 25482 | 0.8413 | 0.8805 |
64
+
65
+
66
+ ### Framework versions
67
+
68
+ - Transformers 4.27.0.dev0
69
+ - Pytorch 1.13.1
70
+ - Datasets 2.9.0
71
+ - Tokenizers 0.13.2
all_results.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 3.0,
3
+ "eval_accuracy": 0.8805294322535702,
4
+ "eval_loss": 0.84130859375,
5
+ "eval_runtime": 4369.2701,
6
+ "eval_samples_per_second": 7.885,
7
+ "eval_steps_per_second": 0.246,
8
+ "train_loss": 0.06268550049036697,
9
+ "train_runtime": 389325.9759,
10
+ "train_samples_per_second": 2.094,
11
+ "train_steps_per_second": 0.065
12
+ }
config.json ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "sanchit-gandhi/whisper-medium-fleurs-lang-id",
3
+ "activation_dropout": 0.0,
4
+ "activation_function": "gelu",
5
+ "apply_spec_augment": false,
6
+ "architectures": [
7
+ "WhisperForAudioClassification"
8
+ ],
9
+ "attention_dropout": 0.0,
10
+ "begin_suppress_tokens": [
11
+ 220,
12
+ 50257
13
+ ],
14
+ "bos_token_id": 50257,
15
+ "classifier_proj_size": 256,
16
+ "d_model": 1024,
17
+ "decoder_attention_heads": 16,
18
+ "decoder_ffn_dim": 4096,
19
+ "decoder_layerdrop": 0.0,
20
+ "decoder_layers": 24,
21
+ "decoder_start_token_id": 50258,
22
+ "dropout": 0.0,
23
+ "encoder_attention_heads": 16,
24
+ "encoder_ffn_dim": 4096,
25
+ "encoder_layerdrop": 0.0,
26
+ "encoder_layers": 24,
27
+ "eos_token_id": 50257,
28
+ "finetuning_task": "audio-classification",
29
+ "forced_decoder_ids": [
30
+ [
31
+ 1,
32
+ 50259
33
+ ],
34
+ [
35
+ 2,
36
+ 50359
37
+ ],
38
+ [
39
+ 3,
40
+ 50363
41
+ ]
42
+ ],
43
+ "id2label": {
44
+ "0": "Afrikaans",
45
+ "1": "Amharic",
46
+ "2": "Arabic",
47
+ "3": "Assamese",
48
+ "4": "Asturian",
49
+ "5": "Azerbaijani",
50
+ "6": "Belarusian",
51
+ "7": "Bulgarian",
52
+ "8": "Bengali",
53
+ "9": "Bosnian",
54
+ "10": "Catalan",
55
+ "11": "Cebuano",
56
+ "12": "Sorani-Kurdish",
57
+ "13": "Mandarin Chinese",
58
+ "14": "Czech",
59
+ "15": "Welsh",
60
+ "16": "Danish",
61
+ "17": "German",
62
+ "18": "Greek",
63
+ "19": "English",
64
+ "20": "Spanish",
65
+ "21": "Estonian",
66
+ "22": "Persian",
67
+ "23": "Fula",
68
+ "24": "Finnish",
69
+ "25": "Filipino",
70
+ "26": "French",
71
+ "27": "Irish",
72
+ "28": "Galician",
73
+ "29": "Gujarati",
74
+ "30": "Hausa",
75
+ "31": "Hebrew",
76
+ "32": "Hindi",
77
+ "33": "Croatian",
78
+ "34": "Hungarian",
79
+ "35": "Armenian",
80
+ "36": "Indonesian",
81
+ "37": "Igbo",
82
+ "38": "Icelandic",
83
+ "39": "Italian",
84
+ "40": "Japanese",
85
+ "41": "Javanese",
86
+ "42": "Georgian",
87
+ "43": "Kamba",
88
+ "44": "Kabuverdianu",
89
+ "45": "Kazakh",
90
+ "46": "Khmer",
91
+ "47": "Kannada",
92
+ "48": "Korean",
93
+ "49": "Kyrgyz",
94
+ "50": "Luxembourgish",
95
+ "51": "Ganda",
96
+ "52": "Lingala",
97
+ "53": "Lao",
98
+ "54": "Lithuanian",
99
+ "55": "Luo",
100
+ "56": "Latvian",
101
+ "57": "Maori",
102
+ "58": "Macedonian",
103
+ "59": "Malayalam",
104
+ "60": "Mongolian",
105
+ "61": "Marathi",
106
+ "62": "Malay",
107
+ "63": "Maltese",
108
+ "64": "Burmese",
109
+ "65": "Norwegian",
110
+ "66": "Nepali",
111
+ "67": "Dutch",
112
+ "68": "Northern-Sotho",
113
+ "69": "Nyanja",
114
+ "70": "Occitan",
115
+ "71": "Oromo",
116
+ "72": "Oriya",
117
+ "73": "Punjabi",
118
+ "74": "Polish",
119
+ "75": "Pashto",
120
+ "76": "Portuguese",
121
+ "77": "Romanian",
122
+ "78": "Russian",
123
+ "79": "Sindhi",
124
+ "80": "Slovak",
125
+ "81": "Slovenian",
126
+ "82": "Shona",
127
+ "83": "Somali",
128
+ "84": "Serbian",
129
+ "85": "Swedish",
130
+ "86": "Swahili",
131
+ "87": "Tamil",
132
+ "88": "Telugu",
133
+ "89": "Tajik",
134
+ "90": "Thai",
135
+ "91": "Turkish",
136
+ "92": "Ukrainian",
137
+ "93": "Umbundu",
138
+ "94": "Urdu",
139
+ "95": "Uzbek",
140
+ "96": "Vietnamese",
141
+ "97": "Wolof",
142
+ "98": "Xhosa",
143
+ "99": "Yoruba",
144
+ "100": "Cantonese Chinese",
145
+ "101": "Zulu"
146
+ },
147
+ "init_std": 0.02,
148
+ "is_encoder_decoder": true,
149
+ "label2id": {
150
+ "Afrikaans": "0",
151
+ "Amharic": "1",
152
+ "Arabic": "2",
153
+ "Armenian": "35",
154
+ "Assamese": "3",
155
+ "Asturian": "4",
156
+ "Azerbaijani": "5",
157
+ "Belarusian": "6",
158
+ "Bengali": "8",
159
+ "Bosnian": "9",
160
+ "Bulgarian": "7",
161
+ "Burmese": "64",
162
+ "Cantonese Chinese": "100",
163
+ "Catalan": "10",
164
+ "Cebuano": "11",
165
+ "Croatian": "33",
166
+ "Czech": "14",
167
+ "Danish": "16",
168
+ "Dutch": "67",
169
+ "English": "19",
170
+ "Estonian": "21",
171
+ "Filipino": "25",
172
+ "Finnish": "24",
173
+ "French": "26",
174
+ "Fula": "23",
175
+ "Galician": "28",
176
+ "Ganda": "51",
177
+ "Georgian": "42",
178
+ "German": "17",
179
+ "Greek": "18",
180
+ "Gujarati": "29",
181
+ "Hausa": "30",
182
+ "Hebrew": "31",
183
+ "Hindi": "32",
184
+ "Hungarian": "34",
185
+ "Icelandic": "38",
186
+ "Igbo": "37",
187
+ "Indonesian": "36",
188
+ "Irish": "27",
189
+ "Italian": "39",
190
+ "Japanese": "40",
191
+ "Javanese": "41",
192
+ "Kabuverdianu": "44",
193
+ "Kamba": "43",
194
+ "Kannada": "47",
195
+ "Kazakh": "45",
196
+ "Khmer": "46",
197
+ "Korean": "48",
198
+ "Kyrgyz": "49",
199
+ "Lao": "53",
200
+ "Latvian": "56",
201
+ "Lingala": "52",
202
+ "Lithuanian": "54",
203
+ "Luo": "55",
204
+ "Luxembourgish": "50",
205
+ "Macedonian": "58",
206
+ "Malay": "62",
207
+ "Malayalam": "59",
208
+ "Maltese": "63",
209
+ "Mandarin Chinese": "13",
210
+ "Maori": "57",
211
+ "Marathi": "61",
212
+ "Mongolian": "60",
213
+ "Nepali": "66",
214
+ "Northern-Sotho": "68",
215
+ "Norwegian": "65",
216
+ "Nyanja": "69",
217
+ "Occitan": "70",
218
+ "Oriya": "72",
219
+ "Oromo": "71",
220
+ "Pashto": "75",
221
+ "Persian": "22",
222
+ "Polish": "74",
223
+ "Portuguese": "76",
224
+ "Punjabi": "73",
225
+ "Romanian": "77",
226
+ "Russian": "78",
227
+ "Serbian": "84",
228
+ "Shona": "82",
229
+ "Sindhi": "79",
230
+ "Slovak": "80",
231
+ "Slovenian": "81",
232
+ "Somali": "83",
233
+ "Sorani-Kurdish": "12",
234
+ "Spanish": "20",
235
+ "Swahili": "86",
236
+ "Swedish": "85",
237
+ "Tajik": "89",
238
+ "Tamil": "87",
239
+ "Telugu": "88",
240
+ "Thai": "90",
241
+ "Turkish": "91",
242
+ "Ukrainian": "92",
243
+ "Umbundu": "93",
244
+ "Urdu": "94",
245
+ "Uzbek": "95",
246
+ "Vietnamese": "96",
247
+ "Welsh": "15",
248
+ "Wolof": "97",
249
+ "Xhosa": "98",
250
+ "Yoruba": "99",
251
+ "Zulu": "101"
252
+ },
253
+ "mask_feature_length": 10,
254
+ "mask_feature_min_masks": 0,
255
+ "mask_feature_prob": 0.0,
256
+ "mask_time_length": 10,
257
+ "mask_time_min_masks": 2,
258
+ "mask_time_prob": 0.05,
259
+ "max_length": 448,
260
+ "max_source_positions": 1500,
261
+ "max_target_positions": 448,
262
+ "model_type": "whisper",
263
+ "num_hidden_layers": 24,
264
+ "num_mel_bins": 80,
265
+ "pad_token_id": 50257,
266
+ "scale_embedding": false,
267
+ "suppress_tokens": [
268
+ 1,
269
+ 2,
270
+ 7,
271
+ 8,
272
+ 9,
273
+ 10,
274
+ 14,
275
+ 25,
276
+ 26,
277
+ 27,
278
+ 28,
279
+ 29,
280
+ 31,
281
+ 58,
282
+ 59,
283
+ 60,
284
+ 61,
285
+ 62,
286
+ 63,
287
+ 90,
288
+ 91,
289
+ 92,
290
+ 93,
291
+ 359,
292
+ 503,
293
+ 522,
294
+ 542,
295
+ 873,
296
+ 893,
297
+ 902,
298
+ 918,
299
+ 922,
300
+ 931,
301
+ 1350,
302
+ 1853,
303
+ 1982,
304
+ 2460,
305
+ 2627,
306
+ 3246,
307
+ 3253,
308
+ 3268,
309
+ 3536,
310
+ 3846,
311
+ 3961,
312
+ 4183,
313
+ 4667,
314
+ 6585,
315
+ 6647,
316
+ 7273,
317
+ 9061,
318
+ 9383,
319
+ 10428,
320
+ 10929,
321
+ 11938,
322
+ 12033,
323
+ 12331,
324
+ 12562,
325
+ 13793,
326
+ 14157,
327
+ 14635,
328
+ 15265,
329
+ 15618,
330
+ 16553,
331
+ 16604,
332
+ 18362,
333
+ 18956,
334
+ 20075,
335
+ 21675,
336
+ 22520,
337
+ 26130,
338
+ 26161,
339
+ 26435,
340
+ 28279,
341
+ 29464,
342
+ 31650,
343
+ 32302,
344
+ 32470,
345
+ 36865,
346
+ 42863,
347
+ 47425,
348
+ 49870,
349
+ 50254,
350
+ 50258,
351
+ 50360,
352
+ 50361,
353
+ 50362
354
+ ],
355
+ "torch_dtype": "float16",
356
+ "transformers_version": "4.30.0.dev0",
357
+ "use_cache": true,
358
+ "use_weighted_layer_sum": false,
359
+ "vocab_size": 51865
360
+ }
ds_config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+
11
+ "optimizer": {
12
+ "type": "AdamW",
13
+ "params": {
14
+ "lr": "auto",
15
+ "betas": "auto",
16
+ "eps": "auto",
17
+ "weight_decay": "auto"
18
+ }
19
+ },
20
+
21
+ "scheduler": {
22
+ "type": "WarmupDecayLR",
23
+ "params": {
24
+ "last_batch_iteration": -1,
25
+ "total_num_steps": "auto",
26
+ "warmup_min_lr": "auto",
27
+ "warmup_max_lr": "auto",
28
+ "warmup_num_steps": "auto"
29
+ }
30
+ },
31
+
32
+ "zero_optimization": {
33
+ "stage": 2,
34
+ "offload_optimizer": {
35
+ "device": "cpu",
36
+ "pin_memory": true
37
+ },
38
+ "allgather_partitions": true,
39
+ "allgather_bucket_size": 2e8,
40
+ "overlap_comm": true,
41
+ "reduce_scatter": true,
42
+ "reduce_bucket_size": 2e8,
43
+ "contiguous_gradients": true
44
+ },
45
+
46
+ "gradient_accumulation_steps": "auto",
47
+ "gradient_clipping": "auto",
48
+ "train_batch_size": "auto",
49
+ "train_micro_batch_size_per_gpu": "auto"
50
+ }
eval_results.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 3.0,
3
+ "eval_accuracy": 0.8805294322535702,
4
+ "eval_loss": 0.84130859375,
5
+ "eval_runtime": 4369.2701,
6
+ "eval_samples_per_second": 7.885,
7
+ "eval_steps_per_second": 0.246
8
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfe1b47efa122ea382e06b17209f1c8b7424d39b6e3224520e601fbb9cd5aaa2
3
+ size 615050492
preprocessor_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "feature_extractor_type": "WhisperFeatureExtractor",
4
+ "feature_size": 80,
5
+ "hop_length": 160,
6
+ "n_fft": 400,
7
+ "n_samples": 480000,
8
+ "nb_max_frames": 3000,
9
+ "padding_side": "right",
10
+ "padding_value": 0.0,
11
+ "processor_class": "WhisperProcessor",
12
+ "return_attention_mask": false,
13
+ "sampling_rate": 16000
14
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ded31e35036a85fe27b810198c6f9dd332d8b506df244c156b3af8524a01bce
3
+ size 615058493
run.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ deepspeed run_audio_classification.py \
2
+ --deepspeed ds_config.json \
3
+ --model_name_or_path openai/whisper-medium \
4
+ --dataset_name google/xtreme_s \
5
+ --dataset_config_name fleurs.all \
6
+ --output_dir ./ \
7
+ --overwrite_output_dir \
8
+ --remove_unused_columns False \
9
+ --do_train \
10
+ --do_eval \
11
+ --fp16 \
12
+ --learning_rate 3e-5 \
13
+ --max_length_seconds 30 \
14
+ --label_column_name lang_id \
15
+ --attention_mask False \
16
+ --warmup_ratio 0.1 \
17
+ --num_train_epochs 3 \
18
+ --per_device_train_batch_size 16 \
19
+ --gradient_accumulation_steps 2 \
20
+ --gradient_checkpointing True \
21
+ --per_device_eval_batch_size 32 \
22
+ --dataloader_num_workers 8 \
23
+ --logging_strategy steps \
24
+ --logging_steps 25 \
25
+ --evaluation_strategy epoch \
26
+ --save_strategy epoch \
27
+ --load_best_model_at_end True \
28
+ --metric_for_best_model accuracy \
29
+ --seed 0 \
30
+ --freeze_feature_encoder False \
31
+ --push_to_hub
run_audio_classification.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import logging
18
+ import os
19
+ import sys
20
+ import warnings
21
+ from dataclasses import dataclass, field
22
+ from random import randint
23
+ from typing import Optional
24
+
25
+ import datasets
26
+ import evaluate
27
+ import numpy as np
28
+ from datasets import DatasetDict, load_dataset
29
+
30
+ import transformers
31
+ from transformers import (
32
+ AutoConfig,
33
+ AutoFeatureExtractor,
34
+ AutoModelForAudioClassification,
35
+ HfArgumentParser,
36
+ Trainer,
37
+ TrainingArguments,
38
+ set_seed,
39
+ )
40
+ from transformers.trainer_utils import get_last_checkpoint
41
+ from transformers.utils import check_min_version, send_example_telemetry
42
+ from transformers.utils.versions import require_version
43
+
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
48
+ check_min_version("4.27.0.dev0")
49
+
50
+ require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")
51
+
52
+
53
+ def random_subsample(wav: np.ndarray, max_length: float, sample_rate: int = 16000):
54
+ """Randomly sample chunks of `max_length` seconds from the input audio"""
55
+ sample_length = int(round(sample_rate * max_length))
56
+ if len(wav) <= sample_length:
57
+ return wav
58
+ random_offset = randint(0, len(wav) - sample_length - 1)
59
+ return wav[random_offset : random_offset + sample_length]
60
+
61
+
62
+ @dataclass
63
+ class DataTrainingArguments:
64
+ """
65
+ Arguments pertaining to what data we are going to input our model for training and eval.
66
+ Using `HfArgumentParser` we can turn this class
67
+ into argparse arguments to be able to specify them on
68
+ the command line.
69
+ """
70
+
71
+ dataset_name: Optional[str] = field(default=None, metadata={"help": "Name of a dataset from the datasets package"})
72
+ dataset_config_name: Optional[str] = field(
73
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
74
+ )
75
+ train_file: Optional[str] = field(
76
+ default=None, metadata={"help": "A file containing the training audio paths and labels."}
77
+ )
78
+ eval_file: Optional[str] = field(
79
+ default=None, metadata={"help": "A file containing the validation audio paths and labels."}
80
+ )
81
+ train_split_name: str = field(
82
+ default="train",
83
+ metadata={
84
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
85
+ },
86
+ )
87
+ eval_split_name: str = field(
88
+ default="validation",
89
+ metadata={
90
+ "help": (
91
+ "The name of the training data set split to use (via the datasets library). Defaults to 'validation'"
92
+ )
93
+ },
94
+ )
95
+ audio_column_name: str = field(
96
+ default="audio",
97
+ metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
98
+ )
99
+ label_column_name: str = field(
100
+ default="label", metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'"}
101
+ )
102
+ max_train_samples: Optional[int] = field(
103
+ default=None,
104
+ metadata={
105
+ "help": (
106
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
107
+ "value if set."
108
+ )
109
+ },
110
+ )
111
+ max_eval_samples: Optional[int] = field(
112
+ default=None,
113
+ metadata={
114
+ "help": (
115
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
116
+ "value if set."
117
+ )
118
+ },
119
+ )
120
+ max_length_seconds: float = field(
121
+ default=20,
122
+ metadata={"help": "Audio clips will be randomly cut to this length during training if the value is set."},
123
+ )
124
+
125
+
126
+ @dataclass
127
+ class ModelArguments:
128
+ """
129
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
130
+ """
131
+
132
+ model_name_or_path: str = field(
133
+ default="facebook/wav2vec2-base",
134
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"},
135
+ )
136
+ config_name: Optional[str] = field(
137
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
138
+ )
139
+ cache_dir: Optional[str] = field(
140
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from the Hub"}
141
+ )
142
+ model_revision: str = field(
143
+ default="main",
144
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
145
+ )
146
+ feature_extractor_name: Optional[str] = field(
147
+ default=None, metadata={"help": "Name or path of preprocessor config."}
148
+ )
149
+ freeze_feature_encoder: bool = field(
150
+ default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
151
+ )
152
+ attention_mask: bool = field(
153
+ default=True, metadata={"help": "Whether to generate an attention mask in the feature extractor."}
154
+ )
155
+ use_auth_token: bool = field(
156
+ default=False,
157
+ metadata={
158
+ "help": (
159
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
160
+ "with private models)."
161
+ )
162
+ },
163
+ )
164
+ freeze_feature_extractor: Optional[bool] = field(
165
+ default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
166
+ )
167
+ ignore_mismatched_sizes: bool = field(
168
+ default=False,
169
+ metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
170
+ )
171
+
172
+ def __post_init__(self):
173
+ if not self.freeze_feature_extractor and self.freeze_feature_encoder:
174
+ warnings.warn(
175
+ "The argument `--freeze_feature_extractor` is deprecated and "
176
+ "will be removed in a future version. Use `--freeze_feature_encoder`"
177
+ "instead. Setting `freeze_feature_encoder==True`.",
178
+ FutureWarning,
179
+ )
180
+ if self.freeze_feature_extractor and not self.freeze_feature_encoder:
181
+ raise ValueError(
182
+ "The argument `--freeze_feature_extractor` is deprecated and "
183
+ "should not be used in combination with `--freeze_feature_encoder`."
184
+ "Only make use of `--freeze_feature_encoder`."
185
+ )
186
+
187
+
188
+ def main():
189
+ # See all possible arguments in src/transformers/training_args.py
190
+ # or by passing the --help flag to this script.
191
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
192
+
193
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
194
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
195
+ # If we pass only one argument to the script and it's the path to a json file,
196
+ # let's parse it to get our arguments.
197
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
198
+ else:
199
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
200
+
201
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
202
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
203
+ send_example_telemetry("run_audio_classification", model_args, data_args)
204
+
205
+ # Setup logging
206
+ logging.basicConfig(
207
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
208
+ datefmt="%m/%d/%Y %H:%M:%S",
209
+ handlers=[logging.StreamHandler(sys.stdout)],
210
+ )
211
+
212
+ if training_args.should_log:
213
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
214
+ transformers.utils.logging.set_verbosity_info()
215
+
216
+ log_level = training_args.get_process_log_level()
217
+ logger.setLevel(log_level)
218
+ transformers.utils.logging.set_verbosity(log_level)
219
+ transformers.utils.logging.enable_default_handler()
220
+ transformers.utils.logging.enable_explicit_format()
221
+
222
+ # Log on each process the small summary:
223
+ logger.warning(
224
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu} "
225
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
226
+ )
227
+ logger.info(f"Training/evaluation parameters {training_args}")
228
+
229
+ # Set seed before initializing model.
230
+ set_seed(training_args.seed)
231
+
232
+ # Detecting last checkpoint.
233
+ last_checkpoint = None
234
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
235
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
236
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
237
+ raise ValueError(
238
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
239
+ "Use --overwrite_output_dir to train from scratch."
240
+ )
241
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
242
+ logger.info(
243
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
244
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
245
+ )
246
+
247
+ # Initialize our dataset and prepare it for the audio classification task.
248
+ raw_datasets = DatasetDict()
249
+ raw_datasets["train"] = load_dataset(
250
+ data_args.dataset_name,
251
+ data_args.dataset_config_name,
252
+ split=data_args.train_split_name,
253
+ use_auth_token=True if model_args.use_auth_token else None,
254
+ )
255
+ raw_datasets["eval"] = load_dataset(
256
+ data_args.dataset_name,
257
+ data_args.dataset_config_name,
258
+ split=data_args.eval_split_name,
259
+ use_auth_token=True if model_args.use_auth_token else None,
260
+ )
261
+
262
+ if data_args.audio_column_name not in raw_datasets["train"].column_names:
263
+ raise ValueError(
264
+ f"--audio_column_name {data_args.audio_column_name} not found in dataset '{data_args.dataset_name}'. "
265
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
266
+ f"{', '.join(raw_datasets['train'].column_names)}."
267
+ )
268
+
269
+ if data_args.label_column_name not in raw_datasets["train"].column_names:
270
+ raise ValueError(
271
+ f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. "
272
+ "Make sure to set `--label_column_name` to the correct text column - one of "
273
+ f"{', '.join(raw_datasets['train'].column_names)}."
274
+ )
275
+
276
+ # Setting `return_attention_mask=True` is the way to get a correctly masked mean-pooling over
277
+ # transformer outputs in the classifier, but it doesn't always lead to better accuracy
278
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
279
+ model_args.feature_extractor_name or model_args.model_name_or_path,
280
+ return_attention_mask=model_args.attention_mask,
281
+ cache_dir=model_args.cache_dir,
282
+ revision=model_args.model_revision,
283
+ use_auth_token=True if model_args.use_auth_token else None,
284
+ )
285
+
286
+ # `datasets` takes care of automatically loading and resampling the audio,
287
+ # so we just need to set the correct target sampling rate.
288
+ raw_datasets = raw_datasets.cast_column(
289
+ data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
290
+ )
291
+
292
+ model_input_name = feature_extractor.model_input_names[0]
293
+
294
+ def train_transforms(batch):
295
+ """Apply train_transforms across a batch."""
296
+ subsampled_wavs = []
297
+ for audio in batch[data_args.audio_column_name]:
298
+ wav = random_subsample(
299
+ audio["array"], max_length=data_args.max_length_seconds, sample_rate=feature_extractor.sampling_rate
300
+ )
301
+ subsampled_wavs.append(wav)
302
+ inputs = feature_extractor(subsampled_wavs, sampling_rate=feature_extractor.sampling_rate)
303
+ output_batch = {model_input_name: inputs.get(model_input_name)}
304
+ output_batch["labels"] = list(batch[data_args.label_column_name])
305
+
306
+ return output_batch
307
+
308
+ def val_transforms(batch):
309
+ """Apply val_transforms across a batch."""
310
+ wavs = [audio["array"] for audio in batch[data_args.audio_column_name]]
311
+ inputs = feature_extractor(wavs, sampling_rate=feature_extractor.sampling_rate)
312
+ output_batch = {model_input_name: inputs.get(model_input_name)}
313
+ output_batch["labels"] = list(batch[data_args.label_column_name])
314
+
315
+ return output_batch
316
+
317
+ # Prepare label mappings.
318
+ # We'll include these in the model's config to get human readable labels in the Inference API.
319
+ labels = raw_datasets["train"].features[data_args.label_column_name].names
320
+ label2id, id2label = {}, {}
321
+ for i, label in enumerate(labels):
322
+ label2id[label] = str(i)
323
+ id2label[str(i)] = label
324
+
325
+ # Load the accuracy metric from the datasets package
326
+ metric = evaluate.load("accuracy")
327
+
328
+ # Define our compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with
329
+ # `predictions` and `label_ids` fields) and has to return a dictionary string to float.
330
+ def compute_metrics(eval_pred):
331
+ """Computes accuracy on a batch of predictions"""
332
+ predictions = np.argmax(eval_pred.predictions, axis=1)
333
+ return metric.compute(predictions=predictions, references=eval_pred.label_ids)
334
+
335
+ config = AutoConfig.from_pretrained(
336
+ model_args.config_name or model_args.model_name_or_path,
337
+ num_labels=len(labels),
338
+ label2id=label2id,
339
+ id2label=id2label,
340
+ finetuning_task="audio-classification",
341
+ cache_dir=model_args.cache_dir,
342
+ revision=model_args.model_revision,
343
+ use_auth_token=True if model_args.use_auth_token else None,
344
+ )
345
+ model = AutoModelForAudioClassification.from_pretrained(
346
+ model_args.model_name_or_path,
347
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
348
+ config=config,
349
+ cache_dir=model_args.cache_dir,
350
+ revision=model_args.model_revision,
351
+ use_auth_token=True if model_args.use_auth_token else None,
352
+ ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
353
+ )
354
+
355
+ # freeze the convolutional waveform encoder
356
+ if model_args.freeze_feature_encoder:
357
+ model.freeze_feature_encoder()
358
+
359
+ if training_args.do_train:
360
+ if data_args.max_train_samples is not None:
361
+ raw_datasets["train"] = (
362
+ raw_datasets["train"].shuffle(seed=training_args.seed).select(range(data_args.max_train_samples))
363
+ )
364
+ # Set the training transforms
365
+ raw_datasets["train"].set_transform(train_transforms, output_all_columns=False)
366
+
367
+ if training_args.do_eval:
368
+ if data_args.max_eval_samples is not None:
369
+ raw_datasets["eval"] = (
370
+ raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
371
+ )
372
+ # Set the validation transforms
373
+ raw_datasets["eval"].set_transform(val_transforms, output_all_columns=False)
374
+
375
+ # Initialize our trainer
376
+ trainer = Trainer(
377
+ model=model,
378
+ args=training_args,
379
+ train_dataset=raw_datasets["train"] if training_args.do_train else None,
380
+ eval_dataset=raw_datasets["eval"] if training_args.do_eval else None,
381
+ compute_metrics=compute_metrics,
382
+ tokenizer=feature_extractor,
383
+ )
384
+
385
+ # Training
386
+ if training_args.do_train:
387
+ checkpoint = None
388
+ if training_args.resume_from_checkpoint is not None:
389
+ checkpoint = training_args.resume_from_checkpoint
390
+ elif last_checkpoint is not None:
391
+ checkpoint = last_checkpoint
392
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
393
+ trainer.save_model()
394
+ trainer.log_metrics("train", train_result.metrics)
395
+ trainer.save_metrics("train", train_result.metrics)
396
+ trainer.save_state()
397
+
398
+ # Evaluation
399
+ if training_args.do_eval:
400
+ metrics = trainer.evaluate()
401
+ trainer.log_metrics("eval", metrics)
402
+ trainer.save_metrics("eval", metrics)
403
+
404
+ # Write model card and (optionally) push to hub
405
+ kwargs = {
406
+ "finetuned_from": model_args.model_name_or_path,
407
+ "tasks": "audio-classification",
408
+ "dataset": data_args.dataset_name,
409
+ "tags": ["audio-classification"],
410
+ }
411
+ if training_args.push_to_hub:
412
+ trainer.push_to_hub(**kwargs)
413
+ else:
414
+ trainer.create_model_card(**kwargs)
415
+
416
+
417
+ if __name__ == "__main__":
418
+ main()
train_results.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 3.0,
3
+ "train_loss": 0.06268550049036697,
4
+ "train_runtime": 389325.9759,
5
+ "train_samples_per_second": 2.094,
6
+ "train_steps_per_second": 0.065
7
+ }
trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:056cee6adb1b2f1fcd2f38aa61d20cb381ad85614636d6b75ea4483a58612531
3
+ size 4731