kunnark commited on
Commit
e794638
1 Parent(s): ea465d6

Update encoder_wav2vec_classifier.py

Browse files
Files changed (1) hide show
  1. encoder_wav2vec_classifier.py +7 -10
encoder_wav2vec_classifier.py CHANGED
@@ -71,10 +71,10 @@ class EncoderWav2vecClassifier(Pretrained):
71
  wavs = wavs.float()
72
 
73
  # Feature extraction and normalization
74
- feats = self.modules.wav2vec2(wavs)
75
  feats = feats.transpose(1, 2)
76
 
77
- pooling = self.modules.attentive(feats, wav_lens) # channels = 1024
78
  outputs = pooling.transpose(1, 2)
79
  return outputs
80
 
@@ -105,7 +105,7 @@ class EncoderWav2vecClassifier(Pretrained):
105
  (label encoder should be provided).
106
  """
107
  outputs = self.encode_batch(wavs, wav_lens)
108
- outputs = self.modules.classifier(outputs)
109
  out_prob = self.hparams.softmax(outputs)
110
  score, index = torch.max(out_prob, dim=-1)
111
  text_lab = self.hparams.label_encoder.decode_torch(index)
@@ -136,24 +136,21 @@ class EncoderWav2vecClassifier(Pretrained):
136
  (label encoder should be provided).
137
  """
138
  waveform = self.load_audio(path)
 
139
  # Fake a batch:
140
  batch = waveform.unsqueeze(0)
141
  rel_length = torch.tensor([1.0])
142
  outputs = self.encode_batch(batch, rel_length)
143
 
144
- outputs = self.modules.classifier(outputs)
145
- # print("classify_outputs_0", outputs.shape)
146
 
147
  out_prob = self.hparams.softmax(outputs)
148
- # print("classify_out_1_softmax", out_prob)
149
  score, index = torch.max(out_prob, dim=-1)
150
  text_lab = self.hparams.label_encoder.decode_torch(index)
151
- # print("classify_score_2", score)
152
- # print("classify_index_3", index)
153
- # print("classify_textlab_4", text_lab)
154
  return out_prob, score, index, text_lab
155
 
156
  def forward(self, wavs, wav_lens=None, normalize=False):
157
  return self.encode_batch(
158
  wavs=wavs, wav_lens=wav_lens, normalize=normalize
159
- )
 
71
  wavs = wavs.float()
72
 
73
  # Feature extraction and normalization
74
+ feats = self.mods.wav2vec2(wavs)
75
  feats = feats.transpose(1, 2)
76
 
77
+ pooling = self.mods.attentive(feats, wav_lens) # channels = 1024
78
  outputs = pooling.transpose(1, 2)
79
  return outputs
80
 
 
105
  (label encoder should be provided).
106
  """
107
  outputs = self.encode_batch(wavs, wav_lens)
108
+ outputs = self.mods.classifier(outputs)
109
  out_prob = self.hparams.softmax(outputs)
110
  score, index = torch.max(out_prob, dim=-1)
111
  text_lab = self.hparams.label_encoder.decode_torch(index)
 
136
  (label encoder should be provided).
137
  """
138
  waveform = self.load_audio(path)
139
+
140
  # Fake a batch:
141
  batch = waveform.unsqueeze(0)
142
  rel_length = torch.tensor([1.0])
143
  outputs = self.encode_batch(batch, rel_length)
144
 
145
+ outputs = self.mods.classifier(outputs)
 
146
 
147
  out_prob = self.hparams.softmax(outputs)
 
148
  score, index = torch.max(out_prob, dim=-1)
149
  text_lab = self.hparams.label_encoder.decode_torch(index)
150
+
 
 
151
  return out_prob, score, index, text_lab
152
 
153
  def forward(self, wavs, wav_lens=None, normalize=False):
154
  return self.encode_batch(
155
  wavs=wavs, wav_lens=wav_lens, normalize=normalize
156
+ )