TheComputerMan commited on
Commit
1fe6b04
1 Parent(s): e06adea

Upload ProsodicConditionExtractor.py

Browse files
Files changed (1) hide show
  1. ProsodicConditionExtractor.py +40 -0
ProsodicConditionExtractor.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import soundfile as sf
2
+ import torch
3
+ import torch.multiprocessing
4
+ import torch.multiprocessing
5
+ from numpy import trim_zeros
6
+ from speechbrain.pretrained import EncoderClassifier
7
+
8
+ from Preprocessing.AudioPreprocessor import AudioPreprocessor
9
+
10
+
11
+ class ProsodicConditionExtractor:
12
+
13
+ def __init__(self, sr, device=torch.device("cpu")):
14
+ self.ap = AudioPreprocessor(input_sr=sr, output_sr=16000, melspec_buckets=80, hop_length=256, n_fft=1024, cut_silence=False)
15
+ # https://huggingface.co/speechbrain/spkrec-ecapa-voxceleb
16
+ self.speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb",
17
+ run_opts={"device": str(device)},
18
+ savedir="Models/SpeakerEmbedding/speechbrain_speaker_embedding_ecapa")
19
+ # https://huggingface.co/speechbrain/spkrec-xvect-voxceleb
20
+ self.speaker_embedding_func_xvector = EncoderClassifier.from_hparams(source="speechbrain/spkrec-xvect-voxceleb",
21
+ run_opts={"device": str(device)},
22
+ savedir="Models/SpeakerEmbedding/speechbrain_speaker_embedding_xvector")
23
+
24
+ def extract_condition_from_reference_wave(self, wave, already_normalized=False):
25
+ if already_normalized:
26
+ norm_wave = wave
27
+ else:
28
+ norm_wave = self.ap.audio_to_wave_tensor(normalize=True, audio=wave)
29
+ norm_wave = torch.tensor(trim_zeros(norm_wave.numpy()))
30
+ spk_emb_ecapa = self.speaker_embedding_func_ecapa.encode_batch(wavs=norm_wave.unsqueeze(0)).squeeze()
31
+ spk_emb_xvector = self.speaker_embedding_func_xvector.encode_batch(wavs=norm_wave.unsqueeze(0)).squeeze()
32
+ combined_utt_condition = torch.cat([spk_emb_ecapa.cpu(),
33
+ spk_emb_xvector.cpu()], dim=0)
34
+ return combined_utt_condition
35
+
36
+
37
+ if __name__ == '__main__':
38
+ wave, sr = sf.read("../audios/1.wav")
39
+ ext = ProsodicConditionExtractor(sr=sr)
40
+ print(ext.extract_condition_from_reference_wave(wave=wave).shape)