# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn.functional as F import torch.nn as nn from torch import nn from modules.vocoder_blocks import * from einops import rearrange import torchaudio.transforms as T from nnAudio import features LRELU_SLOPE = 0.1 class DiscriminatorCQT(nn.Module): def __init__(self, cfg, hop_length, n_octaves, bins_per_octave): super(DiscriminatorCQT, self).__init__() self.cfg = cfg self.filters = cfg.model.mssbcqtd.filters self.max_filters = cfg.model.mssbcqtd.max_filters self.filters_scale = cfg.model.mssbcqtd.filters_scale self.kernel_size = (3, 9) self.dilations = cfg.model.mssbcqtd.dilations self.stride = (1, 2) self.in_channels = cfg.model.mssbcqtd.in_channels self.out_channels = cfg.model.mssbcqtd.out_channels self.fs = cfg.preprocess.sample_rate self.hop_length = hop_length self.n_octaves = n_octaves self.bins_per_octave = bins_per_octave self.cqt_transform = features.cqt.CQT2010v2( sr=self.fs * 2, hop_length=self.hop_length, n_bins=self.bins_per_octave * self.n_octaves, bins_per_octave=self.bins_per_octave, output_format="Complex", pad_mode="constant", ) self.conv_pres = nn.ModuleList() for i in range(self.n_octaves): self.conv_pres.append( NormConv2d( self.in_channels * 2, self.in_channels * 2, kernel_size=self.kernel_size, padding=get_2d_padding(self.kernel_size), ) ) self.convs = nn.ModuleList() self.convs.append( NormConv2d( self.in_channels * 2, self.filters, kernel_size=self.kernel_size, padding=get_2d_padding(self.kernel_size), ) ) in_chs = min(self.filters_scale * self.filters, self.max_filters) for i, dilation in enumerate(self.dilations): out_chs = min( (self.filters_scale ** (i + 1)) * self.filters, self.max_filters ) self.convs.append( NormConv2d( in_chs, out_chs, kernel_size=self.kernel_size, stride=self.stride, dilation=(dilation, 1), padding=get_2d_padding(self.kernel_size, (dilation, 1)), norm="weight_norm", ) ) in_chs = out_chs out_chs = min( (self.filters_scale ** (len(self.dilations) + 1)) * self.filters, self.max_filters, ) self.convs.append( NormConv2d( in_chs, out_chs, kernel_size=(self.kernel_size[0], self.kernel_size[0]), padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])), norm="weight_norm", ) ) self.conv_post = NormConv2d( out_chs, self.out_channels, kernel_size=(self.kernel_size[0], self.kernel_size[0]), padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])), norm="weight_norm", ) self.activation = torch.nn.LeakyReLU(negative_slope=LRELU_SLOPE) self.resample = T.Resample(orig_freq=self.fs, new_freq=self.fs * 2) def forward(self, x): fmap = [] x = self.resample(x) z = self.cqt_transform(x) z_amplitude = z[:, :, :, 0].unsqueeze(1) z_phase = z[:, :, :, 1].unsqueeze(1) z = torch.cat([z_amplitude, z_phase], dim=1) z = rearrange(z, "b c w t -> b c t w") latent_z = [] for i in range(self.n_octaves): latent_z.append( self.conv_pres[i]( z[ :, :, :, i * self.bins_per_octave : (i + 1) * self.bins_per_octave, ] ) ) latent_z = torch.cat(latent_z, dim=-1) for i, l in enumerate(self.convs): latent_z = l(latent_z) latent_z = self.activation(latent_z) fmap.append(latent_z) latent_z = self.conv_post(latent_z) return latent_z, fmap class MultiScaleSubbandCQTDiscriminator(nn.Module): def __init__(self, cfg): super(MultiScaleSubbandCQTDiscriminator, self).__init__() self.cfg = cfg self.discriminators = nn.ModuleList( [ DiscriminatorCQT( cfg, hop_length=cfg.model.mssbcqtd.hop_lengths[i], n_octaves=cfg.model.mssbcqtd.n_octaves[i], bins_per_octave=cfg.model.mssbcqtd.bins_per_octaves[i], ) for i in range(len(cfg.model.mssbcqtd.hop_lengths)) ] ) def forward(self, y, y_hat): y_d_rs = [] y_d_gs = [] fmap_rs = [] fmap_gs = [] for disc in self.discriminators: y_d_r, fmap_r = disc(y) y_d_g, fmap_g = disc(y_hat) y_d_rs.append(y_d_r) fmap_rs.append(fmap_r) y_d_gs.append(y_d_g) fmap_gs.append(fmap_g) return y_d_rs, y_d_gs, fmap_rs, fmap_gs