Spaces:
Sleeping
Sleeping
# code adapted from: https://github.com/Stability-AI/stable-audio-tools | |
import torch | |
from torch import nn | |
from typing import Literal, Dict, Any | |
import math | |
import comfy.ops | |
ops = comfy.ops.disable_weight_init | |
def vae_sample(mean, scale): | |
stdev = nn.functional.softplus(scale) + 1e-4 | |
var = stdev * stdev | |
logvar = torch.log(var) | |
latents = torch.randn_like(mean) * stdev + mean | |
kl = (mean * mean + var - logvar - 1).sum(1).mean() | |
return latents, kl | |
class VAEBottleneck(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.is_discrete = False | |
def encode(self, x, return_info=False, **kwargs): | |
info = {} | |
mean, scale = x.chunk(2, dim=1) | |
x, kl = vae_sample(mean, scale) | |
info["kl"] = kl | |
if return_info: | |
return x, info | |
else: | |
return x | |
def decode(self, x): | |
return x | |
def snake_beta(x, alpha, beta): | |
return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) | |
# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license | |
class SnakeBeta(nn.Module): | |
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): | |
super(SnakeBeta, self).__init__() | |
self.in_features = in_features | |
# initialize alpha | |
self.alpha_logscale = alpha_logscale | |
if self.alpha_logscale: # log scale alphas initialized to zeros | |
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) | |
self.beta = nn.Parameter(torch.zeros(in_features) * alpha) | |
else: # linear scale alphas initialized to ones | |
self.alpha = nn.Parameter(torch.ones(in_features) * alpha) | |
self.beta = nn.Parameter(torch.ones(in_features) * alpha) | |
# self.alpha.requires_grad = alpha_trainable | |
# self.beta.requires_grad = alpha_trainable | |
self.no_div_by_zero = 0.000000001 | |
def forward(self, x): | |
alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) # line up with x to [B, C, T] | |
beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device) | |
if self.alpha_logscale: | |
alpha = torch.exp(alpha) | |
beta = torch.exp(beta) | |
x = snake_beta(x, alpha, beta) | |
return x | |
def WNConv1d(*args, **kwargs): | |
try: | |
return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs)) | |
except: | |
return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older | |
def WNConvTranspose1d(*args, **kwargs): | |
try: | |
return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) | |
except: | |
return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older | |
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: | |
if activation == "elu": | |
act = torch.nn.ELU() | |
elif activation == "snake": | |
act = SnakeBeta(channels) | |
elif activation == "none": | |
act = torch.nn.Identity() | |
else: | |
raise ValueError(f"Unknown activation {activation}") | |
if antialias: | |
act = Activation1d(act) | |
return act | |
class ResidualUnit(nn.Module): | |
def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False): | |
super().__init__() | |
self.dilation = dilation | |
padding = (dilation * (7-1)) // 2 | |
self.layers = nn.Sequential( | |
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), | |
WNConv1d(in_channels=in_channels, out_channels=out_channels, | |
kernel_size=7, dilation=dilation, padding=padding), | |
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), | |
WNConv1d(in_channels=out_channels, out_channels=out_channels, | |
kernel_size=1) | |
) | |
def forward(self, x): | |
res = x | |
#x = checkpoint(self.layers, x) | |
x = self.layers(x) | |
return x + res | |
class EncoderBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False): | |
super().__init__() | |
self.layers = nn.Sequential( | |
ResidualUnit(in_channels=in_channels, | |
out_channels=in_channels, dilation=1, use_snake=use_snake), | |
ResidualUnit(in_channels=in_channels, | |
out_channels=in_channels, dilation=3, use_snake=use_snake), | |
ResidualUnit(in_channels=in_channels, | |
out_channels=in_channels, dilation=9, use_snake=use_snake), | |
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), | |
WNConv1d(in_channels=in_channels, out_channels=out_channels, | |
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)), | |
) | |
def forward(self, x): | |
return self.layers(x) | |
class DecoderBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False): | |
super().__init__() | |
if use_nearest_upsample: | |
upsample_layer = nn.Sequential( | |
nn.Upsample(scale_factor=stride, mode="nearest"), | |
WNConv1d(in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=2*stride, | |
stride=1, | |
bias=False, | |
padding='same') | |
) | |
else: | |
upsample_layer = WNConvTranspose1d(in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) | |
self.layers = nn.Sequential( | |
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), | |
upsample_layer, | |
ResidualUnit(in_channels=out_channels, out_channels=out_channels, | |
dilation=1, use_snake=use_snake), | |
ResidualUnit(in_channels=out_channels, out_channels=out_channels, | |
dilation=3, use_snake=use_snake), | |
ResidualUnit(in_channels=out_channels, out_channels=out_channels, | |
dilation=9, use_snake=use_snake), | |
) | |
def forward(self, x): | |
return self.layers(x) | |
class OobleckEncoder(nn.Module): | |
def __init__(self, | |
in_channels=2, | |
channels=128, | |
latent_dim=32, | |
c_mults = [1, 2, 4, 8], | |
strides = [2, 4, 8, 8], | |
use_snake=False, | |
antialias_activation=False | |
): | |
super().__init__() | |
c_mults = [1] + c_mults | |
self.depth = len(c_mults) | |
layers = [ | |
WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3) | |
] | |
for i in range(self.depth-1): | |
layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)] | |
layers += [ | |
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels), | |
WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1) | |
] | |
self.layers = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.layers(x) | |
class OobleckDecoder(nn.Module): | |
def __init__(self, | |
out_channels=2, | |
channels=128, | |
latent_dim=32, | |
c_mults = [1, 2, 4, 8], | |
strides = [2, 4, 8, 8], | |
use_snake=False, | |
antialias_activation=False, | |
use_nearest_upsample=False, | |
final_tanh=True): | |
super().__init__() | |
c_mults = [1] + c_mults | |
self.depth = len(c_mults) | |
layers = [ | |
WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3), | |
] | |
for i in range(self.depth-1, 0, -1): | |
layers += [DecoderBlock( | |
in_channels=c_mults[i]*channels, | |
out_channels=c_mults[i-1]*channels, | |
stride=strides[i-1], | |
use_snake=use_snake, | |
antialias_activation=antialias_activation, | |
use_nearest_upsample=use_nearest_upsample | |
) | |
] | |
layers += [ | |
get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels), | |
WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False), | |
nn.Tanh() if final_tanh else nn.Identity() | |
] | |
self.layers = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.layers(x) | |
class AudioOobleckVAE(nn.Module): | |
def __init__(self, | |
in_channels=2, | |
channels=128, | |
latent_dim=64, | |
c_mults = [1, 2, 4, 8, 16], | |
strides = [2, 4, 4, 8, 8], | |
use_snake=True, | |
antialias_activation=False, | |
use_nearest_upsample=False, | |
final_tanh=False): | |
super().__init__() | |
self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation) | |
self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation, | |
use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh) | |
self.bottleneck = VAEBottleneck() | |
def encode(self, x): | |
return self.bottleneck.encode(self.encoder(x)) | |
def decode(self, x): | |
return self.decoder(self.bottleneck.decode(x)) | |