Spaces:
Paused
Paused
| from pathlib import Path | |
| import torch | |
| import torch.nn as nn | |
| from omegaconf import OmegaConf | |
| from mmaudio.ext.bigvgan.models import BigVGANVocoder | |
| _bigvgan_vocoder_path = Path(__file__).parent / 'bigvgan_vocoder.yml' | |
| class BigVGAN(nn.Module): | |
| def __init__(self, ckpt_path, config_path=_bigvgan_vocoder_path): | |
| super().__init__() | |
| vocoder_cfg = OmegaConf.load(config_path) | |
| self.vocoder = BigVGANVocoder(vocoder_cfg).eval() | |
| vocoder_ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)['generator'] | |
| self.vocoder.load_state_dict(vocoder_ckpt) | |
| self.weight_norm_removed = False | |
| self.remove_weight_norm() | |
| def forward(self, x): | |
| assert self.weight_norm_removed, 'call remove_weight_norm() before inference' | |
| return self.vocoder(x) | |
| def remove_weight_norm(self): | |
| self.vocoder.remove_weight_norm() | |
| self.weight_norm_removed = True | |
| return self | |