Update custom_interface.py
Browse files- custom_interface.py +9 -10
custom_interface.py
CHANGED
@@ -8,22 +8,21 @@ class ASR(Pretrained):
|
|
8 |
|
9 |
def encode_batch(self, wavs, wav_lens=None, normalize=False):
|
10 |
wavs = wavs.to(self.device)
|
11 |
-
wav_lens = wav_lens.to(self.device)
|
12 |
|
13 |
# Forward pass
|
14 |
-
encoded_outputs = self.
|
15 |
# append
|
16 |
tokens_bos = torch.zeros((wavs.size(0), 1), dtype=torch.long).to(self.device)
|
17 |
-
|
18 |
-
|
19 |
-
decoder_outputs, _ = self.modules.decoder(embedded_tokens, encoded_outputs, self.sig_lens)
|
20 |
|
21 |
# Output layer for seq2seq log-probabilities
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
|
26 |
-
return
|
27 |
|
28 |
|
29 |
def classify_file(self, path):
|
@@ -31,7 +30,7 @@ class ASR(Pretrained):
|
|
31 |
# Fake a batch:
|
32 |
batch = waveform.unsqueeze(0)
|
33 |
rel_length = torch.tensor([1.0])
|
34 |
-
outputs = self.encode_batch(batch, rel_length)
|
35 |
|
36 |
return outputs
|
37 |
|
|
|
8 |
|
9 |
def encode_batch(self, wavs, wav_lens=None, normalize=False):
|
10 |
wavs = wavs.to(self.device)
|
11 |
+
self.wav_lens = wav_lens.to(self.device)
|
12 |
|
13 |
# Forward pass
|
14 |
+
encoded_outputs = self.mods.encoder_w2v2(wavs.detach())
|
15 |
# append
|
16 |
tokens_bos = torch.zeros((wavs.size(0), 1), dtype=torch.long).to(self.device)
|
17 |
+
embedded_tokens = self.mods.embedding(tokens_bos)
|
18 |
+
decoder_outputs, _ = self.mods.decoder(embedded_tokens, encoded_outputs, self.wav_lens)
|
|
|
19 |
|
20 |
# Output layer for seq2seq log-probabilities
|
21 |
+
predictions = self.hparams.test_search(encoded_outputs, self.wav_lens)[0]
|
22 |
+
predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions]
|
23 |
+
print(predicted_words)
|
24 |
|
25 |
+
return predicted_words
|
26 |
|
27 |
|
28 |
def classify_file(self, path):
|
|
|
30 |
# Fake a batch:
|
31 |
batch = waveform.unsqueeze(0)
|
32 |
rel_length = torch.tensor([1.0])
|
33 |
+
outputs = self.encode_batch(batch, rel_length)
|
34 |
|
35 |
return outputs
|
36 |
|