Spaces:
Runtime error
Runtime error
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torchaudio.models import Conformer | |
from models.svc.transformer.transformer import PositionalEncoding | |
from utils.f0 import f0_to_coarse | |
class ContentEncoder(nn.Module): | |
def __init__(self, cfg, input_dim, output_dim): | |
super().__init__() | |
self.cfg = cfg | |
assert input_dim != 0 | |
self.nn = nn.Linear(input_dim, output_dim) | |
# Introduce conformer or not | |
if ( | |
"use_conformer_for_content_features" in cfg | |
and cfg.use_conformer_for_content_features | |
): | |
self.pos_encoder = PositionalEncoding(input_dim) | |
self.conformer = Conformer( | |
input_dim=input_dim, | |
num_heads=2, | |
ffn_dim=256, | |
num_layers=6, | |
depthwise_conv_kernel_size=3, | |
) | |
else: | |
self.conformer = None | |
def forward(self, x, length=None): | |
# x: (N, seq_len, input_dim) -> (N, seq_len, output_dim) | |
if self.conformer: | |
x = self.pos_encoder(x) | |
x, _ = self.conformer(x, length) | |
return self.nn(x) | |
class MelodyEncoder(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
self.cfg = cfg | |
self.input_dim = self.cfg.input_melody_dim | |
self.output_dim = self.cfg.output_melody_dim | |
self.n_bins = self.cfg.n_bins_melody | |
self.pitch_min = self.cfg.pitch_min | |
self.pitch_max = self.cfg.pitch_max | |
if self.input_dim != 0: | |
if self.n_bins == 0: | |
# Not use quantization | |
self.nn = nn.Linear(self.input_dim, self.output_dim) | |
else: | |
self.f0_min = cfg.f0_min | |
self.f0_max = cfg.f0_max | |
self.nn = nn.Embedding( | |
num_embeddings=self.n_bins, | |
embedding_dim=self.output_dim, | |
padding_idx=None, | |
) | |
self.uv_embedding = nn.Embedding(2, self.output_dim) | |
# self.conformer = Conformer( | |
# input_dim=self.output_dim, | |
# num_heads=4, | |
# ffn_dim=128, | |
# num_layers=4, | |
# depthwise_conv_kernel_size=3, | |
# ) | |
def forward(self, x, uv=None, length=None): | |
# x: (N, frame_len) | |
# print(x.shape) | |
if self.n_bins == 0: | |
x = x.unsqueeze(-1) | |
else: | |
x = f0_to_coarse(x, self.n_bins, self.f0_min, self.f0_max) | |
x = self.nn(x) | |
if uv is not None: | |
uv = self.uv_embedding(uv) | |
x = x + uv | |
# x, _ = self.conformer(x, length) | |
return x | |
class LoudnessEncoder(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
self.cfg = cfg | |
self.input_dim = self.cfg.input_loudness_dim | |
self.output_dim = self.cfg.output_loudness_dim | |
self.n_bins = self.cfg.n_bins_loudness | |
if self.input_dim != 0: | |
if self.n_bins == 0: | |
# Not use quantization | |
self.nn = nn.Linear(self.input_dim, self.output_dim) | |
else: | |
# TODO: set trivially now | |
self.loudness_min = 1e-30 | |
self.loudness_max = 1.5 | |
if cfg.use_log_loudness: | |
self.energy_bins = nn.Parameter( | |
torch.exp( | |
torch.linspace( | |
np.log(self.loudness_min), | |
np.log(self.loudness_max), | |
self.n_bins - 1, | |
) | |
), | |
requires_grad=False, | |
) | |
self.nn = nn.Embedding( | |
num_embeddings=self.n_bins, | |
embedding_dim=self.output_dim, | |
padding_idx=None, | |
) | |
def forward(self, x): | |
# x: (N, frame_len) | |
if self.n_bins == 0: | |
x = x.unsqueeze(-1) | |
else: | |
x = torch.bucketize(x, self.energy_bins) | |
return self.nn(x) | |
class SingerEncoder(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
self.cfg = cfg | |
self.input_dim = 1 | |
self.output_dim = self.cfg.output_singer_dim | |
self.nn = nn.Embedding( | |
num_embeddings=cfg.singer_table_size, | |
embedding_dim=self.output_dim, | |
padding_idx=None, | |
) | |
def forward(self, x): | |
# x: (N, 1) -> (N, 1, output_dim) | |
return self.nn(x) | |
class ConditionEncoder(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
self.cfg = cfg | |
self.merge_mode = cfg.merge_mode | |
if cfg.use_whisper: | |
self.whisper_encoder = ContentEncoder( | |
self.cfg, self.cfg.whisper_dim, self.cfg.content_encoder_dim | |
) | |
if cfg.use_contentvec: | |
self.contentvec_encoder = ContentEncoder( | |
self.cfg, self.cfg.contentvec_dim, self.cfg.content_encoder_dim | |
) | |
if cfg.use_mert: | |
self.mert_encoder = ContentEncoder( | |
self.cfg, self.cfg.mert_dim, self.cfg.content_encoder_dim | |
) | |
if cfg.use_wenet: | |
self.wenet_encoder = ContentEncoder( | |
self.cfg, self.cfg.wenet_dim, self.cfg.content_encoder_dim | |
) | |
self.melody_encoder = MelodyEncoder(self.cfg) | |
self.loudness_encoder = LoudnessEncoder(self.cfg) | |
if cfg.use_spkid: | |
self.singer_encoder = SingerEncoder(self.cfg) | |
def forward(self, x): | |
outputs = [] | |
if "frame_pitch" in x.keys(): | |
if "frame_uv" not in x.keys(): | |
x["frame_uv"] = None | |
pitch_enc_out = self.melody_encoder( | |
x["frame_pitch"], uv=x["frame_uv"], length=x["target_len"] | |
) | |
outputs.append(pitch_enc_out) | |
if "frame_energy" in x.keys(): | |
loudness_enc_out = self.loudness_encoder(x["frame_energy"]) | |
outputs.append(loudness_enc_out) | |
if "whisper_feat" in x.keys(): | |
# whisper_feat: [b, T, 1024] | |
whiser_enc_out = self.whisper_encoder( | |
x["whisper_feat"], length=x["target_len"] | |
) | |
outputs.append(whiser_enc_out) | |
seq_len = whiser_enc_out.shape[1] | |
if "contentvec_feat" in x.keys(): | |
contentvec_enc_out = self.contentvec_encoder( | |
x["contentvec_feat"], length=x["target_len"] | |
) | |
outputs.append(contentvec_enc_out) | |
seq_len = contentvec_enc_out.shape[1] | |
if "mert_feat" in x.keys(): | |
mert_enc_out = self.mert_encoder(x["mert_feat"], length=x["target_len"]) | |
outputs.append(mert_enc_out) | |
seq_len = mert_enc_out.shape[1] | |
if "wenet_feat" in x.keys(): | |
wenet_enc_out = self.wenet_encoder(x["wenet_feat"], length=x["target_len"]) | |
outputs.append(wenet_enc_out) | |
seq_len = wenet_enc_out.shape[1] | |
if "spk_id" in x.keys(): | |
speaker_enc_out = self.singer_encoder(x["spk_id"]) # [b, 1, 384] | |
assert ( | |
"whisper_feat" in x.keys() | |
or "contentvec_feat" in x.keys() | |
or "mert_feat" in x.keys() | |
or "wenet_feat" in x.keys() | |
) | |
singer_info = speaker_enc_out.expand(-1, seq_len, -1) | |
outputs.append(singer_info) | |
encoder_output = None | |
if self.merge_mode == "concat": | |
encoder_output = torch.cat(outputs, dim=-1) | |
if self.merge_mode == "add": | |
# (#modules, N, seq_len, output_dim) | |
outputs = torch.cat([out[None, :, :, :] for out in outputs], dim=0) | |
# (N, seq_len, output_dim) | |
encoder_output = torch.sum(outputs, dim=0) | |
return encoder_output | |