Spaces:
Running
on
Zero
Running
on
Zero
from typing import Literal, Optional | |
import torch | |
import torch.nn as nn | |
from mmaudio.ext.autoencoder.vae import VAE, get_my_vae | |
from mmaudio.ext.bigvgan import BigVGAN | |
from mmaudio.ext.bigvgan_v2.bigvgan import BigVGAN as BigVGANv2 | |
from mmaudio.model.utils.distributions import DiagonalGaussianDistribution | |
class AutoEncoderModule(nn.Module): | |
def __init__(self, | |
*, | |
vae_ckpt_path, | |
vocoder_ckpt_path: Optional[str] = None, | |
mode: Literal['16k', '44k'], | |
need_vae_encoder: bool = True): | |
super().__init__() | |
self.vae: VAE = get_my_vae(mode).eval() | |
vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu') | |
self.vae.load_state_dict(vae_state_dict, strict=False) | |
self.vae.remove_weight_norm() | |
if mode == '16k': | |
assert vocoder_ckpt_path is not None | |
self.vocoder = BigVGAN(vocoder_ckpt_path).eval() | |
elif mode == '44k': | |
self.vocoder = BigVGANv2.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x', | |
use_cuda_kernel=False) | |
self.vocoder.remove_weight_norm() | |
else: | |
raise ValueError(f'Unknown mode: {mode}') | |
for param in self.parameters(): | |
param.requires_grad = False | |
if not need_vae_encoder: | |
del self.vae.encoder | |
def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution: | |
return self.vae.encode(x) | |
def decode(self, z: torch.Tensor) -> torch.Tensor: | |
return self.vae.decode(z) | |
def vocode(self, spec: torch.Tensor) -> torch.Tensor: | |
return self.vocoder(spec) | |