Jzuluaga commited on
Commit
2395d8b
1 Parent(s): 327d2c6

Create custom_interface.py

Browse files
Files changed (1) hide show
  1. custom_interface.py +158 -0
custom_interface.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from speechbrain.pretrained import Pretrained
3
+
4
+
5
+ class CustomEncoderWav2vec2Classifier(Pretrained):
6
+ """A ready-to-use class for utterance-level classification (e.g, speaker-id,
7
+ language-id, emotion recognition, keyword spotting, etc).
8
+
9
+ The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model
10
+ are defined in the yaml file. If you want to
11
+ convert the predicted index into a corresponding text label, please
12
+ provide the path of the label_encoder in a variable called 'lab_encoder_file'
13
+ within the yaml.
14
+
15
+ The class can be used either to run only the encoder (encode_batch()) to
16
+ extract embeddings or to run a classification step (classify_batch()).
17
+ ```
18
+
19
+ Example
20
+ -------
21
+ >>> import torchaudio
22
+ >>> from speechbrain.pretrained import EncoderClassifier
23
+ >>> # Model is downloaded from the speechbrain HuggingFace repo
24
+ >>> tmpdir = getfixture("tmpdir")
25
+ >>> classifier = EncoderClassifier.from_hparams(
26
+ ... source="speechbrain/spkrec-ecapa-voxceleb",
27
+ ... savedir=tmpdir,
28
+ ... )
29
+
30
+ >>> # Compute embeddings
31
+ >>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav")
32
+ >>> embeddings = classifier.encode_batch(signal)
33
+
34
+ >>> # Classification
35
+ >>> prediction = classifier .classify_batch(signal)
36
+ """
37
+
38
+ def __init__(self, *args, **kwargs):
39
+ super().__init__(*args, **kwargs)
40
+
41
+ def encode_batch(self, wavs, wav_lens=None, normalize=False):
42
+ """Encodes the input audio into a single vector embedding.
43
+
44
+ The waveforms should already be in the model's desired format.
45
+ You can call:
46
+ ``normalized = <this>.normalizer(signal, sample_rate)``
47
+ to get a correctly converted signal in most cases.
48
+
49
+ Arguments
50
+ ---------
51
+ wavs : torch.tensor
52
+ Batch of waveforms [batch, time, channels] or [batch, time]
53
+ depending on the model. Make sure the sample rate is fs=16000 Hz.
54
+ wav_lens : torch.tensor
55
+ Lengths of the waveforms relative to the longest one in the
56
+ batch, tensor of shape [batch]. The longest one should have
57
+ relative length 1.0 and others len(waveform) / max_length.
58
+ Used for ignoring padding.
59
+ normalize : bool
60
+ If True, it normalizes the embeddings with the statistics
61
+ contained in mean_var_norm_emb.
62
+
63
+ Returns
64
+ -------
65
+ torch.tensor
66
+ The encoded batch
67
+ """
68
+ # Manage single waveforms in input
69
+ if len(wavs.shape) == 1:
70
+ wavs = wavs.unsqueeze(0)
71
+
72
+ # Assign full length if wav_lens is not assigned
73
+ if wav_lens is None:
74
+ wav_lens = torch.ones(wavs.shape[0], device=self.device)
75
+
76
+ # Storing waveform in the specified device
77
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
78
+ wavs = wavs.float()
79
+
80
+ # Computing features and embeddings
81
+ outputs = self.mods.wav2vec2(wavs)
82
+
83
+ # last dim will be used for AdaptativeAVG pool
84
+ outputs = self.mods.avg_pool(outputs, wav_lens)
85
+ outputs = outputs.view(outputs.shape[0], -1)
86
+ return outputs
87
+
88
+ def classify_batch(self, wavs, wav_lens=None):
89
+ """Performs classification on the top of the encoded features.
90
+
91
+ It returns the posterior probabilities, the index and, if the label
92
+ encoder is specified it also the text label.
93
+
94
+ Arguments
95
+ ---------
96
+ wavs : torch.tensor
97
+ Batch of waveforms [batch, time, channels] or [batch, time]
98
+ depending on the model. Make sure the sample rate is fs=16000 Hz.
99
+ wav_lens : torch.tensor
100
+ Lengths of the waveforms relative to the longest one in the
101
+ batch, tensor of shape [batch]. The longest one should have
102
+ relative length 1.0 and others len(waveform) / max_length.
103
+ Used for ignoring padding.
104
+
105
+ Returns
106
+ -------
107
+ out_prob
108
+ The log posterior probabilities of each class ([batch, N_class])
109
+ score:
110
+ It is the value of the log-posterior for the best class ([batch,])
111
+ index
112
+ The indexes of the best class ([batch,])
113
+ text_lab:
114
+ List with the text labels corresponding to the indexes.
115
+ (label encoder should be provided).
116
+ """
117
+ outputs = self.encode_batch(wavs, wav_lens)
118
+ outputs = self.mods.output_mlp(outputs)
119
+ out_prob = self.hparams.softmax(outputs)
120
+ score, index = torch.max(out_prob, dim=-1)
121
+ text_lab = self.hparams.label_encoder.decode_torch(index)
122
+ return out_prob, score, index, text_lab
123
+
124
+ def classify_file(self, path):
125
+ """Classifies the given audiofile into the given set of labels.
126
+
127
+ Arguments
128
+ ---------
129
+ path : str
130
+ Path to audio file to classify.
131
+
132
+ Returns
133
+ -------
134
+ out_prob
135
+ The log posterior probabilities of each class ([batch, N_class])
136
+ score:
137
+ It is the value of the log-posterior for the best class ([batch,])
138
+ index
139
+ The indexes of the best class ([batch,])
140
+ text_lab:
141
+ List with the text labels corresponding to the indexes.
142
+ (label encoder should be provided).
143
+ """
144
+ waveform = self.load_audio(path)
145
+ # Fake a batch:
146
+ batch = waveform.unsqueeze(0)
147
+ rel_length = torch.tensor([1.0])
148
+ outputs = self.encode_batch(batch, rel_length)
149
+ outputs = self.mods.output_mlp(outputs).squeeze(1)
150
+ out_prob = self.hparams.softmax(outputs)
151
+ score, index = torch.max(out_prob, dim=-1)
152
+ text_lab = self.hparams.label_encoder.decode_torch(index)
153
+ return out_prob, score, index, text_lab
154
+
155
+ def forward(self, wavs, wav_lens=None, normalize=False):
156
+ return self.encode_batch(
157
+ wavs=wavs, wav_lens=wav_lens, normalize=normalize
158
+ )