poonehmousavi commited on
Commit
092d812
1 Parent(s): 970c9c3

Upload 3 files

Browse files
Files changed (3) hide show
  1. custom_interface.py +127 -0
  2. hparams.yaml +77 -0
  3. whisper.ckpt +3 -0
custom_interface.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from speechbrain.pretrained import Pretrained
3
+
4
+ class WhisperASR(Pretrained):
5
+ """A ready-to-use Whisper ASR model
6
+
7
+ The class can be used to run only the encoder (encode()) to run the entire encoder-decoder whisper model
8
+ (transcribe()) to transcribe speech. The given YAML must contains the fields
9
+ specified in the *_NEEDED[] lists.
10
+
11
+ Example
12
+ -------
13
+ >>> from speechbrain.pretrained.interfaces import foreign_class
14
+ >>> tmpdir = getfixture("tmpdir")
15
+ >>> asr_model = foreign_class(source="hf",
16
+ ... pymodule_file="custom_interface.py",
17
+ ... classname="WhisperASR",
18
+ ... hparams_file='hparams.yaml',
19
+ ... savedir=tmpdir,
20
+ ... )
21
+ >>> asr_model.transcribe_file("tests/samples/example2.wav")
22
+ """
23
+
24
+ HPARAMS_NEEDED = ['language']
25
+ MODULES_NEEDED = ["whisper", "decoder"]
26
+
27
+ def __init__(self, *args, **kwargs):
28
+ super().__init__(*args, **kwargs)
29
+ self.tokenizer = self.hparams.whisper.tokenizer
30
+ self.tokenizer.set_prefix_tokens(self.hparams.language, "transcribe", False)
31
+ self.hparams.decoder.set_decoder_input_tokens(
32
+ self.tokenizer.prefix_tokens
33
+ )
34
+
35
+ def transcribe_file(self, path):
36
+ """Transcribes the given audiofile into a sequence of words.
37
+
38
+ Arguments
39
+ ---------
40
+ path : str
41
+ Path to audio file which to transcribe.
42
+
43
+ Returns
44
+ -------
45
+ str
46
+ The audiofile transcription produced by this ASR system.
47
+ """
48
+ waveform = self.load_audio(path)
49
+ # Fake a batch:
50
+ batch = waveform.unsqueeze(0)
51
+ rel_length = torch.tensor([1.0])
52
+ predicted_words, predicted_tokens = self.transcribe_batch(
53
+ batch, rel_length
54
+ )
55
+ return predicted_words
56
+
57
+ def encode_batch(self, wavs, wav_lens):
58
+ """Encodes the input audio into a sequence of hidden states
59
+
60
+ The waveforms should already be in the model's desired format.
61
+ You can call:
62
+ ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
63
+ to get a correctly converted signal in most cases.
64
+
65
+ Arguments
66
+ ---------
67
+ wavs : torch.tensor
68
+ Batch of waveforms [batch, time, channels].
69
+ wav_lens : torch.tensor
70
+ Lengths of the waveforms relative to the longest one in the
71
+ batch, tensor of shape [batch]. The longest one should have
72
+ relative length 1.0 and others len(waveform) / max_length.
73
+ Used for ignoring padding.
74
+
75
+ Returns
76
+ -------
77
+ torch.tensor
78
+ The encoded batch
79
+ """
80
+ wavs = wavs.float()
81
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
82
+ encoder_out = self.mods.whisper.forward_encoder(wavs)
83
+ return encoder_out
84
+
85
+ def transcribe_batch(self, wavs, wav_lens):
86
+ """Transcribes the input audio into a sequence of words
87
+
88
+ The waveforms should already be in the model's desired format.
89
+ You can call:
90
+ ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
91
+ to get a correctly converted signal in most cases.
92
+
93
+ Arguments
94
+ ---------
95
+ wavs : torch.tensor
96
+ Batch of waveforms [batch, time, channels].
97
+ wav_lens : torch.tensor
98
+ Lengths of the waveforms relative to the longest one in the
99
+ batch, tensor of shape [batch]. The longest one should have
100
+ relative length 1.0 and others len(waveform) / max_length.
101
+ Used for ignoring padding.
102
+
103
+ Returns
104
+ -------
105
+ list
106
+ Each waveform in the batch transcribed.
107
+ tensor
108
+ Each predicted token id.
109
+ """
110
+ with torch.no_grad():
111
+ wav_lens = wav_lens.to(self.device)
112
+ encoder_out = self.encode_batch(wavs, wav_lens)
113
+ predicted_tokens, scores = self.mods.decoder(encoder_out, wav_lens)
114
+ predicted_words = self.tokenizer.batch_decode(
115
+ predicted_tokens, skip_special_tokens=True)
116
+ if self.hparams.normalized_transcripts:
117
+ predicted_words = [
118
+ self.tokenizer._normalize(text).split(" ")
119
+ for text in predicted_words
120
+ ]
121
+
122
+
123
+ return predicted_words, predicted_tokens
124
+
125
+ def forward(self, wavs, wav_lens):
126
+ """Runs full transcription - note: no gradients through decoding"""
127
+ return self.transcribe_batch(wavs, wav_lens)
hparams.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ################################
2
+ # Model: Whisper (Encoder-Decoder) + NLL
3
+ # Augmentation: TimeDomainSpecAugment
4
+ # Authors: Pooneh Mousavi 2022
5
+ # ################################
6
+
7
+
8
+ # URL for the biggest Fairseq english whisper model.
9
+ whisper_hub: openai/whisper-large-v2
10
+
11
+ # Normalize inputs with
12
+ # the same normalization done in the paper. Refer to Appendix C for further information.
13
+ normalized_transcripts: True
14
+
15
+
16
+ language: mongolian
17
+
18
+ auto_mix_prec: False
19
+ sample_rate: 16000
20
+
21
+ # These values are only used for the searchers.
22
+ # They needs to be hardcoded and should not be changed with Whisper.
23
+ # They are used as part of the searching process.
24
+ # The bos token of the searcher will be timestamp_index
25
+ # and will be concatenated with the bos, language and task tokens.
26
+ timestamp_index: 50363
27
+ eos_index: 50257
28
+ bos_index: 50258
29
+
30
+ # Decoding parameters
31
+ min_decode_ratio: 0.0
32
+ max_decode_ratio: 0.1
33
+ test_beam_size: 8
34
+
35
+ # Model parameters
36
+ freeze_whisper: True
37
+ freeze_encoder: True
38
+
39
+
40
+
41
+ whisper: !new:speechbrain.lobes.models.huggingface_whisper.HuggingFaceWhisper
42
+ source: !ref <whisper_hub>
43
+ freeze: !ref <freeze_whisper>
44
+ freeze_encoder: !ref <freeze_encoder>
45
+ save_path: whisper_checkpoints
46
+ encoder_only: False
47
+
48
+
49
+
50
+ decoder: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearch
51
+ model: !ref <whisper>
52
+ bos_index: !ref <timestamp_index>
53
+ eos_index: !ref <eos_index>
54
+ min_decode_ratio: !ref <min_decode_ratio>
55
+ max_decode_ratio: !ref <max_decode_ratio>
56
+
57
+ # test_beam_searcher: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearch
58
+ # module: [!ref <whisper>]
59
+ # bos_index: !ref <timestamp_index>
60
+ # eos_index: !ref <eos_index>
61
+ # min_decode_ratio: !ref <min_decode_ratio>
62
+ # max_decode_ratio: !ref <max_decode_ratio>
63
+ # beam_size: !ref <test_beam_size>
64
+
65
+
66
+
67
+
68
+
69
+ modules:
70
+ whisper: !ref <whisper>
71
+ decoder: !ref <decoder>
72
+
73
+
74
+ pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
75
+ loadables:
76
+ whisper: !ref <whisper>
77
+
whisper.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6230026278db2e8f4cf6a49c6179bd73e1d3e3cebd0202f4615bab830029f4c5
3
+ size 6173767281