# Copyright (c) 2024 NVIDIA CORPORATION. # Licensed under the MIT license. # Adapted from https://github.com/jik876/hifi-gan under the MIT license. # LICENSE is in incl_licenses directory. import os import json from pathlib import Path from collections import namedtuple from typing import Optional, List, Union, Dict import torch import torch.nn.functional as F import torch.nn as nn from torch.nn import Conv1d, ConvTranspose1d from torch.nn.utils import weight_norm, remove_weight_norm import activations from utils import init_weights, get_padding from alias_free_torch.act import Activation1d as TorchActivation1d from env import AttrDict from huggingface_hub import PyTorchModelHubMixin, hf_hub_download def load_hparams_from_json(path) -> AttrDict: with open(path) as f: data = f.read() h = json.loads(data) return AttrDict(h) class AMPBlock1(torch.nn.Module): def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None): super(AMPBlock1, self).__init__() self.h = h self.convs1 = nn.ModuleList([ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2]))) ]) self.convs1.apply(init_weights) self.convs2 = nn.ModuleList([ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))) ]) self.convs2.apply(init_weights) self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers # select which Activation1d, lazy-load cuda version to ensure backward compatibility if self.h.get("use_cuda_kernel", False): # faster CUDA kernel implementation of Activation1d from alias_free_cuda.activation1d import Activation1d as CudaActivation1d Activation1d = CudaActivation1d else: Activation1d = TorchActivation1d if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing self.activations = nn.ModuleList([ Activation1d( activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) for _ in range(self.num_layers) ]) elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing self.activations = nn.ModuleList([ Activation1d( activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) for _ in range(self.num_layers) ]) else: raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") def forward(self, x): acts1, acts2 = self.activations[::2], self.activations[1::2] for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): xt = a1(x) xt = c1(xt) xt = a2(xt) xt = c2(xt) x = xt + x return x def remove_weight_norm(self): for l in self.convs1: remove_weight_norm(l) for l in self.convs2: remove_weight_norm(l) class AMPBlock2(torch.nn.Module): def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None): super(AMPBlock2, self).__init__() self.h = h self.convs = nn.ModuleList([ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))), weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))) ]) self.convs.apply(init_weights) self.num_layers = len(self.convs) # total number of conv layers # select which Activation1d, lazy-load cuda version to ensure backward compatibility if self.h.get("use_cuda_kernel", False): # faster CUDA kernel implementation of Activation1d from alias_free_cuda.activation1d import Activation1d as CudaActivation1d Activation1d = CudaActivation1d else: Activation1d = TorchActivation1d if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing self.activations = nn.ModuleList([ Activation1d( activation=activations.Snake(channels, alpha_logscale=h.snake_logscale)) for _ in range(self.num_layers) ]) elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing self.activations = nn.ModuleList([ Activation1d( activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale)) for _ in range(self.num_layers) ]) else: raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") def forward(self, x): for c, a in zip (self.convs, self.activations): xt = a(x) xt = c(xt) x = xt + x return x def remove_weight_norm(self): for l in self.convs: remove_weight_norm(l) class BigVGAN( torch.nn.Module, PyTorchModelHubMixin, library_name="bigvgan", repo_url="https://github.com/NVIDIA/BigVGAN", docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md", pipeline_tag="audio-to-audio", license="mit", tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"] ): # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks. # New in v2: if use_cuda_kernel is set to True, it loads optimized CUDA kernels for AMP. # NOTE: use_cuda_kernel=True should be used for inference only (training is not supported). def __init__( self, h, use_cuda_kernel: bool=False ): super(BigVGAN, self).__init__() self.h = h self.h["use_cuda_kernel"] = use_cuda_kernel # add it to global hyperparameters (h) self.num_kernels = len(h.resblock_kernel_sizes) self.num_upsamples = len(h.upsample_rates) # pre conv self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)) # define which AMPBlock to use. BigVGAN uses AMPBlock1 as default resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2 # transposed conv-based upsamplers. does not apply anti-aliasing self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): self.ups.append(nn.ModuleList([ weight_norm(ConvTranspose1d(h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2)) ])) # residual blocks using anti-aliased multi-periodicity composition modules (AMP) self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = h.upsample_initial_channel // (2 ** (i + 1)) for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): self.resblocks.append(resblock(h, ch, k, d, activation=h.activation)) # select which Activation1d, lazy-load cuda version to ensure backward compatibility if self.h.get("use_cuda_kernel", False): # faster CUDA kernel implementation of Activation1d from alias_free_cuda.activation1d import Activation1d as CudaActivation1d Activation1d = CudaActivation1d else: Activation1d = TorchActivation1d # post conv if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale) self.activation_post = Activation1d(activation=activation_post) elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) self.activation_post = Activation1d(activation=activation_post) else: raise NotImplementedError("activation incorrectly specified. check the config file and look for 'activation'.") # whether to use bias for the final conv_post. Defaults to True for backward compatibility self.use_bias_at_final = h.get("use_bias_at_final", True) self.conv_post = weight_norm(Conv1d( ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final )) # weight initialization for i in range(len(self.ups)): self.ups[i].apply(init_weights) self.conv_post.apply(init_weights) # final tanh activation. Defaults to True for backward compatibility self.use_tanh_at_final = h.get("use_tanh_at_final", True) def forward(self, x): # pre conv x = self.conv_pre(x) for i in range(self.num_upsamples): # upsampling for i_up in range(len(self.ups[i])): x = self.ups[i][i_up](x) # AMP blocks xs = None for j in range(self.num_kernels): if xs is None: xs = self.resblocks[i * self.num_kernels + j](x) else: xs += self.resblocks[i * self.num_kernels + j](x) x = xs / self.num_kernels # post conv x = self.activation_post(x) x = self.conv_post(x) # final tanh activation if self.use_tanh_at_final: x = torch.tanh(x) else: x = torch.clamp(x, min=-1., max=1.) # bound the output to [-1, 1] return x def remove_weight_norm(self): try: print('Removing weight norm...') for l in self.ups: for l_i in l: remove_weight_norm(l_i) for l in self.resblocks: l.remove_weight_norm() remove_weight_norm(self.conv_pre) remove_weight_norm(self.conv_post) except ValueError: print('[INFO] Model already removed weight norm. Skipping!') pass ################################################################## # additional methods for huggingface_hub support ################################################################## def _save_pretrained(self, save_directory: Path) -> None: """Save weights and config.json from a Pytorch model to a local directory.""" model_path = save_directory / 'bigvgan_generator.pt' torch.save( {'generator': self.state_dict()}, model_path ) config_path = save_directory / 'config.json' with open(config_path, 'w') as config_file: json.dump(self.h, config_file, indent=4) @classmethod def _from_pretrained( cls, *, model_id: str, revision: str, cache_dir: str, force_download: bool, proxies: Optional[Dict], resume_download: bool, local_files_only: bool, token: Union[str, bool, None], map_location: str = "cpu", # additional argument strict: bool = False, # additional argument use_cuda_kernel: bool = False, **model_kwargs, ): """Load Pytorch pretrained weights and return the loaded model.""" ################################################################## # download and load hyperparameters (h) used by BigVGAN ################################################################## if os.path.isdir(model_id): print("Loading config.json from local directory") config_file = os.path.join(model_id, 'config.json') else: config_file = hf_hub_download( repo_id=model_id, filename='config.json', revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) h = load_hparams_from_json(config_file) ################################################################## # instantiate BigVGAN using h ################################################################## if use_cuda_kernel: print(f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!") print(f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!") print(f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis") model = cls(h, use_cuda_kernel=use_cuda_kernel) ################################################################## # download and load pretrained generator weight ################################################################## if os.path.isdir(model_id): print("Loading weights from local directory") model_file = os.path.join(model_id, 'bigvgan_generator.pt') else: print(f"Loading weights from {model_id}") model_file = hf_hub_download( repo_id=model_id, filename='bigvgan_generator.pt', revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) checkpoint_dict = torch.load(model_file, map_location=map_location) try: model.load_state_dict(checkpoint_dict['generator']) except RuntimeError: print(f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!") model.remove_weight_norm() model.load_state_dict(checkpoint_dict['generator']) return model