Porjaz commited on
Commit
5cc84b2
1 Parent(s): 9c031bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -0
app.py CHANGED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from speechbrain.inference.interfaces import Pretrained, foreign_class
4
+
5
+
6
+ class CustomEncoderWav2vec2Classifier(Pretrained):
7
+ """A ready-to-use class for utterance-level classification (e.g, speaker-id,
8
+ language-id, emotion recognition, keyword spotting, etc).
9
+ The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model
10
+ are defined in the yaml file. If you want to
11
+ convert the predicted index into a corresponding text label, please
12
+ provide the path of the label_encoder in a variable called 'lab_encoder_file'
13
+ within the yaml.
14
+ The class can be used either to run only the encoder (encode_batch()) to
15
+ extract embeddings or to run a classification step (classify_batch()).
16
+ ```
17
+ Example
18
+ -------
19
+ >>> import torchaudio
20
+ >>> from speechbrain.pretrained import EncoderClassifier
21
+ >>> # Model is downloaded from the speechbrain HuggingFace repo
22
+ >>> tmpdir = getfixture("tmpdir")
23
+ >>> classifier = EncoderClassifier.from_hparams(
24
+ ... source="speechbrain/spkrec-ecapa-voxceleb",
25
+ ... savedir=tmpdir,
26
+ ... )
27
+ >>> # Compute embeddings
28
+ >>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav")
29
+ >>> embeddings = classifier.encode_batch(signal)
30
+ >>> # Classification
31
+ >>> prediction = classifier .classify_batch(signal)
32
+ """
33
+
34
+ def __init__(self, *args, **kwargs):
35
+ super().__init__(*args, **kwargs)
36
+
37
+ def encode_batch(self, wavs, wav_lens=None, normalize=False):
38
+ """Encodes the input audio into a single vector embedding.
39
+ The waveforms should already be in the model's desired format.
40
+ You can call:
41
+ ``normalized = <this>.normalizer(signal, sample_rate)``
42
+ to get a correctly converted signal in most cases.
43
+ Arguments
44
+ ---------
45
+ wavs : torch.tensor
46
+ Batch of waveforms [batch, time, channels] or [batch, time]
47
+ depending on the model. Make sure the sample rate is fs=16000 Hz.
48
+ wav_lens : torch.tensor
49
+ Lengths of the waveforms relative to the longest one in the
50
+ batch, tensor of shape [batch]. The longest one should have
51
+ relative length 1.0 and others len(waveform) / max_length.
52
+ Used for ignoring padding.
53
+ normalize : bool
54
+ If True, it normalizes the embeddings with the statistics
55
+ contained in mean_var_norm_emb.
56
+ Returns
57
+ -------
58
+ torch.tensor
59
+ The encoded batch
60
+ """
61
+ # Manage single waveforms in input
62
+ if len(wavs.shape) == 1:
63
+ wavs = wavs.unsqueeze(0)
64
+
65
+ # Assign full length if wav_lens is not assigned
66
+ if wav_lens is None:
67
+ wav_lens = torch.ones(wavs.shape[0], device=self.device)
68
+
69
+ # Storing waveform in the specified device
70
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
71
+ wavs = wavs.float()
72
+
73
+ # Computing features and embeddings
74
+ outputs = self.mods.wav2vec2(wavs)
75
+
76
+ # last dim will be used for AdaptativeAVG pool
77
+ outputs = self.mods.avg_pool(outputs, wav_lens)
78
+ outputs = outputs.view(outputs.shape[0], -1)
79
+ return outputs
80
+
81
+ def classify_batch(self, wavs, wav_lens=None):
82
+ """Performs classification on the top of the encoded features.
83
+ It returns the posterior probabilities, the index and, if the label
84
+ encoder is specified it also the text label.
85
+ Arguments
86
+ ---------
87
+ wavs : torch.tensor
88
+ Batch of waveforms [batch, time, channels] or [batch, time]
89
+ depending on the model. Make sure the sample rate is fs=16000 Hz.
90
+ wav_lens : torch.tensor
91
+ Lengths of the waveforms relative to the longest one in the
92
+ batch, tensor of shape [batch]. The longest one should have
93
+ relative length 1.0 and others len(waveform) / max_length.
94
+ Used for ignoring padding.
95
+ Returns
96
+ -------
97
+ out_prob
98
+ The log posterior probabilities of each class ([batch, N_class])
99
+ score:
100
+ It is the value of the log-posterior for the best class ([batch,])
101
+ index
102
+ The indexes of the best class ([batch,])
103
+ text_lab:
104
+ List with the text labels corresponding to the indexes.
105
+ (label encoder should be provided).
106
+ """
107
+ outputs = self.encode_batch(wavs, wav_lens)
108
+ outputs = self.mods.label_lin(outputs)
109
+ out_prob = self.hparams.softmax(outputs)
110
+ score, index = torch.max(out_prob, dim=-1)
111
+ text_lab = self.hparams.label_encoder.decode_torch(index)
112
+ return out_prob, score, index, text_lab
113
+
114
+ def classify_file(self, path):
115
+ """Classifies the given audiofile into the given set of labels.
116
+ Arguments
117
+ ---------
118
+ path : str
119
+ Path to audio file to classify.
120
+ Returns
121
+ -------
122
+ out_prob
123
+ The log posterior probabilities of each class ([batch, N_class])
124
+ score:
125
+ It is the value of the log-posterior for the best class ([batch,])
126
+ index
127
+ The indexes of the best class ([batch,])
128
+ text_lab:
129
+ List with the text labels corresponding to the indexes.
130
+ (label encoder should be provided).
131
+ """
132
+ waveform = self.load_audio(path)
133
+ # Fake a batch:
134
+ batch = waveform.unsqueeze(0)
135
+ rel_length = torch.tensor([1.0])
136
+ outputs = self.encode_batch(batch, rel_length)
137
+ outputs = self.mods.label_lin(outputs).squeeze(1)
138
+ out_prob = self.hparams.softmax(outputs)
139
+ score, index = torch.max(out_prob, dim=-1)
140
+ text_lab = self.hparams.label_encoder.decode_torch(index)
141
+ if text_lab[0] == "1":
142
+ text_lab = "neutral"
143
+ elif text_lab[0] == "2":
144
+ text_lab = "sadness"
145
+ elif text_lab[0] == "3":
146
+ text_lab = "joy"
147
+ elif text_lab[0] == "4":
148
+ text_lab = "anger"
149
+ elif text_lab[0] == "5":
150
+ text_lab = "affection"
151
+
152
+ return out_prob, score, index, text_lab
153
+
154
+ def forward(self, wavs, wav_lens=None, normalize=False):
155
+ return self.encode_batch(
156
+ wavs=wavs, wav_lens=wav_lens, normalize=normalize
157
+ )
158
+
159
+
160
+ def return_prediction(mic, file):
161
+ classifier = foreign_class(source="Porjaz/wavlm-base-emo-fi", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")
162
+ out_prob, score, index, text_lab = classifier.classify_file("anger.wav")
163
+ return text_lab
164
+
165
+
166
+ classifier = foreign_class(source="Porjaz/wavlm-base-emo-fi", pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier")
167
+
168
+ gradio_app = gr.Interface(
169
+ return_prediction,
170
+ inputs=[
171
+ gr.Audio(sources="microphone", type="filepath"),
172
+ gr.Audio(sources="upload", type="filepath"),
173
+ ],
174
+ outputs="text",
175
+ title="Finnish-Emotion-Recognition",
176
+ )
177
+
178
+ gradio_app.launch(share=True)