ramkamal2000 commited on
Commit
0844354
·
1 Parent(s): 665e760

adding untracked files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. app.py +203 -0
  3. checkpoints/freevc.pth +3 -0
  4. commons.py +171 -0
  5. configs/freevc.json +54 -0
  6. mel_processing.py +112 -0
  7. models.py +351 -0
  8. modules.py +342 -0
  9. requirements.txt +11 -0
  10. sample_inputs/ntr.wav +3 -0
  11. sample_inputs/out.wav +3 -0
  12. sample_inputs/p225_001.wav +3 -0
  13. sample_inputs/p226_002.wav +3 -0
  14. sample_inputs/reference.wav +3 -0
  15. sample_inputs/target.wav +3 -0
  16. sample_inputs/timcast1.wav +3 -0
  17. speaker_encoder/__init__.py +0 -0
  18. speaker_encoder/__pycache__/__init__.cpython-39.pyc +0 -0
  19. speaker_encoder/__pycache__/audio.cpython-39.pyc +0 -0
  20. speaker_encoder/__pycache__/hparams.cpython-39.pyc +0 -0
  21. speaker_encoder/__pycache__/params_data.cpython-39.pyc +0 -0
  22. speaker_encoder/__pycache__/voice_encoder.cpython-39.pyc +0 -0
  23. speaker_encoder/audio.py +107 -0
  24. speaker_encoder/ckpt/pretrained_bak_5805000.pt +3 -0
  25. speaker_encoder/compute_embed.py +40 -0
  26. speaker_encoder/config.py +45 -0
  27. speaker_encoder/data_objects/__init__.py +2 -0
  28. speaker_encoder/data_objects/random_cycler.py +37 -0
  29. speaker_encoder/data_objects/speaker.py +40 -0
  30. speaker_encoder/data_objects/speaker_batch.py +12 -0
  31. speaker_encoder/data_objects/speaker_verification_dataset.py +56 -0
  32. speaker_encoder/data_objects/utterance.py +26 -0
  33. speaker_encoder/hparams.py +31 -0
  34. speaker_encoder/inference.py +177 -0
  35. speaker_encoder/model.py +135 -0
  36. speaker_encoder/params_data.py +29 -0
  37. speaker_encoder/params_model.py +11 -0
  38. speaker_encoder/preprocess.py +285 -0
  39. speaker_encoder/train.py +125 -0
  40. speaker_encoder/visualizations.py +178 -0
  41. speaker_encoder/voice_encoder.py +173 -0
  42. utils.py +306 -0
  43. wavlm/WavLM-Large.pt +3 -0
  44. wavlm/WavLM-Large.pt.txt +1 -0
  45. wavlm/WavLM.py +742 -0
  46. wavlm/__init__.py +1 -0
  47. wavlm/__pycache__/WavLM.cpython-39.pyc +0 -0
  48. wavlm/__pycache__/__init__.cpython-39.pyc +0 -0
  49. wavlm/__pycache__/modules.cpython-39.pyc +0 -0
  50. wavlm/modules.py +827 -0
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.wav filter=lfs diff=lfs merge=lfs -text
36
+ *.tar filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ import gradio as gr
3
+
4
+ def greet(name):
5
+ return "Hello " + name + "!!"
6
+
7
+ iface = gr.Interface(fn=greet, inputs="text", outputs="text")
8
+ iface.launch()
9
+ '''
10
+ import gradio
11
+ import os
12
+ import shutil
13
+ import gradio as gr
14
+ import sys
15
+ import string
16
+ import time
17
+ import argparse
18
+ import json
19
+ import numpy as np
20
+ import torch
21
+ import librosa
22
+ import subprocess
23
+
24
+ from pydub import AudioSegment
25
+ from scipy.io.wavfile import write, read
26
+ from transformers import WavLMModel
27
+
28
+ from TTS.tts.utils.synthesis import synthesis
29
+ from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
30
+ try:
31
+ from TTS.utils.audio import AudioProcessor
32
+ except:
33
+ from TTS.utils.audio import AudioProcessor
34
+ from TTS.tts.models import setup_model
35
+ from TTS.config import load_config
36
+ from TTS.tts.models.vits import *
37
+ from TTS.tts.utils.speakers import SpeakerManager
38
+
39
+ import utils
40
+ from models import SynthesizerTrn
41
+ from mel_processing import mel_spectrogram_torch
42
+ from speaker_encoder.voice_encoder import SpeakerEncoder
43
+
44
+ TTS_PATH = "TTS/"
45
+ sys.path.append(TTS_PATH) # set this if TTS is not installed globally
46
+
47
+ OUT_PATH = 'out/'
48
+ os.makedirs(OUT_PATH, exist_ok=True)
49
+
50
+ TTS_SPEAKERS = "yourTTS_config/speakers.json"
51
+ USE_CUDA = torch.cuda.is_available()
52
+ device = torch.device("cuda" if USE_CUDA else "cpu")
53
+
54
+ CONFIG_PATH = 'yourTTS_config/config.json'
55
+ C = load_config(CONFIG_PATH)
56
+ ap = AudioProcessor(**C.audio)
57
+
58
+ speaker_embedding = None
59
+ C.model_args['d_vector_file'] = TTS_SPEAKERS
60
+ C.model_args['use_speaker_encoder_as_loss'] = False
61
+
62
+ model = setup_model(C)
63
+
64
+ TTS_LANGUAGES = "yourTTS_config/language_ids.json"
65
+ model.language_manager.set_language_ids_from_file(TTS_LANGUAGES)
66
+
67
+ # print(model.language_manager.num_languages, model.embedded_language_dim)
68
+ # print(model.emb_l)
69
+
70
+ MODEL_PATH = 'yourTTS_config/best_model.pth.tar'
71
+ cp = torch.load(MODEL_PATH, map_location=torch.device('cpu'))
72
+
73
+ model_weights = cp['model'].copy()
74
+ for key in list(model_weights.keys()):
75
+ if "speaker_encoder" in key:
76
+ del model_weights[key]
77
+
78
+ model.load_state_dict(model_weights)
79
+ model.eval()
80
+
81
+ if USE_CUDA:
82
+ model = model.cuda()
83
+
84
+ use_griffin_lim = False
85
+
86
+ CONFIG_SE_PATH = "yourtts_config/config_se.json"
87
+ CHECKPOINT_SE_PATH = "yourtts_config/SE_checkpoint.pth.tar"
88
+ SE_speaker_manager = SpeakerManager(encoder_model_path=CHECKPOINT_SE_PATH, encoder_config_path=CONFIG_SE_PATH, use_cuda=USE_CUDA)
89
+
90
+ def compute_spec(ref_file):
91
+ y, sr = librosa.load(ref_file, sr=ap.sample_rate)
92
+ spec = ap.spectrogram(y)
93
+ spec = torch.FloatTensor(spec).unsqueeze(0)
94
+ return spec
95
+
96
+ print("Loading FreeVC...")
97
+ hps = utils.get_hparams_from_file("configs/freevc.json")
98
+ freevc = SynthesizerTrn(
99
+ hps.data.filter_length // 2 + 1,
100
+ hps.train.segment_size // hps.data.hop_length,
101
+ **hps.model).to(device)
102
+ _ = freevc.eval()
103
+ _ = utils.load_checkpoint("checkpoints/freevc.pth", freevc, None)
104
+ smodel = SpeakerEncoder('speaker_encoder/ckpt/pretrained_bak_5805000.pt')
105
+
106
+ print("Loading WavLM for content...")
107
+ cmodel = utils.get_cmodel(device).to(device)
108
+ # cmodel = WavLMModel.from_pretrained("microsoft/wavlm-large").to(device)
109
+
110
+ def voice_conversion_yourtts(da, ta):
111
+
112
+ # write(target_audio, ta[0], ta[1])
113
+ # write(driving_audio, da[0], da[1])
114
+
115
+ # !ffmpeg-normalize $target_audio -nt rms -t=-27 -o $target_audio -ar 16000 -f
116
+ # !ffmpeg-normalize $reference_audio -nt rms -t=-27 -o $reference_audio -ar 16000 -f
117
+ # !ffmpeg-normalize $driving_audio -nt rms -t=-27 -o $driving_audio -ar 16000 -f
118
+
119
+ files = [da, ta]
120
+
121
+ for file in files:
122
+ subprocess.run(["ffmpeg-normalize", file, "-nt", "rms", "-t=-27", "-o", file, "-ar", "16000", "-f"])
123
+
124
+ # ta_ = read(target_audio)
125
+
126
+ target_emb = SE_speaker_manager.compute_d_vector_from_clip([ta])
127
+ target_emb = torch.FloatTensor(target_emb).unsqueeze(0)
128
+
129
+ driving_emb = SE_speaker_manager.compute_d_vector_from_clip([da])
130
+ driving_emb = torch.FloatTensor(driving_emb).unsqueeze(0)
131
+
132
+ # Convert the voice
133
+
134
+ driving_spec = compute_spec(da)
135
+ y_lengths = torch.tensor([driving_spec.size(-1)])
136
+ if USE_CUDA:
137
+ ref_wav_voc, _, _ = model.voice_conversion(driving_spec.cuda(), y_lengths.cuda(), driving_emb.cuda(), target_emb.cuda())
138
+ ref_wav_voc = ref_wav_voc.squeeze().cpu().detach().numpy()
139
+ else:
140
+ ref_wav_voc, _, _ = model.voice_conversion(driving_spec, y_lengths, driving_emb, target_emb)
141
+ ref_wav_voc = ref_wav_voc.squeeze().detach().numpy()
142
+
143
+ # print("Reference Audio after decoder:")
144
+ # IPython.display.display(Audio(ref_wav_voc, rate=ap.sample_rate))
145
+
146
+ return (ap.sample_rate, ref_wav_voc)
147
+
148
+ def voice_conversion_freevc(src, tgt):
149
+ with torch.no_grad():
150
+ wav_tgt, _ = librosa.load(tgt, sr=hps.data.sampling_rate)
151
+ wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
152
+ g_tgt = smodel.embed_utterance(wav_tgt)
153
+ g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).to(device)
154
+ wav_src, _ = librosa.load(src, sr=hps.data.sampling_rate)
155
+ wav_src = torch.from_numpy(wav_src).unsqueeze(0).to(device)
156
+ # c = cmodel(wav_src).last_hidden_state.transpose(1, 2).to(device)
157
+ c = utils.get_content(cmodel, wav_src)
158
+ audio = freevc.infer(c, g=g_tgt)
159
+ audio = audio[0][0].data.cpu().float().numpy()
160
+ write("out.wav", hps.data.sampling_rate, audio)
161
+ out = "out.wav"
162
+ return out
163
+
164
+ model1 = gr.Dropdown(choices=["FreeVC", "YourTTS"], value="FreeVC",type="value", label="Model")
165
+ model2 = gr.Dropdown(choices=["FreeVC", "YourTTS"], value="FreeVC",type="value", label="Model")
166
+
167
+ audio1 = gr.inputs.Audio(label="Source Speaker - Input Audio", type='filepath')
168
+ audio2 = gr.inputs.Audio(label="Target Speaker - Input Audio", type='filepath')
169
+ microphone = gr.inputs.Audio(label="Source Speaker - Input Audio", source='microphone')
170
+ audio3 = gr.inputs.Audio(label="Target Speaker - Input Audio", type='filepath')
171
+
172
+ inputs_1 = [model1, audio1, audio2]
173
+ inputs_2 = [model2, microphone, audio3]
174
+
175
+ outputs_1 = gr.outputs.Audio(label="Target Speaker - Output Audio", type='filepath')
176
+ outputs_2 = gr.outputs.Audio(label="Target Speaker - Output Audio", type='filepath')
177
+
178
+ def voice_conversion(mod, sa, ta):
179
+
180
+ if mod=='FreeVC':
181
+ return voice_conversion_yourtts(sa, ta)
182
+ else:
183
+ return voice_conversion_freevc(sa, ta)
184
+
185
+ examples_1 = [['FreeVC', 'sample_inputs/ntr.wav', 'sample_inputs/timcast1.wav'], ['YourTTS', 'sample_inputs/ntr.wav', 'sample_inputs/timcast1.wav']]
186
+
187
+ vc_1 = gr.Interface(
188
+ fn=voice_conversion,
189
+ inputs=inputs_1,
190
+ outputs=outputs_1,
191
+ examples=examples_1,
192
+ description="Use this cool tool to convert your voice to another person's! \n Upload files in wav format for the source speaker and the target speaker.\n \nThis demonstration is made by T B Ramkamal, for partial credit towards completion of my Dual Degree Project"
193
+ )
194
+
195
+ vc_2 = gr.Interface(
196
+ fn=voice_conversion,
197
+ inputs=inputs_2,
198
+ outputs=outputs_2,
199
+ description="Use this cool tool to convert your voice to another person's! \n Upload files in wav format for the target speaker and record the voice of the input speaker using the microphone.\n \nThis demonstration is made by T B Ramkamal, for partial credit towards completion of my Dual Degree Project"
200
+ )
201
+
202
+ demo = gr.TabbedInterface([vc_1, vc_2], ["wav Input", "Microphone Input"], title="Voice Conversion")
203
+ demo.launch(debug='True')
checkpoints/freevc.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2cc2d047f63b80d1d6780e37611cec11a01d597560393b1fe6118158b3bd47f
3
+ size 472644351
commons.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def init_weights(m, mean=0.0, std=0.01):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ m.weight.data.normal_(mean, std)
12
+
13
+
14
+ def get_padding(kernel_size, dilation=1):
15
+ return int((kernel_size*dilation - dilation)/2)
16
+
17
+
18
+ def convert_pad_shape(pad_shape):
19
+ l = pad_shape[::-1]
20
+ pad_shape = [item for sublist in l for item in sublist]
21
+ return pad_shape
22
+
23
+
24
+ def intersperse(lst, item):
25
+ result = [item] * (len(lst) * 2 + 1)
26
+ result[1::2] = lst
27
+ return result
28
+
29
+
30
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
31
+ """KL(P||Q)"""
32
+ kl = (logs_q - logs_p) - 0.5
33
+ kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
34
+ return kl
35
+
36
+
37
+ def rand_gumbel(shape):
38
+ """Sample from the Gumbel distribution, protect from overflows."""
39
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
+ return -torch.log(-torch.log(uniform_samples))
41
+
42
+
43
+ def rand_gumbel_like(x):
44
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
+ return g
46
+
47
+
48
+ def slice_segments(x, ids_str, segment_size=4):
49
+ ret = torch.zeros_like(x[:, :, :segment_size])
50
+ for i in range(x.size(0)):
51
+ idx_str = ids_str[i]
52
+ idx_end = idx_str + segment_size
53
+ ret[i] = x[i, :, idx_str:idx_end]
54
+ return ret
55
+
56
+
57
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
58
+ b, d, t = x.size()
59
+ if x_lengths is None:
60
+ x_lengths = t
61
+ ids_str_max = x_lengths - segment_size + 1
62
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63
+ ret = slice_segments(x, ids_str, segment_size)
64
+ return ret, ids_str
65
+
66
+
67
+ def rand_spec_segments(x, x_lengths=None, segment_size=4):
68
+ b, d, t = x.size()
69
+ if x_lengths is None:
70
+ x_lengths = t
71
+ ids_str_max = x_lengths - segment_size
72
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
73
+ ret = slice_segments(x, ids_str, segment_size)
74
+ return ret, ids_str
75
+
76
+
77
+ def get_timing_signal_1d(
78
+ length, channels, min_timescale=1.0, max_timescale=1.0e4):
79
+ position = torch.arange(length, dtype=torch.float)
80
+ num_timescales = channels // 2
81
+ log_timescale_increment = (
82
+ math.log(float(max_timescale) / float(min_timescale)) /
83
+ (num_timescales - 1))
84
+ inv_timescales = min_timescale * torch.exp(
85
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
86
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
87
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
88
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
89
+ signal = signal.view(1, channels, length)
90
+ return signal
91
+
92
+
93
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
94
+ b, channels, length = x.size()
95
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
96
+ return x + signal.to(dtype=x.dtype, device=x.device)
97
+
98
+
99
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
100
+ b, channels, length = x.size()
101
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
102
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
103
+
104
+
105
+ def subsequent_mask(length):
106
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
107
+ return mask
108
+
109
+
110
+ @torch.jit.script
111
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
112
+ n_channels_int = n_channels[0]
113
+ in_act = input_a + input_b
114
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
115
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
116
+ acts = t_act * s_act
117
+ return acts
118
+
119
+
120
+ def convert_pad_shape(pad_shape):
121
+ l = pad_shape[::-1]
122
+ pad_shape = [item for sublist in l for item in sublist]
123
+ return pad_shape
124
+
125
+
126
+ def shift_1d(x):
127
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
128
+ return x
129
+
130
+
131
+ def sequence_mask(length, max_length=None):
132
+ if max_length is None:
133
+ max_length = length.max()
134
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
135
+ return x.unsqueeze(0) < length.unsqueeze(1)
136
+
137
+
138
+ def generate_path(duration, mask):
139
+ """
140
+ duration: [b, 1, t_x]
141
+ mask: [b, 1, t_y, t_x]
142
+ """
143
+ device = duration.device
144
+
145
+ b, _, t_y, t_x = mask.shape
146
+ cum_duration = torch.cumsum(duration, -1)
147
+
148
+ cum_duration_flat = cum_duration.view(b * t_x)
149
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
150
+ path = path.view(b, t_x, t_y)
151
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
152
+ path = path.unsqueeze(1).transpose(2,3) * mask
153
+ return path
154
+
155
+
156
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
157
+ if isinstance(parameters, torch.Tensor):
158
+ parameters = [parameters]
159
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
160
+ norm_type = float(norm_type)
161
+ if clip_value is not None:
162
+ clip_value = float(clip_value)
163
+
164
+ total_norm = 0
165
+ for p in parameters:
166
+ param_norm = p.grad.data.norm(norm_type)
167
+ total_norm += param_norm.item() ** norm_type
168
+ if clip_value is not None:
169
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
170
+ total_norm = total_norm ** (1. / norm_type)
171
+ return total_norm
configs/freevc.json ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 10000,
5
+ "seed": 1234,
6
+ "epochs": 10000,
7
+ "learning_rate": 2e-4,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 64,
11
+ "fp16_run": false,
12
+ "lr_decay": 0.999875,
13
+ "segment_size": 8960,
14
+ "init_lr_ratio": 1,
15
+ "warmup_epochs": 0,
16
+ "c_mel": 45,
17
+ "c_kl": 1.0,
18
+ "use_sr": true,
19
+ "max_speclen": 128,
20
+ "port": "8001"
21
+ },
22
+ "data": {
23
+ "training_files":"filelists/train.txt",
24
+ "validation_files":"filelists/val.txt",
25
+ "max_wav_value": 32768.0,
26
+ "sampling_rate": 16000,
27
+ "filter_length": 1280,
28
+ "hop_length": 320,
29
+ "win_length": 1280,
30
+ "n_mel_channels": 80,
31
+ "mel_fmin": 0.0,
32
+ "mel_fmax": null
33
+ },
34
+ "model": {
35
+ "inter_channels": 192,
36
+ "hidden_channels": 192,
37
+ "filter_channels": 768,
38
+ "n_heads": 2,
39
+ "n_layers": 6,
40
+ "kernel_size": 3,
41
+ "p_dropout": 0.1,
42
+ "resblock": "1",
43
+ "resblock_kernel_sizes": [3,7,11],
44
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
45
+ "upsample_rates": [10,8,2,2],
46
+ "upsample_initial_channel": 512,
47
+ "upsample_kernel_sizes": [16,16,4,4],
48
+ "n_layers_q": 3,
49
+ "use_spectral_norm": false,
50
+ "gin_channels": 256,
51
+ "ssl_dim": 1024,
52
+ "use_spk": true
53
+ }
54
+ }
mel_processing.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.data
8
+ import numpy as np
9
+ import librosa
10
+ import librosa.util as librosa_util
11
+ from librosa.util import normalize, pad_center, tiny
12
+ from scipy.signal import get_window
13
+ from scipy.io.wavfile import read
14
+ from librosa.filters import mel as librosa_mel_fn
15
+
16
+ MAX_WAV_VALUE = 32768.0
17
+
18
+
19
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
20
+ """
21
+ PARAMS
22
+ ------
23
+ C: compression factor
24
+ """
25
+ return torch.log(torch.clamp(x, min=clip_val) * C)
26
+
27
+
28
+ def dynamic_range_decompression_torch(x, C=1):
29
+ """
30
+ PARAMS
31
+ ------
32
+ C: compression factor used to compress
33
+ """
34
+ return torch.exp(x) / C
35
+
36
+
37
+ def spectral_normalize_torch(magnitudes):
38
+ output = dynamic_range_compression_torch(magnitudes)
39
+ return output
40
+
41
+
42
+ def spectral_de_normalize_torch(magnitudes):
43
+ output = dynamic_range_decompression_torch(magnitudes)
44
+ return output
45
+
46
+
47
+ mel_basis = {}
48
+ hann_window = {}
49
+
50
+
51
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
52
+ if torch.min(y) < -1.:
53
+ print('min value is ', torch.min(y))
54
+ if torch.max(y) > 1.:
55
+ print('max value is ', torch.max(y))
56
+
57
+ global hann_window
58
+ dtype_device = str(y.dtype) + '_' + str(y.device)
59
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
60
+ if wnsize_dtype_device not in hann_window:
61
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
62
+
63
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
64
+ y = y.squeeze(1)
65
+
66
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
67
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
68
+
69
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
70
+ return spec
71
+
72
+
73
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
74
+ global mel_basis
75
+ dtype_device = str(spec.dtype) + '_' + str(spec.device)
76
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
77
+ if fmax_dtype_device not in mel_basis:
78
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
79
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
80
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
81
+ spec = spectral_normalize_torch(spec)
82
+ return spec
83
+
84
+
85
+ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
86
+ if torch.min(y) < -1.:
87
+ print('min value is ', torch.min(y))
88
+ if torch.max(y) > 1.:
89
+ print('max value is ', torch.max(y))
90
+
91
+ global mel_basis, hann_window
92
+ dtype_device = str(y.dtype) + '_' + str(y.device)
93
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
94
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
95
+ if fmax_dtype_device not in mel_basis:
96
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
97
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
98
+ if wnsize_dtype_device not in hann_window:
99
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
100
+
101
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
102
+ y = y.squeeze(1)
103
+
104
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
105
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
106
+
107
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
108
+
109
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
110
+ spec = spectral_normalize_torch(spec)
111
+
112
+ return spec
models.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ import commons
8
+ import modules
9
+
10
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+ from commons import init_weights, get_padding
13
+
14
+
15
+ class ResidualCouplingBlock(nn.Module):
16
+ def __init__(self,
17
+ channels,
18
+ hidden_channels,
19
+ kernel_size,
20
+ dilation_rate,
21
+ n_layers,
22
+ n_flows=4,
23
+ gin_channels=0):
24
+ super().__init__()
25
+ self.channels = channels
26
+ self.hidden_channels = hidden_channels
27
+ self.kernel_size = kernel_size
28
+ self.dilation_rate = dilation_rate
29
+ self.n_layers = n_layers
30
+ self.n_flows = n_flows
31
+ self.gin_channels = gin_channels
32
+
33
+ self.flows = nn.ModuleList()
34
+ for i in range(n_flows):
35
+ self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
36
+ self.flows.append(modules.Flip())
37
+
38
+ def forward(self, x, x_mask, g=None, reverse=False):
39
+ if not reverse:
40
+ for flow in self.flows:
41
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
42
+ else:
43
+ for flow in reversed(self.flows):
44
+ x = flow(x, x_mask, g=g, reverse=reverse)
45
+ return x
46
+
47
+
48
+ class Encoder(nn.Module):
49
+ def __init__(self,
50
+ in_channels,
51
+ out_channels,
52
+ hidden_channels,
53
+ kernel_size,
54
+ dilation_rate,
55
+ n_layers,
56
+ gin_channels=0):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ self.out_channels = out_channels
60
+ self.hidden_channels = hidden_channels
61
+ self.kernel_size = kernel_size
62
+ self.dilation_rate = dilation_rate
63
+ self.n_layers = n_layers
64
+ self.gin_channels = gin_channels
65
+
66
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
67
+ self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
68
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
69
+
70
+ def forward(self, x, x_lengths, g=None):
71
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
72
+ x = self.pre(x) * x_mask
73
+ x = self.enc(x, x_mask, g=g)
74
+ stats = self.proj(x) * x_mask
75
+ m, logs = torch.split(stats, self.out_channels, dim=1)
76
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
77
+ return z, m, logs, x_mask
78
+
79
+
80
+ class Generator(torch.nn.Module):
81
+ def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
82
+ super(Generator, self).__init__()
83
+ self.num_kernels = len(resblock_kernel_sizes)
84
+ self.num_upsamples = len(upsample_rates)
85
+ self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
86
+ resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
87
+
88
+ self.ups = nn.ModuleList()
89
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
90
+ self.ups.append(weight_norm(
91
+ ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
92
+ k, u, padding=(k-u)//2)))
93
+
94
+ self.resblocks = nn.ModuleList()
95
+ for i in range(len(self.ups)):
96
+ ch = upsample_initial_channel//(2**(i+1))
97
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
98
+ self.resblocks.append(resblock(ch, k, d))
99
+
100
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
101
+ self.ups.apply(init_weights)
102
+
103
+ if gin_channels != 0:
104
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
105
+
106
+ def forward(self, x, g=None):
107
+ x = self.conv_pre(x)
108
+ if g is not None:
109
+ x = x + self.cond(g)
110
+
111
+ for i in range(self.num_upsamples):
112
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
113
+ x = self.ups[i](x)
114
+ xs = None
115
+ for j in range(self.num_kernels):
116
+ if xs is None:
117
+ xs = self.resblocks[i*self.num_kernels+j](x)
118
+ else:
119
+ xs += self.resblocks[i*self.num_kernels+j](x)
120
+ x = xs / self.num_kernels
121
+ x = F.leaky_relu(x)
122
+ x = self.conv_post(x)
123
+ x = torch.tanh(x)
124
+
125
+ return x
126
+
127
+ def remove_weight_norm(self):
128
+ print('Removing weight norm...')
129
+ for l in self.ups:
130
+ remove_weight_norm(l)
131
+ for l in self.resblocks:
132
+ l.remove_weight_norm()
133
+
134
+
135
+ class DiscriminatorP(torch.nn.Module):
136
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
137
+ super(DiscriminatorP, self).__init__()
138
+ self.period = period
139
+ self.use_spectral_norm = use_spectral_norm
140
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
141
+ self.convs = nn.ModuleList([
142
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
143
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
144
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
145
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
146
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
147
+ ])
148
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
149
+
150
+ def forward(self, x):
151
+ fmap = []
152
+
153
+ # 1d to 2d
154
+ b, c, t = x.shape
155
+ if t % self.period != 0: # pad first
156
+ n_pad = self.period - (t % self.period)
157
+ x = F.pad(x, (0, n_pad), "reflect")
158
+ t = t + n_pad
159
+ x = x.view(b, c, t // self.period, self.period)
160
+
161
+ for l in self.convs:
162
+ x = l(x)
163
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
164
+ fmap.append(x)
165
+ x = self.conv_post(x)
166
+ fmap.append(x)
167
+ x = torch.flatten(x, 1, -1)
168
+
169
+ return x, fmap
170
+
171
+
172
+ class DiscriminatorS(torch.nn.Module):
173
+ def __init__(self, use_spectral_norm=False):
174
+ super(DiscriminatorS, self).__init__()
175
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
176
+ self.convs = nn.ModuleList([
177
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
178
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
179
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
180
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
181
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
182
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
183
+ ])
184
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
185
+
186
+ def forward(self, x):
187
+ fmap = []
188
+
189
+ for l in self.convs:
190
+ x = l(x)
191
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
192
+ fmap.append(x)
193
+ x = self.conv_post(x)
194
+ fmap.append(x)
195
+ x = torch.flatten(x, 1, -1)
196
+
197
+ return x, fmap
198
+
199
+
200
+ class MultiPeriodDiscriminator(torch.nn.Module):
201
+ def __init__(self, use_spectral_norm=False):
202
+ super(MultiPeriodDiscriminator, self).__init__()
203
+ periods = [2,3,5,7,11]
204
+
205
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
206
+ discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
207
+ self.discriminators = nn.ModuleList(discs)
208
+
209
+ def forward(self, y, y_hat):
210
+ y_d_rs = []
211
+ y_d_gs = []
212
+ fmap_rs = []
213
+ fmap_gs = []
214
+ for i, d in enumerate(self.discriminators):
215
+ y_d_r, fmap_r = d(y)
216
+ y_d_g, fmap_g = d(y_hat)
217
+ y_d_rs.append(y_d_r)
218
+ y_d_gs.append(y_d_g)
219
+ fmap_rs.append(fmap_r)
220
+ fmap_gs.append(fmap_g)
221
+
222
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
223
+
224
+
225
+ class SpeakerEncoder(torch.nn.Module):
226
+ def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256):
227
+ super(SpeakerEncoder, self).__init__()
228
+ self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
229
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
230
+ self.relu = nn.ReLU()
231
+
232
+ def forward(self, mels):
233
+ self.lstm.flatten_parameters()
234
+ _, (hidden, _) = self.lstm(mels)
235
+ embeds_raw = self.relu(self.linear(hidden[-1]))
236
+ return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
237
+
238
+ def compute_partial_slices(self, total_frames, partial_frames, partial_hop):
239
+ mel_slices = []
240
+ for i in range(0, total_frames-partial_frames, partial_hop):
241
+ mel_range = torch.arange(i, i+partial_frames)
242
+ mel_slices.append(mel_range)
243
+
244
+ return mel_slices
245
+
246
+ def embed_utterance(self, mel, partial_frames=128, partial_hop=64):
247
+ mel_len = mel.size(1)
248
+ last_mel = mel[:,-partial_frames:]
249
+
250
+ if mel_len > partial_frames:
251
+ mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop)
252
+ mels = list(mel[:,s] for s in mel_slices)
253
+ mels.append(last_mel)
254
+ mels = torch.stack(tuple(mels), 0).squeeze(1)
255
+
256
+ with torch.no_grad():
257
+ partial_embeds = self(mels)
258
+ embed = torch.mean(partial_embeds, axis=0).unsqueeze(0)
259
+ #embed = embed / torch.linalg.norm(embed, 2)
260
+ else:
261
+ with torch.no_grad():
262
+ embed = self(last_mel)
263
+
264
+ return embed
265
+
266
+
267
+ class SynthesizerTrn(nn.Module):
268
+ """
269
+ Synthesizer for Training
270
+ """
271
+
272
+ def __init__(self,
273
+ spec_channels,
274
+ segment_size,
275
+ inter_channels,
276
+ hidden_channels,
277
+ filter_channels,
278
+ n_heads,
279
+ n_layers,
280
+ kernel_size,
281
+ p_dropout,
282
+ resblock,
283
+ resblock_kernel_sizes,
284
+ resblock_dilation_sizes,
285
+ upsample_rates,
286
+ upsample_initial_channel,
287
+ upsample_kernel_sizes,
288
+ gin_channels,
289
+ ssl_dim,
290
+ use_spk,
291
+ **kwargs):
292
+
293
+ super().__init__()
294
+ self.spec_channels = spec_channels
295
+ self.inter_channels = inter_channels
296
+ self.hidden_channels = hidden_channels
297
+ self.filter_channels = filter_channels
298
+ self.n_heads = n_heads
299
+ self.n_layers = n_layers
300
+ self.kernel_size = kernel_size
301
+ self.p_dropout = p_dropout
302
+ self.resblock = resblock
303
+ self.resblock_kernel_sizes = resblock_kernel_sizes
304
+ self.resblock_dilation_sizes = resblock_dilation_sizes
305
+ self.upsample_rates = upsample_rates
306
+ self.upsample_initial_channel = upsample_initial_channel
307
+ self.upsample_kernel_sizes = upsample_kernel_sizes
308
+ self.segment_size = segment_size
309
+ self.gin_channels = gin_channels
310
+ self.ssl_dim = ssl_dim
311
+ self.use_spk = use_spk
312
+
313
+ self.enc_p = Encoder(ssl_dim, inter_channels, hidden_channels, 5, 1, 16)
314
+ self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
315
+ self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
316
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
317
+
318
+ if not self.use_spk:
319
+ self.enc_spk = SpeakerEncoder(model_hidden_size=gin_channels, model_embedding_size=gin_channels)
320
+
321
+ def forward(self, c, spec, g=None, mel=None, c_lengths=None, spec_lengths=None):
322
+ if c_lengths == None:
323
+ c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
324
+ if spec_lengths == None:
325
+ spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device)
326
+
327
+ if not self.use_spk:
328
+ g = self.enc_spk(mel.transpose(1,2))
329
+ g = g.unsqueeze(-1)
330
+
331
+ _, m_p, logs_p, _ = self.enc_p(c, c_lengths)
332
+ z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g)
333
+ z_p = self.flow(z, spec_mask, g=g)
334
+
335
+ z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size)
336
+ o = self.dec(z_slice, g=g)
337
+
338
+ return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
339
+
340
+ def infer(self, c, g=None, mel=None, c_lengths=None):
341
+ if c_lengths == None:
342
+ c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
343
+ if not self.use_spk:
344
+ g = self.enc_spk.embed_utterance(mel.transpose(1,2))
345
+ g = g.unsqueeze(-1)
346
+
347
+ z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths)
348
+ z = self.flow(z_p, c_mask, g=g, reverse=True)
349
+ o = self.dec(z * c_mask, g=g)
350
+
351
+ return o
modules.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import scipy
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
+ from torch.nn.utils import weight_norm, remove_weight_norm
11
+
12
+ import commons
13
+ from commons import init_weights, get_padding
14
+
15
+
16
+ LRELU_SLOPE = 0.1
17
+
18
+
19
+ class LayerNorm(nn.Module):
20
+ def __init__(self, channels, eps=1e-5):
21
+ super().__init__()
22
+ self.channels = channels
23
+ self.eps = eps
24
+
25
+ self.gamma = nn.Parameter(torch.ones(channels))
26
+ self.beta = nn.Parameter(torch.zeros(channels))
27
+
28
+ def forward(self, x):
29
+ x = x.transpose(1, -1)
30
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
31
+ return x.transpose(1, -1)
32
+
33
+
34
+ class ConvReluNorm(nn.Module):
35
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
36
+ super().__init__()
37
+ self.in_channels = in_channels
38
+ self.hidden_channels = hidden_channels
39
+ self.out_channels = out_channels
40
+ self.kernel_size = kernel_size
41
+ self.n_layers = n_layers
42
+ self.p_dropout = p_dropout
43
+ assert n_layers > 1, "Number of layers should be larger than 0."
44
+
45
+ self.conv_layers = nn.ModuleList()
46
+ self.norm_layers = nn.ModuleList()
47
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
48
+ self.norm_layers.append(LayerNorm(hidden_channels))
49
+ self.relu_drop = nn.Sequential(
50
+ nn.ReLU(),
51
+ nn.Dropout(p_dropout))
52
+ for _ in range(n_layers-1):
53
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
54
+ self.norm_layers.append(LayerNorm(hidden_channels))
55
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
56
+ self.proj.weight.data.zero_()
57
+ self.proj.bias.data.zero_()
58
+
59
+ def forward(self, x, x_mask):
60
+ x_org = x
61
+ for i in range(self.n_layers):
62
+ x = self.conv_layers[i](x * x_mask)
63
+ x = self.norm_layers[i](x)
64
+ x = self.relu_drop(x)
65
+ x = x_org + self.proj(x)
66
+ return x * x_mask
67
+
68
+
69
+ class DDSConv(nn.Module):
70
+ """
71
+ Dialted and Depth-Separable Convolution
72
+ """
73
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
74
+ super().__init__()
75
+ self.channels = channels
76
+ self.kernel_size = kernel_size
77
+ self.n_layers = n_layers
78
+ self.p_dropout = p_dropout
79
+
80
+ self.drop = nn.Dropout(p_dropout)
81
+ self.convs_sep = nn.ModuleList()
82
+ self.convs_1x1 = nn.ModuleList()
83
+ self.norms_1 = nn.ModuleList()
84
+ self.norms_2 = nn.ModuleList()
85
+ for i in range(n_layers):
86
+ dilation = kernel_size ** i
87
+ padding = (kernel_size * dilation - dilation) // 2
88
+ self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
89
+ groups=channels, dilation=dilation, padding=padding
90
+ ))
91
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
92
+ self.norms_1.append(LayerNorm(channels))
93
+ self.norms_2.append(LayerNorm(channels))
94
+
95
+ def forward(self, x, x_mask, g=None):
96
+ if g is not None:
97
+ x = x + g
98
+ for i in range(self.n_layers):
99
+ y = self.convs_sep[i](x * x_mask)
100
+ y = self.norms_1[i](y)
101
+ y = F.gelu(y)
102
+ y = self.convs_1x1[i](y)
103
+ y = self.norms_2[i](y)
104
+ y = F.gelu(y)
105
+ y = self.drop(y)
106
+ x = x + y
107
+ return x * x_mask
108
+
109
+
110
+ class WN(torch.nn.Module):
111
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
112
+ super(WN, self).__init__()
113
+ assert(kernel_size % 2 == 1)
114
+ self.hidden_channels =hidden_channels
115
+ self.kernel_size = kernel_size,
116
+ self.dilation_rate = dilation_rate
117
+ self.n_layers = n_layers
118
+ self.gin_channels = gin_channels
119
+ self.p_dropout = p_dropout
120
+
121
+ self.in_layers = torch.nn.ModuleList()
122
+ self.res_skip_layers = torch.nn.ModuleList()
123
+ self.drop = nn.Dropout(p_dropout)
124
+
125
+ if gin_channels != 0:
126
+ cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
127
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
128
+
129
+ for i in range(n_layers):
130
+ dilation = dilation_rate ** i
131
+ padding = int((kernel_size * dilation - dilation) / 2)
132
+ in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
133
+ dilation=dilation, padding=padding)
134
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
135
+ self.in_layers.append(in_layer)
136
+
137
+ # last one is not necessary
138
+ if i < n_layers - 1:
139
+ res_skip_channels = 2 * hidden_channels
140
+ else:
141
+ res_skip_channels = hidden_channels
142
+
143
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
144
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
145
+ self.res_skip_layers.append(res_skip_layer)
146
+
147
+ def forward(self, x, x_mask, g=None, **kwargs):
148
+ output = torch.zeros_like(x)
149
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
150
+
151
+ if g is not None:
152
+ g = self.cond_layer(g)
153
+
154
+ for i in range(self.n_layers):
155
+ x_in = self.in_layers[i](x)
156
+ if g is not None:
157
+ cond_offset = i * 2 * self.hidden_channels
158
+ g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
159
+ else:
160
+ g_l = torch.zeros_like(x_in)
161
+
162
+ acts = commons.fused_add_tanh_sigmoid_multiply(
163
+ x_in,
164
+ g_l,
165
+ n_channels_tensor)
166
+ acts = self.drop(acts)
167
+
168
+ res_skip_acts = self.res_skip_layers[i](acts)
169
+ if i < self.n_layers - 1:
170
+ res_acts = res_skip_acts[:,:self.hidden_channels,:]
171
+ x = (x + res_acts) * x_mask
172
+ output = output + res_skip_acts[:,self.hidden_channels:,:]
173
+ else:
174
+ output = output + res_skip_acts
175
+ return output * x_mask
176
+
177
+ def remove_weight_norm(self):
178
+ if self.gin_channels != 0:
179
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
180
+ for l in self.in_layers:
181
+ torch.nn.utils.remove_weight_norm(l)
182
+ for l in self.res_skip_layers:
183
+ torch.nn.utils.remove_weight_norm(l)
184
+
185
+
186
+ class ResBlock1(torch.nn.Module):
187
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
188
+ super(ResBlock1, self).__init__()
189
+ self.convs1 = nn.ModuleList([
190
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
191
+ padding=get_padding(kernel_size, dilation[0]))),
192
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
193
+ padding=get_padding(kernel_size, dilation[1]))),
194
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
195
+ padding=get_padding(kernel_size, dilation[2])))
196
+ ])
197
+ self.convs1.apply(init_weights)
198
+
199
+ self.convs2 = nn.ModuleList([
200
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
201
+ padding=get_padding(kernel_size, 1))),
202
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
203
+ padding=get_padding(kernel_size, 1))),
204
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
205
+ padding=get_padding(kernel_size, 1)))
206
+ ])
207
+ self.convs2.apply(init_weights)
208
+
209
+ def forward(self, x, x_mask=None):
210
+ for c1, c2 in zip(self.convs1, self.convs2):
211
+ xt = F.leaky_relu(x, LRELU_SLOPE)
212
+ if x_mask is not None:
213
+ xt = xt * x_mask
214
+ xt = c1(xt)
215
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
216
+ if x_mask is not None:
217
+ xt = xt * x_mask
218
+ xt = c2(xt)
219
+ x = xt + x
220
+ if x_mask is not None:
221
+ x = x * x_mask
222
+ return x
223
+
224
+ def remove_weight_norm(self):
225
+ for l in self.convs1:
226
+ remove_weight_norm(l)
227
+ for l in self.convs2:
228
+ remove_weight_norm(l)
229
+
230
+
231
+ class ResBlock2(torch.nn.Module):
232
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
233
+ super(ResBlock2, self).__init__()
234
+ self.convs = nn.ModuleList([
235
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
236
+ padding=get_padding(kernel_size, dilation[0]))),
237
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
238
+ padding=get_padding(kernel_size, dilation[1])))
239
+ ])
240
+ self.convs.apply(init_weights)
241
+
242
+ def forward(self, x, x_mask=None):
243
+ for c in self.convs:
244
+ xt = F.leaky_relu(x, LRELU_SLOPE)
245
+ if x_mask is not None:
246
+ xt = xt * x_mask
247
+ xt = c(xt)
248
+ x = xt + x
249
+ if x_mask is not None:
250
+ x = x * x_mask
251
+ return x
252
+
253
+ def remove_weight_norm(self):
254
+ for l in self.convs:
255
+ remove_weight_norm(l)
256
+
257
+
258
+ class Log(nn.Module):
259
+ def forward(self, x, x_mask, reverse=False, **kwargs):
260
+ if not reverse:
261
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
262
+ logdet = torch.sum(-y, [1, 2])
263
+ return y, logdet
264
+ else:
265
+ x = torch.exp(x) * x_mask
266
+ return x
267
+
268
+
269
+ class Flip(nn.Module):
270
+ def forward(self, x, *args, reverse=False, **kwargs):
271
+ x = torch.flip(x, [1])
272
+ if not reverse:
273
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
274
+ return x, logdet
275
+ else:
276
+ return x
277
+
278
+
279
+ class ElementwiseAffine(nn.Module):
280
+ def __init__(self, channels):
281
+ super().__init__()
282
+ self.channels = channels
283
+ self.m = nn.Parameter(torch.zeros(channels,1))
284
+ self.logs = nn.Parameter(torch.zeros(channels,1))
285
+
286
+ def forward(self, x, x_mask, reverse=False, **kwargs):
287
+ if not reverse:
288
+ y = self.m + torch.exp(self.logs) * x
289
+ y = y * x_mask
290
+ logdet = torch.sum(self.logs * x_mask, [1,2])
291
+ return y, logdet
292
+ else:
293
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
294
+ return x
295
+
296
+
297
+ class ResidualCouplingLayer(nn.Module):
298
+ def __init__(self,
299
+ channels,
300
+ hidden_channels,
301
+ kernel_size,
302
+ dilation_rate,
303
+ n_layers,
304
+ p_dropout=0,
305
+ gin_channels=0,
306
+ mean_only=False):
307
+ assert channels % 2 == 0, "channels should be divisible by 2"
308
+ super().__init__()
309
+ self.channels = channels
310
+ self.hidden_channels = hidden_channels
311
+ self.kernel_size = kernel_size
312
+ self.dilation_rate = dilation_rate
313
+ self.n_layers = n_layers
314
+ self.half_channels = channels // 2
315
+ self.mean_only = mean_only
316
+
317
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
318
+ self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
319
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
320
+ self.post.weight.data.zero_()
321
+ self.post.bias.data.zero_()
322
+
323
+ def forward(self, x, x_mask, g=None, reverse=False):
324
+ x0, x1 = torch.split(x, [self.half_channels]*2, 1)
325
+ h = self.pre(x0) * x_mask
326
+ h = self.enc(h, x_mask, g=g)
327
+ stats = self.post(h) * x_mask
328
+ if not self.mean_only:
329
+ m, logs = torch.split(stats, [self.half_channels]*2, 1)
330
+ else:
331
+ m = stats
332
+ logs = torch.zeros_like(m)
333
+
334
+ if not reverse:
335
+ x1 = m + x1 * torch.exp(logs) * x_mask
336
+ x = torch.cat([x0, x1], 1)
337
+ logdet = torch.sum(logs, [1,2])
338
+ return x, logdet
339
+ else:
340
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
341
+ x = torch.cat([x0, x1], 1)
342
+ return x
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/Edresson/Coqui-TTS@multilingual-torchaudio-SE
2
+ torchaudio==0.9.0
3
+ pydub
4
+ ffmpeg-normalize==1.21.0
5
+ numpy
6
+ scipy
7
+ torch
8
+ transformers
9
+ librosa==0.8.1
10
+ webrtcvad==2.0.10
11
+ gradio==3.22.1
sample_inputs/ntr.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b75c56ba7545d0a96bf6a12c02ef38edc4beded66fd4d32d1b92543045e43617
3
+ size 1940444
sample_inputs/out.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:935604738717e5bcdfe59dc5fd2842143cfb9a6073228b06c98d53f5b4ec9bec
3
+ size 103738
sample_inputs/p225_001.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b15dc9bbf0ea3cb0f0f02cf20bb60538a77fc5c6b87769e20cfece81891c78d4
3
+ size 52058
sample_inputs/p226_002.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1dc155a7dbfab1a3150a321082789f610b9c9328e84e0273c3b648273acf0a56
3
+ size 138084
sample_inputs/reference.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1dbdd762a5bca492f82244838df309e94149236a7c659ef192cd6779d2b69983
3
+ size 640078
sample_inputs/target.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a43db3ea5e4ffa59b9a6675940661160a255ba9d3bef4133855298a1ff38ee8e
3
+ size 704078
sample_inputs/timcast1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2fb4d35e5e20c59e6deb69694da0bd403f80704e2f3d9b8d4c4d1a5b558bc6c1
3
+ size 1764044
speaker_encoder/__init__.py ADDED
File without changes
speaker_encoder/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (138 Bytes). View file
 
speaker_encoder/__pycache__/audio.cpython-39.pyc ADDED
Binary file (3.75 kB). View file
 
speaker_encoder/__pycache__/hparams.cpython-39.pyc ADDED
Binary file (497 Bytes). View file
 
speaker_encoder/__pycache__/params_data.cpython-39.pyc ADDED
Binary file (445 Bytes). View file
 
speaker_encoder/__pycache__/voice_encoder.cpython-39.pyc ADDED
Binary file (8.28 kB). View file
 
speaker_encoder/audio.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.ndimage.morphology import binary_dilation
2
+ from speaker_encoder.params_data import *
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+ import numpy as np
6
+ import webrtcvad
7
+ import librosa
8
+ import struct
9
+
10
+ int16_max = (2 ** 15) - 1
11
+
12
+
13
+ def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
14
+ source_sr: Optional[int] = None):
15
+ """
16
+ Applies the preprocessing operations used in training the Speaker Encoder to a waveform
17
+ either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
18
+
19
+ :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
20
+ just .wav), either the waveform as a numpy array of floats.
21
+ :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
22
+ preprocessing. After preprocessing, the waveform's sampling rate will match the data
23
+ hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
24
+ this argument will be ignored.
25
+ """
26
+ # Load the wav from disk if needed
27
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
28
+ wav, source_sr = librosa.load(fpath_or_wav, sr=None)
29
+ else:
30
+ wav = fpath_or_wav
31
+
32
+ # Resample the wav if needed
33
+ if source_sr is not None and source_sr != sampling_rate:
34
+ wav = librosa.resample(wav, source_sr, sampling_rate)
35
+
36
+ # Apply the preprocessing: normalize volume and shorten long silences
37
+ wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
38
+ wav = trim_long_silences(wav)
39
+
40
+ return wav
41
+
42
+
43
+ def wav_to_mel_spectrogram(wav):
44
+ """
45
+ Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
46
+ Note: this not a log-mel spectrogram.
47
+ """
48
+ frames = librosa.feature.melspectrogram(
49
+ y=wav,
50
+ sr=sampling_rate,
51
+ n_fft=int(sampling_rate * mel_window_length / 1000),
52
+ hop_length=int(sampling_rate * mel_window_step / 1000),
53
+ n_mels=mel_n_channels
54
+ )
55
+ return frames.astype(np.float32).T
56
+
57
+
58
+ def trim_long_silences(wav):
59
+ """
60
+ Ensures that segments without voice in the waveform remain no longer than a
61
+ threshold determined by the VAD parameters in params.py.
62
+
63
+ :param wav: the raw waveform as a numpy array of floats
64
+ :return: the same waveform with silences trimmed away (length <= original wav length)
65
+ """
66
+ # Compute the voice detection window size
67
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
68
+
69
+ # Trim the end of the audio to have a multiple of the window size
70
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
71
+
72
+ # Convert the float waveform to 16-bit mono PCM
73
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
74
+
75
+ # Perform voice activation detection
76
+ voice_flags = []
77
+ vad = webrtcvad.Vad(mode=3)
78
+ for window_start in range(0, len(wav), samples_per_window):
79
+ window_end = window_start + samples_per_window
80
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
81
+ sample_rate=sampling_rate))
82
+ voice_flags = np.array(voice_flags)
83
+
84
+ # Smooth the voice detection with a moving average
85
+ def moving_average(array, width):
86
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
87
+ ret = np.cumsum(array_padded, dtype=float)
88
+ ret[width:] = ret[width:] - ret[:-width]
89
+ return ret[width - 1:] / width
90
+
91
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
92
+ audio_mask = np.round(audio_mask).astype(np.bool)
93
+
94
+ # Dilate the voiced regions
95
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
96
+ audio_mask = np.repeat(audio_mask, samples_per_window)
97
+
98
+ return wav[audio_mask == True]
99
+
100
+
101
+ def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
102
+ if increase_only and decrease_only:
103
+ raise ValueError("Both increase only and decrease only are set")
104
+ dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
105
+ if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
106
+ return wav
107
+ return wav * (10 ** (dBFS_change / 20))
speaker_encoder/ckpt/pretrained_bak_5805000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc7ff82ef75becd495aab2ede3a8220da393a717f178ae9534df355a6173bbca
3
+ size 17090379
speaker_encoder/compute_embed.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder import inference as encoder
2
+ from multiprocessing.pool import Pool
3
+ from functools import partial
4
+ from pathlib import Path
5
+ # from utils import logmmse
6
+ # from tqdm import tqdm
7
+ # import numpy as np
8
+ # import librosa
9
+
10
+
11
+ def embed_utterance(fpaths, encoder_model_fpath):
12
+ if not encoder.is_loaded():
13
+ encoder.load_model(encoder_model_fpath)
14
+
15
+ # Compute the speaker embedding of the utterance
16
+ wav_fpath, embed_fpath = fpaths
17
+ wav = np.load(wav_fpath)
18
+ wav = encoder.preprocess_wav(wav)
19
+ embed = encoder.embed_utterance(wav)
20
+ np.save(embed_fpath, embed, allow_pickle=False)
21
+
22
+
23
+ def create_embeddings(outdir_root: Path, wav_dir: Path, encoder_model_fpath: Path, n_processes: int):
24
+
25
+ wav_dir = outdir_root.joinpath("audio")
26
+ metadata_fpath = synthesizer_root.joinpath("train.txt")
27
+ assert wav_dir.exists() and metadata_fpath.exists()
28
+ embed_dir = synthesizer_root.joinpath("embeds")
29
+ embed_dir.mkdir(exist_ok=True)
30
+
31
+ # Gather the input wave filepath and the target output embed filepath
32
+ with metadata_fpath.open("r") as metadata_file:
33
+ metadata = [line.split("|") for line in metadata_file]
34
+ fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
35
+
36
+ # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
37
+ # Embed the utterances in separate threads
38
+ func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
39
+ job = Pool(n_processes).imap(func, fpaths)
40
+ list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
speaker_encoder/config.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librispeech_datasets = {
2
+ "train": {
3
+ "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
4
+ "other": ["LibriSpeech/train-other-500"]
5
+ },
6
+ "test": {
7
+ "clean": ["LibriSpeech/test-clean"],
8
+ "other": ["LibriSpeech/test-other"]
9
+ },
10
+ "dev": {
11
+ "clean": ["LibriSpeech/dev-clean"],
12
+ "other": ["LibriSpeech/dev-other"]
13
+ },
14
+ }
15
+ libritts_datasets = {
16
+ "train": {
17
+ "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
18
+ "other": ["LibriTTS/train-other-500"]
19
+ },
20
+ "test": {
21
+ "clean": ["LibriTTS/test-clean"],
22
+ "other": ["LibriTTS/test-other"]
23
+ },
24
+ "dev": {
25
+ "clean": ["LibriTTS/dev-clean"],
26
+ "other": ["LibriTTS/dev-other"]
27
+ },
28
+ }
29
+ voxceleb_datasets = {
30
+ "voxceleb1" : {
31
+ "train": ["VoxCeleb1/wav"],
32
+ "test": ["VoxCeleb1/test_wav"]
33
+ },
34
+ "voxceleb2" : {
35
+ "train": ["VoxCeleb2/dev/aac"],
36
+ "test": ["VoxCeleb2/test_wav"]
37
+ }
38
+ }
39
+
40
+ other_datasets = [
41
+ "LJSpeech-1.1",
42
+ "VCTK-Corpus/wav48",
43
+ ]
44
+
45
+ anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
speaker_encoder/data_objects/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
speaker_encoder/data_objects/random_cycler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ class RandomCycler:
4
+ """
5
+ Creates an internal copy of a sequence and allows access to its items in a constrained random
6
+ order. For a source sequence of n items and one or several consecutive queries of a total
7
+ of m items, the following guarantees hold (one implies the other):
8
+ - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
9
+ - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
10
+ """
11
+
12
+ def __init__(self, source):
13
+ if len(source) == 0:
14
+ raise Exception("Can't create RandomCycler from an empty collection")
15
+ self.all_items = list(source)
16
+ self.next_items = []
17
+
18
+ def sample(self, count: int):
19
+ shuffle = lambda l: random.sample(l, len(l))
20
+
21
+ out = []
22
+ while count > 0:
23
+ if count >= len(self.all_items):
24
+ out.extend(shuffle(list(self.all_items)))
25
+ count -= len(self.all_items)
26
+ continue
27
+ n = min(count, len(self.next_items))
28
+ out.extend(self.next_items[:n])
29
+ count -= n
30
+ self.next_items = self.next_items[n:]
31
+ if len(self.next_items) == 0:
32
+ self.next_items = shuffle(list(self.all_items))
33
+ return out
34
+
35
+ def __next__(self):
36
+ return self.sample(1)[0]
37
+
speaker_encoder/data_objects/speaker.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.data_objects.random_cycler import RandomCycler
2
+ from speaker_encoder.data_objects.utterance import Utterance
3
+ from pathlib import Path
4
+
5
+ # Contains the set of utterances of a single speaker
6
+ class Speaker:
7
+ def __init__(self, root: Path):
8
+ self.root = root
9
+ self.name = root.name
10
+ self.utterances = None
11
+ self.utterance_cycler = None
12
+
13
+ def _load_utterances(self):
14
+ with self.root.joinpath("_sources.txt").open("r") as sources_file:
15
+ sources = [l.split(",") for l in sources_file]
16
+ sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
17
+ self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
18
+ self.utterance_cycler = RandomCycler(self.utterances)
19
+
20
+ def random_partial(self, count, n_frames):
21
+ """
22
+ Samples a batch of <count> unique partial utterances from the disk in a way that all
23
+ utterances come up at least once every two cycles and in a random order every time.
24
+
25
+ :param count: The number of partial utterances to sample from the set of utterances from
26
+ that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
27
+ the number of utterances available.
28
+ :param n_frames: The number of frames in the partial utterance.
29
+ :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
30
+ frames are the frames of the partial utterances and range is the range of the partial
31
+ utterance with regard to the complete utterance.
32
+ """
33
+ if self.utterances is None:
34
+ self._load_utterances()
35
+
36
+ utterances = self.utterance_cycler.sample(count)
37
+
38
+ a = [(u,) + u.random_partial(n_frames) for u in utterances]
39
+
40
+ return a
speaker_encoder/data_objects/speaker_batch.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List
3
+ from speaker_encoder.data_objects.speaker import Speaker
4
+
5
+ class SpeakerBatch:
6
+ def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
7
+ self.speakers = speakers
8
+ self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
9
+
10
+ # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
11
+ # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
12
+ self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
speaker_encoder/data_objects/speaker_verification_dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.data_objects.random_cycler import RandomCycler
2
+ from speaker_encoder.data_objects.speaker_batch import SpeakerBatch
3
+ from speaker_encoder.data_objects.speaker import Speaker
4
+ from speaker_encoder.params_data import partials_n_frames
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from pathlib import Path
7
+
8
+ # TODO: improve with a pool of speakers for data efficiency
9
+
10
+ class SpeakerVerificationDataset(Dataset):
11
+ def __init__(self, datasets_root: Path):
12
+ self.root = datasets_root
13
+ speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
14
+ if len(speaker_dirs) == 0:
15
+ raise Exception("No speakers found. Make sure you are pointing to the directory "
16
+ "containing all preprocessed speaker directories.")
17
+ self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
18
+ self.speaker_cycler = RandomCycler(self.speakers)
19
+
20
+ def __len__(self):
21
+ return int(1e10)
22
+
23
+ def __getitem__(self, index):
24
+ return next(self.speaker_cycler)
25
+
26
+ def get_logs(self):
27
+ log_string = ""
28
+ for log_fpath in self.root.glob("*.txt"):
29
+ with log_fpath.open("r") as log_file:
30
+ log_string += "".join(log_file.readlines())
31
+ return log_string
32
+
33
+
34
+ class SpeakerVerificationDataLoader(DataLoader):
35
+ def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
36
+ batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
37
+ worker_init_fn=None):
38
+ self.utterances_per_speaker = utterances_per_speaker
39
+
40
+ super().__init__(
41
+ dataset=dataset,
42
+ batch_size=speakers_per_batch,
43
+ shuffle=False,
44
+ sampler=sampler,
45
+ batch_sampler=batch_sampler,
46
+ num_workers=num_workers,
47
+ collate_fn=self.collate,
48
+ pin_memory=pin_memory,
49
+ drop_last=False,
50
+ timeout=timeout,
51
+ worker_init_fn=worker_init_fn
52
+ )
53
+
54
+ def collate(self, speakers):
55
+ return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
56
+
speaker_encoder/data_objects/utterance.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class Utterance:
5
+ def __init__(self, frames_fpath, wave_fpath):
6
+ self.frames_fpath = frames_fpath
7
+ self.wave_fpath = wave_fpath
8
+
9
+ def get_frames(self):
10
+ return np.load(self.frames_fpath)
11
+
12
+ def random_partial(self, n_frames):
13
+ """
14
+ Crops the frames into a partial utterance of n_frames
15
+
16
+ :param n_frames: The number of frames of the partial utterance
17
+ :return: the partial utterance frames and a tuple indicating the start and end of the
18
+ partial utterance in the complete utterance.
19
+ """
20
+ frames = self.get_frames()
21
+ if frames.shape[0] == n_frames:
22
+ start = 0
23
+ else:
24
+ start = np.random.randint(0, frames.shape[0] - n_frames)
25
+ end = start + n_frames
26
+ return frames[start:end], (start, end)
speaker_encoder/hparams.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Mel-filterbank
2
+ mel_window_length = 25 # In milliseconds
3
+ mel_window_step = 10 # In milliseconds
4
+ mel_n_channels = 40
5
+
6
+
7
+ ## Audio
8
+ sampling_rate = 16000
9
+ # Number of spectrogram frames in a partial utterance
10
+ partials_n_frames = 160 # 1600 ms
11
+
12
+
13
+ ## Voice Activation Detection
14
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
15
+ # This sets the granularity of the VAD. Should not need to be changed.
16
+ vad_window_length = 30 # In milliseconds
17
+ # Number of frames to average together when performing the moving average smoothing.
18
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
19
+ vad_moving_average_width = 8
20
+ # Maximum number of consecutive silent frames a segment can have.
21
+ vad_max_silence_length = 6
22
+
23
+
24
+ ## Audio volume normalization
25
+ audio_norm_target_dBFS = -30
26
+
27
+
28
+ ## Model parameters
29
+ model_hidden_size = 256
30
+ model_embedding_size = 256
31
+ model_num_layers = 3
speaker_encoder/inference.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.params_data import *
2
+ from speaker_encoder.model import SpeakerEncoder
3
+ from speaker_encoder.audio import preprocess_wav # We want to expose this function from here
4
+ from matplotlib import cm
5
+ from speaker_encoder import audio
6
+ from pathlib import Path
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+
11
+ _model = None # type: SpeakerEncoder
12
+ _device = None # type: torch.device
13
+
14
+
15
+ def load_model(weights_fpath: Path, device=None):
16
+ """
17
+ Loads the model in memory. If this function is not explicitely called, it will be run on the
18
+ first call to embed_frames() with the default weights file.
19
+
20
+ :param weights_fpath: the path to saved model weights.
21
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
22
+ model will be loaded and will run on this device. Outputs will however always be on the cpu.
23
+ If None, will default to your GPU if it"s available, otherwise your CPU.
24
+ """
25
+ # TODO: I think the slow loading of the encoder might have something to do with the device it
26
+ # was saved on. Worth investigating.
27
+ global _model, _device
28
+ if device is None:
29
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ elif isinstance(device, str):
31
+ _device = torch.device(device)
32
+ _model = SpeakerEncoder(_device, torch.device("cpu"))
33
+ checkpoint = torch.load(weights_fpath)
34
+ _model.load_state_dict(checkpoint["model_state"])
35
+ _model.eval()
36
+ print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
37
+
38
+
39
+ def is_loaded():
40
+ return _model is not None
41
+
42
+
43
+ def embed_frames_batch(frames_batch):
44
+ """
45
+ Computes embeddings for a batch of mel spectrogram.
46
+
47
+ :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
48
+ (batch_size, n_frames, n_channels)
49
+ :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
50
+ """
51
+ if _model is None:
52
+ raise Exception("Model was not loaded. Call load_model() before inference.")
53
+
54
+ frames = torch.from_numpy(frames_batch).to(_device)
55
+ embed = _model.forward(frames).detach().cpu().numpy()
56
+ return embed
57
+
58
+
59
+ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
60
+ min_pad_coverage=0.75, overlap=0.5):
61
+ """
62
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
63
+ partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
64
+ spectrogram slices are returned, so as to make each partial utterance waveform correspond to
65
+ its spectrogram. This function assumes that the mel spectrogram parameters used are those
66
+ defined in params_data.py.
67
+
68
+ The returned ranges may be indexing further than the length of the waveform. It is
69
+ recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
70
+
71
+ :param n_samples: the number of samples in the waveform
72
+ :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
73
+ utterance
74
+ :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
75
+ enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
76
+ then the last partial utterance will be considered, as if we padded the audio. Otherwise,
77
+ it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
78
+ utterance, this parameter is ignored so that the function always returns at least 1 slice.
79
+ :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
80
+ utterances are entirely disjoint.
81
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
82
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
83
+ utterances.
84
+ """
85
+ assert 0 <= overlap < 1
86
+ assert 0 < min_pad_coverage <= 1
87
+
88
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
89
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
90
+ frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
91
+
92
+ # Compute the slices
93
+ wav_slices, mel_slices = [], []
94
+ steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
95
+ for i in range(0, steps, frame_step):
96
+ mel_range = np.array([i, i + partial_utterance_n_frames])
97
+ wav_range = mel_range * samples_per_frame
98
+ mel_slices.append(slice(*mel_range))
99
+ wav_slices.append(slice(*wav_range))
100
+
101
+ # Evaluate whether extra padding is warranted or not
102
+ last_wav_range = wav_slices[-1]
103
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
104
+ if coverage < min_pad_coverage and len(mel_slices) > 1:
105
+ mel_slices = mel_slices[:-1]
106
+ wav_slices = wav_slices[:-1]
107
+
108
+ return wav_slices, mel_slices
109
+
110
+
111
+ def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
112
+ """
113
+ Computes an embedding for a single utterance.
114
+
115
+ # TODO: handle multiple wavs to benefit from batching on GPU
116
+ :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
117
+ :param using_partials: if True, then the utterance is split in partial utterances of
118
+ <partial_utterance_n_frames> frames and the utterance embedding is computed from their
119
+ normalized average. If False, the utterance is instead computed from feeding the entire
120
+ spectogram to the network.
121
+ :param return_partials: if True, the partial embeddings will also be returned along with the
122
+ wav slices that correspond to the partial embeddings.
123
+ :param kwargs: additional arguments to compute_partial_splits()
124
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
125
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
126
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
127
+ returned. If <using_partials> is simultaneously set to False, both these values will be None
128
+ instead.
129
+ """
130
+ # Process the entire utterance if not using partials
131
+ if not using_partials:
132
+ frames = audio.wav_to_mel_spectrogram(wav)
133
+ embed = embed_frames_batch(frames[None, ...])[0]
134
+ if return_partials:
135
+ return embed, None, None
136
+ return embed
137
+
138
+ # Compute where to split the utterance into partials and pad if necessary
139
+ wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
140
+ max_wave_length = wave_slices[-1].stop
141
+ if max_wave_length >= len(wav):
142
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
143
+
144
+ # Split the utterance into partials
145
+ frames = audio.wav_to_mel_spectrogram(wav)
146
+ frames_batch = np.array([frames[s] for s in mel_slices])
147
+ partial_embeds = embed_frames_batch(frames_batch)
148
+
149
+ # Compute the utterance embedding from the partial embeddings
150
+ raw_embed = np.mean(partial_embeds, axis=0)
151
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
152
+
153
+ if return_partials:
154
+ return embed, partial_embeds, wave_slices
155
+ return embed
156
+
157
+
158
+ def embed_speaker(wavs, **kwargs):
159
+ raise NotImplemented()
160
+
161
+
162
+ def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
163
+ if ax is None:
164
+ ax = plt.gca()
165
+
166
+ if shape is None:
167
+ height = int(np.sqrt(len(embed)))
168
+ shape = (height, -1)
169
+ embed = embed.reshape(shape)
170
+
171
+ cmap = cm.get_cmap()
172
+ mappable = ax.imshow(embed, cmap=cmap)
173
+ cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
174
+ cbar.set_clim(*color_range)
175
+
176
+ ax.set_xticks([]), ax.set_yticks([])
177
+ ax.set_title(title)
speaker_encoder/model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.params_model import *
2
+ from speaker_encoder.params_data import *
3
+ from scipy.interpolate import interp1d
4
+ from sklearn.metrics import roc_curve
5
+ from torch.nn.utils import clip_grad_norm_
6
+ from scipy.optimize import brentq
7
+ from torch import nn
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class SpeakerEncoder(nn.Module):
13
+ def __init__(self, device, loss_device):
14
+ super().__init__()
15
+ self.loss_device = loss_device
16
+
17
+ # Network defition
18
+ self.lstm = nn.LSTM(input_size=mel_n_channels, # 40
19
+ hidden_size=model_hidden_size, # 256
20
+ num_layers=model_num_layers, # 3
21
+ batch_first=True).to(device)
22
+ self.linear = nn.Linear(in_features=model_hidden_size,
23
+ out_features=model_embedding_size).to(device)
24
+ self.relu = torch.nn.ReLU().to(device)
25
+
26
+ # Cosine similarity scaling (with fixed initial parameter values)
27
+ self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
28
+ self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
29
+
30
+ # Loss
31
+ self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
32
+
33
+ def do_gradient_ops(self):
34
+ # Gradient scale
35
+ self.similarity_weight.grad *= 0.01
36
+ self.similarity_bias.grad *= 0.01
37
+
38
+ # Gradient clipping
39
+ clip_grad_norm_(self.parameters(), 3, norm_type=2)
40
+
41
+ def forward(self, utterances, hidden_init=None):
42
+ """
43
+ Computes the embeddings of a batch of utterance spectrograms.
44
+
45
+ :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
46
+ (batch_size, n_frames, n_channels)
47
+ :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
48
+ batch_size, hidden_size). Will default to a tensor of zeros if None.
49
+ :return: the embeddings as a tensor of shape (batch_size, embedding_size)
50
+ """
51
+ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
52
+ # and the final cell state.
53
+ out, (hidden, cell) = self.lstm(utterances, hidden_init)
54
+
55
+ # We take only the hidden state of the last layer
56
+ embeds_raw = self.relu(self.linear(hidden[-1]))
57
+
58
+ # L2-normalize it
59
+ embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
60
+
61
+ return embeds
62
+
63
+ def similarity_matrix(self, embeds):
64
+ """
65
+ Computes the similarity matrix according the section 2.1 of GE2E.
66
+
67
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
68
+ utterances_per_speaker, embedding_size)
69
+ :return: the similarity matrix as a tensor of shape (speakers_per_batch,
70
+ utterances_per_speaker, speakers_per_batch)
71
+ """
72
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
73
+
74
+ # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
75
+ centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
76
+ centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True)
77
+
78
+ # Exclusive centroids (1 per utterance)
79
+ centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
80
+ centroids_excl /= (utterances_per_speaker - 1)
81
+ centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True)
82
+
83
+ # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
84
+ # product of these vectors (which is just an element-wise multiplication reduced by a sum).
85
+ # We vectorize the computation for efficiency.
86
+ sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
87
+ speakers_per_batch).to(self.loss_device)
88
+ mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
89
+ for j in range(speakers_per_batch):
90
+ mask = np.where(mask_matrix[j])[0]
91
+ sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
92
+ sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
93
+
94
+ ## Even more vectorized version (slower maybe because of transpose)
95
+ # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
96
+ # ).to(self.loss_device)
97
+ # eye = np.eye(speakers_per_batch, dtype=np.int)
98
+ # mask = np.where(1 - eye)
99
+ # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
100
+ # mask = np.where(eye)
101
+ # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
102
+ # sim_matrix2 = sim_matrix2.transpose(1, 2)
103
+
104
+ sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
105
+ return sim_matrix
106
+
107
+ def loss(self, embeds):
108
+ """
109
+ Computes the softmax loss according the section 2.1 of GE2E.
110
+
111
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
112
+ utterances_per_speaker, embedding_size)
113
+ :return: the loss and the EER for this batch of embeddings.
114
+ """
115
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
116
+
117
+ # Loss
118
+ sim_matrix = self.similarity_matrix(embeds)
119
+ sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
120
+ speakers_per_batch))
121
+ ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
122
+ target = torch.from_numpy(ground_truth).long().to(self.loss_device)
123
+ loss = self.loss_fn(sim_matrix, target)
124
+
125
+ # EER (not backpropagated)
126
+ with torch.no_grad():
127
+ inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
128
+ labels = np.array([inv_argmax(i) for i in ground_truth])
129
+ preds = sim_matrix.detach().cpu().numpy()
130
+
131
+ # Snippet from https://yangcha.github.io/EER-ROC/
132
+ fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
133
+ eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
134
+
135
+ return loss, eer
speaker_encoder/params_data.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Mel-filterbank
3
+ mel_window_length = 25 # In milliseconds
4
+ mel_window_step = 10 # In milliseconds
5
+ mel_n_channels = 40
6
+
7
+
8
+ ## Audio
9
+ sampling_rate = 16000
10
+ # Number of spectrogram frames in a partial utterance
11
+ partials_n_frames = 160 # 1600 ms
12
+ # Number of spectrogram frames at inference
13
+ inference_n_frames = 80 # 800 ms
14
+
15
+
16
+ ## Voice Activation Detection
17
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
18
+ # This sets the granularity of the VAD. Should not need to be changed.
19
+ vad_window_length = 30 # In milliseconds
20
+ # Number of frames to average together when performing the moving average smoothing.
21
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
22
+ vad_moving_average_width = 8
23
+ # Maximum number of consecutive silent frames a segment can have.
24
+ vad_max_silence_length = 6
25
+
26
+
27
+ ## Audio volume normalization
28
+ audio_norm_target_dBFS = -30
29
+
speaker_encoder/params_model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Model parameters
3
+ model_hidden_size = 256
4
+ model_embedding_size = 256
5
+ model_num_layers = 3
6
+
7
+
8
+ ## Training parameters
9
+ learning_rate_init = 1e-4
10
+ speakers_per_batch = 64
11
+ utterances_per_speaker = 10
speaker_encoder/preprocess.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocess.pool import ThreadPool
2
+ from speaker_encoder.params_data import *
3
+ from speaker_encoder.config import librispeech_datasets, anglophone_nationalites
4
+ from datetime import datetime
5
+ from speaker_encoder import audio
6
+ from pathlib import Path
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+
10
+
11
+ class DatasetLog:
12
+ """
13
+ Registers metadata about the dataset in a text file.
14
+ """
15
+ def __init__(self, root, name):
16
+ self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
17
+ self.sample_data = dict()
18
+
19
+ start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
20
+ self.write_line("Creating dataset %s on %s" % (name, start_time))
21
+ self.write_line("-----")
22
+ self._log_params()
23
+
24
+ def _log_params(self):
25
+ from speaker_encoder import params_data
26
+ self.write_line("Parameter values:")
27
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
28
+ value = getattr(params_data, param_name)
29
+ self.write_line("\t%s: %s" % (param_name, value))
30
+ self.write_line("-----")
31
+
32
+ def write_line(self, line):
33
+ self.text_file.write("%s\n" % line)
34
+
35
+ def add_sample(self, **kwargs):
36
+ for param_name, value in kwargs.items():
37
+ if not param_name in self.sample_data:
38
+ self.sample_data[param_name] = []
39
+ self.sample_data[param_name].append(value)
40
+
41
+ def finalize(self):
42
+ self.write_line("Statistics:")
43
+ for param_name, values in self.sample_data.items():
44
+ self.write_line("\t%s:" % param_name)
45
+ self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
46
+ self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
47
+ self.write_line("-----")
48
+ end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
49
+ self.write_line("Finished on %s" % end_time)
50
+ self.text_file.close()
51
+
52
+
53
+ def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
54
+ dataset_root = datasets_root.joinpath(dataset_name)
55
+ if not dataset_root.exists():
56
+ print("Couldn\'t find %s, skipping this dataset." % dataset_root)
57
+ return None, None
58
+ return dataset_root, DatasetLog(out_dir, dataset_name)
59
+
60
+
61
+ def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
62
+ skip_existing, logger):
63
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
64
+
65
+ # Function to preprocess utterances for one speaker
66
+ def preprocess_speaker(speaker_dir: Path):
67
+ # Give a name to the speaker that includes its dataset
68
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
69
+
70
+ # Create an output directory with that name, as well as a txt file containing a
71
+ # reference to each source file.
72
+ speaker_out_dir = out_dir.joinpath(speaker_name)
73
+ speaker_out_dir.mkdir(exist_ok=True)
74
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
75
+
76
+ # There's a possibility that the preprocessing was interrupted earlier, check if
77
+ # there already is a sources file.
78
+ if sources_fpath.exists():
79
+ try:
80
+ with sources_fpath.open("r") as sources_file:
81
+ existing_fnames = {line.split(",")[0] for line in sources_file}
82
+ except:
83
+ existing_fnames = {}
84
+ else:
85
+ existing_fnames = {}
86
+
87
+ # Gather all audio files for that speaker recursively
88
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
89
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
90
+ # Check if the target output file already exists
91
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
92
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
93
+ if skip_existing and out_fname in existing_fnames:
94
+ continue
95
+
96
+ # Load and preprocess the waveform
97
+ wav = audio.preprocess_wav(in_fpath)
98
+ if len(wav) == 0:
99
+ continue
100
+
101
+ # Create the mel spectrogram, discard those that are too short
102
+ frames = audio.wav_to_mel_spectrogram(wav)
103
+ if len(frames) < partials_n_frames:
104
+ continue
105
+
106
+ out_fpath = speaker_out_dir.joinpath(out_fname)
107
+ np.save(out_fpath, frames)
108
+ logger.add_sample(duration=len(wav) / sampling_rate)
109
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
110
+
111
+ sources_file.close()
112
+
113
+ # Process the utterances for each speaker
114
+ with ThreadPool(8) as pool:
115
+ list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
116
+ unit="speakers"))
117
+ logger.finalize()
118
+ print("Done preprocessing %s.\n" % dataset_name)
119
+
120
+
121
+ # Function to preprocess utterances for one speaker
122
+ def __preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, extension: str, skip_existing: bool):
123
+ # Give a name to the speaker that includes its dataset
124
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
125
+
126
+ # Create an output directory with that name, as well as a txt file containing a
127
+ # reference to each source file.
128
+ speaker_out_dir = out_dir.joinpath(speaker_name)
129
+ speaker_out_dir.mkdir(exist_ok=True)
130
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
131
+
132
+ # There's a possibility that the preprocessing was interrupted earlier, check if
133
+ # there already is a sources file.
134
+ # if sources_fpath.exists():
135
+ # try:
136
+ # with sources_fpath.open("r") as sources_file:
137
+ # existing_fnames = {line.split(",")[0] for line in sources_file}
138
+ # except:
139
+ # existing_fnames = {}
140
+ # else:
141
+ # existing_fnames = {}
142
+ existing_fnames = {}
143
+ # Gather all audio files for that speaker recursively
144
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
145
+
146
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
147
+ # Check if the target output file already exists
148
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
149
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
150
+ if skip_existing and out_fname in existing_fnames:
151
+ continue
152
+
153
+ # Load and preprocess the waveform
154
+ wav = audio.preprocess_wav(in_fpath)
155
+ if len(wav) == 0:
156
+ continue
157
+
158
+ # Create the mel spectrogram, discard those that are too short
159
+ frames = audio.wav_to_mel_spectrogram(wav)
160
+ if len(frames) < partials_n_frames:
161
+ continue
162
+
163
+ out_fpath = speaker_out_dir.joinpath(out_fname)
164
+ np.save(out_fpath, frames)
165
+ # logger.add_sample(duration=len(wav) / sampling_rate)
166
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
167
+
168
+ sources_file.close()
169
+ return len(wav)
170
+
171
+ def _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
172
+ skip_existing, logger):
173
+ # from multiprocessing import Pool, cpu_count
174
+ from pathos.multiprocessing import ProcessingPool as Pool
175
+ # Function to preprocess utterances for one speaker
176
+ def __preprocess_speaker(speaker_dir: Path):
177
+ # Give a name to the speaker that includes its dataset
178
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
179
+
180
+ # Create an output directory with that name, as well as a txt file containing a
181
+ # reference to each source file.
182
+ speaker_out_dir = out_dir.joinpath(speaker_name)
183
+ speaker_out_dir.mkdir(exist_ok=True)
184
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
185
+
186
+ existing_fnames = {}
187
+ # Gather all audio files for that speaker recursively
188
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
189
+ wav_lens = []
190
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
191
+ # Check if the target output file already exists
192
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
193
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
194
+ if skip_existing and out_fname in existing_fnames:
195
+ continue
196
+
197
+ # Load and preprocess the waveform
198
+ wav = audio.preprocess_wav(in_fpath)
199
+ if len(wav) == 0:
200
+ continue
201
+
202
+ # Create the mel spectrogram, discard those that are too short
203
+ frames = audio.wav_to_mel_spectrogram(wav)
204
+ if len(frames) < partials_n_frames:
205
+ continue
206
+
207
+ out_fpath = speaker_out_dir.joinpath(out_fname)
208
+ np.save(out_fpath, frames)
209
+ # logger.add_sample(duration=len(wav) / sampling_rate)
210
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
211
+ wav_lens.append(len(wav))
212
+ sources_file.close()
213
+ return wav_lens
214
+
215
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
216
+ # Process the utterances for each speaker
217
+ # with ThreadPool(8) as pool:
218
+ # list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
219
+ # unit="speakers"))
220
+ pool = Pool(processes=20)
221
+ for i, wav_lens in enumerate(pool.map(__preprocess_speaker, speaker_dirs), 1):
222
+ for wav_len in wav_lens:
223
+ logger.add_sample(duration=wav_len / sampling_rate)
224
+ print(f'{i}/{len(speaker_dirs)} \r')
225
+
226
+ logger.finalize()
227
+ print("Done preprocessing %s.\n" % dataset_name)
228
+
229
+
230
+ def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
231
+ for dataset_name in librispeech_datasets["train"]["other"]:
232
+ # Initialize the preprocessing
233
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
234
+ if not dataset_root:
235
+ return
236
+
237
+ # Preprocess all speakers
238
+ speaker_dirs = list(dataset_root.glob("*"))
239
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac",
240
+ skip_existing, logger)
241
+
242
+
243
+ def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
244
+ # Initialize the preprocessing
245
+ dataset_name = "VoxCeleb1"
246
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
247
+ if not dataset_root:
248
+ return
249
+
250
+ # Get the contents of the meta file
251
+ with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
252
+ metadata = [line.split("\t") for line in metafile][1:]
253
+
254
+ # Select the ID and the nationality, filter out non-anglophone speakers
255
+ nationalities = {line[0]: line[3] for line in metadata}
256
+ # keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
257
+ # nationality.lower() in anglophone_nationalites]
258
+ keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items()]
259
+ print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
260
+ (len(keep_speaker_ids), len(nationalities)))
261
+
262
+ # Get the speaker directories for anglophone speakers only
263
+ speaker_dirs = dataset_root.joinpath("wav").glob("*")
264
+ speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
265
+ speaker_dir.name in keep_speaker_ids]
266
+ print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
267
+ (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
268
+
269
+ # Preprocess all speakers
270
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
271
+ skip_existing, logger)
272
+
273
+
274
+ def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
275
+ # Initialize the preprocessing
276
+ dataset_name = "VoxCeleb2"
277
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
278
+ if not dataset_root:
279
+ return
280
+
281
+ # Get the speaker directories
282
+ # Preprocess all speakers
283
+ speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
284
+ _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a",
285
+ skip_existing, logger)
speaker_encoder/train.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.visualizations import Visualizations
2
+ from speaker_encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
3
+ from speaker_encoder.params_model import *
4
+ from speaker_encoder.model import SpeakerEncoder
5
+ from utils.profiler import Profiler
6
+ from pathlib import Path
7
+ import torch
8
+
9
+ def sync(device: torch.device):
10
+ # FIXME
11
+ return
12
+ # For correct profiling (cuda operations are async)
13
+ if device.type == "cuda":
14
+ torch.cuda.synchronize(device)
15
+
16
+ def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
17
+ backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
18
+ no_visdom: bool):
19
+ # Create a dataset and a dataloader
20
+ dataset = SpeakerVerificationDataset(clean_data_root)
21
+ loader = SpeakerVerificationDataLoader(
22
+ dataset,
23
+ speakers_per_batch, # 64
24
+ utterances_per_speaker, # 10
25
+ num_workers=8,
26
+ )
27
+
28
+ # Setup the device on which to run the forward pass and the loss. These can be different,
29
+ # because the forward pass is faster on the GPU whereas the loss is often (depending on your
30
+ # hyperparameters) faster on the CPU.
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ # FIXME: currently, the gradient is None if loss_device is cuda
33
+ loss_device = torch.device("cpu")
34
+
35
+ # Create the model and the optimizer
36
+ model = SpeakerEncoder(device, loss_device)
37
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
38
+ init_step = 1
39
+
40
+ # Configure file path for the model
41
+ state_fpath = models_dir.joinpath(run_id + ".pt")
42
+ backup_dir = models_dir.joinpath(run_id + "_backups")
43
+
44
+ # Load any existing model
45
+ if not force_restart:
46
+ if state_fpath.exists():
47
+ print("Found existing model \"%s\", loading it and resuming training." % run_id)
48
+ checkpoint = torch.load(state_fpath)
49
+ init_step = checkpoint["step"]
50
+ model.load_state_dict(checkpoint["model_state"])
51
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
52
+ optimizer.param_groups[0]["lr"] = learning_rate_init
53
+ else:
54
+ print("No model \"%s\" found, starting training from scratch." % run_id)
55
+ else:
56
+ print("Starting the training from scratch.")
57
+ model.train()
58
+
59
+ # Initialize the visualization environment
60
+ vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
61
+ vis.log_dataset(dataset)
62
+ vis.log_params()
63
+ device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
64
+ vis.log_implementation({"Device": device_name})
65
+
66
+ # Training loop
67
+ profiler = Profiler(summarize_every=10, disabled=False)
68
+ for step, speaker_batch in enumerate(loader, init_step):
69
+ profiler.tick("Blocking, waiting for batch (threaded)")
70
+
71
+ # Forward pass
72
+ inputs = torch.from_numpy(speaker_batch.data).to(device)
73
+ sync(device)
74
+ profiler.tick("Data to %s" % device)
75
+ embeds = model(inputs)
76
+ sync(device)
77
+ profiler.tick("Forward pass")
78
+ embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
79
+ loss, eer = model.loss(embeds_loss)
80
+ sync(loss_device)
81
+ profiler.tick("Loss")
82
+
83
+ # Backward pass
84
+ model.zero_grad()
85
+ loss.backward()
86
+ profiler.tick("Backward pass")
87
+ model.do_gradient_ops()
88
+ optimizer.step()
89
+ profiler.tick("Parameter update")
90
+
91
+ # Update visualizations
92
+ # learning_rate = optimizer.param_groups[0]["lr"]
93
+ vis.update(loss.item(), eer, step)
94
+
95
+ # Draw projections and save them to the backup folder
96
+ if umap_every != 0 and step % umap_every == 0:
97
+ print("Drawing and saving projections (step %d)" % step)
98
+ backup_dir.mkdir(exist_ok=True)
99
+ projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
100
+ embeds = embeds.detach().cpu().numpy()
101
+ vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
102
+ vis.save()
103
+
104
+ # Overwrite the latest version of the model
105
+ if save_every != 0 and step % save_every == 0:
106
+ print("Saving the model (step %d)" % step)
107
+ torch.save({
108
+ "step": step + 1,
109
+ "model_state": model.state_dict(),
110
+ "optimizer_state": optimizer.state_dict(),
111
+ }, state_fpath)
112
+
113
+ # Make a backup
114
+ if backup_every != 0 and step % backup_every == 0:
115
+ print("Making a backup (step %d)" % step)
116
+ backup_dir.mkdir(exist_ok=True)
117
+ backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
118
+ torch.save({
119
+ "step": step + 1,
120
+ "model_state": model.state_dict(),
121
+ "optimizer_state": optimizer.state_dict(),
122
+ }, backup_fpath)
123
+
124
+ profiler.tick("Extras (visualizations, saving)")
125
+
speaker_encoder/visualizations.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from datetime import datetime
3
+ from time import perf_counter as timer
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ # import webbrowser
7
+ import visdom
8
+ import umap
9
+
10
+ colormap = np.array([
11
+ [76, 255, 0],
12
+ [0, 127, 70],
13
+ [255, 0, 0],
14
+ [255, 217, 38],
15
+ [0, 135, 255],
16
+ [165, 0, 165],
17
+ [255, 167, 255],
18
+ [0, 255, 255],
19
+ [255, 96, 38],
20
+ [142, 76, 0],
21
+ [33, 0, 127],
22
+ [0, 0, 0],
23
+ [183, 183, 183],
24
+ ], dtype=np.float) / 255
25
+
26
+
27
+ class Visualizations:
28
+ def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
29
+ # Tracking data
30
+ self.last_update_timestamp = timer()
31
+ self.update_every = update_every
32
+ self.step_times = []
33
+ self.losses = []
34
+ self.eers = []
35
+ print("Updating the visualizations every %d steps." % update_every)
36
+
37
+ # If visdom is disabled TODO: use a better paradigm for that
38
+ self.disabled = disabled
39
+ if self.disabled:
40
+ return
41
+
42
+ # Set the environment name
43
+ now = str(datetime.now().strftime("%d-%m %Hh%M"))
44
+ if env_name is None:
45
+ self.env_name = now
46
+ else:
47
+ self.env_name = "%s (%s)" % (env_name, now)
48
+
49
+ # Connect to visdom and open the corresponding window in the browser
50
+ try:
51
+ self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
52
+ except ConnectionError:
53
+ raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
54
+ "start it.")
55
+ # webbrowser.open("http://localhost:8097/env/" + self.env_name)
56
+
57
+ # Create the windows
58
+ self.loss_win = None
59
+ self.eer_win = None
60
+ # self.lr_win = None
61
+ self.implementation_win = None
62
+ self.projection_win = None
63
+ self.implementation_string = ""
64
+
65
+ def log_params(self):
66
+ if self.disabled:
67
+ return
68
+ from speaker_encoder import params_data
69
+ from speaker_encoder import params_model
70
+ param_string = "<b>Model parameters</b>:<br>"
71
+ for param_name in (p for p in dir(params_model) if not p.startswith("__")):
72
+ value = getattr(params_model, param_name)
73
+ param_string += "\t%s: %s<br>" % (param_name, value)
74
+ param_string += "<b>Data parameters</b>:<br>"
75
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
76
+ value = getattr(params_data, param_name)
77
+ param_string += "\t%s: %s<br>" % (param_name, value)
78
+ self.vis.text(param_string, opts={"title": "Parameters"})
79
+
80
+ def log_dataset(self, dataset: SpeakerVerificationDataset):
81
+ if self.disabled:
82
+ return
83
+ dataset_string = ""
84
+ dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
85
+ dataset_string += "\n" + dataset.get_logs()
86
+ dataset_string = dataset_string.replace("\n", "<br>")
87
+ self.vis.text(dataset_string, opts={"title": "Dataset"})
88
+
89
+ def log_implementation(self, params):
90
+ if self.disabled:
91
+ return
92
+ implementation_string = ""
93
+ for param, value in params.items():
94
+ implementation_string += "<b>%s</b>: %s\n" % (param, value)
95
+ implementation_string = implementation_string.replace("\n", "<br>")
96
+ self.implementation_string = implementation_string
97
+ self.implementation_win = self.vis.text(
98
+ implementation_string,
99
+ opts={"title": "Training implementation"}
100
+ )
101
+
102
+ def update(self, loss, eer, step):
103
+ # Update the tracking data
104
+ now = timer()
105
+ self.step_times.append(1000 * (now - self.last_update_timestamp))
106
+ self.last_update_timestamp = now
107
+ self.losses.append(loss)
108
+ self.eers.append(eer)
109
+ print(".", end="")
110
+
111
+ # Update the plots every <update_every> steps
112
+ if step % self.update_every != 0:
113
+ return
114
+ time_string = "Step time: mean: %5dms std: %5dms" % \
115
+ (int(np.mean(self.step_times)), int(np.std(self.step_times)))
116
+ print("\nStep %6d Loss: %.4f EER: %.4f %s" %
117
+ (step, np.mean(self.losses), np.mean(self.eers), time_string))
118
+ if not self.disabled:
119
+ self.loss_win = self.vis.line(
120
+ [np.mean(self.losses)],
121
+ [step],
122
+ win=self.loss_win,
123
+ update="append" if self.loss_win else None,
124
+ opts=dict(
125
+ legend=["Avg. loss"],
126
+ xlabel="Step",
127
+ ylabel="Loss",
128
+ title="Loss",
129
+ )
130
+ )
131
+ self.eer_win = self.vis.line(
132
+ [np.mean(self.eers)],
133
+ [step],
134
+ win=self.eer_win,
135
+ update="append" if self.eer_win else None,
136
+ opts=dict(
137
+ legend=["Avg. EER"],
138
+ xlabel="Step",
139
+ ylabel="EER",
140
+ title="Equal error rate"
141
+ )
142
+ )
143
+ if self.implementation_win is not None:
144
+ self.vis.text(
145
+ self.implementation_string + ("<b>%s</b>" % time_string),
146
+ win=self.implementation_win,
147
+ opts={"title": "Training implementation"},
148
+ )
149
+
150
+ # Reset the tracking
151
+ self.losses.clear()
152
+ self.eers.clear()
153
+ self.step_times.clear()
154
+
155
+ def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None,
156
+ max_speakers=10):
157
+ max_speakers = min(max_speakers, len(colormap))
158
+ embeds = embeds[:max_speakers * utterances_per_speaker]
159
+
160
+ n_speakers = len(embeds) // utterances_per_speaker
161
+ ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
162
+ colors = [colormap[i] for i in ground_truth]
163
+
164
+ reducer = umap.UMAP()
165
+ projected = reducer.fit_transform(embeds)
166
+ plt.scatter(projected[:, 0], projected[:, 1], c=colors)
167
+ plt.gca().set_aspect("equal", "datalim")
168
+ plt.title("UMAP projection (step %d)" % step)
169
+ if not self.disabled:
170
+ self.projection_win = self.vis.matplot(plt, win=self.projection_win)
171
+ if out_fpath is not None:
172
+ plt.savefig(out_fpath)
173
+ plt.clf()
174
+
175
+ def save(self):
176
+ if not self.disabled:
177
+ self.vis.save([self.env_name])
178
+
speaker_encoder/voice_encoder.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speaker_encoder.hparams import *
2
+ from speaker_encoder import audio
3
+ from pathlib import Path
4
+ from typing import Union, List
5
+ from torch import nn
6
+ from time import perf_counter as timer
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ class SpeakerEncoder(nn.Module):
12
+ def __init__(self, weights_fpath, device: Union[str, torch.device]=None, verbose=True):
13
+ """
14
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda").
15
+ If None, defaults to cuda if it is available on your machine, otherwise the model will
16
+ run on cpu. Outputs are always returned on the cpu, as numpy arrays.
17
+ """
18
+ super().__init__()
19
+
20
+ # Define the network
21
+ self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
22
+ self.linear = nn.Linear(model_hidden_size, model_embedding_size)
23
+ self.relu = nn.ReLU()
24
+
25
+ # Get the target device
26
+ if device is None:
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ elif isinstance(device, str):
29
+ device = torch.device(device)
30
+ self.device = device
31
+
32
+ # Load the pretrained model'speaker weights
33
+ # weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt")
34
+ # if not weights_fpath.exists():
35
+ # raise Exception("Couldn't find the voice encoder pretrained model at %s." %
36
+ # weights_fpath)
37
+
38
+ start = timer()
39
+ checkpoint = torch.load(weights_fpath, map_location="cpu")
40
+
41
+ self.load_state_dict(checkpoint["model_state"], strict=False)
42
+ self.to(device)
43
+
44
+ if verbose:
45
+ print("Loaded the voice encoder model on %s in %.2f seconds." %
46
+ (device.type, timer() - start))
47
+
48
+ def forward(self, mels: torch.FloatTensor):
49
+ """
50
+ Computes the embeddings of a batch of utterance spectrograms.
51
+ :param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape
52
+ (batch_size, n_frames, n_channels)
53
+ :return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size).
54
+ Embeddings are positive and L2-normed, thus they lay in the range [0, 1].
55
+ """
56
+ # Pass the input through the LSTM layers and retrieve the final hidden state of the last
57
+ # layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings.
58
+ _, (hidden, _) = self.lstm(mels)
59
+ embeds_raw = self.relu(self.linear(hidden[-1]))
60
+ return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
61
+
62
+ @staticmethod
63
+ def compute_partial_slices(n_samples: int, rate, min_coverage):
64
+ """
65
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to
66
+ obtain partial utterances of <partials_n_frames> each. Both the waveform and the
67
+ mel spectrogram slices are returned, so as to make each partial utterance waveform
68
+ correspond to its spectrogram.
69
+
70
+ The returned ranges may be indexing further than the length of the waveform. It is
71
+ recommended that you pad the waveform with zeros up to wav_slices[-1].stop.
72
+
73
+ :param n_samples: the number of samples in the waveform
74
+ :param rate: how many partial utterances should occur per second. Partial utterances must
75
+ cover the span of the entire utterance, thus the rate should not be lower than the inverse
76
+ of the duration of a partial utterance. By default, partial utterances are 1.6s long and
77
+ the minimum rate is thus 0.625.
78
+ :param min_coverage: when reaching the last partial utterance, it may or may not have
79
+ enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
80
+ then the last partial utterance will be considered by zero-padding the audio. Otherwise,
81
+ it will be discarded. If there aren't enough frames for one partial utterance,
82
+ this parameter is ignored so that the function always returns at least one slice.
83
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
84
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
85
+ utterances.
86
+ """
87
+ assert 0 < min_coverage <= 1
88
+
89
+ # Compute how many frames separate two partial utterances
90
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
91
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
92
+ frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
93
+ assert 0 < frame_step, "The rate is too high"
94
+ assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
95
+ (sampling_rate / (samples_per_frame * partials_n_frames))
96
+
97
+ # Compute the slices
98
+ wav_slices, mel_slices = [], []
99
+ steps = max(1, n_frames - partials_n_frames + frame_step + 1)
100
+ for i in range(0, steps, frame_step):
101
+ mel_range = np.array([i, i + partials_n_frames])
102
+ wav_range = mel_range * samples_per_frame
103
+ mel_slices.append(slice(*mel_range))
104
+ wav_slices.append(slice(*wav_range))
105
+
106
+ # Evaluate whether extra padding is warranted or not
107
+ last_wav_range = wav_slices[-1]
108
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
109
+ if coverage < min_coverage and len(mel_slices) > 1:
110
+ mel_slices = mel_slices[:-1]
111
+ wav_slices = wav_slices[:-1]
112
+
113
+ return wav_slices, mel_slices
114
+
115
+ def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75):
116
+ """
117
+ Computes an embedding for a single utterance. The utterance is divided in partial
118
+ utterances and an embedding is computed for each. The complete utterance embedding is the
119
+ L2-normed average embedding of the partial utterances.
120
+
121
+ TODO: independent batched version of this function
122
+
123
+ :param wav: a preprocessed utterance waveform as a numpy array of float32
124
+ :param return_partials: if True, the partial embeddings will also be returned along with
125
+ the wav slices corresponding to each partial utterance.
126
+ :param rate: how many partial utterances should occur per second. Partial utterances must
127
+ cover the span of the entire utterance, thus the rate should not be lower than the inverse
128
+ of the duration of a partial utterance. By default, partial utterances are 1.6s long and
129
+ the minimum rate is thus 0.625.
130
+ :param min_coverage: when reaching the last partial utterance, it may or may not have
131
+ enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
132
+ then the last partial utterance will be considered by zero-padding the audio. Otherwise,
133
+ it will be discarded. If there aren't enough frames for one partial utterance,
134
+ this parameter is ignored so that the function always returns at least one slice.
135
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
136
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
137
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
138
+ returned.
139
+ """
140
+ # Compute where to split the utterance into partials and pad the waveform with zeros if
141
+ # the partial utterances cover a larger range.
142
+ wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage)
143
+ max_wave_length = wav_slices[-1].stop
144
+ if max_wave_length >= len(wav):
145
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
146
+
147
+ # Split the utterance into partials and forward them through the model
148
+ mel = audio.wav_to_mel_spectrogram(wav)
149
+ mels = np.array([mel[s] for s in mel_slices])
150
+ with torch.no_grad():
151
+ mels = torch.from_numpy(mels).to(self.device)
152
+ partial_embeds = self(mels).cpu().numpy()
153
+
154
+ # Compute the utterance embedding from the partial embeddings
155
+ raw_embed = np.mean(partial_embeds, axis=0)
156
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
157
+
158
+ if return_partials:
159
+ return embed, partial_embeds, wav_slices
160
+ return embed
161
+
162
+ def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
163
+ """
164
+ Compute the embedding of a collection of wavs (presumably from the same speaker) by
165
+ averaging their embedding and L2-normalizing it.
166
+
167
+ :param wavs: list of wavs a numpy arrays of float32.
168
+ :param kwargs: extra arguments to embed_utterance()
169
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,).
170
+ """
171
+ raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) \
172
+ for wav in wavs], axis=0)
173
+ return raw_embed / np.linalg.norm(raw_embed, 2)
utils.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import logging
5
+ import json
6
+ import subprocess
7
+ import numpy as np
8
+ from scipy.io.wavfile import read
9
+ import torch
10
+ from torch.nn import functional as F
11
+ from commons import sequence_mask
12
+ from wavlm import WavLM, WavLMConfig
13
+
14
+ MATPLOTLIB_FLAG = False
15
+
16
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
17
+ logger = logging
18
+
19
+
20
+ def get_cmodel(rank):
21
+ checkpoint = torch.load('wavlm/WavLM-Large.pt')
22
+ cfg = WavLMConfig(checkpoint['cfg'])
23
+ cmodel = WavLM(cfg)
24
+ cmodel.load_state_dict(checkpoint['model'])
25
+ cmodel.eval()
26
+ return cmodel
27
+
28
+
29
+ def get_content(cmodel, y):
30
+ with torch.no_grad():
31
+ c = cmodel.extract_features(y.squeeze(1))[0]
32
+ c = c.transpose(1, 2)
33
+ return c
34
+
35
+
36
+ def get_vocoder(rank):
37
+ with open("hifigan/config.json", "r") as f:
38
+ config = json.load(f)
39
+ config = hifigan.AttrDict(config)
40
+ vocoder = hifigan.Generator(config)
41
+ ckpt = torch.load("hifigan/generator_v1")
42
+ vocoder.load_state_dict(ckpt["generator"])
43
+ vocoder.eval()
44
+ vocoder.remove_weight_norm()
45
+ vocoder.cuda(rank)
46
+ return vocoder
47
+
48
+
49
+ def transform(mel, height): # 68-92
50
+ #r = np.random.random()
51
+ #rate = r * 0.3 + 0.85 # 0.85-1.15
52
+ #height = int(mel.size(-2) * rate)
53
+ tgt = torchvision.transforms.functional.resize(mel, (height, mel.size(-1)))
54
+ if height >= mel.size(-2):
55
+ return tgt[:, :mel.size(-2), :]
56
+ else:
57
+ silence = tgt[:,-1:,:].repeat(1,mel.size(-2)-height,1)
58
+ silence += torch.randn_like(silence) / 10
59
+ return torch.cat((tgt, silence), 1)
60
+
61
+
62
+ def stretch(mel, width): # 0.5-2
63
+ return torchvision.transforms.functional.resize(mel, (mel.size(-2), width))
64
+
65
+
66
+ def load_checkpoint(checkpoint_path, model, optimizer=None):
67
+ assert os.path.isfile(checkpoint_path)
68
+ checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
69
+ iteration = checkpoint_dict['iteration']
70
+ learning_rate = checkpoint_dict['learning_rate']
71
+ if optimizer is not None:
72
+ optimizer.load_state_dict(checkpoint_dict['optimizer'])
73
+ saved_state_dict = checkpoint_dict['model']
74
+ if hasattr(model, 'module'):
75
+ state_dict = model.module.state_dict()
76
+ else:
77
+ state_dict = model.state_dict()
78
+ new_state_dict= {}
79
+ for k, v in state_dict.items():
80
+ try:
81
+ new_state_dict[k] = saved_state_dict[k]
82
+ except:
83
+ logger.info("%s is not in the checkpoint" % k)
84
+ new_state_dict[k] = v
85
+ if hasattr(model, 'module'):
86
+ model.module.load_state_dict(new_state_dict)
87
+ else:
88
+ model.load_state_dict(new_state_dict)
89
+ logger.info("Loaded checkpoint '{}' (iteration {})" .format(
90
+ checkpoint_path, iteration))
91
+ return model, optimizer, learning_rate, iteration
92
+
93
+
94
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
95
+ logger.info("Saving model and optimizer state at iteration {} to {}".format(
96
+ iteration, checkpoint_path))
97
+ if hasattr(model, 'module'):
98
+ state_dict = model.module.state_dict()
99
+ else:
100
+ state_dict = model.state_dict()
101
+ torch.save({'model': state_dict,
102
+ 'iteration': iteration,
103
+ 'optimizer': optimizer.state_dict(),
104
+ 'learning_rate': learning_rate}, checkpoint_path)
105
+
106
+
107
+ def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
108
+ for k, v in scalars.items():
109
+ writer.add_scalar(k, v, global_step)
110
+ for k, v in histograms.items():
111
+ writer.add_histogram(k, v, global_step)
112
+ for k, v in images.items():
113
+ writer.add_image(k, v, global_step, dataformats='HWC')
114
+ for k, v in audios.items():
115
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
116
+
117
+
118
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
119
+ f_list = glob.glob(os.path.join(dir_path, regex))
120
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
121
+ x = f_list[-1]
122
+ print(x)
123
+ return x
124
+
125
+
126
+ def plot_spectrogram_to_numpy(spectrogram):
127
+ global MATPLOTLIB_FLAG
128
+ if not MATPLOTLIB_FLAG:
129
+ import matplotlib
130
+ matplotlib.use("Agg")
131
+ MATPLOTLIB_FLAG = True
132
+ mpl_logger = logging.getLogger('matplotlib')
133
+ mpl_logger.setLevel(logging.WARNING)
134
+ import matplotlib.pylab as plt
135
+ import numpy as np
136
+
137
+ fig, ax = plt.subplots(figsize=(10,2))
138
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
139
+ interpolation='none')
140
+ plt.colorbar(im, ax=ax)
141
+ plt.xlabel("Frames")
142
+ plt.ylabel("Channels")
143
+ plt.tight_layout()
144
+
145
+ fig.canvas.draw()
146
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
147
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
148
+ plt.close()
149
+ return data
150
+
151
+
152
+ def plot_alignment_to_numpy(alignment, info=None):
153
+ global MATPLOTLIB_FLAG
154
+ if not MATPLOTLIB_FLAG:
155
+ import matplotlib
156
+ matplotlib.use("Agg")
157
+ MATPLOTLIB_FLAG = True
158
+ mpl_logger = logging.getLogger('matplotlib')
159
+ mpl_logger.setLevel(logging.WARNING)
160
+ import matplotlib.pylab as plt
161
+ import numpy as np
162
+
163
+ fig, ax = plt.subplots(figsize=(6, 4))
164
+ im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
165
+ interpolation='none')
166
+ fig.colorbar(im, ax=ax)
167
+ xlabel = 'Decoder timestep'
168
+ if info is not None:
169
+ xlabel += '\n\n' + info
170
+ plt.xlabel(xlabel)
171
+ plt.ylabel('Encoder timestep')
172
+ plt.tight_layout()
173
+
174
+ fig.canvas.draw()
175
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
176
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
177
+ plt.close()
178
+ return data
179
+
180
+
181
+ def load_wav_to_torch(full_path):
182
+ sampling_rate, data = read(full_path)
183
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
184
+
185
+
186
+ def load_filepaths_and_text(filename, split="|"):
187
+ with open(filename, encoding='utf-8') as f:
188
+ filepaths_and_text = [line.strip().split(split) for line in f]
189
+ return filepaths_and_text
190
+
191
+
192
+ def get_hparams(init=True):
193
+ parser = argparse.ArgumentParser()
194
+ parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
195
+ help='JSON file for configuration')
196
+ parser.add_argument('-m', '--model', type=str, required=True,
197
+ help='Model name')
198
+
199
+ args = parser.parse_args()
200
+ model_dir = os.path.join("./logs", args.model)
201
+
202
+ if not os.path.exists(model_dir):
203
+ os.makedirs(model_dir)
204
+
205
+ config_path = args.config
206
+ config_save_path = os.path.join(model_dir, "config.json")
207
+ if init:
208
+ with open(config_path, "r") as f:
209
+ data = f.read()
210
+ with open(config_save_path, "w") as f:
211
+ f.write(data)
212
+ else:
213
+ with open(config_save_path, "r") as f:
214
+ data = f.read()
215
+ config = json.loads(data)
216
+
217
+ hparams = HParams(**config)
218
+ hparams.model_dir = model_dir
219
+ return hparams
220
+
221
+
222
+ def get_hparams_from_dir(model_dir):
223
+ config_save_path = os.path.join(model_dir, "config.json")
224
+ with open(config_save_path, "r") as f:
225
+ data = f.read()
226
+ config = json.loads(data)
227
+
228
+ hparams =HParams(**config)
229
+ hparams.model_dir = model_dir
230
+ return hparams
231
+
232
+
233
+ def get_hparams_from_file(config_path):
234
+ with open(config_path, "r") as f:
235
+ data = f.read()
236
+ config = json.loads(data)
237
+
238
+ hparams =HParams(**config)
239
+ return hparams
240
+
241
+
242
+ def check_git_hash(model_dir):
243
+ source_dir = os.path.dirname(os.path.realpath(__file__))
244
+ if not os.path.exists(os.path.join(source_dir, ".git")):
245
+ logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
246
+ source_dir
247
+ ))
248
+ return
249
+
250
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
251
+
252
+ path = os.path.join(model_dir, "githash")
253
+ if os.path.exists(path):
254
+ saved_hash = open(path).read()
255
+ if saved_hash != cur_hash:
256
+ logger.warn("git hash values are different. {}(saved) != {}(current)".format(
257
+ saved_hash[:8], cur_hash[:8]))
258
+ else:
259
+ open(path, "w").write(cur_hash)
260
+
261
+
262
+ def get_logger(model_dir, filename="train.log"):
263
+ global logger
264
+ logger = logging.getLogger(os.path.basename(model_dir))
265
+ logger.setLevel(logging.DEBUG)
266
+
267
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
268
+ if not os.path.exists(model_dir):
269
+ os.makedirs(model_dir)
270
+ h = logging.FileHandler(os.path.join(model_dir, filename))
271
+ h.setLevel(logging.DEBUG)
272
+ h.setFormatter(formatter)
273
+ logger.addHandler(h)
274
+ return logger
275
+
276
+
277
+ class HParams():
278
+ def __init__(self, **kwargs):
279
+ for k, v in kwargs.items():
280
+ if type(v) == dict:
281
+ v = HParams(**v)
282
+ self[k] = v
283
+
284
+ def keys(self):
285
+ return self.__dict__.keys()
286
+
287
+ def items(self):
288
+ return self.__dict__.items()
289
+
290
+ def values(self):
291
+ return self.__dict__.values()
292
+
293
+ def __len__(self):
294
+ return len(self.__dict__)
295
+
296
+ def __getitem__(self, key):
297
+ return getattr(self, key)
298
+
299
+ def __setitem__(self, key, value):
300
+ return setattr(self, key, value)
301
+
302
+ def __contains__(self, key):
303
+ return key in self.__dict__
304
+
305
+ def __repr__(self):
306
+ return self.__dict__.__repr__()
wavlm/WavLM-Large.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6fb4b3c3e6aa567f0a997b30855859cb81528ee8078802af439f7b2da0bf100f
3
+ size 1261965425
wavlm/WavLM-Large.pt.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ https://github.com/microsoft/unilm/tree/master/wavlm
wavlm/WavLM.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import logging
12
+ from typing import List, Optional, Tuple
13
+
14
+ import numpy as np
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch.nn import LayerNorm
20
+ from wavlm.modules import (
21
+ Fp32GroupNorm,
22
+ Fp32LayerNorm,
23
+ GradMultiply,
24
+ MultiheadAttention,
25
+ SamePad,
26
+ init_bert_params,
27
+ get_activation_fn,
28
+ TransposeLast,
29
+ GLU_Linear,
30
+ )
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def compute_mask_indices(
36
+ shape: Tuple[int, int],
37
+ padding_mask: Optional[torch.Tensor],
38
+ mask_prob: float,
39
+ mask_length: int,
40
+ mask_type: str = "static",
41
+ mask_other: float = 0.0,
42
+ min_masks: int = 0,
43
+ no_overlap: bool = False,
44
+ min_space: int = 0,
45
+ ) -> np.ndarray:
46
+ """
47
+ Computes random mask spans for a given shape
48
+
49
+ Args:
50
+ shape: the the shape for which to compute masks.
51
+ should be of size 2 where first element is batch size and 2nd is timesteps
52
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
53
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
54
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
55
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
56
+ mask_type: how to compute mask lengths
57
+ static = fixed size
58
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
59
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
60
+ poisson = sample from possion distribution with lambda = mask length
61
+ min_masks: minimum number of masked spans
62
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
63
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
64
+ """
65
+
66
+ bsz, all_sz = shape
67
+ mask = np.full((bsz, all_sz), False)
68
+
69
+ all_num_mask = int(
70
+ # add a random number for probabilistic rounding
71
+ mask_prob * all_sz / float(mask_length)
72
+ + np.random.rand()
73
+ )
74
+
75
+ all_num_mask = max(min_masks, all_num_mask)
76
+
77
+ mask_idcs = []
78
+ for i in range(bsz):
79
+ if padding_mask is not None:
80
+ sz = all_sz - padding_mask[i].long().sum().item()
81
+ num_mask = int(
82
+ # add a random number for probabilistic rounding
83
+ mask_prob * sz / float(mask_length)
84
+ + np.random.rand()
85
+ )
86
+ num_mask = max(min_masks, num_mask)
87
+ else:
88
+ sz = all_sz
89
+ num_mask = all_num_mask
90
+
91
+ if mask_type == "static":
92
+ lengths = np.full(num_mask, mask_length)
93
+ elif mask_type == "uniform":
94
+ lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
95
+ elif mask_type == "normal":
96
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
97
+ lengths = [max(1, int(round(x))) for x in lengths]
98
+ elif mask_type == "poisson":
99
+ lengths = np.random.poisson(mask_length, size=num_mask)
100
+ lengths = [int(round(x)) for x in lengths]
101
+ else:
102
+ raise Exception("unknown mask selection " + mask_type)
103
+
104
+ if sum(lengths) == 0:
105
+ lengths[0] = min(mask_length, sz - 1)
106
+
107
+ if no_overlap:
108
+ mask_idc = []
109
+
110
+ def arrange(s, e, length, keep_length):
111
+ span_start = np.random.randint(s, e - length)
112
+ mask_idc.extend(span_start + i for i in range(length))
113
+
114
+ new_parts = []
115
+ if span_start - s - min_space >= keep_length:
116
+ new_parts.append((s, span_start - min_space + 1))
117
+ if e - span_start - keep_length - min_space > keep_length:
118
+ new_parts.append((span_start + length + min_space, e))
119
+ return new_parts
120
+
121
+ parts = [(0, sz)]
122
+ min_length = min(lengths)
123
+ for length in sorted(lengths, reverse=True):
124
+ lens = np.fromiter(
125
+ (e - s if e - s >= length + min_space else 0 for s, e in parts),
126
+ np.int,
127
+ )
128
+ l_sum = np.sum(lens)
129
+ if l_sum == 0:
130
+ break
131
+ probs = lens / np.sum(lens)
132
+ c = np.random.choice(len(parts), p=probs)
133
+ s, e = parts.pop(c)
134
+ parts.extend(arrange(s, e, length, min_length))
135
+ mask_idc = np.asarray(mask_idc)
136
+ else:
137
+ min_len = min(lengths)
138
+ if sz - min_len <= num_mask:
139
+ min_len = sz - num_mask - 1
140
+
141
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
142
+
143
+ mask_idc = np.asarray(
144
+ [
145
+ mask_idc[j] + offset
146
+ for j in range(len(mask_idc))
147
+ for offset in range(lengths[j])
148
+ ]
149
+ )
150
+
151
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
152
+
153
+ min_len = min([len(m) for m in mask_idcs])
154
+ for i, mask_idc in enumerate(mask_idcs):
155
+ if len(mask_idc) > min_len:
156
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
157
+ mask[i, mask_idc] = True
158
+
159
+ return mask
160
+
161
+
162
+ class WavLMConfig:
163
+ def __init__(self, cfg=None):
164
+ self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
165
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
166
+
167
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
168
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
169
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
170
+ self.activation_fn: str = "gelu" # activation function to use
171
+
172
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
173
+ self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
174
+ self.conv_bias: bool = False # include bias in conv encoder
175
+ self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
176
+
177
+ self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
178
+
179
+ # dropouts
180
+ self.dropout: float = 0.1 # dropout probability for the transformer
181
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
182
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
183
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
184
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
185
+ self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
186
+
187
+ # masking
188
+ self.mask_length: int = 10 # mask length
189
+ self.mask_prob: float = 0.65 # probability of replacing a token with mask
190
+ self.mask_selection: str = "static" # how to choose mask length
191
+ self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
192
+ self.no_mask_overlap: bool = False # whether to allow masks to overlap
193
+ self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
194
+
195
+ # channel masking
196
+ self.mask_channel_length: int = 10 # length of the mask for features (channels)
197
+ self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
198
+ self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
199
+ self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
200
+ self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
201
+ self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
202
+
203
+ # positional embeddings
204
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
205
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
206
+
207
+ # relative position embedding
208
+ self.relative_position_embedding: bool = False # apply relative position embedding
209
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
210
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
211
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
212
+
213
+ if cfg is not None:
214
+ self.update(cfg)
215
+
216
+ def update(self, cfg: dict):
217
+ self.__dict__.update(cfg)
218
+
219
+
220
+ class WavLM(nn.Module):
221
+ def __init__(
222
+ self,
223
+ cfg: WavLMConfig,
224
+ ) -> None:
225
+ super().__init__()
226
+ logger.info(f"WavLM Config: {cfg.__dict__}")
227
+
228
+ self.cfg = cfg
229
+ feature_enc_layers = eval(cfg.conv_feature_layers)
230
+ self.embed = feature_enc_layers[-1][0]
231
+
232
+ self.feature_extractor = ConvFeatureExtractionModel(
233
+ conv_layers=feature_enc_layers,
234
+ dropout=0.0,
235
+ mode=cfg.extractor_mode,
236
+ conv_bias=cfg.conv_bias,
237
+ )
238
+
239
+ self.post_extract_proj = (
240
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
241
+ if self.embed != cfg.encoder_embed_dim
242
+ else None
243
+ )
244
+
245
+ self.mask_prob = cfg.mask_prob
246
+ self.mask_selection = cfg.mask_selection
247
+ self.mask_other = cfg.mask_other
248
+ self.mask_length = cfg.mask_length
249
+ self.no_mask_overlap = cfg.no_mask_overlap
250
+ self.mask_min_space = cfg.mask_min_space
251
+
252
+ self.mask_channel_prob = cfg.mask_channel_prob
253
+ self.mask_channel_selection = cfg.mask_channel_selection
254
+ self.mask_channel_other = cfg.mask_channel_other
255
+ self.mask_channel_length = cfg.mask_channel_length
256
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
257
+ self.mask_channel_min_space = cfg.mask_channel_min_space
258
+
259
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
260
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
261
+
262
+ self.feature_grad_mult = cfg.feature_grad_mult
263
+
264
+ self.mask_emb = nn.Parameter(
265
+ torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
266
+ )
267
+
268
+ self.encoder = TransformerEncoder(cfg)
269
+ self.layer_norm = LayerNorm(self.embed)
270
+
271
+ def apply_mask(self, x, padding_mask):
272
+ B, T, C = x.shape
273
+ if self.mask_prob > 0:
274
+ mask_indices = compute_mask_indices(
275
+ (B, T),
276
+ padding_mask,
277
+ self.mask_prob,
278
+ self.mask_length,
279
+ self.mask_selection,
280
+ self.mask_other,
281
+ min_masks=2,
282
+ no_overlap=self.no_mask_overlap,
283
+ min_space=self.mask_min_space,
284
+ )
285
+ mask_indices = torch.from_numpy(mask_indices).to(x.device)
286
+ x[mask_indices] = self.mask_emb
287
+ else:
288
+ mask_indices = None
289
+
290
+ if self.mask_channel_prob > 0:
291
+ mask_channel_indices = compute_mask_indices(
292
+ (B, C),
293
+ None,
294
+ self.mask_channel_prob,
295
+ self.mask_channel_length,
296
+ self.mask_channel_selection,
297
+ self.mask_channel_other,
298
+ no_overlap=self.no_mask_channel_overlap,
299
+ min_space=self.mask_channel_min_space,
300
+ )
301
+ mask_channel_indices = (
302
+ torch.from_numpy(mask_channel_indices)
303
+ .to(x.device)
304
+ .unsqueeze(1)
305
+ .expand(-1, T, -1)
306
+ )
307
+ x[mask_channel_indices] = 0
308
+
309
+ return x, mask_indices
310
+
311
+ def forward_padding_mask(
312
+ self, features: torch.Tensor, padding_mask: torch.Tensor,
313
+ ) -> torch.Tensor:
314
+ extra = padding_mask.size(1) % features.size(1)
315
+ if extra > 0:
316
+ padding_mask = padding_mask[:, :-extra]
317
+ padding_mask = padding_mask.view(
318
+ padding_mask.size(0), features.size(1), -1
319
+ )
320
+ #padding_mask = padding_mask.all(-1)
321
+ padding_mask = padding_mask.any(-1)
322
+ return padding_mask
323
+
324
+ def extract_features(
325
+ self,
326
+ source: torch.Tensor,
327
+ padding_mask: Optional[torch.Tensor] = None,
328
+ mask: bool = False,
329
+ ret_conv: bool = False,
330
+ output_layer: Optional[int] = None,
331
+ ret_layer_results: bool = False,
332
+ ):
333
+ if self.feature_grad_mult > 0:
334
+ features = self.feature_extractor(source)
335
+ if self.feature_grad_mult != 1.0:
336
+ features = GradMultiply.apply(features, self.feature_grad_mult)
337
+ else:
338
+ with torch.no_grad():
339
+ features = self.feature_extractor(source)
340
+
341
+ features = features.transpose(1, 2)
342
+ features = self.layer_norm(features)
343
+
344
+ if padding_mask is not None:
345
+ padding_mask = self.forward_padding_mask(features, padding_mask)
346
+
347
+ if self.post_extract_proj is not None:
348
+ features = self.post_extract_proj(features)
349
+
350
+ features = self.dropout_input(features)
351
+
352
+ if mask:
353
+ x, mask_indices = self.apply_mask(
354
+ features, padding_mask
355
+ )
356
+ else:
357
+ x = features
358
+
359
+ # feature: (B, T, D), float
360
+ # target: (B, T), long
361
+ # x: (B, T, D), float
362
+ # padding_mask: (B, T), bool
363
+ # mask_indices: (B, T), bool
364
+ x, layer_results = self.encoder(
365
+ x,
366
+ padding_mask=padding_mask,
367
+ layer=None if output_layer is None else output_layer - 1
368
+ )
369
+
370
+ res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
371
+
372
+ feature = res["features"] if ret_conv else res["x"]
373
+ if ret_layer_results:
374
+ feature = (feature, res["layer_results"])
375
+ return feature, res["padding_mask"]
376
+
377
+
378
+ class ConvFeatureExtractionModel(nn.Module):
379
+ def __init__(
380
+ self,
381
+ conv_layers: List[Tuple[int, int, int]],
382
+ dropout: float = 0.0,
383
+ mode: str = "default",
384
+ conv_bias: bool = False,
385
+ conv_type: str = "default"
386
+ ):
387
+ super().__init__()
388
+
389
+ assert mode in {"default", "layer_norm"}
390
+
391
+ def block(
392
+ n_in,
393
+ n_out,
394
+ k,
395
+ stride,
396
+ is_layer_norm=False,
397
+ is_group_norm=False,
398
+ conv_bias=False,
399
+ ):
400
+ def make_conv():
401
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
402
+ nn.init.kaiming_normal_(conv.weight)
403
+ return conv
404
+
405
+ assert (
406
+ is_layer_norm and is_group_norm
407
+ ) == False, "layer norm and group norm are exclusive"
408
+
409
+ if is_layer_norm:
410
+ return nn.Sequential(
411
+ make_conv(),
412
+ nn.Dropout(p=dropout),
413
+ nn.Sequential(
414
+ TransposeLast(),
415
+ Fp32LayerNorm(dim, elementwise_affine=True),
416
+ TransposeLast(),
417
+ ),
418
+ nn.GELU(),
419
+ )
420
+ elif is_group_norm:
421
+ return nn.Sequential(
422
+ make_conv(),
423
+ nn.Dropout(p=dropout),
424
+ Fp32GroupNorm(dim, dim, affine=True),
425
+ nn.GELU(),
426
+ )
427
+ else:
428
+ return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
429
+
430
+ self.conv_type = conv_type
431
+ if self.conv_type == "default":
432
+ in_d = 1
433
+ self.conv_layers = nn.ModuleList()
434
+ for i, cl in enumerate(conv_layers):
435
+ assert len(cl) == 3, "invalid conv definition: " + str(cl)
436
+ (dim, k, stride) = cl
437
+
438
+ self.conv_layers.append(
439
+ block(
440
+ in_d,
441
+ dim,
442
+ k,
443
+ stride,
444
+ is_layer_norm=mode == "layer_norm",
445
+ is_group_norm=mode == "default" and i == 0,
446
+ conv_bias=conv_bias,
447
+ )
448
+ )
449
+ in_d = dim
450
+ elif self.conv_type == "conv2d":
451
+ in_d = 1
452
+ self.conv_layers = nn.ModuleList()
453
+ for i, cl in enumerate(conv_layers):
454
+ assert len(cl) == 3
455
+ (dim, k, stride) = cl
456
+
457
+ self.conv_layers.append(
458
+ torch.nn.Conv2d(in_d, dim, k, stride)
459
+ )
460
+ self.conv_layers.append(torch.nn.ReLU())
461
+ in_d = dim
462
+ elif self.conv_type == "custom":
463
+ in_d = 1
464
+ idim = 80
465
+ self.conv_layers = nn.ModuleList()
466
+ for i, cl in enumerate(conv_layers):
467
+ assert len(cl) == 3
468
+ (dim, k, stride) = cl
469
+ self.conv_layers.append(
470
+ torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
471
+ )
472
+ self.conv_layers.append(
473
+ torch.nn.LayerNorm([dim, idim])
474
+ )
475
+ self.conv_layers.append(torch.nn.ReLU())
476
+ in_d = dim
477
+ if (i + 1) % 2 == 0:
478
+ self.conv_layers.append(
479
+ torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
480
+ )
481
+ idim = int(math.ceil(idim / 2))
482
+ else:
483
+ pass
484
+
485
+ def forward(self, x, mask=None):
486
+
487
+ # BxT -> BxCxT
488
+ x = x.unsqueeze(1)
489
+ if self.conv_type == "custom":
490
+ for conv in self.conv_layers:
491
+ if isinstance(conv, nn.LayerNorm):
492
+ x = x.transpose(1, 2)
493
+ x = conv(x).transpose(1, 2)
494
+ else:
495
+ x = conv(x)
496
+ x = x.transpose(2, 3).contiguous()
497
+ x = x.view(x.size(0), -1, x.size(-1))
498
+ else:
499
+ for conv in self.conv_layers:
500
+ x = conv(x)
501
+ if self.conv_type == "conv2d":
502
+ b, c, t, f = x.size()
503
+ x = x.transpose(2, 3).contiguous().view(b, c * f, t)
504
+ return x
505
+
506
+
507
+ class TransformerEncoder(nn.Module):
508
+ def __init__(self, args):
509
+ super().__init__()
510
+
511
+ self.dropout = args.dropout
512
+ self.embedding_dim = args.encoder_embed_dim
513
+
514
+ self.pos_conv = nn.Conv1d(
515
+ self.embedding_dim,
516
+ self.embedding_dim,
517
+ kernel_size=args.conv_pos,
518
+ padding=args.conv_pos // 2,
519
+ groups=args.conv_pos_groups,
520
+ )
521
+ dropout = 0
522
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
523
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
524
+ nn.init.constant_(self.pos_conv.bias, 0)
525
+
526
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
527
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
528
+
529
+ if hasattr(args, "relative_position_embedding"):
530
+ self.relative_position_embedding = args.relative_position_embedding
531
+ self.num_buckets = args.num_buckets
532
+ self.max_distance = args.max_distance
533
+ else:
534
+ self.relative_position_embedding = False
535
+ self.num_buckets = 0
536
+ self.max_distance = 0
537
+
538
+ self.layers = nn.ModuleList(
539
+ [
540
+ TransformerSentenceEncoderLayer(
541
+ embedding_dim=self.embedding_dim,
542
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
543
+ num_attention_heads=args.encoder_attention_heads,
544
+ dropout=self.dropout,
545
+ attention_dropout=args.attention_dropout,
546
+ activation_dropout=args.activation_dropout,
547
+ activation_fn=args.activation_fn,
548
+ layer_norm_first=args.layer_norm_first,
549
+ has_relative_attention_bias=(self.relative_position_embedding and i == 0),
550
+ num_buckets=self.num_buckets,
551
+ max_distance=self.max_distance,
552
+ gru_rel_pos=args.gru_rel_pos,
553
+ )
554
+ for i in range(args.encoder_layers)
555
+ ]
556
+ )
557
+
558
+ self.layer_norm_first = args.layer_norm_first
559
+ self.layer_norm = LayerNorm(self.embedding_dim)
560
+ self.layerdrop = args.encoder_layerdrop
561
+
562
+ self.apply(init_bert_params)
563
+
564
+ def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
565
+ x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
566
+
567
+ if self.layer_norm_first and layer is None:
568
+ x = self.layer_norm(x)
569
+
570
+ return x, layer_results
571
+
572
+ def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
573
+
574
+ if padding_mask is not None:
575
+ x[padding_mask] = 0
576
+
577
+ x_conv = self.pos_conv(x.transpose(1, 2))
578
+ x_conv = x_conv.transpose(1, 2)
579
+ x += x_conv
580
+
581
+ if not self.layer_norm_first:
582
+ x = self.layer_norm(x)
583
+
584
+ x = F.dropout(x, p=self.dropout, training=self.training)
585
+
586
+ # B x T x C -> T x B x C
587
+ x = x.transpose(0, 1)
588
+
589
+ layer_results = []
590
+ z = None
591
+ if tgt_layer is not None:
592
+ layer_results.append((x, z))
593
+ r = None
594
+ pos_bias = None
595
+ for i, layer in enumerate(self.layers):
596
+ dropout_probability = np.random.random()
597
+ if not self.training or (dropout_probability > self.layerdrop):
598
+ x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False,
599
+ self_attn_mask=streaming_mask, pos_bias=pos_bias)
600
+ if tgt_layer is not None:
601
+ layer_results.append((x, z))
602
+ if i == tgt_layer:
603
+ r = x
604
+ break
605
+
606
+ if r is not None:
607
+ x = r
608
+
609
+ # T x B x C -> B x T x C
610
+ x = x.transpose(0, 1)
611
+
612
+ return x, layer_results
613
+
614
+
615
+ class TransformerSentenceEncoderLayer(nn.Module):
616
+ """
617
+ Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
618
+ models.
619
+ """
620
+
621
+ def __init__(
622
+ self,
623
+ embedding_dim: float = 768,
624
+ ffn_embedding_dim: float = 3072,
625
+ num_attention_heads: float = 8,
626
+ dropout: float = 0.1,
627
+ attention_dropout: float = 0.1,
628
+ activation_dropout: float = 0.1,
629
+ activation_fn: str = "relu",
630
+ layer_norm_first: bool = False,
631
+ has_relative_attention_bias: bool = False,
632
+ num_buckets: int = 0,
633
+ max_distance: int = 0,
634
+ rescale_init: bool = False,
635
+ gru_rel_pos: bool = False,
636
+ ) -> None:
637
+
638
+ super().__init__()
639
+ # Initialize parameters
640
+ self.embedding_dim = embedding_dim
641
+ self.dropout = dropout
642
+ self.activation_dropout = activation_dropout
643
+
644
+ # Initialize blocks
645
+ self.activation_name = activation_fn
646
+ self.activation_fn = get_activation_fn(activation_fn)
647
+ self.self_attn = MultiheadAttention(
648
+ self.embedding_dim,
649
+ num_attention_heads,
650
+ dropout=attention_dropout,
651
+ self_attention=True,
652
+ has_relative_attention_bias=has_relative_attention_bias,
653
+ num_buckets=num_buckets,
654
+ max_distance=max_distance,
655
+ rescale_init=rescale_init,
656
+ gru_rel_pos=gru_rel_pos,
657
+ )
658
+
659
+ self.dropout1 = nn.Dropout(dropout)
660
+ self.dropout2 = nn.Dropout(self.activation_dropout)
661
+ self.dropout3 = nn.Dropout(dropout)
662
+
663
+ self.layer_norm_first = layer_norm_first
664
+
665
+ # layer norm associated with the self attention layer
666
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
667
+
668
+ if self.activation_name == "glu":
669
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
670
+ else:
671
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
672
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
673
+
674
+ # layer norm associated with the position wise feed-forward NN
675
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
676
+
677
+ def forward(
678
+ self,
679
+ x: torch.Tensor,
680
+ self_attn_mask: torch.Tensor = None,
681
+ self_attn_padding_mask: torch.Tensor = None,
682
+ need_weights: bool = False,
683
+ pos_bias=None
684
+ ):
685
+ """
686
+ LayerNorm is applied either before or after the self-attention/ffn
687
+ modules similar to the original Transformer imlementation.
688
+ """
689
+ residual = x
690
+
691
+ if self.layer_norm_first:
692
+ x = self.self_attn_layer_norm(x)
693
+ x, attn, pos_bias = self.self_attn(
694
+ query=x,
695
+ key=x,
696
+ value=x,
697
+ key_padding_mask=self_attn_padding_mask,
698
+ need_weights=False,
699
+ attn_mask=self_attn_mask,
700
+ position_bias=pos_bias
701
+ )
702
+ x = self.dropout1(x)
703
+ x = residual + x
704
+
705
+ residual = x
706
+ x = self.final_layer_norm(x)
707
+ if self.activation_name == "glu":
708
+ x = self.fc1(x)
709
+ else:
710
+ x = self.activation_fn(self.fc1(x))
711
+ x = self.dropout2(x)
712
+ x = self.fc2(x)
713
+ x = self.dropout3(x)
714
+ x = residual + x
715
+ else:
716
+ x, attn, pos_bias = self.self_attn(
717
+ query=x,
718
+ key=x,
719
+ value=x,
720
+ key_padding_mask=self_attn_padding_mask,
721
+ need_weights=need_weights,
722
+ attn_mask=self_attn_mask,
723
+ position_bias=pos_bias
724
+ )
725
+
726
+ x = self.dropout1(x)
727
+ x = residual + x
728
+
729
+ x = self.self_attn_layer_norm(x)
730
+
731
+ residual = x
732
+ if self.activation_name == "glu":
733
+ x = self.fc1(x)
734
+ else:
735
+ x = self.activation_fn(self.fc1(x))
736
+ x = self.dropout2(x)
737
+ x = self.fc2(x)
738
+ x = self.dropout3(x)
739
+ x = residual + x
740
+ x = self.final_layer_norm(x)
741
+
742
+ return x, attn, pos_bias
wavlm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from wavlm.WavLM import WavLM, WavLMConfig
wavlm/__pycache__/WavLM.cpython-39.pyc ADDED
Binary file (16.5 kB). View file
 
wavlm/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (191 Bytes). View file
 
wavlm/__pycache__/modules.cpython-39.pyc ADDED
Binary file (19.3 kB). View file
 
wavlm/modules.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4
+ # Copyright (c) 2021 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import warnings
12
+ from typing import Dict, Optional, Tuple
13
+ import torch
14
+ from torch import Tensor, nn
15
+ from torch.nn import Parameter
16
+ import torch.nn.functional as F
17
+
18
+
19
+ class TransposeLast(nn.Module):
20
+ def __init__(self, deconstruct_idx=None):
21
+ super().__init__()
22
+ self.deconstruct_idx = deconstruct_idx
23
+
24
+ def forward(self, x):
25
+ if self.deconstruct_idx is not None:
26
+ x = x[self.deconstruct_idx]
27
+ return x.transpose(-2, -1)
28
+
29
+
30
+ class Fp32LayerNorm(nn.LayerNorm):
31
+ def __init__(self, *args, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+
34
+ def forward(self, input):
35
+ output = F.layer_norm(
36
+ input.float(),
37
+ self.normalized_shape,
38
+ self.weight.float() if self.weight is not None else None,
39
+ self.bias.float() if self.bias is not None else None,
40
+ self.eps,
41
+ )
42
+ return output.type_as(input)
43
+
44
+
45
+ class Fp32GroupNorm(nn.GroupNorm):
46
+ def __init__(self, *args, **kwargs):
47
+ super().__init__(*args, **kwargs)
48
+
49
+ def forward(self, input):
50
+ output = F.group_norm(
51
+ input.float(),
52
+ self.num_groups,
53
+ self.weight.float() if self.weight is not None else None,
54
+ self.bias.float() if self.bias is not None else None,
55
+ self.eps,
56
+ )
57
+ return output.type_as(input)
58
+
59
+
60
+ class GradMultiply(torch.autograd.Function):
61
+ @staticmethod
62
+ def forward(ctx, x, scale):
63
+ ctx.scale = scale
64
+ res = x.new(x)
65
+ return res
66
+
67
+ @staticmethod
68
+ def backward(ctx, grad):
69
+ return grad * ctx.scale, None
70
+
71
+
72
+ class SamePad(nn.Module):
73
+ def __init__(self, kernel_size, causal=False):
74
+ super().__init__()
75
+ if causal:
76
+ self.remove = kernel_size - 1
77
+ else:
78
+ self.remove = 1 if kernel_size % 2 == 0 else 0
79
+
80
+ def forward(self, x):
81
+ if self.remove > 0:
82
+ x = x[:, :, : -self.remove]
83
+ return x
84
+
85
+
86
+ class Swish(nn.Module):
87
+ """Swish function
88
+ """
89
+
90
+ def __init__(self):
91
+ """Construct an MultiHeadedAttention object."""
92
+ super(Swish, self).__init__()
93
+ self.act = torch.nn.Sigmoid()
94
+
95
+ def forward(self, x):
96
+ return x * self.act(x)
97
+
98
+
99
+ class GLU_Linear(nn.Module):
100
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
101
+ super(GLU_Linear, self).__init__()
102
+
103
+ self.glu_type = glu_type
104
+ self.output_dim = output_dim
105
+
106
+ if glu_type == "sigmoid":
107
+ self.glu_act = torch.nn.Sigmoid()
108
+ elif glu_type == "swish":
109
+ self.glu_act = Swish()
110
+ elif glu_type == "relu":
111
+ self.glu_act = torch.nn.ReLU()
112
+ elif glu_type == "gelu":
113
+ self.glu_act = torch.nn.GELU()
114
+
115
+ if bias_in_glu:
116
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
117
+ else:
118
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
119
+
120
+ def forward(self, x):
121
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
122
+ x = self.linear(x)
123
+
124
+ if self.glu_type == "bilinear":
125
+ x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
126
+ else:
127
+ x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
128
+
129
+ return x
130
+
131
+
132
+ def gelu_accurate(x):
133
+ if not hasattr(gelu_accurate, "_a"):
134
+ gelu_accurate._a = math.sqrt(2 / math.pi)
135
+ return (
136
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
137
+ )
138
+
139
+
140
+ def gelu(x: torch.Tensor) -> torch.Tensor:
141
+ return torch.nn.functional.gelu(x.float()).type_as(x)
142
+
143
+
144
+ def get_activation_fn(activation: str):
145
+ """Returns the activation function corresponding to `activation`"""
146
+
147
+ if activation == "relu":
148
+ return F.relu
149
+ elif activation == "gelu":
150
+ return gelu
151
+ elif activation == "gelu_fast":
152
+ warnings.warn(
153
+ "--activation-fn=gelu_fast has been renamed to gelu_accurate"
154
+ )
155
+ return gelu_accurate
156
+ elif activation == "gelu_accurate":
157
+ return gelu_accurate
158
+ elif activation == "tanh":
159
+ return torch.tanh
160
+ elif activation == "linear":
161
+ return lambda x: x
162
+ elif activation == "glu":
163
+ return lambda x: x
164
+ else:
165
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
166
+
167
+
168
+ def init_bert_params(module):
169
+ """
170
+ Initialize the weights specific to the BERT Model.
171
+ This overrides the default initializations depending on the specified arguments.
172
+ 1. If normal_init_linear_weights is set then weights of linear
173
+ layer will be initialized using the normal distribution and
174
+ bais will be set to the specified value.
175
+ 2. If normal_init_embed_weights is set then weights of embedding
176
+ layer will be initialized using the normal distribution.
177
+ 3. If normal_init_proj_weights is set then weights of
178
+ in_project_weight for MultiHeadAttention initialized using
179
+ the normal distribution (to be validated).
180
+ """
181
+
182
+ def normal_(data):
183
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
184
+ # so that the RNG is consistent with and without FSDP
185
+ data.copy_(
186
+ data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
187
+ )
188
+
189
+ if isinstance(module, nn.Linear):
190
+ normal_(module.weight.data)
191
+ if module.bias is not None:
192
+ module.bias.data.zero_()
193
+ if isinstance(module, nn.Embedding):
194
+ normal_(module.weight.data)
195
+ if module.padding_idx is not None:
196
+ module.weight.data[module.padding_idx].zero_()
197
+ if isinstance(module, MultiheadAttention):
198
+ normal_(module.q_proj.weight.data)
199
+ normal_(module.k_proj.weight.data)
200
+ normal_(module.v_proj.weight.data)
201
+
202
+
203
+ def quant_noise(module, p, block_size):
204
+ """
205
+ Wraps modules and applies quantization noise to the weights for
206
+ subsequent quantization with Iterative Product Quantization as
207
+ described in "Training with Quantization Noise for Extreme Model Compression"
208
+
209
+ Args:
210
+ - module: nn.Module
211
+ - p: amount of Quantization Noise
212
+ - block_size: size of the blocks for subsequent quantization with iPQ
213
+
214
+ Remarks:
215
+ - Module weights must have the right sizes wrt the block size
216
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
217
+ - For more detail on how to quantize by blocks with convolutional weights,
218
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
219
+ - We implement the simplest form of noise here as stated in the paper
220
+ which consists in randomly dropping blocks
221
+ """
222
+
223
+ # if no quantization noise, don't register hook
224
+ if p <= 0:
225
+ return module
226
+
227
+ # supported modules
228
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
229
+
230
+ # test whether module.weight has the right sizes wrt block_size
231
+ is_conv = module.weight.ndim == 4
232
+
233
+ # 2D matrix
234
+ if not is_conv:
235
+ assert (
236
+ module.weight.size(1) % block_size == 0
237
+ ), "Input features must be a multiple of block sizes"
238
+
239
+ # 4D matrix
240
+ else:
241
+ # 1x1 convolutions
242
+ if module.kernel_size == (1, 1):
243
+ assert (
244
+ module.in_channels % block_size == 0
245
+ ), "Input channels must be a multiple of block sizes"
246
+ # regular convolutions
247
+ else:
248
+ k = module.kernel_size[0] * module.kernel_size[1]
249
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
250
+
251
+ def _forward_pre_hook(mod, input):
252
+ # no noise for evaluation
253
+ if mod.training:
254
+ if not is_conv:
255
+ # gather weight and sizes
256
+ weight = mod.weight
257
+ in_features = weight.size(1)
258
+ out_features = weight.size(0)
259
+
260
+ # split weight matrix into blocks and randomly drop selected blocks
261
+ mask = torch.zeros(
262
+ in_features // block_size * out_features, device=weight.device
263
+ )
264
+ mask.bernoulli_(p)
265
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
266
+
267
+ else:
268
+ # gather weight and sizes
269
+ weight = mod.weight
270
+ in_channels = mod.in_channels
271
+ out_channels = mod.out_channels
272
+
273
+ # split weight matrix into blocks and randomly drop selected blocks
274
+ if mod.kernel_size == (1, 1):
275
+ mask = torch.zeros(
276
+ int(in_channels // block_size * out_channels),
277
+ device=weight.device,
278
+ )
279
+ mask.bernoulli_(p)
280
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
281
+ else:
282
+ mask = torch.zeros(
283
+ weight.size(0), weight.size(1), device=weight.device
284
+ )
285
+ mask.bernoulli_(p)
286
+ mask = (
287
+ mask.unsqueeze(2)
288
+ .unsqueeze(3)
289
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
290
+ )
291
+
292
+ # scale weights and apply mask
293
+ mask = mask.to(
294
+ torch.bool
295
+ ) # x.bool() is not currently supported in TorchScript
296
+ s = 1 / (1 - p)
297
+ mod.weight.data = s * weight.masked_fill(mask, 0)
298
+
299
+ module.register_forward_pre_hook(_forward_pre_hook)
300
+ return module
301
+
302
+
303
+ class MultiheadAttention(nn.Module):
304
+ """Multi-headed attention.
305
+
306
+ See "Attention Is All You Need" for more details.
307
+ """
308
+
309
+ def __init__(
310
+ self,
311
+ embed_dim,
312
+ num_heads,
313
+ kdim=None,
314
+ vdim=None,
315
+ dropout=0.0,
316
+ bias=True,
317
+ add_bias_kv=False,
318
+ add_zero_attn=False,
319
+ self_attention=False,
320
+ encoder_decoder_attention=False,
321
+ q_noise=0.0,
322
+ qn_block_size=8,
323
+ has_relative_attention_bias=False,
324
+ num_buckets=32,
325
+ max_distance=128,
326
+ gru_rel_pos=False,
327
+ rescale_init=False,
328
+ ):
329
+ super().__init__()
330
+ self.embed_dim = embed_dim
331
+ self.kdim = kdim if kdim is not None else embed_dim
332
+ self.vdim = vdim if vdim is not None else embed_dim
333
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
334
+
335
+ self.num_heads = num_heads
336
+ self.dropout_module = nn.Dropout(dropout)
337
+
338
+ self.has_relative_attention_bias = has_relative_attention_bias
339
+ self.num_buckets = num_buckets
340
+ self.max_distance = max_distance
341
+ if self.has_relative_attention_bias:
342
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
343
+
344
+ self.head_dim = embed_dim // num_heads
345
+ self.q_head_dim = self.head_dim
346
+ self.k_head_dim = self.head_dim
347
+ assert (
348
+ self.head_dim * num_heads == self.embed_dim
349
+ ), "embed_dim must be divisible by num_heads"
350
+ self.scaling = self.head_dim ** -0.5
351
+
352
+ self.self_attention = self_attention
353
+ self.encoder_decoder_attention = encoder_decoder_attention
354
+
355
+ assert not self.self_attention or self.qkv_same_dim, (
356
+ "Self-attention requires query, key and " "value to be of the same size"
357
+ )
358
+
359
+ k_bias = True
360
+ if rescale_init:
361
+ k_bias = False
362
+
363
+ k_embed_dim = embed_dim
364
+ q_embed_dim = embed_dim
365
+
366
+ self.k_proj = quant_noise(
367
+ nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
368
+ )
369
+ self.v_proj = quant_noise(
370
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
371
+ )
372
+ self.q_proj = quant_noise(
373
+ nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
374
+ )
375
+
376
+ self.out_proj = quant_noise(
377
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
378
+ )
379
+
380
+ if add_bias_kv:
381
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
382
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
383
+ else:
384
+ self.bias_k = self.bias_v = None
385
+
386
+ self.add_zero_attn = add_zero_attn
387
+
388
+ self.gru_rel_pos = gru_rel_pos
389
+ if self.gru_rel_pos:
390
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
391
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
392
+
393
+ self.reset_parameters()
394
+
395
+ def reset_parameters(self):
396
+ if self.qkv_same_dim:
397
+ # Empirically observed the convergence to be much better with
398
+ # the scaled initialization
399
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
400
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
401
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
402
+ else:
403
+ nn.init.xavier_uniform_(self.k_proj.weight)
404
+ nn.init.xavier_uniform_(self.v_proj.weight)
405
+ nn.init.xavier_uniform_(self.q_proj.weight)
406
+
407
+ nn.init.xavier_uniform_(self.out_proj.weight)
408
+ if self.out_proj.bias is not None:
409
+ nn.init.constant_(self.out_proj.bias, 0.0)
410
+ if self.bias_k is not None:
411
+ nn.init.xavier_normal_(self.bias_k)
412
+ if self.bias_v is not None:
413
+ nn.init.xavier_normal_(self.bias_v)
414
+ if self.has_relative_attention_bias:
415
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
416
+
417
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
418
+ num_buckets = self.num_buckets
419
+ max_distance = self.max_distance
420
+ relative_buckets = 0
421
+
422
+ if bidirectional:
423
+ num_buckets = num_buckets // 2
424
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
425
+ relative_positions = torch.abs(relative_positions)
426
+ else:
427
+ relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
428
+
429
+ max_exact = num_buckets // 2
430
+ is_small = relative_positions < max_exact
431
+
432
+ relative_postion_if_large = max_exact + (
433
+ torch.log(relative_positions.float() / max_exact)
434
+ / math.log(max_distance / max_exact)
435
+ * (num_buckets - max_exact)
436
+ ).to(torch.long)
437
+ relative_postion_if_large = torch.min(
438
+ relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
439
+ )
440
+
441
+ relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
442
+ return relative_buckets
443
+
444
+ def compute_bias(self, query_length, key_length):
445
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
446
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
447
+ relative_position = memory_position - context_position
448
+ relative_position_bucket = self._relative_positions_bucket(
449
+ relative_position,
450
+ bidirectional=True
451
+ )
452
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
453
+ values = self.relative_attention_bias(relative_position_bucket)
454
+ values = values.permute([2, 0, 1])
455
+ return values
456
+
457
+ def forward(
458
+ self,
459
+ query,
460
+ key: Optional[Tensor],
461
+ value: Optional[Tensor],
462
+ key_padding_mask: Optional[Tensor] = None,
463
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
464
+ need_weights: bool = True,
465
+ static_kv: bool = False,
466
+ attn_mask: Optional[Tensor] = None,
467
+ before_softmax: bool = False,
468
+ need_head_weights: bool = False,
469
+ position_bias: Optional[Tensor] = None
470
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
471
+ """Input shape: Time x Batch x Channel
472
+
473
+ Args:
474
+ key_padding_mask (ByteTensor, optional): mask to exclude
475
+ keys that are pads, of shape `(batch, src_len)`, where
476
+ padding elements are indicated by 1s.
477
+ need_weights (bool, optional): return the attention weights,
478
+ averaged over heads (default: False).
479
+ attn_mask (ByteTensor, optional): typically used to
480
+ implement causal attention, where the mask prevents the
481
+ attention from looking forward in time (default: None).
482
+ before_softmax (bool, optional): return the raw attention
483
+ weights and values before the attention softmax.
484
+ need_head_weights (bool, optional): return the attention
485
+ weights for each head. Implies *need_weights*. Default:
486
+ return the average attention weights over all heads.
487
+ """
488
+ if need_head_weights:
489
+ need_weights = True
490
+
491
+ is_tpu = query.device.type == "xla"
492
+
493
+ tgt_len, bsz, embed_dim = query.size()
494
+ src_len = tgt_len
495
+ assert embed_dim == self.embed_dim
496
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
497
+ if key is not None:
498
+ src_len, key_bsz, _ = key.size()
499
+ if not torch.jit.is_scripting():
500
+ assert key_bsz == bsz
501
+ assert value is not None
502
+ assert src_len, bsz == value.shape[:2]
503
+
504
+ if self.has_relative_attention_bias and position_bias is None:
505
+ position_bias = self.compute_bias(tgt_len, src_len)
506
+ position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
507
+
508
+ if (
509
+ not is_tpu # don't use PyTorch version on TPUs
510
+ and incremental_state is None
511
+ and not static_kv
512
+ # A workaround for quantization to work. Otherwise JIT compilation
513
+ # treats bias in linear module as method.
514
+ and not torch.jit.is_scripting()
515
+ and self.q_head_dim == self.head_dim
516
+ ):
517
+ assert key is not None and value is not None
518
+ assert attn_mask is None
519
+
520
+ attn_mask_rel_pos = None
521
+ if position_bias is not None:
522
+ attn_mask_rel_pos = position_bias
523
+ if self.gru_rel_pos:
524
+ query_layer = query.transpose(0, 1)
525
+ new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
526
+ query_layer = query_layer.view(*new_x_shape)
527
+ query_layer = query_layer.permute(0, 2, 1, 3)
528
+ _B, _H, _L, __ = query_layer.size()
529
+
530
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
531
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
532
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
533
+ attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
534
+
535
+ attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
536
+ k_proj_bias = self.k_proj.bias
537
+ if k_proj_bias is None:
538
+ k_proj_bias = torch.zeros_like(self.q_proj.bias)
539
+
540
+ x, attn = F.multi_head_attention_forward(
541
+ query,
542
+ key,
543
+ value,
544
+ self.embed_dim,
545
+ self.num_heads,
546
+ torch.empty([0]),
547
+ torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
548
+ self.bias_k,
549
+ self.bias_v,
550
+ self.add_zero_attn,
551
+ self.dropout_module.p,
552
+ self.out_proj.weight,
553
+ self.out_proj.bias,
554
+ self.training,
555
+ # self.training or self.dropout_module.apply_during_inference,
556
+ key_padding_mask,
557
+ need_weights,
558
+ attn_mask_rel_pos,
559
+ use_separate_proj_weight=True,
560
+ q_proj_weight=self.q_proj.weight,
561
+ k_proj_weight=self.k_proj.weight,
562
+ v_proj_weight=self.v_proj.weight,
563
+ )
564
+ return x, attn, position_bias
565
+
566
+ if incremental_state is not None:
567
+ saved_state = self._get_input_buffer(incremental_state)
568
+ if saved_state is not None and "prev_key" in saved_state:
569
+ # previous time steps are cached - no need to recompute
570
+ # key and value if they are static
571
+ if static_kv:
572
+ assert self.encoder_decoder_attention and not self.self_attention
573
+ key = value = None
574
+ else:
575
+ saved_state = None
576
+
577
+ if self.self_attention:
578
+ q = self.q_proj(query)
579
+ k = self.k_proj(query)
580
+ v = self.v_proj(query)
581
+ elif self.encoder_decoder_attention:
582
+ # encoder-decoder attention
583
+ q = self.q_proj(query)
584
+ if key is None:
585
+ assert value is None
586
+ k = v = None
587
+ else:
588
+ k = self.k_proj(key)
589
+ v = self.v_proj(key)
590
+
591
+ else:
592
+ assert key is not None and value is not None
593
+ q = self.q_proj(query)
594
+ k = self.k_proj(key)
595
+ v = self.v_proj(value)
596
+ q *= self.scaling
597
+
598
+ if self.bias_k is not None:
599
+ assert self.bias_v is not None
600
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
601
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
602
+ if attn_mask is not None:
603
+ attn_mask = torch.cat(
604
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
605
+ )
606
+ if key_padding_mask is not None:
607
+ key_padding_mask = torch.cat(
608
+ [
609
+ key_padding_mask,
610
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
611
+ ],
612
+ dim=1,
613
+ )
614
+
615
+ q = (
616
+ q.contiguous()
617
+ .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
618
+ .transpose(0, 1)
619
+ )
620
+ if k is not None:
621
+ k = (
622
+ k.contiguous()
623
+ .view(-1, bsz * self.num_heads, self.k_head_dim)
624
+ .transpose(0, 1)
625
+ )
626
+ if v is not None:
627
+ v = (
628
+ v.contiguous()
629
+ .view(-1, bsz * self.num_heads, self.head_dim)
630
+ .transpose(0, 1)
631
+ )
632
+
633
+ if saved_state is not None:
634
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
635
+ if "prev_key" in saved_state:
636
+ _prev_key = saved_state["prev_key"]
637
+ assert _prev_key is not None
638
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
639
+ if static_kv:
640
+ k = prev_key
641
+ else:
642
+ assert k is not None
643
+ k = torch.cat([prev_key, k], dim=1)
644
+ src_len = k.size(1)
645
+ if "prev_value" in saved_state:
646
+ _prev_value = saved_state["prev_value"]
647
+ assert _prev_value is not None
648
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
649
+ if static_kv:
650
+ v = prev_value
651
+ else:
652
+ assert v is not None
653
+ v = torch.cat([prev_value, v], dim=1)
654
+ prev_key_padding_mask: Optional[Tensor] = None
655
+ if "prev_key_padding_mask" in saved_state:
656
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
657
+ assert k is not None and v is not None
658
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
659
+ key_padding_mask=key_padding_mask,
660
+ prev_key_padding_mask=prev_key_padding_mask,
661
+ batch_size=bsz,
662
+ src_len=k.size(1),
663
+ static_kv=static_kv,
664
+ )
665
+
666
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
667
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
668
+ saved_state["prev_key_padding_mask"] = key_padding_mask
669
+ # In this branch incremental_state is never None
670
+ assert incremental_state is not None
671
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
672
+ assert k is not None
673
+ assert k.size(1) == src_len
674
+
675
+ # This is part of a workaround to get around fork/join parallelism
676
+ # not supporting Optional types.
677
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
678
+ key_padding_mask = None
679
+
680
+ if key_padding_mask is not None:
681
+ assert key_padding_mask.size(0) == bsz
682
+ assert key_padding_mask.size(1) == src_len
683
+
684
+ if self.add_zero_attn:
685
+ assert v is not None
686
+ src_len += 1
687
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
688
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
689
+ if attn_mask is not None:
690
+ attn_mask = torch.cat(
691
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
692
+ )
693
+ if key_padding_mask is not None:
694
+ key_padding_mask = torch.cat(
695
+ [
696
+ key_padding_mask,
697
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
698
+ key_padding_mask
699
+ ),
700
+ ],
701
+ dim=1,
702
+ )
703
+
704
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
705
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
706
+
707
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
708
+
709
+ if attn_mask is not None:
710
+ attn_mask = attn_mask.unsqueeze(0)
711
+ attn_weights += attn_mask
712
+
713
+ if key_padding_mask is not None:
714
+ # don't attend to padding symbols
715
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
716
+ if not is_tpu:
717
+ attn_weights = attn_weights.masked_fill(
718
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
719
+ float("-inf"),
720
+ )
721
+ else:
722
+ attn_weights = attn_weights.transpose(0, 2)
723
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
724
+ attn_weights = attn_weights.transpose(0, 2)
725
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
726
+
727
+ if before_softmax:
728
+ return attn_weights, v, position_bias
729
+
730
+ if position_bias is not None:
731
+ if self.gru_rel_pos == 1:
732
+ query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
733
+ _B, _H, _L, __ = query_layer.size()
734
+ gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
735
+ _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
736
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
737
+ position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
738
+
739
+ position_bias = position_bias.view(attn_weights.size())
740
+
741
+ attn_weights = attn_weights + position_bias
742
+
743
+ attn_weights_float = F.softmax(
744
+ attn_weights, dim=-1
745
+ )
746
+ attn_weights = attn_weights_float.type_as(attn_weights)
747
+ attn_probs = self.dropout_module(attn_weights)
748
+
749
+ assert v is not None
750
+ attn = torch.bmm(attn_probs, v)
751
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
752
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
753
+ attn = self.out_proj(attn)
754
+ attn_weights: Optional[Tensor] = None
755
+ if need_weights:
756
+ attn_weights = attn_weights_float.view(
757
+ bsz, self.num_heads, tgt_len, src_len
758
+ ).transpose(1, 0)
759
+ if not need_head_weights:
760
+ # average attention weights over heads
761
+ attn_weights = attn_weights.mean(dim=0)
762
+
763
+ return attn, attn_weights, position_bias
764
+
765
+ @staticmethod
766
+ def _append_prev_key_padding_mask(
767
+ key_padding_mask: Optional[Tensor],
768
+ prev_key_padding_mask: Optional[Tensor],
769
+ batch_size: int,
770
+ src_len: int,
771
+ static_kv: bool,
772
+ ) -> Optional[Tensor]:
773
+ # saved key padding masks have shape (bsz, seq_len)
774
+ if prev_key_padding_mask is not None and static_kv:
775
+ new_key_padding_mask = prev_key_padding_mask
776
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
777
+ new_key_padding_mask = torch.cat(
778
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
779
+ )
780
+ # During incremental decoding, as the padding token enters and
781
+ # leaves the frame, there will be a time when prev or current
782
+ # is None
783
+ elif prev_key_padding_mask is not None:
784
+ if src_len > prev_key_padding_mask.size(1):
785
+ filler = torch.zeros(
786
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
787
+ device=prev_key_padding_mask.device,
788
+ )
789
+ new_key_padding_mask = torch.cat(
790
+ [prev_key_padding_mask.float(), filler.float()], dim=1
791
+ )
792
+ else:
793
+ new_key_padding_mask = prev_key_padding_mask.float()
794
+ elif key_padding_mask is not None:
795
+ if src_len > key_padding_mask.size(1):
796
+ filler = torch.zeros(
797
+ (batch_size, src_len - key_padding_mask.size(1)),
798
+ device=key_padding_mask.device,
799
+ )
800
+ new_key_padding_mask = torch.cat(
801
+ [filler.float(), key_padding_mask.float()], dim=1
802
+ )
803
+ else:
804
+ new_key_padding_mask = key_padding_mask.float()
805
+ else:
806
+ new_key_padding_mask = prev_key_padding_mask
807
+ return new_key_padding_mask
808
+
809
+ def _get_input_buffer(
810
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
811
+ ) -> Dict[str, Optional[Tensor]]:
812
+ result = self.get_incremental_state(incremental_state, "attn_state")
813
+ if result is not None:
814
+ return result
815
+ else:
816
+ empty_result: Dict[str, Optional[Tensor]] = {}
817
+ return empty_result
818
+
819
+ def _set_input_buffer(
820
+ self,
821
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
822
+ buffer: Dict[str, Optional[Tensor]],
823
+ ):
824
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
825
+
826
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
827
+ return attn_weights