llmvc / stylekan /mlstm_models.py
Respair's picture
Upload folder using huggingface_hub
9e121db verified
# import os
# import os.path as osp
# import copy
# import math
# import numpy as np
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
# from Utils.ASR.models import ASRCNN
# from Utils.JDC.model import JDCNet
# from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
# from Modules.diffusion.modules import Transformer1d, StyleTransformer1d
# from Modules.diffusion.diffusion import AudioDiffusionConditional
# from Modules.discriminators import MultiPeriodDiscriminator, MultiResSpecDiscriminator, WavLMDiscriminator
# from munch import Munch
# import yaml
# from hflayers import Hopfield, HopfieldPooling, HopfieldLayer
# from hflayers.auxiliary.data import BitPatternSet
# # Import auxiliary modules.
# from distutils.version import LooseVersion
# from typing import List, Tuple
# import math
# import torch
# from xlstm import (
# xLSTMBlockStack,
# xLSTMBlockStackConfig,
# mLSTMBlockConfig,
# mLSTMLayerConfig,
# sLSTMBlockConfig,
# sLSTMLayerConfig,
# FeedForwardConfig,
# )
# class LearnedDownSample(nn.Module):
# def __init__(self, layer_type, dim_in):
# super().__init__()
# self.layer_type = layer_type
# if self.layer_type == 'none':
# self.conv = nn.Identity()
# elif self.layer_type == 'timepreserve':
# self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
# elif self.layer_type == 'half':
# self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
# else:
# raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
# def forward(self, x):
# return self.conv(x)
# class LearnedUpSample(nn.Module):
# def __init__(self, layer_type, dim_in):
# super().__init__()
# self.layer_type = layer_type
# if self.layer_type == 'none':
# self.conv = nn.Identity()
# elif self.layer_type == 'timepreserve':
# self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
# elif self.layer_type == 'half':
# self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
# else:
# raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
# def forward(self, x):
# return self.conv(x)
# class DownSample(nn.Module):
# def __init__(self, layer_type):
# super().__init__()
# self.layer_type = layer_type
# def forward(self, x):
# if self.layer_type == 'none':
# return x
# elif self.layer_type == 'timepreserve':
# return F.avg_pool2d(x, (2, 1))
# elif self.layer_type == 'half':
# if x.shape[-1] % 2 != 0:
# x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
# return F.avg_pool2d(x, 2)
# else:
# raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
# class UpSample(nn.Module):
# def __init__(self, layer_type):
# super().__init__()
# self.layer_type = layer_type
# def forward(self, x):
# if self.layer_type == 'none':
# return x
# elif self.layer_type == 'timepreserve':
# return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
# elif self.layer_type == 'half':
# return F.interpolate(x, scale_factor=2, mode='nearest')
# else:
# raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
# class ResBlk(nn.Module):
# def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
# normalize=False, downsample='none'):
# super().__init__()
# self.actv = actv
# self.normalize = normalize
# self.downsample = DownSample(downsample)
# self.downsample_res = LearnedDownSample(downsample, dim_in)
# self.learned_sc = dim_in != dim_out
# self._build_weights(dim_in, dim_out)
# def _build_weights(self, dim_in, dim_out):
# self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
# self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
# if self.normalize:
# self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
# self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
# if self.learned_sc:
# self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
# def _shortcut(self, x):
# if self.learned_sc:
# x = self.conv1x1(x)
# if self.downsample:
# x = self.downsample(x)
# return x
# def _residual(self, x):
# if self.normalize:
# x = self.norm1(x)
# x = self.actv(x)
# x = self.conv1(x)
# x = self.downsample_res(x)
# if self.normalize:
# x = self.norm2(x)
# x = self.actv(x)
# x = self.conv2(x)
# return x
# def forward(self, x):
# x = self._shortcut(x) + self._residual(x)
# return x / math.sqrt(2) # unit variance
# class StyleEncoder(nn.Module):
# def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
# super().__init__()
# blocks = []
# blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
# repeat_num = 4
# for _ in range(repeat_num):
# dim_out = min(dim_in*2, max_conv_dim)
# blocks += [ResBlk(dim_in, dim_out, downsample='half')]
# dim_in = dim_out
# blocks += [nn.LeakyReLU(0.2)]
# blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
# blocks += [nn.AdaptiveAvgPool2d(1)]
# blocks += [nn.LeakyReLU(0.2)]
# self.shared = nn.Sequential(*blocks)
# self.unshared = nn.Linear(dim_out, style_dim)
# def forward(self, x):
# h = self.shared(x)
# h = h.view(h.size(0), -1)
# s = self.unshared(h)
# return s
# class LinearNorm(torch.nn.Module):
# def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
# super(LinearNorm, self).__init__()
# self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
# torch.nn.init.xavier_uniform_(
# self.linear_layer.weight,
# gain=torch.nn.init.calculate_gain(w_init_gain))
# def forward(self, x):
# return self.linear_layer(x)
# class Discriminator2d(nn.Module):
# def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
# super().__init__()
# blocks = []
# blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
# for lid in range(repeat_num):
# dim_out = min(dim_in*2, max_conv_dim)
# blocks += [ResBlk(dim_in, dim_out, downsample='half')]
# dim_in = dim_out
# blocks += [nn.LeakyReLU(0.2)]
# blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
# blocks += [nn.LeakyReLU(0.2)]
# blocks += [nn.AdaptiveAvgPool2d(1)]
# blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
# self.main = nn.Sequential(*blocks)
# def get_feature(self, x):
# features = []
# for l in self.main:
# x = l(x)
# features.append(x)
# out = features[-1]
# out = out.view(out.size(0), -1) # (batch, num_domains)
# return out, features
# def forward(self, x):
# out, features = self.get_feature(x)
# out = out.squeeze() # (batch)
# return out, features
# class ResBlk1d(nn.Module):
# def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
# normalize=False, downsample='none', dropout_p=0.2):
# super().__init__()
# self.actv = actv
# self.normalize = normalize
# self.downsample_type = downsample
# self.learned_sc = dim_in != dim_out
# self._build_weights(dim_in, dim_out)
# self.dropout_p = dropout_p
# if self.downsample_type == 'none':
# self.pool = nn.Identity()
# else:
# self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
# def _build_weights(self, dim_in, dim_out):
# self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
# self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
# if self.normalize:
# self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
# self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
# if self.learned_sc:
# self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
# def downsample(self, x):
# if self.downsample_type == 'none':
# return x
# else:
# if x.shape[-1] % 2 != 0:
# x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
# return F.avg_pool1d(x, 2)
# def _shortcut(self, x):
# if self.learned_sc:
# x = self.conv1x1(x)
# x = self.downsample(x)
# return x
# def _residual(self, x):
# if self.normalize:
# x = self.norm1(x)
# x = self.actv(x)
# x = F.dropout(x, p=self.dropout_p, training=self.training)
# x = self.conv1(x)
# x = self.pool(x)
# if self.normalize:
# x = self.norm2(x)
# x = self.actv(x)
# x = F.dropout(x, p=self.dropout_p, training=self.training)
# x = self.conv2(x)
# return x
# def forward(self, x):
# x = self._shortcut(x) + self._residual(x)
# return x / math.sqrt(2) # unit variance
# class LayerNorm(nn.Module):
# def __init__(self, channels, eps=1e-5):
# super().__init__()
# self.channels = channels
# self.eps = eps
# self.gamma = nn.Parameter(torch.ones(channels))
# self.beta = nn.Parameter(torch.zeros(channels))
# def forward(self, x):
# x = x.transpose(1, -1)
# x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
# return x.transpose(1, -1)
# class TextEncoder(nn.Module):
# def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
# super().__init__()
# self.embedding = nn.Embedding(n_symbols, channels)
# padding = (kernel_size - 1) // 2
# self.cnn = nn.ModuleList()
# for _ in range(depth):
# self.cnn.append(nn.Sequential(
# weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
# LayerNorm(channels),
# actv,
# nn.Dropout(0.2),
# ))
# # self.cnn = nn.Sequential(*self.cnn)
# self.lstm = Hopfield(input_size=channels,
# hidden_size=channels // 2,
# num_heads=32,
# # scaling=.75,
# add_zero_association=True,
# batch_first=True)
# def forward(self, x, input_lengths, m):
# x = self.embedding(x) # [B, T, emb]
# x = x.transpose(1, 2) # [B, emb, T]
# m = m.to(input_lengths.device).unsqueeze(1)
# x.masked_fill_(m, 0.0)
# for c in self.cnn:
# x = c(x)
# x.masked_fill_(m, 0.0)
# x = x.transpose(1, 2) # [B, T, chn]
# input_lengths = input_lengths.cpu().numpy()
# # x = nn.utils.rnn.pack_padded_sequence(
# # x, input_lengths, batch_first=True, enforce_sorted=False)
# # self.lstm.flatten_parameters()
# x = self.lstm(x)
# # x, _ = nn.utils.rnn.pad_packed_sequence(
# # x, batch_first=True)
# x = x.transpose(-1, -2)
# # x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
# # x_pad[:, :, :x.shape[-1]] = x
# # x = x_pad.to(x.device)
# x.masked_fill_(m, 0.0)
# return x
# def inference(self, x):
# x = self.embedding(x)
# x = x.transpose(1, 2)
# x = self.cnn(x)
# x = x.transpose(1, 2)
# # self.lstm.flatten_parameters()
# x = self.lstm(x)
# return x
# def length_to_mask(self, lengths):
# mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
# mask = torch.gt(mask+1, lengths.unsqueeze(1))
# return mask
# class AdaIN1d(nn.Module):
# def __init__(self, style_dim, num_features):
# super().__init__()
# self.norm = nn.InstanceNorm1d(num_features, affine=False)
# self.fc = nn.Linear(style_dim, num_features*2)
# def forward(self, x, s):
# h = self.fc(s)
# h = h.view(h.size(0), h.size(1), 1)
# gamma, beta = torch.chunk(h, chunks=2, dim=1)
# return (1 + gamma) * self.norm(x) + beta
# class UpSample1d(nn.Module):
# def __init__(self, layer_type):
# super().__init__()
# self.layer_type = layer_type
# def forward(self, x):
# if self.layer_type == 'none':
# return x
# else:
# return F.interpolate(x, scale_factor=2, mode='nearest')
# class AdainResBlk1d(nn.Module):
# def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
# upsample='none', dropout_p=0.0):
# super().__init__()
# self.actv = actv
# self.upsample_type = upsample
# self.upsample = UpSample1d(upsample)
# self.learned_sc = dim_in != dim_out
# self._build_weights(dim_in, dim_out, style_dim)
# self.dropout = nn.Dropout(dropout_p)
# if upsample == 'none':
# self.pool = nn.Identity()
# else:
# self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
# def _build_weights(self, dim_in, dim_out, style_dim):
# self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
# self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
# self.norm1 = AdaIN1d(style_dim, dim_in)
# self.norm2 = AdaIN1d(style_dim, dim_out)
# if self.learned_sc:
# self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
# def _shortcut(self, x):
# x = self.upsample(x)
# if self.learned_sc:
# x = self.conv1x1(x)
# return x
# def _residual(self, x, s):
# x = self.norm1(x, s)
# x = self.actv(x)
# x = self.pool(x)
# x = self.conv1(self.dropout(x))
# x = self.norm2(x, s)
# x = self.actv(x)
# x = self.conv2(self.dropout(x))
# return x
# def forward(self, x, s):
# out = self._residual(x, s)
# out = (out + self._shortcut(x)) / math.sqrt(2)
# return out
# class AdaLayerNorm(nn.Module):
# def __init__(self, style_dim, channels, eps=1e-5):
# super().__init__()
# self.channels = channels
# self.eps = eps
# self.fc = nn.Linear(style_dim, channels*2)
# def forward(self, x, s):
# x = x.transpose(-1, -2)
# x = x.transpose(1, -1)
# h = self.fc(s)
# h = h.view(h.size(0), h.size(1), 1)
# gamma, beta = torch.chunk(h, chunks=2, dim=1)
# gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
# x = F.layer_norm(x, (self.channels,), eps=self.eps)
# x = (1 + gamma) * x + beta
# return x.transpose(1, -1).transpose(-1, -2)
# # class ProsodyPredictor(nn.Module):
# # def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
# # super().__init__()
# # self.text_encoder = DurationEncoder(sty_dim=style_dim,
# # d_model=d_hid,
# # nlayers=nlayers,
# # dropout=dropout)
# # self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
# # self.duration_proj = LinearNorm(d_hid, max_dur)
# # self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
# # self.F0 = nn.ModuleList()
# # self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
# # self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
# # self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
# # self.N = nn.ModuleList()
# # self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
# # self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
# # self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
# # self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
# # self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
# # def forward(self, texts, style, text_lengths, alignment, m):
# # d = self.text_encoder(texts, style, text_lengths, m)
# # batch_size = d.shape[0]
# # text_size = d.shape[1]
# # # predict duration
# # input_lengths = text_lengths.cpu().numpy()
# # x = nn.utils.rnn.pack_padded_sequence(
# # d, input_lengths, batch_first=True, enforce_sorted=False)
# # m = m.to(text_lengths.device).unsqueeze(1)
# # self.lstm.flatten_parameters()
# # x, _ = self.lstm(x)
# # x, _ = nn.utils.rnn.pad_packed_sequence(
# # x, batch_first=True)
# # x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
# # x_pad[:, :x.shape[1], :] = x
# # x = x_pad.to(x.device)
# # duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
# # en = (d.transpose(-1, -2) @ alignment)
# # return duration.squeeze(-1), en
# class ProsodyPredictor(nn.Module):
# def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
# super().__init__()
# self.text_encoder = DurationEncoder(sty_dim=style_dim,
# d_model=d_hid,
# nlayers=nlayers,
# dropout=dropout)
# self.lstm = Hopfield(input_size=d_hid + style_dim,
# hidden_size=d_hid // 2,
# num_heads=32,
# # scaling=.75,
# add_zero_association=True,
# batch_first=True)
# self.prepare_projection = nn.Linear(d_hid + style_dim, d_hid)
# self.duration_proj = LinearNorm(d_hid , max_dur)
# self.shared = Hopfield(input_size=d_hid + style_dim,
# hidden_size=d_hid // 2,
# num_heads=32,
# # scaling=.75,
# add_zero_association=True,
# batch_first=True)
# #self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
# self.F0 = nn.ModuleList()
# self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
# self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
# self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
# self.N = nn.ModuleList()
# self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
# self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
# self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
# self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
# self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
# def forward(self, texts, style, text_lengths, alignment, m):
# d = self.text_encoder(texts, style, text_lengths, m)
# batch_size = d.shape[0]
# text_size = d.shape[1]
# # predict duration
# input_lengths = text_lengths.cpu().numpy()
# # x = nn.utils.rnn.pack_padded_sequence(
# # d, input_lengths, batch_first=True, enforce_sorted=False)
# x = d # this dude can handle variable seq len so no need for packing
# m = m.to(text_lengths.device).unsqueeze(1)
# # self.lstm.flatten_parameters()
# x = self.lstm(x) # no longer using lstm
# x = self.prepare_projection(x)
# # x, _ = nn.utils.rnn.pad_packed_sequence(
# # x, batch_first=True)
# x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
# x_pad[:, :x.shape[1], :] = x
# x = x_pad.to(x.device)
# x = x.transpose(-1,-2)
# x = x.permute(0,2,1)
# duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
# en = (d.transpose(-1, -2) @ alignment)
# return duration.squeeze(-1), en
# def F0Ntrain(self, x, s):
# x = self.shared(x.transpose(-1, -2))
# x = self.prepare_projection(x)
# F0 = x.transpose(-1, -2)
# for block in self.F0:
# F0 = block(F0, s)
# F0 = self.F0_proj(F0)
# N = x.transpose(-1, -2)
# for block in self.N:
# N = block(N, s)
# N = self.N_proj(N)
# return F0.squeeze(1), N.squeeze(1)
# def length_to_mask(self, lengths):
# mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
# mask = torch.gt(mask+1, lengths.unsqueeze(1))
# return mask
# class DurationEncoder(nn.Module):
# def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
# super().__init__()
# self.lstms = nn.ModuleList()
# for _ in range(nlayers):
# self.lstms.append(nn.GRU(d_model + sty_dim,
# d_model // 2,
# num_layers=1,
# batch_first=True,
# bidirectional=True,
# dropout=dropout))
# self.lstms.append(AdaLayerNorm(sty_dim, d_model))
# self.dropout = dropout
# self.d_model = d_model
# self.sty_dim = sty_dim
# def forward(self, x, style, text_lengths, m):
# masks = m.to(text_lengths.device)
# x = x.permute(2, 0, 1)
# s = style.expand(x.shape[0], x.shape[1], -1)
# x = torch.cat([x, s], axis=-1)
# x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
# x = x.transpose(0, 1)
# input_lengths = text_lengths.cpu().numpy()
# x = x.transpose(-1, -2)
# for block in self.lstms:
# if isinstance(block, AdaLayerNorm):
# x = block(x.transpose(-1, -2), style).transpose(-1, -2)
# x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
# x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
# else:
# x = x.transpose(-1, -2)
# x = nn.utils.rnn.pack_padded_sequence(
# x, input_lengths, batch_first=True, enforce_sorted=False)
# block.flatten_parameters()
# x, _ = block(x)
# x, _ = nn.utils.rnn.pad_packed_sequence(
# x, batch_first=True)
# x = F.dropout(x, p=self.dropout, training=self.training)
# x = x.transpose(-1, -2)
# x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
# x_pad[:, :, :x.shape[-1]] = x
# x = x_pad.to(x.device)
# return x.transpose(-1, -2)
# def inference(self, x, style):
# x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
# style = style.expand(x.shape[0], x.shape[1], -1)
# x = torch.cat([x, style], axis=-1)
# src = self.pos_encoder(x)
# output = self.transformer_encoder(src).transpose(0, 1)
# return output
# def length_to_mask(self, lengths):
# mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
# mask = torch.gt(mask+1, lengths.unsqueeze(1))
# return mask
# def inference(self, x, style):
# x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
# style = style.expand(x.shape[0], x.shape[1], -1)
# x = torch.cat([x, style], axis=-1)
# src = self.pos_encoder(x)
# output = self.transformer_encoder(src).transpose(0, 1)
# return output
# def length_to_mask(self, lengths):
# mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
# mask = torch.gt(mask+1, lengths.unsqueeze(1))
# return mask
# def load_F0_models(path):
# # load F0 model
# F0_model = JDCNet(num_class=1, seq_len=192)
# params = torch.load(path, map_location='cpu')['net']
# F0_model.load_state_dict(params)
# _ = F0_model.train()
# return F0_model
# def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
# # load ASR model
# def _load_config(path):
# with open(path) as f:
# config = yaml.safe_load(f)
# model_config = config['model_params']
# return model_config
# def _load_model(model_config, model_path):
# model = ASRCNN(**model_config)
# params = torch.load(model_path, map_location='cpu')['model']
# model.load_state_dict(params)
# return model
# asr_model_config = _load_config(ASR_MODEL_CONFIG)
# asr_model = _load_model(asr_model_config, ASR_MODEL_PATH)
# _ = asr_model.train()
# return asr_model
# def build_model(args, text_aligner, pitch_extractor, bert):
# assert args.decoder.type in ['istftnet', 'hifigan'], 'Decoder type unknown'
# if args.decoder.type == "istftnet":
# from Modules.istftnet import Decoder
# decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
# resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
# upsample_rates = args.decoder.upsample_rates,
# upsample_initial_channel=args.decoder.upsample_initial_channel,
# resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
# upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
# gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
# else:
# from Modules.hifigan import Decoder
# decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
# resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
# upsample_rates = args.decoder.upsample_rates,
# upsample_initial_channel=args.decoder.upsample_initial_channel,
# resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
# upsample_kernel_sizes=args.decoder.upsample_kernel_sizes)
# text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
# predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
# style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # acoustic style encoder
# predictor_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # prosodic style encoder
# # define diffusion model
# if args.multispeaker:
# transformer = StyleTransformer1d(channels=args.style_dim*2,
# context_embedding_features=bert.config.hidden_size,
# context_features=args.style_dim*2,
# **args.diffusion.transformer)
# else:
# transformer = Transformer1d(channels=args.style_dim*2,
# context_embedding_features=bert.config.hidden_size,
# **args.diffusion.transformer)
# diffusion = AudioDiffusionConditional(
# in_channels=1,
# embedding_max_length=bert.config.max_position_embeddings,
# embedding_features=bert.config.hidden_size,
# embedding_mask_proba=args.diffusion.embedding_mask_proba, # Conditional dropout of batch elements,
# channels=args.style_dim*2,
# context_features=args.style_dim*2,
# )
# diffusion.diffusion = KDiffusion(
# net=diffusion.unet,
# sigma_distribution=LogNormalDistribution(mean = args.diffusion.dist.mean, std = args.diffusion.dist.std),
# sigma_data=args.diffusion.dist.sigma_data, # a placeholder, will be changed dynamically when start training diffusion model
# dynamic_threshold=0.0
# )
# diffusion.diffusion.net = transformer
# diffusion.unet = transformer
# nets = Munch(
# bert=bert,
# bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
# predictor=predictor,
# decoder=decoder,
# text_encoder=text_encoder,
# predictor_encoder=predictor_encoder,
# style_encoder=style_encoder,
# diffusion=diffusion,
# text_aligner = text_aligner,
# pitch_extractor=pitch_extractor,
# mpd = MultiPeriodDiscriminator(),
# msd = MultiResSpecDiscriminator(),
# # slm discriminator head
# wd = WavLMDiscriminator(args.slm.hidden, args.slm.nlayers, args.slm.initial_channel),
# )
# return nets
# def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[]):
# state = torch.load(path, map_location='cpu')
# params = state['net']
# for key in model:
# if key in params and key not in ignore_modules:
# print('%s loaded' % key)
# try:
# model[key].load_state_dict(params[key], strict=True)
# except:
# from collections import OrderedDict
# state_dict = params[key]
# new_state_dict = OrderedDict()
# print(f'{key} key length: {len(model[key].state_dict().keys())}, state_dict length: {len(state_dict.keys())}')
# for (k_m, v_m), (k_c, v_c) in zip(model[key].state_dict().items(), state_dict.items()):
# new_state_dict[k_m] = v_c
# model[key].load_state_dict(new_state_dict, strict=True)
# _ = [model[key].eval() for key in model]
# if not load_only_params:
# epoch = state["epoch"]
# iters = state["iters"]
# optimizer.load_state_dict(state["optimizer"])
# else:
# epoch = 0
# iters = 0
# return model, optimizer, epoch, iters
##############################################################################################################
##############################################################################################################
##############################################################################################################
# mLSTM
##############################################################################################################
##############################################################################################################
##############################################################################################################
import os
import os.path as osp
import copy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from Utils.ASR.models import ASRCNN
from Utils.JDC.model import JDCNet
from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
from Modules.diffusion.modules import Transformer1d, StyleTransformer1d
from Modules.diffusion.diffusion import AudioDiffusionConditional
from Modules.discriminators import MultiPeriodDiscriminator, MultiResSpecDiscriminator, WavLMDiscriminator
from munch import Munch
import yaml
# from hflayers import Hopfield, HopfieldPooling, HopfieldLayer
# from hflayers.auxiliary.data import BitPatternSet
# Import auxiliary modules.
from distutils.version import LooseVersion
from typing import List, Tuple
import math
# from liger_kernel.ops.layer_norm import LigerLayerNormFunction
# from liger_kernel.transformers.experimental.embedding import nn.Embedding
import torch
from xlstm import (
xLSTMBlockStack,
xLSTMBlockStackConfig,
mLSTMBlockConfig,
mLSTMLayerConfig,
sLSTMBlockConfig,
sLSTMLayerConfig,
FeedForwardConfig,
)
class LearnedDownSample(nn.Module):
def __init__(self, layer_type, dim_in):
super().__init__()
self.layer_type = layer_type
if self.layer_type == 'none':
self.conv = nn.Identity()
elif self.layer_type == 'timepreserve':
self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
elif self.layer_type == 'half':
self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
else:
raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
def forward(self, x):
return self.conv(x)
class LearnedUpSample(nn.Module):
def __init__(self, layer_type, dim_in):
super().__init__()
self.layer_type = layer_type
if self.layer_type == 'none':
self.conv = nn.Identity()
elif self.layer_type == 'timepreserve':
self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
elif self.layer_type == 'half':
self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
else:
raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
def forward(self, x):
return self.conv(x)
class DownSample(nn.Module):
def __init__(self, layer_type):
super().__init__()
self.layer_type = layer_type
def forward(self, x):
if self.layer_type == 'none':
return x
elif self.layer_type == 'timepreserve':
return F.avg_pool2d(x, (2, 1))
elif self.layer_type == 'half':
if x.shape[-1] % 2 != 0:
x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
return F.avg_pool2d(x, 2)
else:
raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
class UpSample(nn.Module):
def __init__(self, layer_type):
super().__init__()
self.layer_type = layer_type
def forward(self, x):
if self.layer_type == 'none':
return x
elif self.layer_type == 'timepreserve':
return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
elif self.layer_type == 'half':
return F.interpolate(x, scale_factor=2, mode='nearest')
else:
raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
class ResBlk(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize=False, downsample='none'):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = DownSample(downsample)
self.downsample_res = LearnedDownSample(downsample, dim_in)
self.learned_sc = dim_in != dim_out
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
if self.normalize:
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
if self.learned_sc:
self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = self.downsample(x)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
x = self.downsample_res(x)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x / math.sqrt(2) # unit variance
class StyleEncoder(nn.Module):
def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
super().__init__()
blocks = []
blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
repeat_num = 4
for _ in range(repeat_num):
dim_out = min(dim_in*2, max_conv_dim)
blocks += [ResBlk(dim_in, dim_out, downsample='half')]
dim_in = dim_out
blocks += [nn.LeakyReLU(0.2)]
blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
blocks += [nn.AdaptiveAvgPool2d(1)]
blocks += [nn.LeakyReLU(0.2)]
self.shared = nn.Sequential(*blocks)
self.unshared = nn.Linear(dim_out, style_dim)
def forward(self, x):
h = self.shared(x)
h = h.view(h.size(0), -1)
s = self.unshared(h)
return s
class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
super(LinearNorm, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, x):
return self.linear_layer(x)
class Discriminator2d(nn.Module):
def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
super().__init__()
blocks = []
blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
for lid in range(repeat_num):
dim_out = min(dim_in*2, max_conv_dim)
blocks += [ResBlk(dim_in, dim_out, downsample='half')]
dim_in = dim_out
blocks += [nn.LeakyReLU(0.2)]
blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
blocks += [nn.LeakyReLU(0.2)]
blocks += [nn.AdaptiveAvgPool2d(1)]
blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
self.main = nn.Sequential(*blocks)
def get_feature(self, x):
features = []
for l in self.main:
x = l(x)
features.append(x)
out = features[-1]
out = out.view(out.size(0), -1) # (batch, num_domains)
return out, features
def forward(self, x):
out, features = self.get_feature(x)
out = out.squeeze() # (batch)
return out, features
class ResBlk1d(nn.Module):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize=False, downsample='none', dropout_p=0.2):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample_type = downsample
self.learned_sc = dim_in != dim_out
self._build_weights(dim_in, dim_out)
self.dropout_p = dropout_p
if self.downsample_type == 'none':
self.pool = nn.Identity()
else:
self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
def _build_weights(self, dim_in, dim_out):
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
if self.normalize:
self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
if self.learned_sc:
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
def downsample(self, x):
if self.downsample_type == 'none':
return x
else:
if x.shape[-1] % 2 != 0:
x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
return F.avg_pool1d(x, 2)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
x = self.downsample(x)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = F.dropout(x, p=self.dropout_p, training=self.training)
x = self.conv1(x)
x = self.pool(x)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = F.dropout(x, p=self.dropout_p, training=self.training)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x / math.sqrt(2) # unit variance
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class TextEncoder(nn.Module):
def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
super().__init__()
self.embedding = nn.Embedding(n_symbols, channels)
self.prepare_projection=LinearNorm(channels,channels // 2)
self.post_projection=LinearNorm(channels // 2,channels)
self.cfg = xLSTMBlockStackConfig(
mlstm_block=mLSTMBlockConfig(
mlstm=mLSTMLayerConfig(
conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
)
),
# slstm_block=sLSTMBlockConfig(
# slstm=sLSTMLayerConfig(
# backend="cuda",
# num_heads=4,
# conv1d_kernel_size=4,
# bias_init="powerlaw_blockdependent",
# ),
# feedforward=FeedForwardConfig(proj_factor=1.3, act_fn="gelu"),
# ),
context_length=channels,
num_blocks=8,
embedding_dim=channels // 2,
# slstm_at=[1],
)
padding = (kernel_size - 1) // 2
self.cnn = nn.ModuleList()
for _ in range(depth):
self.cnn.append(nn.Sequential(
weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
LayerNorm(channels),
actv,
nn.Dropout(0.2),
))
# self.cnn = nn.Sequential(*self.cnn)
self.lstm = xLSTMBlockStack(self.cfg)
def forward(self, x, input_lengths, m):
x = self.embedding(x) # [B, T, emb]
x = x.transpose(1, 2) # [B, emb, T]
m = m.to(input_lengths.device).unsqueeze(1)
x.masked_fill_(m, 0.0)
for c in self.cnn:
x = c(x)
x.masked_fill_(m, 0.0)
x = x.transpose(1, 2) # [B, T, chn]
input_lengths = input_lengths.cpu().numpy()
x = self.prepare_projection(x)
# x = nn.utils.rnn.pack_padded_sequence(
# x, input_lengths, batch_first=True, enforce_sorted=False)
# self.lstm.flatten_parameters()
x = self.lstm(x)
x = self.post_projection(x)
# x, _ = nn.utils.rnn.pad_packed_sequence(
# x, batch_first=True)
x = x.transpose(-1, -2)
# x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
# x_pad[:, :, :x.shape[-1]] = x
# x = x_pad.to(x.device)
x.masked_fill_(m, 0.0)
return x
def inference(self, x):
x = self.embedding(x)
x = x.transpose(1, 2)
x = self.cnn(x)
x = x.transpose(1, 2)
# self.lstm.flatten_parameters()
x = self.lstm(x)
return x
def length_to_mask(self, lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask
class AdaIN1d(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm1d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
class UpSample1d(nn.Module):
def __init__(self, layer_type):
super().__init__()
self.layer_type = layer_type
def forward(self, x):
if self.layer_type == 'none':
return x
else:
return F.interpolate(x, scale_factor=2, mode='nearest')
class AdainResBlk1d(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
upsample='none', dropout_p=0.0):
super().__init__()
self.actv = actv
self.upsample_type = upsample
self.upsample = UpSample1d(upsample)
self.learned_sc = dim_in != dim_out
self._build_weights(dim_in, dim_out, style_dim)
self.dropout = nn.Dropout(dropout_p)
if upsample == 'none':
self.pool = nn.Identity()
else:
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
def _build_weights(self, dim_in, dim_out, style_dim):
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
self.norm1 = AdaIN1d(style_dim, dim_in)
self.norm2 = AdaIN1d(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
def _shortcut(self, x):
x = self.upsample(x)
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
x = self.pool(x)
x = self.conv1(self.dropout(x))
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(self.dropout(x))
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / math.sqrt(2)
return out
class AdaLayerNorm(nn.Module):
def __init__(self, style_dim, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.fc = nn.Linear(style_dim, channels*2)
def forward(self, x, s):
x = x.transpose(-1, -2)
x = x.transpose(1, -1)
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), eps=self.eps)
x = (1 + gamma) * x + beta
return x.transpose(1, -1).transpose(-1, -2)
# class ProsodyPredictor(nn.Module):
# def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
# super().__init__()
# self.text_encoder = DurationEncoder(sty_dim=style_dim,
# d_model=d_hid,
# nlayers=nlayers,
# dropout=dropout)
# self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
# self.duration_proj = LinearNorm(d_hid, max_dur)
# self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
# self.F0 = nn.ModuleList()
# self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
# self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
# self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
# self.N = nn.ModuleList()
# self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
# self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
# self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
# self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
# self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
# def forward(self, texts, style, text_lengths, alignment, m):
# d = self.text_encoder(texts, style, text_lengths, m)
# batch_size = d.shape[0]
# text_size = d.shape[1]
# # predict duration
# input_lengths = text_lengths.cpu().numpy()
# x = nn.utils.rnn.pack_padded_sequence(
# d, input_lengths, batch_first=True, enforce_sorted=False)
# m = m.to(text_lengths.device).unsqueeze(1)
# self.lstm.flatten_parameters()
# x, _ = self.lstm(x)
# x, _ = nn.utils.rnn.pad_packed_sequence(
# x, batch_first=True)
# x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
# x_pad[:, :x.shape[1], :] = x
# x = x_pad.to(x.device)
# duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
# en = (d.transpose(-1, -2) @ alignment)
# return duration.squeeze(-1), en
class ProsodyPredictor(nn.Module):
def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
super().__init__()
self.cfg = xLSTMBlockStackConfig(
mlstm_block=mLSTMBlockConfig(
mlstm=mLSTMLayerConfig(
conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
)
),
context_length=d_hid,
num_blocks=8,
embedding_dim=d_hid + style_dim,
)
self.cfg_pred = xLSTMBlockStackConfig(
mlstm_block=mLSTMBlockConfig(
mlstm=mLSTMLayerConfig(
conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
)
),
context_length=4096,
num_blocks=8,
embedding_dim=d_hid,
)
# self.shared = Hopfield(input_size=d_hid + style_dim,
# hidden_size=d_hid // 2,
# num_heads=32,
# # scaling=.75,
# add_zero_association=True,
# batch_first=True)
self.text_encoder = DurationEncoder(sty_dim=style_dim,
d_model=d_hid,
nlayers=nlayers,
dropout=dropout)
self.lstm = xLSTMBlockStack(self.cfg)
self.prepare_projection = nn.Linear(d_hid + style_dim, d_hid)
self.duration_proj = LinearNorm(d_hid , max_dur)
self.shared = xLSTMBlockStack(self.cfg)
# self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
self.F0 = nn.ModuleList()
self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
self.N = nn.ModuleList()
self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
def forward(self, texts, style, text_lengths=None, alignment=None, m=None, f0=False):
if f0:
x, s = texts, style
x = self.shared(x.transpose(-1, -2))
x = self.prepare_projection(x)
F0 = x.transpose(-1, -2)
for block in self.F0:
F0 = block(F0, s)
F0 = self.F0_proj(F0)
N = x.transpose(-1, -2)
for block in self.N:
N = block(N, s)
N = self.N_proj(N)
return F0.squeeze(1), N.squeeze(1)
else:
# Problem is here
d = self.text_encoder(texts, style, text_lengths, m)
batch_size = d.shape[0]
text_size = d.shape[1]
# predict duration
input_lengths = text_lengths.cpu().numpy()
# x = nn.utils.rnn.pack_padded_sequence(
# d, input_lengths, batch_first=True, enforce_sorted=False)
x = d # this dude can handle variable seq len so no need for padding
m = m.to(text_lengths.device).unsqueeze(1)
# self.lstm.flatten_parameters()
x = self.lstm(x) # no longer using lstm
x = self.prepare_projection(x)
# x, _ = nn.utils.rnn.pad_packed_sequence(
# x, batch_first=True)
# x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
# x_pad[:, :x.shape[1], :] = x
# x = x_pad.to(x.device)
x = x.transpose(-1,-2)
x = x.permute(0,2,1)
duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
en = (d.transpose(-1, -2) @ alignment)
return duration.squeeze(-1), en
def F0Ntrain(self, x, s):
# x = self.prepare_projection(x.transpose(-1, -2))
# x = self.shared(x)
####
x = self.shared(x.transpose(-1, -2))
x = self.prepare_projection(x)
F0 = x.transpose(-1, -2)
for block in self.F0:
F0 = block(F0, s)
F0 = self.F0_proj(F0)
N = x.transpose(-1, -2)
for block in self.N:
N = block(N, s)
N = self.N_proj(N)
return F0.squeeze(1), N.squeeze(1)
def length_to_mask(self, lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask
class DurationEncoder(nn.Module):
def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
super().__init__()
self.lstms = nn.ModuleList()
for _ in range(nlayers):
self.lstms.append(nn.LSTM(d_model + sty_dim,
d_model // 2,
num_layers=1,
batch_first=True,
bidirectional=True,
dropout=dropout))
self.lstms.append(AdaLayerNorm(sty_dim, d_model))
self.dropout = dropout
self.d_model = d_model
self.sty_dim = sty_dim
def forward(self, x, style, text_lengths, m):
masks = m.to(text_lengths.device)
x = x.permute(2, 0, 1)
s = style.expand(x.shape[0], x.shape[1], -1)
x = torch.cat([x, s], axis=-1)
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
x = x.transpose(0, 1)
input_lengths = text_lengths.cpu().numpy()
x = x.transpose(-1, -2)
for block in self.lstms:
if isinstance(block, AdaLayerNorm):
x = block(x.transpose(-1, -2), style).transpose(-1, -2)
x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
else:
x = x.transpose(-1, -2)
x = nn.utils.rnn.pack_padded_sequence(
x, input_lengths, batch_first=True, enforce_sorted=False)
block.flatten_parameters()
x, _ = block(x)
x, _ = nn.utils.rnn.pad_packed_sequence(
x, batch_first=True)
x = F.dropout(x, p=self.dropout, training=self.training)
x = x.transpose(-1, -2)
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
x_pad[:, :, :x.shape[-1]] = x
x = x_pad.to(x.device)
return x.transpose(-1, -2)
def inference(self, x, style):
x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
style = style.expand(x.shape[0], x.shape[1], -1)
x = torch.cat([x, style], axis=-1)
src = self.pos_encoder(x)
output = self.transformer_encoder(src).transpose(0, 1)
return output
def length_to_mask(self, lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask
def inference(self, x, style):
x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
style = style.expand(x.shape[0], x.shape[1], -1)
x = torch.cat([x, style], axis=-1)
src = self.pos_encoder(x)
output = self.transformer_encoder(src).transpose(0, 1)
return output
def length_to_mask(self, lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask
def load_F0_models(path):
# load F0 model
F0_model = JDCNet(num_class=1, seq_len=192)
params = torch.load(path, map_location='cpu')['net']
F0_model.load_state_dict(params)
_ = F0_model.train()
return F0_model
def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
# load ASR model
def _load_config(path):
with open(path) as f:
config = yaml.safe_load(f)
model_config = config['model_params']
return model_config
def _load_model(model_config, model_path):
model = ASRCNN(**model_config)
params = torch.load(model_path, map_location='cpu')['model']
model.load_state_dict(params)
return model
asr_model_config = _load_config(ASR_MODEL_CONFIG)
asr_model = _load_model(asr_model_config, ASR_MODEL_PATH)
_ = asr_model.train()
return asr_model
def build_model(args, text_aligner, pitch_extractor, bert):
assert args.decoder.type in ['istftnet', 'hifigan'], 'Decoder type unknown'
if args.decoder.type == "istftnet":
from Modules.istftnet import Decoder
decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
upsample_rates = args.decoder.upsample_rates,
upsample_initial_channel=args.decoder.upsample_initial_channel,
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
else:
from Modules.hifigan import Decoder
decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
upsample_rates = args.decoder.upsample_rates,
upsample_initial_channel=args.decoder.upsample_initial_channel,
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes)
text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # acoustic style encoder
predictor_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # prosodic style encoder
# define diffusion model
if args.multispeaker:
transformer = StyleTransformer1d(channels=args.style_dim*2,
context_embedding_features=bert.config.hidden_size,
context_features=args.style_dim*2,
**args.diffusion.transformer)
else:
transformer = Transformer1d(channels=args.style_dim*2,
context_embedding_features=bert.config.hidden_size,
**args.diffusion.transformer)
diffusion = AudioDiffusionConditional(
in_channels=1,
embedding_max_length=bert.config.max_position_embeddings,
embedding_features=bert.config.hidden_size,
embedding_mask_proba=args.diffusion.embedding_mask_proba, # Conditional dropout of batch elements,
channels=args.style_dim*2,
context_features=args.style_dim*2,
)
diffusion.diffusion = KDiffusion(
net=diffusion.unet,
sigma_distribution=LogNormalDistribution(mean = args.diffusion.dist.mean, std = args.diffusion.dist.std),
sigma_data=args.diffusion.dist.sigma_data, # a placeholder, will be changed dynamically when start training diffusion model
dynamic_threshold=0.0
)
diffusion.diffusion.net = transformer
diffusion.unet = transformer
nets = Munch(
bert=bert,
bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
predictor=predictor,
decoder=decoder,
text_encoder=text_encoder,
predictor_encoder=predictor_encoder,
style_encoder=style_encoder,
diffusion=diffusion,
text_aligner = text_aligner,
pitch_extractor=pitch_extractor,
mpd = MultiPeriodDiscriminator(),
msd = MultiResSpecDiscriminator(),
# slm discriminator head
wd = WavLMDiscriminator(args.slm.hidden, args.slm.nlayers, args.slm.initial_channel),
)
return nets
# def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[]):
# state = torch.load(path, map_location='cpu')
# params = state['net']
# for key in model:
# if key in params and key not in ignore_modules:
# print('%s loaded' % key)
# model[key].load_state_dict(params[key], strict=False)
# _ = [model[key].eval() for key in model]
# if not load_only_params:
# epoch = state["epoch"]
# iters = state["iters"]
# optimizer.load_state_dict(state["optimizer"])
# else:
# epoch = 0
# iters = 0
# return model, optimizer, epoch, iters
def load_checkpoint(model, optimizer, path, load_only_params=False, ignore_modules=[]):
state = torch.load(path, map_location='cpu')
params = state['net']
print('loading the ckpt using the correct function.')
for key in model:
if key in params and key not in ignore_modules:
try:
model[key].load_state_dict(params[key], strict=True)
except:
from collections import OrderedDict
state_dict = params[key]
new_state_dict = OrderedDict()
print(f'{key} key length: {len(model[key].state_dict().keys())}, state_dict key length: {len(state_dict.keys())}')
for (k_m, v_m), (k_c, v_c) in zip(model[key].state_dict().items(), state_dict.items()):
new_state_dict[k_m] = v_c
model[key].load_state_dict(new_state_dict, strict=True)
print('%s loaded' % key)
if not load_only_params:
epoch = state["epoch"]
iters = state["iters"]
optimizer.load_state_dict(state["optimizer"])
else:
epoch = 0
iters = 0
return model, optimizer, epoch, iters
################################################################################################
################################################################################################
################################################################################################
# LSTM ORIGINAL
################################################################################################
################################################################################################
# # import os
# # import os.path as osp
# # import copy
# # import math
# # import numpy as np
# # import torch
# # import torch.nn as nn
# # import torch.nn.functional as F
# # from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
# # from Utils.ASR.models import ASRCNN
# # from Utils.JDC.model import JDCNet
# # from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
# # from Modules.diffusion.modules import Transformer1d, StyleTransformer1d
# # from Modules.diffusion.diffusion import AudioDiffusionConditional
# # from Modules.discriminators import MultiPeriodDiscriminator, MultiResSpecDiscriminator, WavLMDiscriminator
# # from munch import Munch
# # import yaml
# # class LearnedDownSample(nn.Module):
# # def __init__(self, layer_type, dim_in):
# # super().__init__()
# # self.layer_type = layer_type
# # if self.layer_type == 'none':
# # self.conv = nn.Identity()
# # elif self.layer_type == 'timepreserve':
# # self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
# # elif self.layer_type == 'half':
# # self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
# # else:
# # raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
# # def forward(self, x):
# # return self.conv(x)
# # class LearnedUpSample(nn.Module):
# # def __init__(self, layer_type, dim_in):
# # super().__init__()
# # self.layer_type = layer_type
# # if self.layer_type == 'none':
# # self.conv = nn.Identity()
# # elif self.layer_type == 'timepreserve':
# # self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
# # elif self.layer_type == 'half':
# # self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
# # else:
# # raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
# # def forward(self, x):
# # return self.conv(x)
# # class DownSample(nn.Module):
# # def __init__(self, layer_type):
# # super().__init__()
# # self.layer_type = layer_type
# # def forward(self, x):
# # if self.layer_type == 'none':
# # return x
# # elif self.layer_type == 'timepreserve':
# # return F.avg_pool2d(x, (2, 1))
# # elif self.layer_type == 'half':
# # if x.shape[-1] % 2 != 0:
# # x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
# # return F.avg_pool2d(x, 2)
# # else:
# # raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
# # class UpSample(nn.Module):
# # def __init__(self, layer_type):
# # super().__init__()
# # self.layer_type = layer_type
# # def forward(self, x):
# # if self.layer_type == 'none':
# # return x
# # elif self.layer_type == 'timepreserve':
# # return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
# # elif self.layer_type == 'half':
# # return F.interpolate(x, scale_factor=2, mode='nearest')
# # else:
# # raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
# # class ResBlk(nn.Module):
# # def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
# # normalize=False, downsample='none'):
# # super().__init__()
# # self.actv = actv
# # self.normalize = normalize
# # self.downsample = DownSample(downsample)
# # self.downsample_res = LearnedDownSample(downsample, dim_in)
# # self.learned_sc = dim_in != dim_out
# # self._build_weights(dim_in, dim_out)
# # def _build_weights(self, dim_in, dim_out):
# # self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
# # self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
# # if self.normalize:
# # self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
# # self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
# # if self.learned_sc:
# # self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
# # def _shortcut(self, x):
# # if self.learned_sc:
# # x = self.conv1x1(x)
# # if self.downsample:
# # x = self.downsample(x)
# # return x
# # def _residual(self, x):
# # if self.normalize:
# # x = self.norm1(x)
# # x = self.actv(x)
# # x = self.conv1(x)
# # x = self.downsample_res(x)
# # if self.normalize:
# # x = self.norm2(x)
# # x = self.actv(x)
# # x = self.conv2(x)
# # return x
# # def forward(self, x):
# # x = self._shortcut(x) + self._residual(x)
# # return x / math.sqrt(2) # unit variance
# # class StyleEncoder(nn.Module):
# # def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
# # super().__init__()
# # blocks = []
# # blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
# # repeat_num = 4
# # for _ in range(repeat_num):
# # dim_out = min(dim_in*2, max_conv_dim)
# # blocks += [ResBlk(dim_in, dim_out, downsample='half')]
# # dim_in = dim_out
# # blocks += [nn.LeakyReLU(0.2)]
# # blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
# # blocks += [nn.AdaptiveAvgPool2d(1)]
# # blocks += [nn.LeakyReLU(0.2)]
# # self.shared = nn.Sequential(*blocks)
# # self.unshared = nn.Linear(dim_out, style_dim)
# # def forward(self, x):
# # h = self.shared(x)
# # h = h.view(h.size(0), -1)
# # s = self.unshared(h)
# # return s
# # class LinearNorm(torch.nn.Module):
# # def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
# # super(LinearNorm, self).__init__()
# # self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
# # torch.nn.init.xavier_uniform_(
# # self.linear_layer.weight,
# # gain=torch.nn.init.calculate_gain(w_init_gain))
# # def forward(self, x):
# # return self.linear_layer(x)
# # class Discriminator2d(nn.Module):
# # def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
# # super().__init__()
# # blocks = []
# # blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
# # for lid in range(repeat_num):
# # dim_out = min(dim_in*2, max_conv_dim)
# # blocks += [ResBlk(dim_in, dim_out, downsample='half')]
# # dim_in = dim_out
# # blocks += [nn.LeakyReLU(0.2)]
# # blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
# # blocks += [nn.LeakyReLU(0.2)]
# # blocks += [nn.AdaptiveAvgPool2d(1)]
# # blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
# # self.main = nn.Sequential(*blocks)
# # def get_feature(self, x):
# # features = []
# # for l in self.main:
# # x = l(x)
# # features.append(x)
# # out = features[-1]
# # out = out.view(out.size(0), -1) # (batch, num_domains)
# # return out, features
# # def forward(self, x):
# # out, features = self.get_feature(x)
# # out = out.squeeze() # (batch)
# # return out, features
# # class ResBlk1d(nn.Module):
# # def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
# # normalize=False, downsample='none', dropout_p=0.2):
# # super().__init__()
# # self.actv = actv
# # self.normalize = normalize
# # self.downsample_type = downsample
# # self.learned_sc = dim_in != dim_out
# # self._build_weights(dim_in, dim_out)
# # self.dropout_p = dropout_p
# # if self.downsample_type == 'none':
# # self.pool = nn.Identity()
# # else:
# # self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
# # def _build_weights(self, dim_in, dim_out):
# # self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
# # self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
# # if self.normalize:
# # self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
# # self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
# # if self.learned_sc:
# # self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
# # def downsample(self, x):
# # if self.downsample_type == 'none':
# # return x
# # else:
# # if x.shape[-1] % 2 != 0:
# # x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
# # return F.avg_pool1d(x, 2)
# # def _shortcut(self, x):
# # if self.learned_sc:
# # x = self.conv1x1(x)
# # x = self.downsample(x)
# # return x
# # def _residual(self, x):
# # if self.normalize:
# # x = self.norm1(x)
# # x = self.actv(x)
# # x = F.dropout(x, p=self.dropout_p, training=self.training)
# # x = self.conv1(x)
# # x = self.pool(x)
# # if self.normalize:
# # x = self.norm2(x)
# # x = self.actv(x)
# # x = F.dropout(x, p=self.dropout_p, training=self.training)
# # x = self.conv2(x)
# # return x
# # def forward(self, x):
# # x = self._shortcut(x) + self._residual(x)
# # return x / math.sqrt(2) # unit variance
# # class LayerNorm(nn.Module):
# # def __init__(self, channels, eps=1e-5):
# # super().__init__()
# # self.channels = channels
# # self.eps = eps
# # self.gamma = nn.Parameter(torch.ones(channels))
# # self.beta = nn.Parameter(torch.zeros(channels))
# # def forward(self, x):
# # x = x.transpose(1, -1)
# # x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
# # return x.transpose(1, -1)
# # class TextEncoder(nn.Module):
# # def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
# # super().__init__()
# # self.embedding = nn.Embedding(n_symbols, channels)
# # padding = (kernel_size - 1) // 2
# # self.cnn = nn.ModuleList()
# # for _ in range(depth):
# # self.cnn.append(nn.Sequential(
# # weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
# # LayerNorm(channels),
# # actv,
# # nn.Dropout(0.2),
# # ))
# # # self.cnn = nn.Sequential(*self.cnn)
# # self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
# # def forward(self, x, input_lengths, m):
# # x = self.embedding(x) # [B, T, emb]
# # x = x.transpose(1, 2) # [B, emb, T]
# # m = m.to(input_lengths.device).unsqueeze(1)
# # x.masked_fill_(m, 0.0)
# # for c in self.cnn:
# # x = c(x)
# # x.masked_fill_(m, 0.0)
# # x = x.transpose(1, 2) # [B, T, chn]
# # input_lengths = input_lengths.cpu().numpy()
# # x = nn.utils.rnn.pack_padded_sequence(
# # x, input_lengths, batch_first=True, enforce_sorted=False)
# # self.lstm.flatten_parameters()
# # x, _ = self.lstm(x)
# # x, _ = nn.utils.rnn.pad_packed_sequence(
# # x, batch_first=True)
# # x = x.transpose(-1, -2)
# # x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
# # x_pad[:, :, :x.shape[-1]] = x
# # x = x_pad.to(x.device)
# # x.masked_fill_(m, 0.0)
# # return x
# # def inference(self, x):
# # x = self.embedding(x)
# # x = x.transpose(1, 2)
# # x = self.cnn(x)
# # x = x.transpose(1, 2)
# # self.lstm.flatten_parameters()
# # x, _ = self.lstm(x)
# # return x
# # def length_to_mask(self, lengths):
# # mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
# # mask = torch.gt(mask+1, lengths.unsqueeze(1))
# # return mask
# # class AdaIN1d(nn.Module):
# # def __init__(self, style_dim, num_features):
# # super().__init__()
# # self.norm = nn.InstanceNorm1d(num_features, affine=False)
# # self.fc = nn.Linear(style_dim, num_features*2)
# # def forward(self, x, s):
# # h = self.fc(s)
# # h = h.view(h.size(0), h.size(1), 1)
# # gamma, beta = torch.chunk(h, chunks=2, dim=1)
# # return (1 + gamma) * self.norm(x) + beta
# # class UpSample1d(nn.Module):
# # def __init__(self, layer_type):
# # super().__init__()
# # self.layer_type = layer_type
# # def forward(self, x):
# # if self.layer_type == 'none':
# # return x
# # else:
# # return F.interpolate(x, scale_factor=2, mode='nearest')
# # class AdainResBlk1d(nn.Module):
# # def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
# # upsample='none', dropout_p=0.0):
# # super().__init__()
# # self.actv = actv
# # self.upsample_type = upsample
# # self.upsample = UpSample1d(upsample)
# # self.learned_sc = dim_in != dim_out
# # self._build_weights(dim_in, dim_out, style_dim)
# # self.dropout = nn.Dropout(dropout_p)
# # if upsample == 'none':
# # self.pool = nn.Identity()
# # else:
# # self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
# # def _build_weights(self, dim_in, dim_out, style_dim):
# # self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
# # self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
# # self.norm1 = AdaIN1d(style_dim, dim_in)
# # self.norm2 = AdaIN1d(style_dim, dim_out)
# # if self.learned_sc:
# # self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
# # def _shortcut(self, x):
# # x = self.upsample(x)
# # if self.learned_sc:
# # x = self.conv1x1(x)
# # return x
# # def _residual(self, x, s):
# # x = self.norm1(x, s)
# # x = self.actv(x)
# # x = self.pool(x)
# # x = self.conv1(self.dropout(x))
# # x = self.norm2(x, s)
# # x = self.actv(x)
# # x = self.conv2(self.dropout(x))
# # return x
# # def forward(self, x, s):
# # out = self._residual(x, s)
# # out = (out + self._shortcut(x)) / math.sqrt(2)
# # return out
# # class AdaLayerNorm(nn.Module):
# # def __init__(self, style_dim, channels, eps=1e-5):
# # super().__init__()
# # self.channels = channels
# # self.eps = eps
# # self.fc = nn.Linear(style_dim, channels*2)
# # def forward(self, x, s):
# # x = x.transpose(-1, -2)
# # x = x.transpose(1, -1)
# # h = self.fc(s)
# # h = h.view(h.size(0), h.size(1), 1)
# # gamma, beta = torch.chunk(h, chunks=2, dim=1)
# # gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
# # x = F.layer_norm(x, (self.channels,), eps=self.eps)
# # x = (1 + gamma) * x + beta
# # return x.transpose(1, -1).transpose(-1, -2)
# # class ProsodyPredictor(nn.Module):
# # def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
# # super().__init__()
# # self.text_encoder = DurationEncoder(sty_dim=style_dim,
# # d_model=d_hid,
# # nlayers=nlayers,
# # dropout=dropout)
# # self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
# # self.duration_proj = LinearNorm(d_hid, max_dur)
# # self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
# # self.F0 = nn.ModuleList()
# # self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
# # self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
# # self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
# # self.N = nn.ModuleList()
# # self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
# # self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
# # self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
# # self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
# # self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
# # def forward(self, texts, style, text_lengths, alignment, m):
# # d = self.text_encoder(texts, style, text_lengths, m)
# # batch_size = d.shape[0]
# # text_size = d.shape[1]
# # # predict duration
# # input_lengths = text_lengths.cpu().numpy()
# # x = nn.utils.rnn.pack_padded_sequence(
# # d, input_lengths, batch_first=True, enforce_sorted=False)
# # m = m.to(text_lengths.device).unsqueeze(1)
# # self.lstm.flatten_parameters()
# # x, _ = self.lstm(x)
# # x, _ = nn.utils.rnn.pad_packed_sequence(
# # x, batch_first=True)
# # x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
# # x_pad[:, :x.shape[1], :] = x
# # x = x_pad.to(x.device)
# # duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
# # en = (d.transpose(-1, -2) @ alignment)
# # return duration.squeeze(-1), en
# # def F0Ntrain(self, x, s):
# # x, _ = self.shared(x.transpose(-1, -2))
# # F0 = x.transpose(-1, -2)
# # for block in self.F0:
# # F0 = block(F0, s)
# # F0 = self.F0_proj(F0)
# # N = x.transpose(-1, -2)
# # for block in self.N:
# # N = block(N, s)
# # N = self.N_proj(N)
# # return F0.squeeze(1), N.squeeze(1)
# # def length_to_mask(self, lengths):
# # mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
# # mask = torch.gt(mask+1, lengths.unsqueeze(1))
# # return mask
# # class DurationEncoder(nn.Module):
# # def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
# # super().__init__()
# # self.lstms = nn.ModuleList()
# # for _ in range(nlayers):
# # self.lstms.append(nn.LSTM(d_model + sty_dim,
# # d_model // 2,
# # num_layers=1,
# # batch_first=True,
# # bidirectional=True,
# # dropout=dropout))
# # self.lstms.append(AdaLayerNorm(sty_dim, d_model))
# # self.dropout = dropout
# # self.d_model = d_model
# # self.sty_dim = sty_dim
# # def forward(self, x, style, text_lengths, m):
# # masks = m.to(text_lengths.device)
# # x = x.permute(2, 0, 1)
# # s = style.expand(x.shape[0], x.shape[1], -1)
# # x = torch.cat([x, s], axis=-1)
# # x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
# # x = x.transpose(0, 1)
# # input_lengths = text_lengths.cpu().numpy()
# # x = x.transpose(-1, -2)
# # for block in self.lstms:
# # if isinstance(block, AdaLayerNorm):
# # x = block(x.transpose(-1, -2), style).transpose(-1, -2)
# # x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
# # x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
# # else:
# # x = x.transpose(-1, -2)
# # x = nn.utils.rnn.pack_padded_sequence(
# # x, input_lengths, batch_first=True, enforce_sorted=False)
# # block.flatten_parameters()
# # x, _ = block(x)
# # x, _ = nn.utils.rnn.pad_packed_sequence(
# # x, batch_first=True)
# # x = F.dropout(x, p=self.dropout, training=self.training)
# # x = x.transpose(-1, -2)
# # x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
# # x_pad[:, :, :x.shape[-1]] = x
# # x = x_pad.to(x.device)
# # return x.transpose(-1, -2)
# # def inference(self, x, style):
# # x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
# # style = style.expand(x.shape[0], x.shape[1], -1)
# # x = torch.cat([x, style], axis=-1)
# # src = self.pos_encoder(x)
# # output = self.transformer_encoder(src).transpose(0, 1)
# # return output
# # def length_to_mask(self, lengths):
# # mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
# # mask = torch.gt(mask+1, lengths.unsqueeze(1))
# # return mask
# # def load_F0_models(path):
# # # load F0 model
# # F0_model = JDCNet(num_class=1, seq_len=192)
# # params = torch.load(path, map_location='cpu')['net']
# # F0_model.load_state_dict(params)
# # _ = F0_model.train()
# # return F0_model
# # def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
# # # load ASR model
# # def _load_config(path):
# # with open(path) as f:
# # config = yaml.safe_load(f)
# # model_config = config['model_params']
# # return model_config
# # def _load_model(model_config, model_path):
# # model = ASRCNN(**model_config)
# # params = torch.load(model_path, map_location='cpu')['model']
# # model.load_state_dict(params)
# # return model
# # asr_model_config = _load_config(ASR_MODEL_CONFIG)
# # asr_model = _load_model(asr_model_config, ASR_MODEL_PATH)
# # _ = asr_model.train()
# # return asr_model
# # def build_model(args, text_aligner, pitch_extractor, bert):
# # assert args.decoder.type in ['istftnet', 'hifigan'], 'Decoder type unknown'
# # if args.decoder.type == "istftnet":
# # from Modules.istftnet import Decoder
# # decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
# # resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
# # upsample_rates = args.decoder.upsample_rates,
# # upsample_initial_channel=args.decoder.upsample_initial_channel,
# # resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
# # upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
# # gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
# # else:
# # from Modules.hifigan import Decoder
# # decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
# # resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
# # upsample_rates = args.decoder.upsample_rates,
# # upsample_initial_channel=args.decoder.upsample_initial_channel,
# # resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
# # upsample_kernel_sizes=args.decoder.upsample_kernel_sizes)
# # text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
# # predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
# # style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # acoustic style encoder
# # predictor_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim) # prosodic style encoder
# # # define diffusion model
# # if args.multispeaker:
# # transformer = StyleTransformer1d(channels=args.style_dim*2,
# # context_embedding_features=bert.config.hidden_size,
# # context_features=args.style_dim*2,
# # **args.diffusion.transformer)
# # else:
# # transformer = Transformer1d(channels=args.style_dim*2,
# # context_embedding_features=bert.config.hidden_size,
# # **args.diffusion.transformer)
# # diffusion = AudioDiffusionConditional(
# # in_channels=1,
# # embedding_max_length=bert.config.max_position_embeddings,
# # embedding_features=bert.config.hidden_size,
# # embedding_mask_proba=args.diffusion.embedding_mask_proba, # Conditional dropout of batch elements,
# # channels=args.style_dim*2,
# # context_features=args.style_dim*2,
# # )
# # diffusion.diffusion = KDiffusion(
# # net=diffusion.unet,
# # sigma_distribution=LogNormalDistribution(mean = args.diffusion.dist.mean, std = args.diffusion.dist.std),
# # sigma_data=args.diffusion.dist.sigma_data, # a placeholder, will be changed dynamically when start training diffusion model
# # dynamic_threshold=0.0
# # )
# # diffusion.diffusion.net = transformer
# # diffusion.unet = transformer
# # nets = Munch(
# # bert=bert,
# # bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
# # predictor=predictor,
# # decoder=decoder,
# # text_encoder=text_encoder,
# # predictor_encoder=predictor_encoder,
# # style_encoder=style_encoder,
# # diffusion=diffusion,
# # text_aligner = text_aligner,
# # pitch_extractor=pitch_extractor,
# # mpd = MultiPeriodDiscriminator(),
# # msd = MultiResSpecDiscriminator(),
# # # slm discriminator head
# # wd = WavLMDiscriminator(args.slm.hidden, args.slm.nlayers, args.slm.initial_channel),
# # )
# # return nets
# # def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[]):
# # state = torch.load(path, map_location='cpu')
# # params = state['net']
# # for key in model:
# # if key in params and key not in ignore_modules:
# # print('%s loaded' % key)
# # try:
# # model[key].load_state_dict(params[key], strict=True)
# # except:
# # from collections import OrderedDict
# # state_dict = params[key]
# # new_state_dict = OrderedDict()
# # print(f'{key} key length: {len(model[key].state_dict().keys())}, state_dict length: {len(state_dict.keys())}')
# # for (k_m, v_m), (k_c, v_c) in zip(model[key].state_dict().items(), state_dict.items()):
# # new_state_dict[k_m] = v_c
# # model[key].load_state_dict(new_state_dict, strict=True)
# # _ = [model[key].eval() for key in model]
# # if not load_only_params:
# # epoch = state["epoch"]
# # iters = state["iters"]
# # optimizer.load_state_dict(state["optimizer"])
# # else:
# # epoch = 0
# # iters = 0
# # return model, optimizer, epoch, iters