Spaces:
Runtime error
Runtime error
ramkamal2000
commited on
Commit
·
0844354
1
Parent(s):
665e760
adding untracked files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- app.py +203 -0
- checkpoints/freevc.pth +3 -0
- commons.py +171 -0
- configs/freevc.json +54 -0
- mel_processing.py +112 -0
- models.py +351 -0
- modules.py +342 -0
- requirements.txt +11 -0
- sample_inputs/ntr.wav +3 -0
- sample_inputs/out.wav +3 -0
- sample_inputs/p225_001.wav +3 -0
- sample_inputs/p226_002.wav +3 -0
- sample_inputs/reference.wav +3 -0
- sample_inputs/target.wav +3 -0
- sample_inputs/timcast1.wav +3 -0
- speaker_encoder/__init__.py +0 -0
- speaker_encoder/__pycache__/__init__.cpython-39.pyc +0 -0
- speaker_encoder/__pycache__/audio.cpython-39.pyc +0 -0
- speaker_encoder/__pycache__/hparams.cpython-39.pyc +0 -0
- speaker_encoder/__pycache__/params_data.cpython-39.pyc +0 -0
- speaker_encoder/__pycache__/voice_encoder.cpython-39.pyc +0 -0
- speaker_encoder/audio.py +107 -0
- speaker_encoder/ckpt/pretrained_bak_5805000.pt +3 -0
- speaker_encoder/compute_embed.py +40 -0
- speaker_encoder/config.py +45 -0
- speaker_encoder/data_objects/__init__.py +2 -0
- speaker_encoder/data_objects/random_cycler.py +37 -0
- speaker_encoder/data_objects/speaker.py +40 -0
- speaker_encoder/data_objects/speaker_batch.py +12 -0
- speaker_encoder/data_objects/speaker_verification_dataset.py +56 -0
- speaker_encoder/data_objects/utterance.py +26 -0
- speaker_encoder/hparams.py +31 -0
- speaker_encoder/inference.py +177 -0
- speaker_encoder/model.py +135 -0
- speaker_encoder/params_data.py +29 -0
- speaker_encoder/params_model.py +11 -0
- speaker_encoder/preprocess.py +285 -0
- speaker_encoder/train.py +125 -0
- speaker_encoder/visualizations.py +178 -0
- speaker_encoder/voice_encoder.py +173 -0
- utils.py +306 -0
- wavlm/WavLM-Large.pt +3 -0
- wavlm/WavLM-Large.pt.txt +1 -0
- wavlm/WavLM.py +742 -0
- wavlm/__init__.py +1 -0
- wavlm/__pycache__/WavLM.cpython-39.pyc +0 -0
- wavlm/__pycache__/__init__.cpython-39.pyc +0 -0
- wavlm/__pycache__/modules.cpython-39.pyc +0 -0
- wavlm/modules.py +827 -0
.gitattributes
CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
import gradio as gr
|
3 |
+
|
4 |
+
def greet(name):
|
5 |
+
return "Hello " + name + "!!"
|
6 |
+
|
7 |
+
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
8 |
+
iface.launch()
|
9 |
+
'''
|
10 |
+
import gradio
|
11 |
+
import os
|
12 |
+
import shutil
|
13 |
+
import gradio as gr
|
14 |
+
import sys
|
15 |
+
import string
|
16 |
+
import time
|
17 |
+
import argparse
|
18 |
+
import json
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import librosa
|
22 |
+
import subprocess
|
23 |
+
|
24 |
+
from pydub import AudioSegment
|
25 |
+
from scipy.io.wavfile import write, read
|
26 |
+
from transformers import WavLMModel
|
27 |
+
|
28 |
+
from TTS.tts.utils.synthesis import synthesis
|
29 |
+
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
30 |
+
try:
|
31 |
+
from TTS.utils.audio import AudioProcessor
|
32 |
+
except:
|
33 |
+
from TTS.utils.audio import AudioProcessor
|
34 |
+
from TTS.tts.models import setup_model
|
35 |
+
from TTS.config import load_config
|
36 |
+
from TTS.tts.models.vits import *
|
37 |
+
from TTS.tts.utils.speakers import SpeakerManager
|
38 |
+
|
39 |
+
import utils
|
40 |
+
from models import SynthesizerTrn
|
41 |
+
from mel_processing import mel_spectrogram_torch
|
42 |
+
from speaker_encoder.voice_encoder import SpeakerEncoder
|
43 |
+
|
44 |
+
TTS_PATH = "TTS/"
|
45 |
+
sys.path.append(TTS_PATH) # set this if TTS is not installed globally
|
46 |
+
|
47 |
+
OUT_PATH = 'out/'
|
48 |
+
os.makedirs(OUT_PATH, exist_ok=True)
|
49 |
+
|
50 |
+
TTS_SPEAKERS = "yourTTS_config/speakers.json"
|
51 |
+
USE_CUDA = torch.cuda.is_available()
|
52 |
+
device = torch.device("cuda" if USE_CUDA else "cpu")
|
53 |
+
|
54 |
+
CONFIG_PATH = 'yourTTS_config/config.json'
|
55 |
+
C = load_config(CONFIG_PATH)
|
56 |
+
ap = AudioProcessor(**C.audio)
|
57 |
+
|
58 |
+
speaker_embedding = None
|
59 |
+
C.model_args['d_vector_file'] = TTS_SPEAKERS
|
60 |
+
C.model_args['use_speaker_encoder_as_loss'] = False
|
61 |
+
|
62 |
+
model = setup_model(C)
|
63 |
+
|
64 |
+
TTS_LANGUAGES = "yourTTS_config/language_ids.json"
|
65 |
+
model.language_manager.set_language_ids_from_file(TTS_LANGUAGES)
|
66 |
+
|
67 |
+
# print(model.language_manager.num_languages, model.embedded_language_dim)
|
68 |
+
# print(model.emb_l)
|
69 |
+
|
70 |
+
MODEL_PATH = 'yourTTS_config/best_model.pth.tar'
|
71 |
+
cp = torch.load(MODEL_PATH, map_location=torch.device('cpu'))
|
72 |
+
|
73 |
+
model_weights = cp['model'].copy()
|
74 |
+
for key in list(model_weights.keys()):
|
75 |
+
if "speaker_encoder" in key:
|
76 |
+
del model_weights[key]
|
77 |
+
|
78 |
+
model.load_state_dict(model_weights)
|
79 |
+
model.eval()
|
80 |
+
|
81 |
+
if USE_CUDA:
|
82 |
+
model = model.cuda()
|
83 |
+
|
84 |
+
use_griffin_lim = False
|
85 |
+
|
86 |
+
CONFIG_SE_PATH = "yourtts_config/config_se.json"
|
87 |
+
CHECKPOINT_SE_PATH = "yourtts_config/SE_checkpoint.pth.tar"
|
88 |
+
SE_speaker_manager = SpeakerManager(encoder_model_path=CHECKPOINT_SE_PATH, encoder_config_path=CONFIG_SE_PATH, use_cuda=USE_CUDA)
|
89 |
+
|
90 |
+
def compute_spec(ref_file):
|
91 |
+
y, sr = librosa.load(ref_file, sr=ap.sample_rate)
|
92 |
+
spec = ap.spectrogram(y)
|
93 |
+
spec = torch.FloatTensor(spec).unsqueeze(0)
|
94 |
+
return spec
|
95 |
+
|
96 |
+
print("Loading FreeVC...")
|
97 |
+
hps = utils.get_hparams_from_file("configs/freevc.json")
|
98 |
+
freevc = SynthesizerTrn(
|
99 |
+
hps.data.filter_length // 2 + 1,
|
100 |
+
hps.train.segment_size // hps.data.hop_length,
|
101 |
+
**hps.model).to(device)
|
102 |
+
_ = freevc.eval()
|
103 |
+
_ = utils.load_checkpoint("checkpoints/freevc.pth", freevc, None)
|
104 |
+
smodel = SpeakerEncoder('speaker_encoder/ckpt/pretrained_bak_5805000.pt')
|
105 |
+
|
106 |
+
print("Loading WavLM for content...")
|
107 |
+
cmodel = utils.get_cmodel(device).to(device)
|
108 |
+
# cmodel = WavLMModel.from_pretrained("microsoft/wavlm-large").to(device)
|
109 |
+
|
110 |
+
def voice_conversion_yourtts(da, ta):
|
111 |
+
|
112 |
+
# write(target_audio, ta[0], ta[1])
|
113 |
+
# write(driving_audio, da[0], da[1])
|
114 |
+
|
115 |
+
# !ffmpeg-normalize $target_audio -nt rms -t=-27 -o $target_audio -ar 16000 -f
|
116 |
+
# !ffmpeg-normalize $reference_audio -nt rms -t=-27 -o $reference_audio -ar 16000 -f
|
117 |
+
# !ffmpeg-normalize $driving_audio -nt rms -t=-27 -o $driving_audio -ar 16000 -f
|
118 |
+
|
119 |
+
files = [da, ta]
|
120 |
+
|
121 |
+
for file in files:
|
122 |
+
subprocess.run(["ffmpeg-normalize", file, "-nt", "rms", "-t=-27", "-o", file, "-ar", "16000", "-f"])
|
123 |
+
|
124 |
+
# ta_ = read(target_audio)
|
125 |
+
|
126 |
+
target_emb = SE_speaker_manager.compute_d_vector_from_clip([ta])
|
127 |
+
target_emb = torch.FloatTensor(target_emb).unsqueeze(0)
|
128 |
+
|
129 |
+
driving_emb = SE_speaker_manager.compute_d_vector_from_clip([da])
|
130 |
+
driving_emb = torch.FloatTensor(driving_emb).unsqueeze(0)
|
131 |
+
|
132 |
+
# Convert the voice
|
133 |
+
|
134 |
+
driving_spec = compute_spec(da)
|
135 |
+
y_lengths = torch.tensor([driving_spec.size(-1)])
|
136 |
+
if USE_CUDA:
|
137 |
+
ref_wav_voc, _, _ = model.voice_conversion(driving_spec.cuda(), y_lengths.cuda(), driving_emb.cuda(), target_emb.cuda())
|
138 |
+
ref_wav_voc = ref_wav_voc.squeeze().cpu().detach().numpy()
|
139 |
+
else:
|
140 |
+
ref_wav_voc, _, _ = model.voice_conversion(driving_spec, y_lengths, driving_emb, target_emb)
|
141 |
+
ref_wav_voc = ref_wav_voc.squeeze().detach().numpy()
|
142 |
+
|
143 |
+
# print("Reference Audio after decoder:")
|
144 |
+
# IPython.display.display(Audio(ref_wav_voc, rate=ap.sample_rate))
|
145 |
+
|
146 |
+
return (ap.sample_rate, ref_wav_voc)
|
147 |
+
|
148 |
+
def voice_conversion_freevc(src, tgt):
|
149 |
+
with torch.no_grad():
|
150 |
+
wav_tgt, _ = librosa.load(tgt, sr=hps.data.sampling_rate)
|
151 |
+
wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
|
152 |
+
g_tgt = smodel.embed_utterance(wav_tgt)
|
153 |
+
g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).to(device)
|
154 |
+
wav_src, _ = librosa.load(src, sr=hps.data.sampling_rate)
|
155 |
+
wav_src = torch.from_numpy(wav_src).unsqueeze(0).to(device)
|
156 |
+
# c = cmodel(wav_src).last_hidden_state.transpose(1, 2).to(device)
|
157 |
+
c = utils.get_content(cmodel, wav_src)
|
158 |
+
audio = freevc.infer(c, g=g_tgt)
|
159 |
+
audio = audio[0][0].data.cpu().float().numpy()
|
160 |
+
write("out.wav", hps.data.sampling_rate, audio)
|
161 |
+
out = "out.wav"
|
162 |
+
return out
|
163 |
+
|
164 |
+
model1 = gr.Dropdown(choices=["FreeVC", "YourTTS"], value="FreeVC",type="value", label="Model")
|
165 |
+
model2 = gr.Dropdown(choices=["FreeVC", "YourTTS"], value="FreeVC",type="value", label="Model")
|
166 |
+
|
167 |
+
audio1 = gr.inputs.Audio(label="Source Speaker - Input Audio", type='filepath')
|
168 |
+
audio2 = gr.inputs.Audio(label="Target Speaker - Input Audio", type='filepath')
|
169 |
+
microphone = gr.inputs.Audio(label="Source Speaker - Input Audio", source='microphone')
|
170 |
+
audio3 = gr.inputs.Audio(label="Target Speaker - Input Audio", type='filepath')
|
171 |
+
|
172 |
+
inputs_1 = [model1, audio1, audio2]
|
173 |
+
inputs_2 = [model2, microphone, audio3]
|
174 |
+
|
175 |
+
outputs_1 = gr.outputs.Audio(label="Target Speaker - Output Audio", type='filepath')
|
176 |
+
outputs_2 = gr.outputs.Audio(label="Target Speaker - Output Audio", type='filepath')
|
177 |
+
|
178 |
+
def voice_conversion(mod, sa, ta):
|
179 |
+
|
180 |
+
if mod=='FreeVC':
|
181 |
+
return voice_conversion_yourtts(sa, ta)
|
182 |
+
else:
|
183 |
+
return voice_conversion_freevc(sa, ta)
|
184 |
+
|
185 |
+
examples_1 = [['FreeVC', 'sample_inputs/ntr.wav', 'sample_inputs/timcast1.wav'], ['YourTTS', 'sample_inputs/ntr.wav', 'sample_inputs/timcast1.wav']]
|
186 |
+
|
187 |
+
vc_1 = gr.Interface(
|
188 |
+
fn=voice_conversion,
|
189 |
+
inputs=inputs_1,
|
190 |
+
outputs=outputs_1,
|
191 |
+
examples=examples_1,
|
192 |
+
description="Use this cool tool to convert your voice to another person's! \n Upload files in wav format for the source speaker and the target speaker.\n \nThis demonstration is made by T B Ramkamal, for partial credit towards completion of my Dual Degree Project"
|
193 |
+
)
|
194 |
+
|
195 |
+
vc_2 = gr.Interface(
|
196 |
+
fn=voice_conversion,
|
197 |
+
inputs=inputs_2,
|
198 |
+
outputs=outputs_2,
|
199 |
+
description="Use this cool tool to convert your voice to another person's! \n Upload files in wav format for the target speaker and record the voice of the input speaker using the microphone.\n \nThis demonstration is made by T B Ramkamal, for partial credit towards completion of my Dual Degree Project"
|
200 |
+
)
|
201 |
+
|
202 |
+
demo = gr.TabbedInterface([vc_1, vc_2], ["wav Input", "Microphone Input"], title="Voice Conversion")
|
203 |
+
demo.launch(debug='True')
|
checkpoints/freevc.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e2cc2d047f63b80d1d6780e37611cec11a01d597560393b1fe6118158b3bd47f
|
3 |
+
size 472644351
|
commons.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
|
8 |
+
def init_weights(m, mean=0.0, std=0.01):
|
9 |
+
classname = m.__class__.__name__
|
10 |
+
if classname.find("Conv") != -1:
|
11 |
+
m.weight.data.normal_(mean, std)
|
12 |
+
|
13 |
+
|
14 |
+
def get_padding(kernel_size, dilation=1):
|
15 |
+
return int((kernel_size*dilation - dilation)/2)
|
16 |
+
|
17 |
+
|
18 |
+
def convert_pad_shape(pad_shape):
|
19 |
+
l = pad_shape[::-1]
|
20 |
+
pad_shape = [item for sublist in l for item in sublist]
|
21 |
+
return pad_shape
|
22 |
+
|
23 |
+
|
24 |
+
def intersperse(lst, item):
|
25 |
+
result = [item] * (len(lst) * 2 + 1)
|
26 |
+
result[1::2] = lst
|
27 |
+
return result
|
28 |
+
|
29 |
+
|
30 |
+
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
31 |
+
"""KL(P||Q)"""
|
32 |
+
kl = (logs_q - logs_p) - 0.5
|
33 |
+
kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
|
34 |
+
return kl
|
35 |
+
|
36 |
+
|
37 |
+
def rand_gumbel(shape):
|
38 |
+
"""Sample from the Gumbel distribution, protect from overflows."""
|
39 |
+
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
40 |
+
return -torch.log(-torch.log(uniform_samples))
|
41 |
+
|
42 |
+
|
43 |
+
def rand_gumbel_like(x):
|
44 |
+
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
45 |
+
return g
|
46 |
+
|
47 |
+
|
48 |
+
def slice_segments(x, ids_str, segment_size=4):
|
49 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
50 |
+
for i in range(x.size(0)):
|
51 |
+
idx_str = ids_str[i]
|
52 |
+
idx_end = idx_str + segment_size
|
53 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
54 |
+
return ret
|
55 |
+
|
56 |
+
|
57 |
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
58 |
+
b, d, t = x.size()
|
59 |
+
if x_lengths is None:
|
60 |
+
x_lengths = t
|
61 |
+
ids_str_max = x_lengths - segment_size + 1
|
62 |
+
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
63 |
+
ret = slice_segments(x, ids_str, segment_size)
|
64 |
+
return ret, ids_str
|
65 |
+
|
66 |
+
|
67 |
+
def rand_spec_segments(x, x_lengths=None, segment_size=4):
|
68 |
+
b, d, t = x.size()
|
69 |
+
if x_lengths is None:
|
70 |
+
x_lengths = t
|
71 |
+
ids_str_max = x_lengths - segment_size
|
72 |
+
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
73 |
+
ret = slice_segments(x, ids_str, segment_size)
|
74 |
+
return ret, ids_str
|
75 |
+
|
76 |
+
|
77 |
+
def get_timing_signal_1d(
|
78 |
+
length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
79 |
+
position = torch.arange(length, dtype=torch.float)
|
80 |
+
num_timescales = channels // 2
|
81 |
+
log_timescale_increment = (
|
82 |
+
math.log(float(max_timescale) / float(min_timescale)) /
|
83 |
+
(num_timescales - 1))
|
84 |
+
inv_timescales = min_timescale * torch.exp(
|
85 |
+
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
|
86 |
+
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
87 |
+
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
88 |
+
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
89 |
+
signal = signal.view(1, channels, length)
|
90 |
+
return signal
|
91 |
+
|
92 |
+
|
93 |
+
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
94 |
+
b, channels, length = x.size()
|
95 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
96 |
+
return x + signal.to(dtype=x.dtype, device=x.device)
|
97 |
+
|
98 |
+
|
99 |
+
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
100 |
+
b, channels, length = x.size()
|
101 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
102 |
+
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
103 |
+
|
104 |
+
|
105 |
+
def subsequent_mask(length):
|
106 |
+
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
107 |
+
return mask
|
108 |
+
|
109 |
+
|
110 |
+
@torch.jit.script
|
111 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
112 |
+
n_channels_int = n_channels[0]
|
113 |
+
in_act = input_a + input_b
|
114 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
115 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
116 |
+
acts = t_act * s_act
|
117 |
+
return acts
|
118 |
+
|
119 |
+
|
120 |
+
def convert_pad_shape(pad_shape):
|
121 |
+
l = pad_shape[::-1]
|
122 |
+
pad_shape = [item for sublist in l for item in sublist]
|
123 |
+
return pad_shape
|
124 |
+
|
125 |
+
|
126 |
+
def shift_1d(x):
|
127 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
128 |
+
return x
|
129 |
+
|
130 |
+
|
131 |
+
def sequence_mask(length, max_length=None):
|
132 |
+
if max_length is None:
|
133 |
+
max_length = length.max()
|
134 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
135 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
136 |
+
|
137 |
+
|
138 |
+
def generate_path(duration, mask):
|
139 |
+
"""
|
140 |
+
duration: [b, 1, t_x]
|
141 |
+
mask: [b, 1, t_y, t_x]
|
142 |
+
"""
|
143 |
+
device = duration.device
|
144 |
+
|
145 |
+
b, _, t_y, t_x = mask.shape
|
146 |
+
cum_duration = torch.cumsum(duration, -1)
|
147 |
+
|
148 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
149 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
150 |
+
path = path.view(b, t_x, t_y)
|
151 |
+
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
152 |
+
path = path.unsqueeze(1).transpose(2,3) * mask
|
153 |
+
return path
|
154 |
+
|
155 |
+
|
156 |
+
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
157 |
+
if isinstance(parameters, torch.Tensor):
|
158 |
+
parameters = [parameters]
|
159 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
160 |
+
norm_type = float(norm_type)
|
161 |
+
if clip_value is not None:
|
162 |
+
clip_value = float(clip_value)
|
163 |
+
|
164 |
+
total_norm = 0
|
165 |
+
for p in parameters:
|
166 |
+
param_norm = p.grad.data.norm(norm_type)
|
167 |
+
total_norm += param_norm.item() ** norm_type
|
168 |
+
if clip_value is not None:
|
169 |
+
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
170 |
+
total_norm = total_norm ** (1. / norm_type)
|
171 |
+
return total_norm
|
configs/freevc.json
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"train": {
|
3 |
+
"log_interval": 200,
|
4 |
+
"eval_interval": 10000,
|
5 |
+
"seed": 1234,
|
6 |
+
"epochs": 10000,
|
7 |
+
"learning_rate": 2e-4,
|
8 |
+
"betas": [0.8, 0.99],
|
9 |
+
"eps": 1e-9,
|
10 |
+
"batch_size": 64,
|
11 |
+
"fp16_run": false,
|
12 |
+
"lr_decay": 0.999875,
|
13 |
+
"segment_size": 8960,
|
14 |
+
"init_lr_ratio": 1,
|
15 |
+
"warmup_epochs": 0,
|
16 |
+
"c_mel": 45,
|
17 |
+
"c_kl": 1.0,
|
18 |
+
"use_sr": true,
|
19 |
+
"max_speclen": 128,
|
20 |
+
"port": "8001"
|
21 |
+
},
|
22 |
+
"data": {
|
23 |
+
"training_files":"filelists/train.txt",
|
24 |
+
"validation_files":"filelists/val.txt",
|
25 |
+
"max_wav_value": 32768.0,
|
26 |
+
"sampling_rate": 16000,
|
27 |
+
"filter_length": 1280,
|
28 |
+
"hop_length": 320,
|
29 |
+
"win_length": 1280,
|
30 |
+
"n_mel_channels": 80,
|
31 |
+
"mel_fmin": 0.0,
|
32 |
+
"mel_fmax": null
|
33 |
+
},
|
34 |
+
"model": {
|
35 |
+
"inter_channels": 192,
|
36 |
+
"hidden_channels": 192,
|
37 |
+
"filter_channels": 768,
|
38 |
+
"n_heads": 2,
|
39 |
+
"n_layers": 6,
|
40 |
+
"kernel_size": 3,
|
41 |
+
"p_dropout": 0.1,
|
42 |
+
"resblock": "1",
|
43 |
+
"resblock_kernel_sizes": [3,7,11],
|
44 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
45 |
+
"upsample_rates": [10,8,2,2],
|
46 |
+
"upsample_initial_channel": 512,
|
47 |
+
"upsample_kernel_sizes": [16,16,4,4],
|
48 |
+
"n_layers_q": 3,
|
49 |
+
"use_spectral_norm": false,
|
50 |
+
"gin_channels": 256,
|
51 |
+
"ssl_dim": 1024,
|
52 |
+
"use_spk": true
|
53 |
+
}
|
54 |
+
}
|
mel_processing.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.utils.data
|
8 |
+
import numpy as np
|
9 |
+
import librosa
|
10 |
+
import librosa.util as librosa_util
|
11 |
+
from librosa.util import normalize, pad_center, tiny
|
12 |
+
from scipy.signal import get_window
|
13 |
+
from scipy.io.wavfile import read
|
14 |
+
from librosa.filters import mel as librosa_mel_fn
|
15 |
+
|
16 |
+
MAX_WAV_VALUE = 32768.0
|
17 |
+
|
18 |
+
|
19 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
20 |
+
"""
|
21 |
+
PARAMS
|
22 |
+
------
|
23 |
+
C: compression factor
|
24 |
+
"""
|
25 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
26 |
+
|
27 |
+
|
28 |
+
def dynamic_range_decompression_torch(x, C=1):
|
29 |
+
"""
|
30 |
+
PARAMS
|
31 |
+
------
|
32 |
+
C: compression factor used to compress
|
33 |
+
"""
|
34 |
+
return torch.exp(x) / C
|
35 |
+
|
36 |
+
|
37 |
+
def spectral_normalize_torch(magnitudes):
|
38 |
+
output = dynamic_range_compression_torch(magnitudes)
|
39 |
+
return output
|
40 |
+
|
41 |
+
|
42 |
+
def spectral_de_normalize_torch(magnitudes):
|
43 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
44 |
+
return output
|
45 |
+
|
46 |
+
|
47 |
+
mel_basis = {}
|
48 |
+
hann_window = {}
|
49 |
+
|
50 |
+
|
51 |
+
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
52 |
+
if torch.min(y) < -1.:
|
53 |
+
print('min value is ', torch.min(y))
|
54 |
+
if torch.max(y) > 1.:
|
55 |
+
print('max value is ', torch.max(y))
|
56 |
+
|
57 |
+
global hann_window
|
58 |
+
dtype_device = str(y.dtype) + '_' + str(y.device)
|
59 |
+
wnsize_dtype_device = str(win_size) + '_' + dtype_device
|
60 |
+
if wnsize_dtype_device not in hann_window:
|
61 |
+
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
62 |
+
|
63 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
64 |
+
y = y.squeeze(1)
|
65 |
+
|
66 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
|
67 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
|
68 |
+
|
69 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
70 |
+
return spec
|
71 |
+
|
72 |
+
|
73 |
+
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
74 |
+
global mel_basis
|
75 |
+
dtype_device = str(spec.dtype) + '_' + str(spec.device)
|
76 |
+
fmax_dtype_device = str(fmax) + '_' + dtype_device
|
77 |
+
if fmax_dtype_device not in mel_basis:
|
78 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
79 |
+
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
80 |
+
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
81 |
+
spec = spectral_normalize_torch(spec)
|
82 |
+
return spec
|
83 |
+
|
84 |
+
|
85 |
+
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
86 |
+
if torch.min(y) < -1.:
|
87 |
+
print('min value is ', torch.min(y))
|
88 |
+
if torch.max(y) > 1.:
|
89 |
+
print('max value is ', torch.max(y))
|
90 |
+
|
91 |
+
global mel_basis, hann_window
|
92 |
+
dtype_device = str(y.dtype) + '_' + str(y.device)
|
93 |
+
fmax_dtype_device = str(fmax) + '_' + dtype_device
|
94 |
+
wnsize_dtype_device = str(win_size) + '_' + dtype_device
|
95 |
+
if fmax_dtype_device not in mel_basis:
|
96 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
97 |
+
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
|
98 |
+
if wnsize_dtype_device not in hann_window:
|
99 |
+
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
100 |
+
|
101 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
102 |
+
y = y.squeeze(1)
|
103 |
+
|
104 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
|
105 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
|
106 |
+
|
107 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
108 |
+
|
109 |
+
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
110 |
+
spec = spectral_normalize_torch(spec)
|
111 |
+
|
112 |
+
return spec
|
models.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
import commons
|
8 |
+
import modules
|
9 |
+
|
10 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
11 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
12 |
+
from commons import init_weights, get_padding
|
13 |
+
|
14 |
+
|
15 |
+
class ResidualCouplingBlock(nn.Module):
|
16 |
+
def __init__(self,
|
17 |
+
channels,
|
18 |
+
hidden_channels,
|
19 |
+
kernel_size,
|
20 |
+
dilation_rate,
|
21 |
+
n_layers,
|
22 |
+
n_flows=4,
|
23 |
+
gin_channels=0):
|
24 |
+
super().__init__()
|
25 |
+
self.channels = channels
|
26 |
+
self.hidden_channels = hidden_channels
|
27 |
+
self.kernel_size = kernel_size
|
28 |
+
self.dilation_rate = dilation_rate
|
29 |
+
self.n_layers = n_layers
|
30 |
+
self.n_flows = n_flows
|
31 |
+
self.gin_channels = gin_channels
|
32 |
+
|
33 |
+
self.flows = nn.ModuleList()
|
34 |
+
for i in range(n_flows):
|
35 |
+
self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
|
36 |
+
self.flows.append(modules.Flip())
|
37 |
+
|
38 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
39 |
+
if not reverse:
|
40 |
+
for flow in self.flows:
|
41 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
42 |
+
else:
|
43 |
+
for flow in reversed(self.flows):
|
44 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
class Encoder(nn.Module):
|
49 |
+
def __init__(self,
|
50 |
+
in_channels,
|
51 |
+
out_channels,
|
52 |
+
hidden_channels,
|
53 |
+
kernel_size,
|
54 |
+
dilation_rate,
|
55 |
+
n_layers,
|
56 |
+
gin_channels=0):
|
57 |
+
super().__init__()
|
58 |
+
self.in_channels = in_channels
|
59 |
+
self.out_channels = out_channels
|
60 |
+
self.hidden_channels = hidden_channels
|
61 |
+
self.kernel_size = kernel_size
|
62 |
+
self.dilation_rate = dilation_rate
|
63 |
+
self.n_layers = n_layers
|
64 |
+
self.gin_channels = gin_channels
|
65 |
+
|
66 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
67 |
+
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
68 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
69 |
+
|
70 |
+
def forward(self, x, x_lengths, g=None):
|
71 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
72 |
+
x = self.pre(x) * x_mask
|
73 |
+
x = self.enc(x, x_mask, g=g)
|
74 |
+
stats = self.proj(x) * x_mask
|
75 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
76 |
+
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
77 |
+
return z, m, logs, x_mask
|
78 |
+
|
79 |
+
|
80 |
+
class Generator(torch.nn.Module):
|
81 |
+
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
|
82 |
+
super(Generator, self).__init__()
|
83 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
84 |
+
self.num_upsamples = len(upsample_rates)
|
85 |
+
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
86 |
+
resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
|
87 |
+
|
88 |
+
self.ups = nn.ModuleList()
|
89 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
90 |
+
self.ups.append(weight_norm(
|
91 |
+
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
|
92 |
+
k, u, padding=(k-u)//2)))
|
93 |
+
|
94 |
+
self.resblocks = nn.ModuleList()
|
95 |
+
for i in range(len(self.ups)):
|
96 |
+
ch = upsample_initial_channel//(2**(i+1))
|
97 |
+
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
98 |
+
self.resblocks.append(resblock(ch, k, d))
|
99 |
+
|
100 |
+
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
101 |
+
self.ups.apply(init_weights)
|
102 |
+
|
103 |
+
if gin_channels != 0:
|
104 |
+
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
105 |
+
|
106 |
+
def forward(self, x, g=None):
|
107 |
+
x = self.conv_pre(x)
|
108 |
+
if g is not None:
|
109 |
+
x = x + self.cond(g)
|
110 |
+
|
111 |
+
for i in range(self.num_upsamples):
|
112 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
113 |
+
x = self.ups[i](x)
|
114 |
+
xs = None
|
115 |
+
for j in range(self.num_kernels):
|
116 |
+
if xs is None:
|
117 |
+
xs = self.resblocks[i*self.num_kernels+j](x)
|
118 |
+
else:
|
119 |
+
xs += self.resblocks[i*self.num_kernels+j](x)
|
120 |
+
x = xs / self.num_kernels
|
121 |
+
x = F.leaky_relu(x)
|
122 |
+
x = self.conv_post(x)
|
123 |
+
x = torch.tanh(x)
|
124 |
+
|
125 |
+
return x
|
126 |
+
|
127 |
+
def remove_weight_norm(self):
|
128 |
+
print('Removing weight norm...')
|
129 |
+
for l in self.ups:
|
130 |
+
remove_weight_norm(l)
|
131 |
+
for l in self.resblocks:
|
132 |
+
l.remove_weight_norm()
|
133 |
+
|
134 |
+
|
135 |
+
class DiscriminatorP(torch.nn.Module):
|
136 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
137 |
+
super(DiscriminatorP, self).__init__()
|
138 |
+
self.period = period
|
139 |
+
self.use_spectral_norm = use_spectral_norm
|
140 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
141 |
+
self.convs = nn.ModuleList([
|
142 |
+
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
143 |
+
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
144 |
+
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
145 |
+
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
|
146 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
|
147 |
+
])
|
148 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
fmap = []
|
152 |
+
|
153 |
+
# 1d to 2d
|
154 |
+
b, c, t = x.shape
|
155 |
+
if t % self.period != 0: # pad first
|
156 |
+
n_pad = self.period - (t % self.period)
|
157 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
158 |
+
t = t + n_pad
|
159 |
+
x = x.view(b, c, t // self.period, self.period)
|
160 |
+
|
161 |
+
for l in self.convs:
|
162 |
+
x = l(x)
|
163 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
164 |
+
fmap.append(x)
|
165 |
+
x = self.conv_post(x)
|
166 |
+
fmap.append(x)
|
167 |
+
x = torch.flatten(x, 1, -1)
|
168 |
+
|
169 |
+
return x, fmap
|
170 |
+
|
171 |
+
|
172 |
+
class DiscriminatorS(torch.nn.Module):
|
173 |
+
def __init__(self, use_spectral_norm=False):
|
174 |
+
super(DiscriminatorS, self).__init__()
|
175 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
176 |
+
self.convs = nn.ModuleList([
|
177 |
+
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
178 |
+
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
179 |
+
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
180 |
+
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
181 |
+
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
182 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
183 |
+
])
|
184 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
185 |
+
|
186 |
+
def forward(self, x):
|
187 |
+
fmap = []
|
188 |
+
|
189 |
+
for l in self.convs:
|
190 |
+
x = l(x)
|
191 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
192 |
+
fmap.append(x)
|
193 |
+
x = self.conv_post(x)
|
194 |
+
fmap.append(x)
|
195 |
+
x = torch.flatten(x, 1, -1)
|
196 |
+
|
197 |
+
return x, fmap
|
198 |
+
|
199 |
+
|
200 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
201 |
+
def __init__(self, use_spectral_norm=False):
|
202 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
203 |
+
periods = [2,3,5,7,11]
|
204 |
+
|
205 |
+
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
206 |
+
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
207 |
+
self.discriminators = nn.ModuleList(discs)
|
208 |
+
|
209 |
+
def forward(self, y, y_hat):
|
210 |
+
y_d_rs = []
|
211 |
+
y_d_gs = []
|
212 |
+
fmap_rs = []
|
213 |
+
fmap_gs = []
|
214 |
+
for i, d in enumerate(self.discriminators):
|
215 |
+
y_d_r, fmap_r = d(y)
|
216 |
+
y_d_g, fmap_g = d(y_hat)
|
217 |
+
y_d_rs.append(y_d_r)
|
218 |
+
y_d_gs.append(y_d_g)
|
219 |
+
fmap_rs.append(fmap_r)
|
220 |
+
fmap_gs.append(fmap_g)
|
221 |
+
|
222 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
223 |
+
|
224 |
+
|
225 |
+
class SpeakerEncoder(torch.nn.Module):
|
226 |
+
def __init__(self, mel_n_channels=80, model_num_layers=3, model_hidden_size=256, model_embedding_size=256):
|
227 |
+
super(SpeakerEncoder, self).__init__()
|
228 |
+
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
|
229 |
+
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
|
230 |
+
self.relu = nn.ReLU()
|
231 |
+
|
232 |
+
def forward(self, mels):
|
233 |
+
self.lstm.flatten_parameters()
|
234 |
+
_, (hidden, _) = self.lstm(mels)
|
235 |
+
embeds_raw = self.relu(self.linear(hidden[-1]))
|
236 |
+
return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
|
237 |
+
|
238 |
+
def compute_partial_slices(self, total_frames, partial_frames, partial_hop):
|
239 |
+
mel_slices = []
|
240 |
+
for i in range(0, total_frames-partial_frames, partial_hop):
|
241 |
+
mel_range = torch.arange(i, i+partial_frames)
|
242 |
+
mel_slices.append(mel_range)
|
243 |
+
|
244 |
+
return mel_slices
|
245 |
+
|
246 |
+
def embed_utterance(self, mel, partial_frames=128, partial_hop=64):
|
247 |
+
mel_len = mel.size(1)
|
248 |
+
last_mel = mel[:,-partial_frames:]
|
249 |
+
|
250 |
+
if mel_len > partial_frames:
|
251 |
+
mel_slices = self.compute_partial_slices(mel_len, partial_frames, partial_hop)
|
252 |
+
mels = list(mel[:,s] for s in mel_slices)
|
253 |
+
mels.append(last_mel)
|
254 |
+
mels = torch.stack(tuple(mels), 0).squeeze(1)
|
255 |
+
|
256 |
+
with torch.no_grad():
|
257 |
+
partial_embeds = self(mels)
|
258 |
+
embed = torch.mean(partial_embeds, axis=0).unsqueeze(0)
|
259 |
+
#embed = embed / torch.linalg.norm(embed, 2)
|
260 |
+
else:
|
261 |
+
with torch.no_grad():
|
262 |
+
embed = self(last_mel)
|
263 |
+
|
264 |
+
return embed
|
265 |
+
|
266 |
+
|
267 |
+
class SynthesizerTrn(nn.Module):
|
268 |
+
"""
|
269 |
+
Synthesizer for Training
|
270 |
+
"""
|
271 |
+
|
272 |
+
def __init__(self,
|
273 |
+
spec_channels,
|
274 |
+
segment_size,
|
275 |
+
inter_channels,
|
276 |
+
hidden_channels,
|
277 |
+
filter_channels,
|
278 |
+
n_heads,
|
279 |
+
n_layers,
|
280 |
+
kernel_size,
|
281 |
+
p_dropout,
|
282 |
+
resblock,
|
283 |
+
resblock_kernel_sizes,
|
284 |
+
resblock_dilation_sizes,
|
285 |
+
upsample_rates,
|
286 |
+
upsample_initial_channel,
|
287 |
+
upsample_kernel_sizes,
|
288 |
+
gin_channels,
|
289 |
+
ssl_dim,
|
290 |
+
use_spk,
|
291 |
+
**kwargs):
|
292 |
+
|
293 |
+
super().__init__()
|
294 |
+
self.spec_channels = spec_channels
|
295 |
+
self.inter_channels = inter_channels
|
296 |
+
self.hidden_channels = hidden_channels
|
297 |
+
self.filter_channels = filter_channels
|
298 |
+
self.n_heads = n_heads
|
299 |
+
self.n_layers = n_layers
|
300 |
+
self.kernel_size = kernel_size
|
301 |
+
self.p_dropout = p_dropout
|
302 |
+
self.resblock = resblock
|
303 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
304 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
305 |
+
self.upsample_rates = upsample_rates
|
306 |
+
self.upsample_initial_channel = upsample_initial_channel
|
307 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
308 |
+
self.segment_size = segment_size
|
309 |
+
self.gin_channels = gin_channels
|
310 |
+
self.ssl_dim = ssl_dim
|
311 |
+
self.use_spk = use_spk
|
312 |
+
|
313 |
+
self.enc_p = Encoder(ssl_dim, inter_channels, hidden_channels, 5, 1, 16)
|
314 |
+
self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
315 |
+
self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
|
316 |
+
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
317 |
+
|
318 |
+
if not self.use_spk:
|
319 |
+
self.enc_spk = SpeakerEncoder(model_hidden_size=gin_channels, model_embedding_size=gin_channels)
|
320 |
+
|
321 |
+
def forward(self, c, spec, g=None, mel=None, c_lengths=None, spec_lengths=None):
|
322 |
+
if c_lengths == None:
|
323 |
+
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
|
324 |
+
if spec_lengths == None:
|
325 |
+
spec_lengths = (torch.ones(spec.size(0)) * spec.size(-1)).to(spec.device)
|
326 |
+
|
327 |
+
if not self.use_spk:
|
328 |
+
g = self.enc_spk(mel.transpose(1,2))
|
329 |
+
g = g.unsqueeze(-1)
|
330 |
+
|
331 |
+
_, m_p, logs_p, _ = self.enc_p(c, c_lengths)
|
332 |
+
z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g)
|
333 |
+
z_p = self.flow(z, spec_mask, g=g)
|
334 |
+
|
335 |
+
z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size)
|
336 |
+
o = self.dec(z_slice, g=g)
|
337 |
+
|
338 |
+
return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
339 |
+
|
340 |
+
def infer(self, c, g=None, mel=None, c_lengths=None):
|
341 |
+
if c_lengths == None:
|
342 |
+
c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
|
343 |
+
if not self.use_spk:
|
344 |
+
g = self.enc_spk.embed_utterance(mel.transpose(1,2))
|
345 |
+
g = g.unsqueeze(-1)
|
346 |
+
|
347 |
+
z_p, m_p, logs_p, c_mask = self.enc_p(c, c_lengths)
|
348 |
+
z = self.flow(z_p, c_mask, g=g, reverse=True)
|
349 |
+
o = self.dec(z * c_mask, g=g)
|
350 |
+
|
351 |
+
return o
|
modules.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import scipy
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
10 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
11 |
+
|
12 |
+
import commons
|
13 |
+
from commons import init_weights, get_padding
|
14 |
+
|
15 |
+
|
16 |
+
LRELU_SLOPE = 0.1
|
17 |
+
|
18 |
+
|
19 |
+
class LayerNorm(nn.Module):
|
20 |
+
def __init__(self, channels, eps=1e-5):
|
21 |
+
super().__init__()
|
22 |
+
self.channels = channels
|
23 |
+
self.eps = eps
|
24 |
+
|
25 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
26 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
x = x.transpose(1, -1)
|
30 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
31 |
+
return x.transpose(1, -1)
|
32 |
+
|
33 |
+
|
34 |
+
class ConvReluNorm(nn.Module):
|
35 |
+
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
|
36 |
+
super().__init__()
|
37 |
+
self.in_channels = in_channels
|
38 |
+
self.hidden_channels = hidden_channels
|
39 |
+
self.out_channels = out_channels
|
40 |
+
self.kernel_size = kernel_size
|
41 |
+
self.n_layers = n_layers
|
42 |
+
self.p_dropout = p_dropout
|
43 |
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
44 |
+
|
45 |
+
self.conv_layers = nn.ModuleList()
|
46 |
+
self.norm_layers = nn.ModuleList()
|
47 |
+
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
|
48 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
49 |
+
self.relu_drop = nn.Sequential(
|
50 |
+
nn.ReLU(),
|
51 |
+
nn.Dropout(p_dropout))
|
52 |
+
for _ in range(n_layers-1):
|
53 |
+
self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
|
54 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
55 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
56 |
+
self.proj.weight.data.zero_()
|
57 |
+
self.proj.bias.data.zero_()
|
58 |
+
|
59 |
+
def forward(self, x, x_mask):
|
60 |
+
x_org = x
|
61 |
+
for i in range(self.n_layers):
|
62 |
+
x = self.conv_layers[i](x * x_mask)
|
63 |
+
x = self.norm_layers[i](x)
|
64 |
+
x = self.relu_drop(x)
|
65 |
+
x = x_org + self.proj(x)
|
66 |
+
return x * x_mask
|
67 |
+
|
68 |
+
|
69 |
+
class DDSConv(nn.Module):
|
70 |
+
"""
|
71 |
+
Dialted and Depth-Separable Convolution
|
72 |
+
"""
|
73 |
+
def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
|
74 |
+
super().__init__()
|
75 |
+
self.channels = channels
|
76 |
+
self.kernel_size = kernel_size
|
77 |
+
self.n_layers = n_layers
|
78 |
+
self.p_dropout = p_dropout
|
79 |
+
|
80 |
+
self.drop = nn.Dropout(p_dropout)
|
81 |
+
self.convs_sep = nn.ModuleList()
|
82 |
+
self.convs_1x1 = nn.ModuleList()
|
83 |
+
self.norms_1 = nn.ModuleList()
|
84 |
+
self.norms_2 = nn.ModuleList()
|
85 |
+
for i in range(n_layers):
|
86 |
+
dilation = kernel_size ** i
|
87 |
+
padding = (kernel_size * dilation - dilation) // 2
|
88 |
+
self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
|
89 |
+
groups=channels, dilation=dilation, padding=padding
|
90 |
+
))
|
91 |
+
self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
|
92 |
+
self.norms_1.append(LayerNorm(channels))
|
93 |
+
self.norms_2.append(LayerNorm(channels))
|
94 |
+
|
95 |
+
def forward(self, x, x_mask, g=None):
|
96 |
+
if g is not None:
|
97 |
+
x = x + g
|
98 |
+
for i in range(self.n_layers):
|
99 |
+
y = self.convs_sep[i](x * x_mask)
|
100 |
+
y = self.norms_1[i](y)
|
101 |
+
y = F.gelu(y)
|
102 |
+
y = self.convs_1x1[i](y)
|
103 |
+
y = self.norms_2[i](y)
|
104 |
+
y = F.gelu(y)
|
105 |
+
y = self.drop(y)
|
106 |
+
x = x + y
|
107 |
+
return x * x_mask
|
108 |
+
|
109 |
+
|
110 |
+
class WN(torch.nn.Module):
|
111 |
+
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
|
112 |
+
super(WN, self).__init__()
|
113 |
+
assert(kernel_size % 2 == 1)
|
114 |
+
self.hidden_channels =hidden_channels
|
115 |
+
self.kernel_size = kernel_size,
|
116 |
+
self.dilation_rate = dilation_rate
|
117 |
+
self.n_layers = n_layers
|
118 |
+
self.gin_channels = gin_channels
|
119 |
+
self.p_dropout = p_dropout
|
120 |
+
|
121 |
+
self.in_layers = torch.nn.ModuleList()
|
122 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
123 |
+
self.drop = nn.Dropout(p_dropout)
|
124 |
+
|
125 |
+
if gin_channels != 0:
|
126 |
+
cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
|
127 |
+
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
|
128 |
+
|
129 |
+
for i in range(n_layers):
|
130 |
+
dilation = dilation_rate ** i
|
131 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
132 |
+
in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
|
133 |
+
dilation=dilation, padding=padding)
|
134 |
+
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
|
135 |
+
self.in_layers.append(in_layer)
|
136 |
+
|
137 |
+
# last one is not necessary
|
138 |
+
if i < n_layers - 1:
|
139 |
+
res_skip_channels = 2 * hidden_channels
|
140 |
+
else:
|
141 |
+
res_skip_channels = hidden_channels
|
142 |
+
|
143 |
+
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
144 |
+
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
|
145 |
+
self.res_skip_layers.append(res_skip_layer)
|
146 |
+
|
147 |
+
def forward(self, x, x_mask, g=None, **kwargs):
|
148 |
+
output = torch.zeros_like(x)
|
149 |
+
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
150 |
+
|
151 |
+
if g is not None:
|
152 |
+
g = self.cond_layer(g)
|
153 |
+
|
154 |
+
for i in range(self.n_layers):
|
155 |
+
x_in = self.in_layers[i](x)
|
156 |
+
if g is not None:
|
157 |
+
cond_offset = i * 2 * self.hidden_channels
|
158 |
+
g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
|
159 |
+
else:
|
160 |
+
g_l = torch.zeros_like(x_in)
|
161 |
+
|
162 |
+
acts = commons.fused_add_tanh_sigmoid_multiply(
|
163 |
+
x_in,
|
164 |
+
g_l,
|
165 |
+
n_channels_tensor)
|
166 |
+
acts = self.drop(acts)
|
167 |
+
|
168 |
+
res_skip_acts = self.res_skip_layers[i](acts)
|
169 |
+
if i < self.n_layers - 1:
|
170 |
+
res_acts = res_skip_acts[:,:self.hidden_channels,:]
|
171 |
+
x = (x + res_acts) * x_mask
|
172 |
+
output = output + res_skip_acts[:,self.hidden_channels:,:]
|
173 |
+
else:
|
174 |
+
output = output + res_skip_acts
|
175 |
+
return output * x_mask
|
176 |
+
|
177 |
+
def remove_weight_norm(self):
|
178 |
+
if self.gin_channels != 0:
|
179 |
+
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
180 |
+
for l in self.in_layers:
|
181 |
+
torch.nn.utils.remove_weight_norm(l)
|
182 |
+
for l in self.res_skip_layers:
|
183 |
+
torch.nn.utils.remove_weight_norm(l)
|
184 |
+
|
185 |
+
|
186 |
+
class ResBlock1(torch.nn.Module):
|
187 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
188 |
+
super(ResBlock1, self).__init__()
|
189 |
+
self.convs1 = nn.ModuleList([
|
190 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
191 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
192 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
193 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
194 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
195 |
+
padding=get_padding(kernel_size, dilation[2])))
|
196 |
+
])
|
197 |
+
self.convs1.apply(init_weights)
|
198 |
+
|
199 |
+
self.convs2 = nn.ModuleList([
|
200 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
201 |
+
padding=get_padding(kernel_size, 1))),
|
202 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
203 |
+
padding=get_padding(kernel_size, 1))),
|
204 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
205 |
+
padding=get_padding(kernel_size, 1)))
|
206 |
+
])
|
207 |
+
self.convs2.apply(init_weights)
|
208 |
+
|
209 |
+
def forward(self, x, x_mask=None):
|
210 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
211 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
212 |
+
if x_mask is not None:
|
213 |
+
xt = xt * x_mask
|
214 |
+
xt = c1(xt)
|
215 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
216 |
+
if x_mask is not None:
|
217 |
+
xt = xt * x_mask
|
218 |
+
xt = c2(xt)
|
219 |
+
x = xt + x
|
220 |
+
if x_mask is not None:
|
221 |
+
x = x * x_mask
|
222 |
+
return x
|
223 |
+
|
224 |
+
def remove_weight_norm(self):
|
225 |
+
for l in self.convs1:
|
226 |
+
remove_weight_norm(l)
|
227 |
+
for l in self.convs2:
|
228 |
+
remove_weight_norm(l)
|
229 |
+
|
230 |
+
|
231 |
+
class ResBlock2(torch.nn.Module):
|
232 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
233 |
+
super(ResBlock2, self).__init__()
|
234 |
+
self.convs = nn.ModuleList([
|
235 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
236 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
237 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
238 |
+
padding=get_padding(kernel_size, dilation[1])))
|
239 |
+
])
|
240 |
+
self.convs.apply(init_weights)
|
241 |
+
|
242 |
+
def forward(self, x, x_mask=None):
|
243 |
+
for c in self.convs:
|
244 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
245 |
+
if x_mask is not None:
|
246 |
+
xt = xt * x_mask
|
247 |
+
xt = c(xt)
|
248 |
+
x = xt + x
|
249 |
+
if x_mask is not None:
|
250 |
+
x = x * x_mask
|
251 |
+
return x
|
252 |
+
|
253 |
+
def remove_weight_norm(self):
|
254 |
+
for l in self.convs:
|
255 |
+
remove_weight_norm(l)
|
256 |
+
|
257 |
+
|
258 |
+
class Log(nn.Module):
|
259 |
+
def forward(self, x, x_mask, reverse=False, **kwargs):
|
260 |
+
if not reverse:
|
261 |
+
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
262 |
+
logdet = torch.sum(-y, [1, 2])
|
263 |
+
return y, logdet
|
264 |
+
else:
|
265 |
+
x = torch.exp(x) * x_mask
|
266 |
+
return x
|
267 |
+
|
268 |
+
|
269 |
+
class Flip(nn.Module):
|
270 |
+
def forward(self, x, *args, reverse=False, **kwargs):
|
271 |
+
x = torch.flip(x, [1])
|
272 |
+
if not reverse:
|
273 |
+
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
274 |
+
return x, logdet
|
275 |
+
else:
|
276 |
+
return x
|
277 |
+
|
278 |
+
|
279 |
+
class ElementwiseAffine(nn.Module):
|
280 |
+
def __init__(self, channels):
|
281 |
+
super().__init__()
|
282 |
+
self.channels = channels
|
283 |
+
self.m = nn.Parameter(torch.zeros(channels,1))
|
284 |
+
self.logs = nn.Parameter(torch.zeros(channels,1))
|
285 |
+
|
286 |
+
def forward(self, x, x_mask, reverse=False, **kwargs):
|
287 |
+
if not reverse:
|
288 |
+
y = self.m + torch.exp(self.logs) * x
|
289 |
+
y = y * x_mask
|
290 |
+
logdet = torch.sum(self.logs * x_mask, [1,2])
|
291 |
+
return y, logdet
|
292 |
+
else:
|
293 |
+
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
294 |
+
return x
|
295 |
+
|
296 |
+
|
297 |
+
class ResidualCouplingLayer(nn.Module):
|
298 |
+
def __init__(self,
|
299 |
+
channels,
|
300 |
+
hidden_channels,
|
301 |
+
kernel_size,
|
302 |
+
dilation_rate,
|
303 |
+
n_layers,
|
304 |
+
p_dropout=0,
|
305 |
+
gin_channels=0,
|
306 |
+
mean_only=False):
|
307 |
+
assert channels % 2 == 0, "channels should be divisible by 2"
|
308 |
+
super().__init__()
|
309 |
+
self.channels = channels
|
310 |
+
self.hidden_channels = hidden_channels
|
311 |
+
self.kernel_size = kernel_size
|
312 |
+
self.dilation_rate = dilation_rate
|
313 |
+
self.n_layers = n_layers
|
314 |
+
self.half_channels = channels // 2
|
315 |
+
self.mean_only = mean_only
|
316 |
+
|
317 |
+
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
318 |
+
self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
|
319 |
+
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
320 |
+
self.post.weight.data.zero_()
|
321 |
+
self.post.bias.data.zero_()
|
322 |
+
|
323 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
324 |
+
x0, x1 = torch.split(x, [self.half_channels]*2, 1)
|
325 |
+
h = self.pre(x0) * x_mask
|
326 |
+
h = self.enc(h, x_mask, g=g)
|
327 |
+
stats = self.post(h) * x_mask
|
328 |
+
if not self.mean_only:
|
329 |
+
m, logs = torch.split(stats, [self.half_channels]*2, 1)
|
330 |
+
else:
|
331 |
+
m = stats
|
332 |
+
logs = torch.zeros_like(m)
|
333 |
+
|
334 |
+
if not reverse:
|
335 |
+
x1 = m + x1 * torch.exp(logs) * x_mask
|
336 |
+
x = torch.cat([x0, x1], 1)
|
337 |
+
logdet = torch.sum(logs, [1,2])
|
338 |
+
return x, logdet
|
339 |
+
else:
|
340 |
+
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
341 |
+
x = torch.cat([x0, x1], 1)
|
342 |
+
return x
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/Edresson/Coqui-TTS@multilingual-torchaudio-SE
|
2 |
+
torchaudio==0.9.0
|
3 |
+
pydub
|
4 |
+
ffmpeg-normalize==1.21.0
|
5 |
+
numpy
|
6 |
+
scipy
|
7 |
+
torch
|
8 |
+
transformers
|
9 |
+
librosa==0.8.1
|
10 |
+
webrtcvad==2.0.10
|
11 |
+
gradio==3.22.1
|
sample_inputs/ntr.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b75c56ba7545d0a96bf6a12c02ef38edc4beded66fd4d32d1b92543045e43617
|
3 |
+
size 1940444
|
sample_inputs/out.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:935604738717e5bcdfe59dc5fd2842143cfb9a6073228b06c98d53f5b4ec9bec
|
3 |
+
size 103738
|
sample_inputs/p225_001.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b15dc9bbf0ea3cb0f0f02cf20bb60538a77fc5c6b87769e20cfece81891c78d4
|
3 |
+
size 52058
|
sample_inputs/p226_002.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1dc155a7dbfab1a3150a321082789f610b9c9328e84e0273c3b648273acf0a56
|
3 |
+
size 138084
|
sample_inputs/reference.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1dbdd762a5bca492f82244838df309e94149236a7c659ef192cd6779d2b69983
|
3 |
+
size 640078
|
sample_inputs/target.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a43db3ea5e4ffa59b9a6675940661160a255ba9d3bef4133855298a1ff38ee8e
|
3 |
+
size 704078
|
sample_inputs/timcast1.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2fb4d35e5e20c59e6deb69694da0bd403f80704e2f3d9b8d4c4d1a5b558bc6c1
|
3 |
+
size 1764044
|
speaker_encoder/__init__.py
ADDED
File without changes
|
speaker_encoder/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (138 Bytes). View file
|
|
speaker_encoder/__pycache__/audio.cpython-39.pyc
ADDED
Binary file (3.75 kB). View file
|
|
speaker_encoder/__pycache__/hparams.cpython-39.pyc
ADDED
Binary file (497 Bytes). View file
|
|
speaker_encoder/__pycache__/params_data.cpython-39.pyc
ADDED
Binary file (445 Bytes). View file
|
|
speaker_encoder/__pycache__/voice_encoder.cpython-39.pyc
ADDED
Binary file (8.28 kB). View file
|
|
speaker_encoder/audio.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from scipy.ndimage.morphology import binary_dilation
|
2 |
+
from speaker_encoder.params_data import *
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional, Union
|
5 |
+
import numpy as np
|
6 |
+
import webrtcvad
|
7 |
+
import librosa
|
8 |
+
import struct
|
9 |
+
|
10 |
+
int16_max = (2 ** 15) - 1
|
11 |
+
|
12 |
+
|
13 |
+
def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
|
14 |
+
source_sr: Optional[int] = None):
|
15 |
+
"""
|
16 |
+
Applies the preprocessing operations used in training the Speaker Encoder to a waveform
|
17 |
+
either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
|
18 |
+
|
19 |
+
:param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
|
20 |
+
just .wav), either the waveform as a numpy array of floats.
|
21 |
+
:param source_sr: if passing an audio waveform, the sampling rate of the waveform before
|
22 |
+
preprocessing. After preprocessing, the waveform's sampling rate will match the data
|
23 |
+
hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
|
24 |
+
this argument will be ignored.
|
25 |
+
"""
|
26 |
+
# Load the wav from disk if needed
|
27 |
+
if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
|
28 |
+
wav, source_sr = librosa.load(fpath_or_wav, sr=None)
|
29 |
+
else:
|
30 |
+
wav = fpath_or_wav
|
31 |
+
|
32 |
+
# Resample the wav if needed
|
33 |
+
if source_sr is not None and source_sr != sampling_rate:
|
34 |
+
wav = librosa.resample(wav, source_sr, sampling_rate)
|
35 |
+
|
36 |
+
# Apply the preprocessing: normalize volume and shorten long silences
|
37 |
+
wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
|
38 |
+
wav = trim_long_silences(wav)
|
39 |
+
|
40 |
+
return wav
|
41 |
+
|
42 |
+
|
43 |
+
def wav_to_mel_spectrogram(wav):
|
44 |
+
"""
|
45 |
+
Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
|
46 |
+
Note: this not a log-mel spectrogram.
|
47 |
+
"""
|
48 |
+
frames = librosa.feature.melspectrogram(
|
49 |
+
y=wav,
|
50 |
+
sr=sampling_rate,
|
51 |
+
n_fft=int(sampling_rate * mel_window_length / 1000),
|
52 |
+
hop_length=int(sampling_rate * mel_window_step / 1000),
|
53 |
+
n_mels=mel_n_channels
|
54 |
+
)
|
55 |
+
return frames.astype(np.float32).T
|
56 |
+
|
57 |
+
|
58 |
+
def trim_long_silences(wav):
|
59 |
+
"""
|
60 |
+
Ensures that segments without voice in the waveform remain no longer than a
|
61 |
+
threshold determined by the VAD parameters in params.py.
|
62 |
+
|
63 |
+
:param wav: the raw waveform as a numpy array of floats
|
64 |
+
:return: the same waveform with silences trimmed away (length <= original wav length)
|
65 |
+
"""
|
66 |
+
# Compute the voice detection window size
|
67 |
+
samples_per_window = (vad_window_length * sampling_rate) // 1000
|
68 |
+
|
69 |
+
# Trim the end of the audio to have a multiple of the window size
|
70 |
+
wav = wav[:len(wav) - (len(wav) % samples_per_window)]
|
71 |
+
|
72 |
+
# Convert the float waveform to 16-bit mono PCM
|
73 |
+
pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
|
74 |
+
|
75 |
+
# Perform voice activation detection
|
76 |
+
voice_flags = []
|
77 |
+
vad = webrtcvad.Vad(mode=3)
|
78 |
+
for window_start in range(0, len(wav), samples_per_window):
|
79 |
+
window_end = window_start + samples_per_window
|
80 |
+
voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
|
81 |
+
sample_rate=sampling_rate))
|
82 |
+
voice_flags = np.array(voice_flags)
|
83 |
+
|
84 |
+
# Smooth the voice detection with a moving average
|
85 |
+
def moving_average(array, width):
|
86 |
+
array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
|
87 |
+
ret = np.cumsum(array_padded, dtype=float)
|
88 |
+
ret[width:] = ret[width:] - ret[:-width]
|
89 |
+
return ret[width - 1:] / width
|
90 |
+
|
91 |
+
audio_mask = moving_average(voice_flags, vad_moving_average_width)
|
92 |
+
audio_mask = np.round(audio_mask).astype(np.bool)
|
93 |
+
|
94 |
+
# Dilate the voiced regions
|
95 |
+
audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
|
96 |
+
audio_mask = np.repeat(audio_mask, samples_per_window)
|
97 |
+
|
98 |
+
return wav[audio_mask == True]
|
99 |
+
|
100 |
+
|
101 |
+
def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
|
102 |
+
if increase_only and decrease_only:
|
103 |
+
raise ValueError("Both increase only and decrease only are set")
|
104 |
+
dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
|
105 |
+
if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
|
106 |
+
return wav
|
107 |
+
return wav * (10 ** (dBFS_change / 20))
|
speaker_encoder/ckpt/pretrained_bak_5805000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bc7ff82ef75becd495aab2ede3a8220da393a717f178ae9534df355a6173bbca
|
3 |
+
size 17090379
|
speaker_encoder/compute_embed.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from speaker_encoder import inference as encoder
|
2 |
+
from multiprocessing.pool import Pool
|
3 |
+
from functools import partial
|
4 |
+
from pathlib import Path
|
5 |
+
# from utils import logmmse
|
6 |
+
# from tqdm import tqdm
|
7 |
+
# import numpy as np
|
8 |
+
# import librosa
|
9 |
+
|
10 |
+
|
11 |
+
def embed_utterance(fpaths, encoder_model_fpath):
|
12 |
+
if not encoder.is_loaded():
|
13 |
+
encoder.load_model(encoder_model_fpath)
|
14 |
+
|
15 |
+
# Compute the speaker embedding of the utterance
|
16 |
+
wav_fpath, embed_fpath = fpaths
|
17 |
+
wav = np.load(wav_fpath)
|
18 |
+
wav = encoder.preprocess_wav(wav)
|
19 |
+
embed = encoder.embed_utterance(wav)
|
20 |
+
np.save(embed_fpath, embed, allow_pickle=False)
|
21 |
+
|
22 |
+
|
23 |
+
def create_embeddings(outdir_root: Path, wav_dir: Path, encoder_model_fpath: Path, n_processes: int):
|
24 |
+
|
25 |
+
wav_dir = outdir_root.joinpath("audio")
|
26 |
+
metadata_fpath = synthesizer_root.joinpath("train.txt")
|
27 |
+
assert wav_dir.exists() and metadata_fpath.exists()
|
28 |
+
embed_dir = synthesizer_root.joinpath("embeds")
|
29 |
+
embed_dir.mkdir(exist_ok=True)
|
30 |
+
|
31 |
+
# Gather the input wave filepath and the target output embed filepath
|
32 |
+
with metadata_fpath.open("r") as metadata_file:
|
33 |
+
metadata = [line.split("|") for line in metadata_file]
|
34 |
+
fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
|
35 |
+
|
36 |
+
# TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
|
37 |
+
# Embed the utterances in separate threads
|
38 |
+
func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
|
39 |
+
job = Pool(n_processes).imap(func, fpaths)
|
40 |
+
list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
|
speaker_encoder/config.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
librispeech_datasets = {
|
2 |
+
"train": {
|
3 |
+
"clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
|
4 |
+
"other": ["LibriSpeech/train-other-500"]
|
5 |
+
},
|
6 |
+
"test": {
|
7 |
+
"clean": ["LibriSpeech/test-clean"],
|
8 |
+
"other": ["LibriSpeech/test-other"]
|
9 |
+
},
|
10 |
+
"dev": {
|
11 |
+
"clean": ["LibriSpeech/dev-clean"],
|
12 |
+
"other": ["LibriSpeech/dev-other"]
|
13 |
+
},
|
14 |
+
}
|
15 |
+
libritts_datasets = {
|
16 |
+
"train": {
|
17 |
+
"clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
|
18 |
+
"other": ["LibriTTS/train-other-500"]
|
19 |
+
},
|
20 |
+
"test": {
|
21 |
+
"clean": ["LibriTTS/test-clean"],
|
22 |
+
"other": ["LibriTTS/test-other"]
|
23 |
+
},
|
24 |
+
"dev": {
|
25 |
+
"clean": ["LibriTTS/dev-clean"],
|
26 |
+
"other": ["LibriTTS/dev-other"]
|
27 |
+
},
|
28 |
+
}
|
29 |
+
voxceleb_datasets = {
|
30 |
+
"voxceleb1" : {
|
31 |
+
"train": ["VoxCeleb1/wav"],
|
32 |
+
"test": ["VoxCeleb1/test_wav"]
|
33 |
+
},
|
34 |
+
"voxceleb2" : {
|
35 |
+
"train": ["VoxCeleb2/dev/aac"],
|
36 |
+
"test": ["VoxCeleb2/test_wav"]
|
37 |
+
}
|
38 |
+
}
|
39 |
+
|
40 |
+
other_datasets = [
|
41 |
+
"LJSpeech-1.1",
|
42 |
+
"VCTK-Corpus/wav48",
|
43 |
+
]
|
44 |
+
|
45 |
+
anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
|
speaker_encoder/data_objects/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
|
2 |
+
from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
|
speaker_encoder/data_objects/random_cycler.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
class RandomCycler:
|
4 |
+
"""
|
5 |
+
Creates an internal copy of a sequence and allows access to its items in a constrained random
|
6 |
+
order. For a source sequence of n items and one or several consecutive queries of a total
|
7 |
+
of m items, the following guarantees hold (one implies the other):
|
8 |
+
- Each item will be returned between m // n and ((m - 1) // n) + 1 times.
|
9 |
+
- Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, source):
|
13 |
+
if len(source) == 0:
|
14 |
+
raise Exception("Can't create RandomCycler from an empty collection")
|
15 |
+
self.all_items = list(source)
|
16 |
+
self.next_items = []
|
17 |
+
|
18 |
+
def sample(self, count: int):
|
19 |
+
shuffle = lambda l: random.sample(l, len(l))
|
20 |
+
|
21 |
+
out = []
|
22 |
+
while count > 0:
|
23 |
+
if count >= len(self.all_items):
|
24 |
+
out.extend(shuffle(list(self.all_items)))
|
25 |
+
count -= len(self.all_items)
|
26 |
+
continue
|
27 |
+
n = min(count, len(self.next_items))
|
28 |
+
out.extend(self.next_items[:n])
|
29 |
+
count -= n
|
30 |
+
self.next_items = self.next_items[n:]
|
31 |
+
if len(self.next_items) == 0:
|
32 |
+
self.next_items = shuffle(list(self.all_items))
|
33 |
+
return out
|
34 |
+
|
35 |
+
def __next__(self):
|
36 |
+
return self.sample(1)[0]
|
37 |
+
|
speaker_encoder/data_objects/speaker.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from speaker_encoder.data_objects.random_cycler import RandomCycler
|
2 |
+
from speaker_encoder.data_objects.utterance import Utterance
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
# Contains the set of utterances of a single speaker
|
6 |
+
class Speaker:
|
7 |
+
def __init__(self, root: Path):
|
8 |
+
self.root = root
|
9 |
+
self.name = root.name
|
10 |
+
self.utterances = None
|
11 |
+
self.utterance_cycler = None
|
12 |
+
|
13 |
+
def _load_utterances(self):
|
14 |
+
with self.root.joinpath("_sources.txt").open("r") as sources_file:
|
15 |
+
sources = [l.split(",") for l in sources_file]
|
16 |
+
sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
|
17 |
+
self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
|
18 |
+
self.utterance_cycler = RandomCycler(self.utterances)
|
19 |
+
|
20 |
+
def random_partial(self, count, n_frames):
|
21 |
+
"""
|
22 |
+
Samples a batch of <count> unique partial utterances from the disk in a way that all
|
23 |
+
utterances come up at least once every two cycles and in a random order every time.
|
24 |
+
|
25 |
+
:param count: The number of partial utterances to sample from the set of utterances from
|
26 |
+
that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
|
27 |
+
the number of utterances available.
|
28 |
+
:param n_frames: The number of frames in the partial utterance.
|
29 |
+
:return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
|
30 |
+
frames are the frames of the partial utterances and range is the range of the partial
|
31 |
+
utterance with regard to the complete utterance.
|
32 |
+
"""
|
33 |
+
if self.utterances is None:
|
34 |
+
self._load_utterances()
|
35 |
+
|
36 |
+
utterances = self.utterance_cycler.sample(count)
|
37 |
+
|
38 |
+
a = [(u,) + u.random_partial(n_frames) for u in utterances]
|
39 |
+
|
40 |
+
return a
|
speaker_encoder/data_objects/speaker_batch.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import List
|
3 |
+
from speaker_encoder.data_objects.speaker import Speaker
|
4 |
+
|
5 |
+
class SpeakerBatch:
|
6 |
+
def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
|
7 |
+
self.speakers = speakers
|
8 |
+
self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
|
9 |
+
|
10 |
+
# Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
|
11 |
+
# 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
|
12 |
+
self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
|
speaker_encoder/data_objects/speaker_verification_dataset.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from speaker_encoder.data_objects.random_cycler import RandomCycler
|
2 |
+
from speaker_encoder.data_objects.speaker_batch import SpeakerBatch
|
3 |
+
from speaker_encoder.data_objects.speaker import Speaker
|
4 |
+
from speaker_encoder.params_data import partials_n_frames
|
5 |
+
from torch.utils.data import Dataset, DataLoader
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
# TODO: improve with a pool of speakers for data efficiency
|
9 |
+
|
10 |
+
class SpeakerVerificationDataset(Dataset):
|
11 |
+
def __init__(self, datasets_root: Path):
|
12 |
+
self.root = datasets_root
|
13 |
+
speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
|
14 |
+
if len(speaker_dirs) == 0:
|
15 |
+
raise Exception("No speakers found. Make sure you are pointing to the directory "
|
16 |
+
"containing all preprocessed speaker directories.")
|
17 |
+
self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
|
18 |
+
self.speaker_cycler = RandomCycler(self.speakers)
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return int(1e10)
|
22 |
+
|
23 |
+
def __getitem__(self, index):
|
24 |
+
return next(self.speaker_cycler)
|
25 |
+
|
26 |
+
def get_logs(self):
|
27 |
+
log_string = ""
|
28 |
+
for log_fpath in self.root.glob("*.txt"):
|
29 |
+
with log_fpath.open("r") as log_file:
|
30 |
+
log_string += "".join(log_file.readlines())
|
31 |
+
return log_string
|
32 |
+
|
33 |
+
|
34 |
+
class SpeakerVerificationDataLoader(DataLoader):
|
35 |
+
def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
|
36 |
+
batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
|
37 |
+
worker_init_fn=None):
|
38 |
+
self.utterances_per_speaker = utterances_per_speaker
|
39 |
+
|
40 |
+
super().__init__(
|
41 |
+
dataset=dataset,
|
42 |
+
batch_size=speakers_per_batch,
|
43 |
+
shuffle=False,
|
44 |
+
sampler=sampler,
|
45 |
+
batch_sampler=batch_sampler,
|
46 |
+
num_workers=num_workers,
|
47 |
+
collate_fn=self.collate,
|
48 |
+
pin_memory=pin_memory,
|
49 |
+
drop_last=False,
|
50 |
+
timeout=timeout,
|
51 |
+
worker_init_fn=worker_init_fn
|
52 |
+
)
|
53 |
+
|
54 |
+
def collate(self, speakers):
|
55 |
+
return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
|
56 |
+
|
speaker_encoder/data_objects/utterance.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class Utterance:
|
5 |
+
def __init__(self, frames_fpath, wave_fpath):
|
6 |
+
self.frames_fpath = frames_fpath
|
7 |
+
self.wave_fpath = wave_fpath
|
8 |
+
|
9 |
+
def get_frames(self):
|
10 |
+
return np.load(self.frames_fpath)
|
11 |
+
|
12 |
+
def random_partial(self, n_frames):
|
13 |
+
"""
|
14 |
+
Crops the frames into a partial utterance of n_frames
|
15 |
+
|
16 |
+
:param n_frames: The number of frames of the partial utterance
|
17 |
+
:return: the partial utterance frames and a tuple indicating the start and end of the
|
18 |
+
partial utterance in the complete utterance.
|
19 |
+
"""
|
20 |
+
frames = self.get_frames()
|
21 |
+
if frames.shape[0] == n_frames:
|
22 |
+
start = 0
|
23 |
+
else:
|
24 |
+
start = np.random.randint(0, frames.shape[0] - n_frames)
|
25 |
+
end = start + n_frames
|
26 |
+
return frames[start:end], (start, end)
|
speaker_encoder/hparams.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Mel-filterbank
|
2 |
+
mel_window_length = 25 # In milliseconds
|
3 |
+
mel_window_step = 10 # In milliseconds
|
4 |
+
mel_n_channels = 40
|
5 |
+
|
6 |
+
|
7 |
+
## Audio
|
8 |
+
sampling_rate = 16000
|
9 |
+
# Number of spectrogram frames in a partial utterance
|
10 |
+
partials_n_frames = 160 # 1600 ms
|
11 |
+
|
12 |
+
|
13 |
+
## Voice Activation Detection
|
14 |
+
# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
|
15 |
+
# This sets the granularity of the VAD. Should not need to be changed.
|
16 |
+
vad_window_length = 30 # In milliseconds
|
17 |
+
# Number of frames to average together when performing the moving average smoothing.
|
18 |
+
# The larger this value, the larger the VAD variations must be to not get smoothed out.
|
19 |
+
vad_moving_average_width = 8
|
20 |
+
# Maximum number of consecutive silent frames a segment can have.
|
21 |
+
vad_max_silence_length = 6
|
22 |
+
|
23 |
+
|
24 |
+
## Audio volume normalization
|
25 |
+
audio_norm_target_dBFS = -30
|
26 |
+
|
27 |
+
|
28 |
+
## Model parameters
|
29 |
+
model_hidden_size = 256
|
30 |
+
model_embedding_size = 256
|
31 |
+
model_num_layers = 3
|
speaker_encoder/inference.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from speaker_encoder.params_data import *
|
2 |
+
from speaker_encoder.model import SpeakerEncoder
|
3 |
+
from speaker_encoder.audio import preprocess_wav # We want to expose this function from here
|
4 |
+
from matplotlib import cm
|
5 |
+
from speaker_encoder import audio
|
6 |
+
from pathlib import Path
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
_model = None # type: SpeakerEncoder
|
12 |
+
_device = None # type: torch.device
|
13 |
+
|
14 |
+
|
15 |
+
def load_model(weights_fpath: Path, device=None):
|
16 |
+
"""
|
17 |
+
Loads the model in memory. If this function is not explicitely called, it will be run on the
|
18 |
+
first call to embed_frames() with the default weights file.
|
19 |
+
|
20 |
+
:param weights_fpath: the path to saved model weights.
|
21 |
+
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
|
22 |
+
model will be loaded and will run on this device. Outputs will however always be on the cpu.
|
23 |
+
If None, will default to your GPU if it"s available, otherwise your CPU.
|
24 |
+
"""
|
25 |
+
# TODO: I think the slow loading of the encoder might have something to do with the device it
|
26 |
+
# was saved on. Worth investigating.
|
27 |
+
global _model, _device
|
28 |
+
if device is None:
|
29 |
+
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
+
elif isinstance(device, str):
|
31 |
+
_device = torch.device(device)
|
32 |
+
_model = SpeakerEncoder(_device, torch.device("cpu"))
|
33 |
+
checkpoint = torch.load(weights_fpath)
|
34 |
+
_model.load_state_dict(checkpoint["model_state"])
|
35 |
+
_model.eval()
|
36 |
+
print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
|
37 |
+
|
38 |
+
|
39 |
+
def is_loaded():
|
40 |
+
return _model is not None
|
41 |
+
|
42 |
+
|
43 |
+
def embed_frames_batch(frames_batch):
|
44 |
+
"""
|
45 |
+
Computes embeddings for a batch of mel spectrogram.
|
46 |
+
|
47 |
+
:param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
|
48 |
+
(batch_size, n_frames, n_channels)
|
49 |
+
:return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
|
50 |
+
"""
|
51 |
+
if _model is None:
|
52 |
+
raise Exception("Model was not loaded. Call load_model() before inference.")
|
53 |
+
|
54 |
+
frames = torch.from_numpy(frames_batch).to(_device)
|
55 |
+
embed = _model.forward(frames).detach().cpu().numpy()
|
56 |
+
return embed
|
57 |
+
|
58 |
+
|
59 |
+
def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
|
60 |
+
min_pad_coverage=0.75, overlap=0.5):
|
61 |
+
"""
|
62 |
+
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
|
63 |
+
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
|
64 |
+
spectrogram slices are returned, so as to make each partial utterance waveform correspond to
|
65 |
+
its spectrogram. This function assumes that the mel spectrogram parameters used are those
|
66 |
+
defined in params_data.py.
|
67 |
+
|
68 |
+
The returned ranges may be indexing further than the length of the waveform. It is
|
69 |
+
recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
|
70 |
+
|
71 |
+
:param n_samples: the number of samples in the waveform
|
72 |
+
:param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
|
73 |
+
utterance
|
74 |
+
:param min_pad_coverage: when reaching the last partial utterance, it may or may not have
|
75 |
+
enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
|
76 |
+
then the last partial utterance will be considered, as if we padded the audio. Otherwise,
|
77 |
+
it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
|
78 |
+
utterance, this parameter is ignored so that the function always returns at least 1 slice.
|
79 |
+
:param overlap: by how much the partial utterance should overlap. If set to 0, the partial
|
80 |
+
utterances are entirely disjoint.
|
81 |
+
:return: the waveform slices and mel spectrogram slices as lists of array slices. Index
|
82 |
+
respectively the waveform and the mel spectrogram with these slices to obtain the partial
|
83 |
+
utterances.
|
84 |
+
"""
|
85 |
+
assert 0 <= overlap < 1
|
86 |
+
assert 0 < min_pad_coverage <= 1
|
87 |
+
|
88 |
+
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
89 |
+
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
90 |
+
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
|
91 |
+
|
92 |
+
# Compute the slices
|
93 |
+
wav_slices, mel_slices = [], []
|
94 |
+
steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
|
95 |
+
for i in range(0, steps, frame_step):
|
96 |
+
mel_range = np.array([i, i + partial_utterance_n_frames])
|
97 |
+
wav_range = mel_range * samples_per_frame
|
98 |
+
mel_slices.append(slice(*mel_range))
|
99 |
+
wav_slices.append(slice(*wav_range))
|
100 |
+
|
101 |
+
# Evaluate whether extra padding is warranted or not
|
102 |
+
last_wav_range = wav_slices[-1]
|
103 |
+
coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
|
104 |
+
if coverage < min_pad_coverage and len(mel_slices) > 1:
|
105 |
+
mel_slices = mel_slices[:-1]
|
106 |
+
wav_slices = wav_slices[:-1]
|
107 |
+
|
108 |
+
return wav_slices, mel_slices
|
109 |
+
|
110 |
+
|
111 |
+
def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
|
112 |
+
"""
|
113 |
+
Computes an embedding for a single utterance.
|
114 |
+
|
115 |
+
# TODO: handle multiple wavs to benefit from batching on GPU
|
116 |
+
:param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
|
117 |
+
:param using_partials: if True, then the utterance is split in partial utterances of
|
118 |
+
<partial_utterance_n_frames> frames and the utterance embedding is computed from their
|
119 |
+
normalized average. If False, the utterance is instead computed from feeding the entire
|
120 |
+
spectogram to the network.
|
121 |
+
:param return_partials: if True, the partial embeddings will also be returned along with the
|
122 |
+
wav slices that correspond to the partial embeddings.
|
123 |
+
:param kwargs: additional arguments to compute_partial_splits()
|
124 |
+
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
|
125 |
+
<return_partials> is True, the partial utterances as a numpy array of float32 of shape
|
126 |
+
(n_partials, model_embedding_size) and the wav partials as a list of slices will also be
|
127 |
+
returned. If <using_partials> is simultaneously set to False, both these values will be None
|
128 |
+
instead.
|
129 |
+
"""
|
130 |
+
# Process the entire utterance if not using partials
|
131 |
+
if not using_partials:
|
132 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
133 |
+
embed = embed_frames_batch(frames[None, ...])[0]
|
134 |
+
if return_partials:
|
135 |
+
return embed, None, None
|
136 |
+
return embed
|
137 |
+
|
138 |
+
# Compute where to split the utterance into partials and pad if necessary
|
139 |
+
wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
|
140 |
+
max_wave_length = wave_slices[-1].stop
|
141 |
+
if max_wave_length >= len(wav):
|
142 |
+
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
143 |
+
|
144 |
+
# Split the utterance into partials
|
145 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
146 |
+
frames_batch = np.array([frames[s] for s in mel_slices])
|
147 |
+
partial_embeds = embed_frames_batch(frames_batch)
|
148 |
+
|
149 |
+
# Compute the utterance embedding from the partial embeddings
|
150 |
+
raw_embed = np.mean(partial_embeds, axis=0)
|
151 |
+
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
152 |
+
|
153 |
+
if return_partials:
|
154 |
+
return embed, partial_embeds, wave_slices
|
155 |
+
return embed
|
156 |
+
|
157 |
+
|
158 |
+
def embed_speaker(wavs, **kwargs):
|
159 |
+
raise NotImplemented()
|
160 |
+
|
161 |
+
|
162 |
+
def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
|
163 |
+
if ax is None:
|
164 |
+
ax = plt.gca()
|
165 |
+
|
166 |
+
if shape is None:
|
167 |
+
height = int(np.sqrt(len(embed)))
|
168 |
+
shape = (height, -1)
|
169 |
+
embed = embed.reshape(shape)
|
170 |
+
|
171 |
+
cmap = cm.get_cmap()
|
172 |
+
mappable = ax.imshow(embed, cmap=cmap)
|
173 |
+
cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
|
174 |
+
cbar.set_clim(*color_range)
|
175 |
+
|
176 |
+
ax.set_xticks([]), ax.set_yticks([])
|
177 |
+
ax.set_title(title)
|
speaker_encoder/model.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from speaker_encoder.params_model import *
|
2 |
+
from speaker_encoder.params_data import *
|
3 |
+
from scipy.interpolate import interp1d
|
4 |
+
from sklearn.metrics import roc_curve
|
5 |
+
from torch.nn.utils import clip_grad_norm_
|
6 |
+
from scipy.optimize import brentq
|
7 |
+
from torch import nn
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class SpeakerEncoder(nn.Module):
|
13 |
+
def __init__(self, device, loss_device):
|
14 |
+
super().__init__()
|
15 |
+
self.loss_device = loss_device
|
16 |
+
|
17 |
+
# Network defition
|
18 |
+
self.lstm = nn.LSTM(input_size=mel_n_channels, # 40
|
19 |
+
hidden_size=model_hidden_size, # 256
|
20 |
+
num_layers=model_num_layers, # 3
|
21 |
+
batch_first=True).to(device)
|
22 |
+
self.linear = nn.Linear(in_features=model_hidden_size,
|
23 |
+
out_features=model_embedding_size).to(device)
|
24 |
+
self.relu = torch.nn.ReLU().to(device)
|
25 |
+
|
26 |
+
# Cosine similarity scaling (with fixed initial parameter values)
|
27 |
+
self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
|
28 |
+
self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
|
29 |
+
|
30 |
+
# Loss
|
31 |
+
self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
|
32 |
+
|
33 |
+
def do_gradient_ops(self):
|
34 |
+
# Gradient scale
|
35 |
+
self.similarity_weight.grad *= 0.01
|
36 |
+
self.similarity_bias.grad *= 0.01
|
37 |
+
|
38 |
+
# Gradient clipping
|
39 |
+
clip_grad_norm_(self.parameters(), 3, norm_type=2)
|
40 |
+
|
41 |
+
def forward(self, utterances, hidden_init=None):
|
42 |
+
"""
|
43 |
+
Computes the embeddings of a batch of utterance spectrograms.
|
44 |
+
|
45 |
+
:param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
|
46 |
+
(batch_size, n_frames, n_channels)
|
47 |
+
:param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
|
48 |
+
batch_size, hidden_size). Will default to a tensor of zeros if None.
|
49 |
+
:return: the embeddings as a tensor of shape (batch_size, embedding_size)
|
50 |
+
"""
|
51 |
+
# Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
|
52 |
+
# and the final cell state.
|
53 |
+
out, (hidden, cell) = self.lstm(utterances, hidden_init)
|
54 |
+
|
55 |
+
# We take only the hidden state of the last layer
|
56 |
+
embeds_raw = self.relu(self.linear(hidden[-1]))
|
57 |
+
|
58 |
+
# L2-normalize it
|
59 |
+
embeds = embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
|
60 |
+
|
61 |
+
return embeds
|
62 |
+
|
63 |
+
def similarity_matrix(self, embeds):
|
64 |
+
"""
|
65 |
+
Computes the similarity matrix according the section 2.1 of GE2E.
|
66 |
+
|
67 |
+
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
68 |
+
utterances_per_speaker, embedding_size)
|
69 |
+
:return: the similarity matrix as a tensor of shape (speakers_per_batch,
|
70 |
+
utterances_per_speaker, speakers_per_batch)
|
71 |
+
"""
|
72 |
+
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
73 |
+
|
74 |
+
# Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
|
75 |
+
centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
|
76 |
+
centroids_incl = centroids_incl.clone() / torch.norm(centroids_incl, dim=2, keepdim=True)
|
77 |
+
|
78 |
+
# Exclusive centroids (1 per utterance)
|
79 |
+
centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
|
80 |
+
centroids_excl /= (utterances_per_speaker - 1)
|
81 |
+
centroids_excl = centroids_excl.clone() / torch.norm(centroids_excl, dim=2, keepdim=True)
|
82 |
+
|
83 |
+
# Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
|
84 |
+
# product of these vectors (which is just an element-wise multiplication reduced by a sum).
|
85 |
+
# We vectorize the computation for efficiency.
|
86 |
+
sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
|
87 |
+
speakers_per_batch).to(self.loss_device)
|
88 |
+
mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
|
89 |
+
for j in range(speakers_per_batch):
|
90 |
+
mask = np.where(mask_matrix[j])[0]
|
91 |
+
sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
|
92 |
+
sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
|
93 |
+
|
94 |
+
## Even more vectorized version (slower maybe because of transpose)
|
95 |
+
# sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
|
96 |
+
# ).to(self.loss_device)
|
97 |
+
# eye = np.eye(speakers_per_batch, dtype=np.int)
|
98 |
+
# mask = np.where(1 - eye)
|
99 |
+
# sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
|
100 |
+
# mask = np.where(eye)
|
101 |
+
# sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
|
102 |
+
# sim_matrix2 = sim_matrix2.transpose(1, 2)
|
103 |
+
|
104 |
+
sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
|
105 |
+
return sim_matrix
|
106 |
+
|
107 |
+
def loss(self, embeds):
|
108 |
+
"""
|
109 |
+
Computes the softmax loss according the section 2.1 of GE2E.
|
110 |
+
|
111 |
+
:param embeds: the embeddings as a tensor of shape (speakers_per_batch,
|
112 |
+
utterances_per_speaker, embedding_size)
|
113 |
+
:return: the loss and the EER for this batch of embeddings.
|
114 |
+
"""
|
115 |
+
speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
|
116 |
+
|
117 |
+
# Loss
|
118 |
+
sim_matrix = self.similarity_matrix(embeds)
|
119 |
+
sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
|
120 |
+
speakers_per_batch))
|
121 |
+
ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
|
122 |
+
target = torch.from_numpy(ground_truth).long().to(self.loss_device)
|
123 |
+
loss = self.loss_fn(sim_matrix, target)
|
124 |
+
|
125 |
+
# EER (not backpropagated)
|
126 |
+
with torch.no_grad():
|
127 |
+
inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
|
128 |
+
labels = np.array([inv_argmax(i) for i in ground_truth])
|
129 |
+
preds = sim_matrix.detach().cpu().numpy()
|
130 |
+
|
131 |
+
# Snippet from https://yangcha.github.io/EER-ROC/
|
132 |
+
fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
|
133 |
+
eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
|
134 |
+
|
135 |
+
return loss, eer
|
speaker_encoder/params_data.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
## Mel-filterbank
|
3 |
+
mel_window_length = 25 # In milliseconds
|
4 |
+
mel_window_step = 10 # In milliseconds
|
5 |
+
mel_n_channels = 40
|
6 |
+
|
7 |
+
|
8 |
+
## Audio
|
9 |
+
sampling_rate = 16000
|
10 |
+
# Number of spectrogram frames in a partial utterance
|
11 |
+
partials_n_frames = 160 # 1600 ms
|
12 |
+
# Number of spectrogram frames at inference
|
13 |
+
inference_n_frames = 80 # 800 ms
|
14 |
+
|
15 |
+
|
16 |
+
## Voice Activation Detection
|
17 |
+
# Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
|
18 |
+
# This sets the granularity of the VAD. Should not need to be changed.
|
19 |
+
vad_window_length = 30 # In milliseconds
|
20 |
+
# Number of frames to average together when performing the moving average smoothing.
|
21 |
+
# The larger this value, the larger the VAD variations must be to not get smoothed out.
|
22 |
+
vad_moving_average_width = 8
|
23 |
+
# Maximum number of consecutive silent frames a segment can have.
|
24 |
+
vad_max_silence_length = 6
|
25 |
+
|
26 |
+
|
27 |
+
## Audio volume normalization
|
28 |
+
audio_norm_target_dBFS = -30
|
29 |
+
|
speaker_encoder/params_model.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
## Model parameters
|
3 |
+
model_hidden_size = 256
|
4 |
+
model_embedding_size = 256
|
5 |
+
model_num_layers = 3
|
6 |
+
|
7 |
+
|
8 |
+
## Training parameters
|
9 |
+
learning_rate_init = 1e-4
|
10 |
+
speakers_per_batch = 64
|
11 |
+
utterances_per_speaker = 10
|
speaker_encoder/preprocess.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from multiprocess.pool import ThreadPool
|
2 |
+
from speaker_encoder.params_data import *
|
3 |
+
from speaker_encoder.config import librispeech_datasets, anglophone_nationalites
|
4 |
+
from datetime import datetime
|
5 |
+
from speaker_encoder import audio
|
6 |
+
from pathlib import Path
|
7 |
+
from tqdm import tqdm
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
class DatasetLog:
|
12 |
+
"""
|
13 |
+
Registers metadata about the dataset in a text file.
|
14 |
+
"""
|
15 |
+
def __init__(self, root, name):
|
16 |
+
self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
|
17 |
+
self.sample_data = dict()
|
18 |
+
|
19 |
+
start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
|
20 |
+
self.write_line("Creating dataset %s on %s" % (name, start_time))
|
21 |
+
self.write_line("-----")
|
22 |
+
self._log_params()
|
23 |
+
|
24 |
+
def _log_params(self):
|
25 |
+
from speaker_encoder import params_data
|
26 |
+
self.write_line("Parameter values:")
|
27 |
+
for param_name in (p for p in dir(params_data) if not p.startswith("__")):
|
28 |
+
value = getattr(params_data, param_name)
|
29 |
+
self.write_line("\t%s: %s" % (param_name, value))
|
30 |
+
self.write_line("-----")
|
31 |
+
|
32 |
+
def write_line(self, line):
|
33 |
+
self.text_file.write("%s\n" % line)
|
34 |
+
|
35 |
+
def add_sample(self, **kwargs):
|
36 |
+
for param_name, value in kwargs.items():
|
37 |
+
if not param_name in self.sample_data:
|
38 |
+
self.sample_data[param_name] = []
|
39 |
+
self.sample_data[param_name].append(value)
|
40 |
+
|
41 |
+
def finalize(self):
|
42 |
+
self.write_line("Statistics:")
|
43 |
+
for param_name, values in self.sample_data.items():
|
44 |
+
self.write_line("\t%s:" % param_name)
|
45 |
+
self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
|
46 |
+
self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
|
47 |
+
self.write_line("-----")
|
48 |
+
end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
|
49 |
+
self.write_line("Finished on %s" % end_time)
|
50 |
+
self.text_file.close()
|
51 |
+
|
52 |
+
|
53 |
+
def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
|
54 |
+
dataset_root = datasets_root.joinpath(dataset_name)
|
55 |
+
if not dataset_root.exists():
|
56 |
+
print("Couldn\'t find %s, skipping this dataset." % dataset_root)
|
57 |
+
return None, None
|
58 |
+
return dataset_root, DatasetLog(out_dir, dataset_name)
|
59 |
+
|
60 |
+
|
61 |
+
def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
|
62 |
+
skip_existing, logger):
|
63 |
+
print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
|
64 |
+
|
65 |
+
# Function to preprocess utterances for one speaker
|
66 |
+
def preprocess_speaker(speaker_dir: Path):
|
67 |
+
# Give a name to the speaker that includes its dataset
|
68 |
+
speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
|
69 |
+
|
70 |
+
# Create an output directory with that name, as well as a txt file containing a
|
71 |
+
# reference to each source file.
|
72 |
+
speaker_out_dir = out_dir.joinpath(speaker_name)
|
73 |
+
speaker_out_dir.mkdir(exist_ok=True)
|
74 |
+
sources_fpath = speaker_out_dir.joinpath("_sources.txt")
|
75 |
+
|
76 |
+
# There's a possibility that the preprocessing was interrupted earlier, check if
|
77 |
+
# there already is a sources file.
|
78 |
+
if sources_fpath.exists():
|
79 |
+
try:
|
80 |
+
with sources_fpath.open("r") as sources_file:
|
81 |
+
existing_fnames = {line.split(",")[0] for line in sources_file}
|
82 |
+
except:
|
83 |
+
existing_fnames = {}
|
84 |
+
else:
|
85 |
+
existing_fnames = {}
|
86 |
+
|
87 |
+
# Gather all audio files for that speaker recursively
|
88 |
+
sources_file = sources_fpath.open("a" if skip_existing else "w")
|
89 |
+
for in_fpath in speaker_dir.glob("**/*.%s" % extension):
|
90 |
+
# Check if the target output file already exists
|
91 |
+
out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
|
92 |
+
out_fname = out_fname.replace(".%s" % extension, ".npy")
|
93 |
+
if skip_existing and out_fname in existing_fnames:
|
94 |
+
continue
|
95 |
+
|
96 |
+
# Load and preprocess the waveform
|
97 |
+
wav = audio.preprocess_wav(in_fpath)
|
98 |
+
if len(wav) == 0:
|
99 |
+
continue
|
100 |
+
|
101 |
+
# Create the mel spectrogram, discard those that are too short
|
102 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
103 |
+
if len(frames) < partials_n_frames:
|
104 |
+
continue
|
105 |
+
|
106 |
+
out_fpath = speaker_out_dir.joinpath(out_fname)
|
107 |
+
np.save(out_fpath, frames)
|
108 |
+
logger.add_sample(duration=len(wav) / sampling_rate)
|
109 |
+
sources_file.write("%s,%s\n" % (out_fname, in_fpath))
|
110 |
+
|
111 |
+
sources_file.close()
|
112 |
+
|
113 |
+
# Process the utterances for each speaker
|
114 |
+
with ThreadPool(8) as pool:
|
115 |
+
list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
|
116 |
+
unit="speakers"))
|
117 |
+
logger.finalize()
|
118 |
+
print("Done preprocessing %s.\n" % dataset_name)
|
119 |
+
|
120 |
+
|
121 |
+
# Function to preprocess utterances for one speaker
|
122 |
+
def __preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, extension: str, skip_existing: bool):
|
123 |
+
# Give a name to the speaker that includes its dataset
|
124 |
+
speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
|
125 |
+
|
126 |
+
# Create an output directory with that name, as well as a txt file containing a
|
127 |
+
# reference to each source file.
|
128 |
+
speaker_out_dir = out_dir.joinpath(speaker_name)
|
129 |
+
speaker_out_dir.mkdir(exist_ok=True)
|
130 |
+
sources_fpath = speaker_out_dir.joinpath("_sources.txt")
|
131 |
+
|
132 |
+
# There's a possibility that the preprocessing was interrupted earlier, check if
|
133 |
+
# there already is a sources file.
|
134 |
+
# if sources_fpath.exists():
|
135 |
+
# try:
|
136 |
+
# with sources_fpath.open("r") as sources_file:
|
137 |
+
# existing_fnames = {line.split(",")[0] for line in sources_file}
|
138 |
+
# except:
|
139 |
+
# existing_fnames = {}
|
140 |
+
# else:
|
141 |
+
# existing_fnames = {}
|
142 |
+
existing_fnames = {}
|
143 |
+
# Gather all audio files for that speaker recursively
|
144 |
+
sources_file = sources_fpath.open("a" if skip_existing else "w")
|
145 |
+
|
146 |
+
for in_fpath in speaker_dir.glob("**/*.%s" % extension):
|
147 |
+
# Check if the target output file already exists
|
148 |
+
out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
|
149 |
+
out_fname = out_fname.replace(".%s" % extension, ".npy")
|
150 |
+
if skip_existing and out_fname in existing_fnames:
|
151 |
+
continue
|
152 |
+
|
153 |
+
# Load and preprocess the waveform
|
154 |
+
wav = audio.preprocess_wav(in_fpath)
|
155 |
+
if len(wav) == 0:
|
156 |
+
continue
|
157 |
+
|
158 |
+
# Create the mel spectrogram, discard those that are too short
|
159 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
160 |
+
if len(frames) < partials_n_frames:
|
161 |
+
continue
|
162 |
+
|
163 |
+
out_fpath = speaker_out_dir.joinpath(out_fname)
|
164 |
+
np.save(out_fpath, frames)
|
165 |
+
# logger.add_sample(duration=len(wav) / sampling_rate)
|
166 |
+
sources_file.write("%s,%s\n" % (out_fname, in_fpath))
|
167 |
+
|
168 |
+
sources_file.close()
|
169 |
+
return len(wav)
|
170 |
+
|
171 |
+
def _preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
|
172 |
+
skip_existing, logger):
|
173 |
+
# from multiprocessing import Pool, cpu_count
|
174 |
+
from pathos.multiprocessing import ProcessingPool as Pool
|
175 |
+
# Function to preprocess utterances for one speaker
|
176 |
+
def __preprocess_speaker(speaker_dir: Path):
|
177 |
+
# Give a name to the speaker that includes its dataset
|
178 |
+
speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
|
179 |
+
|
180 |
+
# Create an output directory with that name, as well as a txt file containing a
|
181 |
+
# reference to each source file.
|
182 |
+
speaker_out_dir = out_dir.joinpath(speaker_name)
|
183 |
+
speaker_out_dir.mkdir(exist_ok=True)
|
184 |
+
sources_fpath = speaker_out_dir.joinpath("_sources.txt")
|
185 |
+
|
186 |
+
existing_fnames = {}
|
187 |
+
# Gather all audio files for that speaker recursively
|
188 |
+
sources_file = sources_fpath.open("a" if skip_existing else "w")
|
189 |
+
wav_lens = []
|
190 |
+
for in_fpath in speaker_dir.glob("**/*.%s" % extension):
|
191 |
+
# Check if the target output file already exists
|
192 |
+
out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
|
193 |
+
out_fname = out_fname.replace(".%s" % extension, ".npy")
|
194 |
+
if skip_existing and out_fname in existing_fnames:
|
195 |
+
continue
|
196 |
+
|
197 |
+
# Load and preprocess the waveform
|
198 |
+
wav = audio.preprocess_wav(in_fpath)
|
199 |
+
if len(wav) == 0:
|
200 |
+
continue
|
201 |
+
|
202 |
+
# Create the mel spectrogram, discard those that are too short
|
203 |
+
frames = audio.wav_to_mel_spectrogram(wav)
|
204 |
+
if len(frames) < partials_n_frames:
|
205 |
+
continue
|
206 |
+
|
207 |
+
out_fpath = speaker_out_dir.joinpath(out_fname)
|
208 |
+
np.save(out_fpath, frames)
|
209 |
+
# logger.add_sample(duration=len(wav) / sampling_rate)
|
210 |
+
sources_file.write("%s,%s\n" % (out_fname, in_fpath))
|
211 |
+
wav_lens.append(len(wav))
|
212 |
+
sources_file.close()
|
213 |
+
return wav_lens
|
214 |
+
|
215 |
+
print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
|
216 |
+
# Process the utterances for each speaker
|
217 |
+
# with ThreadPool(8) as pool:
|
218 |
+
# list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
|
219 |
+
# unit="speakers"))
|
220 |
+
pool = Pool(processes=20)
|
221 |
+
for i, wav_lens in enumerate(pool.map(__preprocess_speaker, speaker_dirs), 1):
|
222 |
+
for wav_len in wav_lens:
|
223 |
+
logger.add_sample(duration=wav_len / sampling_rate)
|
224 |
+
print(f'{i}/{len(speaker_dirs)} \r')
|
225 |
+
|
226 |
+
logger.finalize()
|
227 |
+
print("Done preprocessing %s.\n" % dataset_name)
|
228 |
+
|
229 |
+
|
230 |
+
def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
|
231 |
+
for dataset_name in librispeech_datasets["train"]["other"]:
|
232 |
+
# Initialize the preprocessing
|
233 |
+
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
234 |
+
if not dataset_root:
|
235 |
+
return
|
236 |
+
|
237 |
+
# Preprocess all speakers
|
238 |
+
speaker_dirs = list(dataset_root.glob("*"))
|
239 |
+
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac",
|
240 |
+
skip_existing, logger)
|
241 |
+
|
242 |
+
|
243 |
+
def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
|
244 |
+
# Initialize the preprocessing
|
245 |
+
dataset_name = "VoxCeleb1"
|
246 |
+
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
247 |
+
if not dataset_root:
|
248 |
+
return
|
249 |
+
|
250 |
+
# Get the contents of the meta file
|
251 |
+
with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
|
252 |
+
metadata = [line.split("\t") for line in metafile][1:]
|
253 |
+
|
254 |
+
# Select the ID and the nationality, filter out non-anglophone speakers
|
255 |
+
nationalities = {line[0]: line[3] for line in metadata}
|
256 |
+
# keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
|
257 |
+
# nationality.lower() in anglophone_nationalites]
|
258 |
+
keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items()]
|
259 |
+
print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
|
260 |
+
(len(keep_speaker_ids), len(nationalities)))
|
261 |
+
|
262 |
+
# Get the speaker directories for anglophone speakers only
|
263 |
+
speaker_dirs = dataset_root.joinpath("wav").glob("*")
|
264 |
+
speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
|
265 |
+
speaker_dir.name in keep_speaker_ids]
|
266 |
+
print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
|
267 |
+
(len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
|
268 |
+
|
269 |
+
# Preprocess all speakers
|
270 |
+
_preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
|
271 |
+
skip_existing, logger)
|
272 |
+
|
273 |
+
|
274 |
+
def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
|
275 |
+
# Initialize the preprocessing
|
276 |
+
dataset_name = "VoxCeleb2"
|
277 |
+
dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
|
278 |
+
if not dataset_root:
|
279 |
+
return
|
280 |
+
|
281 |
+
# Get the speaker directories
|
282 |
+
# Preprocess all speakers
|
283 |
+
speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
|
284 |
+
_preprocess_speaker_dirs_vox2(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a",
|
285 |
+
skip_existing, logger)
|
speaker_encoder/train.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from speaker_encoder.visualizations import Visualizations
|
2 |
+
from speaker_encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
|
3 |
+
from speaker_encoder.params_model import *
|
4 |
+
from speaker_encoder.model import SpeakerEncoder
|
5 |
+
from utils.profiler import Profiler
|
6 |
+
from pathlib import Path
|
7 |
+
import torch
|
8 |
+
|
9 |
+
def sync(device: torch.device):
|
10 |
+
# FIXME
|
11 |
+
return
|
12 |
+
# For correct profiling (cuda operations are async)
|
13 |
+
if device.type == "cuda":
|
14 |
+
torch.cuda.synchronize(device)
|
15 |
+
|
16 |
+
def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
|
17 |
+
backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
|
18 |
+
no_visdom: bool):
|
19 |
+
# Create a dataset and a dataloader
|
20 |
+
dataset = SpeakerVerificationDataset(clean_data_root)
|
21 |
+
loader = SpeakerVerificationDataLoader(
|
22 |
+
dataset,
|
23 |
+
speakers_per_batch, # 64
|
24 |
+
utterances_per_speaker, # 10
|
25 |
+
num_workers=8,
|
26 |
+
)
|
27 |
+
|
28 |
+
# Setup the device on which to run the forward pass and the loss. These can be different,
|
29 |
+
# because the forward pass is faster on the GPU whereas the loss is often (depending on your
|
30 |
+
# hyperparameters) faster on the CPU.
|
31 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
32 |
+
# FIXME: currently, the gradient is None if loss_device is cuda
|
33 |
+
loss_device = torch.device("cpu")
|
34 |
+
|
35 |
+
# Create the model and the optimizer
|
36 |
+
model = SpeakerEncoder(device, loss_device)
|
37 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
|
38 |
+
init_step = 1
|
39 |
+
|
40 |
+
# Configure file path for the model
|
41 |
+
state_fpath = models_dir.joinpath(run_id + ".pt")
|
42 |
+
backup_dir = models_dir.joinpath(run_id + "_backups")
|
43 |
+
|
44 |
+
# Load any existing model
|
45 |
+
if not force_restart:
|
46 |
+
if state_fpath.exists():
|
47 |
+
print("Found existing model \"%s\", loading it and resuming training." % run_id)
|
48 |
+
checkpoint = torch.load(state_fpath)
|
49 |
+
init_step = checkpoint["step"]
|
50 |
+
model.load_state_dict(checkpoint["model_state"])
|
51 |
+
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
52 |
+
optimizer.param_groups[0]["lr"] = learning_rate_init
|
53 |
+
else:
|
54 |
+
print("No model \"%s\" found, starting training from scratch." % run_id)
|
55 |
+
else:
|
56 |
+
print("Starting the training from scratch.")
|
57 |
+
model.train()
|
58 |
+
|
59 |
+
# Initialize the visualization environment
|
60 |
+
vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
|
61 |
+
vis.log_dataset(dataset)
|
62 |
+
vis.log_params()
|
63 |
+
device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
|
64 |
+
vis.log_implementation({"Device": device_name})
|
65 |
+
|
66 |
+
# Training loop
|
67 |
+
profiler = Profiler(summarize_every=10, disabled=False)
|
68 |
+
for step, speaker_batch in enumerate(loader, init_step):
|
69 |
+
profiler.tick("Blocking, waiting for batch (threaded)")
|
70 |
+
|
71 |
+
# Forward pass
|
72 |
+
inputs = torch.from_numpy(speaker_batch.data).to(device)
|
73 |
+
sync(device)
|
74 |
+
profiler.tick("Data to %s" % device)
|
75 |
+
embeds = model(inputs)
|
76 |
+
sync(device)
|
77 |
+
profiler.tick("Forward pass")
|
78 |
+
embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
|
79 |
+
loss, eer = model.loss(embeds_loss)
|
80 |
+
sync(loss_device)
|
81 |
+
profiler.tick("Loss")
|
82 |
+
|
83 |
+
# Backward pass
|
84 |
+
model.zero_grad()
|
85 |
+
loss.backward()
|
86 |
+
profiler.tick("Backward pass")
|
87 |
+
model.do_gradient_ops()
|
88 |
+
optimizer.step()
|
89 |
+
profiler.tick("Parameter update")
|
90 |
+
|
91 |
+
# Update visualizations
|
92 |
+
# learning_rate = optimizer.param_groups[0]["lr"]
|
93 |
+
vis.update(loss.item(), eer, step)
|
94 |
+
|
95 |
+
# Draw projections and save them to the backup folder
|
96 |
+
if umap_every != 0 and step % umap_every == 0:
|
97 |
+
print("Drawing and saving projections (step %d)" % step)
|
98 |
+
backup_dir.mkdir(exist_ok=True)
|
99 |
+
projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
|
100 |
+
embeds = embeds.detach().cpu().numpy()
|
101 |
+
vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
|
102 |
+
vis.save()
|
103 |
+
|
104 |
+
# Overwrite the latest version of the model
|
105 |
+
if save_every != 0 and step % save_every == 0:
|
106 |
+
print("Saving the model (step %d)" % step)
|
107 |
+
torch.save({
|
108 |
+
"step": step + 1,
|
109 |
+
"model_state": model.state_dict(),
|
110 |
+
"optimizer_state": optimizer.state_dict(),
|
111 |
+
}, state_fpath)
|
112 |
+
|
113 |
+
# Make a backup
|
114 |
+
if backup_every != 0 and step % backup_every == 0:
|
115 |
+
print("Making a backup (step %d)" % step)
|
116 |
+
backup_dir.mkdir(exist_ok=True)
|
117 |
+
backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
|
118 |
+
torch.save({
|
119 |
+
"step": step + 1,
|
120 |
+
"model_state": model.state_dict(),
|
121 |
+
"optimizer_state": optimizer.state_dict(),
|
122 |
+
}, backup_fpath)
|
123 |
+
|
124 |
+
profiler.tick("Extras (visualizations, saving)")
|
125 |
+
|
speaker_encoder/visualizations.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from speaker_encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
|
2 |
+
from datetime import datetime
|
3 |
+
from time import perf_counter as timer
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
# import webbrowser
|
7 |
+
import visdom
|
8 |
+
import umap
|
9 |
+
|
10 |
+
colormap = np.array([
|
11 |
+
[76, 255, 0],
|
12 |
+
[0, 127, 70],
|
13 |
+
[255, 0, 0],
|
14 |
+
[255, 217, 38],
|
15 |
+
[0, 135, 255],
|
16 |
+
[165, 0, 165],
|
17 |
+
[255, 167, 255],
|
18 |
+
[0, 255, 255],
|
19 |
+
[255, 96, 38],
|
20 |
+
[142, 76, 0],
|
21 |
+
[33, 0, 127],
|
22 |
+
[0, 0, 0],
|
23 |
+
[183, 183, 183],
|
24 |
+
], dtype=np.float) / 255
|
25 |
+
|
26 |
+
|
27 |
+
class Visualizations:
|
28 |
+
def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
|
29 |
+
# Tracking data
|
30 |
+
self.last_update_timestamp = timer()
|
31 |
+
self.update_every = update_every
|
32 |
+
self.step_times = []
|
33 |
+
self.losses = []
|
34 |
+
self.eers = []
|
35 |
+
print("Updating the visualizations every %d steps." % update_every)
|
36 |
+
|
37 |
+
# If visdom is disabled TODO: use a better paradigm for that
|
38 |
+
self.disabled = disabled
|
39 |
+
if self.disabled:
|
40 |
+
return
|
41 |
+
|
42 |
+
# Set the environment name
|
43 |
+
now = str(datetime.now().strftime("%d-%m %Hh%M"))
|
44 |
+
if env_name is None:
|
45 |
+
self.env_name = now
|
46 |
+
else:
|
47 |
+
self.env_name = "%s (%s)" % (env_name, now)
|
48 |
+
|
49 |
+
# Connect to visdom and open the corresponding window in the browser
|
50 |
+
try:
|
51 |
+
self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
|
52 |
+
except ConnectionError:
|
53 |
+
raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
|
54 |
+
"start it.")
|
55 |
+
# webbrowser.open("http://localhost:8097/env/" + self.env_name)
|
56 |
+
|
57 |
+
# Create the windows
|
58 |
+
self.loss_win = None
|
59 |
+
self.eer_win = None
|
60 |
+
# self.lr_win = None
|
61 |
+
self.implementation_win = None
|
62 |
+
self.projection_win = None
|
63 |
+
self.implementation_string = ""
|
64 |
+
|
65 |
+
def log_params(self):
|
66 |
+
if self.disabled:
|
67 |
+
return
|
68 |
+
from speaker_encoder import params_data
|
69 |
+
from speaker_encoder import params_model
|
70 |
+
param_string = "<b>Model parameters</b>:<br>"
|
71 |
+
for param_name in (p for p in dir(params_model) if not p.startswith("__")):
|
72 |
+
value = getattr(params_model, param_name)
|
73 |
+
param_string += "\t%s: %s<br>" % (param_name, value)
|
74 |
+
param_string += "<b>Data parameters</b>:<br>"
|
75 |
+
for param_name in (p for p in dir(params_data) if not p.startswith("__")):
|
76 |
+
value = getattr(params_data, param_name)
|
77 |
+
param_string += "\t%s: %s<br>" % (param_name, value)
|
78 |
+
self.vis.text(param_string, opts={"title": "Parameters"})
|
79 |
+
|
80 |
+
def log_dataset(self, dataset: SpeakerVerificationDataset):
|
81 |
+
if self.disabled:
|
82 |
+
return
|
83 |
+
dataset_string = ""
|
84 |
+
dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
|
85 |
+
dataset_string += "\n" + dataset.get_logs()
|
86 |
+
dataset_string = dataset_string.replace("\n", "<br>")
|
87 |
+
self.vis.text(dataset_string, opts={"title": "Dataset"})
|
88 |
+
|
89 |
+
def log_implementation(self, params):
|
90 |
+
if self.disabled:
|
91 |
+
return
|
92 |
+
implementation_string = ""
|
93 |
+
for param, value in params.items():
|
94 |
+
implementation_string += "<b>%s</b>: %s\n" % (param, value)
|
95 |
+
implementation_string = implementation_string.replace("\n", "<br>")
|
96 |
+
self.implementation_string = implementation_string
|
97 |
+
self.implementation_win = self.vis.text(
|
98 |
+
implementation_string,
|
99 |
+
opts={"title": "Training implementation"}
|
100 |
+
)
|
101 |
+
|
102 |
+
def update(self, loss, eer, step):
|
103 |
+
# Update the tracking data
|
104 |
+
now = timer()
|
105 |
+
self.step_times.append(1000 * (now - self.last_update_timestamp))
|
106 |
+
self.last_update_timestamp = now
|
107 |
+
self.losses.append(loss)
|
108 |
+
self.eers.append(eer)
|
109 |
+
print(".", end="")
|
110 |
+
|
111 |
+
# Update the plots every <update_every> steps
|
112 |
+
if step % self.update_every != 0:
|
113 |
+
return
|
114 |
+
time_string = "Step time: mean: %5dms std: %5dms" % \
|
115 |
+
(int(np.mean(self.step_times)), int(np.std(self.step_times)))
|
116 |
+
print("\nStep %6d Loss: %.4f EER: %.4f %s" %
|
117 |
+
(step, np.mean(self.losses), np.mean(self.eers), time_string))
|
118 |
+
if not self.disabled:
|
119 |
+
self.loss_win = self.vis.line(
|
120 |
+
[np.mean(self.losses)],
|
121 |
+
[step],
|
122 |
+
win=self.loss_win,
|
123 |
+
update="append" if self.loss_win else None,
|
124 |
+
opts=dict(
|
125 |
+
legend=["Avg. loss"],
|
126 |
+
xlabel="Step",
|
127 |
+
ylabel="Loss",
|
128 |
+
title="Loss",
|
129 |
+
)
|
130 |
+
)
|
131 |
+
self.eer_win = self.vis.line(
|
132 |
+
[np.mean(self.eers)],
|
133 |
+
[step],
|
134 |
+
win=self.eer_win,
|
135 |
+
update="append" if self.eer_win else None,
|
136 |
+
opts=dict(
|
137 |
+
legend=["Avg. EER"],
|
138 |
+
xlabel="Step",
|
139 |
+
ylabel="EER",
|
140 |
+
title="Equal error rate"
|
141 |
+
)
|
142 |
+
)
|
143 |
+
if self.implementation_win is not None:
|
144 |
+
self.vis.text(
|
145 |
+
self.implementation_string + ("<b>%s</b>" % time_string),
|
146 |
+
win=self.implementation_win,
|
147 |
+
opts={"title": "Training implementation"},
|
148 |
+
)
|
149 |
+
|
150 |
+
# Reset the tracking
|
151 |
+
self.losses.clear()
|
152 |
+
self.eers.clear()
|
153 |
+
self.step_times.clear()
|
154 |
+
|
155 |
+
def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None,
|
156 |
+
max_speakers=10):
|
157 |
+
max_speakers = min(max_speakers, len(colormap))
|
158 |
+
embeds = embeds[:max_speakers * utterances_per_speaker]
|
159 |
+
|
160 |
+
n_speakers = len(embeds) // utterances_per_speaker
|
161 |
+
ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
|
162 |
+
colors = [colormap[i] for i in ground_truth]
|
163 |
+
|
164 |
+
reducer = umap.UMAP()
|
165 |
+
projected = reducer.fit_transform(embeds)
|
166 |
+
plt.scatter(projected[:, 0], projected[:, 1], c=colors)
|
167 |
+
plt.gca().set_aspect("equal", "datalim")
|
168 |
+
plt.title("UMAP projection (step %d)" % step)
|
169 |
+
if not self.disabled:
|
170 |
+
self.projection_win = self.vis.matplot(plt, win=self.projection_win)
|
171 |
+
if out_fpath is not None:
|
172 |
+
plt.savefig(out_fpath)
|
173 |
+
plt.clf()
|
174 |
+
|
175 |
+
def save(self):
|
176 |
+
if not self.disabled:
|
177 |
+
self.vis.save([self.env_name])
|
178 |
+
|
speaker_encoder/voice_encoder.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from speaker_encoder.hparams import *
|
2 |
+
from speaker_encoder import audio
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Union, List
|
5 |
+
from torch import nn
|
6 |
+
from time import perf_counter as timer
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class SpeakerEncoder(nn.Module):
|
12 |
+
def __init__(self, weights_fpath, device: Union[str, torch.device]=None, verbose=True):
|
13 |
+
"""
|
14 |
+
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda").
|
15 |
+
If None, defaults to cuda if it is available on your machine, otherwise the model will
|
16 |
+
run on cpu. Outputs are always returned on the cpu, as numpy arrays.
|
17 |
+
"""
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
# Define the network
|
21 |
+
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
|
22 |
+
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
|
23 |
+
self.relu = nn.ReLU()
|
24 |
+
|
25 |
+
# Get the target device
|
26 |
+
if device is None:
|
27 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
28 |
+
elif isinstance(device, str):
|
29 |
+
device = torch.device(device)
|
30 |
+
self.device = device
|
31 |
+
|
32 |
+
# Load the pretrained model'speaker weights
|
33 |
+
# weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt")
|
34 |
+
# if not weights_fpath.exists():
|
35 |
+
# raise Exception("Couldn't find the voice encoder pretrained model at %s." %
|
36 |
+
# weights_fpath)
|
37 |
+
|
38 |
+
start = timer()
|
39 |
+
checkpoint = torch.load(weights_fpath, map_location="cpu")
|
40 |
+
|
41 |
+
self.load_state_dict(checkpoint["model_state"], strict=False)
|
42 |
+
self.to(device)
|
43 |
+
|
44 |
+
if verbose:
|
45 |
+
print("Loaded the voice encoder model on %s in %.2f seconds." %
|
46 |
+
(device.type, timer() - start))
|
47 |
+
|
48 |
+
def forward(self, mels: torch.FloatTensor):
|
49 |
+
"""
|
50 |
+
Computes the embeddings of a batch of utterance spectrograms.
|
51 |
+
:param mels: a batch of mel spectrograms of same duration as a float32 tensor of shape
|
52 |
+
(batch_size, n_frames, n_channels)
|
53 |
+
:return: the embeddings as a float 32 tensor of shape (batch_size, embedding_size).
|
54 |
+
Embeddings are positive and L2-normed, thus they lay in the range [0, 1].
|
55 |
+
"""
|
56 |
+
# Pass the input through the LSTM layers and retrieve the final hidden state of the last
|
57 |
+
# layer. Apply a cutoff to 0 for negative values and L2 normalize the embeddings.
|
58 |
+
_, (hidden, _) = self.lstm(mels)
|
59 |
+
embeds_raw = self.relu(self.linear(hidden[-1]))
|
60 |
+
return embeds_raw / torch.norm(embeds_raw, dim=1, keepdim=True)
|
61 |
+
|
62 |
+
@staticmethod
|
63 |
+
def compute_partial_slices(n_samples: int, rate, min_coverage):
|
64 |
+
"""
|
65 |
+
Computes where to split an utterance waveform and its corresponding mel spectrogram to
|
66 |
+
obtain partial utterances of <partials_n_frames> each. Both the waveform and the
|
67 |
+
mel spectrogram slices are returned, so as to make each partial utterance waveform
|
68 |
+
correspond to its spectrogram.
|
69 |
+
|
70 |
+
The returned ranges may be indexing further than the length of the waveform. It is
|
71 |
+
recommended that you pad the waveform with zeros up to wav_slices[-1].stop.
|
72 |
+
|
73 |
+
:param n_samples: the number of samples in the waveform
|
74 |
+
:param rate: how many partial utterances should occur per second. Partial utterances must
|
75 |
+
cover the span of the entire utterance, thus the rate should not be lower than the inverse
|
76 |
+
of the duration of a partial utterance. By default, partial utterances are 1.6s long and
|
77 |
+
the minimum rate is thus 0.625.
|
78 |
+
:param min_coverage: when reaching the last partial utterance, it may or may not have
|
79 |
+
enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
|
80 |
+
then the last partial utterance will be considered by zero-padding the audio. Otherwise,
|
81 |
+
it will be discarded. If there aren't enough frames for one partial utterance,
|
82 |
+
this parameter is ignored so that the function always returns at least one slice.
|
83 |
+
:return: the waveform slices and mel spectrogram slices as lists of array slices. Index
|
84 |
+
respectively the waveform and the mel spectrogram with these slices to obtain the partial
|
85 |
+
utterances.
|
86 |
+
"""
|
87 |
+
assert 0 < min_coverage <= 1
|
88 |
+
|
89 |
+
# Compute how many frames separate two partial utterances
|
90 |
+
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
91 |
+
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
92 |
+
frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
|
93 |
+
assert 0 < frame_step, "The rate is too high"
|
94 |
+
assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
|
95 |
+
(sampling_rate / (samples_per_frame * partials_n_frames))
|
96 |
+
|
97 |
+
# Compute the slices
|
98 |
+
wav_slices, mel_slices = [], []
|
99 |
+
steps = max(1, n_frames - partials_n_frames + frame_step + 1)
|
100 |
+
for i in range(0, steps, frame_step):
|
101 |
+
mel_range = np.array([i, i + partials_n_frames])
|
102 |
+
wav_range = mel_range * samples_per_frame
|
103 |
+
mel_slices.append(slice(*mel_range))
|
104 |
+
wav_slices.append(slice(*wav_range))
|
105 |
+
|
106 |
+
# Evaluate whether extra padding is warranted or not
|
107 |
+
last_wav_range = wav_slices[-1]
|
108 |
+
coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
|
109 |
+
if coverage < min_coverage and len(mel_slices) > 1:
|
110 |
+
mel_slices = mel_slices[:-1]
|
111 |
+
wav_slices = wav_slices[:-1]
|
112 |
+
|
113 |
+
return wav_slices, mel_slices
|
114 |
+
|
115 |
+
def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_coverage=0.75):
|
116 |
+
"""
|
117 |
+
Computes an embedding for a single utterance. The utterance is divided in partial
|
118 |
+
utterances and an embedding is computed for each. The complete utterance embedding is the
|
119 |
+
L2-normed average embedding of the partial utterances.
|
120 |
+
|
121 |
+
TODO: independent batched version of this function
|
122 |
+
|
123 |
+
:param wav: a preprocessed utterance waveform as a numpy array of float32
|
124 |
+
:param return_partials: if True, the partial embeddings will also be returned along with
|
125 |
+
the wav slices corresponding to each partial utterance.
|
126 |
+
:param rate: how many partial utterances should occur per second. Partial utterances must
|
127 |
+
cover the span of the entire utterance, thus the rate should not be lower than the inverse
|
128 |
+
of the duration of a partial utterance. By default, partial utterances are 1.6s long and
|
129 |
+
the minimum rate is thus 0.625.
|
130 |
+
:param min_coverage: when reaching the last partial utterance, it may or may not have
|
131 |
+
enough frames. If at least <min_pad_coverage> of <partials_n_frames> are present,
|
132 |
+
then the last partial utterance will be considered by zero-padding the audio. Otherwise,
|
133 |
+
it will be discarded. If there aren't enough frames for one partial utterance,
|
134 |
+
this parameter is ignored so that the function always returns at least one slice.
|
135 |
+
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
|
136 |
+
<return_partials> is True, the partial utterances as a numpy array of float32 of shape
|
137 |
+
(n_partials, model_embedding_size) and the wav partials as a list of slices will also be
|
138 |
+
returned.
|
139 |
+
"""
|
140 |
+
# Compute where to split the utterance into partials and pad the waveform with zeros if
|
141 |
+
# the partial utterances cover a larger range.
|
142 |
+
wav_slices, mel_slices = self.compute_partial_slices(len(wav), rate, min_coverage)
|
143 |
+
max_wave_length = wav_slices[-1].stop
|
144 |
+
if max_wave_length >= len(wav):
|
145 |
+
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
146 |
+
|
147 |
+
# Split the utterance into partials and forward them through the model
|
148 |
+
mel = audio.wav_to_mel_spectrogram(wav)
|
149 |
+
mels = np.array([mel[s] for s in mel_slices])
|
150 |
+
with torch.no_grad():
|
151 |
+
mels = torch.from_numpy(mels).to(self.device)
|
152 |
+
partial_embeds = self(mels).cpu().numpy()
|
153 |
+
|
154 |
+
# Compute the utterance embedding from the partial embeddings
|
155 |
+
raw_embed = np.mean(partial_embeds, axis=0)
|
156 |
+
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
157 |
+
|
158 |
+
if return_partials:
|
159 |
+
return embed, partial_embeds, wav_slices
|
160 |
+
return embed
|
161 |
+
|
162 |
+
def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
|
163 |
+
"""
|
164 |
+
Compute the embedding of a collection of wavs (presumably from the same speaker) by
|
165 |
+
averaging their embedding and L2-normalizing it.
|
166 |
+
|
167 |
+
:param wavs: list of wavs a numpy arrays of float32.
|
168 |
+
:param kwargs: extra arguments to embed_utterance()
|
169 |
+
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,).
|
170 |
+
"""
|
171 |
+
raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) \
|
172 |
+
for wav in wavs], axis=0)
|
173 |
+
return raw_embed / np.linalg.norm(raw_embed, 2)
|
utils.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import argparse
|
4 |
+
import logging
|
5 |
+
import json
|
6 |
+
import subprocess
|
7 |
+
import numpy as np
|
8 |
+
from scipy.io.wavfile import read
|
9 |
+
import torch
|
10 |
+
from torch.nn import functional as F
|
11 |
+
from commons import sequence_mask
|
12 |
+
from wavlm import WavLM, WavLMConfig
|
13 |
+
|
14 |
+
MATPLOTLIB_FLAG = False
|
15 |
+
|
16 |
+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
17 |
+
logger = logging
|
18 |
+
|
19 |
+
|
20 |
+
def get_cmodel(rank):
|
21 |
+
checkpoint = torch.load('wavlm/WavLM-Large.pt')
|
22 |
+
cfg = WavLMConfig(checkpoint['cfg'])
|
23 |
+
cmodel = WavLM(cfg)
|
24 |
+
cmodel.load_state_dict(checkpoint['model'])
|
25 |
+
cmodel.eval()
|
26 |
+
return cmodel
|
27 |
+
|
28 |
+
|
29 |
+
def get_content(cmodel, y):
|
30 |
+
with torch.no_grad():
|
31 |
+
c = cmodel.extract_features(y.squeeze(1))[0]
|
32 |
+
c = c.transpose(1, 2)
|
33 |
+
return c
|
34 |
+
|
35 |
+
|
36 |
+
def get_vocoder(rank):
|
37 |
+
with open("hifigan/config.json", "r") as f:
|
38 |
+
config = json.load(f)
|
39 |
+
config = hifigan.AttrDict(config)
|
40 |
+
vocoder = hifigan.Generator(config)
|
41 |
+
ckpt = torch.load("hifigan/generator_v1")
|
42 |
+
vocoder.load_state_dict(ckpt["generator"])
|
43 |
+
vocoder.eval()
|
44 |
+
vocoder.remove_weight_norm()
|
45 |
+
vocoder.cuda(rank)
|
46 |
+
return vocoder
|
47 |
+
|
48 |
+
|
49 |
+
def transform(mel, height): # 68-92
|
50 |
+
#r = np.random.random()
|
51 |
+
#rate = r * 0.3 + 0.85 # 0.85-1.15
|
52 |
+
#height = int(mel.size(-2) * rate)
|
53 |
+
tgt = torchvision.transforms.functional.resize(mel, (height, mel.size(-1)))
|
54 |
+
if height >= mel.size(-2):
|
55 |
+
return tgt[:, :mel.size(-2), :]
|
56 |
+
else:
|
57 |
+
silence = tgt[:,-1:,:].repeat(1,mel.size(-2)-height,1)
|
58 |
+
silence += torch.randn_like(silence) / 10
|
59 |
+
return torch.cat((tgt, silence), 1)
|
60 |
+
|
61 |
+
|
62 |
+
def stretch(mel, width): # 0.5-2
|
63 |
+
return torchvision.transforms.functional.resize(mel, (mel.size(-2), width))
|
64 |
+
|
65 |
+
|
66 |
+
def load_checkpoint(checkpoint_path, model, optimizer=None):
|
67 |
+
assert os.path.isfile(checkpoint_path)
|
68 |
+
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
69 |
+
iteration = checkpoint_dict['iteration']
|
70 |
+
learning_rate = checkpoint_dict['learning_rate']
|
71 |
+
if optimizer is not None:
|
72 |
+
optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
73 |
+
saved_state_dict = checkpoint_dict['model']
|
74 |
+
if hasattr(model, 'module'):
|
75 |
+
state_dict = model.module.state_dict()
|
76 |
+
else:
|
77 |
+
state_dict = model.state_dict()
|
78 |
+
new_state_dict= {}
|
79 |
+
for k, v in state_dict.items():
|
80 |
+
try:
|
81 |
+
new_state_dict[k] = saved_state_dict[k]
|
82 |
+
except:
|
83 |
+
logger.info("%s is not in the checkpoint" % k)
|
84 |
+
new_state_dict[k] = v
|
85 |
+
if hasattr(model, 'module'):
|
86 |
+
model.module.load_state_dict(new_state_dict)
|
87 |
+
else:
|
88 |
+
model.load_state_dict(new_state_dict)
|
89 |
+
logger.info("Loaded checkpoint '{}' (iteration {})" .format(
|
90 |
+
checkpoint_path, iteration))
|
91 |
+
return model, optimizer, learning_rate, iteration
|
92 |
+
|
93 |
+
|
94 |
+
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
|
95 |
+
logger.info("Saving model and optimizer state at iteration {} to {}".format(
|
96 |
+
iteration, checkpoint_path))
|
97 |
+
if hasattr(model, 'module'):
|
98 |
+
state_dict = model.module.state_dict()
|
99 |
+
else:
|
100 |
+
state_dict = model.state_dict()
|
101 |
+
torch.save({'model': state_dict,
|
102 |
+
'iteration': iteration,
|
103 |
+
'optimizer': optimizer.state_dict(),
|
104 |
+
'learning_rate': learning_rate}, checkpoint_path)
|
105 |
+
|
106 |
+
|
107 |
+
def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
|
108 |
+
for k, v in scalars.items():
|
109 |
+
writer.add_scalar(k, v, global_step)
|
110 |
+
for k, v in histograms.items():
|
111 |
+
writer.add_histogram(k, v, global_step)
|
112 |
+
for k, v in images.items():
|
113 |
+
writer.add_image(k, v, global_step, dataformats='HWC')
|
114 |
+
for k, v in audios.items():
|
115 |
+
writer.add_audio(k, v, global_step, audio_sampling_rate)
|
116 |
+
|
117 |
+
|
118 |
+
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
|
119 |
+
f_list = glob.glob(os.path.join(dir_path, regex))
|
120 |
+
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
|
121 |
+
x = f_list[-1]
|
122 |
+
print(x)
|
123 |
+
return x
|
124 |
+
|
125 |
+
|
126 |
+
def plot_spectrogram_to_numpy(spectrogram):
|
127 |
+
global MATPLOTLIB_FLAG
|
128 |
+
if not MATPLOTLIB_FLAG:
|
129 |
+
import matplotlib
|
130 |
+
matplotlib.use("Agg")
|
131 |
+
MATPLOTLIB_FLAG = True
|
132 |
+
mpl_logger = logging.getLogger('matplotlib')
|
133 |
+
mpl_logger.setLevel(logging.WARNING)
|
134 |
+
import matplotlib.pylab as plt
|
135 |
+
import numpy as np
|
136 |
+
|
137 |
+
fig, ax = plt.subplots(figsize=(10,2))
|
138 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
139 |
+
interpolation='none')
|
140 |
+
plt.colorbar(im, ax=ax)
|
141 |
+
plt.xlabel("Frames")
|
142 |
+
plt.ylabel("Channels")
|
143 |
+
plt.tight_layout()
|
144 |
+
|
145 |
+
fig.canvas.draw()
|
146 |
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
147 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
148 |
+
plt.close()
|
149 |
+
return data
|
150 |
+
|
151 |
+
|
152 |
+
def plot_alignment_to_numpy(alignment, info=None):
|
153 |
+
global MATPLOTLIB_FLAG
|
154 |
+
if not MATPLOTLIB_FLAG:
|
155 |
+
import matplotlib
|
156 |
+
matplotlib.use("Agg")
|
157 |
+
MATPLOTLIB_FLAG = True
|
158 |
+
mpl_logger = logging.getLogger('matplotlib')
|
159 |
+
mpl_logger.setLevel(logging.WARNING)
|
160 |
+
import matplotlib.pylab as plt
|
161 |
+
import numpy as np
|
162 |
+
|
163 |
+
fig, ax = plt.subplots(figsize=(6, 4))
|
164 |
+
im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
|
165 |
+
interpolation='none')
|
166 |
+
fig.colorbar(im, ax=ax)
|
167 |
+
xlabel = 'Decoder timestep'
|
168 |
+
if info is not None:
|
169 |
+
xlabel += '\n\n' + info
|
170 |
+
plt.xlabel(xlabel)
|
171 |
+
plt.ylabel('Encoder timestep')
|
172 |
+
plt.tight_layout()
|
173 |
+
|
174 |
+
fig.canvas.draw()
|
175 |
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
176 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
177 |
+
plt.close()
|
178 |
+
return data
|
179 |
+
|
180 |
+
|
181 |
+
def load_wav_to_torch(full_path):
|
182 |
+
sampling_rate, data = read(full_path)
|
183 |
+
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
184 |
+
|
185 |
+
|
186 |
+
def load_filepaths_and_text(filename, split="|"):
|
187 |
+
with open(filename, encoding='utf-8') as f:
|
188 |
+
filepaths_and_text = [line.strip().split(split) for line in f]
|
189 |
+
return filepaths_and_text
|
190 |
+
|
191 |
+
|
192 |
+
def get_hparams(init=True):
|
193 |
+
parser = argparse.ArgumentParser()
|
194 |
+
parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
|
195 |
+
help='JSON file for configuration')
|
196 |
+
parser.add_argument('-m', '--model', type=str, required=True,
|
197 |
+
help='Model name')
|
198 |
+
|
199 |
+
args = parser.parse_args()
|
200 |
+
model_dir = os.path.join("./logs", args.model)
|
201 |
+
|
202 |
+
if not os.path.exists(model_dir):
|
203 |
+
os.makedirs(model_dir)
|
204 |
+
|
205 |
+
config_path = args.config
|
206 |
+
config_save_path = os.path.join(model_dir, "config.json")
|
207 |
+
if init:
|
208 |
+
with open(config_path, "r") as f:
|
209 |
+
data = f.read()
|
210 |
+
with open(config_save_path, "w") as f:
|
211 |
+
f.write(data)
|
212 |
+
else:
|
213 |
+
with open(config_save_path, "r") as f:
|
214 |
+
data = f.read()
|
215 |
+
config = json.loads(data)
|
216 |
+
|
217 |
+
hparams = HParams(**config)
|
218 |
+
hparams.model_dir = model_dir
|
219 |
+
return hparams
|
220 |
+
|
221 |
+
|
222 |
+
def get_hparams_from_dir(model_dir):
|
223 |
+
config_save_path = os.path.join(model_dir, "config.json")
|
224 |
+
with open(config_save_path, "r") as f:
|
225 |
+
data = f.read()
|
226 |
+
config = json.loads(data)
|
227 |
+
|
228 |
+
hparams =HParams(**config)
|
229 |
+
hparams.model_dir = model_dir
|
230 |
+
return hparams
|
231 |
+
|
232 |
+
|
233 |
+
def get_hparams_from_file(config_path):
|
234 |
+
with open(config_path, "r") as f:
|
235 |
+
data = f.read()
|
236 |
+
config = json.loads(data)
|
237 |
+
|
238 |
+
hparams =HParams(**config)
|
239 |
+
return hparams
|
240 |
+
|
241 |
+
|
242 |
+
def check_git_hash(model_dir):
|
243 |
+
source_dir = os.path.dirname(os.path.realpath(__file__))
|
244 |
+
if not os.path.exists(os.path.join(source_dir, ".git")):
|
245 |
+
logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
|
246 |
+
source_dir
|
247 |
+
))
|
248 |
+
return
|
249 |
+
|
250 |
+
cur_hash = subprocess.getoutput("git rev-parse HEAD")
|
251 |
+
|
252 |
+
path = os.path.join(model_dir, "githash")
|
253 |
+
if os.path.exists(path):
|
254 |
+
saved_hash = open(path).read()
|
255 |
+
if saved_hash != cur_hash:
|
256 |
+
logger.warn("git hash values are different. {}(saved) != {}(current)".format(
|
257 |
+
saved_hash[:8], cur_hash[:8]))
|
258 |
+
else:
|
259 |
+
open(path, "w").write(cur_hash)
|
260 |
+
|
261 |
+
|
262 |
+
def get_logger(model_dir, filename="train.log"):
|
263 |
+
global logger
|
264 |
+
logger = logging.getLogger(os.path.basename(model_dir))
|
265 |
+
logger.setLevel(logging.DEBUG)
|
266 |
+
|
267 |
+
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
|
268 |
+
if not os.path.exists(model_dir):
|
269 |
+
os.makedirs(model_dir)
|
270 |
+
h = logging.FileHandler(os.path.join(model_dir, filename))
|
271 |
+
h.setLevel(logging.DEBUG)
|
272 |
+
h.setFormatter(formatter)
|
273 |
+
logger.addHandler(h)
|
274 |
+
return logger
|
275 |
+
|
276 |
+
|
277 |
+
class HParams():
|
278 |
+
def __init__(self, **kwargs):
|
279 |
+
for k, v in kwargs.items():
|
280 |
+
if type(v) == dict:
|
281 |
+
v = HParams(**v)
|
282 |
+
self[k] = v
|
283 |
+
|
284 |
+
def keys(self):
|
285 |
+
return self.__dict__.keys()
|
286 |
+
|
287 |
+
def items(self):
|
288 |
+
return self.__dict__.items()
|
289 |
+
|
290 |
+
def values(self):
|
291 |
+
return self.__dict__.values()
|
292 |
+
|
293 |
+
def __len__(self):
|
294 |
+
return len(self.__dict__)
|
295 |
+
|
296 |
+
def __getitem__(self, key):
|
297 |
+
return getattr(self, key)
|
298 |
+
|
299 |
+
def __setitem__(self, key, value):
|
300 |
+
return setattr(self, key, value)
|
301 |
+
|
302 |
+
def __contains__(self, key):
|
303 |
+
return key in self.__dict__
|
304 |
+
|
305 |
+
def __repr__(self):
|
306 |
+
return self.__dict__.__repr__()
|
wavlm/WavLM-Large.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6fb4b3c3e6aa567f0a997b30855859cb81528ee8078802af439f7b2da0bf100f
|
3 |
+
size 1261965425
|
wavlm/WavLM-Large.pt.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
https://github.com/microsoft/unilm/tree/master/wavlm
|
wavlm/WavLM.py
ADDED
@@ -0,0 +1,742 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
4 |
+
# Copyright (c) 2021 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
import logging
|
12 |
+
from typing import List, Optional, Tuple
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch.nn import LayerNorm
|
20 |
+
from wavlm.modules import (
|
21 |
+
Fp32GroupNorm,
|
22 |
+
Fp32LayerNorm,
|
23 |
+
GradMultiply,
|
24 |
+
MultiheadAttention,
|
25 |
+
SamePad,
|
26 |
+
init_bert_params,
|
27 |
+
get_activation_fn,
|
28 |
+
TransposeLast,
|
29 |
+
GLU_Linear,
|
30 |
+
)
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
def compute_mask_indices(
|
36 |
+
shape: Tuple[int, int],
|
37 |
+
padding_mask: Optional[torch.Tensor],
|
38 |
+
mask_prob: float,
|
39 |
+
mask_length: int,
|
40 |
+
mask_type: str = "static",
|
41 |
+
mask_other: float = 0.0,
|
42 |
+
min_masks: int = 0,
|
43 |
+
no_overlap: bool = False,
|
44 |
+
min_space: int = 0,
|
45 |
+
) -> np.ndarray:
|
46 |
+
"""
|
47 |
+
Computes random mask spans for a given shape
|
48 |
+
|
49 |
+
Args:
|
50 |
+
shape: the the shape for which to compute masks.
|
51 |
+
should be of size 2 where first element is batch size and 2nd is timesteps
|
52 |
+
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
53 |
+
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
54 |
+
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
55 |
+
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
56 |
+
mask_type: how to compute mask lengths
|
57 |
+
static = fixed size
|
58 |
+
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
59 |
+
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
60 |
+
poisson = sample from possion distribution with lambda = mask length
|
61 |
+
min_masks: minimum number of masked spans
|
62 |
+
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
63 |
+
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
64 |
+
"""
|
65 |
+
|
66 |
+
bsz, all_sz = shape
|
67 |
+
mask = np.full((bsz, all_sz), False)
|
68 |
+
|
69 |
+
all_num_mask = int(
|
70 |
+
# add a random number for probabilistic rounding
|
71 |
+
mask_prob * all_sz / float(mask_length)
|
72 |
+
+ np.random.rand()
|
73 |
+
)
|
74 |
+
|
75 |
+
all_num_mask = max(min_masks, all_num_mask)
|
76 |
+
|
77 |
+
mask_idcs = []
|
78 |
+
for i in range(bsz):
|
79 |
+
if padding_mask is not None:
|
80 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
81 |
+
num_mask = int(
|
82 |
+
# add a random number for probabilistic rounding
|
83 |
+
mask_prob * sz / float(mask_length)
|
84 |
+
+ np.random.rand()
|
85 |
+
)
|
86 |
+
num_mask = max(min_masks, num_mask)
|
87 |
+
else:
|
88 |
+
sz = all_sz
|
89 |
+
num_mask = all_num_mask
|
90 |
+
|
91 |
+
if mask_type == "static":
|
92 |
+
lengths = np.full(num_mask, mask_length)
|
93 |
+
elif mask_type == "uniform":
|
94 |
+
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
95 |
+
elif mask_type == "normal":
|
96 |
+
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
97 |
+
lengths = [max(1, int(round(x))) for x in lengths]
|
98 |
+
elif mask_type == "poisson":
|
99 |
+
lengths = np.random.poisson(mask_length, size=num_mask)
|
100 |
+
lengths = [int(round(x)) for x in lengths]
|
101 |
+
else:
|
102 |
+
raise Exception("unknown mask selection " + mask_type)
|
103 |
+
|
104 |
+
if sum(lengths) == 0:
|
105 |
+
lengths[0] = min(mask_length, sz - 1)
|
106 |
+
|
107 |
+
if no_overlap:
|
108 |
+
mask_idc = []
|
109 |
+
|
110 |
+
def arrange(s, e, length, keep_length):
|
111 |
+
span_start = np.random.randint(s, e - length)
|
112 |
+
mask_idc.extend(span_start + i for i in range(length))
|
113 |
+
|
114 |
+
new_parts = []
|
115 |
+
if span_start - s - min_space >= keep_length:
|
116 |
+
new_parts.append((s, span_start - min_space + 1))
|
117 |
+
if e - span_start - keep_length - min_space > keep_length:
|
118 |
+
new_parts.append((span_start + length + min_space, e))
|
119 |
+
return new_parts
|
120 |
+
|
121 |
+
parts = [(0, sz)]
|
122 |
+
min_length = min(lengths)
|
123 |
+
for length in sorted(lengths, reverse=True):
|
124 |
+
lens = np.fromiter(
|
125 |
+
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
126 |
+
np.int,
|
127 |
+
)
|
128 |
+
l_sum = np.sum(lens)
|
129 |
+
if l_sum == 0:
|
130 |
+
break
|
131 |
+
probs = lens / np.sum(lens)
|
132 |
+
c = np.random.choice(len(parts), p=probs)
|
133 |
+
s, e = parts.pop(c)
|
134 |
+
parts.extend(arrange(s, e, length, min_length))
|
135 |
+
mask_idc = np.asarray(mask_idc)
|
136 |
+
else:
|
137 |
+
min_len = min(lengths)
|
138 |
+
if sz - min_len <= num_mask:
|
139 |
+
min_len = sz - num_mask - 1
|
140 |
+
|
141 |
+
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
142 |
+
|
143 |
+
mask_idc = np.asarray(
|
144 |
+
[
|
145 |
+
mask_idc[j] + offset
|
146 |
+
for j in range(len(mask_idc))
|
147 |
+
for offset in range(lengths[j])
|
148 |
+
]
|
149 |
+
)
|
150 |
+
|
151 |
+
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
152 |
+
|
153 |
+
min_len = min([len(m) for m in mask_idcs])
|
154 |
+
for i, mask_idc in enumerate(mask_idcs):
|
155 |
+
if len(mask_idc) > min_len:
|
156 |
+
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
157 |
+
mask[i, mask_idc] = True
|
158 |
+
|
159 |
+
return mask
|
160 |
+
|
161 |
+
|
162 |
+
class WavLMConfig:
|
163 |
+
def __init__(self, cfg=None):
|
164 |
+
self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
|
165 |
+
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
166 |
+
|
167 |
+
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
168 |
+
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
169 |
+
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
170 |
+
self.activation_fn: str = "gelu" # activation function to use
|
171 |
+
|
172 |
+
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
173 |
+
self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
|
174 |
+
self.conv_bias: bool = False # include bias in conv encoder
|
175 |
+
self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
|
176 |
+
|
177 |
+
self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
|
178 |
+
|
179 |
+
# dropouts
|
180 |
+
self.dropout: float = 0.1 # dropout probability for the transformer
|
181 |
+
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
182 |
+
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
183 |
+
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
184 |
+
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
185 |
+
self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
|
186 |
+
|
187 |
+
# masking
|
188 |
+
self.mask_length: int = 10 # mask length
|
189 |
+
self.mask_prob: float = 0.65 # probability of replacing a token with mask
|
190 |
+
self.mask_selection: str = "static" # how to choose mask length
|
191 |
+
self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
|
192 |
+
self.no_mask_overlap: bool = False # whether to allow masks to overlap
|
193 |
+
self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
194 |
+
|
195 |
+
# channel masking
|
196 |
+
self.mask_channel_length: int = 10 # length of the mask for features (channels)
|
197 |
+
self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
|
198 |
+
self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
|
199 |
+
self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
|
200 |
+
self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
|
201 |
+
self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
202 |
+
|
203 |
+
# positional embeddings
|
204 |
+
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
205 |
+
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
206 |
+
|
207 |
+
# relative position embedding
|
208 |
+
self.relative_position_embedding: bool = False # apply relative position embedding
|
209 |
+
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
210 |
+
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
211 |
+
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
212 |
+
|
213 |
+
if cfg is not None:
|
214 |
+
self.update(cfg)
|
215 |
+
|
216 |
+
def update(self, cfg: dict):
|
217 |
+
self.__dict__.update(cfg)
|
218 |
+
|
219 |
+
|
220 |
+
class WavLM(nn.Module):
|
221 |
+
def __init__(
|
222 |
+
self,
|
223 |
+
cfg: WavLMConfig,
|
224 |
+
) -> None:
|
225 |
+
super().__init__()
|
226 |
+
logger.info(f"WavLM Config: {cfg.__dict__}")
|
227 |
+
|
228 |
+
self.cfg = cfg
|
229 |
+
feature_enc_layers = eval(cfg.conv_feature_layers)
|
230 |
+
self.embed = feature_enc_layers[-1][0]
|
231 |
+
|
232 |
+
self.feature_extractor = ConvFeatureExtractionModel(
|
233 |
+
conv_layers=feature_enc_layers,
|
234 |
+
dropout=0.0,
|
235 |
+
mode=cfg.extractor_mode,
|
236 |
+
conv_bias=cfg.conv_bias,
|
237 |
+
)
|
238 |
+
|
239 |
+
self.post_extract_proj = (
|
240 |
+
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
241 |
+
if self.embed != cfg.encoder_embed_dim
|
242 |
+
else None
|
243 |
+
)
|
244 |
+
|
245 |
+
self.mask_prob = cfg.mask_prob
|
246 |
+
self.mask_selection = cfg.mask_selection
|
247 |
+
self.mask_other = cfg.mask_other
|
248 |
+
self.mask_length = cfg.mask_length
|
249 |
+
self.no_mask_overlap = cfg.no_mask_overlap
|
250 |
+
self.mask_min_space = cfg.mask_min_space
|
251 |
+
|
252 |
+
self.mask_channel_prob = cfg.mask_channel_prob
|
253 |
+
self.mask_channel_selection = cfg.mask_channel_selection
|
254 |
+
self.mask_channel_other = cfg.mask_channel_other
|
255 |
+
self.mask_channel_length = cfg.mask_channel_length
|
256 |
+
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
257 |
+
self.mask_channel_min_space = cfg.mask_channel_min_space
|
258 |
+
|
259 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
260 |
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
261 |
+
|
262 |
+
self.feature_grad_mult = cfg.feature_grad_mult
|
263 |
+
|
264 |
+
self.mask_emb = nn.Parameter(
|
265 |
+
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
266 |
+
)
|
267 |
+
|
268 |
+
self.encoder = TransformerEncoder(cfg)
|
269 |
+
self.layer_norm = LayerNorm(self.embed)
|
270 |
+
|
271 |
+
def apply_mask(self, x, padding_mask):
|
272 |
+
B, T, C = x.shape
|
273 |
+
if self.mask_prob > 0:
|
274 |
+
mask_indices = compute_mask_indices(
|
275 |
+
(B, T),
|
276 |
+
padding_mask,
|
277 |
+
self.mask_prob,
|
278 |
+
self.mask_length,
|
279 |
+
self.mask_selection,
|
280 |
+
self.mask_other,
|
281 |
+
min_masks=2,
|
282 |
+
no_overlap=self.no_mask_overlap,
|
283 |
+
min_space=self.mask_min_space,
|
284 |
+
)
|
285 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
286 |
+
x[mask_indices] = self.mask_emb
|
287 |
+
else:
|
288 |
+
mask_indices = None
|
289 |
+
|
290 |
+
if self.mask_channel_prob > 0:
|
291 |
+
mask_channel_indices = compute_mask_indices(
|
292 |
+
(B, C),
|
293 |
+
None,
|
294 |
+
self.mask_channel_prob,
|
295 |
+
self.mask_channel_length,
|
296 |
+
self.mask_channel_selection,
|
297 |
+
self.mask_channel_other,
|
298 |
+
no_overlap=self.no_mask_channel_overlap,
|
299 |
+
min_space=self.mask_channel_min_space,
|
300 |
+
)
|
301 |
+
mask_channel_indices = (
|
302 |
+
torch.from_numpy(mask_channel_indices)
|
303 |
+
.to(x.device)
|
304 |
+
.unsqueeze(1)
|
305 |
+
.expand(-1, T, -1)
|
306 |
+
)
|
307 |
+
x[mask_channel_indices] = 0
|
308 |
+
|
309 |
+
return x, mask_indices
|
310 |
+
|
311 |
+
def forward_padding_mask(
|
312 |
+
self, features: torch.Tensor, padding_mask: torch.Tensor,
|
313 |
+
) -> torch.Tensor:
|
314 |
+
extra = padding_mask.size(1) % features.size(1)
|
315 |
+
if extra > 0:
|
316 |
+
padding_mask = padding_mask[:, :-extra]
|
317 |
+
padding_mask = padding_mask.view(
|
318 |
+
padding_mask.size(0), features.size(1), -1
|
319 |
+
)
|
320 |
+
#padding_mask = padding_mask.all(-1)
|
321 |
+
padding_mask = padding_mask.any(-1)
|
322 |
+
return padding_mask
|
323 |
+
|
324 |
+
def extract_features(
|
325 |
+
self,
|
326 |
+
source: torch.Tensor,
|
327 |
+
padding_mask: Optional[torch.Tensor] = None,
|
328 |
+
mask: bool = False,
|
329 |
+
ret_conv: bool = False,
|
330 |
+
output_layer: Optional[int] = None,
|
331 |
+
ret_layer_results: bool = False,
|
332 |
+
):
|
333 |
+
if self.feature_grad_mult > 0:
|
334 |
+
features = self.feature_extractor(source)
|
335 |
+
if self.feature_grad_mult != 1.0:
|
336 |
+
features = GradMultiply.apply(features, self.feature_grad_mult)
|
337 |
+
else:
|
338 |
+
with torch.no_grad():
|
339 |
+
features = self.feature_extractor(source)
|
340 |
+
|
341 |
+
features = features.transpose(1, 2)
|
342 |
+
features = self.layer_norm(features)
|
343 |
+
|
344 |
+
if padding_mask is not None:
|
345 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
346 |
+
|
347 |
+
if self.post_extract_proj is not None:
|
348 |
+
features = self.post_extract_proj(features)
|
349 |
+
|
350 |
+
features = self.dropout_input(features)
|
351 |
+
|
352 |
+
if mask:
|
353 |
+
x, mask_indices = self.apply_mask(
|
354 |
+
features, padding_mask
|
355 |
+
)
|
356 |
+
else:
|
357 |
+
x = features
|
358 |
+
|
359 |
+
# feature: (B, T, D), float
|
360 |
+
# target: (B, T), long
|
361 |
+
# x: (B, T, D), float
|
362 |
+
# padding_mask: (B, T), bool
|
363 |
+
# mask_indices: (B, T), bool
|
364 |
+
x, layer_results = self.encoder(
|
365 |
+
x,
|
366 |
+
padding_mask=padding_mask,
|
367 |
+
layer=None if output_layer is None else output_layer - 1
|
368 |
+
)
|
369 |
+
|
370 |
+
res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
|
371 |
+
|
372 |
+
feature = res["features"] if ret_conv else res["x"]
|
373 |
+
if ret_layer_results:
|
374 |
+
feature = (feature, res["layer_results"])
|
375 |
+
return feature, res["padding_mask"]
|
376 |
+
|
377 |
+
|
378 |
+
class ConvFeatureExtractionModel(nn.Module):
|
379 |
+
def __init__(
|
380 |
+
self,
|
381 |
+
conv_layers: List[Tuple[int, int, int]],
|
382 |
+
dropout: float = 0.0,
|
383 |
+
mode: str = "default",
|
384 |
+
conv_bias: bool = False,
|
385 |
+
conv_type: str = "default"
|
386 |
+
):
|
387 |
+
super().__init__()
|
388 |
+
|
389 |
+
assert mode in {"default", "layer_norm"}
|
390 |
+
|
391 |
+
def block(
|
392 |
+
n_in,
|
393 |
+
n_out,
|
394 |
+
k,
|
395 |
+
stride,
|
396 |
+
is_layer_norm=False,
|
397 |
+
is_group_norm=False,
|
398 |
+
conv_bias=False,
|
399 |
+
):
|
400 |
+
def make_conv():
|
401 |
+
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
402 |
+
nn.init.kaiming_normal_(conv.weight)
|
403 |
+
return conv
|
404 |
+
|
405 |
+
assert (
|
406 |
+
is_layer_norm and is_group_norm
|
407 |
+
) == False, "layer norm and group norm are exclusive"
|
408 |
+
|
409 |
+
if is_layer_norm:
|
410 |
+
return nn.Sequential(
|
411 |
+
make_conv(),
|
412 |
+
nn.Dropout(p=dropout),
|
413 |
+
nn.Sequential(
|
414 |
+
TransposeLast(),
|
415 |
+
Fp32LayerNorm(dim, elementwise_affine=True),
|
416 |
+
TransposeLast(),
|
417 |
+
),
|
418 |
+
nn.GELU(),
|
419 |
+
)
|
420 |
+
elif is_group_norm:
|
421 |
+
return nn.Sequential(
|
422 |
+
make_conv(),
|
423 |
+
nn.Dropout(p=dropout),
|
424 |
+
Fp32GroupNorm(dim, dim, affine=True),
|
425 |
+
nn.GELU(),
|
426 |
+
)
|
427 |
+
else:
|
428 |
+
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
429 |
+
|
430 |
+
self.conv_type = conv_type
|
431 |
+
if self.conv_type == "default":
|
432 |
+
in_d = 1
|
433 |
+
self.conv_layers = nn.ModuleList()
|
434 |
+
for i, cl in enumerate(conv_layers):
|
435 |
+
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
436 |
+
(dim, k, stride) = cl
|
437 |
+
|
438 |
+
self.conv_layers.append(
|
439 |
+
block(
|
440 |
+
in_d,
|
441 |
+
dim,
|
442 |
+
k,
|
443 |
+
stride,
|
444 |
+
is_layer_norm=mode == "layer_norm",
|
445 |
+
is_group_norm=mode == "default" and i == 0,
|
446 |
+
conv_bias=conv_bias,
|
447 |
+
)
|
448 |
+
)
|
449 |
+
in_d = dim
|
450 |
+
elif self.conv_type == "conv2d":
|
451 |
+
in_d = 1
|
452 |
+
self.conv_layers = nn.ModuleList()
|
453 |
+
for i, cl in enumerate(conv_layers):
|
454 |
+
assert len(cl) == 3
|
455 |
+
(dim, k, stride) = cl
|
456 |
+
|
457 |
+
self.conv_layers.append(
|
458 |
+
torch.nn.Conv2d(in_d, dim, k, stride)
|
459 |
+
)
|
460 |
+
self.conv_layers.append(torch.nn.ReLU())
|
461 |
+
in_d = dim
|
462 |
+
elif self.conv_type == "custom":
|
463 |
+
in_d = 1
|
464 |
+
idim = 80
|
465 |
+
self.conv_layers = nn.ModuleList()
|
466 |
+
for i, cl in enumerate(conv_layers):
|
467 |
+
assert len(cl) == 3
|
468 |
+
(dim, k, stride) = cl
|
469 |
+
self.conv_layers.append(
|
470 |
+
torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
|
471 |
+
)
|
472 |
+
self.conv_layers.append(
|
473 |
+
torch.nn.LayerNorm([dim, idim])
|
474 |
+
)
|
475 |
+
self.conv_layers.append(torch.nn.ReLU())
|
476 |
+
in_d = dim
|
477 |
+
if (i + 1) % 2 == 0:
|
478 |
+
self.conv_layers.append(
|
479 |
+
torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
480 |
+
)
|
481 |
+
idim = int(math.ceil(idim / 2))
|
482 |
+
else:
|
483 |
+
pass
|
484 |
+
|
485 |
+
def forward(self, x, mask=None):
|
486 |
+
|
487 |
+
# BxT -> BxCxT
|
488 |
+
x = x.unsqueeze(1)
|
489 |
+
if self.conv_type == "custom":
|
490 |
+
for conv in self.conv_layers:
|
491 |
+
if isinstance(conv, nn.LayerNorm):
|
492 |
+
x = x.transpose(1, 2)
|
493 |
+
x = conv(x).transpose(1, 2)
|
494 |
+
else:
|
495 |
+
x = conv(x)
|
496 |
+
x = x.transpose(2, 3).contiguous()
|
497 |
+
x = x.view(x.size(0), -1, x.size(-1))
|
498 |
+
else:
|
499 |
+
for conv in self.conv_layers:
|
500 |
+
x = conv(x)
|
501 |
+
if self.conv_type == "conv2d":
|
502 |
+
b, c, t, f = x.size()
|
503 |
+
x = x.transpose(2, 3).contiguous().view(b, c * f, t)
|
504 |
+
return x
|
505 |
+
|
506 |
+
|
507 |
+
class TransformerEncoder(nn.Module):
|
508 |
+
def __init__(self, args):
|
509 |
+
super().__init__()
|
510 |
+
|
511 |
+
self.dropout = args.dropout
|
512 |
+
self.embedding_dim = args.encoder_embed_dim
|
513 |
+
|
514 |
+
self.pos_conv = nn.Conv1d(
|
515 |
+
self.embedding_dim,
|
516 |
+
self.embedding_dim,
|
517 |
+
kernel_size=args.conv_pos,
|
518 |
+
padding=args.conv_pos // 2,
|
519 |
+
groups=args.conv_pos_groups,
|
520 |
+
)
|
521 |
+
dropout = 0
|
522 |
+
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
523 |
+
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
524 |
+
nn.init.constant_(self.pos_conv.bias, 0)
|
525 |
+
|
526 |
+
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
527 |
+
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
528 |
+
|
529 |
+
if hasattr(args, "relative_position_embedding"):
|
530 |
+
self.relative_position_embedding = args.relative_position_embedding
|
531 |
+
self.num_buckets = args.num_buckets
|
532 |
+
self.max_distance = args.max_distance
|
533 |
+
else:
|
534 |
+
self.relative_position_embedding = False
|
535 |
+
self.num_buckets = 0
|
536 |
+
self.max_distance = 0
|
537 |
+
|
538 |
+
self.layers = nn.ModuleList(
|
539 |
+
[
|
540 |
+
TransformerSentenceEncoderLayer(
|
541 |
+
embedding_dim=self.embedding_dim,
|
542 |
+
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
543 |
+
num_attention_heads=args.encoder_attention_heads,
|
544 |
+
dropout=self.dropout,
|
545 |
+
attention_dropout=args.attention_dropout,
|
546 |
+
activation_dropout=args.activation_dropout,
|
547 |
+
activation_fn=args.activation_fn,
|
548 |
+
layer_norm_first=args.layer_norm_first,
|
549 |
+
has_relative_attention_bias=(self.relative_position_embedding and i == 0),
|
550 |
+
num_buckets=self.num_buckets,
|
551 |
+
max_distance=self.max_distance,
|
552 |
+
gru_rel_pos=args.gru_rel_pos,
|
553 |
+
)
|
554 |
+
for i in range(args.encoder_layers)
|
555 |
+
]
|
556 |
+
)
|
557 |
+
|
558 |
+
self.layer_norm_first = args.layer_norm_first
|
559 |
+
self.layer_norm = LayerNorm(self.embedding_dim)
|
560 |
+
self.layerdrop = args.encoder_layerdrop
|
561 |
+
|
562 |
+
self.apply(init_bert_params)
|
563 |
+
|
564 |
+
def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
|
565 |
+
x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
|
566 |
+
|
567 |
+
if self.layer_norm_first and layer is None:
|
568 |
+
x = self.layer_norm(x)
|
569 |
+
|
570 |
+
return x, layer_results
|
571 |
+
|
572 |
+
def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
|
573 |
+
|
574 |
+
if padding_mask is not None:
|
575 |
+
x[padding_mask] = 0
|
576 |
+
|
577 |
+
x_conv = self.pos_conv(x.transpose(1, 2))
|
578 |
+
x_conv = x_conv.transpose(1, 2)
|
579 |
+
x += x_conv
|
580 |
+
|
581 |
+
if not self.layer_norm_first:
|
582 |
+
x = self.layer_norm(x)
|
583 |
+
|
584 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
585 |
+
|
586 |
+
# B x T x C -> T x B x C
|
587 |
+
x = x.transpose(0, 1)
|
588 |
+
|
589 |
+
layer_results = []
|
590 |
+
z = None
|
591 |
+
if tgt_layer is not None:
|
592 |
+
layer_results.append((x, z))
|
593 |
+
r = None
|
594 |
+
pos_bias = None
|
595 |
+
for i, layer in enumerate(self.layers):
|
596 |
+
dropout_probability = np.random.random()
|
597 |
+
if not self.training or (dropout_probability > self.layerdrop):
|
598 |
+
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False,
|
599 |
+
self_attn_mask=streaming_mask, pos_bias=pos_bias)
|
600 |
+
if tgt_layer is not None:
|
601 |
+
layer_results.append((x, z))
|
602 |
+
if i == tgt_layer:
|
603 |
+
r = x
|
604 |
+
break
|
605 |
+
|
606 |
+
if r is not None:
|
607 |
+
x = r
|
608 |
+
|
609 |
+
# T x B x C -> B x T x C
|
610 |
+
x = x.transpose(0, 1)
|
611 |
+
|
612 |
+
return x, layer_results
|
613 |
+
|
614 |
+
|
615 |
+
class TransformerSentenceEncoderLayer(nn.Module):
|
616 |
+
"""
|
617 |
+
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
|
618 |
+
models.
|
619 |
+
"""
|
620 |
+
|
621 |
+
def __init__(
|
622 |
+
self,
|
623 |
+
embedding_dim: float = 768,
|
624 |
+
ffn_embedding_dim: float = 3072,
|
625 |
+
num_attention_heads: float = 8,
|
626 |
+
dropout: float = 0.1,
|
627 |
+
attention_dropout: float = 0.1,
|
628 |
+
activation_dropout: float = 0.1,
|
629 |
+
activation_fn: str = "relu",
|
630 |
+
layer_norm_first: bool = False,
|
631 |
+
has_relative_attention_bias: bool = False,
|
632 |
+
num_buckets: int = 0,
|
633 |
+
max_distance: int = 0,
|
634 |
+
rescale_init: bool = False,
|
635 |
+
gru_rel_pos: bool = False,
|
636 |
+
) -> None:
|
637 |
+
|
638 |
+
super().__init__()
|
639 |
+
# Initialize parameters
|
640 |
+
self.embedding_dim = embedding_dim
|
641 |
+
self.dropout = dropout
|
642 |
+
self.activation_dropout = activation_dropout
|
643 |
+
|
644 |
+
# Initialize blocks
|
645 |
+
self.activation_name = activation_fn
|
646 |
+
self.activation_fn = get_activation_fn(activation_fn)
|
647 |
+
self.self_attn = MultiheadAttention(
|
648 |
+
self.embedding_dim,
|
649 |
+
num_attention_heads,
|
650 |
+
dropout=attention_dropout,
|
651 |
+
self_attention=True,
|
652 |
+
has_relative_attention_bias=has_relative_attention_bias,
|
653 |
+
num_buckets=num_buckets,
|
654 |
+
max_distance=max_distance,
|
655 |
+
rescale_init=rescale_init,
|
656 |
+
gru_rel_pos=gru_rel_pos,
|
657 |
+
)
|
658 |
+
|
659 |
+
self.dropout1 = nn.Dropout(dropout)
|
660 |
+
self.dropout2 = nn.Dropout(self.activation_dropout)
|
661 |
+
self.dropout3 = nn.Dropout(dropout)
|
662 |
+
|
663 |
+
self.layer_norm_first = layer_norm_first
|
664 |
+
|
665 |
+
# layer norm associated with the self attention layer
|
666 |
+
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
667 |
+
|
668 |
+
if self.activation_name == "glu":
|
669 |
+
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
670 |
+
else:
|
671 |
+
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
672 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
673 |
+
|
674 |
+
# layer norm associated with the position wise feed-forward NN
|
675 |
+
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
676 |
+
|
677 |
+
def forward(
|
678 |
+
self,
|
679 |
+
x: torch.Tensor,
|
680 |
+
self_attn_mask: torch.Tensor = None,
|
681 |
+
self_attn_padding_mask: torch.Tensor = None,
|
682 |
+
need_weights: bool = False,
|
683 |
+
pos_bias=None
|
684 |
+
):
|
685 |
+
"""
|
686 |
+
LayerNorm is applied either before or after the self-attention/ffn
|
687 |
+
modules similar to the original Transformer imlementation.
|
688 |
+
"""
|
689 |
+
residual = x
|
690 |
+
|
691 |
+
if self.layer_norm_first:
|
692 |
+
x = self.self_attn_layer_norm(x)
|
693 |
+
x, attn, pos_bias = self.self_attn(
|
694 |
+
query=x,
|
695 |
+
key=x,
|
696 |
+
value=x,
|
697 |
+
key_padding_mask=self_attn_padding_mask,
|
698 |
+
need_weights=False,
|
699 |
+
attn_mask=self_attn_mask,
|
700 |
+
position_bias=pos_bias
|
701 |
+
)
|
702 |
+
x = self.dropout1(x)
|
703 |
+
x = residual + x
|
704 |
+
|
705 |
+
residual = x
|
706 |
+
x = self.final_layer_norm(x)
|
707 |
+
if self.activation_name == "glu":
|
708 |
+
x = self.fc1(x)
|
709 |
+
else:
|
710 |
+
x = self.activation_fn(self.fc1(x))
|
711 |
+
x = self.dropout2(x)
|
712 |
+
x = self.fc2(x)
|
713 |
+
x = self.dropout3(x)
|
714 |
+
x = residual + x
|
715 |
+
else:
|
716 |
+
x, attn, pos_bias = self.self_attn(
|
717 |
+
query=x,
|
718 |
+
key=x,
|
719 |
+
value=x,
|
720 |
+
key_padding_mask=self_attn_padding_mask,
|
721 |
+
need_weights=need_weights,
|
722 |
+
attn_mask=self_attn_mask,
|
723 |
+
position_bias=pos_bias
|
724 |
+
)
|
725 |
+
|
726 |
+
x = self.dropout1(x)
|
727 |
+
x = residual + x
|
728 |
+
|
729 |
+
x = self.self_attn_layer_norm(x)
|
730 |
+
|
731 |
+
residual = x
|
732 |
+
if self.activation_name == "glu":
|
733 |
+
x = self.fc1(x)
|
734 |
+
else:
|
735 |
+
x = self.activation_fn(self.fc1(x))
|
736 |
+
x = self.dropout2(x)
|
737 |
+
x = self.fc2(x)
|
738 |
+
x = self.dropout3(x)
|
739 |
+
x = residual + x
|
740 |
+
x = self.final_layer_norm(x)
|
741 |
+
|
742 |
+
return x, attn, pos_bias
|
wavlm/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from wavlm.WavLM import WavLM, WavLMConfig
|
wavlm/__pycache__/WavLM.cpython-39.pyc
ADDED
Binary file (16.5 kB). View file
|
|
wavlm/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (191 Bytes). View file
|
|
wavlm/__pycache__/modules.cpython-39.pyc
ADDED
Binary file (19.3 kB). View file
|
|
wavlm/modules.py
ADDED
@@ -0,0 +1,827 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
4 |
+
# Copyright (c) 2021 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
import warnings
|
12 |
+
from typing import Dict, Optional, Tuple
|
13 |
+
import torch
|
14 |
+
from torch import Tensor, nn
|
15 |
+
from torch.nn import Parameter
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
|
19 |
+
class TransposeLast(nn.Module):
|
20 |
+
def __init__(self, deconstruct_idx=None):
|
21 |
+
super().__init__()
|
22 |
+
self.deconstruct_idx = deconstruct_idx
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
if self.deconstruct_idx is not None:
|
26 |
+
x = x[self.deconstruct_idx]
|
27 |
+
return x.transpose(-2, -1)
|
28 |
+
|
29 |
+
|
30 |
+
class Fp32LayerNorm(nn.LayerNorm):
|
31 |
+
def __init__(self, *args, **kwargs):
|
32 |
+
super().__init__(*args, **kwargs)
|
33 |
+
|
34 |
+
def forward(self, input):
|
35 |
+
output = F.layer_norm(
|
36 |
+
input.float(),
|
37 |
+
self.normalized_shape,
|
38 |
+
self.weight.float() if self.weight is not None else None,
|
39 |
+
self.bias.float() if self.bias is not None else None,
|
40 |
+
self.eps,
|
41 |
+
)
|
42 |
+
return output.type_as(input)
|
43 |
+
|
44 |
+
|
45 |
+
class Fp32GroupNorm(nn.GroupNorm):
|
46 |
+
def __init__(self, *args, **kwargs):
|
47 |
+
super().__init__(*args, **kwargs)
|
48 |
+
|
49 |
+
def forward(self, input):
|
50 |
+
output = F.group_norm(
|
51 |
+
input.float(),
|
52 |
+
self.num_groups,
|
53 |
+
self.weight.float() if self.weight is not None else None,
|
54 |
+
self.bias.float() if self.bias is not None else None,
|
55 |
+
self.eps,
|
56 |
+
)
|
57 |
+
return output.type_as(input)
|
58 |
+
|
59 |
+
|
60 |
+
class GradMultiply(torch.autograd.Function):
|
61 |
+
@staticmethod
|
62 |
+
def forward(ctx, x, scale):
|
63 |
+
ctx.scale = scale
|
64 |
+
res = x.new(x)
|
65 |
+
return res
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def backward(ctx, grad):
|
69 |
+
return grad * ctx.scale, None
|
70 |
+
|
71 |
+
|
72 |
+
class SamePad(nn.Module):
|
73 |
+
def __init__(self, kernel_size, causal=False):
|
74 |
+
super().__init__()
|
75 |
+
if causal:
|
76 |
+
self.remove = kernel_size - 1
|
77 |
+
else:
|
78 |
+
self.remove = 1 if kernel_size % 2 == 0 else 0
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
if self.remove > 0:
|
82 |
+
x = x[:, :, : -self.remove]
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
class Swish(nn.Module):
|
87 |
+
"""Swish function
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self):
|
91 |
+
"""Construct an MultiHeadedAttention object."""
|
92 |
+
super(Swish, self).__init__()
|
93 |
+
self.act = torch.nn.Sigmoid()
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
return x * self.act(x)
|
97 |
+
|
98 |
+
|
99 |
+
class GLU_Linear(nn.Module):
|
100 |
+
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
101 |
+
super(GLU_Linear, self).__init__()
|
102 |
+
|
103 |
+
self.glu_type = glu_type
|
104 |
+
self.output_dim = output_dim
|
105 |
+
|
106 |
+
if glu_type == "sigmoid":
|
107 |
+
self.glu_act = torch.nn.Sigmoid()
|
108 |
+
elif glu_type == "swish":
|
109 |
+
self.glu_act = Swish()
|
110 |
+
elif glu_type == "relu":
|
111 |
+
self.glu_act = torch.nn.ReLU()
|
112 |
+
elif glu_type == "gelu":
|
113 |
+
self.glu_act = torch.nn.GELU()
|
114 |
+
|
115 |
+
if bias_in_glu:
|
116 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
117 |
+
else:
|
118 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
122 |
+
x = self.linear(x)
|
123 |
+
|
124 |
+
if self.glu_type == "bilinear":
|
125 |
+
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
|
126 |
+
else:
|
127 |
+
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
|
128 |
+
|
129 |
+
return x
|
130 |
+
|
131 |
+
|
132 |
+
def gelu_accurate(x):
|
133 |
+
if not hasattr(gelu_accurate, "_a"):
|
134 |
+
gelu_accurate._a = math.sqrt(2 / math.pi)
|
135 |
+
return (
|
136 |
+
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
137 |
+
)
|
138 |
+
|
139 |
+
|
140 |
+
def gelu(x: torch.Tensor) -> torch.Tensor:
|
141 |
+
return torch.nn.functional.gelu(x.float()).type_as(x)
|
142 |
+
|
143 |
+
|
144 |
+
def get_activation_fn(activation: str):
|
145 |
+
"""Returns the activation function corresponding to `activation`"""
|
146 |
+
|
147 |
+
if activation == "relu":
|
148 |
+
return F.relu
|
149 |
+
elif activation == "gelu":
|
150 |
+
return gelu
|
151 |
+
elif activation == "gelu_fast":
|
152 |
+
warnings.warn(
|
153 |
+
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
|
154 |
+
)
|
155 |
+
return gelu_accurate
|
156 |
+
elif activation == "gelu_accurate":
|
157 |
+
return gelu_accurate
|
158 |
+
elif activation == "tanh":
|
159 |
+
return torch.tanh
|
160 |
+
elif activation == "linear":
|
161 |
+
return lambda x: x
|
162 |
+
elif activation == "glu":
|
163 |
+
return lambda x: x
|
164 |
+
else:
|
165 |
+
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
166 |
+
|
167 |
+
|
168 |
+
def init_bert_params(module):
|
169 |
+
"""
|
170 |
+
Initialize the weights specific to the BERT Model.
|
171 |
+
This overrides the default initializations depending on the specified arguments.
|
172 |
+
1. If normal_init_linear_weights is set then weights of linear
|
173 |
+
layer will be initialized using the normal distribution and
|
174 |
+
bais will be set to the specified value.
|
175 |
+
2. If normal_init_embed_weights is set then weights of embedding
|
176 |
+
layer will be initialized using the normal distribution.
|
177 |
+
3. If normal_init_proj_weights is set then weights of
|
178 |
+
in_project_weight for MultiHeadAttention initialized using
|
179 |
+
the normal distribution (to be validated).
|
180 |
+
"""
|
181 |
+
|
182 |
+
def normal_(data):
|
183 |
+
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
184 |
+
# so that the RNG is consistent with and without FSDP
|
185 |
+
data.copy_(
|
186 |
+
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
|
187 |
+
)
|
188 |
+
|
189 |
+
if isinstance(module, nn.Linear):
|
190 |
+
normal_(module.weight.data)
|
191 |
+
if module.bias is not None:
|
192 |
+
module.bias.data.zero_()
|
193 |
+
if isinstance(module, nn.Embedding):
|
194 |
+
normal_(module.weight.data)
|
195 |
+
if module.padding_idx is not None:
|
196 |
+
module.weight.data[module.padding_idx].zero_()
|
197 |
+
if isinstance(module, MultiheadAttention):
|
198 |
+
normal_(module.q_proj.weight.data)
|
199 |
+
normal_(module.k_proj.weight.data)
|
200 |
+
normal_(module.v_proj.weight.data)
|
201 |
+
|
202 |
+
|
203 |
+
def quant_noise(module, p, block_size):
|
204 |
+
"""
|
205 |
+
Wraps modules and applies quantization noise to the weights for
|
206 |
+
subsequent quantization with Iterative Product Quantization as
|
207 |
+
described in "Training with Quantization Noise for Extreme Model Compression"
|
208 |
+
|
209 |
+
Args:
|
210 |
+
- module: nn.Module
|
211 |
+
- p: amount of Quantization Noise
|
212 |
+
- block_size: size of the blocks for subsequent quantization with iPQ
|
213 |
+
|
214 |
+
Remarks:
|
215 |
+
- Module weights must have the right sizes wrt the block size
|
216 |
+
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
217 |
+
- For more detail on how to quantize by blocks with convolutional weights,
|
218 |
+
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
219 |
+
- We implement the simplest form of noise here as stated in the paper
|
220 |
+
which consists in randomly dropping blocks
|
221 |
+
"""
|
222 |
+
|
223 |
+
# if no quantization noise, don't register hook
|
224 |
+
if p <= 0:
|
225 |
+
return module
|
226 |
+
|
227 |
+
# supported modules
|
228 |
+
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
229 |
+
|
230 |
+
# test whether module.weight has the right sizes wrt block_size
|
231 |
+
is_conv = module.weight.ndim == 4
|
232 |
+
|
233 |
+
# 2D matrix
|
234 |
+
if not is_conv:
|
235 |
+
assert (
|
236 |
+
module.weight.size(1) % block_size == 0
|
237 |
+
), "Input features must be a multiple of block sizes"
|
238 |
+
|
239 |
+
# 4D matrix
|
240 |
+
else:
|
241 |
+
# 1x1 convolutions
|
242 |
+
if module.kernel_size == (1, 1):
|
243 |
+
assert (
|
244 |
+
module.in_channels % block_size == 0
|
245 |
+
), "Input channels must be a multiple of block sizes"
|
246 |
+
# regular convolutions
|
247 |
+
else:
|
248 |
+
k = module.kernel_size[0] * module.kernel_size[1]
|
249 |
+
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
250 |
+
|
251 |
+
def _forward_pre_hook(mod, input):
|
252 |
+
# no noise for evaluation
|
253 |
+
if mod.training:
|
254 |
+
if not is_conv:
|
255 |
+
# gather weight and sizes
|
256 |
+
weight = mod.weight
|
257 |
+
in_features = weight.size(1)
|
258 |
+
out_features = weight.size(0)
|
259 |
+
|
260 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
261 |
+
mask = torch.zeros(
|
262 |
+
in_features // block_size * out_features, device=weight.device
|
263 |
+
)
|
264 |
+
mask.bernoulli_(p)
|
265 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
266 |
+
|
267 |
+
else:
|
268 |
+
# gather weight and sizes
|
269 |
+
weight = mod.weight
|
270 |
+
in_channels = mod.in_channels
|
271 |
+
out_channels = mod.out_channels
|
272 |
+
|
273 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
274 |
+
if mod.kernel_size == (1, 1):
|
275 |
+
mask = torch.zeros(
|
276 |
+
int(in_channels // block_size * out_channels),
|
277 |
+
device=weight.device,
|
278 |
+
)
|
279 |
+
mask.bernoulli_(p)
|
280 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
281 |
+
else:
|
282 |
+
mask = torch.zeros(
|
283 |
+
weight.size(0), weight.size(1), device=weight.device
|
284 |
+
)
|
285 |
+
mask.bernoulli_(p)
|
286 |
+
mask = (
|
287 |
+
mask.unsqueeze(2)
|
288 |
+
.unsqueeze(3)
|
289 |
+
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
290 |
+
)
|
291 |
+
|
292 |
+
# scale weights and apply mask
|
293 |
+
mask = mask.to(
|
294 |
+
torch.bool
|
295 |
+
) # x.bool() is not currently supported in TorchScript
|
296 |
+
s = 1 / (1 - p)
|
297 |
+
mod.weight.data = s * weight.masked_fill(mask, 0)
|
298 |
+
|
299 |
+
module.register_forward_pre_hook(_forward_pre_hook)
|
300 |
+
return module
|
301 |
+
|
302 |
+
|
303 |
+
class MultiheadAttention(nn.Module):
|
304 |
+
"""Multi-headed attention.
|
305 |
+
|
306 |
+
See "Attention Is All You Need" for more details.
|
307 |
+
"""
|
308 |
+
|
309 |
+
def __init__(
|
310 |
+
self,
|
311 |
+
embed_dim,
|
312 |
+
num_heads,
|
313 |
+
kdim=None,
|
314 |
+
vdim=None,
|
315 |
+
dropout=0.0,
|
316 |
+
bias=True,
|
317 |
+
add_bias_kv=False,
|
318 |
+
add_zero_attn=False,
|
319 |
+
self_attention=False,
|
320 |
+
encoder_decoder_attention=False,
|
321 |
+
q_noise=0.0,
|
322 |
+
qn_block_size=8,
|
323 |
+
has_relative_attention_bias=False,
|
324 |
+
num_buckets=32,
|
325 |
+
max_distance=128,
|
326 |
+
gru_rel_pos=False,
|
327 |
+
rescale_init=False,
|
328 |
+
):
|
329 |
+
super().__init__()
|
330 |
+
self.embed_dim = embed_dim
|
331 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
332 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
333 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
334 |
+
|
335 |
+
self.num_heads = num_heads
|
336 |
+
self.dropout_module = nn.Dropout(dropout)
|
337 |
+
|
338 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
339 |
+
self.num_buckets = num_buckets
|
340 |
+
self.max_distance = max_distance
|
341 |
+
if self.has_relative_attention_bias:
|
342 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
343 |
+
|
344 |
+
self.head_dim = embed_dim // num_heads
|
345 |
+
self.q_head_dim = self.head_dim
|
346 |
+
self.k_head_dim = self.head_dim
|
347 |
+
assert (
|
348 |
+
self.head_dim * num_heads == self.embed_dim
|
349 |
+
), "embed_dim must be divisible by num_heads"
|
350 |
+
self.scaling = self.head_dim ** -0.5
|
351 |
+
|
352 |
+
self.self_attention = self_attention
|
353 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
354 |
+
|
355 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
356 |
+
"Self-attention requires query, key and " "value to be of the same size"
|
357 |
+
)
|
358 |
+
|
359 |
+
k_bias = True
|
360 |
+
if rescale_init:
|
361 |
+
k_bias = False
|
362 |
+
|
363 |
+
k_embed_dim = embed_dim
|
364 |
+
q_embed_dim = embed_dim
|
365 |
+
|
366 |
+
self.k_proj = quant_noise(
|
367 |
+
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
|
368 |
+
)
|
369 |
+
self.v_proj = quant_noise(
|
370 |
+
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
371 |
+
)
|
372 |
+
self.q_proj = quant_noise(
|
373 |
+
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
|
374 |
+
)
|
375 |
+
|
376 |
+
self.out_proj = quant_noise(
|
377 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
378 |
+
)
|
379 |
+
|
380 |
+
if add_bias_kv:
|
381 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
382 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
383 |
+
else:
|
384 |
+
self.bias_k = self.bias_v = None
|
385 |
+
|
386 |
+
self.add_zero_attn = add_zero_attn
|
387 |
+
|
388 |
+
self.gru_rel_pos = gru_rel_pos
|
389 |
+
if self.gru_rel_pos:
|
390 |
+
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
391 |
+
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
392 |
+
|
393 |
+
self.reset_parameters()
|
394 |
+
|
395 |
+
def reset_parameters(self):
|
396 |
+
if self.qkv_same_dim:
|
397 |
+
# Empirically observed the convergence to be much better with
|
398 |
+
# the scaled initialization
|
399 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
400 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
401 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
402 |
+
else:
|
403 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
404 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
405 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
406 |
+
|
407 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
408 |
+
if self.out_proj.bias is not None:
|
409 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
410 |
+
if self.bias_k is not None:
|
411 |
+
nn.init.xavier_normal_(self.bias_k)
|
412 |
+
if self.bias_v is not None:
|
413 |
+
nn.init.xavier_normal_(self.bias_v)
|
414 |
+
if self.has_relative_attention_bias:
|
415 |
+
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
416 |
+
|
417 |
+
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
418 |
+
num_buckets = self.num_buckets
|
419 |
+
max_distance = self.max_distance
|
420 |
+
relative_buckets = 0
|
421 |
+
|
422 |
+
if bidirectional:
|
423 |
+
num_buckets = num_buckets // 2
|
424 |
+
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
425 |
+
relative_positions = torch.abs(relative_positions)
|
426 |
+
else:
|
427 |
+
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
428 |
+
|
429 |
+
max_exact = num_buckets // 2
|
430 |
+
is_small = relative_positions < max_exact
|
431 |
+
|
432 |
+
relative_postion_if_large = max_exact + (
|
433 |
+
torch.log(relative_positions.float() / max_exact)
|
434 |
+
/ math.log(max_distance / max_exact)
|
435 |
+
* (num_buckets - max_exact)
|
436 |
+
).to(torch.long)
|
437 |
+
relative_postion_if_large = torch.min(
|
438 |
+
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
439 |
+
)
|
440 |
+
|
441 |
+
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
442 |
+
return relative_buckets
|
443 |
+
|
444 |
+
def compute_bias(self, query_length, key_length):
|
445 |
+
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
446 |
+
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
447 |
+
relative_position = memory_position - context_position
|
448 |
+
relative_position_bucket = self._relative_positions_bucket(
|
449 |
+
relative_position,
|
450 |
+
bidirectional=True
|
451 |
+
)
|
452 |
+
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
453 |
+
values = self.relative_attention_bias(relative_position_bucket)
|
454 |
+
values = values.permute([2, 0, 1])
|
455 |
+
return values
|
456 |
+
|
457 |
+
def forward(
|
458 |
+
self,
|
459 |
+
query,
|
460 |
+
key: Optional[Tensor],
|
461 |
+
value: Optional[Tensor],
|
462 |
+
key_padding_mask: Optional[Tensor] = None,
|
463 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
464 |
+
need_weights: bool = True,
|
465 |
+
static_kv: bool = False,
|
466 |
+
attn_mask: Optional[Tensor] = None,
|
467 |
+
before_softmax: bool = False,
|
468 |
+
need_head_weights: bool = False,
|
469 |
+
position_bias: Optional[Tensor] = None
|
470 |
+
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
471 |
+
"""Input shape: Time x Batch x Channel
|
472 |
+
|
473 |
+
Args:
|
474 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
475 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
476 |
+
padding elements are indicated by 1s.
|
477 |
+
need_weights (bool, optional): return the attention weights,
|
478 |
+
averaged over heads (default: False).
|
479 |
+
attn_mask (ByteTensor, optional): typically used to
|
480 |
+
implement causal attention, where the mask prevents the
|
481 |
+
attention from looking forward in time (default: None).
|
482 |
+
before_softmax (bool, optional): return the raw attention
|
483 |
+
weights and values before the attention softmax.
|
484 |
+
need_head_weights (bool, optional): return the attention
|
485 |
+
weights for each head. Implies *need_weights*. Default:
|
486 |
+
return the average attention weights over all heads.
|
487 |
+
"""
|
488 |
+
if need_head_weights:
|
489 |
+
need_weights = True
|
490 |
+
|
491 |
+
is_tpu = query.device.type == "xla"
|
492 |
+
|
493 |
+
tgt_len, bsz, embed_dim = query.size()
|
494 |
+
src_len = tgt_len
|
495 |
+
assert embed_dim == self.embed_dim
|
496 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
497 |
+
if key is not None:
|
498 |
+
src_len, key_bsz, _ = key.size()
|
499 |
+
if not torch.jit.is_scripting():
|
500 |
+
assert key_bsz == bsz
|
501 |
+
assert value is not None
|
502 |
+
assert src_len, bsz == value.shape[:2]
|
503 |
+
|
504 |
+
if self.has_relative_attention_bias and position_bias is None:
|
505 |
+
position_bias = self.compute_bias(tgt_len, src_len)
|
506 |
+
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
507 |
+
|
508 |
+
if (
|
509 |
+
not is_tpu # don't use PyTorch version on TPUs
|
510 |
+
and incremental_state is None
|
511 |
+
and not static_kv
|
512 |
+
# A workaround for quantization to work. Otherwise JIT compilation
|
513 |
+
# treats bias in linear module as method.
|
514 |
+
and not torch.jit.is_scripting()
|
515 |
+
and self.q_head_dim == self.head_dim
|
516 |
+
):
|
517 |
+
assert key is not None and value is not None
|
518 |
+
assert attn_mask is None
|
519 |
+
|
520 |
+
attn_mask_rel_pos = None
|
521 |
+
if position_bias is not None:
|
522 |
+
attn_mask_rel_pos = position_bias
|
523 |
+
if self.gru_rel_pos:
|
524 |
+
query_layer = query.transpose(0, 1)
|
525 |
+
new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
|
526 |
+
query_layer = query_layer.view(*new_x_shape)
|
527 |
+
query_layer = query_layer.permute(0, 2, 1, 3)
|
528 |
+
_B, _H, _L, __ = query_layer.size()
|
529 |
+
|
530 |
+
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
|
531 |
+
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
532 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
533 |
+
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
534 |
+
|
535 |
+
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
|
536 |
+
k_proj_bias = self.k_proj.bias
|
537 |
+
if k_proj_bias is None:
|
538 |
+
k_proj_bias = torch.zeros_like(self.q_proj.bias)
|
539 |
+
|
540 |
+
x, attn = F.multi_head_attention_forward(
|
541 |
+
query,
|
542 |
+
key,
|
543 |
+
value,
|
544 |
+
self.embed_dim,
|
545 |
+
self.num_heads,
|
546 |
+
torch.empty([0]),
|
547 |
+
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
548 |
+
self.bias_k,
|
549 |
+
self.bias_v,
|
550 |
+
self.add_zero_attn,
|
551 |
+
self.dropout_module.p,
|
552 |
+
self.out_proj.weight,
|
553 |
+
self.out_proj.bias,
|
554 |
+
self.training,
|
555 |
+
# self.training or self.dropout_module.apply_during_inference,
|
556 |
+
key_padding_mask,
|
557 |
+
need_weights,
|
558 |
+
attn_mask_rel_pos,
|
559 |
+
use_separate_proj_weight=True,
|
560 |
+
q_proj_weight=self.q_proj.weight,
|
561 |
+
k_proj_weight=self.k_proj.weight,
|
562 |
+
v_proj_weight=self.v_proj.weight,
|
563 |
+
)
|
564 |
+
return x, attn, position_bias
|
565 |
+
|
566 |
+
if incremental_state is not None:
|
567 |
+
saved_state = self._get_input_buffer(incremental_state)
|
568 |
+
if saved_state is not None and "prev_key" in saved_state:
|
569 |
+
# previous time steps are cached - no need to recompute
|
570 |
+
# key and value if they are static
|
571 |
+
if static_kv:
|
572 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
573 |
+
key = value = None
|
574 |
+
else:
|
575 |
+
saved_state = None
|
576 |
+
|
577 |
+
if self.self_attention:
|
578 |
+
q = self.q_proj(query)
|
579 |
+
k = self.k_proj(query)
|
580 |
+
v = self.v_proj(query)
|
581 |
+
elif self.encoder_decoder_attention:
|
582 |
+
# encoder-decoder attention
|
583 |
+
q = self.q_proj(query)
|
584 |
+
if key is None:
|
585 |
+
assert value is None
|
586 |
+
k = v = None
|
587 |
+
else:
|
588 |
+
k = self.k_proj(key)
|
589 |
+
v = self.v_proj(key)
|
590 |
+
|
591 |
+
else:
|
592 |
+
assert key is not None and value is not None
|
593 |
+
q = self.q_proj(query)
|
594 |
+
k = self.k_proj(key)
|
595 |
+
v = self.v_proj(value)
|
596 |
+
q *= self.scaling
|
597 |
+
|
598 |
+
if self.bias_k is not None:
|
599 |
+
assert self.bias_v is not None
|
600 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
601 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
602 |
+
if attn_mask is not None:
|
603 |
+
attn_mask = torch.cat(
|
604 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
605 |
+
)
|
606 |
+
if key_padding_mask is not None:
|
607 |
+
key_padding_mask = torch.cat(
|
608 |
+
[
|
609 |
+
key_padding_mask,
|
610 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
611 |
+
],
|
612 |
+
dim=1,
|
613 |
+
)
|
614 |
+
|
615 |
+
q = (
|
616 |
+
q.contiguous()
|
617 |
+
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
618 |
+
.transpose(0, 1)
|
619 |
+
)
|
620 |
+
if k is not None:
|
621 |
+
k = (
|
622 |
+
k.contiguous()
|
623 |
+
.view(-1, bsz * self.num_heads, self.k_head_dim)
|
624 |
+
.transpose(0, 1)
|
625 |
+
)
|
626 |
+
if v is not None:
|
627 |
+
v = (
|
628 |
+
v.contiguous()
|
629 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
630 |
+
.transpose(0, 1)
|
631 |
+
)
|
632 |
+
|
633 |
+
if saved_state is not None:
|
634 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
635 |
+
if "prev_key" in saved_state:
|
636 |
+
_prev_key = saved_state["prev_key"]
|
637 |
+
assert _prev_key is not None
|
638 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
639 |
+
if static_kv:
|
640 |
+
k = prev_key
|
641 |
+
else:
|
642 |
+
assert k is not None
|
643 |
+
k = torch.cat([prev_key, k], dim=1)
|
644 |
+
src_len = k.size(1)
|
645 |
+
if "prev_value" in saved_state:
|
646 |
+
_prev_value = saved_state["prev_value"]
|
647 |
+
assert _prev_value is not None
|
648 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
649 |
+
if static_kv:
|
650 |
+
v = prev_value
|
651 |
+
else:
|
652 |
+
assert v is not None
|
653 |
+
v = torch.cat([prev_value, v], dim=1)
|
654 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
655 |
+
if "prev_key_padding_mask" in saved_state:
|
656 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
657 |
+
assert k is not None and v is not None
|
658 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
659 |
+
key_padding_mask=key_padding_mask,
|
660 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
661 |
+
batch_size=bsz,
|
662 |
+
src_len=k.size(1),
|
663 |
+
static_kv=static_kv,
|
664 |
+
)
|
665 |
+
|
666 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
667 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
668 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
669 |
+
# In this branch incremental_state is never None
|
670 |
+
assert incremental_state is not None
|
671 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
672 |
+
assert k is not None
|
673 |
+
assert k.size(1) == src_len
|
674 |
+
|
675 |
+
# This is part of a workaround to get around fork/join parallelism
|
676 |
+
# not supporting Optional types.
|
677 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
678 |
+
key_padding_mask = None
|
679 |
+
|
680 |
+
if key_padding_mask is not None:
|
681 |
+
assert key_padding_mask.size(0) == bsz
|
682 |
+
assert key_padding_mask.size(1) == src_len
|
683 |
+
|
684 |
+
if self.add_zero_attn:
|
685 |
+
assert v is not None
|
686 |
+
src_len += 1
|
687 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
688 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
689 |
+
if attn_mask is not None:
|
690 |
+
attn_mask = torch.cat(
|
691 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
692 |
+
)
|
693 |
+
if key_padding_mask is not None:
|
694 |
+
key_padding_mask = torch.cat(
|
695 |
+
[
|
696 |
+
key_padding_mask,
|
697 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
698 |
+
key_padding_mask
|
699 |
+
),
|
700 |
+
],
|
701 |
+
dim=1,
|
702 |
+
)
|
703 |
+
|
704 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
705 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
706 |
+
|
707 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
708 |
+
|
709 |
+
if attn_mask is not None:
|
710 |
+
attn_mask = attn_mask.unsqueeze(0)
|
711 |
+
attn_weights += attn_mask
|
712 |
+
|
713 |
+
if key_padding_mask is not None:
|
714 |
+
# don't attend to padding symbols
|
715 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
716 |
+
if not is_tpu:
|
717 |
+
attn_weights = attn_weights.masked_fill(
|
718 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
719 |
+
float("-inf"),
|
720 |
+
)
|
721 |
+
else:
|
722 |
+
attn_weights = attn_weights.transpose(0, 2)
|
723 |
+
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
724 |
+
attn_weights = attn_weights.transpose(0, 2)
|
725 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
726 |
+
|
727 |
+
if before_softmax:
|
728 |
+
return attn_weights, v, position_bias
|
729 |
+
|
730 |
+
if position_bias is not None:
|
731 |
+
if self.gru_rel_pos == 1:
|
732 |
+
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
|
733 |
+
_B, _H, _L, __ = query_layer.size()
|
734 |
+
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
|
735 |
+
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
736 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
737 |
+
position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
738 |
+
|
739 |
+
position_bias = position_bias.view(attn_weights.size())
|
740 |
+
|
741 |
+
attn_weights = attn_weights + position_bias
|
742 |
+
|
743 |
+
attn_weights_float = F.softmax(
|
744 |
+
attn_weights, dim=-1
|
745 |
+
)
|
746 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
747 |
+
attn_probs = self.dropout_module(attn_weights)
|
748 |
+
|
749 |
+
assert v is not None
|
750 |
+
attn = torch.bmm(attn_probs, v)
|
751 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
752 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
753 |
+
attn = self.out_proj(attn)
|
754 |
+
attn_weights: Optional[Tensor] = None
|
755 |
+
if need_weights:
|
756 |
+
attn_weights = attn_weights_float.view(
|
757 |
+
bsz, self.num_heads, tgt_len, src_len
|
758 |
+
).transpose(1, 0)
|
759 |
+
if not need_head_weights:
|
760 |
+
# average attention weights over heads
|
761 |
+
attn_weights = attn_weights.mean(dim=0)
|
762 |
+
|
763 |
+
return attn, attn_weights, position_bias
|
764 |
+
|
765 |
+
@staticmethod
|
766 |
+
def _append_prev_key_padding_mask(
|
767 |
+
key_padding_mask: Optional[Tensor],
|
768 |
+
prev_key_padding_mask: Optional[Tensor],
|
769 |
+
batch_size: int,
|
770 |
+
src_len: int,
|
771 |
+
static_kv: bool,
|
772 |
+
) -> Optional[Tensor]:
|
773 |
+
# saved key padding masks have shape (bsz, seq_len)
|
774 |
+
if prev_key_padding_mask is not None and static_kv:
|
775 |
+
new_key_padding_mask = prev_key_padding_mask
|
776 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
777 |
+
new_key_padding_mask = torch.cat(
|
778 |
+
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
779 |
+
)
|
780 |
+
# During incremental decoding, as the padding token enters and
|
781 |
+
# leaves the frame, there will be a time when prev or current
|
782 |
+
# is None
|
783 |
+
elif prev_key_padding_mask is not None:
|
784 |
+
if src_len > prev_key_padding_mask.size(1):
|
785 |
+
filler = torch.zeros(
|
786 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
787 |
+
device=prev_key_padding_mask.device,
|
788 |
+
)
|
789 |
+
new_key_padding_mask = torch.cat(
|
790 |
+
[prev_key_padding_mask.float(), filler.float()], dim=1
|
791 |
+
)
|
792 |
+
else:
|
793 |
+
new_key_padding_mask = prev_key_padding_mask.float()
|
794 |
+
elif key_padding_mask is not None:
|
795 |
+
if src_len > key_padding_mask.size(1):
|
796 |
+
filler = torch.zeros(
|
797 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
798 |
+
device=key_padding_mask.device,
|
799 |
+
)
|
800 |
+
new_key_padding_mask = torch.cat(
|
801 |
+
[filler.float(), key_padding_mask.float()], dim=1
|
802 |
+
)
|
803 |
+
else:
|
804 |
+
new_key_padding_mask = key_padding_mask.float()
|
805 |
+
else:
|
806 |
+
new_key_padding_mask = prev_key_padding_mask
|
807 |
+
return new_key_padding_mask
|
808 |
+
|
809 |
+
def _get_input_buffer(
|
810 |
+
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
811 |
+
) -> Dict[str, Optional[Tensor]]:
|
812 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
813 |
+
if result is not None:
|
814 |
+
return result
|
815 |
+
else:
|
816 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
817 |
+
return empty_result
|
818 |
+
|
819 |
+
def _set_input_buffer(
|
820 |
+
self,
|
821 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
822 |
+
buffer: Dict[str, Optional[Tensor]],
|
823 |
+
):
|
824 |
+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
825 |
+
|
826 |
+
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
827 |
+
return attn_weights
|