psakamoori commited on
Commit
64c2ea8
1 Parent(s): e189d08

CustomeEncoderWav2vec2classifier class

Browse files
Files changed (1) hide show
  1. custom_interface.py +207 -0
custom_interface.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from speechbrain.inference.interfaces import Pretrained
3
+ import openvino as ov
4
+
5
+ class CustomEncoderWav2vec2Classifier(Pretrained):
6
+ """A ready-to-use class for utterance-level classification (e.g, speaker-id,
7
+ language-id, emotion recognition, keyword spotting, etc).
8
+
9
+ The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model
10
+ are defined in the yaml file. If you want to
11
+ convert the predicted index into a corresponding text label, please
12
+ provide the path of the label_encoder in a variable called 'lab_encoder_file'
13
+ within the yaml.
14
+
15
+ The class can be used either to run only the encoder (encode_batch()) to
16
+ extract embeddings or to run a classification step (classify_batch()).
17
+ ```
18
+
19
+ Example
20
+ -------
21
+ >>> import torchaudio
22
+ >>> from speechbrain.pretrained import EncoderClassifier
23
+ >>> # Model is downloaded from the speechbrain HuggingFace repo
24
+ >>> tmpdir = getfixture("tmpdir")
25
+ >>> classifier = EncoderClassifier.from_hparams(
26
+ ... source="speechbrain/spkrec-ecapa-voxceleb",
27
+ ... savedir=tmpdir,
28
+ ... )
29
+
30
+ >>> # Compute embeddings
31
+ >>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav")
32
+ >>> embeddings = classifier.encode_batch(signal)
33
+
34
+ >>> # Classification
35
+ >>> prediction = classifier .classify_batch(signal)
36
+ """
37
+
38
+ def __init__(self, *args, model=None,
39
+ audio_file_path=None, backend="pytorch",
40
+ ov_opts={"device_name": "cpu"},
41
+ save_ov_model=False,
42
+ **kwargs):
43
+ super().__init__(*args, **kwargs)
44
+ self.backend = backend
45
+ if self.backend == "openvino":
46
+ print("=" * 30)
47
+ print("OpenVINO Backend Selected")
48
+ print("=" * 30)
49
+
50
+ self.core = ov.Core()
51
+ self.ov_model = None
52
+ # if torch model
53
+ if model:
54
+ print("\n[INFO] Preparing OpenVINO model...")
55
+ self.get_ov_model(model, audio_file_path)
56
+ print("[SUCCESS] OpenVINO IR model compiled for inference!\n")
57
+ if self.ov_model:
58
+ self.device = ov_opts["device_name"]
59
+ print("[INFO] Compiling OpenVINO IR model for inference...")
60
+ self.compiled_model = self.core.compile_model(self.ov_model, config=ov_opts)
61
+ print("[SUCCESS] OpenVINO IR model compiled for inference!\n")
62
+ # Falg to save openvino ir model file to disk
63
+ if save_ov_model:
64
+ # set to default path
65
+ print("[INFO] Saving OpenVINO IR model to disk!\n")
66
+ ov_ir_file_path = "./openvino_model/fp32/speechbrain_emotion_recog_ov_ir_model.xml"
67
+ ov.save_model(self.ov_model, ov_ir_file_path)
68
+ print(f"[SUCCESS] OpenVINO IR model file saved at {ov_ir_file_path}!\n")
69
+
70
+ def encode_batch(self, wavs, wav_lens=None, normalize=False):
71
+ """Encodes the input audio into a single vector embedding.
72
+
73
+ The waveforms should already be in the model's desired format.
74
+ You can call:
75
+ ``normalized = <this>.normalizer(signal, sample_rate)``
76
+ to get a correctly converted signal in most cases.
77
+
78
+ Arguments
79
+ ---------
80
+ wavs : torch.tensor
81
+ Batch of waveforms [batch, time, channels] or [batch, time]
82
+ depending on the model. Make sure the sample rate is fs=16000 Hz.
83
+ wav_lens : torch.tensor
84
+ Lengths of the waveforms relative to the longest one in the
85
+ batch, tensor of shape [batch]. The longest one should have
86
+ relative length 1.0 and others len(waveform) / max_length.
87
+ Used for ignoring padding.
88
+ normalize : bool
89
+ If True, it normalizes the embeddings with the statistics
90
+ contained in mean_var_norm_emb.
91
+
92
+ Returns
93
+ -------
94
+ torch.tensor
95
+ The encoded batch
96
+ """
97
+ # Manage single waveforms in input
98
+ if len(wavs.shape) == 1:
99
+ wavs = wavs.unsqueeze(0)
100
+
101
+ # Assign full length if wav_lens is not assigned
102
+ if wav_lens is None:
103
+ wav_lens = torch.ones(wavs.shape[0], device=self.device)
104
+
105
+ # Storing waveform in the specified device
106
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
107
+ wavs = wavs.float()
108
+
109
+ if self.backend == "pytorch":
110
+ # Computing features and embeddings
111
+ outputs = self.mods.wav2vec2(wavs)
112
+ elif self.backend == "openvino":
113
+ # OpenVINO inference
114
+ outputs = self.ov_inference(wavs, wav_lens)
115
+
116
+ # last dim will be used for AdaptativeAVG pool
117
+ outputs = self.mods.avg_pool(outputs, wav_lens)
118
+ outputs = outputs.view(outputs.shape[0], -1)
119
+
120
+ return outputs
121
+
122
+ def classify_batch(self, wavs, wav_lens=None):
123
+ """Performs classification on the top of the encoded features.
124
+
125
+ It returns the posterior probabilities, the index and, if the label
126
+ encoder is specified it also the text label.
127
+
128
+ Arguments
129
+ ---------
130
+ wavs : torch.tensor
131
+ Batch of waveforms [batch, time, channels] or [batch, time]
132
+ depending on the model. Make sure the sample rate is fs=16000 Hz.
133
+ wav_lens : torch.tensor
134
+ Lengths of the waveforms relative to the longest one in the
135
+ batch, tensor of shape [batch]. The longest one should have
136
+ relative length 1.0 and others len(waveform) / max_length.
137
+ Used for ignoring padding.
138
+
139
+ Returns
140
+ -------
141
+ out_prob
142
+ The log posterior probabilities of each class ([batch, N_class])
143
+ score:
144
+ It is the value of the log-posterior for the best class ([batch,])
145
+ index
146
+ The indexes of the best class ([batch,])
147
+ text_lab:
148
+ List with the text labels corresponding to the indexes.
149
+ (label encoder should be provided).
150
+ """
151
+ outputs = self.encode_batch(wavs, wav_lens)
152
+ outputs = self.mods.output_mlp(outputs)
153
+ out_prob = self.hparams.softmax(outputs)
154
+ score, index = torch.max(out_prob, dim=-1)
155
+ text_lab = self.hparams.label_encoder.decode_torch(index)
156
+ return out_prob, score, index, text_lab
157
+
158
+ def classify_file(self, path):
159
+ """Classifies the given audiofile into the given set of labels.
160
+
161
+ Arguments
162
+ ---------
163
+ path : str
164
+ Path to audio file to classify.
165
+
166
+ Returns
167
+ -------
168
+ out_prob
169
+ The log posterior probabilities of each class ([batch, N_class])
170
+ score:
171
+ It is the value of the log-posterior for the best class ([batch,])
172
+ index
173
+ The indexes of the best class ([batch,])
174
+ text_lab:
175
+ List with the text labels corresponding to the indexes.
176
+ (label encoder should be provided).
177
+ """
178
+ waveform = self.load_audio(path)
179
+ # Fake a batch:
180
+ batch = waveform.unsqueeze(0)
181
+ rel_length = torch.tensor([1.0])
182
+ outputs = self.encode_batch(batch, rel_length)
183
+ outputs = self.mods.output_mlp(outputs).squeeze(1)
184
+ out_prob = self.hparams.softmax(outputs)
185
+ score, index = torch.max(out_prob, dim=-1)
186
+ text_lab = self.hparams.label_encoder.decode_torch(index)
187
+ return out_prob, score, index, text_lab
188
+
189
+ def get_ov_model(self, torch_model, path):
190
+ # Prepare input tensor
191
+ waveform = self.load_audio(path)
192
+ wavs = waveform.unsqueeze(0)
193
+
194
+ # Torch to OpenVINO model conversion
195
+ self.ov_model = ov.convert_model(torch_model, example_input=wavs)
196
+
197
+ def ov_inference(self, wavs, wav_lens):
198
+ output_tensor = self.compiled_model(wavs.float())[0]
199
+ output_tensor = torch.from_numpy(output_tensor)
200
+ print("\n[INFO] Performing OpenVINO inference...")
201
+
202
+ return output_tensor
203
+
204
+ def forward(self, wavs, wav_lens=None, normalize=False):
205
+ return self.encode_batch(
206
+ wavs=wavs, wav_lens=wav_lens, normalize=normalize
207
+ )