from pathlib import Path import torch import torch.nn as nn from omegaconf import OmegaConf from mmaudio.ext.bigvgan.models import BigVGANVocoder _bigvgan_vocoder_path = Path(__file__).parent / 'bigvgan_vocoder.yml' class BigVGAN(nn.Module): def __init__(self, ckpt_path, config_path=_bigvgan_vocoder_path): super().__init__() vocoder_cfg = OmegaConf.load(config_path) self.vocoder = BigVGANVocoder(vocoder_cfg).eval() vocoder_ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)['generator'] self.vocoder.load_state_dict(vocoder_ckpt) self.weight_norm_removed = False self.remove_weight_norm() @torch.inference_mode() def forward(self, x): assert self.weight_norm_removed, 'call remove_weight_norm() before inference' return self.vocoder(x) def remove_weight_norm(self): self.vocoder.remove_weight_norm() self.weight_norm_removed = True return self