|
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from torchaudio.models import Conformer
|
|
from models.svc.transformer.transformer import PositionalEncoding
|
|
|
|
from utils.f0 import f0_to_coarse
|
|
|
|
|
|
class ContentEncoder(nn.Module):
|
|
def __init__(self, cfg, input_dim, output_dim):
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
|
|
assert input_dim != 0
|
|
self.nn = nn.Linear(input_dim, output_dim)
|
|
|
|
|
|
if (
|
|
"use_conformer_for_content_features" in cfg
|
|
and cfg.use_conformer_for_content_features
|
|
):
|
|
self.pos_encoder = PositionalEncoding(input_dim)
|
|
self.conformer = Conformer(
|
|
input_dim=input_dim,
|
|
num_heads=2,
|
|
ffn_dim=256,
|
|
num_layers=6,
|
|
depthwise_conv_kernel_size=3,
|
|
)
|
|
else:
|
|
self.conformer = None
|
|
|
|
def forward(self, x, length=None):
|
|
|
|
if self.conformer:
|
|
x = self.pos_encoder(x)
|
|
x, _ = self.conformer(x, length)
|
|
return self.nn(x)
|
|
|
|
|
|
class MelodyEncoder(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
|
|
self.input_dim = self.cfg.input_melody_dim
|
|
self.output_dim = self.cfg.output_melody_dim
|
|
self.n_bins = self.cfg.n_bins_melody
|
|
|
|
if self.input_dim != 0:
|
|
if self.n_bins == 0:
|
|
|
|
self.nn = nn.Linear(self.input_dim, self.output_dim)
|
|
else:
|
|
self.f0_min = cfg.f0_min
|
|
self.f0_max = cfg.f0_max
|
|
|
|
self.nn = nn.Embedding(
|
|
num_embeddings=self.n_bins,
|
|
embedding_dim=self.output_dim,
|
|
padding_idx=None,
|
|
)
|
|
self.uv_embedding = nn.Embedding(2, self.output_dim)
|
|
|
|
def forward(self, x, uv=None, length=None):
|
|
|
|
if self.n_bins == 0:
|
|
x = x.unsqueeze(-1)
|
|
else:
|
|
x = f0_to_coarse(x, self.n_bins, self.f0_min, self.f0_max)
|
|
x = self.nn(x)
|
|
|
|
if self.cfg.use_uv:
|
|
uv = self.uv_embedding(uv)
|
|
x = x + uv
|
|
return x
|
|
|
|
|
|
class LoudnessEncoder(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
|
|
self.input_dim = self.cfg.input_loudness_dim
|
|
self.output_dim = self.cfg.output_loudness_dim
|
|
self.n_bins = self.cfg.n_bins_loudness
|
|
|
|
if self.input_dim != 0:
|
|
if self.n_bins == 0:
|
|
|
|
self.nn = nn.Linear(self.input_dim, self.output_dim)
|
|
else:
|
|
|
|
self.loudness_min = 1e-30
|
|
self.loudness_max = 1.5
|
|
self.energy_bins = nn.Parameter(
|
|
torch.exp(
|
|
torch.linspace(
|
|
np.log(self.loudness_min),
|
|
np.log(self.loudness_max),
|
|
self.n_bins - 1,
|
|
)
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
|
|
self.nn = nn.Embedding(
|
|
num_embeddings=self.n_bins,
|
|
embedding_dim=self.output_dim,
|
|
padding_idx=None,
|
|
)
|
|
|
|
def forward(self, x):
|
|
|
|
if self.n_bins == 0:
|
|
x = x.unsqueeze(-1)
|
|
else:
|
|
x = torch.bucketize(x, self.energy_bins)
|
|
return self.nn(x)
|
|
|
|
|
|
class SingerEncoder(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
|
|
self.input_dim = 1
|
|
self.output_dim = self.cfg.output_singer_dim
|
|
|
|
self.nn = nn.Embedding(
|
|
num_embeddings=cfg.singer_table_size,
|
|
embedding_dim=self.output_dim,
|
|
padding_idx=None,
|
|
)
|
|
|
|
def forward(self, x):
|
|
|
|
return self.nn(x)
|
|
|
|
|
|
class ConditionEncoder(nn.Module):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
self.merge_mode = cfg.merge_mode
|
|
|
|
|
|
if cfg.use_whisper:
|
|
self.whisper_encoder = ContentEncoder(
|
|
self.cfg, self.cfg.whisper_dim, self.cfg.content_encoder_dim
|
|
)
|
|
if cfg.use_contentvec:
|
|
self.contentvec_encoder = ContentEncoder(
|
|
self.cfg, self.cfg.contentvec_dim, self.cfg.content_encoder_dim
|
|
)
|
|
if cfg.use_mert:
|
|
self.mert_encoder = ContentEncoder(
|
|
self.cfg, self.cfg.mert_dim, self.cfg.content_encoder_dim
|
|
)
|
|
if cfg.use_wenet:
|
|
self.wenet_encoder = ContentEncoder(
|
|
self.cfg, self.cfg.wenet_dim, self.cfg.content_encoder_dim
|
|
)
|
|
|
|
|
|
if cfg.use_f0:
|
|
self.melody_encoder = MelodyEncoder(self.cfg)
|
|
if cfg.use_energy:
|
|
self.loudness_encoder = LoudnessEncoder(self.cfg)
|
|
|
|
|
|
if cfg.use_spkid:
|
|
self.singer_encoder = SingerEncoder(self.cfg)
|
|
|
|
def forward(self, x):
|
|
outputs = []
|
|
|
|
if self.cfg.use_f0:
|
|
if self.cfg.use_uv:
|
|
pitch_enc_out = self.melody_encoder(
|
|
x["frame_pitch"], uv=x["frame_uv"], length=x["target_len"]
|
|
)
|
|
else:
|
|
pitch_enc_out = self.melody_encoder(
|
|
x["frame_pitch"], uv=None, length=x["target_len"]
|
|
)
|
|
outputs.append(pitch_enc_out)
|
|
|
|
if self.cfg.use_energy:
|
|
loudness_enc_out = self.loudness_encoder(x["frame_energy"])
|
|
outputs.append(loudness_enc_out)
|
|
|
|
if self.cfg.use_whisper:
|
|
|
|
whiser_enc_out = self.whisper_encoder(
|
|
x["whisper_feat"], length=x["target_len"]
|
|
)
|
|
outputs.append(whiser_enc_out)
|
|
seq_len = whiser_enc_out.shape[1]
|
|
|
|
if self.cfg.use_contentvec:
|
|
contentvec_enc_out = self.contentvec_encoder(
|
|
x["contentvec_feat"], length=x["target_len"]
|
|
)
|
|
outputs.append(contentvec_enc_out)
|
|
seq_len = contentvec_enc_out.shape[1]
|
|
|
|
if self.cfg.use_mert:
|
|
mert_enc_out = self.mert_encoder(x["mert_feat"], length=x["target_len"])
|
|
outputs.append(mert_enc_out)
|
|
seq_len = mert_enc_out.shape[1]
|
|
|
|
if self.cfg.use_wenet:
|
|
wenet_enc_out = self.wenet_encoder(x["wenet_feat"], length=x["target_len"])
|
|
outputs.append(wenet_enc_out)
|
|
seq_len = wenet_enc_out.shape[1]
|
|
|
|
if self.cfg.use_spkid:
|
|
speaker_enc_out = self.singer_encoder(x["spk_id"])
|
|
assert (
|
|
"whisper_feat" in x.keys()
|
|
or "contentvec_feat" in x.keys()
|
|
or "mert_feat" in x.keys()
|
|
or "wenet_feat" in x.keys()
|
|
)
|
|
singer_info = speaker_enc_out.expand(-1, seq_len, -1)
|
|
outputs.append(singer_info)
|
|
|
|
encoder_output = None
|
|
if self.merge_mode == "concat":
|
|
encoder_output = torch.cat(outputs, dim=-1)
|
|
if self.merge_mode == "add":
|
|
|
|
outputs = torch.cat([out[None, :, :, :] for out in outputs], dim=0)
|
|
|
|
encoder_output = torch.sum(outputs, dim=0)
|
|
|
|
return encoder_output
|
|
|