|
import librosa |
|
import torch |
|
from torch import nn |
|
|
|
from functools import partial |
|
from math import prod |
|
from typing import Callable, Tuple, List |
|
|
|
import numpy as np |
|
import torch.nn.functional as F |
|
from torch.nn import Conv1d |
|
from torch.nn.utils import weight_norm |
|
from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm |
|
from diffusers.models.modeling_utils import ModelMixin |
|
from diffusers.loaders import FromOriginalModelMixin |
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
|
|
|
try: |
|
from music_log_mel import LogMelSpectrogram |
|
except ImportError: |
|
from .music_log_mel import LogMelSpectrogram |
|
|
|
|
|
def drop_path( |
|
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True |
|
): |
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
|
|
|
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, |
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... |
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for |
|
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use |
|
'survival rate' as the argument. |
|
|
|
""" |
|
|
|
if drop_prob == 0.0 or not training: |
|
return x |
|
keep_prob = 1 - drop_prob |
|
shape = (x.shape[0],) + (1,) * ( |
|
x.ndim - 1 |
|
) |
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob) |
|
if keep_prob > 0.0 and scale_by_keep: |
|
random_tensor.div_(keep_prob) |
|
return x * random_tensor |
|
|
|
|
|
class DropPath(nn.Module): |
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" |
|
|
|
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): |
|
super(DropPath, self).__init__() |
|
self.drop_prob = drop_prob |
|
self.scale_by_keep = scale_by_keep |
|
|
|
def forward(self, x): |
|
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) |
|
|
|
def extra_repr(self): |
|
return f"drop_prob={round(self.drop_prob,3):0.3f}" |
|
|
|
|
|
class LayerNorm(nn.Module): |
|
r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. |
|
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with |
|
shape (batch_size, height, width, channels) while channels_first corresponds to inputs |
|
with shape (batch_size, channels, height, width). |
|
""" |
|
|
|
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): |
|
super().__init__() |
|
self.weight = nn.Parameter(torch.ones(normalized_shape)) |
|
self.bias = nn.Parameter(torch.zeros(normalized_shape)) |
|
self.eps = eps |
|
self.data_format = data_format |
|
if self.data_format not in ["channels_last", "channels_first"]: |
|
raise NotImplementedError |
|
self.normalized_shape = (normalized_shape,) |
|
|
|
def forward(self, x): |
|
if self.data_format == "channels_last": |
|
return F.layer_norm( |
|
x, self.normalized_shape, self.weight, self.bias, self.eps |
|
) |
|
elif self.data_format == "channels_first": |
|
u = x.mean(1, keepdim=True) |
|
s = (x - u).pow(2).mean(1, keepdim=True) |
|
x = (x - u) / torch.sqrt(s + self.eps) |
|
x = self.weight[:, None] * x + self.bias[:, None] |
|
return x |
|
|
|
|
|
class ConvNeXtBlock(nn.Module): |
|
r"""ConvNeXt Block. There are two equivalent implementations: |
|
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) |
|
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back |
|
We use (2) as we find it slightly faster in PyTorch |
|
|
|
Args: |
|
dim (int): Number of input channels. |
|
drop_path (float): Stochastic depth rate. Default: 0.0 |
|
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. |
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. |
|
kernel_size (int): Kernel size for depthwise conv. Default: 7. |
|
dilation (int): Dilation for depthwise conv. Default: 1. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
drop_path: float = 0.0, |
|
layer_scale_init_value: float = 1e-6, |
|
mlp_ratio: float = 4.0, |
|
kernel_size: int = 7, |
|
dilation: int = 1, |
|
): |
|
super().__init__() |
|
|
|
self.dwconv = nn.Conv1d( |
|
dim, |
|
dim, |
|
kernel_size=kernel_size, |
|
padding=int(dilation * (kernel_size - 1) / 2), |
|
groups=dim, |
|
) |
|
self.norm = LayerNorm(dim, eps=1e-6) |
|
self.pwconv1 = nn.Linear( |
|
dim, int(mlp_ratio * dim) |
|
) |
|
self.act = nn.GELU() |
|
self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim) |
|
self.gamma = ( |
|
nn.Parameter(layer_scale_init_value * |
|
torch.ones((dim)), requires_grad=True) |
|
if layer_scale_init_value > 0 |
|
else None |
|
) |
|
self.drop_path = DropPath( |
|
drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
|
def forward(self, x, apply_residual: bool = True): |
|
input = x |
|
|
|
x = self.dwconv(x) |
|
x = x.permute(0, 2, 1) |
|
x = self.norm(x) |
|
x = self.pwconv1(x) |
|
x = self.act(x) |
|
x = self.pwconv2(x) |
|
|
|
if self.gamma is not None: |
|
x = self.gamma * x |
|
|
|
x = x.permute(0, 2, 1) |
|
x = self.drop_path(x) |
|
|
|
if apply_residual: |
|
x = input + x |
|
|
|
return x |
|
|
|
|
|
class ParallelConvNeXtBlock(nn.Module): |
|
def __init__(self, kernel_sizes: List[int], *args, **kwargs): |
|
super().__init__() |
|
self.blocks = nn.ModuleList( |
|
[ |
|
ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs) |
|
for kernel_size in kernel_sizes |
|
] |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return torch.stack( |
|
[block(x, apply_residual=False) for block in self.blocks] + [x], |
|
dim=1, |
|
).sum(dim=1) |
|
|
|
|
|
class ConvNeXtEncoder(nn.Module): |
|
def __init__( |
|
self, |
|
input_channels=3, |
|
depths=[3, 3, 9, 3], |
|
dims=[96, 192, 384, 768], |
|
drop_path_rate=0.0, |
|
layer_scale_init_value=1e-6, |
|
kernel_sizes: Tuple[int] = (7,), |
|
): |
|
super().__init__() |
|
assert len(depths) == len(dims) |
|
|
|
self.channel_layers = nn.ModuleList() |
|
stem = nn.Sequential( |
|
nn.Conv1d( |
|
input_channels, |
|
dims[0], |
|
kernel_size=7, |
|
padding=3, |
|
padding_mode="replicate", |
|
), |
|
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), |
|
) |
|
self.channel_layers.append(stem) |
|
|
|
for i in range(len(depths) - 1): |
|
mid_layer = nn.Sequential( |
|
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), |
|
nn.Conv1d(dims[i], dims[i + 1], kernel_size=1), |
|
) |
|
self.channel_layers.append(mid_layer) |
|
|
|
block_fn = ( |
|
partial(ConvNeXtBlock, kernel_size=kernel_sizes[0]) |
|
if len(kernel_sizes) == 1 |
|
else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes) |
|
) |
|
|
|
self.stages = nn.ModuleList() |
|
drop_path_rates = [ |
|
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) |
|
] |
|
|
|
cur = 0 |
|
for i in range(len(depths)): |
|
stage = nn.Sequential( |
|
*[ |
|
block_fn( |
|
dim=dims[i], |
|
drop_path=drop_path_rates[cur + j], |
|
layer_scale_init_value=layer_scale_init_value, |
|
) |
|
for j in range(depths[i]) |
|
] |
|
) |
|
self.stages.append(stage) |
|
cur += depths[i] |
|
|
|
self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first") |
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, (nn.Conv1d, nn.Linear)): |
|
nn.init.trunc_normal_(m.weight, std=0.02) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
) -> torch.Tensor: |
|
for channel_layer, stage in zip(self.channel_layers, self.stages): |
|
x = channel_layer(x) |
|
x = stage(x) |
|
|
|
return self.norm(x) |
|
|
|
|
|
def init_weights(m, mean=0.0, std=0.01): |
|
classname = m.__class__.__name__ |
|
if classname.find("Conv") != -1: |
|
m.weight.data.normal_(mean, std) |
|
|
|
|
|
def get_padding(kernel_size, dilation=1): |
|
return (kernel_size * dilation - dilation) // 2 |
|
|
|
|
|
class ResBlock1(torch.nn.Module): |
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): |
|
super().__init__() |
|
|
|
self.convs1 = nn.ModuleList( |
|
[ |
|
weight_norm( |
|
Conv1d( |
|
channels, |
|
channels, |
|
kernel_size, |
|
1, |
|
dilation=dilation[0], |
|
padding=get_padding(kernel_size, dilation[0]), |
|
) |
|
), |
|
weight_norm( |
|
Conv1d( |
|
channels, |
|
channels, |
|
kernel_size, |
|
1, |
|
dilation=dilation[1], |
|
padding=get_padding(kernel_size, dilation[1]), |
|
) |
|
), |
|
weight_norm( |
|
Conv1d( |
|
channels, |
|
channels, |
|
kernel_size, |
|
1, |
|
dilation=dilation[2], |
|
padding=get_padding(kernel_size, dilation[2]), |
|
) |
|
), |
|
] |
|
) |
|
self.convs1.apply(init_weights) |
|
|
|
self.convs2 = nn.ModuleList( |
|
[ |
|
weight_norm( |
|
Conv1d( |
|
channels, |
|
channels, |
|
kernel_size, |
|
1, |
|
dilation=1, |
|
padding=get_padding(kernel_size, 1), |
|
) |
|
), |
|
weight_norm( |
|
Conv1d( |
|
channels, |
|
channels, |
|
kernel_size, |
|
1, |
|
dilation=1, |
|
padding=get_padding(kernel_size, 1), |
|
) |
|
), |
|
weight_norm( |
|
Conv1d( |
|
channels, |
|
channels, |
|
kernel_size, |
|
1, |
|
dilation=1, |
|
padding=get_padding(kernel_size, 1), |
|
) |
|
), |
|
] |
|
) |
|
self.convs2.apply(init_weights) |
|
|
|
def forward(self, x): |
|
for c1, c2 in zip(self.convs1, self.convs2): |
|
xt = F.silu(x) |
|
xt = c1(xt) |
|
xt = F.silu(xt) |
|
xt = c2(xt) |
|
x = xt + x |
|
return x |
|
|
|
def remove_weight_norm(self): |
|
for conv in self.convs1: |
|
remove_weight_norm(conv) |
|
for conv in self.convs2: |
|
remove_weight_norm(conv) |
|
|
|
|
|
class HiFiGANGenerator(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
hop_length: int = 512, |
|
upsample_rates: Tuple[int] = (8, 8, 2, 2, 2), |
|
upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2), |
|
resblock_kernel_sizes: Tuple[int] = (3, 7, 11), |
|
resblock_dilation_sizes: Tuple[Tuple[int]] = ( |
|
(1, 3, 5), (1, 3, 5), (1, 3, 5)), |
|
num_mels: int = 128, |
|
upsample_initial_channel: int = 512, |
|
use_template: bool = True, |
|
pre_conv_kernel_size: int = 7, |
|
post_conv_kernel_size: int = 7, |
|
post_activation: Callable = partial(nn.SiLU, inplace=True), |
|
): |
|
super().__init__() |
|
|
|
assert ( |
|
prod(upsample_rates) == hop_length |
|
), f"hop_length must be {prod(upsample_rates)}" |
|
|
|
self.conv_pre = weight_norm( |
|
nn.Conv1d( |
|
num_mels, |
|
upsample_initial_channel, |
|
pre_conv_kernel_size, |
|
1, |
|
padding=get_padding(pre_conv_kernel_size), |
|
) |
|
) |
|
|
|
self.num_upsamples = len(upsample_rates) |
|
self.num_kernels = len(resblock_kernel_sizes) |
|
|
|
self.noise_convs = nn.ModuleList() |
|
self.use_template = use_template |
|
self.ups = nn.ModuleList() |
|
|
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): |
|
c_cur = upsample_initial_channel // (2 ** (i + 1)) |
|
self.ups.append( |
|
weight_norm( |
|
nn.ConvTranspose1d( |
|
upsample_initial_channel // (2**i), |
|
upsample_initial_channel // (2 ** (i + 1)), |
|
k, |
|
u, |
|
padding=(k - u) // 2, |
|
) |
|
) |
|
) |
|
|
|
if not use_template: |
|
continue |
|
|
|
if i + 1 < len(upsample_rates): |
|
stride_f0 = np.prod(upsample_rates[i + 1:]) |
|
self.noise_convs.append( |
|
Conv1d( |
|
1, |
|
c_cur, |
|
kernel_size=stride_f0 * 2, |
|
stride=stride_f0, |
|
padding=stride_f0 // 2, |
|
) |
|
) |
|
else: |
|
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) |
|
|
|
self.resblocks = nn.ModuleList() |
|
for i in range(len(self.ups)): |
|
ch = upsample_initial_channel // (2 ** (i + 1)) |
|
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes): |
|
self.resblocks.append(ResBlock1(ch, k, d)) |
|
|
|
self.activation_post = post_activation() |
|
self.conv_post = weight_norm( |
|
nn.Conv1d( |
|
ch, |
|
1, |
|
post_conv_kernel_size, |
|
1, |
|
padding=get_padding(post_conv_kernel_size), |
|
) |
|
) |
|
self.ups.apply(init_weights) |
|
self.conv_post.apply(init_weights) |
|
|
|
def forward(self, x, template=None): |
|
x = self.conv_pre(x) |
|
|
|
for i in range(self.num_upsamples): |
|
x = F.silu(x, inplace=True) |
|
x = self.ups[i](x) |
|
|
|
if self.use_template: |
|
x = x + self.noise_convs[i](template) |
|
|
|
xs = None |
|
|
|
for j in range(self.num_kernels): |
|
if xs is None: |
|
xs = self.resblocks[i * self.num_kernels + j](x) |
|
else: |
|
xs += self.resblocks[i * self.num_kernels + j](x) |
|
|
|
x = xs / self.num_kernels |
|
|
|
x = self.activation_post(x) |
|
x = self.conv_post(x) |
|
x = torch.tanh(x) |
|
|
|
return x |
|
|
|
def remove_weight_norm(self): |
|
for up in self.ups: |
|
remove_weight_norm(up) |
|
for block in self.resblocks: |
|
block.remove_weight_norm() |
|
remove_weight_norm(self.conv_pre) |
|
remove_weight_norm(self.conv_post) |
|
|
|
|
|
class ADaMoSHiFiGANV1(ModelMixin, ConfigMixin, FromOriginalModelMixin): |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
input_channels: int = 128, |
|
depths: List[int] = [3, 3, 9, 3], |
|
dims: List[int] = [128, 256, 384, 512], |
|
drop_path_rate: float = 0.0, |
|
kernel_sizes: Tuple[int] = (7,), |
|
upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2), |
|
upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4), |
|
resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13), |
|
resblock_dilation_sizes: Tuple[Tuple[int]] = ( |
|
(1, 3, 5), (1, 3, 5), (1, 3, 5), (1, 3, 5)), |
|
num_mels: int = 512, |
|
upsample_initial_channel: int = 1024, |
|
use_template: bool = False, |
|
pre_conv_kernel_size: int = 13, |
|
post_conv_kernel_size: int = 13, |
|
sampling_rate: int = 44100, |
|
n_fft: int = 2048, |
|
win_length: int = 2048, |
|
hop_length: int = 512, |
|
f_min: int = 40, |
|
f_max: int = 16000, |
|
n_mels: int = 128, |
|
): |
|
super().__init__() |
|
|
|
self.backbone = ConvNeXtEncoder( |
|
input_channels=input_channels, |
|
depths=depths, |
|
dims=dims, |
|
drop_path_rate=drop_path_rate, |
|
kernel_sizes=kernel_sizes, |
|
) |
|
|
|
self.head = HiFiGANGenerator( |
|
hop_length=hop_length, |
|
upsample_rates=upsample_rates, |
|
upsample_kernel_sizes=upsample_kernel_sizes, |
|
resblock_kernel_sizes=resblock_kernel_sizes, |
|
resblock_dilation_sizes=resblock_dilation_sizes, |
|
num_mels=num_mels, |
|
upsample_initial_channel=upsample_initial_channel, |
|
use_template=use_template, |
|
pre_conv_kernel_size=pre_conv_kernel_size, |
|
post_conv_kernel_size=post_conv_kernel_size, |
|
) |
|
self.sampling_rate = sampling_rate |
|
self.mel_transform = LogMelSpectrogram( |
|
sample_rate=sampling_rate, |
|
n_fft=n_fft, |
|
win_length=win_length, |
|
hop_length=hop_length, |
|
f_min=f_min, |
|
f_max=f_max, |
|
n_mels=n_mels, |
|
) |
|
self.eval() |
|
|
|
@torch.no_grad() |
|
def decode(self, mel): |
|
y = self.backbone(mel) |
|
y = self.head(y) |
|
return y |
|
|
|
@torch.no_grad() |
|
def encode(self, x): |
|
return self.mel_transform(x) |
|
|
|
def forward(self, mel): |
|
y = self.backbone(mel) |
|
y = self.head(y) |
|
return y |
|
|
|
|
|
if __name__ == "__main__": |
|
import soundfile as sf |
|
|
|
x = "test_audio.flac" |
|
model = ADaMoSHiFiGANV1.from_pretrained("./checkpoints/music_vocoder", local_files_only=True) |
|
|
|
wav, sr = librosa.load(x, sr=44100, mono=True) |
|
wav = torch.from_numpy(wav).float()[None] |
|
mel = model.encode(wav) |
|
|
|
wav = model.decode(mel)[0].mT |
|
sf.write("test_audio_vocoder_rec.flac", wav.cpu().numpy(), 44100) |
|
|