Transformers documentation

Audio classification

You are viewing v4.19.4 version. A newer version v4.42.0 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Audio classification

Audio classification assigns a label or class to audio data. It is similar to text classification, except an audio input is continuous and must be discretized, whereas text can be split into tokens. Some practical applications of audio classification include identifying intent, speakers, and even animal species by their sounds.

This guide will show you how to fine-tune Wav2Vec2 on the MInDS-14 to classify intent.

See the audio classification task page for more information about its associated models, datasets, and metrics.

Load MInDS-14 dataset

Load the MInDS-14 from the 🤗 Datasets library:

>>> from datasets import load_dataset, Audio

>>> minds = load_dataset("PolyAI/minds14", name="en-US", split="train")

Split this dataset into a train and test set:

>>> minds = minds.train_test_split(test_size=0.2)

Then take a look at the dataset:

>>> minds
    train: Dataset({
        features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'],
        num_rows: 450
    test: Dataset({
        features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'],
        num_rows: 113

While the dataset contains a lot of other useful information, like lang_id and english_transcription, you will focus on the audio and intent_class in this guide. Remove the other columns:

>>> minds = minds.remove_columns(["path", "transcription", "english_transcription", "lang_id"])

Take a look at an example now:

>>> minds["train"][0]
{'audio': {'array': array([ 0.        ,  0.        ,  0.        , ..., -0.00048828,
         -0.00024414, -0.00024414], dtype=float32),
  'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602b9a5fbb1e6d0fbce91f52.wav',
  'sampling_rate': 8000},
 'intent_class': 2}

The audio column contains a 1-dimensional array of the speech signal that must be called to load and resample the audio file. The intent_class column is an integer that represents the class id of intent. Create a dictionary that maps a label name to an integer and vice versa. The mapping will help the model recover the label name from the label number:

>>> labels = minds["train"].features["intent_class"].names
>>> label2id, id2label = dict(), dict()
>>> for i, label in enumerate(labels):
...     label2id[label] = str(i)
...     id2label[str(i)] = label

Now you can convert the label number to a label name for more information:

>>> id2label[str(2)]

Each keyword - or label - corresponds to a number; 2 indicates app_error in the example above.


Load the Wav2Vec2 feature extractor to process the audio signal:

>>> from transformers import AutoFeatureExtractor

>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")

The MInDS-14 dataset has a sampling rate of 8000khz. You will need to resample the dataset to use the pretrained Wav2Vec2 model:

>>> minds = minds.cast_column("audio", Audio(sampling_rate=16_000))
>>> minds["train"][0]
{'audio': {'array': array([ 2.2098757e-05,  4.6582241e-05, -2.2803260e-05, ...,
         -2.8419291e-04, -2.3305941e-04, -1.1425107e-04], dtype=float32),
  'path': '/root/.cache/huggingface/datasets/downloads/extracted/f14948e0e84be638dd7943ac36518a4cf3324e8b7aa331c5ab11541518e9368c/en-US~APP_ERROR/602b9a5fbb1e6d0fbce91f52.wav',
  'sampling_rate': 16000},
 'intent_class': 2}

The preprocessing function needs to:

  1. Call the audio column to load and if necessary resample the audio file.
  2. Check the sampling rate of the audio file matches the sampling rate of the audio data a model was pretrained with. You can find this information on the Wav2Vec2 model card.
  3. Set a maximum input length so longer inputs are batched without being truncated.
>>> def preprocess_function(examples):
...     audio_arrays = [x["array"] for x in examples["audio"]]
...     inputs = feature_extractor(
...         audio_arrays, sampling_rate=feature_extractor.sampling_rate, max_length=16000, truncation=True
...     )
...     return inputs

Use 🤗 Datasets map function to apply the preprocessing function over the entire dataset. You can speed up the map function by setting batched=True to process multiple elements of the dataset at once. Remove the columns you don’t need, and rename intent_class to label because that is what the model expects:

>>> encoded_minds =, remove_columns="audio", batched=True)
>>> encoded_minds = encoded_minds.rename_column("intent_class", "label")


Hide Pytorch content

Load Wav2Vec2 with AutoModelForAudioClassification. Specify the number of labels, and pass the model the mapping between label number and label class:

>>> from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer

>>> num_labels = len(id2label)
>>> model = AutoModelForAudioClassification.from_pretrained(
...     "facebook/wav2vec2-base", num_labels=num_labels, label2id=label2id, id2label=id2label
... )

If you aren’t familiar with fine-tuning a model with the Trainer, take a look at the basic tutorial here!

At this point, only three steps remain:

  1. Define your training hyperparameters in TrainingArguments.
  2. Pass the training arguments to Trainer along with the model, datasets, and feature extractor.
  3. Call train() to fine-tune your model.
>>> training_args = TrainingArguments(
...     output_dir="./results",
...     evaluation_strategy="epoch",
...     save_strategy="epoch",
...     learning_rate=3e-5,
...     num_train_epochs=5,
... )

>>> trainer = Trainer(
...     model=model,
...     args=training_args,
...     train_dataset=encoded_minds["train"],
...     eval_dataset=encoded_minds["test"],
...     tokenizer=feature_extractor,
... )

>>> trainer.train()

For a more in-depth example of how to fine-tune a model for audio classification, take a look at the corresponding PyTorch notebook.