diff --git a/.gitattributes b/.gitattributes index e76c1b5458b542c84c8e4da5d3558b711037b313..c4630d0c9d98e7209d22036f66d8118f49818287 100644 --- a/.gitattributes +++ b/.gitattributes @@ -21,3 +21,7 @@ pretrain/hubert-soft-0d54a1f4.pt filter=lfs diff=lfs merge=lfs -text pretrain/medium.pt filter=lfs diff=lfs merge=lfs -text pretrain/nsf_hifigan/model filter=lfs diff=lfs merge=lfs -text pretrain/rmvpe.pt filter=lfs diff=lfs merge=lfs -text +results/audio_Shengshuyan_12key_sovits_pm_1.wav filter=lfs diff=lfs merge=lfs -text +results/audio_Shengshuyan_12key_sovits_pm.wav filter=lfs diff=lfs merge=lfs -text +results/tts_Shengshuyan_auto_sovits_pm_2.wav filter=lfs diff=lfs merge=lfs -text +results/vocals_Shengshuyan_0key_sovits_pm.wav filter=lfs diff=lfs merge=lfs -text diff --git a/results/audio_Shengshuyan_12key_sovits_pm.wav b/results/audio_Shengshuyan_12key_sovits_pm.wav new file mode 100644 index 0000000000000000000000000000000000000000..2d1d6444c222d68d05ed20e465048f20749ff703 --- /dev/null +++ b/results/audio_Shengshuyan_12key_sovits_pm.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:802153e84b5b9415e88fead81f26dfa39276411458d1d7b46bfa158ae2268084 +size 1603520 diff --git a/results/audio_Shengshuyan_12key_sovits_pm_1.wav b/results/audio_Shengshuyan_12key_sovits_pm_1.wav new file mode 100644 index 0000000000000000000000000000000000000000..f375e0201ed4fba34e8d6056a6a1fe87e1b494ac --- /dev/null +++ b/results/audio_Shengshuyan_12key_sovits_pm_1.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c8b0907cc03021cdd9b396a5f25158e02ba36186733f03e1d7d9e59048dc1450 +size 1534724 diff --git a/results/tts_Shengshuyan_0key_sovits_pm.wav b/results/tts_Shengshuyan_0key_sovits_pm.wav new file mode 100644 index 0000000000000000000000000000000000000000..304c6114f644c3e0fd2263c61354027088240e5f Binary files /dev/null and b/results/tts_Shengshuyan_0key_sovits_pm.wav differ diff --git a/results/tts_Shengshuyan_0key_sovits_pm_1.wav b/results/tts_Shengshuyan_0key_sovits_pm_1.wav new file mode 100644 index 0000000000000000000000000000000000000000..ad88b73c1d8735f303f978a0b6c4f8be05214ba7 Binary files /dev/null and b/results/tts_Shengshuyan_0key_sovits_pm_1.wav differ diff --git a/results/tts_Shengshuyan_0key_sovits_pm_2.wav b/results/tts_Shengshuyan_0key_sovits_pm_2.wav new file mode 100644 index 0000000000000000000000000000000000000000..ad88b73c1d8735f303f978a0b6c4f8be05214ba7 Binary files /dev/null and b/results/tts_Shengshuyan_0key_sovits_pm_2.wav differ diff --git a/results/tts_Shengshuyan_0key_sovits_pm_3.wav b/results/tts_Shengshuyan_0key_sovits_pm_3.wav new file mode 100644 index 0000000000000000000000000000000000000000..5093dd32224ad7020dc7eae1563f5677ffcae1d8 Binary files /dev/null and b/results/tts_Shengshuyan_0key_sovits_pm_3.wav differ diff --git a/results/tts_Shengshuyan_12key_sovits_pm.wav b/results/tts_Shengshuyan_12key_sovits_pm.wav new file mode 100644 index 0000000000000000000000000000000000000000..50e996a53d9a154fc0832a408266fd995aebe873 Binary files /dev/null and b/results/tts_Shengshuyan_12key_sovits_pm.wav differ diff --git a/results/tts_Shengshuyan_auto_sovits_pm.wav b/results/tts_Shengshuyan_auto_sovits_pm.wav new file mode 100644 index 0000000000000000000000000000000000000000..4299ee443672c189781e7911ca077a7c7414eedf Binary files /dev/null and b/results/tts_Shengshuyan_auto_sovits_pm.wav differ diff --git a/results/tts_Shengshuyan_auto_sovits_pm_1.wav b/results/tts_Shengshuyan_auto_sovits_pm_1.wav new file mode 100644 index 0000000000000000000000000000000000000000..4299ee443672c189781e7911ca077a7c7414eedf Binary files /dev/null and b/results/tts_Shengshuyan_auto_sovits_pm_1.wav differ diff --git a/results/tts_Shengshuyan_auto_sovits_pm_2.wav b/results/tts_Shengshuyan_auto_sovits_pm_2.wav new file mode 100644 index 0000000000000000000000000000000000000000..bff524512a28aac6107ede16a514218986b88649 --- /dev/null +++ b/results/tts_Shengshuyan_auto_sovits_pm_2.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:008ab4c60f5e1b31466e4b60d97ababbf8ec476c01d0e2c1c330972d72fda1fd +size 1172752 diff --git a/results/tts_Shengshuyan_auto_sovits_pm_3.wav b/results/tts_Shengshuyan_auto_sovits_pm_3.wav new file mode 100644 index 0000000000000000000000000000000000000000..3bb17a3334b9ee3709bb44f6308fa7ae8e062b45 Binary files /dev/null and b/results/tts_Shengshuyan_auto_sovits_pm_3.wav differ diff --git a/results/tts_Shengshuyan_auto_sovits_pm_4.wav b/results/tts_Shengshuyan_auto_sovits_pm_4.wav new file mode 100644 index 0000000000000000000000000000000000000000..784c2a10ee8337fc5d06be75a6f63200583f741a Binary files /dev/null and b/results/tts_Shengshuyan_auto_sovits_pm_4.wav differ diff --git a/results/vocals_Shengshuyan_0key_sovits_pm.wav b/results/vocals_Shengshuyan_0key_sovits_pm.wav new file mode 100644 index 0000000000000000000000000000000000000000..b2862c4bff9131ee79deddab7e68c21cd15d6e1e --- /dev/null +++ b/results/vocals_Shengshuyan_0key_sovits_pm.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d4057ce12f656c6325fa616e97b1fe82e392e3e0bafb518384c5f548dd549ca8 +size 18330192 diff --git a/trained/put_trained_checkpoints_here b/trained/put_trained_checkpoints_here new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vdecoder/__init__.py b/vdecoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vdecoder/__pycache__/__init__.cpython-38.pyc b/vdecoder/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ce51834c9a81637d20eee2dc7e30dcb23468c03 Binary files /dev/null and b/vdecoder/__pycache__/__init__.cpython-38.pyc differ diff --git a/vdecoder/hifigan/__pycache__/env.cpython-38.pyc b/vdecoder/hifigan/__pycache__/env.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f03b083666a5e083064e2720184da6598bda0e30 Binary files /dev/null and b/vdecoder/hifigan/__pycache__/env.cpython-38.pyc differ diff --git a/vdecoder/hifigan/__pycache__/models.cpython-38.pyc b/vdecoder/hifigan/__pycache__/models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd75f1a49892d52b8fea59b20923d8dc6194d967 Binary files /dev/null and b/vdecoder/hifigan/__pycache__/models.cpython-38.pyc differ diff --git a/vdecoder/hifigan/__pycache__/utils.cpython-38.pyc b/vdecoder/hifigan/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5595402babbb22f9311cd42feb48f8399121ca8b Binary files /dev/null and b/vdecoder/hifigan/__pycache__/utils.cpython-38.pyc differ diff --git a/vdecoder/hifigan/env.py b/vdecoder/hifigan/env.py new file mode 100644 index 0000000000000000000000000000000000000000..2bdbc95d4f7a8bad8fd4f5eef657e2b51d946056 --- /dev/null +++ b/vdecoder/hifigan/env.py @@ -0,0 +1,15 @@ +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/vdecoder/hifigan/models.py b/vdecoder/hifigan/models.py new file mode 100644 index 0000000000000000000000000000000000000000..107553368ff1798f72df21c6d5a965260f5a60fd --- /dev/null +++ b/vdecoder/hifigan/models.py @@ -0,0 +1,557 @@ +import json +import os + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from .env import AttrDict +from .utils import get_padding, init_weights + +LRELU_SLOPE = 0.1 + + +def load_model(model_path, device='cuda'): + config_file = os.path.join(os.path.split(model_path)[0], 'config.json') + with open(config_file) as f: + data = f.read() + + global h + json_config = json.loads(data) + h = AttrDict(json_config) + + generator = Generator(h).to(device) + + cp_dict = torch.load(model_path) + generator.load_state_dict(cp_dict['generator']) + generator.eval() + generator.remove_weight_norm() + del cp_dict + return generator, h + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +def padDiff(x): + return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0) + +class SineGen(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.flag_for_pulse = flag_for_pulse + self.onnx = False + + def _f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + def _f02sine(self, f0_values): + """ f0_values: (batchsize, length, dim) + where dim indicates fundamental tone and overtones + """ + # convert to F0 in rad. The interger part n can be ignored + # because 2 * np.pi * n doesn't affect phase + rad_values = (f0_values / self.sampling_rate) % 1 + + # initial phase noise (no noise for fundamental component) + rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \ + device=f0_values.device) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + + # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) + if not self.flag_for_pulse: + # for normal case + + # To prevent torch.cumsum numerical overflow, + # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1. + # Buffer tmp_over_one_idx indicates the time step to add -1. + # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi + tmp_over_one = torch.cumsum(rad_values, 1) % 1 + tmp_over_one_idx = (padDiff(tmp_over_one)) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + + sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) + * 2 * np.pi) + else: + # If necessary, make sure that the first time step of every + # voiced segments is sin(pi) or cos(0) + # This is used for pulse-train generation + + # identify the last time step in unvoiced segments + uv = self._f02uv(f0_values) + uv_1 = torch.roll(uv, shifts=-1, dims=1) + uv_1[:, -1, :] = 1 + u_loc = (uv < 1) * (uv_1 > 0) + + # get the instantanouse phase + tmp_cumsum = torch.cumsum(rad_values, dim=1) + # different batch needs to be processed differently + for idx in range(f0_values.shape[0]): + temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] + temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] + # stores the accumulation of i.phase within + # each voiced segments + tmp_cumsum[idx, :, :] = 0 + tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum + + # rad_values - tmp_cumsum: remove the accumulation of i.phase + # within the previous voiced segment. + i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) + + # get the sines + sines = torch.cos(i_phase * 2 * np.pi) + return sines + + def forward(self, f0, upp=None): + """ sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + if self.onnx: + with torch.no_grad(): + f0 = f0[:, None].transpose(1, 2) + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) + # fundamental component + f0_buf[:, :, 0] = f0[:, :, 0] + for idx in np.arange(self.harmonic_num): + f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * ( + idx + 2 + ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic + rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化 + rand_ini = torch.rand( + f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device + ) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化 + tmp_over_one *= upp + tmp_over_one = F.interpolate( + tmp_over_one.transpose(2, 1), + scale_factor=upp, + mode="linear", + align_corners=True, + ).transpose(2, 1) + rad_values = F.interpolate( + rad_values.transpose(2, 1), scale_factor=upp, mode="nearest" + ).transpose( + 2, 1 + ) ####### + tmp_over_one %= 1 + tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + sine_waves = torch.sin( + torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi + ) + sine_waves = sine_waves * self.sine_amp + uv = self._f02uv(f0) + uv = F.interpolate( + uv.transpose(2, 1), scale_factor=upp, mode="nearest" + ).transpose(2, 1) + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + else: + with torch.no_grad(): + # fundamental component + fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) + + # generate sine waveforms + sine_waves = self._f02sine(fn) * self.sine_amp + + # generate uv signal + # uv = torch.ones(f0.shape) + # uv = uv * (f0 > self.voiced_threshold) + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen(sampling_rate, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x, upp=None): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + sine_wavs, uv, _ = self.l_sin_gen(x, upp) + sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(self.l_linear.weight.dtype))) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + + self.num_kernels = len(h["resblock_kernel_sizes"]) + self.num_upsamples = len(h["upsample_rates"]) + self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h["upsample_rates"])) + self.m_source = SourceModuleHnNSF( + sampling_rate=h["sampling_rate"], + harmonic_num=8) + self.noise_convs = nn.ModuleList() + self.conv_pre = weight_norm(Conv1d(h["inter_channels"], h["upsample_initial_channel"], 7, 1, padding=3)) + resblock = ResBlock1 if h["resblock"] == '1' else ResBlock2 + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h["upsample_rates"], h["upsample_kernel_sizes"])): + c_cur = h["upsample_initial_channel"] // (2 ** (i + 1)) + self.ups.append(weight_norm( + ConvTranspose1d(h["upsample_initial_channel"] // (2 ** i), h["upsample_initial_channel"] // (2 ** (i + 1)), + k, u, padding=(k - u +1 ) // 2))) + if i + 1 < len(h["upsample_rates"]): # + stride_f0 = np.prod(h["upsample_rates"][i + 1:]) + self.noise_convs.append(Conv1d( + 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2)) + else: + self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h["upsample_initial_channel"] // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(h["resblock_kernel_sizes"], h["resblock_dilation_sizes"])): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.cond = nn.Conv1d(h['gin_channels'], h['upsample_initial_channel'], 1) + self.upp = np.prod(h["upsample_rates"]) + self.onnx = False + + def OnnxExport(self): + self.onnx = True + self.m_source.l_sin_gen.onnx = True + + def forward(self, x, f0, g=None): + # print(1,x.shape,f0.shape,f0[:, None].shape) + if not self.onnx: + f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + # print(2,f0.shape) + har_source, noi_source, uv = self.m_source(f0, self.upp) + har_source = har_source.transpose(1, 2) + x = self.conv_pre(x) + x = x + self.cond(g) + # print(124,x.shape,har_source.shape) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + # print(3,x.shape) + x = self.ups[i](x) + x_source = self.noise_convs[i](har_source) + # print(4,x_source.shape,har_source.shape,x.shape) + x = x + x_source + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, periods=None): + super(MultiPeriodDiscriminator, self).__init__() + self.periods = periods if periods is not None else [2, 3, 5, 7, 11] + self.discriminators = nn.ModuleList() + for period in self.periods: + self.discriminators.append(DiscriminatorP(period)) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList([ + AvgPool1d(4, 2, padding=2), + AvgPool1d(4, 2, padding=2) + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg ** 2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/vdecoder/hifigan/nvSTFT.py b/vdecoder/hifigan/nvSTFT.py new file mode 100644 index 0000000000000000000000000000000000000000..b3321b2ee3da28f43c2650ea011e14d5e1cdcc94 --- /dev/null +++ b/vdecoder/hifigan/nvSTFT.py @@ -0,0 +1,109 @@ +import os + +import librosa +import numpy as np +import soundfile as sf +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +os.environ["LRU_CACHE_CAPACITY"] = "3" + +def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): + sampling_rate = None + try: + data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile. + except Exception as ex: + print(f"'{full_path}' failed to load.\nException:") + print(ex) + if return_empty_on_exception: + return [], sampling_rate or target_sr or 32000 + else: + raise Exception(ex) + + if len(data.shape) > 1: + data = data[:, 0] + assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension) + + if np.issubdtype(data.dtype, np.integer): # if audio data is type int + max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX + else: # if audio data is type fp32 + max_mag = max(np.amax(data), -np.amin(data)) + max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32 + + data = torch.FloatTensor(data.astype(np.float32))/max_mag + + if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except + return [], sampling_rate or target_sr or 32000 + if target_sr is not None and sampling_rate != target_sr: + data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr)) + sampling_rate = target_sr + + return data, sampling_rate + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + +class STFT(): + def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5): + self.target_sr = sr + + self.n_mels = n_mels + self.n_fft = n_fft + self.win_size = win_size + self.hop_length = hop_length + self.fmin = fmin + self.fmax = fmax + self.clip_val = clip_val + self.mel_basis = {} + self.hann_window = {} + + def get_mel(self, y, center=False): + sampling_rate = self.target_sr + n_mels = self.n_mels + n_fft = self.n_fft + win_size = self.win_size + hop_length = self.hop_length + fmin = self.fmin + fmax = self.fmax + clip_val = self.clip_val + + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + if fmax not in self.mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) + self.mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) + self.hann_window[str(y.device)] = torch.hann_window(self.win_size).to(y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_length)/2), int((n_fft-hop_length)/2)), mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_length, win_length=win_size, window=self.hann_window[str(y.device)], + center=center, pad_mode='reflect', normalized=False, onesided=True) + # print(111,spec) + spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + # print(222,spec) + spec = torch.matmul(self.mel_basis[str(fmax)+'_'+str(y.device)], spec) + # print(333,spec) + spec = dynamic_range_compression_torch(spec, clip_val=clip_val) + # print(444,spec) + return spec + + def __call__(self, audiopath): + audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr) + spect = self.get_mel(audio.unsqueeze(0)).squeeze(0) + return spect + +stft = STFT() diff --git a/vdecoder/hifigan/utils.py b/vdecoder/hifigan/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e519e2b7ed8fe5f93266d21d727a30173699f88b --- /dev/null +++ b/vdecoder/hifigan/utils.py @@ -0,0 +1,68 @@ +import glob +import os + +# matplotlib.use("Agg") +import matplotlib.pylab as plt +import torch +from torch.nn.utils import weight_norm + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print("Saving checkpoint to {}".format(filepath)) + torch.save(obj, filepath) + print("Complete.") + + +def del_old_checkpoints(cp_dir, prefix, n_models=2): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) # get checkpoint paths + cp_list = sorted(cp_list)# sort by iter + if len(cp_list) > n_models: # if more than n_models models are found + for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models + open(cp, 'w').close()# empty file contents + os.unlink(cp)# delete file (move to trash when using Colab) + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] + diff --git a/vdecoder/hifiganwithsnake/alias/__init__.py b/vdecoder/hifiganwithsnake/alias/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..be97a33248ae6378c6736586774abda11cfbdeba --- /dev/null +++ b/vdecoder/hifiganwithsnake/alias/__init__.py @@ -0,0 +1,6 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +from .act import * # noqa: F403 +from .filter import * # noqa: F403 +from .resample import * # noqa: F403 diff --git a/vdecoder/hifiganwithsnake/alias/act.py b/vdecoder/hifiganwithsnake/alias/act.py new file mode 100644 index 0000000000000000000000000000000000000000..e46b3467b73b90df51c1d19032b90d26595aca6e --- /dev/null +++ b/vdecoder/hifiganwithsnake/alias/act.py @@ -0,0 +1,130 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import pow, sin +from torch.nn import Parameter + +from .resample import DownSample1d, UpSample1d + + +class Activation1d(nn.Module): + def __init__(self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x + + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.beta = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.beta = Parameter(torch.ones(in_features) * alpha) + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta = x + 1/b * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze( + 0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + return x + + +class Mish(nn.Module): + """ + Mish activation function is proposed in "Mish: A Self + Regularized Non-Monotonic Neural Activation Function" + paper, https://arxiv.org/abs/1908.08681. + """ + + def __init__(self): + super().__init__() + + def forward(self, x): + return x * torch.tanh(F.softplus(x)) + + +class SnakeAlias(nn.Module): + def __init__(self, + channels, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12, + C = None): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = SnakeBeta(channels, alpha_logscale=True) + self.upsample = UpSample1d(up_ratio, up_kernel_size, C) + self.downsample = DownSample1d(down_ratio, down_kernel_size, C) + + # x: [B,C,T] + def forward(self, x, C=None): + x = self.upsample(x, C) + x = self.act(x) + x = self.downsample(x) + + return x \ No newline at end of file diff --git a/vdecoder/hifiganwithsnake/alias/filter.py b/vdecoder/hifiganwithsnake/alias/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..3942eb3ae547a2f500d5c47defdd70cd29ea4655 --- /dev/null +++ b/vdecoder/hifiganwithsnake/alias/filter.py @@ -0,0 +1,110 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +if 'sinc' in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where(x == 0, + torch.tensor(1., device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = (kernel_size % 2 == 0) + half_size = kernel_size // 2 + + #For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.: + beta = 0.1102 * (A - 8.7) + elif A >= 21.: + beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) + else: + beta = 0. + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = (torch.arange(-half_size, half_size) + 0.5) + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__(self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = 'replicate', + kernel_size: int = 12, + C=None): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = (kernel_size % 2 == 0) + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + self.conv1d_block = None + if C is not None: + self.conv1d_block = [nn.Conv1d(C,C,kernel_size,stride=self.stride, groups=C, bias=False),] + self.conv1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1)) + self.conv1d_block[0].requires_grad_(False) + + #input [B, C, T] + def forward(self, x): + if self.conv1d_block[0].weight.device != x.device: + self.conv1d_block[0] = self.conv1d_block[0].to(x.device) + if self.conv1d_block is None: + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), + stride=self.stride, groups=C) + else: + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = self.conv1d_block[0](x) + + return out \ No newline at end of file diff --git a/vdecoder/hifiganwithsnake/alias/resample.py b/vdecoder/hifiganwithsnake/alias/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..a364403f0977bc8bcffbb4764081e4bd3619467a --- /dev/null +++ b/vdecoder/hifiganwithsnake/alias/resample.py @@ -0,0 +1,72 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from torch.nn import functional as F + +from .filter import LowPassFilter1d, kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None, C=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + self.conv_transpose1d_block = None + if C is not None: + self.conv_transpose1d_block = [nn.ConvTranspose1d(C, + C, + kernel_size=self.kernel_size, + stride=self.stride, + groups=C, + bias=False + ),] + self.conv_transpose1d_block[0].weight = nn.Parameter(self.filter.expand(C, -1, -1).clone()) + self.conv_transpose1d_block[0].requires_grad_(False) + + + + # x: [B, C, T] + def forward(self, x, C=None): + if self.conv_transpose1d_block[0].weight.device != x.device: + self.conv_transpose1d_block[0] = self.conv_transpose1d_block[0].to(x.device) + if self.conv_transpose1d_block is None: + if C is None: + _, C, _ = x.shape + # print("snake.conv_t.in:",x.shape) + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + # print("snake.conv_t.out:",x.shape) + x = x[..., self.pad_left:-self.pad_right] + else: + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * self.conv_transpose1d_block[0](x) + x = x[..., self.pad_left:-self.pad_right] + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None, C=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + C=C) + + + def forward(self, x): + xx = self.lowpass(x) + + return xx \ No newline at end of file diff --git a/vdecoder/hifiganwithsnake/env.py b/vdecoder/hifiganwithsnake/env.py new file mode 100644 index 0000000000000000000000000000000000000000..2bdbc95d4f7a8bad8fd4f5eef657e2b51d946056 --- /dev/null +++ b/vdecoder/hifiganwithsnake/env.py @@ -0,0 +1,15 @@ +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/vdecoder/hifiganwithsnake/models.py b/vdecoder/hifiganwithsnake/models.py new file mode 100644 index 0000000000000000000000000000000000000000..08bbda9b77b095d81ca8d8a9e5e8ebe20fa9bcfa --- /dev/null +++ b/vdecoder/hifiganwithsnake/models.py @@ -0,0 +1,576 @@ +import json +import os + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from vdecoder.hifiganwithsnake.alias.act import SnakeAlias + +from .env import AttrDict +from .utils import get_padding, init_weights + +LRELU_SLOPE = 0.1 + + +def load_model(model_path, device='cuda'): + config_file = os.path.join(os.path.split(model_path)[0], 'config.json') + with open(config_file) as f: + data = f.read() + + global h + json_config = json.loads(data) + h = AttrDict(json_config) + + generator = Generator(h).to(device) + + cp_dict = torch.load(model_path) + generator.load_state_dict(cp_dict['generator']) + generator.eval() + generator.remove_weight_norm() + del cp_dict + return generator, h + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), C=None): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + self.num_layers = len(self.convs1) + len(self.convs2) + self.activations = nn.ModuleList([ + SnakeAlias(channels, C=C) for _ in range(self.num_layers) + ]) + + def forward(self, x, DIM=None): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x, DIM) + xt = c1(xt) + xt = a2(xt, DIM) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), C=None): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + self.num_layers = len(self.convs) + self.activations = nn.ModuleList([ + SnakeAlias(channels, C=C) for _ in range(self.num_layers) + ]) + + def forward(self, x, DIM=None): + for c,a in zip(self.convs, self.activations): + xt = a(x, DIM) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +def padDiff(x): + return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0) + +class SineGen(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.flag_for_pulse = flag_for_pulse + self.onnx = False + + def _f02uv(self, f0): + # generate uv signal + uv = (f0 > self.voiced_threshold).type(torch.float32) + return uv + + def _f02sine(self, f0_values): + """ f0_values: (batchsize, length, dim) + where dim indicates fundamental tone and overtones + """ + # convert to F0 in rad. The interger part n can be ignored + # because 2 * np.pi * n doesn't affect phase + rad_values = (f0_values / self.sampling_rate) % 1 + + # initial phase noise (no noise for fundamental component) + rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \ + device=f0_values.device) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + + # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) + if not self.flag_for_pulse: + # for normal case + + # To prevent torch.cumsum numerical overflow, + # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1. + # Buffer tmp_over_one_idx indicates the time step to add -1. + # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi + tmp_over_one = torch.cumsum(rad_values, 1) % 1 + tmp_over_one_idx = (padDiff(tmp_over_one)) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + + sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) + * 2 * np.pi) + else: + # If necessary, make sure that the first time step of every + # voiced segments is sin(pi) or cos(0) + # This is used for pulse-train generation + + # identify the last time step in unvoiced segments + uv = self._f02uv(f0_values) + uv_1 = torch.roll(uv, shifts=-1, dims=1) + uv_1[:, -1, :] = 1 + u_loc = (uv < 1) * (uv_1 > 0) + + # get the instantanouse phase + tmp_cumsum = torch.cumsum(rad_values, dim=1) + # different batch needs to be processed differently + for idx in range(f0_values.shape[0]): + temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] + temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] + # stores the accumulation of i.phase within + # each voiced segments + tmp_cumsum[idx, :, :] = 0 + tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum + + # rad_values - tmp_cumsum: remove the accumulation of i.phase + # within the previous voiced segment. + i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) + + # get the sines + sines = torch.cos(i_phase * 2 * np.pi) + return sines + + def forward(self, f0, upp=None): + """ sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + + if self.onnx: + with torch.no_grad(): + f0 = f0[:, None].transpose(1, 2) + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) + # fundamental component + f0_buf[:, :, 0] = f0[:, :, 0] + for idx in np.arange(self.harmonic_num): + f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * ( + idx + 2 + ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic + rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化 + rand_ini = torch.rand( + f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device + ) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化 + tmp_over_one *= upp + tmp_over_one = F.interpolate( + tmp_over_one.transpose(2, 1), + scale_factor=upp, + mode="linear", + align_corners=True, + ).transpose(2, 1) + rad_values = F.interpolate( + rad_values.transpose(2, 1), scale_factor=upp, mode="nearest" + ).transpose( + 2, 1 + ) ####### + tmp_over_one %= 1 + tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + sine_waves = torch.sin( + torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi + ) + sine_waves = sine_waves * self.sine_amp + uv = self._f02uv(f0) + uv = F.interpolate( + uv.transpose(2, 1), scale_factor=upp, mode="nearest" + ).transpose(2, 1) + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + else: + with torch.no_grad(): + # fundamental component + fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)) + + # generate sine waveforms + sine_waves = self._f02sine(fn) * self.sine_amp + + # generate uv signal + # uv = torch.ones(f0.shape) + # uv = uv * (f0 > self.voiced_threshold) + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen(sampling_rate, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x, upp=None): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + sine_wavs, uv, _ = self.l_sin_gen(x, upp) + sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(self.l_linear.weight.dtype))) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + + self.num_kernels = len(h["resblock_kernel_sizes"]) + self.num_upsamples = len(h["upsample_rates"]) + self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h["upsample_rates"])) + self.m_source = SourceModuleHnNSF( + sampling_rate=h["sampling_rate"], + harmonic_num=8) + self.noise_convs = nn.ModuleList() + self.conv_pre = weight_norm(Conv1d(h["inter_channels"], h["upsample_initial_channel"], 7, 1, padding=3)) + resblock = ResBlock1 if h["resblock"] == '1' else ResBlock2 + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h["upsample_rates"], h["upsample_kernel_sizes"])): + c_cur = h["upsample_initial_channel"] // (2 ** (i + 1)) + self.ups.append(weight_norm( + ConvTranspose1d(h["upsample_initial_channel"] // (2 ** i), h["upsample_initial_channel"] // (2 ** (i + 1)), + k, u, padding=(k - u + 1) // 2))) + if i + 1 < len(h["upsample_rates"]): # + stride_f0 = np.prod(h["upsample_rates"][i + 1:]) + self.noise_convs.append(Conv1d( + 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+ 1) // 2)) + else: + self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) + self.resblocks = nn.ModuleList() + self.snakes = nn.ModuleList() + for i in range(len(self.ups)): + ch = h["upsample_initial_channel"] // (2 ** (i + 1)) + self.snakes.append(SnakeAlias(h["upsample_initial_channel"] // (2 ** (i)), C = h["upsample_initial_channel"] >> i)) + for j, (k, d) in enumerate(zip(h["resblock_kernel_sizes"], h["resblock_dilation_sizes"])): + self.resblocks.append(resblock(h, ch, k, d, C = h["upsample_initial_channel"] >> (i + 1))) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.snake_post = SnakeAlias(ch, C = h["upsample_initial_channel"] >> len(self.ups)) + self.cond = nn.Conv1d(h['gin_channels'], h['upsample_initial_channel'], 1) + self.upp = np.prod(h["upsample_rates"]) + self.onnx = False + + def OnnxExport(self): + self.onnx = True + self.m_source.l_sin_gen.onnx = True + + def forward(self, x, f0, g=None): + # print(1,x.shape,f0.shape,f0[:, None].shape) + if not self.onnx: + f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + # print(2,f0.shape) + har_source, noi_source, uv = self.m_source(f0, self.upp) + har_source = har_source.transpose(1, 2) + x = self.conv_pre(x) + x = x + self.cond(g) + # print(124,x.shape,har_source.shape) + for i in range(self.num_upsamples): + # print(f"self.snakes.{i}.pre:", x.shape) + x = self.snakes[i](x) + # print(f"self.snakes.{i}.after:", x.shape) + x = self.ups[i](x) + x_source = self.noise_convs[i](har_source) + # print(4,x_source.shape,har_source.shape,x.shape) + x = x + x_source + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + # print(f"self.resblocks.{i}.after:", xs.shape) + x = xs / self.num_kernels + x = self.snake_post(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, periods=None): + super(MultiPeriodDiscriminator, self).__init__() + self.periods = periods if periods is not None else [2, 3, 5, 7, 11] + self.discriminators = nn.ModuleList() + for period in self.periods: + self.discriminators.append(DiscriminatorP(period)) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList([ + AvgPool1d(4, 2, padding=2), + AvgPool1d(4, 2, padding=2) + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg ** 2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/vdecoder/hifiganwithsnake/nvSTFT.py b/vdecoder/hifiganwithsnake/nvSTFT.py new file mode 100644 index 0000000000000000000000000000000000000000..b3321b2ee3da28f43c2650ea011e14d5e1cdcc94 --- /dev/null +++ b/vdecoder/hifiganwithsnake/nvSTFT.py @@ -0,0 +1,109 @@ +import os + +import librosa +import numpy as np +import soundfile as sf +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +os.environ["LRU_CACHE_CAPACITY"] = "3" + +def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): + sampling_rate = None + try: + data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile. + except Exception as ex: + print(f"'{full_path}' failed to load.\nException:") + print(ex) + if return_empty_on_exception: + return [], sampling_rate or target_sr or 32000 + else: + raise Exception(ex) + + if len(data.shape) > 1: + data = data[:, 0] + assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension) + + if np.issubdtype(data.dtype, np.integer): # if audio data is type int + max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX + else: # if audio data is type fp32 + max_mag = max(np.amax(data), -np.amin(data)) + max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32 + + data = torch.FloatTensor(data.astype(np.float32))/max_mag + + if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except + return [], sampling_rate or target_sr or 32000 + if target_sr is not None and sampling_rate != target_sr: + data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr)) + sampling_rate = target_sr + + return data, sampling_rate + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + +class STFT(): + def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5): + self.target_sr = sr + + self.n_mels = n_mels + self.n_fft = n_fft + self.win_size = win_size + self.hop_length = hop_length + self.fmin = fmin + self.fmax = fmax + self.clip_val = clip_val + self.mel_basis = {} + self.hann_window = {} + + def get_mel(self, y, center=False): + sampling_rate = self.target_sr + n_mels = self.n_mels + n_fft = self.n_fft + win_size = self.win_size + hop_length = self.hop_length + fmin = self.fmin + fmax = self.fmax + clip_val = self.clip_val + + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + if fmax not in self.mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) + self.mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) + self.hann_window[str(y.device)] = torch.hann_window(self.win_size).to(y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_length)/2), int((n_fft-hop_length)/2)), mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_length, win_length=win_size, window=self.hann_window[str(y.device)], + center=center, pad_mode='reflect', normalized=False, onesided=True) + # print(111,spec) + spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + # print(222,spec) + spec = torch.matmul(self.mel_basis[str(fmax)+'_'+str(y.device)], spec) + # print(333,spec) + spec = dynamic_range_compression_torch(spec, clip_val=clip_val) + # print(444,spec) + return spec + + def __call__(self, audiopath): + audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr) + spect = self.get_mel(audio.unsqueeze(0)).squeeze(0) + return spect + +stft = STFT() diff --git a/vdecoder/hifiganwithsnake/utils.py b/vdecoder/hifiganwithsnake/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e519e2b7ed8fe5f93266d21d727a30173699f88b --- /dev/null +++ b/vdecoder/hifiganwithsnake/utils.py @@ -0,0 +1,68 @@ +import glob +import os + +# matplotlib.use("Agg") +import matplotlib.pylab as plt +import torch +from torch.nn.utils import weight_norm + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print("Saving checkpoint to {}".format(filepath)) + torch.save(obj, filepath) + print("Complete.") + + +def del_old_checkpoints(cp_dir, prefix, n_models=2): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) # get checkpoint paths + cp_list = sorted(cp_list)# sort by iter + if len(cp_list) > n_models: # if more than n_models models are found + for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models + open(cp, 'w').close()# empty file contents + os.unlink(cp)# delete file (move to trash when using Colab) + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] + diff --git a/vdecoder/nsf_hifigan/__pycache__/env.cpython-38.pyc b/vdecoder/nsf_hifigan/__pycache__/env.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ffd7330cd081266b6af94ff4d329f0e0e84297e Binary files /dev/null and b/vdecoder/nsf_hifigan/__pycache__/env.cpython-38.pyc differ diff --git a/vdecoder/nsf_hifigan/__pycache__/models.cpython-38.pyc b/vdecoder/nsf_hifigan/__pycache__/models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..713841f649c1107b07c6dd76593b36c8d8e91440 Binary files /dev/null and b/vdecoder/nsf_hifigan/__pycache__/models.cpython-38.pyc differ diff --git a/vdecoder/nsf_hifigan/__pycache__/nvSTFT.cpython-38.pyc b/vdecoder/nsf_hifigan/__pycache__/nvSTFT.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a48a576591650766115e1485f1e9799452057531 Binary files /dev/null and b/vdecoder/nsf_hifigan/__pycache__/nvSTFT.cpython-38.pyc differ diff --git a/vdecoder/nsf_hifigan/__pycache__/utils.cpython-38.pyc b/vdecoder/nsf_hifigan/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5a7c868fcfa303445aea455dfb596537b460cd6 Binary files /dev/null and b/vdecoder/nsf_hifigan/__pycache__/utils.cpython-38.pyc differ diff --git a/vdecoder/nsf_hifigan/env.py b/vdecoder/nsf_hifigan/env.py new file mode 100644 index 0000000000000000000000000000000000000000..2bdbc95d4f7a8bad8fd4f5eef657e2b51d946056 --- /dev/null +++ b/vdecoder/nsf_hifigan/env.py @@ -0,0 +1,15 @@ +import os +import shutil + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def build_env(config, config_name, path): + t_path = os.path.join(path, config_name) + if config != t_path: + os.makedirs(path, exist_ok=True) + shutil.copyfile(config, os.path.join(path, config_name)) diff --git a/vdecoder/nsf_hifigan/models.py b/vdecoder/nsf_hifigan/models.py new file mode 100644 index 0000000000000000000000000000000000000000..8a35b134d814008c3990d019d1de502ff10dd86f --- /dev/null +++ b/vdecoder/nsf_hifigan/models.py @@ -0,0 +1,441 @@ +import json +import os + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from .env import AttrDict +from .utils import get_padding, init_weights + +LRELU_SLOPE = 0.1 + + +def load_model(model_path, device='cuda'): + h = load_config(model_path) + + generator = Generator(h).to(device) + + cp_dict = torch.load(model_path, map_location=device) + generator.load_state_dict(cp_dict['generator']) + generator.eval() + generator.remove_weight_norm() + del cp_dict + return generator, h + +def load_config(model_path): + config_file = os.path.join(os.path.split(model_path)[0], 'config.json') + with open(config_file) as f: + data = f.read() + + json_config = json.loads(data) + h = AttrDict(json_config) + return h + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class SineGen(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + + def _f02uv(self, f0): + # generate uv signal + uv = torch.ones_like(f0) + uv = uv * (f0 > self.voiced_threshold) + return uv + + @torch.no_grad() + def forward(self, f0, upp): + """ sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + f0 = f0.unsqueeze(-1) + fn = torch.multiply(f0, torch.arange(1, self.dim + 1, device=f0.device).reshape((1, 1, -1))) + rad_values = (fn / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化 + rand_ini = torch.rand(fn.shape[0], fn.shape[2], device=fn.device) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + is_half = rad_values.dtype is not torch.float32 + tmp_over_one = torch.cumsum(rad_values.double(), 1) # % 1 #####%1意味着后面的cumsum无法再优化 + if is_half: + tmp_over_one = tmp_over_one.half() + else: + tmp_over_one = tmp_over_one.float() + tmp_over_one *= upp + tmp_over_one = F.interpolate( + tmp_over_one.transpose(2, 1), scale_factor=upp, + mode='linear', align_corners=True + ).transpose(2, 1) + rad_values = F.interpolate(rad_values.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1) + tmp_over_one %= 1 + tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + rad_values = rad_values.double() + cumsum_shift = cumsum_shift.double() + sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi) + if is_half: + sine_waves = sine_waves.half() + else: + sine_waves = sine_waves.float() + sine_waves = sine_waves * self.sine_amp + uv = self._f02uv(f0) + uv = F.interpolate(uv.transpose(2, 1), scale_factor=upp, mode='nearest').transpose(2, 1) + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen(sampling_rate, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x, upp): + sine_wavs, uv, _ = self.l_sin_gen(x, upp) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + return sine_merge + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.m_source = SourceModuleHnNSF( + sampling_rate=h.sampling_rate, + harmonic_num=8 + ) + self.noise_convs = nn.ModuleList() + self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)) + resblock = ResBlock1 if h.resblock == '1' else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + c_cur = h.upsample_initial_channel // (2 ** (i + 1)) + self.ups.append(weight_norm( + ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)), + k, u, padding=(k - u) // 2))) + if i + 1 < len(h.upsample_rates): # + stride_f0 = int(np.prod(h.upsample_rates[i + 1:])) + self.noise_convs.append(Conv1d( + 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2)) + else: + self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) + self.resblocks = nn.ModuleList() + ch = h.upsample_initial_channel + for i in range(len(self.ups)): + ch //= 2 + for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.upp = int(np.prod(h.upsample_rates)) + + def forward(self, x, f0): + har_source = self.m_source(f0, self.upp).transpose(1, 2) + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + x_source = self.noise_convs[i](har_source) + x = x + x_source + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, periods=None): + super(MultiPeriodDiscriminator, self).__init__() + self.periods = periods if periods is not None else [2, 3, 5, 7, 11] + self.discriminators = nn.ModuleList() + for period in self.periods: + self.discriminators.append(DiscriminatorP(period)) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList([ + AvgPool1d(4, 2, padding=2), + AvgPool1d(4, 2, padding=2) + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg ** 2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses diff --git a/vdecoder/nsf_hifigan/nvSTFT.py b/vdecoder/nsf_hifigan/nvSTFT.py new file mode 100644 index 0000000000000000000000000000000000000000..e756cca561a45bde435f36447e6681bfa17e34aa --- /dev/null +++ b/vdecoder/nsf_hifigan/nvSTFT.py @@ -0,0 +1,132 @@ +import os + +import librosa +import numpy as np +import soundfile as sf +import torch +import torch.nn.functional as F +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn + +os.environ["LRU_CACHE_CAPACITY"] = "3" + +def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False): + sampling_rate = None + try: + data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile. + except Exception as ex: + print(f"'{full_path}' failed to load.\nException:") + print(ex) + if return_empty_on_exception: + return [], sampling_rate or target_sr or 48000 + else: + raise Exception(ex) + + if len(data.shape) > 1: + data = data[:, 0] + assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension) + + if np.issubdtype(data.dtype, np.integer): # if audio data is type int + max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX + else: # if audio data is type fp32 + max_mag = max(np.amax(data), -np.amin(data)) + max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32 + + data = torch.FloatTensor(data.astype(np.float32))/max_mag + + if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except + return [], sampling_rate or target_sr or 48000 + if target_sr is not None and sampling_rate != target_sr: + data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr)) + sampling_rate = target_sr + + return data, sampling_rate + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + +class STFT(): + def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5): + self.target_sr = sr + + self.n_mels = n_mels + self.n_fft = n_fft + self.win_size = win_size + self.hop_length = hop_length + self.fmin = fmin + self.fmax = fmax + self.clip_val = clip_val + self.mel_basis = {} + self.hann_window = {} + + def get_mel(self, y, keyshift=0, speed=1, center=False): + sampling_rate = self.target_sr + n_mels = self.n_mels + n_fft = self.n_fft + win_size = self.win_size + hop_length = self.hop_length + fmin = self.fmin + fmax = self.fmax + clip_val = self.clip_val + + factor = 2 ** (keyshift / 12) + n_fft_new = int(np.round(n_fft * factor)) + win_size_new = int(np.round(win_size * factor)) + hop_length_new = int(np.round(hop_length * speed)) + + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + mel_basis_key = str(fmax)+'_'+str(y.device) + if mel_basis_key not in self.mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) + self.mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device) + + keyshift_key = str(keyshift)+'_'+str(y.device) + if keyshift_key not in self.hann_window: + self.hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device) + + pad_left = (win_size_new - hop_length_new) //2 + pad_right = max((win_size_new- hop_length_new + 1) //2, win_size_new - y.size(-1) - pad_left) + if pad_right < y.size(-1): + mode = 'reflect' + else: + mode = 'constant' + y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode = mode) + y = y.squeeze(1) + + spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=self.hann_window[keyshift_key], + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) + # print(111,spec) + spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + if keyshift != 0: + size = n_fft // 2 + 1 + resize = spec.size(1) + if resize < size: + spec = F.pad(spec, (0, 0, 0, size-resize)) + spec = spec[:, :size, :] * win_size / win_size_new + + # print(222,spec) + spec = torch.matmul(self.mel_basis[mel_basis_key], spec) + # print(333,spec) + spec = dynamic_range_compression_torch(spec, clip_val=clip_val) + # print(444,spec) + return spec + + def __call__(self, audiopath): + audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr) + spect = self.get_mel(audio.unsqueeze(0)).squeeze(0) + return spect + +stft = STFT() diff --git a/vdecoder/nsf_hifigan/utils.py b/vdecoder/nsf_hifigan/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..58d0e701d377e318fe0302743c27bdb4d6e089ec --- /dev/null +++ b/vdecoder/nsf_hifigan/utils.py @@ -0,0 +1,70 @@ +import glob +import os + +import matplotlib +import matplotlib.pylab as plt +import torch +from torch.nn.utils import weight_norm + +matplotlib.use("Agg") + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print("Complete.") + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print("Saving checkpoint to {}".format(filepath)) + torch.save(obj, filepath) + print("Complete.") + + +def del_old_checkpoints(cp_dir, prefix, n_models=2): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) # get checkpoint paths + cp_list = sorted(cp_list)# sort by iter + if len(cp_list) > n_models: # if more than n_models models are found + for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models + open(cp, 'w').close()# empty file contents + os.unlink(cp)# delete file (move to trash when using Colab) + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] + diff --git a/vencoder/CNHubertLarge.py b/vencoder/CNHubertLarge.py new file mode 100644 index 0000000000000000000000000000000000000000..f43694762f92c5d839d358825f157f5d1a4ff6f6 --- /dev/null +++ b/vencoder/CNHubertLarge.py @@ -0,0 +1,36 @@ +import torch +from fairseq import checkpoint_utils + +from vencoder.encoder import SpeechEncoder + + +class CNHubertLarge(SpeechEncoder): + def __init__(self, vec_path="pretrain/chinese-hubert-large-fairseq-ckpt.pt", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + self.hidden_dim = 1024 + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( + [vec_path], + suffix="", + ) + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + self.model = models[0].to(self.dev) + self.model.eval() + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + padding_mask = torch.BoolTensor(feats.shape).fill_(False) + inputs = { + "source": feats.to(wav.device), + "padding_mask": padding_mask.to(wav.device) + } + with torch.no_grad(): + logits = self.model.extract_features(**inputs) + return logits[0].transpose(1, 2) \ No newline at end of file diff --git a/vencoder/ContentVec256L12_Onnx.py b/vencoder/ContentVec256L12_Onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..466e6c128b88acdfb94392662086e6752d503a27 --- /dev/null +++ b/vencoder/ContentVec256L12_Onnx.py @@ -0,0 +1,33 @@ +import onnxruntime +import torch + +from vencoder.encoder import SpeechEncoder + + +class ContentVec256L12_Onnx(SpeechEncoder): + def __init__(self, vec_path="pretrain/vec-256-layer-12.onnx", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + self.hidden_dim = 256 + if device is None: + self.dev = torch.device("cpu") + else: + self.dev = torch.device(device) + + if device == 'cuda' or device == torch.device("cuda"): + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + providers = ['CPUExecutionProvider'] + + self.model = onnxruntime.InferenceSession(vec_path, providers=providers) + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + feats = feats.unsqueeze(0).cpu().detach().numpy() + onnx_input = {self.model.get_inputs()[0].name: feats} + logits = self.model.run(None, onnx_input) + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) diff --git a/vencoder/ContentVec256L9.py b/vencoder/ContentVec256L9.py new file mode 100644 index 0000000000000000000000000000000000000000..c973090dd4cdaa3d8ca07d9007c26633883c36a7 --- /dev/null +++ b/vencoder/ContentVec256L9.py @@ -0,0 +1,38 @@ +import torch +from fairseq import checkpoint_utils + +from vencoder.encoder import SpeechEncoder + + +class ContentVec256L9(SpeechEncoder): + def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( + [vec_path], + suffix="", + ) + self.hidden_dim = 256 + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + self.model = models[0].to(self.dev) + self.model.eval() + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + padding_mask = torch.BoolTensor(feats.shape).fill_(False) + inputs = { + "source": feats.to(wav.device), + "padding_mask": padding_mask.to(wav.device), + "output_layer": 9, # layer 9 + } + with torch.no_grad(): + logits = self.model.extract_features(**inputs) + feats = self.model.final_proj(logits[0]) + return feats.transpose(1, 2) diff --git a/vencoder/ContentVec256L9_Onnx.py b/vencoder/ContentVec256L9_Onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..a27e1f76655d9dc9fcc41d05d11b4a1ac5d85b90 --- /dev/null +++ b/vencoder/ContentVec256L9_Onnx.py @@ -0,0 +1,32 @@ +import onnxruntime +import torch + +from vencoder.encoder import SpeechEncoder + + +class ContentVec256L9_Onnx(SpeechEncoder): + def __init__(self, vec_path="pretrain/vec-256-layer-9.onnx", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + self.hidden_dim = 256 + if device is None: + self.dev = torch.device("cpu") + else: + self.dev = torch.device(device) + if device == 'cpu' or device == torch.device("cpu") or device is None: + providers = ['CPUExecutionProvider'] + elif device == 'cuda' or device == torch.device("cuda"): + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + self.model = onnxruntime.InferenceSession(vec_path, providers=providers) + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + feats = feats.unsqueeze(0).cpu().detach().numpy() + onnx_input = {self.model.get_inputs()[0].name: feats} + logits = self.model.run(None, onnx_input) + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) + \ No newline at end of file diff --git a/vencoder/ContentVec768L12.py b/vencoder/ContentVec768L12.py new file mode 100644 index 0000000000000000000000000000000000000000..066b824b68447b5c860730c9f11b7be415068b46 --- /dev/null +++ b/vencoder/ContentVec768L12.py @@ -0,0 +1,37 @@ +import torch +from fairseq import checkpoint_utils + +from vencoder.encoder import SpeechEncoder + + +class ContentVec768L12(SpeechEncoder): + def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + self.hidden_dim = 768 + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( + [vec_path], + suffix="", + ) + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + self.model = models[0].to(self.dev) + self.model.eval() + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + padding_mask = torch.BoolTensor(feats.shape).fill_(False) + inputs = { + "source": feats.to(wav.device), + "padding_mask": padding_mask.to(wav.device), + "output_layer": 12, # layer 12 + } + with torch.no_grad(): + logits = self.model.extract_features(**inputs) + return logits[0].transpose(1, 2) diff --git a/vencoder/ContentVec768L12_Onnx.py b/vencoder/ContentVec768L12_Onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..e737594526fd09f19353b85c11d4c357a325af48 --- /dev/null +++ b/vencoder/ContentVec768L12_Onnx.py @@ -0,0 +1,33 @@ +import onnxruntime +import torch + +from vencoder.encoder import SpeechEncoder + + +class ContentVec768L12_Onnx(SpeechEncoder): + def __init__(self, vec_path="pretrain/vec-768-layer-12.onnx", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + self.hidden_dim = 768 + if device is None: + self.dev = torch.device("cpu") + else: + self.dev = torch.device(device) + + if device == 'cuda' or device == torch.device("cuda"): + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + providers = ['CPUExecutionProvider'] + + self.model = onnxruntime.InferenceSession(vec_path, providers=providers) + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + feats = feats.unsqueeze(0).cpu().detach().numpy() + onnx_input = {self.model.get_inputs()[0].name: feats} + logits = self.model.run(None, onnx_input) + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) diff --git a/vencoder/ContentVec768L9_Onnx.py b/vencoder/ContentVec768L9_Onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..3bd0f337bbf5fa261ea43adfab2377fced7c9e7c --- /dev/null +++ b/vencoder/ContentVec768L9_Onnx.py @@ -0,0 +1,33 @@ +import onnxruntime +import torch + +from vencoder.encoder import SpeechEncoder + + +class ContentVec768L9_Onnx(SpeechEncoder): + def __init__(self,vec_path = "pretrain/vec-768-layer-9.onnx",device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + self.hidden_dim = 768 + if device is None: + self.dev = torch.device("cpu") + else: + self.dev = torch.device(device) + + if device == 'cuda' or device == torch.device("cuda"): + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + providers = ['CPUExecutionProvider'] + + self.model = onnxruntime.InferenceSession(vec_path, providers=providers) + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + feats = feats.unsqueeze(0).cpu().detach().numpy() + onnx_input = {self.model.get_inputs()[0].name: feats} + logits = self.model.run(None, onnx_input) + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) diff --git a/vencoder/DPHubert.py b/vencoder/DPHubert.py new file mode 100644 index 0000000000000000000000000000000000000000..130064ff3ea5c24017be2f0faa204fc4c7dbd078 --- /dev/null +++ b/vencoder/DPHubert.py @@ -0,0 +1,29 @@ +import torch + +from vencoder.dphubert.model import wav2vec2_model +from vencoder.encoder import SpeechEncoder + + +class DPHubert(SpeechEncoder): + def __init__(self, vec_path="pretrain/DPHuBERT-sp0.75.pth", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + ckpt = torch.load(vec_path) + self.hidden_dim = 768 + self.model = wav2vec2_model(**ckpt["config"]).to(self.dev) + self.model.load_state_dict(ckpt["state_dict"], strict=False) + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats[None, :] + with torch.no_grad(): + with torch.inference_mode(): + units = self.model(feats)[0] + return units.transpose(1,2) diff --git a/vencoder/HubertSoft.py b/vencoder/HubertSoft.py new file mode 100644 index 0000000000000000000000000000000000000000..423c159c44f0e5cb820a911a47b71ae1478d725d --- /dev/null +++ b/vencoder/HubertSoft.py @@ -0,0 +1,28 @@ +import torch + +from vencoder.encoder import SpeechEncoder +from vencoder.hubert import hubert_model + + +class HubertSoft(SpeechEncoder): + def __init__(self, vec_path="pretrain/hubert-soft-0d54a1f4.pt", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + hubert_soft = hubert_model.hubert_soft(vec_path) + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + self.hidden_dim = 256 + self.model = hubert_soft.to(self.dev) + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats[None,None,:] + with torch.no_grad(): + with torch.inference_mode(): + units = self.model.units(feats) + return units.transpose(1,2) diff --git a/vencoder/HubertSoft_Onnx.py b/vencoder/HubertSoft_Onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..038d78e8ffa0804cb63b146f8122b3f2bba2f637 --- /dev/null +++ b/vencoder/HubertSoft_Onnx.py @@ -0,0 +1,33 @@ +import onnxruntime +import torch + +from vencoder.encoder import SpeechEncoder + + +class HubertSoft_Onnx(SpeechEncoder): + def __init__(self, vec_path="pretrain/hubert-soft.onnx", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + self.hidden_dim = 256 + if device is None: + self.dev = torch.device("cpu") + else: + self.dev = torch.device(device) + + if device == 'cuda' or device == torch.device("cuda"): + providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + else: + providers = ['CPUExecutionProvider'] + + self.model = onnxruntime.InferenceSession(vec_path, providers=providers) + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + feats = feats.unsqueeze(0).cpu().detach().numpy() + onnx_input = {self.model.get_inputs()[0].name: feats} + logits = self.model.run(None, onnx_input) + return torch.tensor(logits[0]).transpose(1, 2).to(self.dev) diff --git a/vencoder/WavLMBasePlus.py b/vencoder/WavLMBasePlus.py new file mode 100644 index 0000000000000000000000000000000000000000..99df15be73c0c4774cea83a376f79fb68405bfa1 --- /dev/null +++ b/vencoder/WavLMBasePlus.py @@ -0,0 +1,32 @@ +import torch + +from vencoder.encoder import SpeechEncoder +from vencoder.wavlm.WavLM import WavLM, WavLMConfig + + +class WavLMBasePlus(SpeechEncoder): + def __init__(self, vec_path="pretrain/WavLM-Base+.pt", device=None): + super().__init__() + print("load model(s) from {}".format(vec_path)) + checkpoint = torch.load(vec_path) + self.cfg = WavLMConfig(checkpoint['cfg']) + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + self.hidden_dim = self.cfg.encoder_embed_dim + self.model = WavLM(self.cfg) + self.model.load_state_dict(checkpoint['model']) + self.model.to(self.dev).eval() + + def encoder(self, wav): + feats = wav + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + if self.cfg.normalize: + feats = torch.nn.functional.layer_norm(feats, feats.shape) + with torch.no_grad(): + with torch.inference_mode(): + units = self.model.extract_features(feats[None, :])[0] + return units.transpose(1, 2) diff --git a/vencoder/WhisperPPG.py b/vencoder/WhisperPPG.py new file mode 100644 index 0000000000000000000000000000000000000000..86af53e69b5f60f143a4acce0949c24812e327d1 --- /dev/null +++ b/vencoder/WhisperPPG.py @@ -0,0 +1,31 @@ +import torch + +from vencoder.encoder import SpeechEncoder +from vencoder.whisper.audio import log_mel_spectrogram, pad_or_trim +from vencoder.whisper.model import ModelDimensions, Whisper + + +class WhisperPPG(SpeechEncoder): + def __init__(self, vec_path="pretrain/medium.pt", device=None): + super().__init__() + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + checkpoint = torch.load(vec_path, map_location=device) + dims = ModelDimensions(**checkpoint["dims"]) + model = Whisper(dims) + model.load_state_dict(checkpoint["model_state_dict"]) + self.hidden_dim = dims + self.model = model.to(self.dev) + + def encoder(self, wav): + audio = wav + audln = audio.shape[0] + ppgln = audln // 320 + audio = pad_or_trim(audio) + mel = log_mel_spectrogram(audio).to(self.dev) + with torch.no_grad(): + ppg = self.model.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy() + ppg = torch.FloatTensor(ppg[:ppgln, ]).to(self.dev) + return ppg[None, :, :].transpose(1, 2) diff --git a/vencoder/WhisperPPGLarge.py b/vencoder/WhisperPPGLarge.py new file mode 100644 index 0000000000000000000000000000000000000000..e1d3ea212bff50c11c2711077c67800b06318e3a --- /dev/null +++ b/vencoder/WhisperPPGLarge.py @@ -0,0 +1,31 @@ +import torch + +from vencoder.encoder import SpeechEncoder +from vencoder.whisper.audio import log_mel_spectrogram, pad_or_trim +from vencoder.whisper.model import ModelDimensions, Whisper + + +class WhisperPPGLarge(SpeechEncoder): + def __init__(self, vec_path="pretrain/large-v2.pt", device=None): + super().__init__() + if device is None: + self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + self.dev = torch.device(device) + checkpoint = torch.load(vec_path, map_location=device) + dims = ModelDimensions(**checkpoint["dims"]) + model = Whisper(dims) + model.load_state_dict(checkpoint["model_state_dict"]) + self.hidden_dim = dims + self.model = model.to(self.dev) + + def encoder(self, wav): + audio = wav + audln = audio.shape[0] + ppgln = audln // 320 + audio = pad_or_trim(audio) + mel = log_mel_spectrogram(audio).to(self.dev) + with torch.no_grad(): + ppg = self.model.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy() + ppg = torch.FloatTensor(ppg[:ppgln, ]).to(self.dev) + return ppg[None, :, :].transpose(1, 2) diff --git a/vencoder/__init__.py b/vencoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vencoder/__pycache__/ContentVec256L9.cpython-38.pyc b/vencoder/__pycache__/ContentVec256L9.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8fc8558506b8c5e2c3dfb7a97807e67e85d18a1 Binary files /dev/null and b/vencoder/__pycache__/ContentVec256L9.cpython-38.pyc differ diff --git a/vencoder/__pycache__/ContentVec768L12.cpython-38.pyc b/vencoder/__pycache__/ContentVec768L12.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7e52b508144faffae93c6b5316a9cc3ccc34aec Binary files /dev/null and b/vencoder/__pycache__/ContentVec768L12.cpython-38.pyc differ diff --git a/vencoder/__pycache__/HubertSoft.cpython-38.pyc b/vencoder/__pycache__/HubertSoft.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dabc52d3acc38ab96e6c32ef508773c615315da4 Binary files /dev/null and b/vencoder/__pycache__/HubertSoft.cpython-38.pyc differ diff --git a/vencoder/__pycache__/WhisperPPG.cpython-38.pyc b/vencoder/__pycache__/WhisperPPG.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2be57b190df7c4e942215830002d076eeb287441 Binary files /dev/null and b/vencoder/__pycache__/WhisperPPG.cpython-38.pyc differ diff --git a/vencoder/__pycache__/__init__.cpython-38.pyc b/vencoder/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4534db51c4ba0062bf334ce4fa8a086ee9a60e3a Binary files /dev/null and b/vencoder/__pycache__/__init__.cpython-38.pyc differ diff --git a/vencoder/__pycache__/encoder.cpython-38.pyc b/vencoder/__pycache__/encoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..434e6b38e0b56c61d72d06d5f1aeb4e7f61a89d6 Binary files /dev/null and b/vencoder/__pycache__/encoder.cpython-38.pyc differ diff --git a/vencoder/dphubert/__init__.py b/vencoder/dphubert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vencoder/dphubert/components.py b/vencoder/dphubert/components.py new file mode 100644 index 0000000000000000000000000000000000000000..be5cc8ce28f11f4f1339578a9d2658740f103283 --- /dev/null +++ b/vencoder/dphubert/components.py @@ -0,0 +1,1410 @@ +"""Building blocks for speech SSL models supporting pruning. + +Originally from: +https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/components.py + +""" + +import math +from collections import defaultdict +from typing import List, Optional, Tuple + +import torch +from torch import Tensor, nn +from torch.nn import Module + +from .hardconcrete import HardConcrete +from .pruning_utils import ( + prune_conv1d_layer, + prune_layer_norm, + prune_linear_layer, +) + + +def _init_transformer_params(module): + """ + Initialize the weights of Transformer module in Wav2Vec2/HuBERT. + + If the module is ``nn.Linear``, normalize the weight with mean 0 and standard deviation 0.02. + If ``bias`` is set to ``True`` in the module, set ``bias`` to 0. + + If the module is ``nn.Embedding``, normalize the weight with mean 0 and standard deviation 0.02. + If ``padding_idx`` is not None, set the weight of padding to 0. + + Note: + Ths method corresponds to + `init_bert_params + `__ + in the original ``fairseq`` implementation. + """ + + def normal_(data): + data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device)) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class LayerNorm(nn.LayerNorm): + """Layer norm with transpose""" + + def forward(self, input: Tensor) -> Tensor: + x = input.transpose(-2, -1) + x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = x.transpose(-2, -1) + return x + + +class ConvLayerBlock(Module): + """Convolution unit of FeatureExtractor""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + bias: bool, + layer_norm: Optional[Module], + prune_conv_channels: bool = False, + ): + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.layer_norm = layer_norm + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + bias=bias, + ) + + if prune_conv_channels: + self.hard_concrete = HardConcrete(n_in=out_channels, init_mean=0.01) + else: + self.hard_concrete = None + + def forward( + self, + x: Tensor, + length: Optional[Tensor], + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x (Tensor): Shape: ``[batch, in_channels, in_frame]``. + length (Tensor or None, optional): Shape ``[batch, ]``. + Returns: + Tensor: Shape ``[batch, out_channels, out_frames]``. + Optional[Tensor]: Shape ``[batch, ]``. + """ + x = self.conv(x) + if self.layer_norm is not None: + x = self.layer_norm(x) + x = nn.functional.gelu(x) + + if self.hard_concrete is not None: + channel_mask = self.hard_concrete() # hard concrete mask, (out_channels,) + x = x * channel_mask.unsqueeze(-1) + + if length is not None: + length = torch.div(length - self.kernel_size, self.stride, rounding_mode="floor") + 1 + # When input length is 0, the resulting length can be negative. So fix it here. + length = torch.max(torch.zeros_like(length), length) + return x, length + + def get_num_params_and_out_channels(self, in_channels): + if self.hard_concrete is not None: + out_channels = self.hard_concrete.l0_norm() + else: + out_channels = self.conv.out_channels + + num_params = in_channels * out_channels * self.kernel_size + if self.conv.bias is not None: + num_params += out_channels + if self.layer_norm is not None: + num_params += out_channels * 2 + + return num_params, out_channels + + +class FeatureExtractor(Module): + """Extract features from audio + + Args: + conv_layers (nn.ModuleList): + convolution layers + """ + + def __init__( + self, + conv_layers: nn.ModuleList, + ): + super().__init__() + self.conv_layers = conv_layers + + # NOTE: a dummy weight used to save the soft mask of the last conv layer + self.dummy_weight = nn.Parameter( + torch.ones(conv_layers[-1].conv.out_channels, dtype=torch.float32), + requires_grad=False + ) + + def forward( + self, + x: Tensor, + length: Optional[Tensor], + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x (Tensor): + Input Tensor representing a batch of audio, + shape: ``[batch, time]``. + length (Tensor or None, optional): + Valid length of each input sample. shape: ``[batch, ]``. + + Returns: + Tensor: + The resulting feature, shape: ``[batch, frame, feature]`` + Optional[Tensor]: + Valid length of each output sample. shape: ``[batch, ]``. + """ + if x.ndim != 2: + raise ValueError("Expected the input Tensor to be 2D (batch, time), " "but received {list(x.shape)}") + + x = x.unsqueeze(1) # (batch, channel==1, frame) + for layer in self.conv_layers: + x, length = layer(x, length) # (batch, feature, frame) + x = x.transpose(1, 2) # (batch, frame, feature) + x = x * self.dummy_weight + return x, length + + def get_num_params_and_final_out_channels(self): + in_channels = 1 + num_params = 0 + for layer in self.conv_layers: + layer_params, in_channels = layer.get_num_params_and_out_channels(in_channels) + num_params += layer_params + + num_params += in_channels # dummy weight + + return num_params, in_channels + + def prune(self): + """"Prune conv layers and dummy weight based on hardconcrete parameters. + This is an in-place operation. + """ + new_config = [] # [(output_channel, kernel_size, stride), ...] + for idx, layer in enumerate(self.conv_layers): + if layer.hard_concrete is not None: + assert not layer.hard_concrete.training + mask = layer.hard_concrete() # (out_features,) + index = mask.nonzero().squeeze(-1) # 2D -> 1D + assert len(index) > 0, f"Conv channels pruned to zero at index {idx}" + new_config.append( + (len(index), layer.kernel_size, layer.stride) + ) + + # prune the current layer + prune_conv1d_layer(layer.conv, index, "output") + if layer.layer_norm is not None: + prune_layer_norm(layer.layer_norm, index) + + # prune the next layer + if idx == len(self.conv_layers) - 1: + self.dummy_weight.data *= mask + self.dummy_weight = nn.Parameter( + self.dummy_weight.index_select(0, index).clone().detach(), requires_grad=False + ) + else: + self.conv_layers[idx+1].conv.weight.data *= mask.unsqueeze(-1) + prune_conv1d_layer(self.conv_layers[idx+1].conv, index, dim="input") + + layer.hard_concrete = None + else: + new_config.append( + (layer.conv.out_channels, layer.kernel_size, layer.stride) + ) + index = torch.arange(layer.conv.out_channels, dtype=torch.long) + + return new_config, index + + +class FeatureProjection(Module): + """Layer that connects FeatureExtractor and Encoder + + Projects features to encoder dimension. + + Args: + in_features (int): Input feature dim. + out_features (int): Output feature dim. + dropout (float): Dropout probability. + """ + + def __init__( + self, + in_features: int, + out_features: int, + dropout: float, + ): + super().__init__() + self.layer_norm = nn.LayerNorm(in_features) + self.projection = nn.Linear( + in_features, + out_features, + ) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + """ + Args: + x (Tensor): + Feature Tensor. shape: ``[batch, frame, in_feature]`` + Returns: + Tensor: Projected features. ``[batch, frame, out_feature]``. + """ + x = self.layer_norm(x) + x = self.projection(x) + x = self.dropout(x) + return x + + def get_num_params(self, in_features): + return in_features * 2 + (in_features + 1) * self.projection.out_features + + +class ConvolutionalPositionalEmbedding(Module): + """Positional embedding which is placed at the beginning of Transformer. + + Args: + embed_dim (int): Feature dimension of the input Tensor. + kernel_size (int): The number of frames to be use. + groups (int): The number of groups in feature dimensions. + """ + + def __init__( + self, + embed_dim: int, + kernel_size: int, + groups: int, + ): + super().__init__() + self.embed_dim = embed_dim + self.kernel_size = kernel_size + self.conv = nn.Conv1d( + in_channels=embed_dim, + out_channels=embed_dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + groups=groups, + ) + + self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + self.num_remove: int = 1 if kernel_size % 2 == 0 else 0 + + def __prepare_scriptable__(self): + for hook in self.conv._forward_pre_hooks.values(): + # The hook we want to remove is an instance of WeightNorm class, so + # normally we would do `if isinstance(...)` but this class is not accessible + # because of shadowing, so we check the module name directly. + # https://github.com/pytorch/pytorch/blob/be0ca00c5ce260eb5bcec3237357f7a30cc08983/torch/nn/utils/__init__.py#L3 + if hook.__module__ == "torch.nn.utils.weight_norm" and hook.__class__.__name__ == "WeightNorm": + torch.nn.utils.remove_weight_norm(self.conv) + return self + + def forward(self, x): + """ + Args: + x (Tensor): shape ``[batch, frame, feature]``. + + Returns: + Tensor: The resulting feature. Shape ``[batch, frame, feature]``. + """ + x = x.transpose(-2, -1) + x = self.conv(x) + if self.num_remove > 0: + x = x[..., : -self.num_remove] + x = torch.nn.functional.gelu(x) + x = x.transpose(-2, -1) + return x + + +class SelfAttention(Module): + """Multihead Self Attention module + + Args: + embed_dim (int): Total dimension of the model. + num_heads (int): The number of heads. + dropout (float, optional): + Dropout probability on attn_output_weights. Default: ``0.0`` + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + head_dim: int, + dropout: float = 0.0, + prune_heads: bool = False, # whether to prune attention heads + prune_layer: bool = False, # whether to prune entire attention layers + ): + super().__init__() + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.dropout = torch.nn.Dropout(dropout) + + self.scaling = self.head_dim**-0.5 + + self.k_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True) + self.v_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True) + self.q_proj = nn.Linear(embed_dim, num_heads * head_dim, bias=True) + self.out_proj = nn.Linear(num_heads * head_dim, embed_dim, bias=True) + + if prune_heads: + self.hard_concrete_for_heads = HardConcrete(n_in=num_heads, init_mean=0.01) + else: + self.hard_concrete_for_heads = None + + if prune_layer: + self.hard_concrete_for_layer = HardConcrete(n_in=1, init_mean=0.01) + else: + self.hard_concrete_for_layer = None + + def forward( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x (Tensor): shape: ``[batch_size, sequence_length, embed_dim]``. + attention_mask (Tensor or ``None``, optional): + shape: ``[batch_size, 1, sequence_length, sequence_length]`` + position_bias: Not used. Only for the compatibility with :py:class:`WavLMSelfAttention`. + key_padding_mask (Tensor or ``None``): Not used. Only for the compatibility with + :py:class:`WavLMSelfAttention`. + Returns: + (Tensor, ``None``): The resulting attention output and ``None`` (necessary for compatibility + with :py:class:`WavLMSelAttention`). + Attention output shape: ``[batch, sequence_length, embed_dim]``. + """ + if x.ndim != 3 or x.shape[2] != self.embed_dim: + raise ValueError( + f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). " f"Found {x.shape}." + ) + batch_size, length, embed_dim = x.size() + + shape = (batch_size, length, self.num_heads, self.head_dim) + q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd + k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L + v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd + + # scale down q to avoid value overflow. + weights = (self.scaling * q) @ k # B, nH, L, L + if attention_mask is not None: + weights += attention_mask + # subtracting a constant value from the tensor won't change the output of softmax. + # apply the subtraction to avoid value overflow in torch.nn.functional.softmax. + # for more details, please see Equation 7 in https://arxiv.org/abs/2112.08778 + weights = weights - weights.max(dim=-1, keepdim=True)[0] + + weights = torch.nn.functional.softmax(weights, dim=-1) + weights = self.dropout(weights) + + output = weights @ v # B, nH, L, Hd + + if self.hard_concrete_for_heads is not None: + head_mask = self.hard_concrete_for_heads() # (nH,) + output = output * head_mask.unsqueeze(-1).unsqueeze(-1) + + output = output.transpose(2, 1).reshape(batch_size, length, self.num_heads * self.head_dim) + + output = self.out_proj(output) + + if self.hard_concrete_for_layer is not None: + layer_mask = self.hard_concrete_for_layer() # (1,) + output = output * layer_mask + + return output, None # Necessary for compatibility with WavLMSelAttention + + def get_num_params(self): + if self.hard_concrete_for_heads is not None: + num_heads = self.hard_concrete_for_heads.l0_norm() + else: + num_heads = self.num_heads + num_params = (self.embed_dim + 1) * num_heads * self.head_dim * 3 \ + + (num_heads * self.head_dim + 1) * self.embed_dim + + if self.hard_concrete_for_layer is not None: + num_params *= self.hard_concrete_for_layer.l0_norm() + + return num_params + + def prune(self): + new_config = { + "use_attention": True, + "num_heads": self.num_heads, + } + if self.hard_concrete_for_layer is not None: + assert not self.hard_concrete_for_layer.training + layer_mask = self.hard_concrete_for_layer() # (1,) + self.out_proj.weight.data *= layer_mask + self.out_proj.bias.data *= layer_mask + if layer_mask == 0: + new_config["use_attention"] = False + self.hard_concrete_for_layer = None + + if self.hard_concrete_for_heads is not None: + assert not self.hard_concrete_for_heads.training + head_mask = self.hard_concrete_for_heads() # (num_heads,) + new_config["num_heads"] = len(head_mask.nonzero()) + if new_config["num_heads"] == 0: + new_config["use_attention"] = False + else: + full_mask = head_mask.repeat_interleave(self.head_dim) + full_index = full_mask.nonzero().squeeze(-1) # 1D + + prune_linear_layer(self.k_proj, full_index, "output") + prune_linear_layer(self.v_proj, full_index, "output") + prune_linear_layer(self.q_proj, full_index, "output") + + self.out_proj.weight.data *= full_mask + prune_linear_layer(self.out_proj, full_index, "input") + self.hard_concrete_for_heads = None + + return new_config + + +class WavLMSelfAttention(SelfAttention): + """Multi-headed self-attention for WavLM model :cite:`chen2022wavlm`. + + Args: + embed_dim (int): Total dimension of the model. + num_heads (int): The number of heads. + dropout (float, optional): Dropout probability on attn_output_weights. (Default: to ``0.0``) + bias (bool, optional): If ``True``, add bias to input / output projection layers. (Default: ``True``) + has_relative_attention_bias (bool, optional): If ``True``, apply relative position embedding. + Necessary in the first encoder layer, but not in the subsequent ones. (Default: ``False``) + num_buckets (int, optional): Number of buckets for relative position embedding. (Default: ``32``) + max_distance (int, optional): Naximum distance for relative position embedding. (Default: ``128``) + gru_rel_pos (bool, optional): If ``True``, apply gated relative position embedding. (Default: ``False``) + """ + + def __init__( + self, + embed_dim: int, + total_num_heads: int, + remaining_heads: Optional[List[int]] = None, + dropout: float = 0.0, + bias: bool = True, + has_relative_attention_bias: bool = False, + num_buckets: int = 32, + max_distance: int = 128, + gru_rel_pos: bool = True, + prune_heads: bool = False, + prune_layer: bool = False, + ): + self.total_num_heads = total_num_heads + if remaining_heads is None: + self.remaining_heads = list(range(total_num_heads)) + else: + self.remaining_heads = remaining_heads # list of indices + + self.head_dim = embed_dim // total_num_heads + + super().__init__(embed_dim, len(self.remaining_heads), self.head_dim, dropout, prune_heads, prune_layer) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + + if has_relative_attention_bias: + self.rel_attn_embed = nn.Embedding(num_buckets, total_num_heads) + else: + self.rel_attn_embed = None + + # override linear layers to customize bias + self.k_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, len(self.remaining_heads) * self.head_dim, bias=bias) + self.out_proj = nn.Linear(len(self.remaining_heads) * self.head_dim, embed_dim, bias=bias) + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8) + self.gru_rel_pos_const = nn.Parameter(torch.ones(1, total_num_heads, 1, 1)) + self.has_position_bias = True + + def compute_bias(self, query_length: int, key_length: int) -> Tensor: + """Compute relative position embeddings for WavLM model. + Args: + query_length (int): Query position can take values between 0 and ``query_length - 1``. + key_length (int): Key position can take values between 0 and ``key_length - 1``. + Returns: + Tensor of shape `(num_heads, query_length, key_length)`, relative positions embeddings + """ + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position # Shape (query_length, key_length) + relative_position_bucket = self._relative_positions_bucket(relative_position, bidirectional=True) + relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device) + values = self.rel_attn_embed(relative_position_bucket) # Shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]) + return values + + def _relative_positions_bucket(self, relative_positions: Tensor, bidirectional: bool = True): + """Compute relative position buckets for WavLM model. Computation similar to formula (5) in WavLM + paper :cite:`chen2022wavlm`. + Args: + relative_positions (Tensor): Relative offsets between query and key positions, + of shape ``(query_length, key_length)``. + bidirectional (bool): If ``True``, values will be filled both above and below the diagonal in the resulting + matrix. If ``False``, the elements above the diagonal (i.e. with negative relative offsets) will be set + to zero. (Default ``True``) + Returns: + Tensor of shape ``(query_length, key_length)`` filled bucketed values of with relative positions. + """ + num_buckets = self.num_buckets + max_distance = self.max_distance + # Shape (query_length, key_length) + relative_buckets = torch.zeros_like(relative_positions, dtype=torch.long) + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def forward( + self, + query: Tensor, + attention_mask: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + query (Tensor): Input of shape ``(batch_size, src_len, embed_dim)``. + key_padding_mask (Tensor or None, optional): Mask to exclude keys that are pads, of shape + `(batch, src_len)`, where padding elements are indicated by 1s. (Default: ``None``) + attn_mask: Needs to be ``None``. The argument exists for compatibility with + ``EncoderLayer``. (Default: ``None``) + position_bias (Tensor or None, optional): Position bias of shape + ``(batch_size * num_heads, src_len, src_len)``. When used inside WavLM model encoder, will be + generated in the first layer and then passed from each encoder layer to the next one. + (Default: ``None``) + Returns: + attn_output (Tensor): Attention output of shape ``(batch_size, src_len, embed_dim)``. + position_bias (Tensor or None): Position bias of shape ``(batch_size * num_heads, src_len, src_len)``. + """ + bsz, seq_len, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert key_padding_mask is None + + # only for the first layer + if self.rel_attn_embed is not None and position_bias is None: + position_bias = self.compute_bias(seq_len, seq_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.total_num_heads, seq_len, seq_len) + + attn_mask_rel_pos: Optional[Tensor] = None + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos: # Apply gating on relative position bias + query_layer = query.view(bsz, seq_len, self.total_num_heads, -1) + query_layer = query_layer.permute(0, 2, 1, 3) + + gate_a, gate_b = torch.sigmoid( + self.gru_rel_pos_linear(query_layer).view(bsz, self.total_num_heads, seq_len, 2, 4).sum(-1, keepdim=False) + ).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.total_num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((-1, seq_len, seq_len)) + attn_mask_rel_pos = attn_mask_rel_pos.reshape(bsz, self.total_num_heads, seq_len, seq_len)[:, self.remaining_heads, :, :] + + attn_mask = attn_mask_rel_pos + if attention_mask is not None: + attn_mask = attn_mask + attention_mask + if key_padding_mask is not None: + attn_mask = attn_mask.masked_fill( + key_padding_mask.reshape(bsz, 1, 1, seq_len), + float("-inf") + ) + attn_output, _ = super().forward(query, attention_mask=attn_mask) + + return attn_output, position_bias + + def prune(self): + new_config = { + "use_attention": True, + "remaining_heads": self.remaining_heads, + } + if self.hard_concrete_for_layer is not None: + assert not self.hard_concrete_for_layer.training + layer_mask = self.hard_concrete_for_layer() # (1,) + self.out_proj.weight.data *= layer_mask + self.out_proj.bias.data *= layer_mask + if layer_mask == 0: + new_config["use_attention"] = False + self.hard_concrete_for_layer = None + + if self.hard_concrete_for_heads is not None: + assert not self.hard_concrete_for_heads.training + head_mask = self.hard_concrete_for_heads() # (num_heads,) + new_config["remaining_heads"] = head_mask.nonzero().squeeze(-1).tolist() + if len(new_config["remaining_heads"]) == 0: + new_config["use_attention"] = False + else: + full_mask = head_mask.repeat_interleave(self.head_dim) + full_index = full_mask.nonzero().squeeze(-1) # 1D + + prune_linear_layer(self.k_proj, full_index, "output") + prune_linear_layer(self.v_proj, full_index, "output") + prune_linear_layer(self.q_proj, full_index, "output") + + self.out_proj.weight.data *= full_mask + prune_linear_layer(self.out_proj, full_index, "input") + self.hard_concrete_for_heads = None + + return new_config + + +class FeedForward(Module): + """Layer that follows attention layer in encoder layer.""" + + def __init__( + self, + io_features: int, + intermediate_features: int, + intermediate_dropout: float, + output_dropout: float, + prune_intermediate: bool = False, + prune_layer: bool = False, + ): + super().__init__() + self.intermediate_dense = nn.Linear(io_features, intermediate_features) + self.intermediate_dropout = nn.Dropout(intermediate_dropout) + self.output_dense = nn.Linear(intermediate_features, io_features) + self.output_dropout = nn.Dropout(output_dropout) + + if prune_intermediate: + self.hard_concrete_for_intermediate = HardConcrete( + n_in=intermediate_features, init_mean=0.5 + ) + else: + self.hard_concrete_for_intermediate = None + + if prune_layer: + self.hard_concrete_for_layer = HardConcrete(n_in=1, init_mean=0.01) + else: + self.hard_concrete_for_layer = None + + def forward(self, x): + """ + Args: + x (Tensor): shape: `(batch, sequence_length, io_features)` + Returns: + x (Tensor): shape: `(batch, sequence_length, io_features)` + """ + x = self.intermediate_dense(x) + x = torch.nn.functional.gelu(x) + x = self.intermediate_dropout(x) + + if self.hard_concrete_for_intermediate is not None: + intermediate_mask = self.hard_concrete_for_intermediate() # (intermediate_features,) + x = x * intermediate_mask + + x = self.output_dense(x) + x = self.output_dropout(x) + + if self.hard_concrete_for_layer is not None: + layer_mask = self.hard_concrete_for_layer() # (1,) + x = x * layer_mask + + return x + + def get_num_params(self): + io_features = self.intermediate_dense.in_features + if self.hard_concrete_for_intermediate is not None: + intermediate_features = self.hard_concrete_for_intermediate.l0_norm() + else: + intermediate_features = self.intermediate_dense.out_features + num_params = (io_features + 1) * intermediate_features + (intermediate_features + 1) * io_features + + if self.hard_concrete_for_layer is not None: + num_params *= self.hard_concrete_for_layer.l0_norm() + + return num_params + + def prune(self): + new_config = { + "use_feed_forward": True, + "ff_interm_features": self.intermediate_dense.out_features + } + if self.hard_concrete_for_layer is not None: + assert not self.hard_concrete_for_layer.training + layer_mask = self.hard_concrete_for_layer() + self.output_dense.weight.data *= layer_mask + self.output_dense.bias.data *= layer_mask + if layer_mask == 0: + new_config["use_feed_forward"] = False + self.hard_concrete_for_layer = None + + if self.hard_concrete_for_intermediate is not None: + assert not self.hard_concrete_for_intermediate.training + interm_mask = self.hard_concrete_for_intermediate() + interm_index = interm_mask.nonzero().squeeze(-1) # NOTE: must specify dim=-1 + new_config["ff_interm_features"] = len(interm_index) + if new_config["ff_interm_features"] == 0: + new_config["use_feed_forward"] = False + else: + prune_linear_layer(self.intermediate_dense, interm_index, "output") + + self.output_dense.weight.data *= interm_mask + prune_linear_layer(self.output_dense, interm_index, "input") + self.hard_concrete_for_intermediate = None + + return new_config + + +class EncoderLayer(Module): + """A layer unit in encoder. Combines multihead self attention and feed forward.""" + + def __init__( + self, + attention: Optional[Module], # can be None if the entire layer is pruned + dropout: float, + layer_norm_first: bool, + feed_forward: Optional[Module], # can be None if the entire layer is pruned + embed_dim: int, + ): + super().__init__() + self.attention = attention + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(embed_dim) + self.layer_norm_first = layer_norm_first + self.feed_forward = feed_forward + self.final_layer_norm = nn.LayerNorm(embed_dim) + self.embed_dim = embed_dim + + def forward( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x (Tensor): Input of shape ``(batch, sequence_length, embed_dim)``. + attention_mask (Tensor or ``None``, optional): attention mask + of shape ``(batch, 1, sequence_length, sequence_length)``. (Default: ``None``) + position_bias (Tensor or ``None``, optional): position bias of shape + ``(batch_size * num_heads, src_len, src_len)``. + Only necessary for WavLM model, ``None`` otherwise. (Default: ``None``) + key_padding_mask (Tensor or ``None``, optional): key padding mask of shape ``(batch_size, src_len)``. + Only used for WavLM model, ignored otherwise. (Default: ``None``) + Returns: + (x, position_bias): Shapes are the same as in the input. Position bias is only relevant for WaLM model, + ``None`` otherwise. + """ + if self.attention is not None: + residual = x + + if self.layer_norm_first: + x = self.layer_norm(x) + + x, position_bias = self.attention( + x, attention_mask=attention_mask, position_bias=position_bias, key_padding_mask=key_padding_mask + ) + + x = self.dropout(x) + x = residual + x + + if self.layer_norm_first: + if self.feed_forward is not None: + x = x + self.feed_forward(self.final_layer_norm(x)) + else: + # NOTE: for post norm, the layer norms should always be applied even if the layers are pruned. + x = self.layer_norm(x) + if self.feed_forward is not None: + x = x + self.feed_forward(x) + x = self.final_layer_norm(x) + return x, position_bias + + def get_num_params(self): + num_params = self.embed_dim * 2 * 2 # two layer norms + if self.attention is not None: + num_params += self.attention.get_num_params() + if self.feed_forward is not None: + num_params += self.feed_forward.get_num_params() + return num_params + + +class Transformer(Module): + def __init__( + self, + pos_conv_embed: Module, + dropout: float, + layers: Module, + layer_norm_first: bool, + layer_drop: float, + ): + super().__init__() + self.pos_conv_embed = pos_conv_embed + self.layer_norm = nn.LayerNorm(pos_conv_embed.embed_dim) + self.layer_norm_first = layer_norm_first + self.layer_drop = layer_drop + self.dropout = nn.Dropout(dropout) + self.layers = layers + + def _preprocess(self, x: Tensor): + x = x + self.pos_conv_embed(x) + + if self.layer_norm_first: + x = self.layer_norm(x) + + x = self.dropout(x) + return x + + def forward( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + position_bias: Optional[Tensor] = None, + ) -> Tensor: + x = self._preprocess(x) + for layer in self.layers: + if not (self.training and torch.rand(1).item() <= self.layer_drop): + x, position_bias = layer(x, attention_mask, position_bias=position_bias) + + if not self.layer_norm_first: + x = self.layer_norm(x) + return x + + def get_intermediate_outputs( + self, + x: Tensor, + attention_mask: Optional[Tensor] = None, + num_layers: Optional[int] = None, + position_bias: Optional[Tensor] = None, + ) -> List[Tensor]: + if num_layers is not None: + if not 0 < num_layers <= len(self.layers): + raise ValueError(f"`num_layers` must be between [1, {len(self.layers)}]") + + ret: List[Tensor] = [] + x = self._preprocess(x) + for layer in self.layers: + x, position_bias = layer(x, attention_mask, position_bias=position_bias) + ret.append(x) + if num_layers is not None and len(ret) >= num_layers: + return ret + return ret + + def get_num_params(self): + # pos_conv_embed and layer_norm + num_params = sum(p.numel() for p in self.pos_conv_embed.parameters()) + self.pos_conv_embed.embed_dim * 2 + for layer in self.layers: + num_params += layer.get_num_params() + return num_params + + def prune(self): + new_config = defaultdict(list) + for layer in self.layers: + attention_config = layer.attention.prune() + new_config["use_attention"].append(attention_config["use_attention"]) + if "remaining_heads" in attention_config: + new_config["remaining_heads"].append(attention_config["remaining_heads"]) + else: + new_config["num_heads"].append(attention_config["num_heads"]) + + if not attention_config["use_attention"]: + layer.attention = None + + ff_config = layer.feed_forward.prune() + new_config["use_feed_forward"].append(ff_config["use_feed_forward"]) + new_config["ff_interm_features"].append(ff_config["ff_interm_features"]) + if not ff_config["use_feed_forward"]: + layer.feed_forward = None + + return new_config + + +class Encoder(Module): + def __init__( + self, + feature_projection: Module, + transformer: Module, + ): + super().__init__() + self.feature_projection = feature_projection + self.transformer = transformer + + def _preprocess( + self, + features: Tensor, + lengths: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + x = self.feature_projection(features) + + mask: Optional[Tensor] = None + if lengths is not None: + batch_size, max_len, _ = x.shape + # create mask for padded elements and zero-out them + mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None] + x[mask] = 0.0 + # extend the mask to attention shape and set weight + mask = -10000.0 * mask[:, None, None, :].to(dtype=features.dtype) + mask = mask.expand(batch_size, 1, max_len, max_len) + return x, mask + + def forward( + self, + features: Tensor, + lengths: Optional[Tensor] = None, + ) -> Tensor: + x, mask = self._preprocess(features, lengths) + x = self.transformer(x, attention_mask=mask) + return x + + def extract_features( + self, + features: Tensor, + lengths: Optional[Tensor] = None, + num_layers: Optional[int] = None, + ) -> List[Tensor]: + x, masks = self._preprocess(features, lengths) + interm = self.transformer.get_intermediate_outputs(x, attention_mask=masks, num_layers=num_layers) + return [x] + interm + + def get_num_params(self, in_features): + """Calculate the current model size.""" + feature_projection_size = self.feature_projection.get_num_params(in_features) + transformer_size = self.transformer.get_num_params() + return feature_projection_size + transformer_size + + def prune(self, conv_out_index): + """In-place pruning of submodules.""" + prune_layer_norm(self.feature_projection.layer_norm, conv_out_index) + prune_linear_layer(self.feature_projection.projection, conv_out_index, "input") + transformer_config = self.transformer.prune() + return transformer_config + + +################################################################################ +def _get_feature_extractor( + norm_mode: str, + shapes: List[Tuple[int, int, int]], + bias: bool, + prune_conv_channels: bool = False, +) -> FeatureExtractor: + """ + Args: + norm_mode (str): + Either "group_norm" or "layer_norm". + If "group_norm", then a single normalization is applied + in the first convolution block. Otherwise, all the convolution + blocks will have layer normalization. + This option corresponds to "extractor_mode" from fairseq. + Expected values are "group_norm" for Base arch, and + "layer_norm" for Large arch. + shapes (list of tuple of int): + Configuration of convolution layers. List of convolution configuration, + i.e. ``[(output_channel, kernel_size, stride), ...]`` + This option corresponds to "conv_feature_layers" from fairseq. + Expected values are + ``[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2`` + for all the architectures. + bias (bool): + Whether to include bias term to each convolution operation. + This option corresponds to "conv_bias" from fairseq. + Expected values are False for Base arch, and True for Large arch. + + See Also: + * Original implementation + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L666-L733 + * "extractor_mode" + - Def and base: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L38-L45 + - Large: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L52 + * "conv_feature_layers" + - Def, base and large: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L94-L100 + * "conv_bias" + - Def and base: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L101-L103 + - Large: + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L61 + """ + if norm_mode not in ["group_norm", "layer_norm"]: + raise ValueError("Invalid norm mode") + blocks = [] + in_channels = 1 + for i, (out_channels, kernel_size, stride) in enumerate(shapes): + normalization = None + if norm_mode == "group_norm" and i == 0: + normalization = nn.GroupNorm( + num_groups=out_channels, + num_channels=out_channels, + affine=True, + ) + elif norm_mode == "layer_norm": + normalization = LayerNorm( + normalized_shape=out_channels, + elementwise_affine=True, + ) + blocks.append( + ConvLayerBlock( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + bias=bias, + layer_norm=normalization, + prune_conv_channels=prune_conv_channels, + ) + ) + in_channels = out_channels + return FeatureExtractor(nn.ModuleList(blocks)) + + +def _get_encoder( + in_features: int, + embed_dim: int, + dropout_input: float, + pos_conv_kernel: int, + pos_conv_groups: int, + num_layers: int, + use_attention: List[bool], + use_feed_forward: List[bool], + num_heads: List[int], + head_dim: int, + attention_dropout: float, + ff_interm_features: List[int], + ff_interm_dropout: float, + dropout: float, + layer_norm_first: bool, + layer_drop: float, + prune_attention_heads: bool = False, + prune_attention_layer: bool = False, + prune_feed_forward_intermediate: bool = False, + prune_feed_forward_layer: bool = False, +) -> Encoder: + """ + Args: + in_features (int): The number of input features. + embed_dim (int): + The dimension of embedding. + This option corresponds to "encoder_embed_dim" from fairseq. + Expected values are 768 for Base arch, and 1024 for Large arch. + dropout_input (float): + The dropout probability applied after the input feature is projected + to ``embed_dim``. + This option corresponds to "dropout_input" from fairseq. + Expected values are 0.1 for both Base and Large arch. + pos_conv_kernel (int): + The kernel size of convolutional positional embeddings. + This option corresponds to "conv_pos" from fairseq. + Expected values are 128 for both Base and Large arch. + pos_conv_groups (int): + The number of groups of convolutional positional embeddings. + This option corresponds to "conv_pos_groups" from fairseq. + Expected values are 16 for both Base and Large arch. + num_layers (int): + The number of self attention layers in transformer block. + This option corresponds to "encoder_layers" from fairseq. + Expected values are 12 for Base and 24 for Large arch. + num_heads (int): + The number of heads in self attention layers. + This option corresponds to "encoder_attention_heads" from fairseq. + Expected values are 12 for Base and 16 for Large arch. + attention_dropout (float): + The dropout probability applied after softmax in self-attention layer. + This option corresponds to "attention_dropout" from fairseq. + Expected values are 0.1 for Base and 0.0 for Large arch. + ff_interm_features (int): + The dimension of hidden features in feed forward layer. + This option corresponds to "encoder_ffn_embed_dim" from fairseq. + Expected values are 3072 for Base and 4096 for Large arch. + ff_interm_dropout (float): + The dropout probability applied in feedforward layer. + This option correspinds to "activation_dropout" from fairseq. + Expected values are 0.1 for both Base and Large arch. + dropout (float): + The dropout probability applied at the end of feed forward layer. + This option corresponds to "dropout" from fairseq. + Expected values are 0.1 for Base and 0.0 for Large arch. + layer_norm_first (bool): + Control the order of layer norm in transformer layer and each encoder layer. + If True, in transformer layer, layer norm is applied before features are fed + to encoder layers. In encoder layer, two layer norms are applied before and after + self attention. + If False, in transformer layer, layer norm is applied after features are fed + to encoder layers. In encoder layer, two layer norms are applied after self + attention, before and after feed forward. + This option corresponds to "layer_norm_first" from fairseq. + Expected values are False for Base and True for Large arch. + layer_drop (float): + Probability to drop each encoder layer during training. + This option corresponds to "layerdrop" from fairseq. + Expected values are 0.1 for both Base and Large arch. + + See Also: + * "encoder_embed_dim" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L49-L51 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L64 + * "dropout_input" + - Def, base and large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L75-L78 + * "conv_pos" + - Def, base and large + NOTE: The description is wrong. + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L204-L207 + - Usage + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L756 + * "conv_pos_groups" + - Def, base and large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L208-L211 + * "encoder_layers" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L46-L48 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L63 + * "encoder_attention_heads" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L55-L57 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L66 + * "attention_dropout" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L66-L68 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L60 + * "encoder_ffn_embed_dim" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L52-L54 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L65 + * "activation_dropout" + - Def + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L69-L71 + - Base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L55 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L55 + * "dropout" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L63-L65 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L59 + * "layer_norm_first" + - Def and base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L91-L93 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml#L53 + * "layerdrop" + - Def + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L72-L74 + - Base + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/base_960h.yaml#L54 + - Large + https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/examples/wav2vec/config/finetuning/vox_960h.yaml#L54 + """ + feature_projection = FeatureProjection(in_features, embed_dim, dropout_input) + pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups) + + # Original impl + # https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782 + encoder_layers = nn.ModuleList() + for idx in range(num_layers): + if use_attention[idx]: + attention = SelfAttention( + embed_dim=embed_dim, + num_heads=num_heads[idx], + head_dim=head_dim, + dropout=attention_dropout, + prune_heads=prune_attention_heads, + prune_layer=prune_attention_layer, + ) + else: + attention = None + if use_feed_forward[idx]: + feed_forward = FeedForward( + io_features=embed_dim, + intermediate_features=ff_interm_features[idx], + intermediate_dropout=ff_interm_dropout, + output_dropout=dropout, + prune_intermediate=prune_feed_forward_intermediate, + prune_layer=prune_feed_forward_layer, + ) + else: + feed_forward = None + encoder_layers.append( + EncoderLayer( + attention=attention, + dropout=dropout, + layer_norm_first=layer_norm_first, + feed_forward=feed_forward, + embed_dim=embed_dim, + ) + ) + transformer = Transformer( + pos_conv_embed=pos_conv, + dropout=dropout, + layers=encoder_layers, + layer_norm_first=not layer_norm_first, + layer_drop=layer_drop, + ) + return Encoder(feature_projection, transformer) + + +def _get_wavlm_encoder( + in_features: int, + embed_dim: int, + dropout_input: float, + pos_conv_kernel: int, + pos_conv_groups: int, + num_layers: int, + use_attention: List[bool], + use_feed_forward: List[bool], + total_num_heads: List[int], + remaining_heads: List[List[int]], + num_buckets: int, + max_distance: int, + attention_dropout: float, + ff_interm_features: List[int], + ff_interm_dropout: float, + dropout: float, + layer_norm_first: bool, + layer_drop: float, + prune_attention_heads: bool = False, + prune_attention_layer: bool = False, + prune_feed_forward_intermediate: bool = False, + prune_feed_forward_layer: bool = False, +) -> Encoder: + """ + Construct encoder for WavLM model :cite:`chen2022wavlm`. The structure of the encoder and most of the argments are + the same as in :py:func:`_get_encoder` so refer there for documentation. The only difference from Wav2Vec2 encoder + is usage of `WavLMSelfAttention` instead of `SelfAttention` and two additional parameters: `num_buckets` and + `max_distance`. + Args: + in_features (int): See :py:func:`_get_encoder`. + embed_dim (int): See :py:func:`_get_encoder`. + dropout_input (float): See :py:func:`_get_encoder`. + pos_conv_kernel (int): See :py:func:`_get_encoder`. + pos_conv_groups (int): See :py:func:`_get_encoder`. + num_layers (int): See :py:func:`_get_encoder`. + num_heads (int): See :py:func:`_get_encoder`. + num_buckets (int): Number of buckets for relative position embedding. + max_distance (int): Maximum distance for relative position embedding. + attention_dropout (float): See :py:func:`_get_encoder`. + ff_interm_features (int): See :py:func:`_get_encoder`. + ff_interm_dropout (float): See :py:func:`_get_encoder`. + dropout (float): See :py:func:`_get_encoder`. + layer_norm_first (bool): See :py:func:`_get_encoder`. + layer_drop (float): See :py:func:`_get_encoder`. + + """ + feature_projection = FeatureProjection(in_features, embed_dim, dropout_input) + pos_conv = ConvolutionalPositionalEmbedding(embed_dim, pos_conv_kernel, pos_conv_groups) + + # Original impl + # https://github.com/pytorch/fairseq/blob/425c36eafff535fe7337f8bdd5ace22ebacc78cb/fairseq/models/wav2vec/wav2vec2.py#L768-L782 + encoder_layers = nn.ModuleList() + for i in range(num_layers): + if use_attention[i]: + attention = WavLMSelfAttention( + embed_dim=embed_dim, + total_num_heads=total_num_heads[i], + remaining_heads=remaining_heads[i], + dropout=attention_dropout, + has_relative_attention_bias=(i == 0), # Position embedding is only necessary in the first layer. + num_buckets=num_buckets, + max_distance=max_distance, + prune_heads=prune_attention_heads, + prune_layer=prune_attention_layer, + ) + else: + attention = None + if use_feed_forward[i]: + feed_forward = FeedForward( + io_features=embed_dim, + intermediate_features=ff_interm_features[i], + intermediate_dropout=ff_interm_dropout, + output_dropout=dropout, + prune_intermediate=prune_feed_forward_intermediate, + prune_layer=prune_feed_forward_layer, + ) + else: + feed_forward = None + encoder_layers.append( + EncoderLayer( + attention=attention, + dropout=dropout, + layer_norm_first=layer_norm_first, + feed_forward=feed_forward, + embed_dim=embed_dim, + ) + ) + transformer = Transformer( + pos_conv_embed=pos_conv, + dropout=dropout, + layers=encoder_layers, + layer_norm_first=not layer_norm_first, + layer_drop=layer_drop, + ) + return Encoder(feature_projection, transformer) + + +def _get_padding_mask(input: Tensor, lengths: Tensor) -> Tensor: + """Generate the padding mask given the padded input and the lengths Tensors. + Args: + input (Tensor): The padded Tensor of dimension `[batch, max_len, frequency]`. + lengths (Tensor): The lengths Tensor of dimension `[batch,]`. + + Returns: + (Tensor): The padding mask. + """ + batch_size, max_len, _ = input.shape + mask = torch.arange(max_len, device=lengths.device).expand(batch_size, max_len) >= lengths[:, None] + return mask + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None diff --git a/vencoder/dphubert/hardconcrete.py b/vencoder/dphubert/hardconcrete.py new file mode 100644 index 0000000000000000000000000000000000000000..468a30d1eccdf20ee7493e71792c46e48449c4e3 --- /dev/null +++ b/vencoder/dphubert/hardconcrete.py @@ -0,0 +1,122 @@ +"""Implementation of the hard Concrete distribution. + +Originally from: +https://github.com/asappresearch/flop/blob/master/flop/hardconcrete.py + +""" + +import math + +import torch +import torch.nn as nn + + +class HardConcrete(nn.Module): + """A HarcConcrete module. + Use this module to create a mask of size N, which you can + then use to perform L0 regularization. + + To obtain a mask, simply run a forward pass through the module + with no input data. The mask is sampled in training mode, and + fixed during evaluation mode, e.g.: + + >>> module = HardConcrete(n_in=100) + >>> mask = module() + >>> norm = module.l0_norm() + """ + + def __init__( + self, + n_in: int, + init_mean: float = 0.5, + init_std: float = 0.01, + temperature: float = 2/3, # from CoFi + stretch: float = 0.1, + eps: float = 1e-6 + ) -> None: + """Initialize the HardConcrete module. + Parameters + ---------- + n_in : int + The number of hard concrete variables in this mask. + init_mean : float, optional + Initial drop rate for hard concrete parameter, + by default 0.5., + init_std: float, optional + Used to initialize the hard concrete parameters, + by default 0.01. + temperature : float, optional + Temperature used to control the sharpness of the + distribution, by default 1.0 + stretch : float, optional + Stretch the sampled value from [0, 1] to the interval + [-stretch, 1 + stretch], by default 0.1. + """ + super().__init__() + + self.n_in = n_in + self.limit_l = -stretch + self.limit_r = 1.0 + stretch + self.log_alpha = nn.Parameter(torch.zeros(n_in)) + self.beta = temperature + self.init_mean = init_mean + self.init_std = init_std + self.bias = -self.beta * math.log(-self.limit_l / self.limit_r) + + self.eps = eps + self.compiled_mask = None + self.reset_parameters() + + def reset_parameters(self): + """Reset the parameters of this module.""" + self.compiled_mask = None + mean = math.log(1 - self.init_mean) - math.log(self.init_mean) + self.log_alpha.data.normal_(mean, self.init_std) + + def l0_norm(self) -> torch.Tensor: + """Compute the expected L0 norm of this mask. + Returns + ------- + torch.Tensor + The expected L0 norm. + """ + return (self.log_alpha + self.bias).sigmoid().sum() + + def forward(self) -> torch.Tensor: + """Sample a hard concrete mask. + Returns + ------- + torch.Tensor + The sampled binary mask + """ + if self.training: + # Reset the compiled mask + self.compiled_mask = None + # Sample mask dynamically + u = self.log_alpha.new(self.n_in).uniform_(self.eps, 1 - self.eps) + s = torch.sigmoid((torch.log(u / (1 - u)) + self.log_alpha) / self.beta) + s = s * (self.limit_r - self.limit_l) + self.limit_l + mask = s.clamp(min=0., max=1.) + + else: + # Compile new mask if not cached + if self.compiled_mask is None: + # Get expected sparsity + expected_num_zeros = self.n_in - self.l0_norm().item() + num_zeros = round(expected_num_zeros) + # Approximate expected value of each mask variable z; + # We use an empirically validated magic number 0.8 + soft_mask = torch.sigmoid(self.log_alpha / self.beta * 0.8) + # Prune small values to set to 0 + _, indices = torch.topk(soft_mask, k=num_zeros, largest=False) + soft_mask[indices] = 0. + self.compiled_mask = soft_mask + mask = self.compiled_mask + + return mask + + def extra_repr(self) -> str: + return str(self.n_in) + + def __repr__(self) -> str: + return "{}({})".format(self.__class__.__name__, self.extra_repr()) diff --git a/vencoder/dphubert/model.py b/vencoder/dphubert/model.py new file mode 100644 index 0000000000000000000000000000000000000000..348ede2c3edc3e5588ee75760085dee9eafd9d68 --- /dev/null +++ b/vencoder/dphubert/model.py @@ -0,0 +1,966 @@ +"""Speech SSL models supporting pruning. + +Originally from: +https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/model.py + +""" + +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Module + +from . import components + + +class Wav2Vec2Model(Module): + """Acoustic model used in *wav2vec 2.0* :cite:`baevski2020wav2vec`. + + Note: + To build the model, please use one of the factory functions. + :py:func:`wav2vec2_model`, :py:func:`wav2vec2_base`, :py:func:`wav2vec2_large`, + :py:func:`wav2vec2_large_lv60k`, :py:func:`hubert_base`, :py:func:`hubert_large`, + and :py:func:`hubert_xlarge`. + + See Also: + * :class:`torchaudio.pipelines.Wav2Vec2Bundle`: Pretrained models (without fine-tuning) + * :class:`torchaudio.pipelines.Wav2Vec2ASRBundle`: ASR pipelines with pretrained models. + + Args: + feature_extractor (torch.nn.Module): + Feature extractor that extracts feature vectors from raw audio Tensor. + + encoder (torch.nn.Module): + Encoder that converts the audio features into the sequence of probability + distribution (in negative log-likelihood) over labels. + + aux (torch.nn.Module or None, optional): + Auxiliary module. If provided, the output from encoder is passed to this module. + """ # noqa: E501 + + def __init__( + self, + normalize_waveform: bool, + feature_extractor: Module, + encoder: Module, + aux: Optional[Module] = None, + ): + super().__init__() + self.normalize_waveform = normalize_waveform + self.feature_extractor = feature_extractor + self.encoder = encoder + self.aux = aux + + @torch.jit.export + def extract_features( + self, + waveforms: Tensor, + lengths: Optional[Tensor] = None, + num_layers: Optional[int] = None, + ) -> Tuple[List[Tensor], Optional[Tensor]]: + """Extract feature vectors from raw waveforms + + This returns the list of outputs from the intermediate layers of + transformer block in encoder. + + Args: + waveforms (Tensor): Audio tensor of shape `(batch, frames)`. + lengths (Tensor or None, optional): + Indicates the valid length of each audio in the batch. + Shape: `(batch, )`. + When the ``waveforms`` contains audios with different durations, + by providing ``lengths`` argument, the model will compute + the corresponding valid output lengths and apply proper mask in + transformer attention layer. + If ``None``, it is assumed that the entire audio waveform + length is valid. + num_layers (int or None, optional): + If given, limit the number of intermediate layers to go through. + Providing `1` will stop the computation after going through one + intermediate layers. If not given, the outputs from all the + intermediate layers are returned. + + Returns: + (List[Tensor], Optional[Tensor]): + List of Tensors + Features from requested layers. + Each Tensor is of shape: `(batch, time frame, feature dimension)` + Tensor or None + If ``lengths`` argument was provided, a Tensor of shape `(batch, )` + is returned. + It indicates the valid length in time axis of each feature Tensor. + """ + if self.normalize_waveform: + if lengths is not None: + waveforms = [ + F.layer_norm(wave[:length], (length,)) for wave, length in zip(waveforms, lengths) + ] + waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) + else: + waveforms = F.layer_norm(waveforms, waveforms.shape[-1:]) + + x, lengths = self.feature_extractor(waveforms, lengths) + x = self.encoder.extract_features(x, lengths, num_layers) # (num_layers+1,), including the input + return x, lengths + + def get_num_params(self): + """Calculate the current size.""" + feature_extractor_size, encoder_in_features = self.feature_extractor.get_num_params_and_final_out_channels() + encoder_size = self.encoder.get_num_params(encoder_in_features) + return feature_extractor_size + encoder_size + + def prune(self): + self.eval() # must be in eval mode + conv_config, conv_out_index = self.feature_extractor.prune() # [(output_channel, kernel_size, stride), ...] + transformer_config = self.encoder.prune(conv_out_index) # NOTE: this is a defaultdict(list) + use_attention = transformer_config["use_attention"] + use_feed_forward = transformer_config["use_feed_forward"] + num_heads = transformer_config["num_heads"] # can be [] + remaining_heads = transformer_config["remaining_heads"] # can be [] + ff_interm_features = transformer_config["ff_interm_features"] + + return conv_config, use_attention, use_feed_forward, num_heads, remaining_heads, ff_interm_features + + def forward( + self, + waveforms: Tensor, + lengths: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + """Compute the sequence of probability distribution over labels. + + Args: + waveforms (Tensor): Audio tensor of shape `(batch, frames)`. + lengths (Tensor or None, optional): + Indicates the valid length of each audio in the batch. + Shape: `(batch, )`. + When the ``waveforms`` contains audios with different durations, + by providing ``lengths`` argument, the model will compute + the corresponding valid output lengths and apply proper mask in + transformer attention layer. + If ``None``, it is assumed that all the audio in ``waveforms`` + have valid length. Default: ``None``. + + Returns: + (Tensor, Optional[Tensor]): + Tensor + The sequences of probability distribution (in logit) over labels. + Shape: `(batch, frames, num labels)`. + Tensor or None + If ``lengths`` argument was provided, a Tensor of shape `(batch, )` + is returned. + It indicates the valid length in time axis of the output Tensor. + """ + if self.normalize_waveform: + if lengths is not None: + waveforms = [ + F.layer_norm(wave[:length], (length,)) for wave, length in zip(waveforms, lengths) + ] + waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) + else: + waveforms = F.layer_norm(waveforms, waveforms.shape[-1:]) + + x, lengths = self.feature_extractor(waveforms, lengths) + x = self.encoder(x, lengths) + if self.aux is not None: + x = self.aux(x) + return x, lengths + + +def wav2vec2_model(**configs) -> Wav2Vec2Model: + """Wraps the original wav2vec2_model and wavlm_model.""" + + if "encoder_remaining_heads" in configs: + return wavlm_model(**configs) + + return wav2vec2_model_original(**configs) + + +def wav2vec2_model_original( + extractor_mode: str, + extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], + extractor_conv_bias: bool, + encoder_embed_dim: int, + encoder_projection_dropout: float, + encoder_pos_conv_kernel: int, + encoder_pos_conv_groups: int, + encoder_num_layers: int, + encoder_use_attention: List[bool], + encoder_use_feed_forward: List[bool], + encoder_num_heads: List[int], + encoder_head_dim: int, + encoder_attention_dropout: float, + encoder_ff_interm_features: List[int], + encoder_ff_interm_dropout: float, + encoder_dropout: float, + encoder_layer_norm_first: bool, + encoder_layer_drop: float, + aux_num_out: Optional[int], + normalize_waveform: bool, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds custom :class:`~torchaudio.models.Wav2Vec2Model`. + + Note: + The "feature extractor" below corresponds to + `ConvFeatureExtractionModel `__ + in the original ``fairseq`` implementation. + This is referred as "(convolutional) feature encoder" in the *wav2vec 2.0* + :cite:`baevski2020wav2vec` paper. + + The "encoder" below corresponds to `TransformerEncoder `__, + and this is referred as "Transformer" in the paper. + + Args: + extractor_mode (str): Operation mode of feature extractor. + Valid values are ``"group_norm"`` or ``"layer_norm"``. + If ``"group_norm"``, then a single normalization is applied + in the first convolution block. Otherwise, all the convolution + blocks will have layer normalization. + + This option corresponds to ``extractor_mode`` from ``fairseq``. + extractor_conv_layer_config (list of integer tuples or None): + Configuration of convolution layers in feature extractor. + List of convolution configuration, + i.e. ``[(output_channel, kernel_size, stride), ...]`` + + If ``None`` is provided, then the following default value is used. + + .. code-block:: python + + [ + (512, 10, 5), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 3, 2), + (512, 2, 2), + (512, 2, 2), + ] + + This option corresponds to ``conv_feature_layers`` from ``fairseq``. + + extractor_conv_bias (bool): + Whether to include bias term to each convolution operation. + + This option corresponds to ``conv_bias`` from ``fairseq``. + + encoder_embed_dim (int): + The dimension of embedding in encoder. + + This option corresponds to ``encoder_embed_dim`` from ``fairseq``. + + encoder_projection_dropout (float): + The dropout probability applied after the input feature is projected + to ``encoder_embed_dim``. + + This option corresponds to ``dropout_input`` from ``fairseq``. + + encoder_pos_conv_kernel (int): + The kernel size of convolutional positional embeddings. + + This option corresponds to ``conv_pos`` from ``fairseq``. + + encoder_pos_conv_groups (int): + The number of groups of convolutional positional embeddings. + + This option corresponds to ``conv_pos_groups`` from ``fairseq``. + + encoder_num_layers (int): + The number of self attention layers in transformer block. + + This option corresponds to ``encoder_layers`` from ``fairseq``. + + encoder_num_heads (int): + The number of heads in self attention layers. + + This option corresponds to ``encoder_attention_heads`` from ``fairseq``. + + encoder_attention_dropout (float): + The dropout probability applied after softmax in self-attention layer. + + This option corresponds to ``attention_dropout`` from ``fairseq``. + + encoder_ff_interm_features (int): + The dimension of hidden features in feed forward layer. + + This option corresponds to ``encoder_ffn_embed_dim`` from ``fairseq``. + + encoder_ff_interm_dropout (float): + The dropout probability applied in feedforward layer. + + This option correspinds to ``activation_dropout`` from ``fairseq``. + + encoder_dropout (float): + The dropout probability applied at the end of feed forward layer. + + This option corresponds to ``dropout`` from ``fairseq``. + + encoder_layer_norm_first (bool): + Control the order of layer norm in transformer layer and each encoder layer. + If True, in transformer layer, layer norm is applied before features are fed + to encoder layers. In encoder layer, two layer norms are applied before and after + self attention. + If False, in transformer layer, layer norm is applied after features are fed + to encoder layers. In encoder layer, two layer norms are applied after self + attention, before and after feed forward. + + This option corresponds to ``layer_norm_first`` from ``fairseq``. + + encoder_layer_drop (float): + Probability to drop each encoder layer during training. + + This option corresponds to ``layerdrop`` from ``fairseq``. + + aux_num_out (int or None): + When provided, attach an extra linear layer on top of encoder, which can be + used for fine-tuning. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + if extractor_conv_layer_config is None: + extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 + + feature_extractor = components._get_feature_extractor( + extractor_mode, extractor_conv_layer_config, extractor_conv_bias, + prune_conv_channels=extractor_prune_conv_channels, + ) + encoder = components._get_encoder( + in_features=extractor_conv_layer_config[-1][0], + embed_dim=encoder_embed_dim, + dropout_input=encoder_projection_dropout, + pos_conv_kernel=encoder_pos_conv_kernel, + pos_conv_groups=encoder_pos_conv_groups, + num_layers=encoder_num_layers, + use_attention=encoder_use_attention, + use_feed_forward=encoder_use_feed_forward, + num_heads=encoder_num_heads, + head_dim=encoder_head_dim, + attention_dropout=encoder_attention_dropout, + ff_interm_features=encoder_ff_interm_features, + ff_interm_dropout=encoder_ff_interm_dropout, + dropout=encoder_dropout, + layer_norm_first=encoder_layer_norm_first, + layer_drop=encoder_layer_drop, + prune_attention_heads=encoder_prune_attention_heads, + prune_attention_layer=encoder_prune_attention_layer, + prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + aux = None + if aux_num_out is not None: + aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out) + return Wav2Vec2Model(normalize_waveform, feature_extractor, encoder, aux) + + +def wav2vec2_base( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.1, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds "base" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="group_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=768, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=12, + encoder_num_heads=12, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=3072, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + extractor_prune_conv_channels=extractor_prune_conv_channels, + encoder_prune_attention_heads=encoder_prune_attention_heads, + encoder_prune_attention_layer=encoder_prune_attention_layer, + encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + + +def wav2vec2_large( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.1, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds "large" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="group_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + extractor_prune_conv_channels=extractor_prune_conv_channels, + encoder_prune_attention_heads=encoder_prune_attention_heads, + encoder_prune_attention_layer=encoder_prune_attention_layer, + encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + + +def wav2vec2_large_lv60k( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.1, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds "large lv-60k" :class:`~torchaudio.models.Wav2Vec2Model` from *wav2vec 2.0* :cite:`baevski2020wav2vec` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=True, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + extractor_prune_conv_channels=extractor_prune_conv_channels, + encoder_prune_attention_heads=encoder_prune_attention_heads, + encoder_prune_attention_layer=encoder_prune_attention_layer, + encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + + +def hubert_base( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.05, + aux_num_out: Optional[int] = None, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds "base" :class:`HuBERT ` from *HuBERT* :cite:`hsu2021hubert` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="group_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=768, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=12, + encoder_use_attention=[True] * 12, + encoder_use_feed_forward=[True] * 12, + encoder_num_heads=[12] * 12, + encoder_head_dim=64, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=[3072] * 12, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + extractor_prune_conv_channels=extractor_prune_conv_channels, + encoder_prune_attention_heads=encoder_prune_attention_heads, + encoder_prune_attention_layer=encoder_prune_attention_layer, + encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + + +def hubert_large( + encoder_projection_dropout: float = 0.0, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.0, + aux_num_out: Optional[int] = None, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds "large" :class:`HuBERT ` from *HuBERT* :cite:`hsu2021hubert` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + extractor_prune_conv_channels=extractor_prune_conv_channels, + encoder_prune_attention_heads=encoder_prune_attention_heads, + encoder_prune_attention_layer=encoder_prune_attention_layer, + encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + + +def hubert_xlarge( + encoder_projection_dropout: float = 0.0, + encoder_attention_dropout: float = 0.0, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.0, + encoder_layer_drop: float = 0.0, + aux_num_out: Optional[int] = None, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds "extra large" :class:`HuBERT ` from *HuBERT* :cite:`hsu2021hubert` + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int or None, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ # noqa: E501 + return wav2vec2_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1280, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=48, + encoder_num_heads=16, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=5120, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + extractor_prune_conv_channels=extractor_prune_conv_channels, + encoder_prune_attention_heads=encoder_prune_attention_heads, + encoder_prune_attention_layer=encoder_prune_attention_layer, + encoder_prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + encoder_prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + + +def _init_hubert_pretrain_model(module): + if isinstance(module, components.LayerNorm): + torch.nn.init.kaiming_normal_(module.conv.weight) + elif isinstance(module, components.ConvolutionalPositionalEmbedding): + # normalize the weight to normal distribution. + std = math.sqrt(4.0 / (module.embed_dim * module.kernel_size)) + torch.nn.init.normal_(module.conv.weight, mean=0.0, std=std) + torch.nn.init.constant_(module.conv.bias, 0.0) + elif isinstance(module, components.SelfAttention): + # normalize the query, key, value, and out_proj parameters in self attention module. + torch.nn.init.xavier_uniform_(module.k_proj.weight, gain=1 / math.sqrt(2)) + torch.nn.init.xavier_uniform_(module.v_proj.weight, gain=1 / math.sqrt(2)) + torch.nn.init.xavier_uniform_(module.q_proj.weight, gain=1 / math.sqrt(2)) + torch.nn.init.xavier_uniform_(module.out_proj.weight) + torch.nn.init.constant_(module.out_proj.bias, 0.0) + elif isinstance(module, components.Transformer): + module.apply(components._init_transformer_params) + else: + pass + + +def wavlm_model( + extractor_mode: str, + extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]], + extractor_conv_bias: bool, + encoder_embed_dim: int, + encoder_projection_dropout: float, + encoder_pos_conv_kernel: int, + encoder_pos_conv_groups: int, + encoder_num_layers: int, + encoder_use_attention: List[bool], + encoder_use_feed_forward: List[bool], + encoder_total_num_heads: List[int], + encoder_remaining_heads: List[List[int]], + encoder_num_buckets: int, + encoder_max_distance: int, + encoder_attention_dropout: float, + encoder_ff_interm_features: List[int], + encoder_ff_interm_dropout: float, + encoder_dropout: float, + encoder_layer_norm_first: bool, + encoder_layer_drop: float, + aux_num_out: Optional[int], + normalize_waveform: bool, + extractor_prune_conv_channels: bool = False, + encoder_prune_attention_heads: bool = False, + encoder_prune_attention_layer: bool = False, + encoder_prune_feed_forward_intermediate: bool = False, + encoder_prune_feed_forward_layer: bool = False, +) -> Wav2Vec2Model: + """Builds custom WaveLM model :cite:`chen2022wavlm`. The architecture is compatible + with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output object is + :class:`~torchaudio.models.Wav2Vec2Model`. Most of the arguments have the same meaning + as in :py:func:`wav2vec2_model` so please refer there for documentation. + + Args: + extractor_mode (str): Operation mode of feature extractor. + See :py:func:`wav2vec2_model`. + + extractor_conv_layer_config (list of integer tuples or None): + See :py:func:`wav2vec2_model`. + + extractor_conv_bias (bool): + See :py:func:`wav2vec2_model`. + + encoder_embed_dim (int): + See :py:func:`wav2vec2_model`. + + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + + encoder_pos_conv_kernel (int): + See :py:func:`wav2vec2_model`. + + encoder_pos_conv_groups (int): + See :py:func:`wav2vec2_model`. + + encoder_num_layers (int): + See :py:func:`wav2vec2_model`. + + encoder_num_heads (int): + See :py:func:`wav2vec2_model`. + + encoder_num_buckets (int): + Number of buckets for relative position embedding. + encoder_max_distance (int): + Maximum distance for relative position embedding. + + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + + encoder_ff_interm_features (int): + See :py:func:`wav2vec2_model`. + + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + + encoder_layer_norm_first (bool): + See :py:func:`wav2vec2_model`. + + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + + aux_num_out (int or None): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ + if extractor_conv_layer_config is None: + extractor_conv_layer_config = [(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2 + + feature_extractor = components._get_feature_extractor( + extractor_mode, extractor_conv_layer_config, extractor_conv_bias, + prune_conv_channels=extractor_prune_conv_channels, + ) + encoder = components._get_wavlm_encoder( + in_features=extractor_conv_layer_config[-1][0], + embed_dim=encoder_embed_dim, + dropout_input=encoder_projection_dropout, + pos_conv_kernel=encoder_pos_conv_kernel, + pos_conv_groups=encoder_pos_conv_groups, + num_layers=encoder_num_layers, + use_attention=encoder_use_attention, + use_feed_forward=encoder_use_feed_forward, + total_num_heads=encoder_total_num_heads, + remaining_heads=encoder_remaining_heads, + num_buckets=encoder_num_buckets, + max_distance=encoder_max_distance, + attention_dropout=encoder_attention_dropout, + ff_interm_features=encoder_ff_interm_features, + ff_interm_dropout=encoder_ff_interm_dropout, + dropout=encoder_dropout, + layer_norm_first=encoder_layer_norm_first, + layer_drop=encoder_layer_drop, + prune_attention_heads=encoder_prune_attention_heads, + prune_attention_layer=encoder_prune_attention_layer, + prune_feed_forward_intermediate=encoder_prune_feed_forward_intermediate, + prune_feed_forward_layer=encoder_prune_feed_forward_layer, + ) + aux = None + if aux_num_out is not None: + aux = torch.nn.Linear(in_features=encoder_embed_dim, out_features=aux_num_out) + return Wav2Vec2Model(normalize_waveform, feature_extractor, encoder, aux) + + +def wavlm_base( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.1, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + """Builds "base" WaveLM model :cite:`chen2022wavlm`. The architecture is compatible + with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is + :class:`~torchaudio.models.Wav2Vec2Model`. + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ + return wavlm_model( + extractor_mode="group_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=768, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=12, + encoder_num_heads=12, + encoder_num_buckets=320, + encoder_max_distance=800, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=3072, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=False, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) + + +def wavlm_large( + encoder_projection_dropout: float = 0.1, + encoder_attention_dropout: float = 0.1, + encoder_ff_interm_dropout: float = 0.0, + encoder_dropout: float = 0.1, + encoder_layer_drop: float = 0.1, + aux_num_out: Optional[int] = None, +) -> Wav2Vec2Model: + """Builds "large" WaveLM model :cite:`chen2022wavlm`. The architecture is compatible + with Wav2Vec2 model :cite:`baevski2020wav2vec`, and so the output class is + :class:`~torchaudio.models.Wav2Vec2Model`. + + Args: + encoder_projection_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_attention_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_ff_interm_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_dropout (float): + See :py:func:`wav2vec2_model`. + encoder_layer_drop (float): + See :py:func:`wav2vec2_model`. + aux_num_out (int, optional): + See :py:func:`wav2vec2_model`. + + Returns: + Wav2Vec2Model: + The resulting model. + """ + return wavlm_model( + extractor_mode="layer_norm", + extractor_conv_layer_config=None, + extractor_conv_bias=False, + encoder_embed_dim=1024, + encoder_projection_dropout=encoder_projection_dropout, + encoder_pos_conv_kernel=128, + encoder_pos_conv_groups=16, + encoder_num_layers=24, + encoder_num_heads=16, + encoder_num_buckets=320, + encoder_max_distance=800, + encoder_attention_dropout=encoder_attention_dropout, + encoder_ff_interm_features=4096, + encoder_ff_interm_dropout=encoder_ff_interm_dropout, + encoder_dropout=encoder_dropout, + encoder_layer_norm_first=True, + encoder_layer_drop=encoder_layer_drop, + aux_num_out=aux_num_out, + ) diff --git a/vencoder/dphubert/pruning_utils.py b/vencoder/dphubert/pruning_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ac185980c2c3da716bf3ce402a541ffe70776acf --- /dev/null +++ b/vencoder/dphubert/pruning_utils.py @@ -0,0 +1,51 @@ +"""Utility functions for pruning.""" + +from typing import Union + +import torch +import torch.nn as nn + + +def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: str): + "Prune linear layer in place." + # NOTE: weight: (out_features, in_features), bias: (out_features,) + if dim == "input": + dim = 1 + layer.in_features = len(index) + elif dim == "output": + dim = 0 + layer.out_features = len(index) + else: + raise ValueError + + layer.weight = nn.Parameter(layer.weight.index_select(dim, index).clone().detach()) + if layer.bias is not None and dim == 0: + layer.bias = nn.Parameter(layer.bias.index_select(0, index).clone().detach()) + + +def prune_conv1d_layer(layer: nn.Conv1d, index: torch.LongTensor, dim: str): + """Prune conv1d in place.""" + # NOTE: weight: (out_channels, in_channels, kernel_size), bias: (out_channels,) + if dim == "input": + dim = 1 + layer.in_channels = len(index) + elif dim == "output": + dim = 0 + layer.out_channels = len(index) + else: + raise ValueError + + layer.weight = nn.Parameter(layer.weight.index_select(dim, index).clone().detach()) + if layer.bias is not None and dim == 0: + layer.bias = nn.Parameter(layer.bias.index_select(0, index).clone().detach()) + + +def prune_layer_norm(layernorm: Union[nn.LayerNorm, nn.GroupNorm], index: torch.LongTensor): + """Prune layer norm or group norm in place.""" + layernorm.weight = nn.Parameter(layernorm.weight.index_select(0, index).clone().detach()) + layernorm.bias = nn.Parameter(layernorm.bias.index_select(0, index).clone().detach()) + if isinstance(layernorm, nn.LayerNorm): + layernorm.normalized_shape = (len(index),) + elif isinstance(layernorm, nn.GroupNorm): + layernorm.num_groups = len(index) + layernorm.num_channels = len(index) diff --git a/vencoder/dphubert/utils/__init__.py b/vencoder/dphubert/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vencoder/dphubert/utils/import_huggingface_wavlm.py b/vencoder/dphubert/utils/import_huggingface_wavlm.py new file mode 100644 index 0000000000000000000000000000000000000000..24a3f38ae9cc08e19010b2876b19dc9082873377 --- /dev/null +++ b/vencoder/dphubert/utils/import_huggingface_wavlm.py @@ -0,0 +1,129 @@ +"""Import Hugging Face transformers's wav2vec2.0 pretrained weights to torchaudios's format. + +Originally from: +https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/utils/import_huggingface.py + +""" + +import logging +from typing import Any, Dict + +from torch.nn import Module + +from ..model import Wav2Vec2Model, wav2vec2_model, wavlm_model + +_LG = logging.getLogger(__name__) + + +def _get_config(cfg): + config = { + "extractor_mode": f"{cfg.feat_extract_norm}_norm", + "extractor_conv_layer_config": list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)), + "extractor_conv_bias": cfg.conv_bias, + "encoder_embed_dim": cfg.hidden_size, + "encoder_projection_dropout": cfg.feat_proj_dropout, + "encoder_pos_conv_kernel": cfg.num_conv_pos_embeddings, + "encoder_pos_conv_groups": cfg.num_conv_pos_embedding_groups, + "encoder_num_layers": cfg.num_hidden_layers, + "encoder_num_heads": cfg.num_attention_heads, + "encoder_attention_dropout": cfg.attention_dropout, + "encoder_ff_interm_features": cfg.intermediate_size, + "encoder_ff_interm_dropout": cfg.activation_dropout, + "encoder_dropout": cfg.hidden_dropout, + "encoder_layer_norm_first": cfg.do_stable_layer_norm, + "encoder_layer_drop": cfg.layerdrop, + } + return config + + +def _get_config_wavlm(cfg): + config = { + "extractor_mode": f"{cfg.feat_extract_norm}_norm", + "extractor_conv_layer_config": list(zip(cfg.conv_dim, cfg.conv_kernel, cfg.conv_stride)), + "extractor_conv_bias": cfg.conv_bias, + "encoder_embed_dim": cfg.hidden_size, + "encoder_projection_dropout": cfg.feat_proj_dropout, + "encoder_pos_conv_kernel": cfg.num_conv_pos_embeddings, + "encoder_pos_conv_groups": cfg.num_conv_pos_embedding_groups, + "encoder_num_layers": cfg.num_hidden_layers, + "encoder_use_attention": [True] * cfg.num_hidden_layers, + "encoder_use_feed_forward": [True] * cfg.num_hidden_layers, + "encoder_total_num_heads": [cfg.num_attention_heads for _ in range(cfg.num_hidden_layers)], + "encoder_remaining_heads": [list(range(cfg.num_attention_heads)) for _ in range(cfg.num_hidden_layers)], + "encoder_num_buckets": cfg.num_buckets, + "encoder_max_distance": cfg.max_bucket_distance, + "encoder_attention_dropout": cfg.attention_dropout, + "encoder_ff_interm_features": [cfg.intermediate_size for _ in range(cfg.num_hidden_layers)], + "encoder_ff_interm_dropout": cfg.activation_dropout, + "encoder_dropout": cfg.hidden_dropout, + "encoder_layer_norm_first": cfg.do_stable_layer_norm, + "encoder_layer_drop": cfg.layerdrop, + "normalize_waveform": cfg.feat_extract_norm == "layer", + } + return config + + +def _build(config, original): + is_for_ctc = original.__class__.__name__ in ["Wav2Vec2ForCTC", "WavLMForCTC"] + if is_for_ctc: + aux_num_out = original.config.vocab_size + wav2vec2 = original.wav2vec2 + else: + _LG.warning( + "The model is not an instance of Wav2Vec2ForCTC or WavLMForCTC. " '"lm_head" module is not imported.' + ) + aux_num_out = None + wav2vec2 = original + is_wavlm = original.__class__.__name__ in ["WavLMModel", "WavLMForCTC"] + if is_wavlm: + imported = wavlm_model(**config, aux_num_out=aux_num_out) + else: + imported = wav2vec2_model(**config, aux_num_out=aux_num_out) + print(imported.feature_extractor.load_state_dict(wav2vec2.feature_extractor.state_dict(), strict=False)) + print(imported.encoder.feature_projection.load_state_dict(wav2vec2.feature_projection.state_dict(), strict=False)) + encoder_state_dict = wav2vec2.encoder.state_dict() + if is_wavlm: # Rename paramaters of linear transformations for compatibility with the HF model + transform_wavlm_encoder_state(encoder_state_dict, config["encoder_num_layers"]) + print(imported.encoder.transformer.load_state_dict(encoder_state_dict, strict=False)) + if is_for_ctc: + imported.aux.load_state_dict(original.lm_head.state_dict()) + return imported + + +def transform_wavlm_encoder_state(state: Dict[str, Any], encoder_num_layers: int): + """Converts WavLM encoder state from HuggingFace format. In particular, concatenates linear projection weights and + biases to align with the structure of ``torch.nn.MultiheadAttention``. + """ + pass + + +def import_huggingface_model(original: Module) -> Wav2Vec2Model: + """Builds :class:`Wav2Vec2Model` from the corresponding model object of + `Transformers `_. + + Args: + original (torch.nn.Module): An instance of ``Wav2Vec2ForCTC`` from ``transformers``. + + Returns: + Wav2Vec2Model: Imported model. + + Example + >>> from torchaudio.models.wav2vec2.utils import import_huggingface_model + >>> + >>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") + >>> model = import_huggingface_model(original) + >>> + >>> waveforms, _ = torchaudio.load("audio.wav") + >>> logits, _ = model(waveforms) + """ + _LG.info("Importing model.") + _LG.info("Loading model configuration.") + is_wavlm = original.__class__.__name__ in ["WavLMModel", "WavLMForCTC"] + if is_wavlm: + config = _get_config_wavlm(original.config) + else: + config = _get_config(original.config) + _LG.debug(" - config: %s", config) + _LG.info("Building model.") + imported = _build(config, original) + return imported diff --git a/vencoder/encoder.py b/vencoder/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..9ad120da34893d64b47b8ebeeaaed1f822a2e0be --- /dev/null +++ b/vencoder/encoder.py @@ -0,0 +1,13 @@ +class SpeechEncoder(object): + def __init__(self, vec_path="pretrain/checkpoint_best_legacy_500.pt", device=None): + self.model = None # This is Model + self.hidden_dim = 768 + pass + + + def encoder(self, wav): + """ + input: wav:[signal_length] + output: embedding:[batchsize,hidden_dim,wav_frame] + """ + pass diff --git a/vencoder/hubert/__init__.py b/vencoder/hubert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vencoder/hubert/__pycache__/__init__.cpython-38.pyc b/vencoder/hubert/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd7309fa140a6c763ca7a619b8d05c42d666cc83 Binary files /dev/null and b/vencoder/hubert/__pycache__/__init__.cpython-38.pyc differ diff --git a/vencoder/hubert/__pycache__/hubert_model.cpython-38.pyc b/vencoder/hubert/__pycache__/hubert_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6e519a942f14d95ca0e77e8f24f5e1945c82a02 Binary files /dev/null and b/vencoder/hubert/__pycache__/hubert_model.cpython-38.pyc differ diff --git a/vencoder/hubert/hubert_model.py b/vencoder/hubert/hubert_model.py new file mode 100644 index 0000000000000000000000000000000000000000..7fb642d89b07ca60792debab18e3454f52d8f357 --- /dev/null +++ b/vencoder/hubert/hubert_model.py @@ -0,0 +1,222 @@ +import copy +import random +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as t_func +from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present + + +class Hubert(nn.Module): + def __init__(self, num_label_embeddings: int = 100, mask: bool = True): + super().__init__() + self._mask = mask + self.feature_extractor = FeatureExtractor() + self.feature_projection = FeatureProjection() + self.positional_embedding = PositionalConvEmbedding() + self.norm = nn.LayerNorm(768) + self.dropout = nn.Dropout(0.1) + self.encoder = TransformerEncoder( + nn.TransformerEncoderLayer( + 768, 12, 3072, activation="gelu", batch_first=True + ), + 12, + ) + self.proj = nn.Linear(768, 256) + + self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_()) + self.label_embedding = nn.Embedding(num_label_embeddings, 256) + + def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + mask = None + if self.training and self._mask: + mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2) + x[mask] = self.masked_spec_embed.to(x.dtype) + return x, mask + + def encode( + self, x: torch.Tensor, layer: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = self.feature_extractor(x) + x = self.feature_projection(x.transpose(1, 2)) + x, mask = self.mask(x) + x = x + self.positional_embedding(x) + x = self.dropout(self.norm(x)) + x = self.encoder(x, output_layer=layer) + return x, mask + + def logits(self, x: torch.Tensor) -> torch.Tensor: + logits = torch.cosine_similarity( + x.unsqueeze(2), + self.label_embedding.weight.unsqueeze(0).unsqueeze(0), + dim=-1, + ) + return logits / 0.1 + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x, mask = self.encode(x) + x = self.proj(x) + logits = self.logits(x) + return logits, mask + + +class HubertSoft(Hubert): + def __init__(self): + super().__init__() + + @torch.inference_mode() + def units(self, wav: torch.Tensor) -> torch.Tensor: + wav = t_func.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) + x, _ = self.encode(wav) + return self.proj(x) + + +class FeatureExtractor(nn.Module): + def __init__(self): + super().__init__() + self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False) + self.norm0 = nn.GroupNorm(512, 512) + self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False) + self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = t_func.gelu(self.norm0(self.conv0(x))) + x = t_func.gelu(self.conv1(x)) + x = t_func.gelu(self.conv2(x)) + x = t_func.gelu(self.conv3(x)) + x = t_func.gelu(self.conv4(x)) + x = t_func.gelu(self.conv5(x)) + x = t_func.gelu(self.conv6(x)) + return x + + +class FeatureProjection(nn.Module): + def __init__(self): + super().__init__() + self.norm = nn.LayerNorm(512) + self.projection = nn.Linear(512, 768) + self.dropout = nn.Dropout(0.1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + x = self.projection(x) + x = self.dropout(x) + return x + + +class PositionalConvEmbedding(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv1d( + 768, + 768, + kernel_size=128, + padding=128 // 2, + groups=16, + ) + self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x.transpose(1, 2)) + x = t_func.gelu(x[:, :, :-1]) + return x.transpose(1, 2) + + +class TransformerEncoder(nn.Module): + def __init__( + self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int + ) -> None: + super(TransformerEncoder, self).__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for _ in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: torch.Tensor, + mask: torch.Tensor = None, + src_key_padding_mask: torch.Tensor = None, + output_layer: Optional[int] = None, + ) -> torch.Tensor: + output = src + for layer in self.layers[:output_layer]: + output = layer( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask + ) + return output + + +def _compute_mask( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + device: torch.device, + min_masks: int = 0, +) -> torch.Tensor: + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`" + ) + + # compute number of masked spans in batch + num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random()) + num_masked_spans = max(num_masked_spans, min_masks) + + # make sure num masked indices <= sequence_length + if num_masked_spans * mask_length > sequence_length: + num_masked_spans = sequence_length // mask_length + + # SpecAugment mask to fill + mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool) + + # uniform distribution to sample from, make sure that offset samples are < sequence_length + uniform_dist = torch.ones( + (batch_size, sequence_length - (mask_length - 1)), device=device + ) + + # get random indices to mask + mask_indices = torch.multinomial(uniform_dist, num_masked_spans) + + # expand masked indices to masked spans + mask_indices = ( + mask_indices.unsqueeze(dim=-1) + .expand((batch_size, num_masked_spans, mask_length)) + .reshape(batch_size, num_masked_spans * mask_length) + ) + offsets = ( + torch.arange(mask_length, device=device)[None, None, :] + .expand((batch_size, num_masked_spans, mask_length)) + .reshape(batch_size, num_masked_spans * mask_length) + ) + mask_idxs = mask_indices + offsets + + # scatter indices to mask + mask = mask.scatter(1, mask_idxs, True) + + return mask + + +def hubert_soft( + path: str, +) -> HubertSoft: + r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. + Args: + path (str): path of a pretrained model + """ + hubert = HubertSoft() + checkpoint = torch.load(path) + consume_prefix_in_state_dict_if_present(checkpoint, "module.") + hubert.load_state_dict(checkpoint) + hubert.eval() + return hubert diff --git a/vencoder/hubert/hubert_model_onnx.py b/vencoder/hubert/hubert_model_onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..d18f3c2a0fc29592a573a9780308d38f059640b9 --- /dev/null +++ b/vencoder/hubert/hubert_model_onnx.py @@ -0,0 +1,217 @@ +import copy +import random +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as t_func +from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present + + +class Hubert(nn.Module): + def __init__(self, num_label_embeddings: int = 100, mask: bool = True): + super().__init__() + self._mask = mask + self.feature_extractor = FeatureExtractor() + self.feature_projection = FeatureProjection() + self.positional_embedding = PositionalConvEmbedding() + self.norm = nn.LayerNorm(768) + self.dropout = nn.Dropout(0.1) + self.encoder = TransformerEncoder( + nn.TransformerEncoderLayer( + 768, 12, 3072, activation="gelu", batch_first=True + ), + 12, + ) + self.proj = nn.Linear(768, 256) + + self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_()) + self.label_embedding = nn.Embedding(num_label_embeddings, 256) + + def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + mask = None + if self.training and self._mask: + mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2) + x[mask] = self.masked_spec_embed.to(x.dtype) + return x, mask + + def encode( + self, x: torch.Tensor, layer: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = self.feature_extractor(x) + x = self.feature_projection(x.transpose(1, 2)) + x, mask = self.mask(x) + x = x + self.positional_embedding(x) + x = self.dropout(self.norm(x)) + x = self.encoder(x, output_layer=layer) + return x, mask + + def logits(self, x: torch.Tensor) -> torch.Tensor: + logits = torch.cosine_similarity( + x.unsqueeze(2), + self.label_embedding.weight.unsqueeze(0).unsqueeze(0), + dim=-1, + ) + return logits / 0.1 + + +class HubertSoft(Hubert): + def __init__(self): + super().__init__() + + def units(self, wav: torch.Tensor) -> torch.Tensor: + wav = t_func.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) + x, _ = self.encode(wav) + return self.proj(x) + + def forward(self, x): + return self.units(x) + +class FeatureExtractor(nn.Module): + def __init__(self): + super().__init__() + self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False) + self.norm0 = nn.GroupNorm(512, 512) + self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False) + self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = t_func.gelu(self.norm0(self.conv0(x))) + x = t_func.gelu(self.conv1(x)) + x = t_func.gelu(self.conv2(x)) + x = t_func.gelu(self.conv3(x)) + x = t_func.gelu(self.conv4(x)) + x = t_func.gelu(self.conv5(x)) + x = t_func.gelu(self.conv6(x)) + return x + + +class FeatureProjection(nn.Module): + def __init__(self): + super().__init__() + self.norm = nn.LayerNorm(512) + self.projection = nn.Linear(512, 768) + self.dropout = nn.Dropout(0.1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + x = self.projection(x) + x = self.dropout(x) + return x + + +class PositionalConvEmbedding(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv1d( + 768, + 768, + kernel_size=128, + padding=128 // 2, + groups=16, + ) + self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x.transpose(1, 2)) + x = t_func.gelu(x[:, :, :-1]) + return x.transpose(1, 2) + + +class TransformerEncoder(nn.Module): + def __init__( + self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int + ) -> None: + super(TransformerEncoder, self).__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for _ in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: torch.Tensor, + mask: torch.Tensor = None, + src_key_padding_mask: torch.Tensor = None, + output_layer: Optional[int] = None, + ) -> torch.Tensor: + output = src + for layer in self.layers[:output_layer]: + output = layer( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask + ) + return output + + +def _compute_mask( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + device: torch.device, + min_masks: int = 0, +) -> torch.Tensor: + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`" + ) + + # compute number of masked spans in batch + num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random()) + num_masked_spans = max(num_masked_spans, min_masks) + + # make sure num masked indices <= sequence_length + if num_masked_spans * mask_length > sequence_length: + num_masked_spans = sequence_length // mask_length + + # SpecAugment mask to fill + mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool) + + # uniform distribution to sample from, make sure that offset samples are < sequence_length + uniform_dist = torch.ones( + (batch_size, sequence_length - (mask_length - 1)), device=device + ) + + # get random indices to mask + mask_indices = torch.multinomial(uniform_dist, num_masked_spans) + + # expand masked indices to masked spans + mask_indices = ( + mask_indices.unsqueeze(dim=-1) + .expand((batch_size, num_masked_spans, mask_length)) + .reshape(batch_size, num_masked_spans * mask_length) + ) + offsets = ( + torch.arange(mask_length, device=device)[None, None, :] + .expand((batch_size, num_masked_spans, mask_length)) + .reshape(batch_size, num_masked_spans * mask_length) + ) + mask_idxs = mask_indices + offsets + + # scatter indices to mask + mask = mask.scatter(1, mask_idxs, True) + + return mask + + +def hubert_soft( + path: str, +) -> HubertSoft: + r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. + Args: + path (str): path of a pretrained model + """ + hubert = HubertSoft() + checkpoint = torch.load(path) + consume_prefix_in_state_dict_if_present(checkpoint, "module.") + hubert.load_state_dict(checkpoint) + hubert.eval() + return hubert diff --git a/vencoder/wavlm/WavLM.py b/vencoder/wavlm/WavLM.py new file mode 100644 index 0000000000000000000000000000000000000000..5a3986fdcc00033a9e8f1bfcd25df3799f40ed90 --- /dev/null +++ b/vencoder/wavlm/WavLM.py @@ -0,0 +1,741 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import logging +import math +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import LayerNorm + +from vencoder.wavlm.modules import ( + Fp32GroupNorm, + Fp32LayerNorm, + GLU_Linear, + GradMultiply, + MultiheadAttention, + SamePad, + TransposeLast, + get_activation_fn, + init_bert_params, +) + +logger = logging.getLogger(__name__) + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return mask + + +class WavLMConfig: + def __init__(self, cfg=None): + self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) + self.encoder_layers: int = 12 # num encoder layers in the transformer + + self.encoder_embed_dim: int = 768 # encoder embedding dimension + self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN + self.encoder_attention_heads: int = 12 # num encoder attention heads + self.activation_fn: str = "gelu" # activation function to use + + self.layer_norm_first: bool = False # apply layernorm first in the transformer + self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] + self.conv_bias: bool = False # include bias in conv encoder + self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this + + self.normalize: bool = False # normalize input to have 0 mean and unit variance during training + + # dropouts + self.dropout: float = 0.1 # dropout probability for the transformer + self.attention_dropout: float = 0.1 # dropout probability for attention weights + self.activation_dropout: float = 0.0 # dropout probability after activation in FFN + self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer + self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) + self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr) + + # masking + self.mask_length: int = 10 # mask length + self.mask_prob: float = 0.65 # probability of replacing a token with mask + self.mask_selection: str = "static" # how to choose mask length + self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh + self.no_mask_overlap: bool = False # whether to allow masks to overlap + self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # channel masking + self.mask_channel_length: int = 10 # length of the mask for features (channels) + self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 + self.mask_channel_selection: str = "static" # how to choose mask length for channel masking + self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices + self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap + self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled) + + # positional embeddings + self.conv_pos: int = 128 # number of filters for convolutional positional embeddings + self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding + + # relative position embedding + self.relative_position_embedding: bool = False # apply relative position embedding + self.num_buckets: int = 320 # number of buckets for relative position embedding + self.max_distance: int = 1280 # maximum distance for relative position embedding + self.gru_rel_pos: bool = False # apply gated relative position embedding + + if cfg is not None: + self.update(cfg) + + def update(self, cfg: dict): + self.__dict__.update(cfg) + + +class WavLM(nn.Module): + def __init__( + self, + cfg: WavLMConfig, + ) -> None: + super().__init__() + logger.info(f"WavLM Config: {cfg.__dict__}") + + self.cfg = cfg + feature_enc_layers = eval(cfg.conv_feature_layers) + self.embed = feature_enc_layers[-1][0] + + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + def apply_mask(self, x, padding_mask): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward_padding_mask( + self, features: torch.Tensor, padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ret_layer_results: bool = False, + ): + + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = self.feature_extractor(source) + + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + + if mask: + x, mask_indices = self.apply_mask( + features, padding_mask + ) + else: + x = features + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1 + ) + + res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} + + feature = res["features"] if ret_conv else res["x"] + if ret_layer_results: + feature = (feature, res["layer_results"]) + return feature, res["padding_mask"] + + +class ConvFeatureExtractionModel(nn.Module): + def __init__( + self, + conv_layers: List[Tuple[int, int, int]], + dropout: float = 0.0, + mode: str = "default", + conv_bias: bool = False, + conv_type: str = "default" + ): + super().__init__() + + assert mode in {"default", "layer_norm"} + + def block( + n_in, + n_out, + k, + stride, + is_layer_norm=False, + is_group_norm=False, + conv_bias=False, + ): + def make_conv(): + conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) + nn.init.kaiming_normal_(conv.weight) + return conv + + assert (is_layer_norm and is_group_norm) is False, "layer norm and group norm are exclusive" + + if is_layer_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + nn.Sequential( + TransposeLast(), + Fp32LayerNorm(dim, elementwise_affine=True), + TransposeLast(), + ), + nn.GELU(), + ) + elif is_group_norm: + return nn.Sequential( + make_conv(), + nn.Dropout(p=dropout), + Fp32GroupNorm(dim, dim, affine=True), + nn.GELU(), + ) + else: + return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) + + self.conv_type = conv_type + if self.conv_type == "default": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3, "invalid conv definition: " + str(cl) + (dim, k, stride) = cl + + self.conv_layers.append( + block( + in_d, + dim, + k, + stride, + is_layer_norm=mode == "layer_norm", + is_group_norm=mode == "default" and i == 0, + conv_bias=conv_bias, + ) + ) + in_d = dim + elif self.conv_type == "conv2d": + in_d = 1 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + elif self.conv_type == "custom": + in_d = 1 + idim = 80 + self.conv_layers = nn.ModuleList() + for i, cl in enumerate(conv_layers): + assert len(cl) == 3 + (dim, k, stride) = cl + self.conv_layers.append( + torch.nn.Conv2d(in_d, dim, k, stride, padding=1) + ) + self.conv_layers.append( + torch.nn.LayerNorm([dim, idim]) + ) + self.conv_layers.append(torch.nn.ReLU()) + in_d = dim + if (i + 1) % 2 == 0: + self.conv_layers.append( + torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) + ) + idim = int(math.ceil(idim / 2)) + else: + pass + + def forward(self, x, mask=None): + + # BxT -> BxCxT + x = x.unsqueeze(1) + if self.conv_type == "custom": + for conv in self.conv_layers: + if isinstance(conv, nn.LayerNorm): + x = x.transpose(1, 2) + x = conv(x).transpose(1, 2) + else: + x = conv(x) + x = x.transpose(2, 3).contiguous() + x = x.view(x.size(0), -1, x.size(-1)) + else: + for conv in self.conv_layers: + x = conv(x) + if self.conv_type == "conv2d": + b, c, t, f = x.size() + x = x.transpose(2, 3).contiguous().view(b, c * f, t) + return x + + +class TransformerEncoder(nn.Module): + def __init__(self, args): + super().__init__() + + self.dropout = args.dropout + self.embedding_dim = args.encoder_embed_dim + + self.pos_conv = nn.Conv1d( + self.embedding_dim, + self.embedding_dim, + kernel_size=args.conv_pos, + padding=args.conv_pos // 2, + groups=args.conv_pos_groups, + ) + dropout = 0 + std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) + nn.init.normal_(self.pos_conv.weight, mean=0, std=std) + nn.init.constant_(self.pos_conv.bias, 0) + + self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) + self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) + + if hasattr(args, "relative_position_embedding"): + self.relative_position_embedding = args.relative_position_embedding + self.num_buckets = args.num_buckets + self.max_distance = args.max_distance + else: + self.relative_position_embedding = False + self.num_buckets = 0 + self.max_distance = 0 + + self.layers = nn.ModuleList( + [ + TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + has_relative_attention_bias=(self.relative_position_embedding and i == 0), + num_buckets=self.num_buckets, + max_distance=self.max_distance, + gru_rel_pos=args.gru_rel_pos, + ) + for i in range(args.encoder_layers) + ] + ) + + self.layer_norm_first = args.layer_norm_first + self.layer_norm = LayerNorm(self.embedding_dim) + self.layerdrop = args.encoder_layerdrop + + self.apply(init_bert_params) + + def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): + x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer) + + if self.layer_norm_first and layer is None: + x = self.layer_norm(x) + + return x, layer_results + + def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): + + if padding_mask is not None: + x[padding_mask] = 0 + + x_conv = self.pos_conv(x.transpose(1, 2)) + x_conv = x_conv.transpose(1, 2) + x = x + x_conv + + if not self.layer_norm_first: + x = self.layer_norm(x) + + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + layer_results = [] + z = None + if tgt_layer is not None: + layer_results.append((x, z)) + r = None + pos_bias = None + for i, layer in enumerate(self.layers): + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, + self_attn_mask=streaming_mask, pos_bias=pos_bias) + if tgt_layer is not None: + layer_results.append((x, z)) + if i == tgt_layer: + r = x + break + + if r is not None: + x = r + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, layer_results + + +class TransformerSentenceEncoderLayer(nn.Module): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: float = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + has_relative_attention_bias: bool = False, + num_buckets: int = 0, + max_distance: int = 0, + rescale_init: bool = False, + gru_rel_pos: bool = False, + ) -> None: + + super().__init__() + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_name = activation_fn + self.activation_fn = get_activation_fn(activation_fn) + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + has_relative_attention_bias=has_relative_attention_bias, + num_buckets=num_buckets, + max_distance=max_distance, + rescale_init=rescale_init, + gru_rel_pos=gru_rel_pos, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + + if self.activation_name == "glu": + self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") + else: + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + pos_bias=None + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + x = self.dropout1(x) + x = residual + x + + residual = x + x = self.final_layer_norm(x) + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + else: + x, attn, pos_bias = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=need_weights, + attn_mask=self_attn_mask, + position_bias=pos_bias + ) + + x = self.dropout1(x) + x = residual + x + + x = self.self_attn_layer_norm(x) + + residual = x + if self.activation_name == "glu": + x = self.fc1(x) + else: + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + x = self.dropout3(x) + x = residual + x + x = self.final_layer_norm(x) + + return x, attn, pos_bias + diff --git a/vencoder/wavlm/modules.py b/vencoder/wavlm/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..add4a1aa0042cbcbf5c3b28d4d72f017b507717d --- /dev/null +++ b/vencoder/wavlm/modules.py @@ -0,0 +1,828 @@ +# -------------------------------------------------------- +# WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) +# Github source: https://github.com/microsoft/unilm/tree/master/wavlm +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Based on fairseq code bases +# https://github.com/pytorch/fairseq +# -------------------------------------------------------- + +import math +import warnings +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn import Parameter + + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None): + super().__init__() + self.deconstruct_idx = deconstruct_idx + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(-2, -1) + + +class Fp32LayerNorm(nn.LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.layer_norm( + input.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class Fp32GroupNorm(nn.GroupNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, input): + output = F.group_norm( + input.float(), + self.num_groups, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ) + return output.type_as(input) + + +class GradMultiply(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scale): + ctx.scale = scale + res = x.new(x) + return res + + @staticmethod + def backward(ctx, grad): + return grad * ctx.scale, None + + +class SamePad(nn.Module): + def __init__(self, kernel_size, causal=False): + super().__init__() + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + if self.remove > 0: + x = x[:, :, : -self.remove] + return x + + +class Swish(nn.Module): + """Swish function + """ + + def __init__(self): + """Construct an MultiHeadedAttention object.""" + super(Swish, self).__init__() + self.act = torch.nn.Sigmoid() + + def forward(self, x): + return x * self.act(x) + + +class GLU_Linear(nn.Module): + def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): + super(GLU_Linear, self).__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + + if glu_type == "sigmoid": + self.glu_act = torch.nn.Sigmoid() + elif glu_type == "swish": + self.glu_act = Swish() + elif glu_type == "relu": + self.glu_act = torch.nn.ReLU() + elif glu_type == "gelu": + self.glu_act = torch.nn.GELU() + + if bias_in_glu: + self.linear = nn.Linear(input_dim, output_dim * 2, True) + else: + self.linear = nn.Linear(input_dim, output_dim * 2, False) + + def forward(self, x): + # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = self.linear(x) + + if self.glu_type == "bilinear": + x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) + else: + x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) + + return x + + +def gelu_accurate(x): + if not hasattr(gelu_accurate, "_a"): + gelu_accurate._a = math.sqrt(2 / math.pi) + return ( + 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) + ) + + +def gelu(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x.float()).type_as(x) + + +def get_activation_fn(activation: str): + """Returns the activation function corresponding to `activation`""" + + if activation == "relu": + return F.relu + elif activation == "gelu": + return gelu + elif activation == "gelu_fast": + warnings.warn( + "--activation-fn=gelu_fast has been renamed to gelu_accurate" + ) + return gelu_accurate + elif activation == "gelu_accurate": + return gelu_accurate + elif activation == "tanh": + return torch.tanh + elif activation == "linear": + return lambda x: x + elif activation == "glu": + return lambda x: x + else: + raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +def init_bert_params(module): + """ + Initialize the weights specific to the BERT Model. + This overrides the default initializations depending on the specified arguments. + 1. If normal_init_linear_weights is set then weights of linear + layer will be initialized using the normal distribution and + bais will be set to the specified value. + 2. If normal_init_embed_weights is set then weights of embedding + layer will be initialized using the normal distribution. + 3. If normal_init_proj_weights is set then weights of + in_project_weight for MultiHeadAttention initialized using + the normal distribution (to be validated). + """ + + def normal_(data): + # with FSDP, module params will be on CUDA, so we cast them back to CPU + # so that the RNG is consistent with and without FSDP + data.copy_( + data.cpu().normal_(mean=0.0, std=0.02).to(data.device) + ) + + if isinstance(module, nn.Linear): + normal_(module.weight.data) + if module.bias is not None: + module.bias.data.zero_() + if isinstance(module, nn.Embedding): + normal_(module.weight.data) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + if isinstance(module, MultiheadAttention): + normal_(module.q_proj.weight.data) + normal_(module.k_proj.weight.data) + normal_(module.v_proj.weight.data) + + +def quant_noise(module, p, block_size): + """ + Wraps modules and applies quantization noise to the weights for + subsequent quantization with Iterative Product Quantization as + described in "Training with Quantization Noise for Extreme Model Compression" + + Args: + - module: nn.Module + - p: amount of Quantization Noise + - block_size: size of the blocks for subsequent quantization with iPQ + + Remarks: + - Module weights must have the right sizes wrt the block size + - Only Linear, Embedding and Conv2d modules are supported for the moment + - For more detail on how to quantize by blocks with convolutional weights, + see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" + - We implement the simplest form of noise here as stated in the paper + which consists in randomly dropping blocks + """ + + # if no quantization noise, don't register hook + if p <= 0: + return module + + # supported modules + assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) + + # test whether module.weight has the right sizes wrt block_size + is_conv = module.weight.ndim == 4 + + # 2D matrix + if not is_conv: + assert ( + module.weight.size(1) % block_size == 0 + ), "Input features must be a multiple of block sizes" + + # 4D matrix + else: + # 1x1 convolutions + if module.kernel_size == (1, 1): + assert ( + module.in_channels % block_size == 0 + ), "Input channels must be a multiple of block sizes" + # regular convolutions + else: + k = module.kernel_size[0] * module.kernel_size[1] + assert k % block_size == 0, "Kernel size must be a multiple of block size" + + def _forward_pre_hook(mod, input): + # no noise for evaluation + if mod.training: + if not is_conv: + # gather weight and sizes + weight = mod.weight + in_features = weight.size(1) + out_features = weight.size(0) + + # split weight matrix into blocks and randomly drop selected blocks + mask = torch.zeros( + in_features // block_size * out_features, device=weight.device + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) + + else: + # gather weight and sizes + weight = mod.weight + in_channels = mod.in_channels + out_channels = mod.out_channels + + # split weight matrix into blocks and randomly drop selected blocks + if mod.kernel_size == (1, 1): + mask = torch.zeros( + int(in_channels // block_size * out_channels), + device=weight.device, + ) + mask.bernoulli_(p) + mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) + else: + mask = torch.zeros( + weight.size(0), weight.size(1), device=weight.device + ) + mask.bernoulli_(p) + mask = ( + mask.unsqueeze(2) + .unsqueeze(3) + .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) + ) + + # scale weights and apply mask + mask = mask.to( + torch.bool + ) # x.bool() is not currently supported in TorchScript + s = 1 / (1 - p) + mod.weight.data = s * weight.masked_fill(mask, 0) + + module.register_forward_pre_hook(_forward_pre_hook) + return module + + +class MultiheadAttention(nn.Module): + """Multi-headed attention. + + See "Attention Is All You Need" for more details. + """ + + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + has_relative_attention_bias=False, + num_buckets=32, + max_distance=128, + gru_rel_pos=False, + rescale_init=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout_module = nn.Dropout(dropout) + + self.has_relative_attention_bias = has_relative_attention_bias + self.num_buckets = num_buckets + self.max_distance = max_distance + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + self.head_dim = embed_dim // num_heads + self.q_head_dim = self.head_dim + self.k_head_dim = self.head_dim + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ( + "Self-attention requires query, key and " "value to be of the same size" + ) + + k_bias = True + if rescale_init: + k_bias = False + + k_embed_dim = embed_dim + q_embed_dim = embed_dim + + self.k_proj = quant_noise( + nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size + ) + self.v_proj = quant_noise( + nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size + ) + self.q_proj = quant_noise( + nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size + ) + + self.out_proj = quant_noise( + nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size + ) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.gru_rel_pos = gru_rel_pos + if self.gru_rel_pos: + self.grep_linear = nn.Linear(self.q_head_dim, 8) + self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) + + self.reset_parameters() + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.k_proj.weight) + nn.init.xavier_uniform_(self.v_proj.weight) + nn.init.xavier_uniform_(self.q_proj.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + if self.has_relative_attention_bias: + nn.init.xavier_normal_(self.relative_attention_bias.weight) + + def _relative_positions_bucket(self, relative_positions, bidirectional=True): + num_buckets = self.num_buckets + max_distance = self.max_distance + relative_buckets = 0 + + if bidirectional: + num_buckets = num_buckets // 2 + relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets + relative_positions = torch.abs(relative_positions) + else: + relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) + + max_exact = num_buckets // 2 + is_small = relative_positions < max_exact + + relative_postion_if_large = max_exact + ( + torch.log(relative_positions.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_postion_if_large = torch.min( + relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length): + context_position = torch.arange(query_length, dtype=torch.long)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + relative_position_bucket = self._relative_positions_bucket( + relative_position, + bidirectional=True + ) + relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) + values = self.relative_attention_bias(relative_position_bucket) + values = values.permute([2, 0, 1]) + return values + + def forward( + self, + query, + key: Optional[Tensor], + value: Optional[Tensor], + key_padding_mask: Optional[Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + position_bias: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert key_bsz == bsz + assert value is not None + assert src_len, bsz == value.shape[:2] + + if self.has_relative_attention_bias and position_bias is None: + position_bias = self.compute_bias(tgt_len, src_len) + position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) + + if ( + not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + and self.q_head_dim == self.head_dim + ): + assert key is not None and value is not None + assert attn_mask is None + + attn_mask_rel_pos = None + if position_bias is not None: + attn_mask_rel_pos = position_bias + if self.gru_rel_pos: + query_layer = query.transpose(0, 1) + new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1) + query_layer = query_layer.view(*new_x_shape) + query_layer = query_layer.permute(0, 2, 1, 3) + _B, _H, _L, __ = query_layer.size() + + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len)) + k_proj_bias = self.k_proj.bias + if k_proj_bias is None: + k_proj_bias = torch.zeros_like(self.q_proj.bias) + + x, attn = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training, + # self.training or self.dropout_module.apply_during_inference, + key_padding_mask, + need_weights, + attn_mask_rel_pos, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + return x, attn, position_bias + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + key_padding_mask.new_zeros(key_padding_mask.size(0), 1), + ], + dim=1, + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.q_head_dim) + .transpose(0, 1) + ) + if k is not None: + k = ( + k.contiguous() + .view(-1, bsz * self.num_heads, self.k_head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat( + [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 + ) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [ + key_padding_mask, + torch.zeros(key_padding_mask.size(0), 1).type_as( + key_padding_mask + ), + ], + dim=1, + ) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v, position_bias + + if position_bias is not None: + if self.gru_rel_pos == 1: + query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) + _B, _H, _L, __ = query_layer.size() + gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( + _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) + gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 + position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias + + position_bias = position_bias.view(attn_weights.size()) + + attn_weights = attn_weights + position_bias + + attn_weights_float = F.softmax( + attn_weights, dim=-1 + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights, position_bias + + @staticmethod + def _append_prev_key_padding_mask( + key_padding_mask: Optional[Tensor], + prev_key_padding_mask: Optional[Tensor], + batch_size: int, + src_len: int, + static_kv: bool, + ) -> Optional[Tensor]: + # saved key padding masks have shape (bsz, seq_len) + if prev_key_padding_mask is not None and static_kv: + new_key_padding_mask = prev_key_padding_mask + elif prev_key_padding_mask is not None and key_padding_mask is not None: + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 + ) + # During incremental decoding, as the padding token enters and + # leaves the frame, there will be a time when prev or current + # is None + elif prev_key_padding_mask is not None: + if src_len > prev_key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - prev_key_padding_mask.size(1)), + device=prev_key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [prev_key_padding_mask.float(), filler.float()], dim=1 + ) + else: + new_key_padding_mask = prev_key_padding_mask.float() + elif key_padding_mask is not None: + if src_len > key_padding_mask.size(1): + filler = torch.zeros( + (batch_size, src_len - key_padding_mask.size(1)), + device=key_padding_mask.device, + ) + new_key_padding_mask = torch.cat( + [filler.float(), key_padding_mask.float()], dim=1 + ) + else: + new_key_padding_mask = key_padding_mask.float() + else: + new_key_padding_mask = prev_key_padding_mask + return new_key_padding_mask + + def _get_input_buffer( + self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] + ) -> Dict[str, Optional[Tensor]]: + result = self.get_incremental_state(incremental_state, "attn_state") + if result is not None: + return result + else: + empty_result: Dict[str, Optional[Tensor]] = {} + return empty_result + + def _set_input_buffer( + self, + incremental_state: Dict[str, Dict[str, Optional[Tensor]]], + buffer: Dict[str, Optional[Tensor]], + ): + return self.set_incremental_state(incremental_state, "attn_state", buffer) + + def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): + return attn_weights diff --git a/vencoder/whisper/__init__.py b/vencoder/whisper/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vencoder/whisper/__pycache__/__init__.cpython-38.pyc b/vencoder/whisper/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65afdea99c3827ee46bee4ba6ce8bd0bb68817d1 Binary files /dev/null and b/vencoder/whisper/__pycache__/__init__.cpython-38.pyc differ diff --git a/vencoder/whisper/__pycache__/audio.cpython-38.pyc b/vencoder/whisper/__pycache__/audio.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6069e7d1f48a241bdd2106d46e47ec121da77ba1 Binary files /dev/null and b/vencoder/whisper/__pycache__/audio.cpython-38.pyc differ diff --git a/vencoder/whisper/__pycache__/decoding.cpython-38.pyc b/vencoder/whisper/__pycache__/decoding.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b33859f8b950e62eac5368a45c9392ebbf0211e Binary files /dev/null and b/vencoder/whisper/__pycache__/decoding.cpython-38.pyc differ diff --git a/vencoder/whisper/__pycache__/model.cpython-38.pyc b/vencoder/whisper/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1cae15ed5b02d77a82e61e97b0584b7e7b944e3 Binary files /dev/null and b/vencoder/whisper/__pycache__/model.cpython-38.pyc differ diff --git a/vencoder/whisper/__pycache__/tokenizer.cpython-38.pyc b/vencoder/whisper/__pycache__/tokenizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad3fa610f46d865fa460c1aa9bc1dbfede98c941 Binary files /dev/null and b/vencoder/whisper/__pycache__/tokenizer.cpython-38.pyc differ diff --git a/vencoder/whisper/__pycache__/utils.cpython-38.pyc b/vencoder/whisper/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ded08d43ccdf79b73ad8dc439fbd0f909b313e1 Binary files /dev/null and b/vencoder/whisper/__pycache__/utils.cpython-38.pyc differ diff --git a/vencoder/whisper/audio.py b/vencoder/whisper/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..05890dc195a376181c21072eb0a8af24cf29928a --- /dev/null +++ b/vencoder/whisper/audio.py @@ -0,0 +1,123 @@ +from functools import lru_cache +from typing import Union + +import ffmpeg +import numpy as np +import torch +import torch.nn.functional as F +from librosa.filters import mel as librosa_mel_fn + +from .utils import exact_div + +# hard-coded audio hyperparameters +SAMPLE_RATE = 16000 +N_FFT = 400 +N_MELS = 80 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk +N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input + + +def load_audio(file: str, sr: int = SAMPLE_RATE): + """ + Open an audio file and read as mono waveform, resampling as necessary + + Parameters + ---------- + file: str + The audio file to open + + sr: int + The sample rate to resample the audio if necessary + + Returns + ------- + A NumPy array containing the audio waveform, in float32 dtype. + """ + try: + # This launches a subprocess to decode audio while down-mixing and resampling as necessary. + # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. + out, _ = ( + ffmpeg.input(file, threads=0) + .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr) + .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) + ) + except ffmpeg.Error as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + + return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 + + +def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): + """ + Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + """ + if torch.is_tensor(array): + if array.shape[axis] > length: + array = array.index_select(dim=axis, index=torch.arange(length, device=array.device)) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) + else: + if array.shape[axis] > length: + array = array.take(indices=range(length), axis=axis) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = np.pad(array, pad_widths) + + return array + + +@lru_cache(maxsize=None) +def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: + """ + load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + Allows decoupling librosa dependency; saved using: + + np.savez_compressed( + "mel_filters.npz", + mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + ) + """ + assert n_mels == 80, f"Unsupported n_mels: {n_mels}" + return torch.from_numpy(librosa_mel_fn(sr=SAMPLE_RATE,n_fft=N_FFT,n_mels=n_mels)).to(device) + + +def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 is supported + + Returns + ------- + torch.Tensor, shape = (80, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not torch.is_tensor(audio): + if isinstance(audio, str): + audio = load_audio(audio) + audio = torch.from_numpy(audio) + + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 + + filters = mel_filters(audio.device, n_mels) + mel_spec = filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec diff --git a/vencoder/whisper/decoding.py b/vencoder/whisper/decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..45e50b1c33c2c8f9ca6572e6175b8d6051ae02ee --- /dev/null +++ b/vencoder/whisper/decoding.py @@ -0,0 +1,712 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.distributions import Categorical + +from .audio import CHUNK_LENGTH +from .tokenizer import Tokenizer, get_tokenizer +from .utils import compression_ratio + +if TYPE_CHECKING: + from .model import Whisper + + +@torch.no_grad() +def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]: + """ + Detect the spoken language in the audio, and return them as list of strings, along with the ids + of the most probable language tokens and the probability distribution over all language tokens. + This is performed outside the main decode loop in order to not interfere with kv-caching. + + Returns + ------- + language_tokens : Tensor, shape = (n_audio,) + ids of the most probable language tokens, which appears after the startoftranscript token. + language_probs : List[Dict[str, float]], length = n_audio + list of dictionaries containing the probability distribution over all languages. + """ + if tokenizer is None: + tokenizer = get_tokenizer(model.is_multilingual) + if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence: + raise ValueError("This model doesn't have language tokens so it can't perform lang id") + + single = mel.ndim == 2 + if single: + mel = mel.unsqueeze(0) + + # skip encoder forward pass if already-encoded audio features were given + if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): + mel = model.encoder(mel) + + # forward pass using a single token, startoftranscript + n_audio = mel.shape[0] + x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] + logits = model.logits(x, mel)[:, 0] + + # collect detected languages; suppress all non-language tokens + mask = torch.ones(logits.shape[-1], dtype=torch.bool) + mask[list(tokenizer.all_language_tokens)] = False + logits[:, mask] = -np.inf + language_tokens = logits.argmax(dim=-1) + language_token_probs = logits.softmax(dim=-1).cpu() + language_probs = [ + { + c: language_token_probs[i, j].item() + for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes) + } + for i in range(n_audio) + ] + + if single: + language_tokens = language_tokens[0] + language_probs = language_probs[0] + + return language_tokens, language_probs + + +@dataclass(frozen=True) +class DecodingOptions: + task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate" + language: Optional[str] = None # language that the audio is in; uses detected language if None + + # sampling-related options + temperature: float = 0.0 + sample_len: Optional[int] = None # maximum number of tokens to sample + best_of: Optional[int] = None # number of independent samples to collect, when t > 0 + beam_size: Optional[int] = None # number of beams in beam search, when t == 0 + patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424) + + # options for ranking generations (either beams or best-of-N samples) + length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm + + # prompt, prefix, and token suppression + prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context + prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context + suppress_blank: bool = True # this will suppress blank outputs + + # list of tokens ids (or comma-separated token ids) to suppress + # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()` + suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1" + + # timestamp sampling options + without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only + max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this + + # implementation details + fp16: bool = True # use fp16 for most of the calculation + + +@dataclass(frozen=True) +class DecodingResult: + audio_features: Tensor + language: str + language_probs: Optional[Dict[str, float]] = None + tokens: List[int] = field(default_factory=list) + text: str = "" + avg_logprob: float = np.nan + no_speech_prob: float = np.nan + temperature: float = np.nan + compression_ratio: float = np.nan + + +class Inference: + def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: + """Perform a forward pass on the decoder and return per-token logits""" + raise NotImplementedError + + def rearrange_kv_cache(self, source_indices) -> None: + """Update the key-value cache according to the updated beams""" + raise NotImplementedError + + def cleanup_caching(self) -> None: + """Clean up any resources or hooks after decoding is finished""" + pass + + +class PyTorchInference(Inference): + def __init__(self, model: "Whisper", initial_token_length: int): + self.model: "Whisper" = model + self.initial_token_length = initial_token_length + self.kv_cache = {} + self.hooks = [] + + def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: + if not self.kv_cache: + self.kv_cache, self.hooks = self.model.install_kv_cache_hooks() + + if tokens.shape[-1] > self.initial_token_length: + # only need to use the last token except in the first forward pass + tokens = tokens[:, -1:] + + return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) + + def cleanup_caching(self): + for hook in self.hooks: + hook.remove() + + self.kv_cache = {} + self.hooks = [] + + def rearrange_kv_cache(self, source_indices): + for module, tensor in self.kv_cache.items(): + # update the key/value cache to contain the selected sequences + self.kv_cache[module] = tensor[source_indices].detach() + + +class SequenceRanker: + def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]: + """ + Given a list of groups of samples and their cumulative log probabilities, + return the indices of the samples in each group to select as the final result + """ + raise NotImplementedError + + +class MaximumLikelihoodRanker(SequenceRanker): + """ + Select the sample with the highest log probabilities, penalized using either + a simple length normalization or Google NMT paper's length penalty + """ + + def __init__(self, length_penalty: Optional[float]): + self.length_penalty = length_penalty + + def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]): + def scores(logprobs, lengths): + result = [] + for logprob, length in zip(logprobs, lengths): + if self.length_penalty is None: + penalty = length + else: + # from the Google NMT paper + penalty = ((5 + length) / 6) ** self.length_penalty + result.append(logprob / penalty) + return result + + # get the sequence with the highest score + lengths = [[len(t) for t in s] for s in tokens] + return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)] + + +class TokenDecoder: + def reset(self): + """Initialize any stateful variables for decoding a new sequence""" + + def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: + """Specify how to select the next token, based on the current trace and logits + + Parameters + ---------- + tokens : Tensor, shape = (n_batch, current_sequence_length) + all tokens in the context so far, including the prefix and sot_sequence tokens + + logits : Tensor, shape = (n_batch, vocab_size) + per-token logits of the probability distribution at the current step + + sum_logprobs : Tensor, shape = (n_batch) + cumulative log probabilities for each sequence + + Returns + ------- + tokens : Tensor, shape = (n_batch, current_sequence_length + 1) + the tokens, appended with the selected next token + + completed : bool + True if all sequences has reached the end of text + + """ + raise NotImplementedError + + def finalize( + self, tokens: Tensor, sum_logprobs: Tensor + ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]: + """Finalize search and return the final candidate sequences + + Parameters + ---------- + tokens : Tensor, shape = (n_audio, n_group, current_sequence_length) + all tokens in the context so far, including the prefix and sot_sequence + + sum_logprobs : Tensor, shape = (n_audio, n_group) + cumulative log probabilities for each sequence + + Returns + ------- + tokens : Sequence[Sequence[Tensor]], length = n_audio + sequence of Tensors containing candidate token sequences, for each audio input + + sum_logprobs : List[List[float]], length = n_audio + sequence of cumulative log probabilities corresponding to the above + + """ + raise NotImplementedError + + +class GreedyDecoder(TokenDecoder): + def __init__(self, temperature: float, eot: int): + self.temperature = temperature + self.eot = eot + + def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: + temperature = self.temperature + if temperature == 0: + next_tokens = logits.argmax(dim=-1) + else: + next_tokens = Categorical(logits=logits / temperature).sample() + + logprobs = F.log_softmax(logits.float(), dim=-1) + current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens] + sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot) + + next_tokens[tokens[:, -1] == self.eot] = self.eot + tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1) + + completed = (tokens[:, -1] == self.eot).all() + return tokens, completed + + def finalize(self, tokens: Tensor, sum_logprobs: Tensor): + # make sure each sequence has at least one EOT token at the end + tokens = F.pad(tokens, (0, 1), value=self.eot) + return tokens, sum_logprobs.tolist() + + +class BeamSearchDecoder(TokenDecoder): + def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None): + self.beam_size = beam_size + self.eot = eot + self.inference = inference + self.patience = patience or 1.0 + self.max_candidates: int = round(beam_size * self.patience) + self.finished_sequences = None + + assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})" + + def reset(self): + self.finished_sequences = None + + def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]: + if tokens.shape[0] % self.beam_size != 0: + raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0") + + n_audio = tokens.shape[0] // self.beam_size + if self.finished_sequences is None: # for the first update + self.finished_sequences = [{} for _ in range(n_audio)] + + logprobs = F.log_softmax(logits.float(), dim=-1) + next_tokens, source_indices, finished_sequences = [], [], [] + for i in range(n_audio): + scores, sources, finished = {}, {}, {} + + # STEP 1: calculate the cumulative log probabilities for possible candidates + for j in range(self.beam_size): + idx = i * self.beam_size + j + prefix = tokens[idx].tolist() + for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)): + new_logprob = (sum_logprobs[idx] + logprob).item() + sequence = tuple(prefix + [token.item()]) + scores[sequence] = new_logprob + sources[sequence] = idx + + # STEP 2: rank the candidates and keep the top beam_size sequences for each audio + saved = 0 + for sequence in sorted(scores, key=scores.get, reverse=True): + if sequence[-1] == self.eot: + finished[sequence] = scores[sequence] + else: + sum_logprobs[len(next_tokens)] = scores[sequence] + next_tokens.append(sequence) + source_indices.append(sources[sequence]) + + saved += 1 + if saved == self.beam_size: + break + + finished_sequences.append(finished) + + tokens = torch.tensor(next_tokens, device=tokens.device) + self.inference.rearrange_kv_cache(source_indices) + + # add newly finished sequences to self.finished_sequences + assert len(self.finished_sequences) == len(finished_sequences) + for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences): + for seq in sorted(newly_finished, key=newly_finished.get, reverse=True): + if len(previously_finished) >= self.max_candidates: + break # the candidate list is full + previously_finished[seq] = newly_finished[seq] + + # mark as completed if all audio has enough number of samples + completed = all( + len(sequences) >= self.max_candidates for sequences in self.finished_sequences + ) + return tokens, completed + + def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor): + # collect all finished sequences, including patience, and add unfinished ones if not enough + sum_logprobs = sum_logprobs.cpu() + for i, sequences in enumerate(self.finished_sequences): + if len(sequences) < self.beam_size: # when not enough sequences are finished + for j in list(np.argsort(sum_logprobs[i]))[::-1]: + sequence = preceding_tokens[i, j].tolist() + [self.eot] + sequences[tuple(sequence)] = sum_logprobs[i][j].item() + if len(sequences) >= self.beam_size: + break + + tokens: List[List[Tensor]] = [ + [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences + ] + sum_logprobs: List[List[float]] = [ + list(sequences.values()) for sequences in self.finished_sequences + ] + return tokens, sum_logprobs + + +class LogitFilter: + def apply(self, logits: Tensor, tokens: Tensor) -> None: + """Apply any filtering or masking to logits in-place + + Parameters + ---------- + logits : Tensor, shape = (n_batch, vocab_size) + per-token logits of the probability distribution at the current step + + tokens : Tensor, shape = (n_batch, current_sequence_length) + all tokens in the context so far, including the prefix and sot_sequence tokens + + """ + raise NotImplementedError + + +class SuppressBlank(LogitFilter): + def __init__(self, tokenizer: Tokenizer, sample_begin: int): + self.tokenizer = tokenizer + self.sample_begin = sample_begin + + def apply(self, logits: Tensor, tokens: Tensor): + if tokens.shape[1] == self.sample_begin: + logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf + + +class SuppressTokens(LogitFilter): + def __init__(self, suppress_tokens: Sequence[int]): + self.suppress_tokens = list(suppress_tokens) + + def apply(self, logits: Tensor, tokens: Tensor): + logits[:, self.suppress_tokens] = -np.inf + + +class ApplyTimestampRules(LogitFilter): + def __init__( + self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int] + ): + self.tokenizer = tokenizer + self.sample_begin = sample_begin + self.max_initial_timestamp_index = max_initial_timestamp_index + + def apply(self, logits: Tensor, tokens: Tensor): + # suppress <|notimestamps|> which is handled by without_timestamps + if self.tokenizer.no_timestamps is not None: + logits[:, self.tokenizer.no_timestamps] = -np.inf + + # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly + for k in range(tokens.shape[0]): + seq = [t for t in tokens[k, self.sample_begin :].tolist()] + last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin + penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin + + if last_was_timestamp: + if penultimate_was_timestamp: # has to be non-timestamp + logits[k, self.tokenizer.timestamp_begin :] = -np.inf + else: # cannot be normal text tokens + logits[k, : self.tokenizer.eot] = -np.inf + + if tokens.shape[1] == self.sample_begin: + # suppress generating non-timestamp tokens at the beginning + logits[:, : self.tokenizer.timestamp_begin] = -np.inf + + # apply the `max_initial_timestamp` option + if self.max_initial_timestamp_index is not None: + last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index + logits[:, last_allowed + 1 :] = -np.inf + + # if sum of probability over timestamps is above any other token, sample timestamp + logprobs = F.log_softmax(logits.float(), dim=-1) + for k in range(tokens.shape[0]): + timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1) + max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max() + if timestamp_logprob > max_text_token_logprob: + logits[k, : self.tokenizer.timestamp_begin] = -np.inf + + +class DecodingTask: + inference: Inference + sequence_ranker: SequenceRanker + decoder: TokenDecoder + logit_filters: List[LogitFilter] + + def __init__(self, model: "Whisper", options: DecodingOptions): + self.model = model + + language = options.language or "en" + tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task) + self.tokenizer: Tokenizer = tokenizer + self.options: DecodingOptions = self._verify_options(options) + + self.n_group: int = options.beam_size or options.best_of or 1 + self.n_ctx: int = model.dims.n_text_ctx + self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2 + + self.sot_sequence: Tuple[int] = tokenizer.sot_sequence + if self.options.without_timestamps: + self.sot_sequence = tokenizer.sot_sequence_including_notimestamps + + self.initial_tokens: Tuple[int] = self._get_initial_tokens() + self.sample_begin: int = len(self.initial_tokens) + self.sot_index: int = self.initial_tokens.index(tokenizer.sot) + + # inference: implements the forward pass through the decoder, including kv caching + self.inference = PyTorchInference(model, len(self.initial_tokens)) + + # sequence ranker: implements how to rank a group of sampled sequences + self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty) + + # decoder: implements how to select the next tokens, given the autoregressive distribution + if options.beam_size is not None: + self.decoder = BeamSearchDecoder( + options.beam_size, tokenizer.eot, self.inference, options.patience + ) + else: + self.decoder = GreedyDecoder(options.temperature, tokenizer.eot) + + # logit filters: applies various rules to suppress or penalize certain tokens + self.logit_filters = [] + if self.options.suppress_blank: + self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin)) + if self.options.suppress_tokens: + self.logit_filters.append(SuppressTokens(self._get_suppress_tokens())) + if not options.without_timestamps: + precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds + max_initial_timestamp_index = None + if options.max_initial_timestamp: + max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision) + self.logit_filters.append( + ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index) + ) + + def _verify_options(self, options: DecodingOptions) -> DecodingOptions: + if options.beam_size is not None and options.best_of is not None: + raise ValueError("beam_size and best_of can't be given together") + if options.temperature == 0: + if options.best_of is not None: + raise ValueError("best_of with greedy sampling (T=0) is not compatible") + if options.patience is not None and options.beam_size is None: + raise ValueError("patience requires beam_size to be given") + if options.length_penalty is not None and not (0 <= options.length_penalty <= 1): + raise ValueError("length_penalty (alpha) should be a value between 0 and 1") + + return options + + def _get_initial_tokens(self) -> Tuple[int]: + tokens = list(self.sot_sequence) + prefix = self.options.prefix + prompt = self.options.prompt + + if prefix: + prefix_tokens = ( + self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix + ) + if self.sample_len is not None: + max_prefix_len = self.n_ctx // 2 - self.sample_len + prefix_tokens = prefix_tokens[-max_prefix_len:] + tokens = tokens + prefix_tokens + + if prompt: + prompt_tokens = ( + self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt + ) + tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens + + return tuple(tokens) + + def _get_suppress_tokens(self) -> Tuple[int]: + suppress_tokens = self.options.suppress_tokens + + if isinstance(suppress_tokens, str): + suppress_tokens = [int(t) for t in suppress_tokens.split(",")] + + if -1 in suppress_tokens: + suppress_tokens = [t for t in suppress_tokens if t >= 0] + suppress_tokens.extend(self.tokenizer.non_speech_tokens) + elif suppress_tokens is None or len(suppress_tokens) == 0: + suppress_tokens = [] # interpret empty string as an empty list + else: + assert isinstance(suppress_tokens, list), "suppress_tokens must be a list" + + suppress_tokens.extend( + [self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm] + ) + if self.tokenizer.no_speech is not None: + # no-speech probability is collected separately + suppress_tokens.append(self.tokenizer.no_speech) + + return tuple(sorted(set(suppress_tokens))) + + def _get_audio_features(self, mel: Tensor): + if self.options.fp16: + mel = mel.half() + + if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state): + # encoded audio features are given; skip audio encoding + print("encoded audio features are given; skip audio encoding") + audio_features = mel + else: + print(mel.shape) + print("===============================") + audio_features = self.model.encoder(mel) + + if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32): + return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}") + + return audio_features + + def _detect_language(self, audio_features: Tensor, tokens: Tensor): + languages = [self.options.language] * audio_features.shape[0] + lang_probs = None + + if self.options.language is None or self.options.task == "lang_id": + lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer) + languages = [max(probs, key=probs.get) for probs in lang_probs] + if self.options.language is None: + tokens[:, self.sot_index + 1] = lang_tokens # write language tokens + + return languages, lang_probs + + def _main_loop(self, audio_features: Tensor, tokens: Tensor): + assert audio_features.shape[0] == tokens.shape[0] + n_batch = tokens.shape[0] + sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device) + no_speech_probs = [np.nan] * n_batch + + try: + for i in range(self.sample_len): + logits = self.inference.logits(tokens, audio_features) + + if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs + probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1) + no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() + + # now we need to consider the logits at the last token only + logits = logits[:, -1] + + # apply the logit filters, e.g. for suppressing or applying penalty to + for logit_filter in self.logit_filters: + logit_filter.apply(logits, tokens) + + # expand the tokens tensor with the selected next tokens + tokens, completed = self.decoder.update(tokens, logits, sum_logprobs) + + if completed or tokens.shape[-1] > self.n_ctx: + break + finally: + self.inference.cleanup_caching() + + return tokens, sum_logprobs, no_speech_probs + + @torch.no_grad() + def run(self, mel: Tensor) -> List[DecodingResult]: + self.decoder.reset() + tokenizer: Tokenizer = self.tokenizer + n_audio: int = mel.shape[0] + + audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass + tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1) + + # detect language if requested, overwriting the language token + languages, language_probs = self._detect_language(audio_features, tokens) + if self.options.task == "lang_id": + return [ + DecodingResult(audio_features=features, language=language, language_probs=probs) + for features, language, probs in zip(audio_features, languages, language_probs) + ] + + # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling + audio_features = audio_features.repeat_interleave(self.n_group, dim=0) + tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) + + # call the main sampling loop + tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens) + + # reshape the tensors to have (n_audio, n_group) as the first two dimensions + audio_features = audio_features[:: self.n_group] + no_speech_probs = no_speech_probs[:: self.n_group] + assert audio_features.shape[0] == len(no_speech_probs) == n_audio + + tokens = tokens.reshape(n_audio, self.n_group, -1) + sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group) + + # get the final candidates for each group, and slice between the first sampled token and EOT + tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) + tokens: List[List[Tensor]] = [ + [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens + ] + + # select the top-ranked sample in each group + selected = self.sequence_ranker.rank(tokens, sum_logprobs) + tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)] + texts: List[str] = [tokenizer.decode(t).strip() for t in tokens] + + sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)] + avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)] + + fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs) + if len(set(map(len, fields))) != 1: + raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}") + + return [ + DecodingResult( + audio_features=features, + language=language, + tokens=tokens, + text=text, + avg_logprob=avg_logprob, + no_speech_prob=no_speech_prob, + temperature=self.options.temperature, + compression_ratio=compression_ratio(text), + ) + for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields) + ] + + +@torch.no_grad() +def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]: + """ + Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). + + Parameters + ---------- + model: Whisper + the Whisper model instance + + mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000) + A tensor containing the Mel spectrogram(s) + + options: DecodingOptions + A dataclass that contains all necessary options for decoding 30-second segments + + Returns + ------- + result: Union[DecodingResult, List[DecodingResult]] + The result(s) of decoding contained in `DecodingResult` dataclass instance(s) + """ + single = mel.ndim == 2 + if single: + mel = mel.unsqueeze(0) + result = DecodingTask(model, options).run(mel) + + if single: + result = result[0] + + return result diff --git a/vencoder/whisper/model.py b/vencoder/whisper/model.py new file mode 100644 index 0000000000000000000000000000000000000000..f3de4d32cb9646964074401aad176dbef9ef2125 --- /dev/null +++ b/vencoder/whisper/model.py @@ -0,0 +1,268 @@ +from dataclasses import dataclass +from typing import Dict, Iterable, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from .decoding import decode as decode_function +from .decoding import detect_language as detect_language_function + + +@dataclass +class ModelDimensions: + n_mels: int + n_audio_ctx: int + n_audio_state: int + n_audio_head: int + n_audio_layer: int + n_vocab: int + n_text_ctx: int + n_text_state: int + n_text_head: int + n_text_layer: int + + +class LayerNorm(nn.LayerNorm): + def forward(self, x: Tensor) -> Tensor: + return super().forward(x.float()).type(x.dtype) + + +class Linear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + return F.linear( + x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype) + ) + + +class Conv1d(nn.Conv1d): + def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: + return super()._conv_forward( + x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) + ) + + +def sinusoids(length, channels, max_timescale=10000): + """Returns sinusoids for positional embedding""" + assert channels % 2 == 0 + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + + +class MultiHeadAttention(nn.Module): + def __init__(self, n_state: int, n_head: int): + super().__init__() + self.n_head = n_head + self.query = Linear(n_state, n_state) + self.key = Linear(n_state, n_state, bias=False) + self.value = Linear(n_state, n_state) + self.out = Linear(n_state, n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + ): + q = self.query(x) + + if kv_cache is None or xa is None or self.key not in kv_cache: + # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; + # otherwise, perform key/value projections for self- or cross-attention as usual. + k = self.key(x if xa is None else xa) + v = self.value(x if xa is None else xa) + else: + # for cross-attention, calculate keys and values once and reuse in subsequent calls. + k = kv_cache[self.key] + v = kv_cache[self.value] + + wv, qk = self.qkv_attention(q, k, v, mask) + return self.out(wv), qk + + def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None): + n_batch, n_ctx, n_state = q.shape + scale = (n_state // self.n_head) ** -0.25 + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + qk = q @ k + if mask is not None: + qk = qk + mask[:n_ctx, :n_ctx] + qk = qk.float() + + w = F.softmax(qk, dim=-1).to(q.dtype) + return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): + super().__init__() + + self.attn = MultiHeadAttention(n_state, n_head) + self.attn_ln = LayerNorm(n_state) + + self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None + self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None + + n_mlp = n_state * 4 + self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)) + self.mlp_ln = LayerNorm(n_state) + + def forward( + self, + x: Tensor, + xa: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + kv_cache: Optional[dict] = None, + ): + x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] + if self.cross_attn: + x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] + x = x + self.mlp(self.mlp_ln(x)) + return x + + +class AudioEncoder(nn.Module): + def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): + super().__init__() + self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) + self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) + self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) + + self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( + [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] + ) + self.ln_post = LayerNorm(n_state) + + def forward(self, x: Tensor): + """ + x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) + the mel spectrogram of the audio + """ + x = F.gelu(self.conv1(x)) + x = F.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + + len_x = x.shape[1] + len_e = self.positional_embedding.shape[0] + assert len_x <= len_e, "incorrect audio shape" + pos_e = self.positional_embedding[:len_x, :] + x = (x + pos_e).to(x.dtype) + + for block in self.blocks: + x = block(x) + + x = self.ln_post(x) + return x + + +class TextDecoder(nn.Module): + def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): + super().__init__() + + self.token_embedding = nn.Embedding(n_vocab, n_state) + self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) + + self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( + [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)] + ) + self.ln = LayerNorm(n_state) + + mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) + self.register_buffer("mask", mask, persistent=False) + + def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): + """ + x : torch.LongTensor, shape = (batch_size, <= n_ctx) + the text tokens + xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx) + the encoded audio features to be attended on + """ + offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 + x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]] + x = x.to(xa.dtype) + + for block in self.blocks: + x = block(x, xa, mask=self.mask, kv_cache=kv_cache) + + x = self.ln(x) + logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() + + return logits + + +class Whisper(nn.Module): + def __init__(self, dims: ModelDimensions): + super().__init__() + self.dims = dims + self.encoder = AudioEncoder( + self.dims.n_mels, + self.dims.n_audio_ctx, + self.dims.n_audio_state, + self.dims.n_audio_head, + self.dims.n_audio_layer, + ) + self.decoder = TextDecoder( + self.dims.n_vocab, + self.dims.n_text_ctx, + self.dims.n_text_state, + self.dims.n_text_head, + self.dims.n_text_layer, + ) + + def embed_audio(self, mel: torch.Tensor): + return self.encoder(mel) + + def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): + return self.decoder(tokens, audio_features) + + def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]: + return self.decoder(tokens, self.encoder(mel)) + + @property + def device(self): + return next(self.parameters()).device + + @property + def is_multilingual(self): + return self.dims.n_vocab == 51865 + + def install_kv_cache_hooks(self, cache: Optional[dict] = None): + """ + The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value + tensors calculated for the previous positions. This method returns a dictionary that stores + all caches, and the necessary hooks for the key and value projection modules that save the + intermediate tensors to be reused during later calculations. + + Returns + ------- + cache : Dict[nn.Module, torch.Tensor] + A dictionary object mapping the key/value projection modules to its cache + hooks : List[RemovableHandle] + List of PyTorch RemovableHandle objects to stop the hooks to be called + """ + cache = {**cache} if cache is not None else {} + hooks = [] + + def save_to_cache(module, _, output): + if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]: + cache[module] = output # save as-is, for the first token or cross attention + else: + cache[module] = torch.cat([cache[module], output], dim=1).detach() + return cache[module] + + def install_hooks(layer: nn.Module): + if isinstance(layer, MultiHeadAttention): + hooks.append(layer.key.register_forward_hook(save_to_cache)) + hooks.append(layer.value.register_forward_hook(save_to_cache)) + + self.decoder.apply(install_hooks) + return cache, hooks + + detect_language = detect_language_function + decode = decode_function diff --git a/vencoder/whisper/tokenizer.py b/vencoder/whisper/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b15645dc7e15ca9f601413076299b362293eae6d --- /dev/null +++ b/vencoder/whisper/tokenizer.py @@ -0,0 +1,331 @@ +import os +from dataclasses import dataclass +from functools import lru_cache +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import GPT2TokenizerFast + +LANGUAGES = { + "en": "english", + "zh": "chinese", + "de": "german", + "es": "spanish", + "ru": "russian", + "ko": "korean", + "fr": "french", + "ja": "japanese", + "pt": "portuguese", + "tr": "turkish", + "pl": "polish", + "ca": "catalan", + "nl": "dutch", + "ar": "arabic", + "sv": "swedish", + "it": "italian", + "id": "indonesian", + "hi": "hindi", + "fi": "finnish", + "vi": "vietnamese", + "he": "hebrew", + "uk": "ukrainian", + "el": "greek", + "ms": "malay", + "cs": "czech", + "ro": "romanian", + "da": "danish", + "hu": "hungarian", + "ta": "tamil", + "no": "norwegian", + "th": "thai", + "ur": "urdu", + "hr": "croatian", + "bg": "bulgarian", + "lt": "lithuanian", + "la": "latin", + "mi": "maori", + "ml": "malayalam", + "cy": "welsh", + "sk": "slovak", + "te": "telugu", + "fa": "persian", + "lv": "latvian", + "bn": "bengali", + "sr": "serbian", + "az": "azerbaijani", + "sl": "slovenian", + "kn": "kannada", + "et": "estonian", + "mk": "macedonian", + "br": "breton", + "eu": "basque", + "is": "icelandic", + "hy": "armenian", + "ne": "nepali", + "mn": "mongolian", + "bs": "bosnian", + "kk": "kazakh", + "sq": "albanian", + "sw": "swahili", + "gl": "galician", + "mr": "marathi", + "pa": "punjabi", + "si": "sinhala", + "km": "khmer", + "sn": "shona", + "yo": "yoruba", + "so": "somali", + "af": "afrikaans", + "oc": "occitan", + "ka": "georgian", + "be": "belarusian", + "tg": "tajik", + "sd": "sindhi", + "gu": "gujarati", + "am": "amharic", + "yi": "yiddish", + "lo": "lao", + "uz": "uzbek", + "fo": "faroese", + "ht": "haitian creole", + "ps": "pashto", + "tk": "turkmen", + "nn": "nynorsk", + "mt": "maltese", + "sa": "sanskrit", + "lb": "luxembourgish", + "my": "myanmar", + "bo": "tibetan", + "tl": "tagalog", + "mg": "malagasy", + "as": "assamese", + "tt": "tatar", + "haw": "hawaiian", + "ln": "lingala", + "ha": "hausa", + "ba": "bashkir", + "jw": "javanese", + "su": "sundanese", +} + +# language code lookup by name, with a few language aliases +TO_LANGUAGE_CODE = { + **{language: code for code, language in LANGUAGES.items()}, + "burmese": "my", + "valencian": "ca", + "flemish": "nl", + "haitian": "ht", + "letzeburgesch": "lb", + "pushto": "ps", + "panjabi": "pa", + "moldavian": "ro", + "moldovan": "ro", + "sinhalese": "si", + "castilian": "es", +} + + +@dataclass(frozen=True) +class Tokenizer: + """A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens""" + + tokenizer: "GPT2TokenizerFast" + language: Optional[str] + sot_sequence: Tuple[int] + + def encode(self, text, **kwargs): + return self.tokenizer.encode(text, **kwargs) + + def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs): + return self.tokenizer.decode(token_ids, **kwargs) + + def decode_with_timestamps(self, tokens) -> str: + """ + Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. + This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". + """ + outputs = [[]] + for token in tokens: + if token >= self.timestamp_begin: + timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>" + outputs.append(timestamp) + outputs.append([]) + else: + outputs[-1].append(token) + outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs] + return "".join(outputs) + + @property + @lru_cache() + def eot(self) -> int: + return self.tokenizer.eos_token_id + + @property + @lru_cache() + def sot(self) -> int: + return self._get_single_token_id("<|startoftranscript|>") + + @property + @lru_cache() + def sot_lm(self) -> int: + return self._get_single_token_id("<|startoflm|>") + + @property + @lru_cache() + def sot_prev(self) -> int: + return self._get_single_token_id("<|startofprev|>") + + @property + @lru_cache() + def no_speech(self) -> int: + return self._get_single_token_id("<|nospeech|>") + + @property + @lru_cache() + def no_timestamps(self) -> int: + return self._get_single_token_id("<|notimestamps|>") + + @property + @lru_cache() + def timestamp_begin(self) -> int: + return self.tokenizer.all_special_ids[-1] + 1 + + @property + @lru_cache() + def language_token(self) -> int: + """Returns the token id corresponding to the value of the `language` field""" + if self.language is None: + raise ValueError("This tokenizer does not have language token configured") + + additional_tokens = dict( + zip( + self.tokenizer.additional_special_tokens, + self.tokenizer.additional_special_tokens_ids, + ) + ) + candidate = f"<|{self.language}|>" + if candidate in additional_tokens: + return additional_tokens[candidate] + + raise KeyError(f"Language {self.language} not found in tokenizer.") + + @property + @lru_cache() + def all_language_tokens(self) -> Tuple[int]: + result = [] + for token, token_id in zip( + self.tokenizer.additional_special_tokens, + self.tokenizer.additional_special_tokens_ids, + ): + if token.strip("<|>") in LANGUAGES: + result.append(token_id) + return tuple(result) + + @property + @lru_cache() + def all_language_codes(self) -> Tuple[str]: + return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens) + + @property + @lru_cache() + def sot_sequence_including_notimestamps(self) -> Tuple[int]: + return tuple(list(self.sot_sequence) + [self.no_timestamps]) + + @property + @lru_cache() + def non_speech_tokens(self) -> Tuple[int]: + """ + Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech + annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. + + - ♪♪♪ + - ( SPEAKING FOREIGN LANGUAGE ) + - [DAVID] Hey there, + + keeping basic punctuations like commas, periods, question marks, exclamation points, etc. + """ + symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』") + symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() + + # symbols that may be a single token or multiple tokens depending on the tokenizer. + # In case they're multiple tokens, suppress the first token, which is safe because: + # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress + # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. + miscellaneous = set("♩♪♫♬♭♮♯") + assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) + + # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word + result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]} + for symbol in symbols + list(miscellaneous): + for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]: + if len(tokens) == 1 or symbol in miscellaneous: + result.add(tokens[0]) + + return tuple(sorted(result)) + + def _get_single_token_id(self, text) -> int: + tokens = self.tokenizer.encode(text) + assert len(tokens) == 1, f"{text} is not encoded as a single token" + return tokens[0] + + +@lru_cache(maxsize=None) +def build_tokenizer(name: str = "gpt2"): + os.environ["TOKENIZERS_PARALLELISM"] = "false" + path = os.path.join(os.path.dirname(__file__), "assets", name) + tokenizer = GPT2TokenizerFast.from_pretrained(path) + + specials = [ + "<|startoftranscript|>", + *[f"<|{lang}|>" for lang in LANGUAGES.keys()], + "<|translate|>", + "<|transcribe|>", + "<|startoflm|>", + "<|startofprev|>", + "<|nospeech|>", + "<|notimestamps|>", + ] + + tokenizer.add_special_tokens(dict(additional_special_tokens=specials)) + return tokenizer + + +@lru_cache(maxsize=None) +def get_tokenizer( + multilingual: bool, + *, + task: Optional[str] = None, # Literal["transcribe", "translate", None] + language: Optional[str] = None, +) -> Tokenizer: + if language is not None: + language = language.lower() + if language not in LANGUAGES: + if language in TO_LANGUAGE_CODE: + language = TO_LANGUAGE_CODE[language] + else: + raise ValueError(f"Unsupported language: {language}") + + if multilingual: + tokenizer_name = "multilingual" + task = task or "transcribe" + language = language or "en" + else: + tokenizer_name = "gpt2" + task = None + language = None + + tokenizer = build_tokenizer(name=tokenizer_name) + all_special_ids: List[int] = tokenizer.all_special_ids + sot: int = all_special_ids[1] + translate: int = all_special_ids[-6] + transcribe: int = all_special_ids[-5] + + langs = tuple(LANGUAGES.keys()) + sot_sequence = [sot] + if language is not None: + sot_sequence.append(sot + 1 + langs.index(language)) + if task is not None: + sot_sequence.append(transcribe if task == "transcribe" else translate) + + return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence)) diff --git a/vencoder/whisper/utils.py b/vencoder/whisper/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5dacc173c40bcd6e999d728862e29a968000b12e --- /dev/null +++ b/vencoder/whisper/utils.py @@ -0,0 +1,163 @@ +import json +import os +import sys +import zlib +from typing import Callable, TextIO + +system_encoding = sys.getdefaultencoding() + +if system_encoding != "utf-8": + def make_safe(string): + # replaces any character not representable using the system default encoding with an '?', + # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729). + return string.encode(system_encoding, errors="replace").decode(system_encoding) +else: + def make_safe(string): + # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding + return string + + +def exact_div(x, y): + assert x % y == 0 + return x // y + + +def str2bool(string): + str2val = {"True": True, "False": False} + if string in str2val: + return str2val[string] + else: + raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") + + +def optional_int(string): + return None if string == "None" else int(string) + + +def optional_float(string): + return None if string == "None" else float(string) + + +def compression_ratio(text) -> float: + text_bytes = text.encode("utf-8") + return len(text_bytes) / len(zlib.compress(text_bytes)) + + +def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'): + assert seconds >= 0, "non-negative timestamp expected" + milliseconds = round(seconds * 1000.0) + + hours = milliseconds // 3_600_000 + milliseconds -= hours * 3_600_000 + + minutes = milliseconds // 60_000 + milliseconds -= minutes * 60_000 + + seconds = milliseconds // 1_000 + milliseconds -= seconds * 1_000 + + hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" + return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" + + +class ResultWriter: + extension: str + + def __init__(self, output_dir: str): + self.output_dir = output_dir + + def __call__(self, result: dict, audio_path: str): + audio_basename = os.path.basename(audio_path) + output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension) + + with open(output_path, "w", encoding="utf-8") as f: + self.write_result(result, file=f) + + def write_result(self, result: dict, file: TextIO): + raise NotImplementedError + + +class WriteTXT(ResultWriter): + extension: str = "txt" + + def write_result(self, result: dict, file: TextIO): + for segment in result["segments"]: + print(segment['text'].strip(), file=file, flush=True) + + +class WriteVTT(ResultWriter): + extension: str = "vtt" + + def write_result(self, result: dict, file: TextIO): + print("WEBVTT\n", file=file) + for segment in result["segments"]: + print( + f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n" + f"{segment['text'].strip().replace('-->', '->')}\n", + file=file, + flush=True, + ) + + +class WriteSRT(ResultWriter): + extension: str = "srt" + + def write_result(self, result: dict, file: TextIO): + for i, segment in enumerate(result["segments"], start=1): + # write srt lines + print( + f"{i}\n" + f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> " + f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n" + f"{segment['text'].strip().replace('-->', '->')}\n", + file=file, + flush=True, + ) + + +class WriteTSV(ResultWriter): + """ + Write a transcript to a file in TSV (tab-separated values) format containing lines like: + \t\t + + Using integer milliseconds as start and end times means there's no chance of interference from + an environment setting a language encoding that causes the decimal in a floating point number + to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. + """ + extension: str = "tsv" + + def write_result(self, result: dict, file: TextIO): + print("start", "end", "text", sep="\t", file=file) + for segment in result["segments"]: + print(round(1000 * segment['start']), file=file, end="\t") + print(round(1000 * segment['end']), file=file, end="\t") + print(segment['text'].strip().replace("\t", " "), file=file, flush=True) + + +class WriteJSON(ResultWriter): + extension: str = "json" + + def write_result(self, result: dict, file: TextIO): + json.dump(result, file) + + +def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]: + writers = { + "txt": WriteTXT, + "vtt": WriteVTT, + "srt": WriteSRT, + "tsv": WriteTSV, + "json": WriteJSON, + } + + if output_format == "all": + all_writers = [writer(output_dir) for writer in writers.values()] + + def write_all(result: dict, file: TextIO): + for writer in all_writers: + writer(result, file) + + return write_all + + return writers[output_format](output_dir) +