File size: 3,968 Bytes
7bb6f38
 
e6b2d99
7bb6f38
 
 
 
 
 
 
 
ff1dfd0
 
 
 
 
 
 
 
 
 
 
 
7bb6f38
 
 
 
2559184
2081854
7bb6f38
 
 
 
2559184
7bb6f38
2559184
7bb6f38
2559184
7bb6f38
81cabbe
7bb6f38
5df9582
 
 
 
 
 
 
 
 
 
 
 
 
2559184
 
5df9582
2559184
 
 
 
7bb6f38
5df9582
7bb6f38
0763a50
7bb6f38
 
 
 
 
 
 
 
 
 
 
 
 
0763a50
7bb6f38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0763a50
7bb6f38
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
---
license: apache-2.0
language: en
tags:
- generated_from_trainer
datasets:
- speech_commands
metrics:
- accuracy
model-index:
- name: wav2vec2-conformer-rel-pos-large-finetuned-speech-commands
  results:
  - task:
       type: audio-classification
       name: audio classification
    dataset: 
      type: speech_commands
      name: speech_commands
      split: v0.02
    metrics:
    - type: accuracy
      value: 0.9724
      name: accuracy
---

# wav2vec2-conformer-rel-pos-large-finetuned-speech-commands

This model is a fine-tuned version of [facebook/wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large) on the [speech_commands](https://huggingface.co/datasets/speech_commands) dataset.

It achieves the following results on the evaluation set:
- Loss: 0.5245
- Accuracy: 0.9724

### Model description

TBD

#### Intended uses & limitations

The model can spot one of the following keywords: "Yes", "No", "Up", "Down", "Left", "Right", "On", "Off", "Stop", "Go", "Zero", "One", "Two", "Three", "Four", "Five", "Six", "Seven", "Eight", "Nine", "Bed", "Bird", "Cat", "Dog", "Happy", "House", "Marvin", "Sheila", "Tree", "Wow", "Backward", "Forward", "Follow", "Learn", "Visual".

The repository includes sample files that I recorded (WAV, 16Khz sampling rate, mono). The simplest way to use the model is with the ```pipeline``` API:

```
>>> from transformers import pipeline
>>> p = pipeline("audio-classification", model="juliensimon/wav2vec2-conformer-rel-pos-large-finetuned-speech-commands")
>>> p("up16k.wav")
[{'score': 0.7008192539215088, 'label': 'up'}, {'score': 0.04346614331007004, 'label': 'off'}, {'score': 0.029526518657803535, 'label': 'left'}, {'score': 0.02905120886862278, 'label': 'stop'}, {'score': 0.027142534032464027, 'label': 'on'}]
>>> p("stop16k.wav")
[{'score': 0.6969656944274902, 'label': 'stop'}, {'score': 0.03391443192958832, 'label': 'up'}, {'score': 0.027382319793105125, 'label': 'seven'}, {'score': 0.020835857838392258, 'label': 'five'}, {'score': 0.018051736056804657, 'label': 'down'}]
>>> p("marvin16k.wav")
[{'score': 0.5276530981063843, 'label': 'marvin'}, {'score': 0.04645705968141556, 'label': 'down'}, {'score': 0.038583893328905106, 'label': 'backward'}, {'score': 0.03578080236911774, 'label': 'wow'}, {'score': 0.03178196772933006, 'label': 'bird'}]
```

### Training and evaluation data

- subset: v0.02
- full training set
- full validation set

### Training procedure

The model was fine-tuned on [Amazon SageMaker](https://aws.amazon.com/sagemaker), using an [ml.p3dn.24xlarge](https://aws.amazon.com/fr/ec2/instance-types/p3/) instance (8 NVIDIA V100 GPUs). Total training time for 10 epochs was 4.5 hours.

#### Training hyperparameters

The following hyperparameters were used during training:
- learning_rate: 3e-05
- train_batch_size: 256
- eval_batch_size: 256
- seed: 42
- gradient_accumulation_steps: 4
- total_train_batch_size: 1024
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- lr_scheduler_warmup_ratio: 0.1
- num_epochs: 10

#### Training results

| Training Loss | Epoch | Step | Validation Loss | Accuracy |
|:-------------:|:-----:|:----:|:---------------:|:--------:|
| 2.2901        | 1.0   | 83   | 2.0542          | 0.8875   |
| 1.8375        | 2.0   | 166  | 1.5610          | 0.9316   |
| 1.4957        | 3.0   | 249  | 1.1850          | 0.9558   |
| 1.1917        | 4.0   | 332  | 0.9159          | 0.9695   |
| 1.0449        | 5.0   | 415  | 0.7624          | 0.9687   |
| 0.9319        | 6.0   | 498  | 0.6444          | 0.9715   |
| 0.8559        | 7.0   | 581  | 0.5806          | 0.9711   |
| 0.8199        | 8.0   | 664  | 0.5394          | 0.9721   |
| 0.7949        | 9.0   | 747  | 0.5245          | 0.9724   |
| 0.7975        | 10.0  | 830  | 0.5256          | 0.9721   |


#### Framework versions

- Transformers 4.20.1
- Pytorch 1.11.0+cu102
- Datasets 2.3.2
- Tokenizers 0.12.1