d22cs051 commited on
Commit
fc8786d
1 Parent(s): 2918b59

add app for speaker verifcation

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *__
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torch import nn
4
+ from verification import init_model
5
+
6
+
7
+
8
+ # model definition
9
+ class WaveLMSpeakerVerifi(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+ self.feature_extractor = init_model("wavlm_base_plus")
13
+ self.cosine_sim = nn.CosineSimilarity(dim=-1)
14
+ self.sigmoid = nn.Sigmoid()
15
+
16
+
17
+ def forward(self, auido1, audio2):
18
+ audio1_emb = self.feature_extractor(auido1)
19
+ audio2_emb = self.feature_extractor(audio2)
20
+ similarity = self.cosine_sim(audio1_emb, audio2_emb)
21
+ similarity = (similarity + 1) / 2 # converting (-1,1) -> (0,1)
22
+ return similarity
23
+
24
+
25
+ class SourceSeparationApp:
26
+ def __init__(self, model_path,device="cpu"):
27
+ self.model = self.load_model(model_path)
28
+ self.device = device
29
+
30
+ def load_model(self, model_path):
31
+ checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
32
+ fine_tuned_model = WaveLMSpeakerVerifi()
33
+ fine_tuned_model.load_state_dict(checkpoint["model"])
34
+ return fine_tuned_model
35
+
36
+ def verify_speaker(self, audio_file1, audio_file2):
37
+ # Load input audio
38
+ # print(f"[LOG] Audio file: {audio_file}")
39
+ input_audio_tensor1, sr1 = audio_file1[1], audio_file1[0]
40
+ input_audio_tensor2, sr2 = audio_file2[1], audio_file2[0]
41
+
42
+ if self.model is None:
43
+ return "Error: Model not loaded."
44
+
45
+ # sending input audio to PyTorch tensor
46
+ input_audio_tensor1 = torch.tensor(input_audio_tensor1,dtype=torch.float).unsqueeze(0)
47
+ input_audio_tensor1 = input_audio_tensor1.to(self.device)
48
+ input_audio_tensor2 = torch.tensor(input_audio_tensor2,dtype=torch.float).unsqueeze(0)
49
+ input_audio_tensor2 = input_audio_tensor2.to(self.device)
50
+
51
+ # Source separation using the loaded model
52
+ self.model.to(self.device)
53
+ self.model.eval()
54
+ with torch.inference_mode():
55
+ # print(f"[LOG] mix shape: {mix.shape}, s1 shape: {s1.shape}, s2 shape: {s2.shape}, noise shape: {noise.shape}")
56
+ similarity = self.model(input_audio_tensor1, input_audio_tensor2)
57
+
58
+ return similarity.item()
59
+
60
+ def run(self):
61
+ audio_input1 = gr.Audio(label="Upload or record audio")
62
+ audio_input2 = gr.Audio(label="Upload or record audio")
63
+ output_text = gr.Label(label="Similarity Score Result:")
64
+ gr.Interface(
65
+ fn=self.verify_speaker,
66
+ inputs=[audio_input1, audio_input2],
67
+ outputs=[output_text],
68
+ title="Speaker Verification",
69
+ description="Speaker Verification using fine-tuned Sepformer model.",
70
+ examples = [
71
+ ["samples/844424933481805-705-m.wav", "samples/844424932691175-645-f.wav","0"],
72
+ ["samples/844424931281875-277-f.wav", "samples/844424930801214-277-f.wav","1"],
73
+ ],
74
+ ).launch()
75
+
76
+
77
+ if __name__ == "__main__":
78
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
+ model_path = "fine-tuning-wavlm-base-plus-checkpoint.ckpt" # Replace with your model path
80
+ app = SourceSeparationApp(model_path, device=device)
81
+ app.run()
fine-tuning-wavlm-base-plus-checkpoint.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a631d58a7197bb43d1f79564a9d21959a96dba78405921246e63287c3ae79a8
3
+ size 470929785
models/__init__.py ADDED
File without changes
models/ecapa_tdnn.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchaudio.transforms as trans
7
+ from .utils import UpstreamExpert
8
+
9
+
10
+ ''' Res2Conv1d + BatchNorm1d + ReLU
11
+ '''
12
+
13
+
14
+ class Res2Conv1dReluBn(nn.Module):
15
+ '''
16
+ in_channels == out_channels == channels
17
+ '''
18
+
19
+ def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
20
+ super().__init__()
21
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
22
+ self.scale = scale
23
+ self.width = channels // scale
24
+ self.nums = scale if scale == 1 else scale - 1
25
+
26
+ self.convs = []
27
+ self.bns = []
28
+ for i in range(self.nums):
29
+ self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
30
+ self.bns.append(nn.BatchNorm1d(self.width))
31
+ self.convs = nn.ModuleList(self.convs)
32
+ self.bns = nn.ModuleList(self.bns)
33
+
34
+ def forward(self, x):
35
+ out = []
36
+ spx = torch.split(x, self.width, 1)
37
+ for i in range(self.nums):
38
+ if i == 0:
39
+ sp = spx[i]
40
+ else:
41
+ sp = sp + spx[i]
42
+ # Order: conv -> relu -> bn
43
+ sp = self.convs[i](sp)
44
+ sp = self.bns[i](F.relu(sp))
45
+ out.append(sp)
46
+ if self.scale != 1:
47
+ out.append(spx[self.nums])
48
+ out = torch.cat(out, dim=1)
49
+
50
+ return out
51
+
52
+
53
+ ''' Conv1d + BatchNorm1d + ReLU
54
+ '''
55
+
56
+
57
+ class Conv1dReluBn(nn.Module):
58
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
59
+ super().__init__()
60
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
61
+ self.bn = nn.BatchNorm1d(out_channels)
62
+
63
+ def forward(self, x):
64
+ return self.bn(F.relu(self.conv(x)))
65
+
66
+
67
+ ''' The SE connection of 1D case.
68
+ '''
69
+
70
+
71
+ class SE_Connect(nn.Module):
72
+ def __init__(self, channels, se_bottleneck_dim=128):
73
+ super().__init__()
74
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
75
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
76
+
77
+ def forward(self, x):
78
+ out = x.mean(dim=2)
79
+ out = F.relu(self.linear1(out))
80
+ out = torch.sigmoid(self.linear2(out))
81
+ out = x * out.unsqueeze(2)
82
+
83
+ return out
84
+
85
+
86
+ ''' SE-Res2Block of the ECAPA-TDNN architecture.
87
+ '''
88
+
89
+
90
+ # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
91
+ # return nn.Sequential(
92
+ # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
93
+ # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
94
+ # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
95
+ # SE_Connect(channels)
96
+ # )
97
+
98
+
99
+ class SE_Res2Block(nn.Module):
100
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
101
+ super().__init__()
102
+ self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
103
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
104
+ self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
105
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
106
+
107
+ self.shortcut = None
108
+ if in_channels != out_channels:
109
+ self.shortcut = nn.Conv1d(
110
+ in_channels=in_channels,
111
+ out_channels=out_channels,
112
+ kernel_size=1,
113
+ )
114
+
115
+ def forward(self, x):
116
+ residual = x
117
+ if self.shortcut:
118
+ residual = self.shortcut(x)
119
+
120
+ x = self.Conv1dReluBn1(x)
121
+ x = self.Res2Conv1dReluBn(x)
122
+ x = self.Conv1dReluBn2(x)
123
+ x = self.SE_Connect(x)
124
+
125
+ return x + residual
126
+
127
+
128
+ ''' Attentive weighted mean and standard deviation pooling.
129
+ '''
130
+
131
+
132
+ class AttentiveStatsPool(nn.Module):
133
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
134
+ super().__init__()
135
+ self.global_context_att = global_context_att
136
+
137
+ # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
138
+ if global_context_att:
139
+ self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
140
+ else:
141
+ self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
142
+ self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
143
+
144
+ def forward(self, x):
145
+
146
+ if self.global_context_att:
147
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
148
+ context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
149
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
150
+ else:
151
+ x_in = x
152
+
153
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
154
+ alpha = torch.tanh(self.linear1(x_in))
155
+ # alpha = F.relu(self.linear1(x_in))
156
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
157
+ mean = torch.sum(alpha * x, dim=2)
158
+ residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
159
+ std = torch.sqrt(residuals.clamp(min=1e-9))
160
+ return torch.cat([mean, std], dim=1)
161
+
162
+
163
+ class ECAPA_TDNN(nn.Module):
164
+ def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
165
+ feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
166
+ super().__init__()
167
+
168
+ self.feat_type = feat_type
169
+ self.feature_selection = feature_selection
170
+ self.update_extract = update_extract
171
+ self.sr = sr
172
+
173
+ if feat_type == "fbank" or feat_type == "mfcc":
174
+ self.update_extract = False
175
+
176
+ win_len = int(sr * 0.025)
177
+ hop_len = int(sr * 0.01)
178
+
179
+ if feat_type == 'fbank':
180
+ self.feature_extract = trans.MelSpectrogram(sample_rate=sr, n_fft=512, win_length=win_len,
181
+ hop_length=hop_len, f_min=0.0, f_max=sr // 2,
182
+ pad=0, n_mels=feat_dim)
183
+ elif feat_type == 'mfcc':
184
+ melkwargs = {
185
+ 'n_fft': 512,
186
+ 'win_length': win_len,
187
+ 'hop_length': hop_len,
188
+ 'f_min': 0.0,
189
+ 'f_max': sr // 2,
190
+ 'pad': 0
191
+ }
192
+ self.feature_extract = trans.MFCC(sample_rate=sr, n_mfcc=feat_dim, log_mels=False,
193
+ melkwargs=melkwargs)
194
+ else:
195
+ if config_path is None:
196
+ self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
197
+ else:
198
+ self.feature_extract = UpstreamExpert(config_path)
199
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
200
+ self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
201
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
202
+ self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
203
+
204
+ self.feat_num = self.get_feat_num()
205
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
206
+
207
+ if feat_type != 'fbank' and feat_type != 'mfcc':
208
+ freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
209
+ for name, param in self.feature_extract.named_parameters():
210
+ for freeze_val in freeze_list:
211
+ if freeze_val in name:
212
+ param.requires_grad = False
213
+ break
214
+
215
+ if not self.update_extract:
216
+ for param in self.feature_extract.parameters():
217
+ param.requires_grad = False
218
+
219
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
220
+ # self.channels = [channels] * 4 + [channels * 3]
221
+ self.channels = [channels] * 4 + [1536]
222
+
223
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
224
+ self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
225
+ self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
226
+ self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
227
+
228
+ # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
229
+ cat_channels = channels * 3
230
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
231
+ self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
232
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
233
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
234
+
235
+
236
+ def get_feat_num(self):
237
+ self.feature_extract.eval()
238
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
239
+ with torch.no_grad():
240
+ features = self.feature_extract(wav)
241
+ select_feature = features[self.feature_selection]
242
+ if isinstance(select_feature, (list, tuple)):
243
+ return len(select_feature)
244
+ else:
245
+ return 1
246
+
247
+ def get_feat(self, x):
248
+ if self.update_extract:
249
+ x = self.feature_extract([sample for sample in x])
250
+ else:
251
+ with torch.no_grad():
252
+ if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
253
+ x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
254
+ else:
255
+ x = self.feature_extract([sample for sample in x])
256
+
257
+ if self.feat_type == 'fbank':
258
+ x = x.log()
259
+
260
+ if self.feat_type != "fbank" and self.feat_type != "mfcc":
261
+ x = x[self.feature_selection]
262
+ if isinstance(x, (list, tuple)):
263
+ x = torch.stack(x, dim=0)
264
+ else:
265
+ x = x.unsqueeze(0)
266
+ norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
267
+ x = (norm_weights * x).sum(dim=0)
268
+ x = torch.transpose(x, 1, 2) + 1e-6
269
+
270
+ x = self.instance_norm(x)
271
+ return x
272
+
273
+ def forward(self, x):
274
+ x = self.get_feat(x)
275
+
276
+ out1 = self.layer1(x)
277
+ out2 = self.layer2(out1)
278
+ out3 = self.layer3(out2)
279
+ out4 = self.layer4(out3)
280
+
281
+ out = torch.cat([out2, out3, out4], dim=1)
282
+ out = F.relu(self.conv(out))
283
+ out = self.bn(self.pooling(out))
284
+ out = self.linear(out)
285
+
286
+ return out
287
+
288
+
289
+ def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='fbank', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
290
+ return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
291
+ feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
292
+
293
+ if __name__ == '__main__':
294
+ x = torch.zeros(2, 32000)
295
+ model = ECAPA_TDNN_SMALL(feat_dim=768, emb_dim=256, feat_type='hubert_base', feature_selection="hidden_states",
296
+ update_extract=False)
297
+
298
+ out = model(x)
299
+ # print(model)
300
+ print(out.shape)
301
+
models/utils.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import fairseq
3
+ from packaging import version
4
+ import torch.nn.functional as F
5
+ from fairseq import tasks
6
+ from fairseq.checkpoint_utils import load_checkpoint_to_cpu
7
+ from fairseq.dataclass.utils import convert_namespace_to_omegaconf
8
+ from omegaconf import OmegaConf
9
+ from s3prl.upstream.interfaces import UpstreamBase
10
+ from torch.nn.utils.rnn import pad_sequence
11
+
12
+ def load_model(filepath):
13
+ state = torch.load(filepath, map_location=lambda storage, loc: storage)
14
+ # state = load_checkpoint_to_cpu(filepath)
15
+ state["cfg"] = OmegaConf.create(state["cfg"])
16
+
17
+ if "args" in state and state["args"] is not None:
18
+ cfg = convert_namespace_to_omegaconf(state["args"])
19
+ elif "cfg" in state and state["cfg"] is not None:
20
+ cfg = state["cfg"]
21
+ else:
22
+ raise RuntimeError(
23
+ f"Neither args nor cfg exist in state keys = {state.keys()}"
24
+ )
25
+
26
+ task = tasks.setup_task(cfg.task)
27
+ if "task_state" in state:
28
+ task.load_state_dict(state["task_state"])
29
+
30
+ model = task.build_model(cfg.model)
31
+
32
+ return model, cfg, task
33
+
34
+
35
+ ###################
36
+ # UPSTREAM EXPERT #
37
+ ###################
38
+ class UpstreamExpert(UpstreamBase):
39
+ def __init__(self, ckpt, **kwargs):
40
+ super().__init__(**kwargs)
41
+ assert version.parse(fairseq.__version__) > version.parse(
42
+ "0.10.2"
43
+ ), "Please install the fairseq master branch."
44
+
45
+ model, cfg, task = load_model(ckpt)
46
+ self.model = model
47
+ self.task = task
48
+
49
+ if len(self.hooks) == 0:
50
+ module_name = "self.model.encoder.layers"
51
+ for module_id in range(len(eval(module_name))):
52
+ self.add_hook(
53
+ f"{module_name}[{module_id}]",
54
+ lambda input, output: input[0].transpose(0, 1),
55
+ )
56
+ self.add_hook("self.model.encoder", lambda input, output: output[0])
57
+
58
+ def forward(self, wavs):
59
+ if self.task.cfg.normalize:
60
+ wavs = [F.layer_norm(wav, wav.shape) for wav in wavs]
61
+
62
+ device = wavs[0].device
63
+ wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
64
+ wav_padding_mask = ~torch.lt(
65
+ torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
66
+ wav_lengths.unsqueeze(1),
67
+ )
68
+ padded_wav = pad_sequence(wavs, batch_first=True)
69
+
70
+ features, feat_padding_mask = self.model.extract_features(
71
+ padded_wav,
72
+ padding_mask=wav_padding_mask,
73
+ mask=None,
74
+ )
75
+ return {
76
+ "default": features,
77
+ }
78
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ scipy==1.7.1
2
+ fire==0.4.0
3
+ sklearn==0.0
4
+ s3prl==0.3.1
5
+ torchaudio==0.9.0
6
+ sentencepiece==0.1.96
samples/844424930801214-277-f.wav ADDED
Binary file (592 kB). View file
 
samples/844424931281875-277-f.wav ADDED
Binary file (573 kB). View file
 
samples/844424932691175-645-f.wav ADDED
Binary file (303 kB). View file
 
samples/844424933481805-705-m.wav ADDED
Binary file (619 kB). View file
 
verification.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import soundfile as sf
2
+ import torch
3
+ import fire
4
+ import torch.nn.functional as F
5
+ from torchaudio.transforms import Resample
6
+ from models.ecapa_tdnn import ECAPA_TDNN_SMALL
7
+
8
+ MODEL_LIST = ['ecapa_tdnn', 'hubert_large', 'wav2vec2_xlsr', 'unispeech_sat', "wavlm_base_plus", "wavlm_large"]
9
+
10
+
11
+ def init_model(model_name, checkpoint=None):
12
+ if model_name == 'unispeech_sat':
13
+ config_path = 'config/unispeech_sat.th'
14
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='unispeech_sat', config_path=config_path)
15
+ elif model_name == 'wavlm_base_plus':
16
+ config_path = None
17
+ model = ECAPA_TDNN_SMALL(feat_dim=768, feat_type='wavlm_base_plus', config_path=config_path)
18
+ elif model_name == 'wavlm_large':
19
+ config_path = None
20
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=config_path)
21
+ elif model_name == 'hubert_large':
22
+ config_path = None
23
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='hubert_large_ll60k', config_path=config_path)
24
+ elif model_name == 'wav2vec2_xlsr':
25
+ config_path = None
26
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wav2vec2_xlsr', config_path=config_path)
27
+ else:
28
+ model = ECAPA_TDNN_SMALL(feat_dim=40, feat_type='fbank')
29
+
30
+ if checkpoint is not None:
31
+ state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage)
32
+ model.load_state_dict(state_dict['model'], strict=False)
33
+ return model
34
+
35
+
36
+ def verification(model_name, wav1, wav2, use_gpu=True, checkpoint=None):
37
+
38
+ assert model_name in MODEL_LIST, 'The model_name should be in {}'.format(MODEL_LIST)
39
+ model = init_model(model_name, checkpoint)
40
+
41
+ wav1, sr1 = sf.read(wav1)
42
+ wav2, sr2 = sf.read(wav2)
43
+
44
+ wav1 = torch.from_numpy(wav1).unsqueeze(0).float()
45
+ wav2 = torch.from_numpy(wav2).unsqueeze(0).float()
46
+ resample1 = Resample(orig_freq=sr1, new_freq=16000)
47
+ resample2 = Resample(orig_freq=sr2, new_freq=16000)
48
+ wav1 = resample1(wav1)
49
+ wav2 = resample2(wav2)
50
+
51
+ if use_gpu:
52
+ model = model.cuda()
53
+ wav1 = wav1.cuda()
54
+ wav2 = wav2.cuda()
55
+
56
+ model.eval()
57
+ with torch.no_grad():
58
+ emb1 = model(wav1)
59
+ emb2 = model(wav2)
60
+
61
+ sim = F.cosine_similarity(emb1, emb2)
62
+ # print("The similarity score between two audios is {:.4f} (-1.0, 1.0).".format(sim[0].item()))
63
+ return sim[0].item()
64
+
65
+ def verification_batch(model_name, batch_wav1, batch_wav2, use_gpu=True, checkpoint=None):
66
+ assert model_name in MODEL_LIST, 'The model_name should be in {}'.format(MODEL_LIST)
67
+ model = init_model(model_name, checkpoint)
68
+
69
+
70
+ # print(str(batch_wav1[0]))
71
+
72
+ sr1 = sf.read(str(batch_wav1[0]))[1]
73
+ sr2 = sf.read(str(batch_wav2[0]))[1]
74
+
75
+ # print(sr1)
76
+
77
+ batch_wav1 = [torch.from_numpy(sf.read(wav)[0][:50000]).unsqueeze(0).float() for wav in batch_wav1]
78
+ batch_wav2 = [torch.from_numpy(sf.read(wav)[0][:50000]).unsqueeze(0).float() for wav in batch_wav2]
79
+
80
+ resample1 = Resample(orig_freq=sr1, new_freq=16000)
81
+ resample2 = Resample(orig_freq=sr2, new_freq=16000)
82
+
83
+
84
+
85
+ batch_wav1 = torch.cat([resample1(wav) for wav in batch_wav1], 0)
86
+ batch_wav2 = torch.cat([resample2(wav) for wav in batch_wav2], 0)
87
+
88
+ # print(batch_wav1.shape)
89
+ # print(batch_wav2.shape)
90
+
91
+ if use_gpu:
92
+ model = model.cuda()
93
+ batch_wav1 = batch_wav1.cuda()
94
+ batch_wav2 = batch_wav2.cuda()
95
+
96
+ model.eval()
97
+ with torch.no_grad():
98
+ emb1 = model(batch_wav1)
99
+ emb2 = model(batch_wav2)
100
+
101
+ sim = F.cosine_similarity(emb1, emb2 ,dim=-1)
102
+
103
+ return sim.cpu().numpy()
104
+ if __name__ == "__main__":
105
+ fire.Fire(verification)
106
+