| from typing import Literal, Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from mmaudio.ext.autoencoder.vae import VAE, get_my_vae |
| from mmaudio.ext.bigvgan import BigVGAN |
| from mmaudio.ext.bigvgan_v2.bigvgan import BigVGAN as BigVGANv2 |
| from mmaudio.model.utils.distributions import DiagonalGaussianDistribution |
|
|
|
|
| class AutoEncoderModule(nn.Module): |
|
|
| def __init__(self, |
| *, |
| vae_ckpt_path, |
| vocoder_ckpt_path: Optional[str] = None, |
| mode: Literal['16k', '44k'], |
| need_vae_encoder: bool = True): |
| super().__init__() |
| self.vae: VAE = get_my_vae(mode).eval() |
| vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu') |
| self.vae.load_state_dict(vae_state_dict, strict=False) |
| self.vae.remove_weight_norm() |
|
|
| if mode == '16k': |
| assert vocoder_ckpt_path is not None |
| self.vocoder = BigVGAN(vocoder_ckpt_path).eval() |
| elif mode == '44k': |
| self.vocoder = BigVGANv2.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x', |
| use_cuda_kernel=False) |
| self.vocoder.remove_weight_norm() |
| else: |
| raise ValueError(f'Unknown mode: {mode}') |
|
|
| for param in self.parameters(): |
| param.requires_grad = False |
|
|
| if not need_vae_encoder: |
| del self.vae.encoder |
|
|
| @torch.inference_mode() |
| def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution: |
| return self.vae.encode(x) |
|
|
| @torch.inference_mode() |
| def decode(self, z: torch.Tensor) -> torch.Tensor: |
| return self.vae.decode(z) |
|
|
| @torch.inference_mode() |
| def vocode(self, spec: torch.Tensor) -> torch.Tensor: |
| return self.vocoder(spec) |
|
|