KevinGeng commited on
Commit
d706e2e
1 Parent(s): 195a188

return to old PPM model

Browse files
Files changed (1) hide show
  1. app.py +3 -7
app.py CHANGED
@@ -23,9 +23,9 @@ transformation = jiwer.Compose([
23
 
24
  # WPM part
25
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
26
- processor = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
27
- model = Wav2Vec2ForCTC.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
28
- phoneme_model = pipeline(model="vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
29
  class ChangeSampleRate(nn.Module):
30
  def __init__(self, input_rate: int, output_rate: int):
31
  super().__init__()
@@ -83,9 +83,6 @@ def calc_mos(audio_path, ref):
83
  logits = phoneme_model(out_wavs).logits
84
  phone_predicted_ids = torch.argmax(logits, dim=-1)
85
  phone_transcription = processor.batch_decode(phone_predicted_ids)
86
-
87
- # Disable PPM for now
88
- phone_transcription = ['D U M M Y']
89
  lst_phonemes = phone_transcription[0].split(" ")
90
  wav_vad = torchaudio.functional.vad(wav, sample_rate=sr)
91
 
@@ -94,7 +91,6 @@ def calc_mos(audio_path, ref):
94
 
95
  ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
96
 
97
- # pdb.set_trace()
98
  return AVA_MOS, MOS_fig, INTELI_score, INT_fig, trans, phone_transcription, ppm, f0_db_fig
99
 
100
 
 
23
 
24
  # WPM part
25
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
26
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
27
+ phoneme_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
28
+ # phoneme_model = pipeline(model="vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
29
  class ChangeSampleRate(nn.Module):
30
  def __init__(self, input_rate: int, output_rate: int):
31
  super().__init__()
 
83
  logits = phoneme_model(out_wavs).logits
84
  phone_predicted_ids = torch.argmax(logits, dim=-1)
85
  phone_transcription = processor.batch_decode(phone_predicted_ids)
 
 
 
86
  lst_phonemes = phone_transcription[0].split(" ")
87
  wav_vad = torchaudio.functional.vad(wav, sample_rate=sr)
88
 
 
91
 
92
  ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
93
 
 
94
  return AVA_MOS, MOS_fig, INTELI_score, INT_fig, trans, phone_transcription, ppm, f0_db_fig
95
 
96