w11wo commited on
Commit
03cf889
1 Parent(s): 9d3f3b7

Added eval script

Browse files
Files changed (2) hide show
  1. eval.py +473 -0
  2. eval_teacher.sh +14 -0
eval.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import logging
18
+ import os
19
+ import sys
20
+ import warnings
21
+ from dataclasses import dataclass, field
22
+ from typing import Optional
23
+
24
+ import datasets
25
+ import numpy as np
26
+ import torch
27
+ from datasets import DatasetDict, load_dataset
28
+
29
+ import transformers
30
+ from transformers import (
31
+ AutoConfig,
32
+ AutoFeatureExtractor,
33
+ AutoModelForAudioClassification,
34
+ EvalPrediction,
35
+ HfArgumentParser,
36
+ Trainer,
37
+ TrainingArguments,
38
+ set_seed,
39
+ )
40
+ from transformers.trainer_utils import get_last_checkpoint
41
+ from transformers.utils import send_example_telemetry
42
+ from transformers.utils.versions import require_version
43
+
44
+ from sklearn.metrics import (
45
+ accuracy_score,
46
+ average_precision_score,
47
+ f1_score,
48
+ roc_auc_score,
49
+ )
50
+
51
+ logger = logging.getLogger(__name__)
52
+
53
+ require_version(
54
+ "datasets>=1.14.0",
55
+ "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt",
56
+ )
57
+
58
+
59
+ class MultiLabelTrainer(Trainer):
60
+ def compute_loss(self, model, inputs, return_outputs=False):
61
+ labels = inputs.pop("labels")
62
+ outputs = model(**inputs)
63
+ logits = outputs.logits
64
+ bce_loss_fct = torch.nn.BCEWithLogitsLoss()
65
+ loss = bce_loss_fct(
66
+ logits.view(-1, self.model.config.num_labels),
67
+ labels.float().view(-1, self.model.config.num_labels),
68
+ )
69
+ return (loss, outputs) if return_outputs else loss
70
+
71
+
72
+ @dataclass
73
+ class DataTrainingArguments:
74
+ """
75
+ Arguments pertaining to what data we are going to input our model for training and eval.
76
+ Using `HfArgumentParser` we can turn this class
77
+ into argparse arguments to be able to specify them on
78
+ the command line.
79
+ """
80
+
81
+ dataset_name: Optional[str] = field(
82
+ default=None, metadata={"help": "Name of a dataset from the datasets package"}
83
+ )
84
+ dataset_config_name: Optional[str] = field(
85
+ default=None,
86
+ metadata={
87
+ "help": "The configuration name of the dataset to use (via the datasets library)."
88
+ },
89
+ )
90
+ train_file: Optional[str] = field(
91
+ default=None,
92
+ metadata={"help": "A file containing the training audio paths and labels."},
93
+ )
94
+ eval_file: Optional[str] = field(
95
+ default=None,
96
+ metadata={"help": "A file containing the validation audio paths and labels."},
97
+ )
98
+ train_split_name: str = field(
99
+ default="train",
100
+ metadata={
101
+ "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
102
+ },
103
+ )
104
+ eval_split_name: str = field(
105
+ default="validation",
106
+ metadata={
107
+ "help": (
108
+ "The name of the training data set split to use (via the datasets library). Defaults to 'validation'"
109
+ )
110
+ },
111
+ )
112
+ audio_column_name: str = field(
113
+ default="audio",
114
+ metadata={
115
+ "help": "The name of the dataset column containing the audio data. Defaults to 'audio'"
116
+ },
117
+ )
118
+ label_column_name: Optional[str] = field(
119
+ default="label",
120
+ metadata={
121
+ "help": "The name of the dataset column containing the labels. Defaults to 'label'"
122
+ },
123
+ )
124
+ max_train_samples: Optional[int] = field(
125
+ default=None,
126
+ metadata={
127
+ "help": (
128
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
129
+ "value if set."
130
+ )
131
+ },
132
+ )
133
+ max_eval_samples: Optional[int] = field(
134
+ default=None,
135
+ metadata={
136
+ "help": (
137
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
138
+ "value if set."
139
+ )
140
+ },
141
+ )
142
+ max_length_seconds: float = field(
143
+ default=20,
144
+ metadata={
145
+ "help": "Audio clips will be randomly cut to this length during training if the value is set."
146
+ },
147
+ )
148
+
149
+
150
+ @dataclass
151
+ class ModelArguments:
152
+ """
153
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
154
+ """
155
+
156
+ model_name_or_path: str = field(
157
+ default="facebook/wav2vec2-base",
158
+ metadata={
159
+ "help": "Path to pretrained model or model identifier from huggingface.co/models"
160
+ },
161
+ )
162
+ config_name: Optional[str] = field(
163
+ default=None,
164
+ metadata={
165
+ "help": "Pretrained config name or path if not the same as model_name"
166
+ },
167
+ )
168
+ cache_dir: Optional[str] = field(
169
+ default=None,
170
+ metadata={
171
+ "help": "Where do you want to store the pretrained models downloaded from the Hub"
172
+ },
173
+ )
174
+ model_revision: str = field(
175
+ default="main",
176
+ metadata={
177
+ "help": "The specific model version to use (can be a branch name, tag name or commit id)."
178
+ },
179
+ )
180
+ feature_extractor_name: Optional[str] = field(
181
+ default=None, metadata={"help": "Name or path of preprocessor config."}
182
+ )
183
+ freeze_feature_encoder: bool = field(
184
+ default=True,
185
+ metadata={"help": "Whether to freeze the feature encoder layers of the model."},
186
+ )
187
+ attention_mask: bool = field(
188
+ default=True,
189
+ metadata={
190
+ "help": "Whether to generate an attention mask in the feature extractor."
191
+ },
192
+ )
193
+ use_auth_token: bool = field(
194
+ default=False,
195
+ metadata={
196
+ "help": (
197
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
198
+ "with private models)."
199
+ )
200
+ },
201
+ )
202
+ freeze_feature_extractor: Optional[bool] = field(
203
+ default=None,
204
+ metadata={
205
+ "help": "Whether to freeze the feature extractor layers of the model."
206
+ },
207
+ )
208
+ ignore_mismatched_sizes: bool = field(
209
+ default=False,
210
+ metadata={
211
+ "help": "Will enable to load a pretrained model whose head dimensions are different."
212
+ },
213
+ )
214
+
215
+ def __post_init__(self):
216
+ if not self.freeze_feature_extractor and self.freeze_feature_encoder:
217
+ warnings.warn(
218
+ "The argument `--freeze_feature_extractor` is deprecated and "
219
+ "will be removed in a future version. Use `--freeze_feature_encoder`"
220
+ "instead. Setting `freeze_feature_encoder==True`.",
221
+ FutureWarning,
222
+ )
223
+ if self.freeze_feature_extractor and not self.freeze_feature_encoder:
224
+ raise ValueError(
225
+ "The argument `--freeze_feature_extractor` is deprecated and "
226
+ "should not be used in combination with `--freeze_feature_encoder`."
227
+ "Only make use of `--freeze_feature_encoder`."
228
+ )
229
+
230
+
231
+ def main():
232
+ # See all possible arguments in src/transformers/training_args.py
233
+ # or by passing the --help flag to this script.
234
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
235
+
236
+ parser = HfArgumentParser(
237
+ (ModelArguments, DataTrainingArguments, TrainingArguments)
238
+ )
239
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
240
+ # If we pass only one argument to the script and it's the path to a json file,
241
+ # let's parse it to get our arguments.
242
+ model_args, data_args, training_args = parser.parse_json_file(
243
+ json_file=os.path.abspath(sys.argv[1])
244
+ )
245
+ else:
246
+ (model_args, data_args, training_args) = parser.parse_args_into_dataclasses()
247
+
248
+ # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
249
+ # information sent is the one passed as arguments along with your Python/PyTorch versions.
250
+ send_example_telemetry("run_audio_classification", model_args, data_args)
251
+
252
+ # Setup logging
253
+ logging.basicConfig(
254
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
255
+ datefmt="%m/%d/%Y %H:%M:%S",
256
+ handlers=[logging.StreamHandler(sys.stdout)],
257
+ )
258
+
259
+ if training_args.should_log:
260
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
261
+ transformers.utils.logging.set_verbosity_info()
262
+
263
+ log_level = training_args.get_process_log_level()
264
+ logger.setLevel(log_level)
265
+ transformers.utils.logging.set_verbosity(log_level)
266
+ transformers.utils.logging.enable_default_handler()
267
+ transformers.utils.logging.enable_explicit_format()
268
+
269
+ # Log on each process the small summary:
270
+ logger.warning(
271
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu} "
272
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
273
+ )
274
+ logger.info(f"Training/evaluation parameters {training_args}")
275
+
276
+ # Set seed before initializing model.
277
+ set_seed(training_args.seed)
278
+
279
+ # Detecting last checkpoint.
280
+ last_checkpoint = None
281
+ if (
282
+ os.path.isdir(training_args.output_dir)
283
+ and training_args.do_train
284
+ and not training_args.overwrite_output_dir
285
+ ):
286
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
287
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
288
+ raise ValueError(
289
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
290
+ "Use --overwrite_output_dir to train from scratch."
291
+ )
292
+ elif (
293
+ last_checkpoint is not None and training_args.resume_from_checkpoint is None
294
+ ):
295
+ logger.info(
296
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
297
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
298
+ )
299
+
300
+ # Initialize our dataset and prepare it for the audio classification task.
301
+ raw_datasets = DatasetDict()
302
+ raw_datasets["train"] = load_dataset(
303
+ data_args.dataset_name,
304
+ data_args.dataset_config_name,
305
+ split=data_args.train_split_name,
306
+ use_auth_token=True if model_args.use_auth_token else None,
307
+ )
308
+ raw_datasets["eval"] = load_dataset(
309
+ data_args.dataset_name,
310
+ data_args.dataset_config_name,
311
+ split=data_args.eval_split_name,
312
+ use_auth_token=True if model_args.use_auth_token else None,
313
+ )
314
+
315
+ if data_args.audio_column_name not in raw_datasets["train"].column_names:
316
+ raise ValueError(
317
+ f"--audio_column_name {data_args.audio_column_name} not found in dataset '{data_args.dataset_name}'. "
318
+ "Make sure to set `--audio_column_name` to the correct audio column - one of "
319
+ f"{', '.join(raw_datasets['train'].column_names)}."
320
+ )
321
+
322
+ # Setting `return_attention_mask=True` is the way to get a correctly masked mean-pooling over
323
+ # transformer outputs in the classifier, but it doesn't always lead to better accuracy
324
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
325
+ model_args.feature_extractor_name or model_args.model_name_or_path,
326
+ return_attention_mask=model_args.attention_mask,
327
+ cache_dir=model_args.cache_dir,
328
+ revision=model_args.model_revision,
329
+ use_auth_token=True if model_args.use_auth_token else None,
330
+ )
331
+
332
+ # `datasets` takes care of automatically loading and resampling the audio,
333
+ # so we just need to set the correct target sampling rate.
334
+ raw_datasets = raw_datasets.cast_column(
335
+ data_args.audio_column_name,
336
+ datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate),
337
+ )
338
+
339
+ model_input_name = feature_extractor.model_input_names[0]
340
+
341
+ def preprocess_data(examples):
342
+ # get audio arrays
343
+ audio_arrays = [x["array"] for x in examples[data_args.audio_column_name]]
344
+ # encode batch of audio
345
+ inputs = feature_extractor(
346
+ audio_arrays, sampling_rate=feature_extractor.sampling_rate
347
+ )
348
+ # add labels
349
+ labels_batch = {k: examples[k] for k in examples.keys() if k in labels}
350
+ # create numpy array of shape (batch_size, num_labels)
351
+ labels_matrix = np.zeros((len(audio_arrays), len(labels)))
352
+ # fill numpy array
353
+ for idx, label in enumerate(labels):
354
+ labels_matrix[:, idx] = labels_batch[label]
355
+
356
+ output_batch = {model_input_name: inputs.get(model_input_name)}
357
+ output_batch["labels"] = labels_matrix.tolist()
358
+
359
+ return output_batch
360
+
361
+ def multi_label_metrics(predictions, labels, threshold=0.5):
362
+ # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
363
+ sigmoid = torch.nn.Sigmoid()
364
+ probs = sigmoid(torch.Tensor(predictions)).cpu().numpy()
365
+ # next, use threshold to turn them into integer predictions
366
+ y_pred = np.zeros(probs.shape)
367
+ y_pred[np.where(probs >= threshold)] = 1
368
+ # finally, compute metrics
369
+ f1_micro_average = f1_score(y_true=labels, y_pred=y_pred, average="micro")
370
+ roc_auc = roc_auc_score(labels, y_pred, average="micro")
371
+ accuracy = accuracy_score(labels, y_pred)
372
+ mAP = average_precision_score(labels, probs, average="micro")
373
+ # return as dictionary
374
+ metrics = {
375
+ "f1": f1_micro_average,
376
+ "roc_auc": roc_auc,
377
+ "accuracy": accuracy,
378
+ "mAP": mAP,
379
+ }
380
+ return metrics
381
+
382
+ def compute_metrics(p: EvalPrediction):
383
+ """Computes mean average precision (mAP) score"""
384
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
385
+ result = multi_label_metrics(predictions=preds, labels=p.label_ids)
386
+ return result
387
+
388
+ config = AutoConfig.from_pretrained(
389
+ model_args.config_name or model_args.model_name_or_path,
390
+ cache_dir=model_args.cache_dir,
391
+ revision=model_args.model_revision,
392
+ use_auth_token=True if model_args.use_auth_token else None,
393
+ )
394
+ model = AutoModelForAudioClassification.from_pretrained(
395
+ model_args.model_name_or_path,
396
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
397
+ config=config,
398
+ cache_dir=model_args.cache_dir,
399
+ revision=model_args.model_revision,
400
+ use_auth_token=True if model_args.use_auth_token else None,
401
+ ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
402
+ )
403
+
404
+ labels = list(config.id2label.values())
405
+
406
+ # freeze the convolutional waveform encoder
407
+ if model_args.freeze_feature_encoder:
408
+ model.freeze_feature_encoder()
409
+
410
+ if training_args.do_train:
411
+ if data_args.max_train_samples is not None:
412
+ raw_datasets["train"] = (
413
+ raw_datasets["train"]
414
+ .shuffle(seed=training_args.seed)
415
+ .select(range(data_args.max_train_samples))
416
+ )
417
+ # Set the training transforms
418
+ raw_datasets["train"].set_transform(preprocess_data, output_all_columns=False)
419
+
420
+ if training_args.do_eval:
421
+ if data_args.max_eval_samples is not None:
422
+ raw_datasets["eval"] = (
423
+ raw_datasets["eval"]
424
+ .shuffle(seed=training_args.seed)
425
+ .select(range(data_args.max_eval_samples))
426
+ )
427
+ # Set the validation transforms
428
+ raw_datasets["eval"].set_transform(preprocess_data, output_all_columns=False)
429
+
430
+ # Initialize our trainer
431
+ trainer = MultiLabelTrainer(
432
+ model=model,
433
+ args=training_args,
434
+ train_dataset=raw_datasets["train"] if training_args.do_train else None,
435
+ eval_dataset=raw_datasets["eval"] if training_args.do_eval else None,
436
+ compute_metrics=compute_metrics,
437
+ tokenizer=feature_extractor,
438
+ )
439
+
440
+ # Training
441
+ if training_args.do_train:
442
+ checkpoint = None
443
+ if training_args.resume_from_checkpoint is not None:
444
+ checkpoint = training_args.resume_from_checkpoint
445
+ elif last_checkpoint is not None:
446
+ checkpoint = last_checkpoint
447
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
448
+ trainer.save_model()
449
+ trainer.log_metrics("train", train_result.metrics)
450
+ trainer.save_metrics("train", train_result.metrics)
451
+ trainer.save_state()
452
+
453
+ # Evaluation
454
+ if training_args.do_eval:
455
+ metrics = trainer.evaluate()
456
+ trainer.log_metrics("eval", metrics)
457
+ trainer.save_metrics("eval", metrics)
458
+
459
+ # Write model card and (optionally) push to hub
460
+ kwargs = {
461
+ "finetuned_from": model_args.model_name_or_path,
462
+ "tasks": "audio-classification",
463
+ "dataset": data_args.dataset_name,
464
+ "tags": ["audio-classification"],
465
+ }
466
+ if training_args.push_to_hub:
467
+ trainer.push_to_hub(**kwargs)
468
+ else:
469
+ trainer.create_model_card(**kwargs)
470
+
471
+
472
+ if __name__ == "__main__":
473
+ main()
eval_teacher.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python eval.py \
2
+ --model_name_or_path MIT/ast-finetuned-audioset-10-10-0.4593 \
3
+ --dataset_name bookbot/audioset \
4
+ --output_dir ast-audioset-test \
5
+ --overwrite_output_dir \
6
+ --remove_unused_columns False \
7
+ --freeze_feature_encoder False \
8
+ --do_eval \
9
+ --fp16 \
10
+ --attention_mask False \
11
+ --per_device_eval_batch_size 32 \
12
+ --dataloader_num_workers 4 \
13
+ --seed 0 \
14
+ --report_to tensorboard