Porjaz commited on
Commit
95bcbb5
1 Parent(s): f23fcf0

Create custom_inference.py

Browse files
Files changed (1) hide show
  1. custom_inference.py +145 -0
custom_inference.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from speechbrain.inference.interfaces import Pretrained
3
+
4
+
5
+ class CustomEncoderWav2vec2Classifier(Pretrained):
6
+ """A ready-to-use class for utterance-level classification (e.g, speaker-id,
7
+ language-id, emotion recognition, keyword spotting, etc).
8
+ The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model
9
+ are defined in the yaml file. If you want to
10
+ convert the predicted index into a corresponding text label, please
11
+ provide the path of the label_encoder in a variable called 'lab_encoder_file'
12
+ within the yaml.
13
+ The class can be used either to run only the encoder (encode_batch()) to
14
+ extract embeddings or to run a classification step (classify_batch()).
15
+ ```
16
+ Example
17
+ -------
18
+ >>> import torchaudio
19
+ >>> from speechbrain.pretrained import EncoderClassifier
20
+ >>> # Model is downloaded from the speechbrain HuggingFace repo
21
+ >>> tmpdir = getfixture("tmpdir")
22
+ >>> classifier = EncoderClassifier.from_hparams(
23
+ ... source="speechbrain/spkrec-ecapa-voxceleb",
24
+ ... savedir=tmpdir,
25
+ ... )
26
+ >>> # Compute embeddings
27
+ >>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav")
28
+ >>> embeddings = classifier.encode_batch(signal)
29
+ >>> # Classification
30
+ >>> prediction = classifier .classify_batch(signal)
31
+ """
32
+
33
+ def __init__(self, *args, **kwargs):
34
+ super().__init__(*args, **kwargs)
35
+
36
+ def encode_batch(self, wavs, wav_lens=None, normalize=False):
37
+ """Encodes the input audio into a single vector embedding.
38
+ The waveforms should already be in the model's desired format.
39
+ You can call:
40
+ ``normalized = <this>.normalizer(signal, sample_rate)``
41
+ to get a correctly converted signal in most cases.
42
+ Arguments
43
+ ---------
44
+ wavs : torch.tensor
45
+ Batch of waveforms [batch, time, channels] or [batch, time]
46
+ depending on the model. Make sure the sample rate is fs=16000 Hz.
47
+ wav_lens : torch.tensor
48
+ Lengths of the waveforms relative to the longest one in the
49
+ batch, tensor of shape [batch]. The longest one should have
50
+ relative length 1.0 and others len(waveform) / max_length.
51
+ Used for ignoring padding.
52
+ normalize : bool
53
+ If True, it normalizes the embeddings with the statistics
54
+ contained in mean_var_norm_emb.
55
+ Returns
56
+ -------
57
+ torch.tensor
58
+ The encoded batch
59
+ """
60
+ # Manage single waveforms in input
61
+ if len(wavs.shape) == 1:
62
+ wavs = wavs.unsqueeze(0)
63
+
64
+ # Assign full length if wav_lens is not assigned
65
+ if wav_lens is None:
66
+ wav_lens = torch.ones(wavs.shape[0], device=self.device)
67
+
68
+ # Storing waveform in the specified device
69
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
70
+ wavs = wavs.float()
71
+
72
+ # Computing features and embeddings
73
+ outputs = self.mods.wav2vec2(wavs)
74
+
75
+ # last dim will be used for AdaptativeAVG pool
76
+ outputs = self.mods.avg_pool(outputs, wav_lens)
77
+ outputs = outputs.view(outputs.shape[0], -1)
78
+ return outputs
79
+
80
+ def classify_batch(self, wavs, wav_lens=None):
81
+ """Performs classification on the top of the encoded features.
82
+ It returns the posterior probabilities, the index and, if the label
83
+ encoder is specified it also the text label.
84
+ Arguments
85
+ ---------
86
+ wavs : torch.tensor
87
+ Batch of waveforms [batch, time, channels] or [batch, time]
88
+ depending on the model. Make sure the sample rate is fs=16000 Hz.
89
+ wav_lens : torch.tensor
90
+ Lengths of the waveforms relative to the longest one in the
91
+ batch, tensor of shape [batch]. The longest one should have
92
+ relative length 1.0 and others len(waveform) / max_length.
93
+ Used for ignoring padding.
94
+ Returns
95
+ -------
96
+ out_prob
97
+ The log posterior probabilities of each class ([batch, N_class])
98
+ score:
99
+ It is the value of the log-posterior for the best class ([batch,])
100
+ index
101
+ The indexes of the best class ([batch,])
102
+ text_lab:
103
+ List with the text labels corresponding to the indexes.
104
+ (label encoder should be provided).
105
+ """
106
+ outputs = self.encode_batch(wavs, wav_lens)
107
+ outputs = self.mods.output_mlp(outputs)
108
+ out_prob = self.hparams.softmax(outputs)
109
+ score, index = torch.max(out_prob, dim=-1)
110
+ text_lab = self.hparams.label_encoder.decode_torch(index)
111
+ return out_prob, score, index, text_lab
112
+
113
+ def classify_file(self, path):
114
+ """Classifies the given audiofile into the given set of labels.
115
+ Arguments
116
+ ---------
117
+ path : str
118
+ Path to audio file to classify.
119
+ Returns
120
+ -------
121
+ out_prob
122
+ The log posterior probabilities of each class ([batch, N_class])
123
+ score:
124
+ It is the value of the log-posterior for the best class ([batch,])
125
+ index
126
+ The indexes of the best class ([batch,])
127
+ text_lab:
128
+ List with the text labels corresponding to the indexes.
129
+ (label encoder should be provided).
130
+ """
131
+ waveform = self.load_audio(path)
132
+ # Fake a batch:
133
+ batch = waveform.unsqueeze(0)
134
+ rel_length = torch.tensor([1.0])
135
+ outputs = self.encode_batch(batch, rel_length)
136
+ outputs = self.mods.output_mlp(outputs).squeeze(1)
137
+ out_prob = self.hparams.softmax(outputs)
138
+ score, index = torch.max(out_prob, dim=-1)
139
+ text_lab = self.hparams.label_encoder.decode_torch(index)
140
+ return out_prob, score, index, text_lab
141
+
142
+ def forward(self, wavs, wav_lens=None, normalize=False):
143
+ return self.encode_batch(
144
+ wavs=wavs, wav_lens=wav_lens, normalize=normalize
145
+ )