poonehmousavi commited on
Commit
5925f2c
1 Parent(s): bd71d20

Delete custom_interface.py

Browse files
Files changed (1) hide show
  1. custom_interface.py +0 -127
custom_interface.py DELETED
@@ -1,127 +0,0 @@
1
- import torch
2
- from speechbrain.pretrained import Pretrained
3
-
4
- class WhisperASR(Pretrained):
5
- """A ready-to-use Whisper ASR model
6
-
7
- The class can be used to run only the encoder (encode()) to run the entire encoder-decoder whisper model
8
- (transcribe()) to transcribe speech. The given YAML must contains the fields
9
- specified in the *_NEEDED[] lists.
10
-
11
- Example
12
- -------
13
- >>> from speechbrain.pretrained.interfaces import foreign_class
14
- >>> tmpdir = getfixture("tmpdir")
15
- >>> asr_model = foreign_class(source="hf",
16
- ... pymodule_file="custom_interface.py",
17
- ... classname="WhisperASR",
18
- ... hparams_file='hparams.yaml',
19
- ... savedir=tmpdir,
20
- ... )
21
- >>> asr_model.transcribe_file("tests/samples/example2.wav")
22
- """
23
-
24
- HPARAMS_NEEDED = ['language']
25
- MODULES_NEEDED = ["whisper", "decoder"]
26
-
27
- def __init__(self, *args, **kwargs):
28
- super().__init__(*args, **kwargs)
29
- self.tokenizer = self.hparams.whisper.tokenizer
30
- self.tokenizer.set_prefix_tokens(self.hparams.language, "transcribe", False)
31
- self.hparams.decoder.set_decoder_input_tokens(
32
- self.tokenizer.prefix_tokens
33
- )
34
-
35
- def transcribe_file(self, path):
36
- """Transcribes the given audiofile into a sequence of words.
37
-
38
- Arguments
39
- ---------
40
- path : str
41
- Path to audio file which to transcribe.
42
-
43
- Returns
44
- -------
45
- str
46
- The audiofile transcription produced by this ASR system.
47
- """
48
- waveform = self.load_audio(path)
49
- # Fake a batch:
50
- batch = waveform.unsqueeze(0)
51
- rel_length = torch.tensor([1.0])
52
- predicted_words, predicted_tokens = self.transcribe_batch(
53
- batch, rel_length
54
- )
55
- return predicted_words
56
-
57
- def encode_batch(self, wavs, wav_lens):
58
- """Encodes the input audio into a sequence of hidden states
59
-
60
- The waveforms should already be in the model's desired format.
61
- You can call:
62
- ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
63
- to get a correctly converted signal in most cases.
64
-
65
- Arguments
66
- ---------
67
- wavs : torch.tensor
68
- Batch of waveforms [batch, time, channels].
69
- wav_lens : torch.tensor
70
- Lengths of the waveforms relative to the longest one in the
71
- batch, tensor of shape [batch]. The longest one should have
72
- relative length 1.0 and others len(waveform) / max_length.
73
- Used for ignoring padding.
74
-
75
- Returns
76
- -------
77
- torch.tensor
78
- The encoded batch
79
- """
80
- wavs = wavs.float()
81
- wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
82
- encoder_out = self.mods.whisper.forward_encoder(wavs)
83
- return encoder_out
84
-
85
- def transcribe_batch(self, wavs, wav_lens):
86
- """Transcribes the input audio into a sequence of words
87
-
88
- The waveforms should already be in the model's desired format.
89
- You can call:
90
- ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)``
91
- to get a correctly converted signal in most cases.
92
-
93
- Arguments
94
- ---------
95
- wavs : torch.tensor
96
- Batch of waveforms [batch, time, channels].
97
- wav_lens : torch.tensor
98
- Lengths of the waveforms relative to the longest one in the
99
- batch, tensor of shape [batch]. The longest one should have
100
- relative length 1.0 and others len(waveform) / max_length.
101
- Used for ignoring padding.
102
-
103
- Returns
104
- -------
105
- list
106
- Each waveform in the batch transcribed.
107
- tensor
108
- Each predicted token id.
109
- """
110
- with torch.no_grad():
111
- wav_lens = wav_lens.to(self.device)
112
- encoder_out = self.encode_batch(wavs, wav_lens)
113
- predicted_tokens, scores = self.mods.decoder(encoder_out, wav_lens)
114
- predicted_words = self.tokenizer.batch_decode(
115
- predicted_tokens, skip_special_tokens=True)
116
- if self.hparams.normalized_transcripts:
117
- predicted_words = [
118
- self.tokenizer._normalize(text).split(" ")
119
- for text in predicted_words
120
- ]
121
-
122
-
123
- return predicted_words, predicted_tokens
124
-
125
- def forward(self, wavs, wav_lens):
126
- """Runs full transcription - note: no gradients through decoding"""
127
- return self.transcribe_batch(wavs, wav_lens)