poonehmousavi commited on
Commit
1cd8cde
1 Parent(s): e631cd7

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +132 -0
  2. custom_interface.py +127 -0
  3. hparams.yaml +77 -0
  4. whisper.ckpt +3 -0
README.md ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - hi
4
+ thumbnail: null
5
+ pipeline_tag: automatic-speech-recognition
6
+ tags:
7
+ - whisper
8
+ - pytorch
9
+ - speechbrain
10
+ - Transformer
11
+ - hf-asr-leaderboard
12
+ license: apache-2.0
13
+ datasets:
14
+ - commonvoice
15
+ metrics:
16
+ - wer
17
+ - cer
18
+ model-index:
19
+ - name: asr-whisper-large-v2-commonvoice-hi
20
+ results:
21
+ - task:
22
+ name: Automatic Speech Recognition
23
+ type: automatic-speech-recognition
24
+ dataset:
25
+ name: CommonVoice 10.0 (Hindi)
26
+ type: mozilla-foundation/common_voice_10_0
27
+ config: hi
28
+ split: test
29
+ args:
30
+ language: hi
31
+ metrics:
32
+ - name: Test WER
33
+ type: wer
34
+ value: '15.27'
35
+ ---
36
+
37
+ <iframe src="https://ghbtns.com/github-btn.html?user=speechbrain&repo=speechbrain&type=star&count=true&size=large&v=2" frameborder="0" scrolling="0" width="170" height="30" title="GitHub"></iframe>
38
+ <br/><br/>
39
+
40
+ # whisper large-v2 fine-tuned on CommonVoice Hindi
41
+
42
+ This repository provides all the necessary tools to perform automatic speech
43
+ recognition from an end-to-end whisper model fine-tuned on CommonVoice (Hindi Language) within
44
+ SpeechBrain. For a better experience, we encourage you to learn more about
45
+ [SpeechBrain](https://speechbrain.github.io).
46
+
47
+ The performance of the model is the following:
48
+
49
+ | Release | Test CER | Test WER | GPUs |
50
+ |:-------------:|:--------------:|:--------------:| :--------:|
51
+ | 01-02-23 | 7.00 | 15.27 | 1xV100 16GB |
52
+
53
+ ## Pipeline description
54
+
55
+ This ASR system is composed of whisper encoder-decoder blocks:
56
+ - The pretrained whisper-large-v2 encoder is frozen.
57
+ - The pretrained Whisper tokenizer is used.
58
+ - A pretrained Whisper-large-v2 decoder ([openai/whisper-large-v2](https://huggingface.co/openai/whisper-large-v2)) is finetuned on CommonVoice MN.
59
+ The obtained final acoustic representation is given to the greedy decoder.
60
+
61
+ The system is trained with recordings sampled at 16kHz (single channel).
62
+ The code will automatically normalize your audio (i.e., resampling + mono channel selection) when calling *transcribe_file* if needed.
63
+
64
+ ## Install SpeechBrain
65
+
66
+ First of all, please install tranformers and SpeechBrain with the following command:
67
+
68
+ ```
69
+ pip install speechbrain transformers
70
+ ```
71
+
72
+ Please notice that we encourage you to read our tutorials and learn more about
73
+ [SpeechBrain](https://speechbrain.github.io).
74
+
75
+ ### Transcribing your own audio files (in Hindi)
76
+
77
+ ```python
78
+
79
+ from speechbrain.pretrained.interfaces import foreign_class
80
+
81
+ asr_model = foreign_class(source="speechbrain/asr-whisper-large-v2-commonvoice-hi", pymodule_file="custom_interface.py", classname="WhisperASR", hparams_file='hparams.yaml', savedir="pretrained_models/asr-whisper-large-v2-commonvoice-hi")
82
+ asr_model.transcribe_file('speechbrain/asr-whisper-large-v2-commonvoice-hi/example-hi.wav')
83
+
84
+
85
+ ```
86
+ ### Inference on GPU
87
+ To perform inference on the GPU, add `run_opts={"device":"cuda"}` when calling the `from_hparams` method.
88
+
89
+ ### Training
90
+ The model was trained with SpeechBrain.
91
+ To train it from scratch follow these steps:
92
+ 1. Clone SpeechBrain:
93
+ ```bash
94
+ git clone https://github.com/speechbrain/speechbrain/
95
+ ```
96
+ 2. Install it:
97
+ ```bash
98
+ cd speechbrain
99
+ pip install -r requirements.txt
100
+ pip install -e .
101
+ ```
102
+
103
+ 3. Run Training:
104
+ ```bash
105
+ cd recipes/CommonVoice/ASR/transformer/
106
+ python train_with_whisper.py hparams/train_hi_hf_whisper.yaml --data_folder=your_data_folder
107
+ ```
108
+
109
+ You can find our training results (models, logs, etc) [here](https://drive.google.com/drive/folders/11PKCsyIE703mmDv6n6n_UnD0bUgMPbg_?usp=share_link).
110
+
111
+ ### Limitations
112
+ The SpeechBrain team does not provide any warranty on the performance achieved by this model when used on other datasets.
113
+
114
+ #### Referencing SpeechBrain
115
+
116
+ ```
117
+ @misc{SB2021,
118
+ author = {Ravanelli, Mirco and Parcollet, Titouan and Rouhe, Aku and Plantinga, Peter and Rastorgueva, Elena and Lugosch, Loren and Dawalatabad, Nauman and Ju-Chieh, Chou and Heba, Abdel and Grondin, Francois and Aris, William and Liao, Chien-Feng and Cornell, Samuele and Yeh, Sung-Lin and Na, Hwidong and Gao, Yan and Fu, Szu-Wei and Subakan, Cem and De Mori, Renato and Bengio, Yoshua },
119
+ title = {SpeechBrain},
120
+ year = {2021},
121
+ publisher = {GitHub},
122
+ journal = {GitHub repository},
123
+ howpublished = {\\\\url{https://github.com/speechbrain/speechbrain}},
124
+ }
125
+ ```
126
+
127
+ #### About SpeechBrain
128
+ SpeechBrain is an open-source and all-in-one speech toolkit. It is designed to be simple, extremely flexible, and user-friendly. Competitive or state-of-the-art performance is obtained in various domains.
129
+
130
+ Website: https://speechbrain.github.io/
131
+
132
+ GitHub: https://github.com/speechbrain/speechbrain
custom_interface.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from speechbrain.pretrained import Pretrained
3
+
4
+ class WhisperASR(Pretrained):
5
+ """A ready-to-use Whisper ASR model
6
+
7
+ The class can be used to run only the encoder (encode()) to run the entire encoder-decoder whisper model
8
+ (transcribe()) to transcribe speech. The given YAML must contains the fields
9
+ specified in the *_NEEDED[] lists.
10
+
11
+ Example
12
+ -------
13
+ >>> from speechbrain.pretrained.interfaces import foreign_class
14
+ >>> tmpdir = getfixture("tmpdir")
15
+ >>> asr_model = foreign_class(source="hf",
16
+ ... pymodule_file="custom_interface.py",
17
+ ... classname="WhisperASR",
18
+ ... hparams_file='hparams.yaml',
19
+ ... savedir=tmpdir,
20
+ ... )
21
+ >>> asr_model.transcribe_file("tests/samples/example2.wav")
22
+ """
23
+
24
+ HPARAMS_NEEDED = ['language']
25
+ MODULES_NEEDED = ["whisper", "decoder"]
26
+
27
+ def __init__(self, *args, **kwargs):
28
+ super().__init__(*args, **kwargs)
29
+ self.tokenizer = self.hparams.whisper.tokenizer
30
+ self.tokenizer.set_prefix_tokens(self.hparams.language, "transcribe", False)
31
+ self.hparams.decoder.set_decoder_input_tokens(
32
+ self.tokenizer.prefix_tokens
33
+ )
34
+
35
+ def transcribe_file(self, path):
36
+ """Transcribes the given audiofile into a sequence of words.
37
+
38
+ Arguments
39
+ ---------
40
+ path : str
41
+ Path to audio file which to transcribe.
42
+
43
+ Returns
44
+ -------
45
+ str
46
+ The audiofile transcription produced by this ASR system.
47
+ """
48
+ waveform = self.load_audio(path)
49
+ # Fake a batch:
50
+ batch = waveform.unsqueeze(0)
51
+ rel_length = torch.tensor([1.0])
52
+ predicted_words, predicted_tokens = self.transcribe_batch(
53
+ batch, rel_length
54
+ )
55
+ return predicted_words
56
+
57
+ def encode_batch(self, wavs, wav_lens):
58
+ """Encodes the input audio into a sequence of hidden states
59
+
60
+ The waveforms should already be in the model's desired format.
61
+ You can call:
62
+ ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
63
+ to get a correctly converted signal in most cases.
64
+
65
+ Arguments
66
+ ---------
67
+ wavs : torch.tensor
68
+ Batch of waveforms [batch, time, channels].
69
+ wav_lens : torch.tensor
70
+ Lengths of the waveforms relative to the longest one in the
71
+ batch, tensor of shape [batch]. The longest one should have
72
+ relative length 1.0 and others len(waveform) / max_length.
73
+ Used for ignoring padding.
74
+
75
+ Returns
76
+ -------
77
+ torch.tensor
78
+ The encoded batch
79
+ """
80
+ wavs = wavs.float()
81
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
82
+ encoder_out = self.mods.whisper.forward_encoder(wavs)
83
+ return encoder_out
84
+
85
+ def transcribe_batch(self, wavs, wav_lens):
86
+ """Transcribes the input audio into a sequence of words
87
+
88
+ The waveforms should already be in the model's desired format.
89
+ You can call:
90
+ ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
91
+ to get a correctly converted signal in most cases.
92
+
93
+ Arguments
94
+ ---------
95
+ wavs : torch.tensor
96
+ Batch of waveforms [batch, time, channels].
97
+ wav_lens : torch.tensor
98
+ Lengths of the waveforms relative to the longest one in the
99
+ batch, tensor of shape [batch]. The longest one should have
100
+ relative length 1.0 and others len(waveform) / max_length.
101
+ Used for ignoring padding.
102
+
103
+ Returns
104
+ -------
105
+ list
106
+ Each waveform in the batch transcribed.
107
+ tensor
108
+ Each predicted token id.
109
+ """
110
+ with torch.no_grad():
111
+ wav_lens = wav_lens.to(self.device)
112
+ encoder_out = self.encode_batch(wavs, wav_lens)
113
+ predicted_tokens, scores = self.mods.decoder(encoder_out, wav_lens)
114
+ predicted_words = self.tokenizer.batch_decode(
115
+ predicted_tokens, skip_special_tokens=True)
116
+ if self.hparams.normalized_transcripts:
117
+ predicted_words = [
118
+ self.tokenizer._normalize(text).split(" ")
119
+ for text in predicted_words
120
+ ]
121
+
122
+
123
+ return predicted_words, predicted_tokens
124
+
125
+ def forward(self, wavs, wav_lens):
126
+ """Runs full transcription - note: no gradients through decoding"""
127
+ return self.transcribe_batch(wavs, wav_lens)
hparams.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ################################
2
+ # Model: Whisper (Encoder-Decoder) + NLL
3
+ # Augmentation: TimeDomainSpecAugment
4
+ # Authors: Pooneh Mousavi 2022
5
+ # ################################
6
+
7
+
8
+ # URL for the biggest Fairseq english whisper model.
9
+ whisper_hub: openai/whisper-large-v2
10
+
11
+ # Normalize inputs with
12
+ # the same normalization done in the paper. Refer to Appendix C for further information.
13
+ normalized_transcripts: True
14
+
15
+
16
+ language: hindi
17
+
18
+ auto_mix_prec: False
19
+ sample_rate: 16000
20
+
21
+ # These values are only used for the searchers.
22
+ # They needs to be hardcoded and should not be changed with Whisper.
23
+ # They are used as part of the searching process.
24
+ # The bos token of the searcher will be timestamp_index
25
+ # and will be concatenated with the bos, language and task tokens.
26
+ timestamp_index: 50363
27
+ eos_index: 50257
28
+ bos_index: 50258
29
+
30
+ # Decoding parameters
31
+ min_decode_ratio: 0.0
32
+ max_decode_ratio: 0.1
33
+ test_beam_size: 8
34
+
35
+ # Model parameters
36
+ freeze_whisper: True
37
+ freeze_encoder: True
38
+
39
+
40
+
41
+ whisper: !new:speechbrain.lobes.models.huggingface_whisper.HuggingFaceWhisper
42
+ source: !ref <whisper_hub>
43
+ freeze: !ref <freeze_whisper>
44
+ freeze_encoder: !ref <freeze_encoder>
45
+ save_path: whisper_checkpoints
46
+ encoder_only: False
47
+
48
+
49
+
50
+ decoder: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch
51
+ model: !ref <whisper>
52
+ bos_index: !ref <timestamp_index>
53
+ eos_index: !ref <eos_index>
54
+ min_decode_ratio: !ref <min_decode_ratio>
55
+ max_decode_ratio: !ref <max_decode_ratio>
56
+
57
+ # test_beam_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch
58
+ # module: [!ref <whisper>]
59
+ # bos_index: !ref <timestamp_index>
60
+ # eos_index: !ref <eos_index>
61
+ # min_decode_ratio: !ref <min_decode_ratio>
62
+ # max_decode_ratio: !ref <max_decode_ratio>
63
+ # beam_size: !ref <test_beam_size>
64
+
65
+
66
+
67
+
68
+
69
+ modules:
70
+ whisper: !ref <whisper>
71
+ decoder: !ref <decoder>
72
+
73
+
74
+ pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
75
+ loadables:
76
+ whisper: !ref <whisper>
77
+
whisper.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65048b4f44c10e5645ac0a2f43c6086520491f98bf4fcb0711ebeb1fb5f20d09
3
+ size 6173767281