Spaces:
Running
Running
Inital demo
Browse files- LICENSE +25 -0
- app.py +69 -0
- hifigan/config_v1_wavlm.json +40 -0
- hifigan/meldataset.py +208 -0
- hifigan/models.py +289 -0
- hifigan/train.py +335 -0
- hifigan/utils.py +73 -0
- hubconf.py +75 -0
- knnvc_utils.py +23 -0
- matcher.py +172 -0
- prematch_dataset.py +172 -0
- requirements.txt +5 -0
- wavlm/WavLM.py +743 -0
- wavlm/modules.py +827 -0
LICENSE
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 MediaLab, Department of Electrical & Electronic Engineering, Stellenbosch University
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software, and you spend at least 10 seconds
|
14 |
+
thinking about whether the idea of copyright for Software actually makes sense
|
15 |
+
the first time you download the Software.
|
16 |
+
|
17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
23 |
+
SOFTWARE.
|
24 |
+
|
25 |
+
|
app.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
import spaces
|
4 |
+
from typing import List
|
5 |
+
import soundfile as sf
|
6 |
+
|
7 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
+
knn_vc = torch.hub.load('bshall/knn-vc', 'knn_vc', prematched=True, trust_repo=True, pretrained=True, device=device)
|
9 |
+
|
10 |
+
|
11 |
+
def convert_voice(src_wav_path:str, ref_wav_paths, top_k:int):
|
12 |
+
|
13 |
+
query_seq = knn_vc.get_features(src_wav_path)
|
14 |
+
matching_set = knn_vc.get_matching_set([ref_wav_paths])
|
15 |
+
out_wav = knn_vc.match(query_seq, matching_set, topk=int(top_k))
|
16 |
+
|
17 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as converted_file:
|
18 |
+
sf.write(converted_file.name, out_wav, 16000, "PCM_24")
|
19 |
+
|
20 |
+
return converted_file.name
|
21 |
+
|
22 |
+
|
23 |
+
title = """
|
24 |
+
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
|
25 |
+
<div
|
26 |
+
style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"
|
27 |
+
> <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
|
28 |
+
KNN Voice Conversion
|
29 |
+
</h1> </div>
|
30 |
+
</div>
|
31 |
+
"""
|
32 |
+
|
33 |
+
description = """
|
34 |
+
Voice Conversion With Just k-Nearest Neighbors. The source and reference utterance(s) are encoded into self-supervised features using WavLM.
|
35 |
+
Each source feature is assigned to the mean of the k closest features from the reference.
|
36 |
+
The resulting feature sequence is then vocoded with HiFi-GAN to arrive at the converted waveform output.
|
37 |
+
"""
|
38 |
+
|
39 |
+
article = """
|
40 |
+
If the model contributes to your research please cite the following work:
|
41 |
+
|
42 |
+
Baas, M., van Niekerk, B., & Kamper, H. (2023). Voice conversion with just nearest neighbors. arXiv preprint arXiv:2305.18975.
|
43 |
+
|
44 |
+
demo contributed by [@wetdog](https://github.com/wetdog)
|
45 |
+
"""
|
46 |
+
demo = gr.Blocks()
|
47 |
+
with demo:
|
48 |
+
gr.Markdown(title)
|
49 |
+
gr.Markdown(description)
|
50 |
+
gr.Interface(
|
51 |
+
fn=convert_voice,
|
52 |
+
inputs=[
|
53 |
+
gr.Audio(type='filepath'),
|
54 |
+
gr.Audio(type='filepath'),
|
55 |
+
gr.Slider(
|
56 |
+
3,
|
57 |
+
10,
|
58 |
+
value=4,
|
59 |
+
step=1,
|
60 |
+
label="Top-k",
|
61 |
+
info=f"These default settings provide pretty good results, but feel free to modify the kNN topk",
|
62 |
+
)],
|
63 |
+
outputs=[gr.Audio(type='filepath')],
|
64 |
+
allow_flagging=False,)
|
65 |
+
gr.Markdown(article)
|
66 |
+
|
67 |
+
demo.queue(max_size=10)
|
68 |
+
demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860)
|
69 |
+
|
hifigan/config_v1_wavlm.json
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"resblock": "1",
|
3 |
+
"num_gpus": 0,
|
4 |
+
"batch_size": 16,
|
5 |
+
"learning_rate": 0.0002,
|
6 |
+
"adam_b1": 0.8,
|
7 |
+
"adam_b2": 0.99,
|
8 |
+
"lr_decay": 0.999,
|
9 |
+
"seed": 1234,
|
10 |
+
|
11 |
+
"upsample_rates": [10,8,2,2],
|
12 |
+
"upsample_kernel_sizes": [20,16,4,4],
|
13 |
+
"upsample_initial_channel": 512,
|
14 |
+
"resblock_kernel_sizes": [3,7,11],
|
15 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
16 |
+
|
17 |
+
"hubert_dim": 1024,
|
18 |
+
"hifi_dim": 512,
|
19 |
+
|
20 |
+
"segment_size": 7040,
|
21 |
+
"num_mels": 80,
|
22 |
+
"num_freq": 1025,
|
23 |
+
"n_fft": 1024,
|
24 |
+
"hop_size": 320,
|
25 |
+
"win_size": 1024,
|
26 |
+
|
27 |
+
"sampling_rate": 16000,
|
28 |
+
|
29 |
+
"fmin": 0,
|
30 |
+
"fmax": 8000,
|
31 |
+
"fmax_for_loss": null,
|
32 |
+
|
33 |
+
"num_workers": 4,
|
34 |
+
|
35 |
+
"dist_config": {
|
36 |
+
"dist_backend": "nccl",
|
37 |
+
"dist_url": "tcp://localhost:54321",
|
38 |
+
"world_size": 1
|
39 |
+
}
|
40 |
+
}
|
hifigan/meldataset.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import librosa
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.utils.data
|
12 |
+
import torchaudio
|
13 |
+
from librosa.filters import mel as librosa_mel_fn
|
14 |
+
from librosa.util import normalize
|
15 |
+
from scipy.io.wavfile import read
|
16 |
+
|
17 |
+
|
18 |
+
def load_wav(full_path):
|
19 |
+
#sampling_rate, data = read(full_path)
|
20 |
+
#return data, sampling_rate
|
21 |
+
data, sampling_rate = librosa.load(full_path, sr=None)
|
22 |
+
return data, sampling_rate
|
23 |
+
|
24 |
+
|
25 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
26 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
27 |
+
|
28 |
+
|
29 |
+
def dynamic_range_decompression(x, C=1):
|
30 |
+
return np.exp(x) / C
|
31 |
+
|
32 |
+
|
33 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
34 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
35 |
+
|
36 |
+
|
37 |
+
def dynamic_range_decompression_torch(x, C=1):
|
38 |
+
return torch.exp(x) / C
|
39 |
+
|
40 |
+
|
41 |
+
def spectral_normalize_torch(magnitudes):
|
42 |
+
output = dynamic_range_compression_torch(magnitudes)
|
43 |
+
return output
|
44 |
+
|
45 |
+
|
46 |
+
def spectral_de_normalize_torch(magnitudes):
|
47 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
48 |
+
return output
|
49 |
+
|
50 |
+
|
51 |
+
mel_basis = {}
|
52 |
+
hann_window = {}
|
53 |
+
|
54 |
+
class LogMelSpectrogram(torch.nn.Module):
|
55 |
+
def __init__(self, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
56 |
+
super().__init__()
|
57 |
+
self.melspctrogram = torchaudio.transforms.MelSpectrogram(
|
58 |
+
sample_rate=sampling_rate,
|
59 |
+
n_fft=n_fft,
|
60 |
+
win_length=win_size,
|
61 |
+
hop_length=hop_size,
|
62 |
+
center=center,
|
63 |
+
power=1.0,
|
64 |
+
norm="slaney",
|
65 |
+
onesided=True,
|
66 |
+
n_mels=num_mels,
|
67 |
+
mel_scale="slaney",
|
68 |
+
f_min=fmin,
|
69 |
+
f_max=fmax
|
70 |
+
)
|
71 |
+
self.n_fft = n_fft
|
72 |
+
self.hop_size = hop_size
|
73 |
+
|
74 |
+
def forward(self, wav):
|
75 |
+
wav = F.pad(wav, ((self.n_fft - self.hop_size) // 2, (self.n_fft - self.hop_size) // 2), "reflect")
|
76 |
+
mel = self.melspctrogram(wav)
|
77 |
+
logmel = torch.log(torch.clamp(mel, min=1e-5))
|
78 |
+
return logmel
|
79 |
+
|
80 |
+
|
81 |
+
def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
82 |
+
if torch.min(y) < -1.:
|
83 |
+
print('min value is ', torch.min(y))
|
84 |
+
if torch.max(y) > 1.:
|
85 |
+
print('max value is ', torch.max(y))
|
86 |
+
|
87 |
+
global mel_basis, hann_window
|
88 |
+
if fmax not in mel_basis:
|
89 |
+
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
90 |
+
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
91 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
92 |
+
|
93 |
+
# print("Padding by", int((n_fft - hop_size)/2), y.shape)
|
94 |
+
# pre-padding
|
95 |
+
n_pad = hop_size - ( y.shape[1] % hop_size )
|
96 |
+
y = F.pad(y.unsqueeze(1), (0, n_pad), mode='reflect').squeeze(1)
|
97 |
+
# print("intermediate:", y.shape)
|
98 |
+
|
99 |
+
y = F.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
100 |
+
y = y.squeeze(1)
|
101 |
+
|
102 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
103 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
104 |
+
spec = spec.abs().clamp_(3e-5)
|
105 |
+
# print("Post: ", y.shape, spec.shape)
|
106 |
+
|
107 |
+
spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
|
108 |
+
spec = spectral_normalize_torch(spec)
|
109 |
+
|
110 |
+
return spec
|
111 |
+
|
112 |
+
|
113 |
+
def get_dataset_filelist(a):
|
114 |
+
train_df = pd.read_csv(a.input_training_file)
|
115 |
+
valid_df = pd.read_csv(a.input_validation_file)
|
116 |
+
return train_df, valid_df
|
117 |
+
|
118 |
+
|
119 |
+
class MelDataset(torch.utils.data.Dataset):
|
120 |
+
def __init__(self, training_files, segment_size, n_fft, num_mels,
|
121 |
+
hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
|
122 |
+
device=None, fmax_loss=None, fine_tuning=False, audio_root_path=None, feat_root_path=None, use_alt_melcalc=False):
|
123 |
+
self.audio_files = training_files
|
124 |
+
if shuffle:
|
125 |
+
self.audio_files = self.audio_files.sample(frac=1, random_state=1234)
|
126 |
+
self.segment_size = segment_size
|
127 |
+
self.sampling_rate = sampling_rate
|
128 |
+
self.split = split
|
129 |
+
self.n_fft = n_fft
|
130 |
+
self.num_mels = num_mels
|
131 |
+
self.hop_size = hop_size
|
132 |
+
self.win_size = win_size
|
133 |
+
self.fmin = fmin
|
134 |
+
self.fmax = fmax
|
135 |
+
self.fmax_loss = fmax_loss
|
136 |
+
self.cached_wav = None
|
137 |
+
self.n_cache_reuse = n_cache_reuse
|
138 |
+
self._cache_ref_count = 0
|
139 |
+
self.device = device
|
140 |
+
self.fine_tuning = fine_tuning
|
141 |
+
self.audio_root_path = Path(audio_root_path)
|
142 |
+
self.feat_root_path = Path(feat_root_path)
|
143 |
+
self.alt_melspec = LogMelSpectrogram(n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax)
|
144 |
+
self.use_alt_melcalc = use_alt_melcalc
|
145 |
+
|
146 |
+
def __getitem__(self, index):
|
147 |
+
row = self.audio_files.iloc[index]
|
148 |
+
if self._cache_ref_count == 0:
|
149 |
+
audio, sampling_rate = load_wav(self.audio_root_path/row.audio_path)
|
150 |
+
if not self.fine_tuning:
|
151 |
+
audio = normalize(audio) * 0.95
|
152 |
+
self.cached_wav = audio
|
153 |
+
if sampling_rate != self.sampling_rate:
|
154 |
+
raise ValueError("{} SR doesn't match target {} SR".format(
|
155 |
+
sampling_rate, self.sampling_rate))
|
156 |
+
self._cache_ref_count = self.n_cache_reuse
|
157 |
+
else:
|
158 |
+
audio = self.cached_wav
|
159 |
+
self._cache_ref_count -= 1
|
160 |
+
|
161 |
+
audio = torch.tensor(audio, dtype=torch.float32)
|
162 |
+
audio = audio.unsqueeze(0)
|
163 |
+
|
164 |
+
if not self.fine_tuning:
|
165 |
+
if self.split:
|
166 |
+
if audio.size(1) >= self.segment_size:
|
167 |
+
max_audio_start = audio.size(1) - self.segment_size
|
168 |
+
audio_start = random.randint(0, max_audio_start)
|
169 |
+
audio = audio[:, audio_start:audio_start+self.segment_size]
|
170 |
+
else:
|
171 |
+
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
|
172 |
+
|
173 |
+
if self.use_alt_melcalc:
|
174 |
+
mel = self.alt_melspec(audio)
|
175 |
+
else:
|
176 |
+
mel1 = mel_spectrogram(audio, self.n_fft, self.num_mels,
|
177 |
+
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
|
178 |
+
center=False)
|
179 |
+
|
180 |
+
mel = mel.permute(0, 2, 1) # (1, dim, seq_len) --> (1, seq_len, dim)
|
181 |
+
else:
|
182 |
+
mel = torch.load(self.feat_root_path/row.feat_path, map_location='cpu').float()
|
183 |
+
|
184 |
+
if len(mel.shape) < 3:
|
185 |
+
mel = mel.unsqueeze(0) # (1, seq_len, dim)
|
186 |
+
|
187 |
+
if self.split:
|
188 |
+
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
|
189 |
+
|
190 |
+
if audio.size(1) >= self.segment_size:
|
191 |
+
mel_start = random.randint(0, mel.size(1) - frames_per_seg - 1)
|
192 |
+
mel = mel[:, mel_start:mel_start + frames_per_seg, :]
|
193 |
+
audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
|
194 |
+
else:
|
195 |
+
mel = torch.nn.functional.pad(mel, (0, 0, 0, frames_per_seg - mel.size(2)), 'constant')
|
196 |
+
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
|
197 |
+
|
198 |
+
|
199 |
+
if self.use_alt_melcalc:
|
200 |
+
mel_loss = self.alt_melspec(audio)
|
201 |
+
else:
|
202 |
+
mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
|
203 |
+
self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
|
204 |
+
center=False)
|
205 |
+
return (mel.squeeze(), audio.squeeze(0), str(row.audio_path), mel_loss.squeeze())
|
206 |
+
|
207 |
+
def __len__(self):
|
208 |
+
return len(self.audio_files)
|
hifigan/models.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
6 |
+
from .utils import init_weights, get_padding
|
7 |
+
|
8 |
+
LRELU_SLOPE = 0.1
|
9 |
+
|
10 |
+
|
11 |
+
class ResBlock1(torch.nn.Module):
|
12 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
13 |
+
super(ResBlock1, self).__init__()
|
14 |
+
self.h = h
|
15 |
+
self.convs1 = nn.ModuleList([
|
16 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
17 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
18 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
19 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
20 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
21 |
+
padding=get_padding(kernel_size, dilation[2])))
|
22 |
+
])
|
23 |
+
self.convs1.apply(init_weights)
|
24 |
+
|
25 |
+
self.convs2 = nn.ModuleList([
|
26 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
27 |
+
padding=get_padding(kernel_size, 1))),
|
28 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
29 |
+
padding=get_padding(kernel_size, 1))),
|
30 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
31 |
+
padding=get_padding(kernel_size, 1)))
|
32 |
+
])
|
33 |
+
self.convs2.apply(init_weights)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
37 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
38 |
+
xt = c1(xt)
|
39 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
40 |
+
xt = c2(xt)
|
41 |
+
x = xt + x
|
42 |
+
return x
|
43 |
+
|
44 |
+
def remove_weight_norm(self):
|
45 |
+
for l in self.convs1:
|
46 |
+
remove_weight_norm(l)
|
47 |
+
for l in self.convs2:
|
48 |
+
remove_weight_norm(l)
|
49 |
+
|
50 |
+
|
51 |
+
class ResBlock2(torch.nn.Module):
|
52 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
53 |
+
super(ResBlock2, self).__init__()
|
54 |
+
self.h = h
|
55 |
+
self.convs = nn.ModuleList([
|
56 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
57 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
58 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
59 |
+
padding=get_padding(kernel_size, dilation[1])))
|
60 |
+
])
|
61 |
+
self.convs.apply(init_weights)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
for c in self.convs:
|
65 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
66 |
+
xt = c(xt)
|
67 |
+
x = xt + x
|
68 |
+
return x
|
69 |
+
|
70 |
+
def remove_weight_norm(self):
|
71 |
+
for l in self.convs:
|
72 |
+
remove_weight_norm(l)
|
73 |
+
|
74 |
+
|
75 |
+
class Generator(torch.nn.Module):
|
76 |
+
def __init__(self, h):
|
77 |
+
super(Generator, self).__init__()
|
78 |
+
self.h = h
|
79 |
+
self.lin_pre = nn.Linear(h.hubert_dim, h.hifi_dim)
|
80 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
81 |
+
self.num_upsamples = len(h.upsample_rates)
|
82 |
+
self.conv_pre = weight_norm(Conv1d(h.hifi_dim, h.upsample_initial_channel, 7, 1, padding=3))
|
83 |
+
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
84 |
+
|
85 |
+
self.ups = nn.ModuleList()
|
86 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
87 |
+
|
88 |
+
self.ups.append(weight_norm(
|
89 |
+
ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
|
90 |
+
k, u, padding=(k-u)//2)))
|
91 |
+
|
92 |
+
self.resblocks = nn.ModuleList()
|
93 |
+
for i in range(len(self.ups)):
|
94 |
+
ch = h.upsample_initial_channel//(2**(i+1))
|
95 |
+
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
96 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
97 |
+
|
98 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
99 |
+
self.ups.apply(init_weights)
|
100 |
+
self.conv_post.apply(init_weights)
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
""" `x` as (bs, seq_len, dim), regular hifi assumes input of shape (bs, n_mels, seq_len) """
|
104 |
+
x = self.lin_pre(x)
|
105 |
+
x = x.permute(0, 2, 1) # (bs, seq_len, dim) --> (bs, dim, seq_len)
|
106 |
+
|
107 |
+
x = self.conv_pre(x)
|
108 |
+
for i in range(self.num_upsamples):
|
109 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
110 |
+
x = self.ups[i](x)
|
111 |
+
xs = None
|
112 |
+
for j in range(self.num_kernels):
|
113 |
+
if xs is None:
|
114 |
+
xs = self.resblocks[i*self.num_kernels+j](x)
|
115 |
+
else:
|
116 |
+
xs += self.resblocks[i*self.num_kernels+j](x)
|
117 |
+
x = xs / self.num_kernels
|
118 |
+
x = F.leaky_relu(x)
|
119 |
+
x = self.conv_post(x)
|
120 |
+
x = torch.tanh(x)
|
121 |
+
|
122 |
+
return x
|
123 |
+
|
124 |
+
def remove_weight_norm(self):
|
125 |
+
print('Removing weight norm...')
|
126 |
+
for l in self.ups:
|
127 |
+
remove_weight_norm(l)
|
128 |
+
for l in self.resblocks:
|
129 |
+
l.remove_weight_norm()
|
130 |
+
remove_weight_norm(self.conv_pre)
|
131 |
+
remove_weight_norm(self.conv_post)
|
132 |
+
|
133 |
+
|
134 |
+
class DiscriminatorP(torch.nn.Module):
|
135 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
136 |
+
super(DiscriminatorP, self).__init__()
|
137 |
+
self.period = period
|
138 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
139 |
+
self.convs = nn.ModuleList([
|
140 |
+
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
141 |
+
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
142 |
+
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
143 |
+
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
144 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
145 |
+
])
|
146 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
fmap = []
|
150 |
+
|
151 |
+
# 1d to 2d
|
152 |
+
b, c, t = x.shape
|
153 |
+
if t % self.period != 0: # pad first
|
154 |
+
n_pad = self.period - (t % self.period)
|
155 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
156 |
+
t = t + n_pad
|
157 |
+
x = x.view(b, c, t // self.period, self.period)
|
158 |
+
|
159 |
+
for l in self.convs:
|
160 |
+
x = l(x)
|
161 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
162 |
+
fmap.append(x)
|
163 |
+
x = self.conv_post(x)
|
164 |
+
fmap.append(x)
|
165 |
+
x = torch.flatten(x, 1, -1)
|
166 |
+
|
167 |
+
return x, fmap
|
168 |
+
|
169 |
+
|
170 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
171 |
+
def __init__(self):
|
172 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
173 |
+
self.discriminators = nn.ModuleList([
|
174 |
+
DiscriminatorP(2),
|
175 |
+
DiscriminatorP(3),
|
176 |
+
DiscriminatorP(5),
|
177 |
+
DiscriminatorP(7),
|
178 |
+
DiscriminatorP(11),
|
179 |
+
])
|
180 |
+
|
181 |
+
def forward(self, y, y_hat):
|
182 |
+
y_d_rs = []
|
183 |
+
y_d_gs = []
|
184 |
+
fmap_rs = []
|
185 |
+
fmap_gs = []
|
186 |
+
for i, d in enumerate(self.discriminators):
|
187 |
+
y_d_r, fmap_r = d(y)
|
188 |
+
y_d_g, fmap_g = d(y_hat)
|
189 |
+
y_d_rs.append(y_d_r)
|
190 |
+
fmap_rs.append(fmap_r)
|
191 |
+
y_d_gs.append(y_d_g)
|
192 |
+
fmap_gs.append(fmap_g)
|
193 |
+
|
194 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
195 |
+
|
196 |
+
|
197 |
+
class DiscriminatorS(torch.nn.Module):
|
198 |
+
def __init__(self, use_spectral_norm=False):
|
199 |
+
super(DiscriminatorS, self).__init__()
|
200 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
201 |
+
self.convs = nn.ModuleList([
|
202 |
+
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
203 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
204 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
205 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
206 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
207 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
208 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
209 |
+
])
|
210 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
211 |
+
|
212 |
+
def forward(self, x):
|
213 |
+
fmap = []
|
214 |
+
for l in self.convs:
|
215 |
+
x = l(x)
|
216 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
217 |
+
fmap.append(x)
|
218 |
+
x = self.conv_post(x)
|
219 |
+
fmap.append(x)
|
220 |
+
x = torch.flatten(x, 1, -1)
|
221 |
+
|
222 |
+
return x, fmap
|
223 |
+
|
224 |
+
|
225 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
226 |
+
def __init__(self):
|
227 |
+
super(MultiScaleDiscriminator, self).__init__()
|
228 |
+
self.discriminators = nn.ModuleList([
|
229 |
+
DiscriminatorS(use_spectral_norm=True),
|
230 |
+
DiscriminatorS(),
|
231 |
+
DiscriminatorS(),
|
232 |
+
])
|
233 |
+
self.meanpools = nn.ModuleList([
|
234 |
+
AvgPool1d(4, 2, padding=2),
|
235 |
+
AvgPool1d(4, 2, padding=2)
|
236 |
+
])
|
237 |
+
|
238 |
+
def forward(self, y, y_hat):
|
239 |
+
y_d_rs = []
|
240 |
+
y_d_gs = []
|
241 |
+
fmap_rs = []
|
242 |
+
fmap_gs = []
|
243 |
+
for i, d in enumerate(self.discriminators):
|
244 |
+
if i != 0:
|
245 |
+
y = self.meanpools[i-1](y)
|
246 |
+
y_hat = self.meanpools[i-1](y_hat)
|
247 |
+
y_d_r, fmap_r = d(y)
|
248 |
+
y_d_g, fmap_g = d(y_hat)
|
249 |
+
y_d_rs.append(y_d_r)
|
250 |
+
fmap_rs.append(fmap_r)
|
251 |
+
y_d_gs.append(y_d_g)
|
252 |
+
fmap_gs.append(fmap_g)
|
253 |
+
|
254 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
255 |
+
|
256 |
+
|
257 |
+
def feature_loss(fmap_r, fmap_g):
|
258 |
+
loss = 0
|
259 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
260 |
+
for rl, gl in zip(dr, dg):
|
261 |
+
loss += torch.mean(torch.abs(rl - gl))
|
262 |
+
|
263 |
+
return loss*2
|
264 |
+
|
265 |
+
|
266 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
267 |
+
loss = 0
|
268 |
+
r_losses = []
|
269 |
+
g_losses = []
|
270 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
271 |
+
r_loss = torch.mean((1-dr)**2)
|
272 |
+
g_loss = torch.mean(dg**2)
|
273 |
+
loss += (r_loss + g_loss)
|
274 |
+
r_losses.append(r_loss.item())
|
275 |
+
g_losses.append(g_loss.item())
|
276 |
+
|
277 |
+
return loss, r_losses, g_losses
|
278 |
+
|
279 |
+
|
280 |
+
def generator_loss(disc_outputs):
|
281 |
+
loss = 0
|
282 |
+
gen_losses = []
|
283 |
+
for dg in disc_outputs:
|
284 |
+
l = torch.mean((1-dg)**2)
|
285 |
+
gen_losses.append(l)
|
286 |
+
loss += l
|
287 |
+
|
288 |
+
return loss, gen_losses
|
289 |
+
|
hifigan/train.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import itertools
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.multiprocessing as mp
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from fastprogress import master_bar, progress_bar
|
11 |
+
from torch.cuda.amp.grad_scaler import GradScaler
|
12 |
+
from torch.distributed import init_process_group
|
13 |
+
from torch.nn.parallel import DistributedDataParallel
|
14 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
15 |
+
from torch.utils.tensorboard import SummaryWriter
|
16 |
+
|
17 |
+
from .meldataset import (LogMelSpectrogram, MelDataset, get_dataset_filelist,
|
18 |
+
mel_spectrogram)
|
19 |
+
from .models import (Generator, MultiPeriodDiscriminator,
|
20 |
+
MultiScaleDiscriminator, discriminator_loss, feature_loss,
|
21 |
+
generator_loss)
|
22 |
+
from .utils import (AttrDict, build_env, load_checkpoint, plot_spectrogram,
|
23 |
+
save_checkpoint, scan_checkpoint)
|
24 |
+
|
25 |
+
torch.backends.cudnn.benchmark = True
|
26 |
+
USE_ALT_MELCALC = True
|
27 |
+
|
28 |
+
|
29 |
+
def train(rank, a, h):
|
30 |
+
if h.num_gpus > 1:
|
31 |
+
init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
|
32 |
+
world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
|
33 |
+
|
34 |
+
torch.cuda.manual_seed(h.seed)
|
35 |
+
device = torch.device('cuda:{:d}'.format(rank))
|
36 |
+
|
37 |
+
generator = Generator(h).to(device)
|
38 |
+
mpd = MultiPeriodDiscriminator().to(device)
|
39 |
+
msd = MultiScaleDiscriminator().to(device)
|
40 |
+
|
41 |
+
if rank == 0:
|
42 |
+
print(generator)
|
43 |
+
os.makedirs(a.checkpoint_path, exist_ok=True)
|
44 |
+
print("checkpoints directory : ", a.checkpoint_path)
|
45 |
+
|
46 |
+
if os.path.isdir(a.checkpoint_path):
|
47 |
+
cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
|
48 |
+
cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
|
49 |
+
|
50 |
+
steps = 0
|
51 |
+
if cp_g is None or cp_do is None:
|
52 |
+
state_dict_do = None
|
53 |
+
last_epoch = -1
|
54 |
+
else:
|
55 |
+
state_dict_g = load_checkpoint(cp_g, device)
|
56 |
+
state_dict_do = load_checkpoint(cp_do, device)
|
57 |
+
generator.load_state_dict(state_dict_g['generator'])
|
58 |
+
mpd.load_state_dict(state_dict_do['mpd'])
|
59 |
+
msd.load_state_dict(state_dict_do['msd'])
|
60 |
+
steps = state_dict_do['steps'] + 1
|
61 |
+
last_epoch = state_dict_do['epoch']
|
62 |
+
print(f"Restored checkpoint from {cp_g} and {cp_do}")
|
63 |
+
|
64 |
+
if h.num_gpus > 1:
|
65 |
+
print("Multi-gpu detected")
|
66 |
+
generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
|
67 |
+
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
|
68 |
+
msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
|
69 |
+
|
70 |
+
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
71 |
+
optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
|
72 |
+
h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
73 |
+
|
74 |
+
if state_dict_do is not None:
|
75 |
+
optim_g.load_state_dict(state_dict_do['optim_g'])
|
76 |
+
optim_d.load_state_dict(state_dict_do['optim_d'])
|
77 |
+
|
78 |
+
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
|
79 |
+
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
|
80 |
+
if a.fp16:
|
81 |
+
scaler_g = GradScaler()
|
82 |
+
scaler_d = GradScaler()
|
83 |
+
|
84 |
+
train_df, valid_df = get_dataset_filelist(a)
|
85 |
+
|
86 |
+
trainset = MelDataset(train_df, h.segment_size, h.n_fft, h.num_mels,
|
87 |
+
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
|
88 |
+
shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
|
89 |
+
fine_tuning=a.fine_tuning,
|
90 |
+
audio_root_path=a.audio_root_path, feat_root_path=a.feature_root_path,
|
91 |
+
use_alt_melcalc=USE_ALT_MELCALC)
|
92 |
+
|
93 |
+
train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
|
94 |
+
|
95 |
+
train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
|
96 |
+
sampler=train_sampler,
|
97 |
+
batch_size=h.batch_size,
|
98 |
+
pin_memory=True,
|
99 |
+
persistent_workers=True,
|
100 |
+
drop_last=True)
|
101 |
+
|
102 |
+
alt_melspec = LogMelSpectrogram(h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax).to(device)
|
103 |
+
|
104 |
+
if rank == 0:
|
105 |
+
validset = MelDataset(valid_df, h.segment_size, h.n_fft, h.num_mels,
|
106 |
+
h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
|
107 |
+
fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
|
108 |
+
audio_root_path=a.audio_root_path, feat_root_path=a.feature_root_path,
|
109 |
+
use_alt_melcalc=USE_ALT_MELCALC)
|
110 |
+
validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
|
111 |
+
sampler=None,
|
112 |
+
batch_size=1,
|
113 |
+
pin_memory=True,
|
114 |
+
persistent_workers=True,
|
115 |
+
drop_last=True)
|
116 |
+
|
117 |
+
sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
|
118 |
+
|
119 |
+
generator.train()
|
120 |
+
mpd.train()
|
121 |
+
msd.train()
|
122 |
+
|
123 |
+
if rank == 0: mb = master_bar(range(max(0, last_epoch), a.training_epochs))
|
124 |
+
else: mb = range(max(0, last_epoch), a.training_epochs)
|
125 |
+
|
126 |
+
for epoch in mb:
|
127 |
+
if rank == 0:
|
128 |
+
start = time.time()
|
129 |
+
mb.write("Epoch: {}".format(epoch+1))
|
130 |
+
|
131 |
+
if h.num_gpus > 1:
|
132 |
+
train_sampler.set_epoch(epoch)
|
133 |
+
|
134 |
+
if rank == 0: pb = progress_bar(enumerate(train_loader), total=len(train_loader), parent=mb)
|
135 |
+
else: pb = enumerate(train_loader)
|
136 |
+
|
137 |
+
|
138 |
+
for i, batch in pb:
|
139 |
+
if rank == 0:
|
140 |
+
start_b = time.time()
|
141 |
+
x, y, _, y_mel = batch
|
142 |
+
x = x.to(device, non_blocking=True)
|
143 |
+
y = y.to(device, non_blocking=True)
|
144 |
+
y_mel = y_mel.to(device, non_blocking=True)
|
145 |
+
y = y.unsqueeze(1)
|
146 |
+
|
147 |
+
with torch.cuda.amp.autocast(enabled=a.fp16):
|
148 |
+
y_g_hat = generator(x)
|
149 |
+
if USE_ALT_MELCALC:
|
150 |
+
y_g_hat_mel = alt_melspec(y_g_hat.squeeze(1))
|
151 |
+
else:
|
152 |
+
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
|
153 |
+
h.fmin, h.fmax_for_loss)
|
154 |
+
# print(x.shape, y_g_hat.shape, y_g_hat_mel.shape, y_mel.shape, y.shape)
|
155 |
+
optim_d.zero_grad()
|
156 |
+
|
157 |
+
with torch.cuda.amp.autocast(enabled=a.fp16):
|
158 |
+
# MPD
|
159 |
+
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
160 |
+
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
161 |
+
|
162 |
+
# MSD
|
163 |
+
y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
|
164 |
+
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
165 |
+
|
166 |
+
loss_disc_all = loss_disc_s + loss_disc_f
|
167 |
+
|
168 |
+
if a.fp16:
|
169 |
+
scaler_d.scale(loss_disc_all).backward()
|
170 |
+
scaler_d.step(optim_d)
|
171 |
+
scaler_d.update()
|
172 |
+
else:
|
173 |
+
loss_disc_all.backward()
|
174 |
+
optim_d.step()
|
175 |
+
|
176 |
+
# Generator
|
177 |
+
optim_g.zero_grad()
|
178 |
+
|
179 |
+
with torch.cuda.amp.autocast(enabled=a.fp16):
|
180 |
+
# L1 Mel-Spectrogram Loss
|
181 |
+
loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
|
182 |
+
|
183 |
+
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
|
184 |
+
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
|
185 |
+
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
|
186 |
+
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
|
187 |
+
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
|
188 |
+
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
189 |
+
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
190 |
+
|
191 |
+
if a.fp16:
|
192 |
+
scaler_g.scale(loss_gen_all).backward()
|
193 |
+
scaler_g.step(optim_g)
|
194 |
+
scaler_g.update()
|
195 |
+
else:
|
196 |
+
loss_gen_all.backward()
|
197 |
+
optim_g.step()
|
198 |
+
|
199 |
+
if rank == 0:
|
200 |
+
# STDOUT logging
|
201 |
+
if steps % a.stdout_interval == 0:
|
202 |
+
with torch.no_grad():
|
203 |
+
mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
|
204 |
+
|
205 |
+
mb.write('Steps : {:,d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, sec/batch : {:4.3f}, peak mem: {:5.2f}GB'. \
|
206 |
+
format(steps, loss_gen_all, mel_error, time.time() - start_b, torch.cuda.max_memory_allocated()/1e9))
|
207 |
+
mb.child.comment = "Steps : {:,d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}". \
|
208 |
+
format(steps, loss_gen_all, mel_error)
|
209 |
+
|
210 |
+
|
211 |
+
# checkpointing
|
212 |
+
if steps % a.checkpoint_interval == 0 and steps != 0:
|
213 |
+
checkpoint_path = "{}/g_{:08d}.pt".format(a.checkpoint_path, steps)
|
214 |
+
save_checkpoint(checkpoint_path,
|
215 |
+
{'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
|
216 |
+
checkpoint_path = "{}/do_{:08d}.pt".format(a.checkpoint_path, steps)
|
217 |
+
save_checkpoint(checkpoint_path,
|
218 |
+
{'mpd': (mpd.module if h.num_gpus > 1
|
219 |
+
else mpd).state_dict(),
|
220 |
+
'msd': (msd.module if h.num_gpus > 1
|
221 |
+
else msd).state_dict(),
|
222 |
+
'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
|
223 |
+
'epoch': epoch})
|
224 |
+
|
225 |
+
# Tensorboard summary logging
|
226 |
+
if steps % a.summary_interval == 0:
|
227 |
+
sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
|
228 |
+
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
229 |
+
sw.add_scalar("training/disc_loss_total", loss_disc_all, steps)
|
230 |
+
|
231 |
+
# Validation
|
232 |
+
if steps % a.validation_interval == 0: # and steps != 0:
|
233 |
+
generator.eval()
|
234 |
+
torch.cuda.empty_cache()
|
235 |
+
val_err_tot = 0
|
236 |
+
with torch.no_grad():
|
237 |
+
for j, batch in progress_bar(enumerate(validation_loader), total=len(validation_loader), parent=mb):
|
238 |
+
x, y, _, y_mel = batch
|
239 |
+
y_g_hat = generator(x.to(device))
|
240 |
+
y_mel = y_mel.to(device, non_blocking=True)
|
241 |
+
if USE_ALT_MELCALC:
|
242 |
+
y_g_hat_mel = alt_melspec(y_g_hat.squeeze(1))
|
243 |
+
if y_g_hat_mel.shape[-1] != y_mel.shape[-1]:
|
244 |
+
# pad it
|
245 |
+
n_pad = h.hop_size
|
246 |
+
y_g_hat = F.pad(y_g_hat, (n_pad//2, n_pad - n_pad//2))
|
247 |
+
y_g_hat_mel = alt_melspec(y_g_hat.squeeze(1))
|
248 |
+
else:
|
249 |
+
y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
|
250 |
+
h.hop_size, h.win_size,
|
251 |
+
h.fmin, h.fmax_for_loss)
|
252 |
+
#print('valid', x.shape, y_g_hat.shape, y_g_hat_mel.shape, y_mel.shape, y.shape)
|
253 |
+
val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
|
254 |
+
|
255 |
+
if j <= 4:
|
256 |
+
if steps == 0:
|
257 |
+
sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
|
258 |
+
sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
|
259 |
+
|
260 |
+
sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
|
261 |
+
if USE_ALT_MELCALC:
|
262 |
+
y_hat_spec = alt_melspec(y_g_hat.squeeze(1))
|
263 |
+
else:
|
264 |
+
y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
|
265 |
+
h.hop_size, h.win_size,
|
266 |
+
h.fmin, h.fmax_for_loss)
|
267 |
+
|
268 |
+
sw.add_figure('generated/y_hat_spec_{}'.format(j),
|
269 |
+
plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
|
270 |
+
|
271 |
+
val_err = val_err_tot / (j+1)
|
272 |
+
sw.add_scalar("validation/mel_spec_error", val_err, steps)
|
273 |
+
mb.write(f"validation run complete at {steps:,d} steps. validation mel spec error: {val_err:5.4f}")
|
274 |
+
|
275 |
+
generator.train()
|
276 |
+
sw.add_scalar("memory/max_allocated_gb", torch.cuda.max_memory_allocated()/1e9, steps)
|
277 |
+
sw.add_scalar("memory/max_reserved_gb", torch.cuda.max_memory_reserved()/1e9, steps)
|
278 |
+
torch.cuda.reset_peak_memory_stats()
|
279 |
+
torch.cuda.reset_accumulated_memory_stats()
|
280 |
+
|
281 |
+
steps += 1
|
282 |
+
|
283 |
+
scheduler_g.step()
|
284 |
+
scheduler_d.step()
|
285 |
+
|
286 |
+
if rank == 0:
|
287 |
+
print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
|
288 |
+
|
289 |
+
|
290 |
+
def main():
|
291 |
+
print('Initializing Training Process..')
|
292 |
+
|
293 |
+
parser = argparse.ArgumentParser()
|
294 |
+
|
295 |
+
parser.add_argument('--group_name', default=None)
|
296 |
+
parser.add_argument('--audio_root_path', required=True)
|
297 |
+
parser.add_argument('--feature_root_path', required=True)
|
298 |
+
parser.add_argument('--input_training_file', default='LJSpeech-1.1/training.txt')
|
299 |
+
parser.add_argument('--input_validation_file', default='LJSpeech-1.1/validation.txt')
|
300 |
+
parser.add_argument('--checkpoint_path', default='cp_hifigan')
|
301 |
+
parser.add_argument('--config', default='')
|
302 |
+
parser.add_argument('--training_epochs', default=1500, type=int)
|
303 |
+
parser.add_argument('--stdout_interval', default=5, type=int)
|
304 |
+
parser.add_argument('--checkpoint_interval', default=5000, type=int)
|
305 |
+
parser.add_argument('--summary_interval', default=25, type=int)
|
306 |
+
parser.add_argument('--validation_interval', default=1000, type=int)
|
307 |
+
parser.add_argument('--fp16', default=False, type=bool)
|
308 |
+
parser.add_argument('--fine_tuning', action='store_true')
|
309 |
+
|
310 |
+
a = parser.parse_args()
|
311 |
+
print(a)
|
312 |
+
with open(a.config) as f:
|
313 |
+
data = f.read()
|
314 |
+
|
315 |
+
json_config = json.loads(data)
|
316 |
+
h = AttrDict(json_config)
|
317 |
+
build_env(a.config, 'config.json', a.checkpoint_path)
|
318 |
+
|
319 |
+
torch.manual_seed(h.seed)
|
320 |
+
if torch.cuda.is_available():
|
321 |
+
torch.cuda.manual_seed(h.seed)
|
322 |
+
h.num_gpus = torch.cuda.device_count()
|
323 |
+
h.batch_size = int(h.batch_size / h.num_gpus)
|
324 |
+
print('Batch size per GPU :', h.batch_size)
|
325 |
+
else:
|
326 |
+
pass
|
327 |
+
|
328 |
+
if h.num_gpus > 1:
|
329 |
+
mp.spawn(train, nprocs=h.num_gpus, args=(a, h,))
|
330 |
+
else:
|
331 |
+
train(0, a, h)
|
332 |
+
|
333 |
+
|
334 |
+
if __name__ == '__main__':
|
335 |
+
main()
|
hifigan/utils.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.nn.utils import weight_norm
|
7 |
+
import json
|
8 |
+
|
9 |
+
|
10 |
+
def plot_spectrogram(spectrogram):
|
11 |
+
import matplotlib.pylab as plt
|
12 |
+
import matplotlib
|
13 |
+
matplotlib.use("Agg")
|
14 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
15 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
16 |
+
interpolation='none')
|
17 |
+
plt.colorbar(im, ax=ax)
|
18 |
+
|
19 |
+
fig.canvas.draw()
|
20 |
+
plt.close()
|
21 |
+
|
22 |
+
return fig
|
23 |
+
|
24 |
+
|
25 |
+
def init_weights(m, mean=0.0, std=0.01):
|
26 |
+
classname = m.__class__.__name__
|
27 |
+
if classname.find("Conv") != -1:
|
28 |
+
m.weight.data.normal_(mean, std)
|
29 |
+
|
30 |
+
|
31 |
+
def apply_weight_norm(m):
|
32 |
+
classname = m.__class__.__name__
|
33 |
+
if classname.find("Conv") != -1:
|
34 |
+
weight_norm(m)
|
35 |
+
|
36 |
+
|
37 |
+
def get_padding(kernel_size, dilation=1):
|
38 |
+
return int((kernel_size*dilation - dilation)/2)
|
39 |
+
|
40 |
+
|
41 |
+
def load_checkpoint(filepath, device):
|
42 |
+
assert os.path.isfile(filepath)
|
43 |
+
print("Loading '{}'".format(filepath))
|
44 |
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
45 |
+
print("Complete.")
|
46 |
+
return checkpoint_dict
|
47 |
+
|
48 |
+
|
49 |
+
def save_checkpoint(filepath, obj):
|
50 |
+
print("Saving checkpoint to {}".format(filepath))
|
51 |
+
torch.save(obj, filepath)
|
52 |
+
print("Complete.")
|
53 |
+
|
54 |
+
|
55 |
+
def scan_checkpoint(cp_dir, prefix):
|
56 |
+
pattern = os.path.join(cp_dir, prefix + '*')
|
57 |
+
cp_list = glob.glob(pattern)
|
58 |
+
if len(cp_list) == 0:
|
59 |
+
return None
|
60 |
+
return sorted(cp_list)[-1]
|
61 |
+
|
62 |
+
|
63 |
+
class AttrDict(dict):
|
64 |
+
def __init__(self, *args, **kwargs):
|
65 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
66 |
+
self.__dict__ = self
|
67 |
+
|
68 |
+
|
69 |
+
def build_env(config, config_name, path):
|
70 |
+
t_path = os.path.join(path, config_name)
|
71 |
+
if config != t_path:
|
72 |
+
os.makedirs(path, exist_ok=True)
|
73 |
+
shutil.copyfile(config, os.path.join(path, config_name))
|
hubconf.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dependencies = ['torch', 'torchaudio', 'numpy']
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import logging
|
8 |
+
import json
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
|
12 |
+
from wavlm.WavLM import WavLM, WavLMConfig
|
13 |
+
from hifigan.models import Generator as HiFiGAN
|
14 |
+
from hifigan.utils import AttrDict
|
15 |
+
from matcher import KNeighborsVC
|
16 |
+
|
17 |
+
|
18 |
+
def knn_vc(pretrained=True, progress=True, prematched=True, device='cuda') -> KNeighborsVC:
|
19 |
+
""" Load kNN-VC (WavLM encoder and HiFiGAN decoder). Optionally use vocoder trained on `prematched` data. """
|
20 |
+
hifigan, hifigan_cfg = hifigan_wavlm(pretrained, progress, prematched, device)
|
21 |
+
wavlm = wavlm_large(pretrained, progress, device)
|
22 |
+
knnvc = KNeighborsVC(wavlm, hifigan, hifigan_cfg, device)
|
23 |
+
return knnvc
|
24 |
+
|
25 |
+
|
26 |
+
def hifigan_wavlm(pretrained=True, progress=True, prematched=True, device='cuda') -> HiFiGAN:
|
27 |
+
""" Load pretrained hifigan trained to vocode wavlm features. Optionally use weights trained on `prematched` data. """
|
28 |
+
cp = Path(__file__).parent.absolute()
|
29 |
+
|
30 |
+
with open(cp/'hifigan'/'config_v1_wavlm.json') as f:
|
31 |
+
data = f.read()
|
32 |
+
json_config = json.loads(data)
|
33 |
+
h = AttrDict(json_config)
|
34 |
+
device = torch.device(device)
|
35 |
+
|
36 |
+
generator = HiFiGAN(h).to(device)
|
37 |
+
|
38 |
+
if pretrained:
|
39 |
+
if prematched:
|
40 |
+
url = "https://github.com/bshall/knn-vc/releases/download/v0.1/prematch_g_02500000.pt"
|
41 |
+
else:
|
42 |
+
url = "https://github.com/bshall/knn-vc/releases/download/v0.1/g_02500000.pt"
|
43 |
+
state_dict_g = torch.hub.load_state_dict_from_url(
|
44 |
+
url,
|
45 |
+
map_location=device,
|
46 |
+
progress=progress
|
47 |
+
)
|
48 |
+
generator.load_state_dict(state_dict_g['generator'])
|
49 |
+
generator.eval()
|
50 |
+
generator.remove_weight_norm()
|
51 |
+
print(f"[HiFiGAN] Generator loaded with {sum([p.numel() for p in generator.parameters()]):,d} parameters.")
|
52 |
+
return generator, h
|
53 |
+
|
54 |
+
|
55 |
+
def wavlm_large(pretrained=True, progress=True, device='cuda') -> WavLM:
|
56 |
+
"""Load the WavLM large checkpoint from the original paper. See https://github.com/microsoft/unilm/tree/master/wavlm for details. """
|
57 |
+
if torch.cuda.is_available() == False:
|
58 |
+
if str(device) != 'cpu':
|
59 |
+
logging.warning(f"Overriding device {device} to cpu since no GPU is available.")
|
60 |
+
device = 'cpu'
|
61 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
62 |
+
"https://github.com/bshall/knn-vc/releases/download/v0.1/WavLM-Large.pt",
|
63 |
+
map_location=device,
|
64 |
+
progress=progress
|
65 |
+
)
|
66 |
+
|
67 |
+
cfg = WavLMConfig(checkpoint['cfg'])
|
68 |
+
device = torch.device(device)
|
69 |
+
model = WavLM(cfg)
|
70 |
+
if pretrained:
|
71 |
+
model.load_state_dict(checkpoint['model'])
|
72 |
+
model = model.to(device)
|
73 |
+
model.eval()
|
74 |
+
print(f"WavLM-Large loaded with {sum([p.numel() for p in model.parameters()]):,d} parameters.")
|
75 |
+
return model
|
knnvc_utils.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
def generate_matrix_from_index(A, len=25):
|
4 |
+
matrix = np.zeros(len, dtype=float)
|
5 |
+
matrix[A] = 1
|
6 |
+
return matrix
|
7 |
+
|
8 |
+
|
9 |
+
def retrieve_index_from_matrix(matrix):
|
10 |
+
A = np.where(matrix == 1)[0]
|
11 |
+
return A
|
12 |
+
|
13 |
+
if __name__ == '__main__':
|
14 |
+
# Generating a matrix from index A
|
15 |
+
A = 6
|
16 |
+
matrix = generate_matrix_from_index(A)
|
17 |
+
print("Generated Matrix:")
|
18 |
+
print(matrix)
|
19 |
+
|
20 |
+
# Retrieving index A from the matrix
|
21 |
+
retrieved_A = retrieve_index_from_matrix(matrix)
|
22 |
+
print("Retrieved Index A:")
|
23 |
+
print(retrieved_A)
|
matcher.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torchaudio
|
8 |
+
import torchaudio.transforms as T
|
9 |
+
from hifigan.models import Generator as HiFiGAN
|
10 |
+
from hifigan.utils import AttrDict
|
11 |
+
from torch import Tensor
|
12 |
+
from torchaudio.sox_effects import apply_effects_tensor
|
13 |
+
from wavlm.WavLM import WavLM
|
14 |
+
from knnvc_utils import generate_matrix_from_index
|
15 |
+
|
16 |
+
|
17 |
+
SPEAKER_INFORMATION_LAYER = 6
|
18 |
+
SPEAKER_INFORMATION_WEIGHTS = generate_matrix_from_index(SPEAKER_INFORMATION_LAYER)
|
19 |
+
|
20 |
+
|
21 |
+
def fast_cosine_dist(source_feats: Tensor, matching_pool: Tensor, device: str = 'cpu') -> Tensor:
|
22 |
+
""" Like torch.cdist, but fixed dim=-1 and for cosine distance."""
|
23 |
+
source_norms = torch.norm(source_feats, p=2, dim=-1).to(device)
|
24 |
+
matching_norms = torch.norm(matching_pool, p=2, dim=-1)
|
25 |
+
dotprod = -torch.cdist(source_feats[None].to(device), matching_pool[None], p=2)[0]**2 + source_norms[:, None]**2 + matching_norms[None]**2
|
26 |
+
dotprod /= 2
|
27 |
+
|
28 |
+
dists = 1 - ( dotprod / (source_norms[:, None] * matching_norms[None]) )
|
29 |
+
return dists
|
30 |
+
|
31 |
+
|
32 |
+
class KNeighborsVC(nn.Module):
|
33 |
+
|
34 |
+
def __init__(self,
|
35 |
+
wavlm: WavLM,
|
36 |
+
hifigan: HiFiGAN,
|
37 |
+
hifigan_cfg: AttrDict,
|
38 |
+
device='cuda'
|
39 |
+
) -> None:
|
40 |
+
""" kNN-VC matcher.
|
41 |
+
Arguments:
|
42 |
+
- `wavlm` : trained WavLM model
|
43 |
+
- `hifigan`: trained hifigan model
|
44 |
+
- `hifigan_cfg`: hifigan config to use for vocoding.
|
45 |
+
"""
|
46 |
+
super().__init__()
|
47 |
+
# set which features to extract from wavlm
|
48 |
+
self.weighting = torch.tensor(SPEAKER_INFORMATION_WEIGHTS, device=device)[:, None]
|
49 |
+
# load hifigan
|
50 |
+
self.hifigan = hifigan.eval()
|
51 |
+
self.h = hifigan_cfg
|
52 |
+
# store wavlm
|
53 |
+
self.wavlm = wavlm.eval()
|
54 |
+
self.device = torch.device(device)
|
55 |
+
self.sr = self.h.sampling_rate
|
56 |
+
self.hop_length = 320
|
57 |
+
|
58 |
+
def get_matching_set(self, wavs: list[Path] | list[Tensor], weights=None, vad_trigger_level=7) -> Tensor:
|
59 |
+
""" Get concatenated wavlm features for the matching set using all waveforms in `wavs`,
|
60 |
+
specified as either a list of paths or list of loaded waveform tensors of
|
61 |
+
shape (channels, T), assumed to be of 16kHz sample rate.
|
62 |
+
Optionally specify custom WavLM feature weighting with `weights`.
|
63 |
+
"""
|
64 |
+
feats = []
|
65 |
+
for p in wavs:
|
66 |
+
feats.append(self.get_features(p, weights=self.weighting if weights is None else weights, vad_trigger_level=vad_trigger_level))
|
67 |
+
|
68 |
+
feats = torch.concat(feats, dim=0).cpu()
|
69 |
+
return feats
|
70 |
+
|
71 |
+
|
72 |
+
@torch.inference_mode()
|
73 |
+
def vocode(self, c: Tensor) -> Tensor:
|
74 |
+
""" Vocode features with hifigan. `c` is of shape (bs, seq_len, c_dim) """
|
75 |
+
y_g_hat = self.hifigan(c)
|
76 |
+
y_g_hat = y_g_hat.squeeze(1)
|
77 |
+
return y_g_hat
|
78 |
+
|
79 |
+
|
80 |
+
@torch.inference_mode()
|
81 |
+
def get_features(self, path, weights=None, vad_trigger_level=0):
|
82 |
+
"""Returns features of `path` waveform as a tensor of shape (seq_len, dim), optionally perform VAD trimming
|
83 |
+
on start/end with `vad_trigger_level`.
|
84 |
+
"""
|
85 |
+
# load audio
|
86 |
+
if weights == None: weights = self.weighting
|
87 |
+
if type(path) in [str, Path]:
|
88 |
+
x, sr = torchaudio.load(path, normalize=True)
|
89 |
+
else:
|
90 |
+
x: Tensor = path
|
91 |
+
sr = self.sr
|
92 |
+
if x.dim() == 1: x = x[None]
|
93 |
+
|
94 |
+
if not sr == self.sr :
|
95 |
+
print(f"resample {sr} to {self.sr} in {path}")
|
96 |
+
x = torchaudio.functional.resample(x, orig_freq=sr, new_freq=self.sr)
|
97 |
+
sr = self.sr
|
98 |
+
|
99 |
+
# trim silence from front and back
|
100 |
+
if vad_trigger_level > 1e-3:
|
101 |
+
transform = T.Vad(sample_rate=sr, trigger_level=vad_trigger_level)
|
102 |
+
x_front_trim = transform(x)
|
103 |
+
# original way, disabled because it lacks windows support
|
104 |
+
#waveform_reversed, sr = apply_effects_tensor(x_front_trim, sr, [["reverse"]])
|
105 |
+
waveform_reversed = torch.flip(x_front_trim, (-1,))
|
106 |
+
waveform_reversed_front_trim = transform(waveform_reversed)
|
107 |
+
waveform_end_trim = torch.flip(waveform_reversed_front_trim, (-1,))
|
108 |
+
#waveform_end_trim, sr = apply_effects_tensor(
|
109 |
+
# waveform_reversed_front_trim, sr, [["reverse"]]
|
110 |
+
#)
|
111 |
+
x = waveform_end_trim
|
112 |
+
|
113 |
+
# extract the representation of each layer
|
114 |
+
wav_input_16khz = x.to(self.device)
|
115 |
+
if torch.allclose(weights, self.weighting):
|
116 |
+
# use fastpath
|
117 |
+
features = self.wavlm.extract_features(wav_input_16khz, output_layer=SPEAKER_INFORMATION_LAYER, ret_layer_results=False)[0]
|
118 |
+
features = features.squeeze(0)
|
119 |
+
else:
|
120 |
+
# use slower weighted
|
121 |
+
rep, layer_results = self.wavlm.extract_features(wav_input_16khz, output_layer=self.wavlm.cfg.encoder_layers, ret_layer_results=True)[0]
|
122 |
+
features = torch.cat([x.transpose(0, 1) for x, _ in layer_results], dim=0) # (n_layers, seq_len, dim)
|
123 |
+
# save full sequence
|
124 |
+
features = ( features*weights[:, None] ).sum(dim=0) # (seq_len, dim)
|
125 |
+
|
126 |
+
return features
|
127 |
+
|
128 |
+
|
129 |
+
@torch.inference_mode()
|
130 |
+
def match(self, query_seq: Tensor, matching_set: Tensor, synth_set: Tensor = None,
|
131 |
+
topk: int = 4, tgt_loudness_db: float | None = -16,
|
132 |
+
target_duration: float | None = None, device: str | None = None) -> Tensor:
|
133 |
+
""" Given `query_seq`, `matching_set`, and `synth_set` tensors of shape (N, dim), perform kNN regression matching
|
134 |
+
with k=`topk`. Inputs:
|
135 |
+
- `query_seq`: Tensor (N1, dim) of the input/source query features.
|
136 |
+
- `matching_set`: Tensor (N2, dim) of the matching set used as the 'training set' for the kNN algorithm.
|
137 |
+
- `synth_set`: optional Tensor (N2, dim) corresponding to the matching set. We use the matching set to assign each query
|
138 |
+
vector to a vector in the matching set, and then use the corresponding vector from the synth set during HiFiGAN synthesis.
|
139 |
+
By default, and for best performance, this should be identical to the matching set.
|
140 |
+
- `topk`: k in the kNN -- the number of nearest neighbors to average over.
|
141 |
+
- `tgt_loudness_db`: float db used to normalize the output volume. Set to None to disable.
|
142 |
+
- `target_duration`: if set to a float, interpolate resulting waveform duration to be equal to this value in seconds.
|
143 |
+
- `device`: if None, uses default device at initialization. Otherwise uses specified device
|
144 |
+
Returns:
|
145 |
+
- converted waveform of shape (T,)
|
146 |
+
"""
|
147 |
+
device = torch.device(device) if device is not None else self.device
|
148 |
+
if synth_set is None: synth_set = matching_set.to(device)
|
149 |
+
else: synth_set = synth_set.to(device)
|
150 |
+
matching_set = matching_set.to(device)
|
151 |
+
query_seq = query_seq.to(device)
|
152 |
+
|
153 |
+
if target_duration is not None:
|
154 |
+
target_samples = int(target_duration*self.sr)
|
155 |
+
scale_factor = (target_samples/self.hop_length) / query_seq.shape[0] # n_targ_feats / n_input_feats
|
156 |
+
query_seq = F.interpolate(query_seq.T[None], scale_factor=scale_factor, mode='linear')[0].T
|
157 |
+
|
158 |
+
dists = fast_cosine_dist(query_seq, matching_set, device=device)
|
159 |
+
best = dists.topk(k=topk, largest=False, dim=-1)
|
160 |
+
out_feats = synth_set[best.indices].mean(dim=1)
|
161 |
+
|
162 |
+
prediction = self.vocode(out_feats[None].to(device)).cpu().squeeze()
|
163 |
+
|
164 |
+
# normalization
|
165 |
+
if tgt_loudness_db is not None:
|
166 |
+
src_loudness = torchaudio.functional.loudness(prediction[None], self.h.sampling_rate)
|
167 |
+
tgt_loudness = tgt_loudness_db
|
168 |
+
pred_wav = torchaudio.functional.gain(prediction, tgt_loudness - src_loudness)
|
169 |
+
else: pred_wav = prediction
|
170 |
+
return pred_wav
|
171 |
+
|
172 |
+
|
prematch_dataset.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import gc
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import time
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import torchaudio
|
14 |
+
from fastprogress.fastprogress import master_bar, progress_bar
|
15 |
+
from torch import Tensor
|
16 |
+
|
17 |
+
from hubconf import wavlm_large
|
18 |
+
|
19 |
+
DOWNSAMPLE_FACTOR = 320
|
20 |
+
|
21 |
+
global feature_cache
|
22 |
+
feature_cache = {}
|
23 |
+
global synthesis_cache
|
24 |
+
synthesis_cache = {}
|
25 |
+
|
26 |
+
def make_librispeech_df(root_path: Path) -> pd.DataFrame:
|
27 |
+
all_files = []
|
28 |
+
folders = ['train-clean-100', 'dev-clean']
|
29 |
+
print(f"[LIBRISPEECH] Computing folders {folders}")
|
30 |
+
for f in folders:
|
31 |
+
all_files.extend(list((root_path/f).rglob('**/*.flac')))
|
32 |
+
speakers = ['ls-' + f.stem.split('-')[0] for f in all_files]
|
33 |
+
df = pd.DataFrame({'path': all_files, 'speaker': speakers})
|
34 |
+
return df
|
35 |
+
|
36 |
+
|
37 |
+
def main(args):
|
38 |
+
device = torch.device(args.device)
|
39 |
+
SYNTH_WEIGHTINGS = F.one_hot(torch.tensor(args.synthesis_layer), num_classes=25).float().to(device)[:, None]
|
40 |
+
MATCH_WEIGHTINGS = F.one_hot(torch.tensor(args.matching_layer), num_classes=25).float().to(device)[:, None]
|
41 |
+
|
42 |
+
print(f"Matching weightings: {MATCH_WEIGHTINGS.squeeze()}\nSynthesis weightings: {SYNTH_WEIGHTINGS.squeeze()}")
|
43 |
+
ls_df = make_librispeech_df(Path(args.librispeech_path))
|
44 |
+
|
45 |
+
print(f"Loading wavlm.")
|
46 |
+
wavlm = wavlm_large(pretrained=True, progress=True, device=args.device)
|
47 |
+
|
48 |
+
np.random.seed(args.seed)
|
49 |
+
torch.manual_seed(args.seed)
|
50 |
+
extract(ls_df, wavlm, args.device, Path(args.librispeech_path), Path(args.out_path), SYNTH_WEIGHTINGS, MATCH_WEIGHTINGS)
|
51 |
+
print("All done!", flush=True)
|
52 |
+
|
53 |
+
|
54 |
+
def path2pools(path: Path, wavlm: nn.Module(), match_weights: Tensor, synth_weights: Tensor, device):
|
55 |
+
"""Given a waveform `path`, compute the matching pool"""
|
56 |
+
|
57 |
+
uttrs_from_same_spk = sorted(list(path.parent.rglob('**/*.flac')))
|
58 |
+
uttrs_from_same_spk.remove(path)
|
59 |
+
matching_pool = []
|
60 |
+
synth_pool = []
|
61 |
+
for pth in uttrs_from_same_spk:
|
62 |
+
if pth in feature_cache and pth in synthesis_cache:
|
63 |
+
matching_feats = feature_cache[pth].float() # (seq_len, dim)
|
64 |
+
synth_feats = synthesis_cache[pth].float() # (seq_len, dim)
|
65 |
+
else:
|
66 |
+
feats = get_full_features(pth, wavlm, device)
|
67 |
+
matching_feats = ( feats*match_weights[:, None] ).sum(dim=0) # (seq_len, dim)
|
68 |
+
synth_feats = ( feats*synth_weights[:, None] ).sum(dim=0) # (seq_len, dim)
|
69 |
+
feature_cache[pth] = matching_feats.half().cpu()
|
70 |
+
synthesis_cache[pth] = synth_feats.half().cpu()
|
71 |
+
|
72 |
+
matching_pool.append(matching_feats.cpu())
|
73 |
+
synth_pool.append(synth_feats.cpu())
|
74 |
+
matching_pool = torch.concat(matching_pool, dim=0)
|
75 |
+
synth_pool = torch.concat(synth_pool, dim=0)
|
76 |
+
return matching_pool, synth_pool # (N, dim)
|
77 |
+
|
78 |
+
|
79 |
+
@torch.inference_mode()
|
80 |
+
def get_full_features(path, wavlm, device):
|
81 |
+
|
82 |
+
x, sr = torchaudio.load(path)
|
83 |
+
assert sr == 16000
|
84 |
+
# This does not work i.t.o the hifigan training.
|
85 |
+
# x = F.pad(x, (DOWNSAMPLE_FACTOR//2, DOWNSAMPLE_FACTOR - DOWNSAMPLE_FACTOR//2), value=0)
|
86 |
+
# This does.
|
87 |
+
n_pad = DOWNSAMPLE_FACTOR - (x.shape[-1] % DOWNSAMPLE_FACTOR)
|
88 |
+
x = F.pad(x, (0, n_pad), value=0)
|
89 |
+
|
90 |
+
# extract the representation of each layer
|
91 |
+
wav_input_16khz = x.to(device)
|
92 |
+
rep, layer_results = wavlm.extract_features(wav_input_16khz, output_layer=wavlm.cfg.encoder_layers, ret_layer_results=True)[0]
|
93 |
+
features = torch.cat([x.transpose(0, 1) for x, _ in layer_results], dim=0) # (n_layers, seq_len, dim)
|
94 |
+
|
95 |
+
return features
|
96 |
+
|
97 |
+
|
98 |
+
def fast_cosine_dist(source_feats, matching_pool):
|
99 |
+
source_norms = torch.norm(source_feats, p=2, dim=-1)
|
100 |
+
matching_norms = torch.norm(matching_pool, p=2, dim=-1)
|
101 |
+
dotprod = -torch.cdist(source_feats[None], matching_pool[None], p=2)[0]**2 + source_norms[:, None]**2 + matching_norms[None]**2
|
102 |
+
dotprod /= 2
|
103 |
+
|
104 |
+
dists = 1 - ( dotprod / (source_norms[:, None] * matching_norms[None]) )
|
105 |
+
return dists
|
106 |
+
|
107 |
+
|
108 |
+
@torch.inference_mode()
|
109 |
+
def extract(df: pd.DataFrame, wavlm: nn.Module, device, ls_path: Path, out_path: Path, synth_weights: Tensor, match_weights: Tensor):
|
110 |
+
|
111 |
+
pb = progress_bar(df.iterrows(), total=len(df))
|
112 |
+
|
113 |
+
for i, row in pb:
|
114 |
+
rel_path = Path(row.path).relative_to(ls_path)
|
115 |
+
targ_path = (out_path/rel_path).with_suffix('.pt')
|
116 |
+
if args.resume:
|
117 |
+
if targ_path.is_file(): continue
|
118 |
+
# if targ_path.is_file(): continue
|
119 |
+
os.makedirs(targ_path.parent, exist_ok=True)
|
120 |
+
|
121 |
+
if Path(row.path) in feature_cache:
|
122 |
+
source_feats = feature_cache[Path(row.path)].float()
|
123 |
+
else:
|
124 |
+
source_feats = get_full_features(row.path, wavlm, device)
|
125 |
+
source_feats = ( source_feats*match_weights[:, None] ).sum(dim=0) # (seq_len, dim)
|
126 |
+
|
127 |
+
matching_pool, synth_pool = path2pools(row.path, wavlm, match_weights, synth_weights, device)
|
128 |
+
|
129 |
+
if not args.prematch:
|
130 |
+
out_feats = source_feats.cpu()
|
131 |
+
else:
|
132 |
+
dists = fast_cosine_dist(source_feats.cpu(), matching_pool.cpu()).cpu()
|
133 |
+
best = dists.topk(k=args.topk, dim=-1, largest=False) # (src_len, 4)
|
134 |
+
out_feats = synth_pool[best.indices].mean(dim=1) # (N, dim)
|
135 |
+
|
136 |
+
# save matched sequence
|
137 |
+
if i < 3: print("Feature has shape: ", out_feats.shape, flush=True)
|
138 |
+
# 3. save
|
139 |
+
torch.save(out_feats.cpu().half(), str(targ_path))
|
140 |
+
if hasattr(pb, 'child'):
|
141 |
+
pb.child.comment = str(rel_path)
|
142 |
+
pb.child.wait_for = min(pb.child.wait_for, 10)
|
143 |
+
pb.main_bar.comment = str(rel_path)
|
144 |
+
else:
|
145 |
+
pb.wait_for = min(pb.wait_for, 10)
|
146 |
+
pb.comment = str(rel_path)
|
147 |
+
|
148 |
+
|
149 |
+
if i % 1000 == 0:
|
150 |
+
print(f"Done {i:,d}/{len(df):,d}", flush=True)
|
151 |
+
feature_cache.clear()
|
152 |
+
synthesis_cache.clear()
|
153 |
+
gc.collect()
|
154 |
+
time.sleep(4)
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == '__main__':
|
158 |
+
parser = argparse.ArgumentParser(description="Compute matched wavlm features for a librispeech dataset")
|
159 |
+
|
160 |
+
parser.add_argument('--librispeech_path', required=True, type=str)
|
161 |
+
parser.add_argument('--seed', default=123, type=int)
|
162 |
+
parser.add_argument('--out_path', required=True, type=str)
|
163 |
+
parser.add_argument('--device', default='cuda', type=str)
|
164 |
+
parser.add_argument('--topk', type=int, default=4)
|
165 |
+
parser.add_argument('--matching_layer', type=int, default=6)
|
166 |
+
parser.add_argument('--synthesis_layer', type=int, default=6)
|
167 |
+
parser.add_argument('--prematch', action='store_true', help='prematch')
|
168 |
+
parser.add_argument('--resume', action='store_true')
|
169 |
+
|
170 |
+
args = parser.parse_args()
|
171 |
+
main(args)
|
172 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchaudio
|
3 |
+
soundfile
|
4 |
+
gradio
|
5 |
+
spaces
|
wavlm/WavLM.py
ADDED
@@ -0,0 +1,743 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
4 |
+
# Copyright (c) 2021 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
import logging
|
12 |
+
from typing import List, Optional, Tuple
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch.nn import LayerNorm
|
20 |
+
from .modules import (
|
21 |
+
Fp32GroupNorm,
|
22 |
+
Fp32LayerNorm,
|
23 |
+
GradMultiply,
|
24 |
+
MultiheadAttention,
|
25 |
+
SamePad,
|
26 |
+
init_bert_params,
|
27 |
+
get_activation_fn,
|
28 |
+
TransposeLast,
|
29 |
+
GLU_Linear,
|
30 |
+
)
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
|
35 |
+
def compute_mask_indices(
|
36 |
+
shape: Tuple[int, int],
|
37 |
+
padding_mask: Optional[torch.Tensor],
|
38 |
+
mask_prob: float,
|
39 |
+
mask_length: int,
|
40 |
+
mask_type: str = "static",
|
41 |
+
mask_other: float = 0.0,
|
42 |
+
min_masks: int = 0,
|
43 |
+
no_overlap: bool = False,
|
44 |
+
min_space: int = 0,
|
45 |
+
) -> np.ndarray:
|
46 |
+
"""
|
47 |
+
Computes random mask spans for a given shape
|
48 |
+
|
49 |
+
Args:
|
50 |
+
shape: the the shape for which to compute masks.
|
51 |
+
should be of size 2 where first element is batch size and 2nd is timesteps
|
52 |
+
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
53 |
+
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
54 |
+
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
55 |
+
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
56 |
+
mask_type: how to compute mask lengths
|
57 |
+
static = fixed size
|
58 |
+
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
59 |
+
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
60 |
+
poisson = sample from possion distribution with lambda = mask length
|
61 |
+
min_masks: minimum number of masked spans
|
62 |
+
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
63 |
+
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
64 |
+
"""
|
65 |
+
|
66 |
+
bsz, all_sz = shape
|
67 |
+
mask = np.full((bsz, all_sz), False)
|
68 |
+
|
69 |
+
all_num_mask = int(
|
70 |
+
# add a random number for probabilistic rounding
|
71 |
+
mask_prob * all_sz / float(mask_length)
|
72 |
+
+ np.random.rand()
|
73 |
+
)
|
74 |
+
|
75 |
+
all_num_mask = max(min_masks, all_num_mask)
|
76 |
+
|
77 |
+
mask_idcs = []
|
78 |
+
for i in range(bsz):
|
79 |
+
if padding_mask is not None:
|
80 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
81 |
+
num_mask = int(
|
82 |
+
# add a random number for probabilistic rounding
|
83 |
+
mask_prob * sz / float(mask_length)
|
84 |
+
+ np.random.rand()
|
85 |
+
)
|
86 |
+
num_mask = max(min_masks, num_mask)
|
87 |
+
else:
|
88 |
+
sz = all_sz
|
89 |
+
num_mask = all_num_mask
|
90 |
+
|
91 |
+
if mask_type == "static":
|
92 |
+
lengths = np.full(num_mask, mask_length)
|
93 |
+
elif mask_type == "uniform":
|
94 |
+
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
95 |
+
elif mask_type == "normal":
|
96 |
+
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
97 |
+
lengths = [max(1, int(round(x))) for x in lengths]
|
98 |
+
elif mask_type == "poisson":
|
99 |
+
lengths = np.random.poisson(mask_length, size=num_mask)
|
100 |
+
lengths = [int(round(x)) for x in lengths]
|
101 |
+
else:
|
102 |
+
raise Exception("unknown mask selection " + mask_type)
|
103 |
+
|
104 |
+
if sum(lengths) == 0:
|
105 |
+
lengths[0] = min(mask_length, sz - 1)
|
106 |
+
|
107 |
+
if no_overlap:
|
108 |
+
mask_idc = []
|
109 |
+
|
110 |
+
def arrange(s, e, length, keep_length):
|
111 |
+
span_start = np.random.randint(s, e - length)
|
112 |
+
mask_idc.extend(span_start + i for i in range(length))
|
113 |
+
|
114 |
+
new_parts = []
|
115 |
+
if span_start - s - min_space >= keep_length:
|
116 |
+
new_parts.append((s, span_start - min_space + 1))
|
117 |
+
if e - span_start - keep_length - min_space > keep_length:
|
118 |
+
new_parts.append((span_start + length + min_space, e))
|
119 |
+
return new_parts
|
120 |
+
|
121 |
+
parts = [(0, sz)]
|
122 |
+
min_length = min(lengths)
|
123 |
+
for length in sorted(lengths, reverse=True):
|
124 |
+
lens = np.fromiter(
|
125 |
+
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
126 |
+
np.int,
|
127 |
+
)
|
128 |
+
l_sum = np.sum(lens)
|
129 |
+
if l_sum == 0:
|
130 |
+
break
|
131 |
+
probs = lens / np.sum(lens)
|
132 |
+
c = np.random.choice(len(parts), p=probs)
|
133 |
+
s, e = parts.pop(c)
|
134 |
+
parts.extend(arrange(s, e, length, min_length))
|
135 |
+
mask_idc = np.asarray(mask_idc)
|
136 |
+
else:
|
137 |
+
min_len = min(lengths)
|
138 |
+
if sz - min_len <= num_mask:
|
139 |
+
min_len = sz - num_mask - 1
|
140 |
+
|
141 |
+
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
142 |
+
|
143 |
+
mask_idc = np.asarray(
|
144 |
+
[
|
145 |
+
mask_idc[j] + offset
|
146 |
+
for j in range(len(mask_idc))
|
147 |
+
for offset in range(lengths[j])
|
148 |
+
]
|
149 |
+
)
|
150 |
+
|
151 |
+
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
152 |
+
|
153 |
+
min_len = min([len(m) for m in mask_idcs])
|
154 |
+
for i, mask_idc in enumerate(mask_idcs):
|
155 |
+
if len(mask_idc) > min_len:
|
156 |
+
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
157 |
+
mask[i, mask_idc] = True
|
158 |
+
|
159 |
+
return mask
|
160 |
+
|
161 |
+
|
162 |
+
class WavLMConfig:
|
163 |
+
def __init__(self, cfg=None):
|
164 |
+
self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
|
165 |
+
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
166 |
+
|
167 |
+
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
168 |
+
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
169 |
+
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
170 |
+
self.activation_fn: str = "gelu" # activation function to use
|
171 |
+
|
172 |
+
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
173 |
+
self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
|
174 |
+
self.conv_bias: bool = False # include bias in conv encoder
|
175 |
+
self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
|
176 |
+
|
177 |
+
self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
|
178 |
+
|
179 |
+
# dropouts
|
180 |
+
self.dropout: float = 0.1 # dropout probability for the transformer
|
181 |
+
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
182 |
+
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
183 |
+
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
184 |
+
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
185 |
+
self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
|
186 |
+
|
187 |
+
# masking
|
188 |
+
self.mask_length: int = 10 # mask length
|
189 |
+
self.mask_prob: float = 0.65 # probability of replacing a token with mask
|
190 |
+
self.mask_selection: str = "static" # how to choose mask length
|
191 |
+
self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
|
192 |
+
self.no_mask_overlap: bool = False # whether to allow masks to overlap
|
193 |
+
self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
194 |
+
|
195 |
+
# channel masking
|
196 |
+
self.mask_channel_length: int = 10 # length of the mask for features (channels)
|
197 |
+
self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
|
198 |
+
self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
|
199 |
+
self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
|
200 |
+
self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
|
201 |
+
self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
|
202 |
+
|
203 |
+
# positional embeddings
|
204 |
+
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
205 |
+
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
206 |
+
|
207 |
+
# relative position embedding
|
208 |
+
self.relative_position_embedding: bool = False # apply relative position embedding
|
209 |
+
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
210 |
+
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
211 |
+
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
212 |
+
|
213 |
+
if cfg is not None:
|
214 |
+
self.update(cfg)
|
215 |
+
|
216 |
+
def update(self, cfg: dict):
|
217 |
+
self.__dict__.update(cfg)
|
218 |
+
|
219 |
+
|
220 |
+
class WavLM(nn.Module):
|
221 |
+
def __init__(
|
222 |
+
self,
|
223 |
+
cfg: WavLMConfig,
|
224 |
+
) -> None:
|
225 |
+
super().__init__()
|
226 |
+
logger.info(f"WavLM Config: {cfg.__dict__}")
|
227 |
+
|
228 |
+
self.cfg = cfg
|
229 |
+
feature_enc_layers = eval(cfg.conv_feature_layers)
|
230 |
+
self.embed = feature_enc_layers[-1][0]
|
231 |
+
|
232 |
+
self.feature_extractor = ConvFeatureExtractionModel(
|
233 |
+
conv_layers=feature_enc_layers,
|
234 |
+
dropout=0.0,
|
235 |
+
mode=cfg.extractor_mode,
|
236 |
+
conv_bias=cfg.conv_bias,
|
237 |
+
)
|
238 |
+
|
239 |
+
self.post_extract_proj = (
|
240 |
+
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
241 |
+
if self.embed != cfg.encoder_embed_dim
|
242 |
+
else None
|
243 |
+
)
|
244 |
+
|
245 |
+
self.mask_prob = cfg.mask_prob
|
246 |
+
self.mask_selection = cfg.mask_selection
|
247 |
+
self.mask_other = cfg.mask_other
|
248 |
+
self.mask_length = cfg.mask_length
|
249 |
+
self.no_mask_overlap = cfg.no_mask_overlap
|
250 |
+
self.mask_min_space = cfg.mask_min_space
|
251 |
+
|
252 |
+
self.mask_channel_prob = cfg.mask_channel_prob
|
253 |
+
self.mask_channel_selection = cfg.mask_channel_selection
|
254 |
+
self.mask_channel_other = cfg.mask_channel_other
|
255 |
+
self.mask_channel_length = cfg.mask_channel_length
|
256 |
+
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
257 |
+
self.mask_channel_min_space = cfg.mask_channel_min_space
|
258 |
+
|
259 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
260 |
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
261 |
+
|
262 |
+
self.feature_grad_mult = cfg.feature_grad_mult
|
263 |
+
|
264 |
+
self.mask_emb = nn.Parameter(
|
265 |
+
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
|
266 |
+
)
|
267 |
+
|
268 |
+
self.encoder = TransformerEncoder(cfg)
|
269 |
+
self.layer_norm = LayerNorm(self.embed)
|
270 |
+
|
271 |
+
def apply_mask(self, x, padding_mask):
|
272 |
+
B, T, C = x.shape
|
273 |
+
if self.mask_prob > 0:
|
274 |
+
mask_indices = compute_mask_indices(
|
275 |
+
(B, T),
|
276 |
+
padding_mask,
|
277 |
+
self.mask_prob,
|
278 |
+
self.mask_length,
|
279 |
+
self.mask_selection,
|
280 |
+
self.mask_other,
|
281 |
+
min_masks=2,
|
282 |
+
no_overlap=self.no_mask_overlap,
|
283 |
+
min_space=self.mask_min_space,
|
284 |
+
)
|
285 |
+
mask_indices = torch.from_numpy(mask_indices).to(x.device)
|
286 |
+
x[mask_indices] = self.mask_emb
|
287 |
+
else:
|
288 |
+
mask_indices = None
|
289 |
+
|
290 |
+
if self.mask_channel_prob > 0:
|
291 |
+
mask_channel_indices = compute_mask_indices(
|
292 |
+
(B, C),
|
293 |
+
None,
|
294 |
+
self.mask_channel_prob,
|
295 |
+
self.mask_channel_length,
|
296 |
+
self.mask_channel_selection,
|
297 |
+
self.mask_channel_other,
|
298 |
+
no_overlap=self.no_mask_channel_overlap,
|
299 |
+
min_space=self.mask_channel_min_space,
|
300 |
+
)
|
301 |
+
mask_channel_indices = (
|
302 |
+
torch.from_numpy(mask_channel_indices)
|
303 |
+
.to(x.device)
|
304 |
+
.unsqueeze(1)
|
305 |
+
.expand(-1, T, -1)
|
306 |
+
)
|
307 |
+
x[mask_channel_indices] = 0
|
308 |
+
|
309 |
+
return x, mask_indices
|
310 |
+
|
311 |
+
def forward_padding_mask(
|
312 |
+
self, features: torch.Tensor, padding_mask: torch.Tensor,
|
313 |
+
) -> torch.Tensor:
|
314 |
+
extra = padding_mask.size(1) % features.size(1)
|
315 |
+
if extra > 0:
|
316 |
+
padding_mask = padding_mask[:, :-extra]
|
317 |
+
padding_mask = padding_mask.view(
|
318 |
+
padding_mask.size(0), features.size(1), -1
|
319 |
+
)
|
320 |
+
padding_mask = padding_mask.all(-1)
|
321 |
+
return padding_mask
|
322 |
+
|
323 |
+
def extract_features(
|
324 |
+
self,
|
325 |
+
source: torch.Tensor,
|
326 |
+
padding_mask: Optional[torch.Tensor] = None,
|
327 |
+
mask: bool = False,
|
328 |
+
ret_conv: bool = False,
|
329 |
+
output_layer: Optional[int] = None,
|
330 |
+
ret_layer_results: bool = False,
|
331 |
+
):
|
332 |
+
|
333 |
+
if self.feature_grad_mult > 0:
|
334 |
+
features = self.feature_extractor(source)
|
335 |
+
if self.feature_grad_mult != 1.0:
|
336 |
+
features = GradMultiply.apply(features, self.feature_grad_mult)
|
337 |
+
else:
|
338 |
+
with torch.no_grad():
|
339 |
+
features = self.feature_extractor(source)
|
340 |
+
|
341 |
+
features = features.transpose(1, 2)
|
342 |
+
features = self.layer_norm(features)
|
343 |
+
|
344 |
+
if padding_mask is not None:
|
345 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
346 |
+
|
347 |
+
if self.post_extract_proj is not None:
|
348 |
+
features = self.post_extract_proj(features)
|
349 |
+
|
350 |
+
features = self.dropout_input(features)
|
351 |
+
|
352 |
+
if mask:
|
353 |
+
x, mask_indices = self.apply_mask(
|
354 |
+
features, padding_mask
|
355 |
+
)
|
356 |
+
else:
|
357 |
+
x = features
|
358 |
+
|
359 |
+
# feature: (B, T, D), float
|
360 |
+
# target: (B, T), long
|
361 |
+
# x: (B, T, D), float
|
362 |
+
# padding_mask: (B, T), bool
|
363 |
+
# mask_indices: (B, T), bool
|
364 |
+
x, layer_results = self.encoder(
|
365 |
+
x,
|
366 |
+
padding_mask=padding_mask,
|
367 |
+
layer=None if output_layer is None else output_layer - 1
|
368 |
+
)
|
369 |
+
|
370 |
+
res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
|
371 |
+
|
372 |
+
feature = res["features"] if ret_conv else res["x"]
|
373 |
+
if ret_layer_results:
|
374 |
+
feature = (feature, res["layer_results"])
|
375 |
+
return feature, res["padding_mask"]
|
376 |
+
|
377 |
+
|
378 |
+
class ConvFeatureExtractionModel(nn.Module):
|
379 |
+
def __init__(
|
380 |
+
self,
|
381 |
+
conv_layers: List[Tuple[int, int, int]],
|
382 |
+
dropout: float = 0.0,
|
383 |
+
mode: str = "default",
|
384 |
+
conv_bias: bool = False,
|
385 |
+
conv_type: str = "default"
|
386 |
+
):
|
387 |
+
super().__init__()
|
388 |
+
|
389 |
+
assert mode in {"default", "layer_norm"}
|
390 |
+
|
391 |
+
def block(
|
392 |
+
n_in,
|
393 |
+
n_out,
|
394 |
+
k,
|
395 |
+
stride,
|
396 |
+
is_layer_norm=False,
|
397 |
+
is_group_norm=False,
|
398 |
+
conv_bias=False,
|
399 |
+
):
|
400 |
+
def make_conv():
|
401 |
+
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
402 |
+
nn.init.kaiming_normal_(conv.weight)
|
403 |
+
return conv
|
404 |
+
|
405 |
+
assert (
|
406 |
+
is_layer_norm and is_group_norm
|
407 |
+
) == False, "layer norm and group norm are exclusive"
|
408 |
+
|
409 |
+
if is_layer_norm:
|
410 |
+
return nn.Sequential(
|
411 |
+
make_conv(),
|
412 |
+
nn.Dropout(p=dropout),
|
413 |
+
nn.Sequential(
|
414 |
+
TransposeLast(),
|
415 |
+
Fp32LayerNorm(dim, elementwise_affine=True),
|
416 |
+
TransposeLast(),
|
417 |
+
),
|
418 |
+
nn.GELU(),
|
419 |
+
)
|
420 |
+
elif is_group_norm:
|
421 |
+
return nn.Sequential(
|
422 |
+
make_conv(),
|
423 |
+
nn.Dropout(p=dropout),
|
424 |
+
Fp32GroupNorm(dim, dim, affine=True),
|
425 |
+
nn.GELU(),
|
426 |
+
)
|
427 |
+
else:
|
428 |
+
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
429 |
+
|
430 |
+
self.conv_type = conv_type
|
431 |
+
if self.conv_type == "default":
|
432 |
+
in_d = 1
|
433 |
+
self.conv_layers = nn.ModuleList()
|
434 |
+
for i, cl in enumerate(conv_layers):
|
435 |
+
assert len(cl) == 3, "invalid conv definition: " + str(cl)
|
436 |
+
(dim, k, stride) = cl
|
437 |
+
|
438 |
+
self.conv_layers.append(
|
439 |
+
block(
|
440 |
+
in_d,
|
441 |
+
dim,
|
442 |
+
k,
|
443 |
+
stride,
|
444 |
+
is_layer_norm=mode == "layer_norm",
|
445 |
+
is_group_norm=mode == "default" and i == 0,
|
446 |
+
conv_bias=conv_bias,
|
447 |
+
)
|
448 |
+
)
|
449 |
+
in_d = dim
|
450 |
+
elif self.conv_type == "conv2d":
|
451 |
+
in_d = 1
|
452 |
+
self.conv_layers = nn.ModuleList()
|
453 |
+
for i, cl in enumerate(conv_layers):
|
454 |
+
assert len(cl) == 3
|
455 |
+
(dim, k, stride) = cl
|
456 |
+
|
457 |
+
self.conv_layers.append(
|
458 |
+
torch.nn.Conv2d(in_d, dim, k, stride)
|
459 |
+
)
|
460 |
+
self.conv_layers.append(torch.nn.ReLU())
|
461 |
+
in_d = dim
|
462 |
+
elif self.conv_type == "custom":
|
463 |
+
in_d = 1
|
464 |
+
idim = 80
|
465 |
+
self.conv_layers = nn.ModuleList()
|
466 |
+
for i, cl in enumerate(conv_layers):
|
467 |
+
assert len(cl) == 3
|
468 |
+
(dim, k, stride) = cl
|
469 |
+
self.conv_layers.append(
|
470 |
+
torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
|
471 |
+
)
|
472 |
+
self.conv_layers.append(
|
473 |
+
torch.nn.LayerNorm([dim, idim])
|
474 |
+
)
|
475 |
+
self.conv_layers.append(torch.nn.ReLU())
|
476 |
+
in_d = dim
|
477 |
+
if (i + 1) % 2 == 0:
|
478 |
+
self.conv_layers.append(
|
479 |
+
torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
480 |
+
)
|
481 |
+
idim = int(math.ceil(idim / 2))
|
482 |
+
else:
|
483 |
+
pass
|
484 |
+
|
485 |
+
def forward(self, x, mask=None):
|
486 |
+
|
487 |
+
# BxT -> BxCxT
|
488 |
+
x = x.unsqueeze(1)
|
489 |
+
if self.conv_type == "custom":
|
490 |
+
for conv in self.conv_layers:
|
491 |
+
if isinstance(conv, nn.LayerNorm):
|
492 |
+
x = x.transpose(1, 2)
|
493 |
+
x = conv(x).transpose(1, 2)
|
494 |
+
else:
|
495 |
+
x = conv(x)
|
496 |
+
x = x.transpose(2, 3).contiguous()
|
497 |
+
x = x.view(x.size(0), -1, x.size(-1))
|
498 |
+
else:
|
499 |
+
for conv in self.conv_layers:
|
500 |
+
x = conv(x)
|
501 |
+
if self.conv_type == "conv2d":
|
502 |
+
b, c, t, f = x.size()
|
503 |
+
x = x.transpose(2, 3).contiguous().view(b, c * f, t)
|
504 |
+
return x
|
505 |
+
|
506 |
+
|
507 |
+
class TransformerEncoder(nn.Module):
|
508 |
+
def __init__(self, args):
|
509 |
+
super().__init__()
|
510 |
+
|
511 |
+
self.dropout = args.dropout
|
512 |
+
self.embedding_dim = args.encoder_embed_dim
|
513 |
+
|
514 |
+
self.pos_conv = nn.Conv1d(
|
515 |
+
self.embedding_dim,
|
516 |
+
self.embedding_dim,
|
517 |
+
kernel_size=args.conv_pos,
|
518 |
+
padding=args.conv_pos // 2,
|
519 |
+
groups=args.conv_pos_groups,
|
520 |
+
)
|
521 |
+
dropout = 0
|
522 |
+
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
523 |
+
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
524 |
+
nn.init.constant_(self.pos_conv.bias, 0)
|
525 |
+
|
526 |
+
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
527 |
+
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
528 |
+
|
529 |
+
if hasattr(args, "relative_position_embedding"):
|
530 |
+
self.relative_position_embedding = args.relative_position_embedding
|
531 |
+
self.num_buckets = args.num_buckets
|
532 |
+
self.max_distance = args.max_distance
|
533 |
+
else:
|
534 |
+
self.relative_position_embedding = False
|
535 |
+
self.num_buckets = 0
|
536 |
+
self.max_distance = 0
|
537 |
+
|
538 |
+
self.layers = nn.ModuleList(
|
539 |
+
[
|
540 |
+
TransformerSentenceEncoderLayer(
|
541 |
+
embedding_dim=self.embedding_dim,
|
542 |
+
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
543 |
+
num_attention_heads=args.encoder_attention_heads,
|
544 |
+
dropout=self.dropout,
|
545 |
+
attention_dropout=args.attention_dropout,
|
546 |
+
activation_dropout=args.activation_dropout,
|
547 |
+
activation_fn=args.activation_fn,
|
548 |
+
layer_norm_first=args.layer_norm_first,
|
549 |
+
has_relative_attention_bias=(self.relative_position_embedding and i == 0),
|
550 |
+
num_buckets=self.num_buckets,
|
551 |
+
max_distance=self.max_distance,
|
552 |
+
gru_rel_pos=args.gru_rel_pos,
|
553 |
+
)
|
554 |
+
for i in range(args.encoder_layers)
|
555 |
+
]
|
556 |
+
)
|
557 |
+
|
558 |
+
self.layer_norm_first = args.layer_norm_first
|
559 |
+
self.layer_norm = LayerNorm(self.embedding_dim)
|
560 |
+
self.layerdrop = args.encoder_layerdrop
|
561 |
+
|
562 |
+
self.apply(init_bert_params)
|
563 |
+
|
564 |
+
def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
|
565 |
+
x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
|
566 |
+
|
567 |
+
if self.layer_norm_first and layer is None:
|
568 |
+
x = self.layer_norm(x)
|
569 |
+
|
570 |
+
return x, layer_results
|
571 |
+
|
572 |
+
def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
|
573 |
+
|
574 |
+
if padding_mask is not None:
|
575 |
+
x[padding_mask] = 0
|
576 |
+
|
577 |
+
x_conv = self.pos_conv(x.transpose(1, 2))
|
578 |
+
x_conv = x_conv.transpose(1, 2)
|
579 |
+
x += x_conv
|
580 |
+
|
581 |
+
if not self.layer_norm_first:
|
582 |
+
x = self.layer_norm(x)
|
583 |
+
|
584 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
585 |
+
|
586 |
+
# B x T x C -> T x B x C
|
587 |
+
x = x.transpose(0, 1)
|
588 |
+
|
589 |
+
layer_results = []
|
590 |
+
z = None
|
591 |
+
if tgt_layer is not None:
|
592 |
+
layer_results.append((x, z))
|
593 |
+
r = None
|
594 |
+
pos_bias = None
|
595 |
+
for i, layer in enumerate(self.layers):
|
596 |
+
dropout_probability = np.random.random()
|
597 |
+
if not self.training or (dropout_probability > self.layerdrop):
|
598 |
+
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False,
|
599 |
+
self_attn_mask=streaming_mask, pos_bias=pos_bias)
|
600 |
+
if tgt_layer is not None:
|
601 |
+
layer_results.append((x, z))
|
602 |
+
if i == tgt_layer:
|
603 |
+
r = x
|
604 |
+
break
|
605 |
+
|
606 |
+
if r is not None:
|
607 |
+
x = r
|
608 |
+
|
609 |
+
# T x B x C -> B x T x C
|
610 |
+
x = x.transpose(0, 1)
|
611 |
+
|
612 |
+
return x, layer_results
|
613 |
+
|
614 |
+
|
615 |
+
class TransformerSentenceEncoderLayer(nn.Module):
|
616 |
+
"""
|
617 |
+
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
|
618 |
+
models.
|
619 |
+
"""
|
620 |
+
|
621 |
+
def __init__(
|
622 |
+
self,
|
623 |
+
embedding_dim: float = 768,
|
624 |
+
ffn_embedding_dim: float = 3072,
|
625 |
+
num_attention_heads: float = 8,
|
626 |
+
dropout: float = 0.1,
|
627 |
+
attention_dropout: float = 0.1,
|
628 |
+
activation_dropout: float = 0.1,
|
629 |
+
activation_fn: str = "relu",
|
630 |
+
layer_norm_first: bool = False,
|
631 |
+
has_relative_attention_bias: bool = False,
|
632 |
+
num_buckets: int = 0,
|
633 |
+
max_distance: int = 0,
|
634 |
+
rescale_init: bool = False,
|
635 |
+
gru_rel_pos: bool = False,
|
636 |
+
) -> None:
|
637 |
+
|
638 |
+
super().__init__()
|
639 |
+
# Initialize parameters
|
640 |
+
self.embedding_dim = embedding_dim
|
641 |
+
self.dropout = dropout
|
642 |
+
self.activation_dropout = activation_dropout
|
643 |
+
|
644 |
+
# Initialize blocks
|
645 |
+
self.activation_name = activation_fn
|
646 |
+
self.activation_fn = get_activation_fn(activation_fn)
|
647 |
+
self.self_attn = MultiheadAttention(
|
648 |
+
self.embedding_dim,
|
649 |
+
num_attention_heads,
|
650 |
+
dropout=attention_dropout,
|
651 |
+
self_attention=True,
|
652 |
+
has_relative_attention_bias=has_relative_attention_bias,
|
653 |
+
num_buckets=num_buckets,
|
654 |
+
max_distance=max_distance,
|
655 |
+
rescale_init=rescale_init,
|
656 |
+
gru_rel_pos=gru_rel_pos,
|
657 |
+
)
|
658 |
+
|
659 |
+
self.dropout1 = nn.Dropout(dropout)
|
660 |
+
self.dropout2 = nn.Dropout(self.activation_dropout)
|
661 |
+
self.dropout3 = nn.Dropout(dropout)
|
662 |
+
|
663 |
+
self.layer_norm_first = layer_norm_first
|
664 |
+
|
665 |
+
# layer norm associated with the self attention layer
|
666 |
+
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
667 |
+
|
668 |
+
if self.activation_name == "glu":
|
669 |
+
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
670 |
+
else:
|
671 |
+
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
672 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
673 |
+
|
674 |
+
# layer norm associated with the position wise feed-forward NN
|
675 |
+
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
676 |
+
|
677 |
+
def forward(
|
678 |
+
self,
|
679 |
+
x: torch.Tensor,
|
680 |
+
self_attn_mask: torch.Tensor = None,
|
681 |
+
self_attn_padding_mask: torch.Tensor = None,
|
682 |
+
need_weights: bool = False,
|
683 |
+
pos_bias=None
|
684 |
+
):
|
685 |
+
"""
|
686 |
+
LayerNorm is applied either before or after the self-attention/ffn
|
687 |
+
modules similar to the original Transformer imlementation.
|
688 |
+
"""
|
689 |
+
residual = x
|
690 |
+
|
691 |
+
if self.layer_norm_first:
|
692 |
+
x = self.self_attn_layer_norm(x)
|
693 |
+
x, attn, pos_bias = self.self_attn(
|
694 |
+
query=x,
|
695 |
+
key=x,
|
696 |
+
value=x,
|
697 |
+
key_padding_mask=self_attn_padding_mask,
|
698 |
+
need_weights=False,
|
699 |
+
attn_mask=self_attn_mask,
|
700 |
+
position_bias=pos_bias
|
701 |
+
)
|
702 |
+
x = self.dropout1(x)
|
703 |
+
x = residual + x
|
704 |
+
|
705 |
+
residual = x
|
706 |
+
x = self.final_layer_norm(x)
|
707 |
+
if self.activation_name == "glu":
|
708 |
+
x = self.fc1(x)
|
709 |
+
else:
|
710 |
+
x = self.activation_fn(self.fc1(x))
|
711 |
+
x = self.dropout2(x)
|
712 |
+
x = self.fc2(x)
|
713 |
+
x = self.dropout3(x)
|
714 |
+
x = residual + x
|
715 |
+
else:
|
716 |
+
x, attn, pos_bias = self.self_attn(
|
717 |
+
query=x,
|
718 |
+
key=x,
|
719 |
+
value=x,
|
720 |
+
key_padding_mask=self_attn_padding_mask,
|
721 |
+
need_weights=need_weights,
|
722 |
+
attn_mask=self_attn_mask,
|
723 |
+
position_bias=pos_bias
|
724 |
+
)
|
725 |
+
|
726 |
+
x = self.dropout1(x)
|
727 |
+
x = residual + x
|
728 |
+
|
729 |
+
x = self.self_attn_layer_norm(x)
|
730 |
+
|
731 |
+
residual = x
|
732 |
+
if self.activation_name == "glu":
|
733 |
+
x = self.fc1(x)
|
734 |
+
else:
|
735 |
+
x = self.activation_fn(self.fc1(x))
|
736 |
+
x = self.dropout2(x)
|
737 |
+
x = self.fc2(x)
|
738 |
+
x = self.dropout3(x)
|
739 |
+
x = residual + x
|
740 |
+
x = self.final_layer_norm(x)
|
741 |
+
|
742 |
+
return x, attn, pos_bias
|
743 |
+
|
wavlm/modules.py
ADDED
@@ -0,0 +1,827 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/wavlm
|
4 |
+
# Copyright (c) 2021 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
import warnings
|
12 |
+
from typing import Dict, Optional, Tuple
|
13 |
+
import torch
|
14 |
+
from torch import Tensor, nn
|
15 |
+
from torch.nn import Parameter
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
|
19 |
+
class TransposeLast(nn.Module):
|
20 |
+
def __init__(self, deconstruct_idx=None):
|
21 |
+
super().__init__()
|
22 |
+
self.deconstruct_idx = deconstruct_idx
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
if self.deconstruct_idx is not None:
|
26 |
+
x = x[self.deconstruct_idx]
|
27 |
+
return x.transpose(-2, -1)
|
28 |
+
|
29 |
+
|
30 |
+
class Fp32LayerNorm(nn.LayerNorm):
|
31 |
+
def __init__(self, *args, **kwargs):
|
32 |
+
super().__init__(*args, **kwargs)
|
33 |
+
|
34 |
+
def forward(self, input):
|
35 |
+
output = F.layer_norm(
|
36 |
+
input.float(),
|
37 |
+
self.normalized_shape,
|
38 |
+
self.weight.float() if self.weight is not None else None,
|
39 |
+
self.bias.float() if self.bias is not None else None,
|
40 |
+
self.eps,
|
41 |
+
)
|
42 |
+
return output.type_as(input)
|
43 |
+
|
44 |
+
|
45 |
+
class Fp32GroupNorm(nn.GroupNorm):
|
46 |
+
def __init__(self, *args, **kwargs):
|
47 |
+
super().__init__(*args, **kwargs)
|
48 |
+
|
49 |
+
def forward(self, input):
|
50 |
+
output = F.group_norm(
|
51 |
+
input.float(),
|
52 |
+
self.num_groups,
|
53 |
+
self.weight.float() if self.weight is not None else None,
|
54 |
+
self.bias.float() if self.bias is not None else None,
|
55 |
+
self.eps,
|
56 |
+
)
|
57 |
+
return output.type_as(input)
|
58 |
+
|
59 |
+
|
60 |
+
class GradMultiply(torch.autograd.Function):
|
61 |
+
@staticmethod
|
62 |
+
def forward(ctx, x, scale):
|
63 |
+
ctx.scale = scale
|
64 |
+
res = x.new(x)
|
65 |
+
return res
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def backward(ctx, grad):
|
69 |
+
return grad * ctx.scale, None
|
70 |
+
|
71 |
+
|
72 |
+
class SamePad(nn.Module):
|
73 |
+
def __init__(self, kernel_size, causal=False):
|
74 |
+
super().__init__()
|
75 |
+
if causal:
|
76 |
+
self.remove = kernel_size - 1
|
77 |
+
else:
|
78 |
+
self.remove = 1 if kernel_size % 2 == 0 else 0
|
79 |
+
|
80 |
+
def forward(self, x):
|
81 |
+
if self.remove > 0:
|
82 |
+
x = x[:, :, : -self.remove]
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
class Swish(nn.Module):
|
87 |
+
"""Swish function
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self):
|
91 |
+
"""Construct an MultiHeadedAttention object."""
|
92 |
+
super(Swish, self).__init__()
|
93 |
+
self.act = torch.nn.Sigmoid()
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
return x * self.act(x)
|
97 |
+
|
98 |
+
|
99 |
+
class GLU_Linear(nn.Module):
|
100 |
+
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
101 |
+
super(GLU_Linear, self).__init__()
|
102 |
+
|
103 |
+
self.glu_type = glu_type
|
104 |
+
self.output_dim = output_dim
|
105 |
+
|
106 |
+
if glu_type == "sigmoid":
|
107 |
+
self.glu_act = torch.nn.Sigmoid()
|
108 |
+
elif glu_type == "swish":
|
109 |
+
self.glu_act = Swish()
|
110 |
+
elif glu_type == "relu":
|
111 |
+
self.glu_act = torch.nn.ReLU()
|
112 |
+
elif glu_type == "gelu":
|
113 |
+
self.glu_act = torch.nn.GELU()
|
114 |
+
|
115 |
+
if bias_in_glu:
|
116 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
117 |
+
else:
|
118 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
122 |
+
x = self.linear(x)
|
123 |
+
|
124 |
+
if self.glu_type == "bilinear":
|
125 |
+
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
|
126 |
+
else:
|
127 |
+
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
|
128 |
+
|
129 |
+
return x
|
130 |
+
|
131 |
+
|
132 |
+
def gelu_accurate(x):
|
133 |
+
if not hasattr(gelu_accurate, "_a"):
|
134 |
+
gelu_accurate._a = math.sqrt(2 / math.pi)
|
135 |
+
return (
|
136 |
+
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
137 |
+
)
|
138 |
+
|
139 |
+
|
140 |
+
def gelu(x: torch.Tensor) -> torch.Tensor:
|
141 |
+
return torch.nn.functional.gelu(x.float()).type_as(x)
|
142 |
+
|
143 |
+
|
144 |
+
def get_activation_fn(activation: str):
|
145 |
+
"""Returns the activation function corresponding to `activation`"""
|
146 |
+
|
147 |
+
if activation == "relu":
|
148 |
+
return F.relu
|
149 |
+
elif activation == "gelu":
|
150 |
+
return gelu
|
151 |
+
elif activation == "gelu_fast":
|
152 |
+
warnings.warn(
|
153 |
+
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
|
154 |
+
)
|
155 |
+
return gelu_accurate
|
156 |
+
elif activation == "gelu_accurate":
|
157 |
+
return gelu_accurate
|
158 |
+
elif activation == "tanh":
|
159 |
+
return torch.tanh
|
160 |
+
elif activation == "linear":
|
161 |
+
return lambda x: x
|
162 |
+
elif activation == "glu":
|
163 |
+
return lambda x: x
|
164 |
+
else:
|
165 |
+
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
166 |
+
|
167 |
+
|
168 |
+
def init_bert_params(module):
|
169 |
+
"""
|
170 |
+
Initialize the weights specific to the BERT Model.
|
171 |
+
This overrides the default initializations depending on the specified arguments.
|
172 |
+
1. If normal_init_linear_weights is set then weights of linear
|
173 |
+
layer will be initialized using the normal distribution and
|
174 |
+
bais will be set to the specified value.
|
175 |
+
2. If normal_init_embed_weights is set then weights of embedding
|
176 |
+
layer will be initialized using the normal distribution.
|
177 |
+
3. If normal_init_proj_weights is set then weights of
|
178 |
+
in_project_weight for MultiHeadAttention initialized using
|
179 |
+
the normal distribution (to be validated).
|
180 |
+
"""
|
181 |
+
|
182 |
+
def normal_(data):
|
183 |
+
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
184 |
+
# so that the RNG is consistent with and without FSDP
|
185 |
+
data.copy_(
|
186 |
+
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
|
187 |
+
)
|
188 |
+
|
189 |
+
if isinstance(module, nn.Linear):
|
190 |
+
normal_(module.weight.data)
|
191 |
+
if module.bias is not None:
|
192 |
+
module.bias.data.zero_()
|
193 |
+
if isinstance(module, nn.Embedding):
|
194 |
+
normal_(module.weight.data)
|
195 |
+
if module.padding_idx is not None:
|
196 |
+
module.weight.data[module.padding_idx].zero_()
|
197 |
+
if isinstance(module, MultiheadAttention):
|
198 |
+
normal_(module.q_proj.weight.data)
|
199 |
+
normal_(module.k_proj.weight.data)
|
200 |
+
normal_(module.v_proj.weight.data)
|
201 |
+
|
202 |
+
|
203 |
+
def quant_noise(module, p, block_size):
|
204 |
+
"""
|
205 |
+
Wraps modules and applies quantization noise to the weights for
|
206 |
+
subsequent quantization with Iterative Product Quantization as
|
207 |
+
described in "Training with Quantization Noise for Extreme Model Compression"
|
208 |
+
|
209 |
+
Args:
|
210 |
+
- module: nn.Module
|
211 |
+
- p: amount of Quantization Noise
|
212 |
+
- block_size: size of the blocks for subsequent quantization with iPQ
|
213 |
+
|
214 |
+
Remarks:
|
215 |
+
- Module weights must have the right sizes wrt the block size
|
216 |
+
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
217 |
+
- For more detail on how to quantize by blocks with convolutional weights,
|
218 |
+
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
219 |
+
- We implement the simplest form of noise here as stated in the paper
|
220 |
+
which consists in randomly dropping blocks
|
221 |
+
"""
|
222 |
+
|
223 |
+
# if no quantization noise, don't register hook
|
224 |
+
if p <= 0:
|
225 |
+
return module
|
226 |
+
|
227 |
+
# supported modules
|
228 |
+
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
229 |
+
|
230 |
+
# test whether module.weight has the right sizes wrt block_size
|
231 |
+
is_conv = module.weight.ndim == 4
|
232 |
+
|
233 |
+
# 2D matrix
|
234 |
+
if not is_conv:
|
235 |
+
assert (
|
236 |
+
module.weight.size(1) % block_size == 0
|
237 |
+
), "Input features must be a multiple of block sizes"
|
238 |
+
|
239 |
+
# 4D matrix
|
240 |
+
else:
|
241 |
+
# 1x1 convolutions
|
242 |
+
if module.kernel_size == (1, 1):
|
243 |
+
assert (
|
244 |
+
module.in_channels % block_size == 0
|
245 |
+
), "Input channels must be a multiple of block sizes"
|
246 |
+
# regular convolutions
|
247 |
+
else:
|
248 |
+
k = module.kernel_size[0] * module.kernel_size[1]
|
249 |
+
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
250 |
+
|
251 |
+
def _forward_pre_hook(mod, input):
|
252 |
+
# no noise for evaluation
|
253 |
+
if mod.training:
|
254 |
+
if not is_conv:
|
255 |
+
# gather weight and sizes
|
256 |
+
weight = mod.weight
|
257 |
+
in_features = weight.size(1)
|
258 |
+
out_features = weight.size(0)
|
259 |
+
|
260 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
261 |
+
mask = torch.zeros(
|
262 |
+
in_features // block_size * out_features, device=weight.device
|
263 |
+
)
|
264 |
+
mask.bernoulli_(p)
|
265 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
266 |
+
|
267 |
+
else:
|
268 |
+
# gather weight and sizes
|
269 |
+
weight = mod.weight
|
270 |
+
in_channels = mod.in_channels
|
271 |
+
out_channels = mod.out_channels
|
272 |
+
|
273 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
274 |
+
if mod.kernel_size == (1, 1):
|
275 |
+
mask = torch.zeros(
|
276 |
+
int(in_channels // block_size * out_channels),
|
277 |
+
device=weight.device,
|
278 |
+
)
|
279 |
+
mask.bernoulli_(p)
|
280 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
281 |
+
else:
|
282 |
+
mask = torch.zeros(
|
283 |
+
weight.size(0), weight.size(1), device=weight.device
|
284 |
+
)
|
285 |
+
mask.bernoulli_(p)
|
286 |
+
mask = (
|
287 |
+
mask.unsqueeze(2)
|
288 |
+
.unsqueeze(3)
|
289 |
+
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
290 |
+
)
|
291 |
+
|
292 |
+
# scale weights and apply mask
|
293 |
+
mask = mask.to(
|
294 |
+
torch.bool
|
295 |
+
) # x.bool() is not currently supported in TorchScript
|
296 |
+
s = 1 / (1 - p)
|
297 |
+
mod.weight.data = s * weight.masked_fill(mask, 0)
|
298 |
+
|
299 |
+
module.register_forward_pre_hook(_forward_pre_hook)
|
300 |
+
return module
|
301 |
+
|
302 |
+
|
303 |
+
class MultiheadAttention(nn.Module):
|
304 |
+
"""Multi-headed attention.
|
305 |
+
|
306 |
+
See "Attention Is All You Need" for more details.
|
307 |
+
"""
|
308 |
+
|
309 |
+
def __init__(
|
310 |
+
self,
|
311 |
+
embed_dim,
|
312 |
+
num_heads,
|
313 |
+
kdim=None,
|
314 |
+
vdim=None,
|
315 |
+
dropout=0.0,
|
316 |
+
bias=True,
|
317 |
+
add_bias_kv=False,
|
318 |
+
add_zero_attn=False,
|
319 |
+
self_attention=False,
|
320 |
+
encoder_decoder_attention=False,
|
321 |
+
q_noise=0.0,
|
322 |
+
qn_block_size=8,
|
323 |
+
has_relative_attention_bias=False,
|
324 |
+
num_buckets=32,
|
325 |
+
max_distance=128,
|
326 |
+
gru_rel_pos=False,
|
327 |
+
rescale_init=False,
|
328 |
+
):
|
329 |
+
super().__init__()
|
330 |
+
self.embed_dim = embed_dim
|
331 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
332 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
333 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
334 |
+
|
335 |
+
self.num_heads = num_heads
|
336 |
+
self.dropout_module = nn.Dropout(dropout)
|
337 |
+
|
338 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
339 |
+
self.num_buckets = num_buckets
|
340 |
+
self.max_distance = max_distance
|
341 |
+
if self.has_relative_attention_bias:
|
342 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
343 |
+
|
344 |
+
self.head_dim = embed_dim // num_heads
|
345 |
+
self.q_head_dim = self.head_dim
|
346 |
+
self.k_head_dim = self.head_dim
|
347 |
+
assert (
|
348 |
+
self.head_dim * num_heads == self.embed_dim
|
349 |
+
), "embed_dim must be divisible by num_heads"
|
350 |
+
self.scaling = self.head_dim ** -0.5
|
351 |
+
|
352 |
+
self.self_attention = self_attention
|
353 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
354 |
+
|
355 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
356 |
+
"Self-attention requires query, key and " "value to be of the same size"
|
357 |
+
)
|
358 |
+
|
359 |
+
k_bias = True
|
360 |
+
if rescale_init:
|
361 |
+
k_bias = False
|
362 |
+
|
363 |
+
k_embed_dim = embed_dim
|
364 |
+
q_embed_dim = embed_dim
|
365 |
+
|
366 |
+
self.k_proj = quant_noise(
|
367 |
+
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
|
368 |
+
)
|
369 |
+
self.v_proj = quant_noise(
|
370 |
+
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
371 |
+
)
|
372 |
+
self.q_proj = quant_noise(
|
373 |
+
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
|
374 |
+
)
|
375 |
+
|
376 |
+
self.out_proj = quant_noise(
|
377 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
378 |
+
)
|
379 |
+
|
380 |
+
if add_bias_kv:
|
381 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
382 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
383 |
+
else:
|
384 |
+
self.bias_k = self.bias_v = None
|
385 |
+
|
386 |
+
self.add_zero_attn = add_zero_attn
|
387 |
+
|
388 |
+
self.gru_rel_pos = gru_rel_pos
|
389 |
+
if self.gru_rel_pos:
|
390 |
+
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
391 |
+
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
392 |
+
|
393 |
+
self.reset_parameters()
|
394 |
+
|
395 |
+
def reset_parameters(self):
|
396 |
+
if self.qkv_same_dim:
|
397 |
+
# Empirically observed the convergence to be much better with
|
398 |
+
# the scaled initialization
|
399 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
400 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
401 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
402 |
+
else:
|
403 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
404 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
405 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
406 |
+
|
407 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
408 |
+
if self.out_proj.bias is not None:
|
409 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
410 |
+
if self.bias_k is not None:
|
411 |
+
nn.init.xavier_normal_(self.bias_k)
|
412 |
+
if self.bias_v is not None:
|
413 |
+
nn.init.xavier_normal_(self.bias_v)
|
414 |
+
if self.has_relative_attention_bias:
|
415 |
+
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
416 |
+
|
417 |
+
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
418 |
+
num_buckets = self.num_buckets
|
419 |
+
max_distance = self.max_distance
|
420 |
+
relative_buckets = 0
|
421 |
+
|
422 |
+
if bidirectional:
|
423 |
+
num_buckets = num_buckets // 2
|
424 |
+
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
425 |
+
relative_positions = torch.abs(relative_positions)
|
426 |
+
else:
|
427 |
+
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
|
428 |
+
|
429 |
+
max_exact = num_buckets // 2
|
430 |
+
is_small = relative_positions < max_exact
|
431 |
+
|
432 |
+
relative_postion_if_large = max_exact + (
|
433 |
+
torch.log(relative_positions.float() / max_exact)
|
434 |
+
/ math.log(max_distance / max_exact)
|
435 |
+
* (num_buckets - max_exact)
|
436 |
+
).to(torch.long)
|
437 |
+
relative_postion_if_large = torch.min(
|
438 |
+
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
|
439 |
+
)
|
440 |
+
|
441 |
+
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
|
442 |
+
return relative_buckets
|
443 |
+
|
444 |
+
def compute_bias(self, query_length, key_length):
|
445 |
+
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
446 |
+
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
447 |
+
relative_position = memory_position - context_position
|
448 |
+
relative_position_bucket = self._relative_positions_bucket(
|
449 |
+
relative_position,
|
450 |
+
bidirectional=True
|
451 |
+
)
|
452 |
+
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
|
453 |
+
values = self.relative_attention_bias(relative_position_bucket)
|
454 |
+
values = values.permute([2, 0, 1])
|
455 |
+
return values
|
456 |
+
|
457 |
+
def forward(
|
458 |
+
self,
|
459 |
+
query,
|
460 |
+
key: Optional[Tensor],
|
461 |
+
value: Optional[Tensor],
|
462 |
+
key_padding_mask: Optional[Tensor] = None,
|
463 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
464 |
+
need_weights: bool = True,
|
465 |
+
static_kv: bool = False,
|
466 |
+
attn_mask: Optional[Tensor] = None,
|
467 |
+
before_softmax: bool = False,
|
468 |
+
need_head_weights: bool = False,
|
469 |
+
position_bias: Optional[Tensor] = None
|
470 |
+
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
471 |
+
"""Input shape: Time x Batch x Channel
|
472 |
+
|
473 |
+
Args:
|
474 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
475 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
476 |
+
padding elements are indicated by 1s.
|
477 |
+
need_weights (bool, optional): return the attention weights,
|
478 |
+
averaged over heads (default: False).
|
479 |
+
attn_mask (ByteTensor, optional): typically used to
|
480 |
+
implement causal attention, where the mask prevents the
|
481 |
+
attention from looking forward in time (default: None).
|
482 |
+
before_softmax (bool, optional): return the raw attention
|
483 |
+
weights and values before the attention softmax.
|
484 |
+
need_head_weights (bool, optional): return the attention
|
485 |
+
weights for each head. Implies *need_weights*. Default:
|
486 |
+
return the average attention weights over all heads.
|
487 |
+
"""
|
488 |
+
if need_head_weights:
|
489 |
+
need_weights = True
|
490 |
+
|
491 |
+
is_tpu = query.device.type == "xla"
|
492 |
+
|
493 |
+
tgt_len, bsz, embed_dim = query.size()
|
494 |
+
src_len = tgt_len
|
495 |
+
assert embed_dim == self.embed_dim
|
496 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
497 |
+
if key is not None:
|
498 |
+
src_len, key_bsz, _ = key.size()
|
499 |
+
if not torch.jit.is_scripting():
|
500 |
+
assert key_bsz == bsz
|
501 |
+
assert value is not None
|
502 |
+
assert src_len, bsz == value.shape[:2]
|
503 |
+
|
504 |
+
if self.has_relative_attention_bias and position_bias is None:
|
505 |
+
position_bias = self.compute_bias(tgt_len, src_len)
|
506 |
+
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
|
507 |
+
|
508 |
+
if (
|
509 |
+
not is_tpu # don't use PyTorch version on TPUs
|
510 |
+
and incremental_state is None
|
511 |
+
and not static_kv
|
512 |
+
# A workaround for quantization to work. Otherwise JIT compilation
|
513 |
+
# treats bias in linear module as method.
|
514 |
+
and not torch.jit.is_scripting()
|
515 |
+
and self.q_head_dim == self.head_dim
|
516 |
+
):
|
517 |
+
assert key is not None and value is not None
|
518 |
+
assert attn_mask is None
|
519 |
+
|
520 |
+
attn_mask_rel_pos = None
|
521 |
+
if position_bias is not None:
|
522 |
+
attn_mask_rel_pos = position_bias
|
523 |
+
if self.gru_rel_pos:
|
524 |
+
query_layer = query.transpose(0, 1)
|
525 |
+
new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
|
526 |
+
query_layer = query_layer.view(*new_x_shape)
|
527 |
+
query_layer = query_layer.permute(0, 2, 1, 3)
|
528 |
+
_B, _H, _L, __ = query_layer.size()
|
529 |
+
|
530 |
+
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
|
531 |
+
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
532 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
533 |
+
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
534 |
+
|
535 |
+
attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
|
536 |
+
k_proj_bias = self.k_proj.bias
|
537 |
+
if k_proj_bias is None:
|
538 |
+
k_proj_bias = torch.zeros_like(self.q_proj.bias)
|
539 |
+
|
540 |
+
x, attn = F.multi_head_attention_forward(
|
541 |
+
query,
|
542 |
+
key,
|
543 |
+
value,
|
544 |
+
self.embed_dim,
|
545 |
+
self.num_heads,
|
546 |
+
torch.empty([0]),
|
547 |
+
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
|
548 |
+
self.bias_k,
|
549 |
+
self.bias_v,
|
550 |
+
self.add_zero_attn,
|
551 |
+
self.dropout_module.p,
|
552 |
+
self.out_proj.weight,
|
553 |
+
self.out_proj.bias,
|
554 |
+
self.training,
|
555 |
+
# self.training or self.dropout_module.apply_during_inference,
|
556 |
+
key_padding_mask,
|
557 |
+
need_weights,
|
558 |
+
attn_mask_rel_pos,
|
559 |
+
use_separate_proj_weight=True,
|
560 |
+
q_proj_weight=self.q_proj.weight,
|
561 |
+
k_proj_weight=self.k_proj.weight,
|
562 |
+
v_proj_weight=self.v_proj.weight,
|
563 |
+
)
|
564 |
+
return x, attn, position_bias
|
565 |
+
|
566 |
+
if incremental_state is not None:
|
567 |
+
saved_state = self._get_input_buffer(incremental_state)
|
568 |
+
if saved_state is not None and "prev_key" in saved_state:
|
569 |
+
# previous time steps are cached - no need to recompute
|
570 |
+
# key and value if they are static
|
571 |
+
if static_kv:
|
572 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
573 |
+
key = value = None
|
574 |
+
else:
|
575 |
+
saved_state = None
|
576 |
+
|
577 |
+
if self.self_attention:
|
578 |
+
q = self.q_proj(query)
|
579 |
+
k = self.k_proj(query)
|
580 |
+
v = self.v_proj(query)
|
581 |
+
elif self.encoder_decoder_attention:
|
582 |
+
# encoder-decoder attention
|
583 |
+
q = self.q_proj(query)
|
584 |
+
if key is None:
|
585 |
+
assert value is None
|
586 |
+
k = v = None
|
587 |
+
else:
|
588 |
+
k = self.k_proj(key)
|
589 |
+
v = self.v_proj(key)
|
590 |
+
|
591 |
+
else:
|
592 |
+
assert key is not None and value is not None
|
593 |
+
q = self.q_proj(query)
|
594 |
+
k = self.k_proj(key)
|
595 |
+
v = self.v_proj(value)
|
596 |
+
q *= self.scaling
|
597 |
+
|
598 |
+
if self.bias_k is not None:
|
599 |
+
assert self.bias_v is not None
|
600 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
601 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
602 |
+
if attn_mask is not None:
|
603 |
+
attn_mask = torch.cat(
|
604 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
605 |
+
)
|
606 |
+
if key_padding_mask is not None:
|
607 |
+
key_padding_mask = torch.cat(
|
608 |
+
[
|
609 |
+
key_padding_mask,
|
610 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
611 |
+
],
|
612 |
+
dim=1,
|
613 |
+
)
|
614 |
+
|
615 |
+
q = (
|
616 |
+
q.contiguous()
|
617 |
+
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
618 |
+
.transpose(0, 1)
|
619 |
+
)
|
620 |
+
if k is not None:
|
621 |
+
k = (
|
622 |
+
k.contiguous()
|
623 |
+
.view(-1, bsz * self.num_heads, self.k_head_dim)
|
624 |
+
.transpose(0, 1)
|
625 |
+
)
|
626 |
+
if v is not None:
|
627 |
+
v = (
|
628 |
+
v.contiguous()
|
629 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
630 |
+
.transpose(0, 1)
|
631 |
+
)
|
632 |
+
|
633 |
+
if saved_state is not None:
|
634 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
635 |
+
if "prev_key" in saved_state:
|
636 |
+
_prev_key = saved_state["prev_key"]
|
637 |
+
assert _prev_key is not None
|
638 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
639 |
+
if static_kv:
|
640 |
+
k = prev_key
|
641 |
+
else:
|
642 |
+
assert k is not None
|
643 |
+
k = torch.cat([prev_key, k], dim=1)
|
644 |
+
src_len = k.size(1)
|
645 |
+
if "prev_value" in saved_state:
|
646 |
+
_prev_value = saved_state["prev_value"]
|
647 |
+
assert _prev_value is not None
|
648 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
649 |
+
if static_kv:
|
650 |
+
v = prev_value
|
651 |
+
else:
|
652 |
+
assert v is not None
|
653 |
+
v = torch.cat([prev_value, v], dim=1)
|
654 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
655 |
+
if "prev_key_padding_mask" in saved_state:
|
656 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
657 |
+
assert k is not None and v is not None
|
658 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
659 |
+
key_padding_mask=key_padding_mask,
|
660 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
661 |
+
batch_size=bsz,
|
662 |
+
src_len=k.size(1),
|
663 |
+
static_kv=static_kv,
|
664 |
+
)
|
665 |
+
|
666 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
667 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
668 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
669 |
+
# In this branch incremental_state is never None
|
670 |
+
assert incremental_state is not None
|
671 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
672 |
+
assert k is not None
|
673 |
+
assert k.size(1) == src_len
|
674 |
+
|
675 |
+
# This is part of a workaround to get around fork/join parallelism
|
676 |
+
# not supporting Optional types.
|
677 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
678 |
+
key_padding_mask = None
|
679 |
+
|
680 |
+
if key_padding_mask is not None:
|
681 |
+
assert key_padding_mask.size(0) == bsz
|
682 |
+
assert key_padding_mask.size(1) == src_len
|
683 |
+
|
684 |
+
if self.add_zero_attn:
|
685 |
+
assert v is not None
|
686 |
+
src_len += 1
|
687 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
688 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
689 |
+
if attn_mask is not None:
|
690 |
+
attn_mask = torch.cat(
|
691 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
692 |
+
)
|
693 |
+
if key_padding_mask is not None:
|
694 |
+
key_padding_mask = torch.cat(
|
695 |
+
[
|
696 |
+
key_padding_mask,
|
697 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
698 |
+
key_padding_mask
|
699 |
+
),
|
700 |
+
],
|
701 |
+
dim=1,
|
702 |
+
)
|
703 |
+
|
704 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
705 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
706 |
+
|
707 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
708 |
+
|
709 |
+
if attn_mask is not None:
|
710 |
+
attn_mask = attn_mask.unsqueeze(0)
|
711 |
+
attn_weights += attn_mask
|
712 |
+
|
713 |
+
if key_padding_mask is not None:
|
714 |
+
# don't attend to padding symbols
|
715 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
716 |
+
if not is_tpu:
|
717 |
+
attn_weights = attn_weights.masked_fill(
|
718 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
719 |
+
float("-inf"),
|
720 |
+
)
|
721 |
+
else:
|
722 |
+
attn_weights = attn_weights.transpose(0, 2)
|
723 |
+
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
724 |
+
attn_weights = attn_weights.transpose(0, 2)
|
725 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
726 |
+
|
727 |
+
if before_softmax:
|
728 |
+
return attn_weights, v, position_bias
|
729 |
+
|
730 |
+
if position_bias is not None:
|
731 |
+
if self.gru_rel_pos == 1:
|
732 |
+
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
|
733 |
+
_B, _H, _L, __ = query_layer.size()
|
734 |
+
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
|
735 |
+
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
|
736 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
737 |
+
position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
|
738 |
+
|
739 |
+
position_bias = position_bias.view(attn_weights.size())
|
740 |
+
|
741 |
+
attn_weights = attn_weights + position_bias
|
742 |
+
|
743 |
+
attn_weights_float = F.softmax(
|
744 |
+
attn_weights, dim=-1
|
745 |
+
)
|
746 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
747 |
+
attn_probs = self.dropout_module(attn_weights)
|
748 |
+
|
749 |
+
assert v is not None
|
750 |
+
attn = torch.bmm(attn_probs, v)
|
751 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
752 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
753 |
+
attn = self.out_proj(attn)
|
754 |
+
attn_weights: Optional[Tensor] = None
|
755 |
+
if need_weights:
|
756 |
+
attn_weights = attn_weights_float.view(
|
757 |
+
bsz, self.num_heads, tgt_len, src_len
|
758 |
+
).transpose(1, 0)
|
759 |
+
if not need_head_weights:
|
760 |
+
# average attention weights over heads
|
761 |
+
attn_weights = attn_weights.mean(dim=0)
|
762 |
+
|
763 |
+
return attn, attn_weights, position_bias
|
764 |
+
|
765 |
+
@staticmethod
|
766 |
+
def _append_prev_key_padding_mask(
|
767 |
+
key_padding_mask: Optional[Tensor],
|
768 |
+
prev_key_padding_mask: Optional[Tensor],
|
769 |
+
batch_size: int,
|
770 |
+
src_len: int,
|
771 |
+
static_kv: bool,
|
772 |
+
) -> Optional[Tensor]:
|
773 |
+
# saved key padding masks have shape (bsz, seq_len)
|
774 |
+
if prev_key_padding_mask is not None and static_kv:
|
775 |
+
new_key_padding_mask = prev_key_padding_mask
|
776 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
777 |
+
new_key_padding_mask = torch.cat(
|
778 |
+
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
779 |
+
)
|
780 |
+
# During incremental decoding, as the padding token enters and
|
781 |
+
# leaves the frame, there will be a time when prev or current
|
782 |
+
# is None
|
783 |
+
elif prev_key_padding_mask is not None:
|
784 |
+
if src_len > prev_key_padding_mask.size(1):
|
785 |
+
filler = torch.zeros(
|
786 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
787 |
+
device=prev_key_padding_mask.device,
|
788 |
+
)
|
789 |
+
new_key_padding_mask = torch.cat(
|
790 |
+
[prev_key_padding_mask.float(), filler.float()], dim=1
|
791 |
+
)
|
792 |
+
else:
|
793 |
+
new_key_padding_mask = prev_key_padding_mask.float()
|
794 |
+
elif key_padding_mask is not None:
|
795 |
+
if src_len > key_padding_mask.size(1):
|
796 |
+
filler = torch.zeros(
|
797 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
798 |
+
device=key_padding_mask.device,
|
799 |
+
)
|
800 |
+
new_key_padding_mask = torch.cat(
|
801 |
+
[filler.float(), key_padding_mask.float()], dim=1
|
802 |
+
)
|
803 |
+
else:
|
804 |
+
new_key_padding_mask = key_padding_mask.float()
|
805 |
+
else:
|
806 |
+
new_key_padding_mask = prev_key_padding_mask
|
807 |
+
return new_key_padding_mask
|
808 |
+
|
809 |
+
def _get_input_buffer(
|
810 |
+
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
811 |
+
) -> Dict[str, Optional[Tensor]]:
|
812 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
813 |
+
if result is not None:
|
814 |
+
return result
|
815 |
+
else:
|
816 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
817 |
+
return empty_result
|
818 |
+
|
819 |
+
def _set_input_buffer(
|
820 |
+
self,
|
821 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
822 |
+
buffer: Dict[str, Optional[Tensor]],
|
823 |
+
):
|
824 |
+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
825 |
+
|
826 |
+
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
827 |
+
return attn_weights
|