Porjaz commited on
Commit
722014e
1 Parent(s): ce7435b

Update custom_interface.py

Browse files
Files changed (1) hide show
  1. custom_interface.py +9 -10
custom_interface.py CHANGED
@@ -8,22 +8,21 @@ class ASR(Pretrained):
8
 
9
  def encode_batch(self, wavs, wav_lens=None, normalize=False):
10
  wavs = wavs.to(self.device)
11
- wav_lens = wav_lens.to(self.device)
12
 
13
  # Forward pass
14
- encoded_outputs = self.modules.encoder_w2v2(wavs.detach())
15
  # append
16
  tokens_bos = torch.zeros((wavs.size(0), 1), dtype=torch.long).to(self.device)
17
- print(tokens_bos.size())
18
- embedded_tokens = self.modules.embedding(tokens_bos)
19
- decoder_outputs, _ = self.modules.decoder(embedded_tokens, encoded_outputs, self.sig_lens)
20
 
21
  # Output layer for seq2seq log-probabilities
22
- logits = self.modules.seq_lin(decoder_outputs)
23
- predictions = {"seq_logprobs": self.hparams.log_softmax(logits)}
24
- predictions["tokens"], _, _, _ = self.hparams.test_search(encoded_outputs, self.sig_lens)
25
 
26
- return predictions
27
 
28
 
29
  def classify_file(self, path):
@@ -31,7 +30,7 @@ class ASR(Pretrained):
31
  # Fake a batch:
32
  batch = waveform.unsqueeze(0)
33
  rel_length = torch.tensor([1.0])
34
- outputs = self.encode_batch(batch, rel_length)["tokens"]
35
 
36
  return outputs
37
 
 
8
 
9
  def encode_batch(self, wavs, wav_lens=None, normalize=False):
10
  wavs = wavs.to(self.device)
11
+ self.wav_lens = wav_lens.to(self.device)
12
 
13
  # Forward pass
14
+ encoded_outputs = self.mods.encoder_w2v2(wavs.detach())
15
  # append
16
  tokens_bos = torch.zeros((wavs.size(0), 1), dtype=torch.long).to(self.device)
17
+ embedded_tokens = self.mods.embedding(tokens_bos)
18
+ decoder_outputs, _ = self.mods.decoder(embedded_tokens, encoded_outputs, self.wav_lens)
 
19
 
20
  # Output layer for seq2seq log-probabilities
21
+ predictions = self.hparams.test_search(encoded_outputs, self.wav_lens)[0]
22
+ predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions]
23
+ print(predicted_words)
24
 
25
+ return predicted_words
26
 
27
 
28
  def classify_file(self, path):
 
30
  # Fake a batch:
31
  batch = waveform.unsqueeze(0)
32
  rel_length = torch.tensor([1.0])
33
+ outputs = self.encode_batch(batch, rel_length)
34
 
35
  return outputs
36