Spaces:
Runtime error
Runtime error
| #! /usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| # vim:fenc=utf-8 | |
| # | |
| # Copyright (c) 2021 Kazuhiro KOBAYASHI <root.4mac@gmail.com> | |
| # | |
| # Distributed under terms of the MIT license. | |
| """ | |
| """ | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .layer import ConvLayers, DFTLayer | |
| from .model import NeuralHomomorphicVocoder | |
| from .module import CCepLTVFilter, SinusoidsGenerator | |
| class IncrementalCacheConvClass(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # remain handles to remove old hooks | |
| self.handles = [] | |
| def _forward_without_cache(self, x): | |
| raise NotImplementedError("Please implement _forward_without_cache") | |
| def forward(self, caches, *inputs): | |
| self.caches = caches | |
| self.new_caches = [] | |
| self.cache_num = 0 | |
| x = self._forward(*inputs) | |
| return x, self.new_caches | |
| def reset_caches(self, *args, hop_size=128, batch_size=1): | |
| self.caches = [] | |
| self.receptive_sizes = [] | |
| self._initialize_caches(batch_size=batch_size, hop_size=hop_size) | |
| # set ordering hook | |
| self._set_pre_hooks(cache_ordering=True) | |
| # caclulate order of inference | |
| _ = self._forward_without_cache(*args) | |
| # remove hook handles for ordering | |
| [h.remove() for h in self.handles] | |
| # set concatenate hook | |
| self._set_pre_hooks(cache_ordering=False) | |
| # make cache zeros | |
| self.caches = [torch.zeros_like(c) for c in self.caches] | |
| # remove conv padding | |
| self._remove_padding() | |
| return self.caches | |
| def _initialize_caches(self, batch_size=1, hop_size=128): | |
| self.caches_dict = {} | |
| self.receptive_sizes_dict = {} | |
| for k, m in self.named_modules(): | |
| if isinstance(m, nn.Conv1d): | |
| if m.kernel_size[0] > 1: | |
| receptive_size = self._get_receptive_size_1d(m) | |
| # NOTE(k2kobayashi): postfilter_fn requires to accept | |
| # hop_size length input | |
| if "postfilter_fn" in k: | |
| receptive_size += hop_size - 1 | |
| self.caches_dict[id(m)] = torch.randn( | |
| (batch_size, m.in_channels, receptive_size) | |
| ) | |
| self.receptive_sizes_dict[id(m)] = receptive_size | |
| def _set_pre_hooks(self, cache_ordering=True): | |
| if cache_ordering: | |
| func = self._cache_ordering | |
| else: | |
| func = self._concat_cache | |
| for k, m in self.named_modules(): | |
| if isinstance(m, nn.Conv1d): | |
| if m.kernel_size[0] > 1: | |
| self.handles.append(m.register_forward_pre_hook(func)) | |
| def _concat_cache(self, module, inputs): | |
| def __concat_cache(inputs, cache, receptive_size): | |
| inputs = torch.cat([cache, inputs[0]], axis=-1) | |
| inputs = inputs[..., -receptive_size:] | |
| return inputs | |
| cache = self.caches[self.cache_num] | |
| receptive_size = self.receptive_sizes[self.cache_num] | |
| inputs = __concat_cache(inputs, cache, receptive_size) | |
| self.new_caches += [inputs] | |
| self.cache_num += 1 | |
| return inputs | |
| def _cache_ordering(self, module, inputs): | |
| self.caches.append(self.caches_dict[id(module)]) | |
| self.receptive_sizes.append(self.receptive_sizes_dict[id(module)]) | |
| def _remove_padding(self): | |
| def __remove_padding(m): | |
| if isinstance(m, torch.nn.Conv1d): | |
| m.padding = (0,) | |
| if isinstance(m, torch.nn.Conv2d): | |
| m.padding = (0, 0) | |
| self.apply(__remove_padding) | |
| def _get_receptive_size_1d(m): | |
| return (m.kernel_size[0] - 1) * m.dilation[0] + 1 | |
| class IncrementalNeuralHomomorphicVocoder( | |
| NeuralHomomorphicVocoder, IncrementalCacheConvClass | |
| ): | |
| fs = 24000 | |
| fft_size = 1024 | |
| hop_size = 256 | |
| in_channels = 80 | |
| conv_channels = 256 | |
| ccep_size = 222 | |
| out_channels = 1 | |
| kernel_size = 3 | |
| dilation_size = 1 | |
| group_size = 8 | |
| fmin = 80 | |
| fmax = 7600 | |
| roll_size = 24 | |
| n_ltv_layers = 3 | |
| n_postfilter_layers = 4 | |
| n_ltv_postfilter_layers = 1 | |
| use_causal = False | |
| use_reference_mag = False | |
| use_tanh = False | |
| use_uvmask = False | |
| use_weight_norm = True | |
| conv_type = "original" | |
| postfilter_type = "ddsconv" | |
| ltv_postfilter_type = "conv" | |
| ltv_postfilter_kernel_size = 128 | |
| scaler_file = None | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| assert kwargs["use_causal"], "Require use_causal" | |
| self.impulse_generator = IncrementalSinusoidsGenerator( | |
| hop_size=self.hop_size, fs=self.fs, use_uvmask=self.use_uvmask | |
| ) | |
| self.ltv_harmonic = IncrementalCCepLTVFilter( | |
| **self.ltv_params, feat2linear_fn=self.feat2linear_fn | |
| ) | |
| self.ltv_noise = IncrementalCCepLTVFilter(**self.ltv_params) | |
| self.window_size = self.ltv_harmonic.window_size | |
| def _forward_without_cache(self, *inputs): | |
| super()._forward(*inputs) | |
| def forward(self, z, x, f0, uv, ltv_caches, conv_caches): | |
| self.caches = conv_caches | |
| self.new_caches = [] | |
| self.cache_num = 0 | |
| y, new_ltv_caches = self._incremental_forward(z, x, f0, uv, ltv_caches) | |
| return y, new_ltv_caches, self.new_caches | |
| def _incremental_forward(self, z, x, cf0, uv, ltv_caches): | |
| if self.feat_scaler_fn is not None: | |
| x = self.feat_scaler_fn(x) | |
| # impulse | |
| impulse, impulse_cache = self.impulse_generator.incremental_forward( | |
| cf0, uv, ltv_caches[0] | |
| ) | |
| # ltv for harmonic | |
| harmonic = self._concat_ltv_input_cache(ltv_caches[1], impulse) | |
| ltv_harm = self.ltv_harmonic.incremental_forward(x, harmonic) | |
| sig_harm = ltv_caches[2][..., -self.hop_size :] + ltv_harm[..., : self.hop_size] | |
| if self.ltv_harmonic.ltv_postfilter_fn is not None: | |
| sig_harm = self.ltv_harmonic.ltv_postfilter_fn( | |
| sig_harm.transpose(1, 2) | |
| ).transpose(1, 2) | |
| # ltv for noise | |
| noise = self._concat_ltv_input_cache(ltv_caches[3], z) | |
| ltv_noise = self.ltv_noise.incremental_forward(x, noise) | |
| sig_noise = ( | |
| ltv_caches[4][..., -self.hop_size :] + ltv_noise[..., : self.hop_size] | |
| ) | |
| if self.ltv_noise.ltv_postfilter_fn is not None: | |
| sig_noise = self.ltv_noise.ltv_postfilter_fn( | |
| sig_noise.transpose(1, 2) | |
| ).transpose(1, 2) | |
| # superimpose | |
| y = sig_harm + sig_noise | |
| if self.postfilter_fn is not None: | |
| y = self.postfilter_fn(y.transpose(1, 2)).transpose(1, 2) | |
| y = torch.tanh(y) if self.use_tanh else torch.clamp(y, -1, 1) | |
| new_ltv_caches = [impulse_cache, harmonic, ltv_harm, noise, ltv_noise] | |
| return y.reshape(1, self.out_channels, -1), new_ltv_caches | |
| def reset_ltv_caches(self): | |
| ltv_caches = [] | |
| # impulse generator | |
| ltv_caches += [torch.zeros(1, 1, 1)] | |
| # ltv harm | |
| ltv_caches += [torch.zeros(1, 1, self.window_size)] | |
| ltv_caches += [torch.zeros(1, 1, self.window_size)] | |
| # ltv noise | |
| ltv_caches += [torch.zeros(1, 1, self.window_size)] | |
| ltv_caches += [torch.zeros(1, 1, self.window_size)] | |
| return ltv_caches | |
| def _concat_ltv_input_cache(self, cache, z): | |
| z = torch.cat([cache, z], axis=-1) | |
| z = z[..., self.hop_size :] | |
| return z | |
| class IncrementalSinusoidsGenerator(SinusoidsGenerator): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| def incremental_forward(self, cf0, uv, cache): | |
| f0, uv = self.upsample(cf0.transpose(1, 2)), self.upsample(uv.transpose(1, 2)) | |
| harmonic, new_cache = self.incremental_generate_sinusoids(f0, uv, cache) | |
| harmonic = self.harmonic_amp * harmonic.reshape(cf0.size(0), 1, -1) | |
| return harmonic, new_cache | |
| def incremental_generate_sinusoids(self, f0, uv, cache): | |
| mask = self.anti_aliacing_mask(f0 * self.harmonics) | |
| # f0[..., 0] = f0[..., 0] + cache | |
| f0 = torch.cat([cache, f0], axis=-1) | |
| cumsum = torch.cumsum(f0, dim=-1)[..., 1:] | |
| rads = cumsum * 2.0 * math.pi / self.fs * self.harmonics | |
| harmonic = torch.sum(torch.cos(rads) * mask, dim=1, keepdim=True) | |
| if self.use_uvmask: | |
| harmonic = uv * harmonic | |
| new_cache = cumsum[..., -1:] % self.fs | |
| return harmonic, new_cache | |
| class IncrementalConvLayers(ConvLayers, IncrementalCacheConvClass): | |
| in_channels = 80 | |
| conv_channels = 256 | |
| out_channels = 222 | |
| kernel_size = 3 | |
| dilation_size = 1 | |
| group_size = 8 | |
| n_conv_layers = 3 | |
| use_causal = False | |
| conv_type = "original" | |
| def __init__(self, **kwargs): | |
| for k, v in kwargs.items(): | |
| if k not in self.__class__.__dict__.keys(): | |
| raise ValueError(f"{k} not in arguments {self.__class__}.") | |
| setattr(self, k, v) | |
| assert kwargs["use_causal"], "Require use_causal" | |
| super().__init__(**kwargs) | |
| def _forward_without_cache(self, *inputs): | |
| super().forward(*inputs) | |
| def forward(self, x, conv_caches): | |
| self.caches = conv_caches | |
| self.new_caches = [] | |
| self.cache_num = 0 | |
| x = self.conv_layers(x) | |
| return x, self.new_caches | |
| class IncrementalCCepLTVFilter(CCepLTVFilter): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.conv_dft = DFTLayer(n_fft=self.fft_size) | |
| self.conv_idft = DFTLayer(n_fft=self.fft_size + 1) | |
| self.padding = (self.fft_size - self.ccep_size) // 2 | |
| def incremental_forward(self, x, z): | |
| """Input tensor size | |
| x: (1, 1, input_size) | |
| z: (1, 1, fft_size + hop_size) | |
| """ | |
| # inference complex cepstrum | |
| ccep = self.conv(x) / self.quef_norm | |
| log_mag = None if self.feat2linear_fn is None else self.feat2linear_fn(x) | |
| y = self._dft_ccep2impulse(ccep, ref=log_mag) | |
| # convolve to a frame | |
| z = F.pad(z, (self.fft_size // 2, self.fft_size // 2)) | |
| z = F.conv1d(z, y) | |
| return z * self.win | |
| def _dft_ccep2impulse(self, ccep, ref=None): | |
| ccep = F.pad(ccep, (self.padding, self.padding)) | |
| real, imag = self.conv_dft(ccep) | |
| if ref is not None: | |
| real = self._apply_ref_mag(real, ref) | |
| mag, phase = torch.pow(10, real / 10), imag | |
| real, imag = mag * torch.cos(phase), mag * torch.sin(phase) | |
| real, _ = self.conv_idft(F.pad(real, (0, 1)), F.pad(imag, (0, 1)), inverse=True) | |
| return real | |