rookie9's picture
Upload 77 files
f582ec6 verified
raw
history blame
15 kB
from typing import Any, Literal, Callable
import math
from pathlib import Path
import torch
import torch.nn as nn
from torch.nn.utils.parametrizations import weight_norm
import torchaudio
from alias_free_torch import Activation1d
from models.common import LoadPretrainedBase
from models.autoencoder.autoencoder_base import AutoEncoderBase
from utils.torch_utilities import remove_key_prefix_factory, create_mask_from_length
# jit script make it 1.4x faster and save GPU memory
@torch.jit.script
def snake_beta(x, alpha, beta):
return x + (1.0 / (beta+0.000000001)) * pow(torch.sin(x * alpha), 2)
class SnakeBeta(nn.Module):
def __init__(
self,
in_features,
alpha=1.0,
alpha_trainable=True,
alpha_logscale=True
):
super(SnakeBeta, self).__init__()
self.in_features = in_features
# initialize alpha
self.alpha_logscale = alpha_logscale
if self.alpha_logscale:
# log scale alphas initialized to zeros
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
else:
# linear scale alphas initialized to ones
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
self.beta = nn.Parameter(torch.ones(in_features) * alpha)
self.alpha.requires_grad = alpha_trainable
self.beta.requires_grad = alpha_trainable
# self.no_div_by_zero = 0.000000001
def forward(self, x):
alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
# line up with x to [B, C, T]
beta = self.beta.unsqueeze(0).unsqueeze(-1)
if self.alpha_logscale:
alpha = torch.exp(alpha)
beta = torch.exp(beta)
x = snake_beta(x, alpha, beta)
return x
def WNConv1d(*args, **kwargs):
return weight_norm(nn.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
def get_activation(
activation: Literal["elu", "snake", "none"],
antialias=False,
channels=None
) -> nn.Module:
if activation == "elu":
act = nn.ELU()
elif activation == "snake":
act = SnakeBeta(channels)
elif activation == "none":
act = nn.Identity()
else:
raise ValueError(f"Unknown activation {activation}")
if antialias:
act = Activation1d(act)
return act
class ResidualUnit(nn.Module):
def __init__(
self,
in_channels,
out_channels,
dilation,
use_snake=False,
antialias_activation=False
):
super().__init__()
self.dilation = dilation
padding = (dilation * (7-1)) // 2
self.layers = nn.Sequential(
get_activation(
"snake" if use_snake else "elu",
antialias=antialias_activation,
channels=out_channels
),
WNConv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=7,
dilation=dilation,
padding=padding
),
get_activation(
"snake" if use_snake else "elu",
antialias=antialias_activation,
channels=out_channels
),
WNConv1d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=1
)
)
def forward(self, x):
res = x
#x = checkpoint(self.layers, x)
x = self.layers(x)
return x + res
class EncoderBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
stride,
use_snake=False,
antialias_activation=False
):
super().__init__()
self.layers = nn.Sequential(
ResidualUnit(
in_channels=in_channels,
out_channels=in_channels,
dilation=1,
use_snake=use_snake
),
ResidualUnit(
in_channels=in_channels,
out_channels=in_channels,
dilation=3,
use_snake=use_snake
),
ResidualUnit(
in_channels=in_channels,
out_channels=in_channels,
dilation=9,
use_snake=use_snake
),
get_activation(
"snake" if use_snake else "elu",
antialias=antialias_activation,
channels=in_channels
),
WNConv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2)
),
)
def forward(self, x):
return self.layers(x)
class DecoderBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
stride,
use_snake=False,
antialias_activation=False,
use_nearest_upsample=False
):
super().__init__()
if use_nearest_upsample:
upsample_layer = nn.Sequential(
nn.Upsample(scale_factor=stride, mode="nearest"),
WNConv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=2 * stride,
stride=1,
bias=False,
padding='same'
)
)
else:
upsample_layer = WNConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2)
)
self.layers = nn.Sequential(
get_activation(
"snake" if use_snake else "elu",
antialias=antialias_activation,
channels=in_channels
),
upsample_layer,
ResidualUnit(
in_channels=out_channels,
out_channels=out_channels,
dilation=1,
use_snake=use_snake
),
ResidualUnit(
in_channels=out_channels,
out_channels=out_channels,
dilation=3,
use_snake=use_snake
),
ResidualUnit(
in_channels=out_channels,
out_channels=out_channels,
dilation=9,
use_snake=use_snake
),
)
def forward(self, x):
return self.layers(x)
class OobleckEncoder(nn.Module):
def __init__(
self,
in_channels=2,
channels=128,
latent_dim=32,
c_mults=[1, 2, 4, 8],
strides=[2, 4, 8, 8],
use_snake=False,
antialias_activation=False
):
super().__init__()
c_mults = [1] + c_mults
self.depth = len(c_mults)
layers = [
WNConv1d(
in_channels=in_channels,
out_channels=c_mults[0] * channels,
kernel_size=7,
padding=3
)
]
for i in range(self.depth - 1):
layers += [
EncoderBlock(
in_channels=c_mults[i] * channels,
out_channels=c_mults[i + 1] * channels,
stride=strides[i],
use_snake=use_snake
)
]
layers += [
get_activation(
"snake" if use_snake else "elu",
antialias=antialias_activation,
channels=c_mults[-1] * channels
),
WNConv1d(
in_channels=c_mults[-1] * channels,
out_channels=latent_dim,
kernel_size=3,
padding=1
)
]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class OobleckDecoder(nn.Module):
def __init__(
self,
out_channels=2,
channels=128,
latent_dim=32,
c_mults=[1, 2, 4, 8],
strides=[2, 4, 8, 8],
use_snake=False,
antialias_activation=False,
use_nearest_upsample=False,
final_tanh=True
):
super().__init__()
c_mults = [1] + c_mults
self.depth = len(c_mults)
layers = [
WNConv1d(
in_channels=latent_dim,
out_channels=c_mults[-1] * channels,
kernel_size=7,
padding=3
),
]
for i in range(self.depth - 1, 0, -1):
layers += [
DecoderBlock(
in_channels=c_mults[i] * channels,
out_channels=c_mults[i - 1] * channels,
stride=strides[i - 1],
use_snake=use_snake,
antialias_activation=antialias_activation,
use_nearest_upsample=use_nearest_upsample
)
]
layers += [
get_activation(
"snake" if use_snake else "elu",
antialias=antialias_activation,
channels=c_mults[0] * channels
),
WNConv1d(
in_channels=c_mults[0] * channels,
out_channels=out_channels,
kernel_size=7,
padding=3,
bias=False
),
nn.Tanh() if final_tanh else nn.Identity()
]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class Bottleneck(nn.Module):
def __init__(self, is_discrete: bool = False):
super().__init__()
self.is_discrete = is_discrete
def encode(self, x, return_info=False, **kwargs):
raise NotImplementedError
def decode(self, x):
raise NotImplementedError
@torch.jit.script
def vae_sample(mean, scale) -> dict[str, torch.Tensor]:
stdev = nn.functional.softplus(scale) + 1e-4
var = stdev * stdev
logvar = torch.log(var)
latents = torch.randn_like(mean) * stdev + mean
kl = (mean*mean + var - logvar - 1).sum(1).mean()
return {"latents": latents, "kl": kl}
class VAEBottleneck(Bottleneck):
def __init__(self):
super().__init__(is_discrete=False)
def encode(self,
x,
return_info=False,
**kwargs) -> dict[str, torch.Tensor] | torch.Tensor:
mean, scale = x.chunk(2, dim=1)
sampled = vae_sample(mean, scale)
if return_info:
return sampled["latents"], {"kl": sampled["kl"]}
else:
return sampled["latents"]
def decode(self, x):
return x
def compute_mean_kernel(x, y):
kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
return torch.exp(-kernel_input).mean()
class Pretransform(nn.Module):
def __init__(self, enable_grad, io_channels, is_discrete):
super().__init__()
self.is_discrete = is_discrete
self.io_channels = io_channels
self.encoded_channels = None
self.downsampling_ratio = None
self.enable_grad = enable_grad
def encode(self, x):
raise NotImplementedError
def decode(self, z):
raise NotImplementedError
def tokenize(self, x):
raise NotImplementedError
def decode_tokens(self, tokens):
raise NotImplementedError
class StableVAE(LoadPretrainedBase, AutoEncoderBase):
def __init__(
self,
encoder,
decoder,
latent_dim,
downsampling_ratio,
sample_rate,
io_channels=2,
bottleneck: Bottleneck = None,
pretransform: Pretransform = None,
in_channels=None,
out_channels=None,
soft_clip=False,
pretrained_ckpt: str | Path = None
):
LoadPretrainedBase.__init__(self)
AutoEncoderBase.__init__(
self,
downsampling_ratio=downsampling_ratio,
sample_rate=sample_rate,
latent_shape=(latent_dim, None)
)
self.latent_dim = latent_dim
self.io_channels = io_channels
self.in_channels = io_channels
self.out_channels = io_channels
self.min_length = self.downsampling_ratio
if in_channels is not None:
self.in_channels = in_channels
if out_channels is not None:
self.out_channels = out_channels
self.bottleneck = bottleneck
self.encoder = encoder
self.decoder = decoder
self.pretransform = pretransform
self.soft_clip = soft_clip
self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
self.remove_autoencoder_prefix_fn: Callable = remove_key_prefix_factory(
"autoencoder."
)
if pretrained_ckpt is not None:
self.load_pretrained(pretrained_ckpt)
def process_state_dict(self, model_dict, state_dict):
state_dict = state_dict["state_dict"]
state_dict = self.remove_autoencoder_prefix_fn(model_dict, state_dict)
return state_dict
def encode(
self, waveform: torch.Tensor, waveform_lengths: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
z = self.encoder(waveform)
z = self.bottleneck.encode(z)
z_length = waveform_lengths // self.downsampling_ratio
z_mask = create_mask_from_length(z_length)
return z, z_mask
def decode(self, latents: torch.Tensor) -> torch.Tensor:
waveform = self.decoder(latents)
return waveform
if __name__ == '__main__':
import hydra
from utils.config import generate_config_from_command_line_overrides
model_config = generate_config_from_command_line_overrides(
"configs/model/autoencoder/stable_vae.yaml"
)
autoencoder: StableVAE = hydra.utils.instantiate(model_config)
autoencoder.eval()
waveform, sr = torchaudio.load(
"/hpc_stor03/sjtu_home/xuenan.xu/workspace/singing_voice_synthesis/diffsinger/data/raw/opencpop/segments/wavs/2007000230.wav"
)
waveform = torchaudio.functional.resample(
waveform, sr, model_config["sample_rate"]
)
print("waveform: ", waveform.shape)
with torch.no_grad():
latent, latent_length = autoencoder.encode(
waveform, torch.as_tensor([waveform.shape[-1]])
)
print("latent: ", latent.shape)
reconstructed = autoencoder.decode(latent)
print("reconstructed: ", reconstructed.shape)
import soundfile as sf
sf.write(
"./reconstructed.wav",
reconstructed[0, 0].numpy(),
samplerate=model_config["sample_rate"]
)