# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import numpy as np import torch import torch.nn as nn from torchaudio.models import Conformer from models.svc.transformer.transformer import PositionalEncoding from utils.f0 import f0_to_coarse class ContentEncoder(nn.Module): def __init__(self, cfg, input_dim, output_dim): super().__init__() self.cfg = cfg assert input_dim != 0 self.nn = nn.Linear(input_dim, output_dim) # Introduce conformer or not if ( "use_conformer_for_content_features" in cfg and cfg.use_conformer_for_content_features ): self.pos_encoder = PositionalEncoding(input_dim) self.conformer = Conformer( input_dim=input_dim, num_heads=2, ffn_dim=256, num_layers=6, depthwise_conv_kernel_size=3, ) else: self.conformer = None def forward(self, x, length=None): # x: (N, seq_len, input_dim) -> (N, seq_len, output_dim) if self.conformer: x = self.pos_encoder(x) x, _ = self.conformer(x, length) return self.nn(x) class MelodyEncoder(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg self.input_dim = self.cfg.input_melody_dim self.output_dim = self.cfg.output_melody_dim self.n_bins = self.cfg.n_bins_melody self.pitch_min = self.cfg.pitch_min self.pitch_max = self.cfg.pitch_max if self.input_dim != 0: if self.n_bins == 0: # Not use quantization self.nn = nn.Linear(self.input_dim, self.output_dim) else: self.f0_min = cfg.f0_min self.f0_max = cfg.f0_max self.nn = nn.Embedding( num_embeddings=self.n_bins, embedding_dim=self.output_dim, padding_idx=None, ) self.uv_embedding = nn.Embedding(2, self.output_dim) # self.conformer = Conformer( # input_dim=self.output_dim, # num_heads=4, # ffn_dim=128, # num_layers=4, # depthwise_conv_kernel_size=3, # ) def forward(self, x, uv=None, length=None): # x: (N, frame_len) # print(x.shape) if self.n_bins == 0: x = x.unsqueeze(-1) else: x = f0_to_coarse(x, self.n_bins, self.f0_min, self.f0_max) x = self.nn(x) if uv is not None: uv = self.uv_embedding(uv) x = x + uv # x, _ = self.conformer(x, length) return x class LoudnessEncoder(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg self.input_dim = self.cfg.input_loudness_dim self.output_dim = self.cfg.output_loudness_dim self.n_bins = self.cfg.n_bins_loudness if self.input_dim != 0: if self.n_bins == 0: # Not use quantization self.nn = nn.Linear(self.input_dim, self.output_dim) else: # TODO: set trivially now self.loudness_min = 1e-30 self.loudness_max = 1.5 if cfg.use_log_loudness: self.energy_bins = nn.Parameter( torch.exp( torch.linspace( np.log(self.loudness_min), np.log(self.loudness_max), self.n_bins - 1, ) ), requires_grad=False, ) self.nn = nn.Embedding( num_embeddings=self.n_bins, embedding_dim=self.output_dim, padding_idx=None, ) def forward(self, x): # x: (N, frame_len) if self.n_bins == 0: x = x.unsqueeze(-1) else: x = torch.bucketize(x, self.energy_bins) return self.nn(x) class SingerEncoder(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg self.input_dim = 1 self.output_dim = self.cfg.output_singer_dim self.nn = nn.Embedding( num_embeddings=cfg.singer_table_size, embedding_dim=self.output_dim, padding_idx=None, ) def forward(self, x): # x: (N, 1) -> (N, 1, output_dim) return self.nn(x) class ConditionEncoder(nn.Module): def __init__(self, cfg): super().__init__() self.cfg = cfg self.merge_mode = cfg.merge_mode if cfg.use_whisper: self.whisper_encoder = ContentEncoder( self.cfg, self.cfg.whisper_dim, self.cfg.content_encoder_dim ) if cfg.use_contentvec: self.contentvec_encoder = ContentEncoder( self.cfg, self.cfg.contentvec_dim, self.cfg.content_encoder_dim ) if cfg.use_mert: self.mert_encoder = ContentEncoder( self.cfg, self.cfg.mert_dim, self.cfg.content_encoder_dim ) if cfg.use_wenet: self.wenet_encoder = ContentEncoder( self.cfg, self.cfg.wenet_dim, self.cfg.content_encoder_dim ) self.melody_encoder = MelodyEncoder(self.cfg) self.loudness_encoder = LoudnessEncoder(self.cfg) if cfg.use_spkid: self.singer_encoder = SingerEncoder(self.cfg) def forward(self, x): outputs = [] if "frame_pitch" in x.keys(): if "frame_uv" not in x.keys(): x["frame_uv"] = None pitch_enc_out = self.melody_encoder( x["frame_pitch"], uv=x["frame_uv"], length=x["target_len"] ) outputs.append(pitch_enc_out) if "frame_energy" in x.keys(): loudness_enc_out = self.loudness_encoder(x["frame_energy"]) outputs.append(loudness_enc_out) if "whisper_feat" in x.keys(): # whisper_feat: [b, T, 1024] whiser_enc_out = self.whisper_encoder( x["whisper_feat"], length=x["target_len"] ) outputs.append(whiser_enc_out) seq_len = whiser_enc_out.shape[1] if "contentvec_feat" in x.keys(): contentvec_enc_out = self.contentvec_encoder( x["contentvec_feat"], length=x["target_len"] ) outputs.append(contentvec_enc_out) seq_len = contentvec_enc_out.shape[1] if "mert_feat" in x.keys(): mert_enc_out = self.mert_encoder(x["mert_feat"], length=x["target_len"]) outputs.append(mert_enc_out) seq_len = mert_enc_out.shape[1] if "wenet_feat" in x.keys(): wenet_enc_out = self.wenet_encoder(x["wenet_feat"], length=x["target_len"]) outputs.append(wenet_enc_out) seq_len = wenet_enc_out.shape[1] if "spk_id" in x.keys(): speaker_enc_out = self.singer_encoder(x["spk_id"]) # [b, 1, 384] assert ( "whisper_feat" in x.keys() or "contentvec_feat" in x.keys() or "mert_feat" in x.keys() or "wenet_feat" in x.keys() ) singer_info = speaker_enc_out.expand(-1, seq_len, -1) outputs.append(singer_info) encoder_output = None if self.merge_mode == "concat": encoder_output = torch.cat(outputs, dim=-1) if self.merge_mode == "add": # (#modules, N, seq_len, output_dim) outputs = torch.cat([out[None, :, :, :] for out in outputs], dim=0) # (N, seq_len, output_dim) encoder_output = torch.sum(outputs, dim=0) return encoder_output