DINO-HuVITS / dino_huvits.py
SazerLife's picture
feat: added model
36a67ca
raw
history blame contribute delete
No virus
2.07 kB
import torch
from transformers import PreTrainedModel
from config import DinoHuVitsConfig
from src import CAMPPlus, Flow, HiFiGAN, PosteriorHubert
class DinoHuVits(PreTrainedModel):
config_class = DinoHuVitsConfig
def __init__(self, config: DinoHuVitsConfig):
super().__init__(config)
self.enc_r = CAMPPlus(embed_dim=config.gin_channels, pooling_func="TSTP")
self.enc_q = PosteriorHubert(
out_channels=config.inter_channels,
feature_channels=config.hubert_feature_channels,
downsample_channels=config.hubert_downsample_channels,
output_layer=config.hubert_output_layer,
)
self.flow = Flow(
channels=config.inter_channels,
hidden_channels=config.hidden_channels,
kernel_size=5,
dilation_rate=1,
n_layers=4,
gin_channels=config.gin_channels,
)
self.dec = HiFiGAN(
initial_channel=config.inter_channels,
resblock=config.resblock,
resblock_kernel_sizes=config.resblock_kernel_sizes,
resblock_dilation_sizes=config.resblock_dilation_sizes,
upsample_rates=config.upsample_rates,
upsample_initial_channel=config.upsample_initial_channel,
upsample_kernel_sizes=config.upsample_kernel_sizes,
gin_channels=config.gin_channels,
)
def forward(
self, content: torch.Tensor, lengths: torch.Tensor, reference: torch.Tensor
):
g_src = self.__get_style_embedding(content)
g_tgt = self.__get_style_embedding(reference)
z, _, _, y_mask = self.enc_q(content, lengths, g=g_src)
z_p = self.flow(z, y_mask, g=g_src)
z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
o_hat = self.dec(z_hat * y_mask, g=g_tgt)
return o_hat, y_mask
def __get_style_embedding(self, wavefrom: torch.Tensor):
g = self.enc_r(wavefrom) # [b, h, 1]
g = torch.nn.functional.normalize(g, dim=1)
return g.unsqueeze(-1)