Spaces:
Runtime error
Runtime error
EXP
Browse files- RVC/modules/attentions.py +138 -0
- RVC/modules/commons.py +56 -0
- RVC/modules/config.py +42 -0
- RVC/modules/cut.py +160 -0
- RVC/modules/download.py +132 -0
- RVC/modules/encoders.py +96 -0
- RVC/modules/fairseq.py +1396 -0
- RVC/modules/gdown.py +100 -0
- RVC/modules/generator.py +257 -0
- RVC/modules/hifigan.py +60 -0
- RVC/modules/mediafire.py +30 -0
- RVC/modules/meganz.py +122 -0
- RVC/modules/modules.py +60 -0
- RVC/modules/mrf_hifigan.py +150 -0
- RVC/modules/noisereduce.py +196 -0
- RVC/modules/normalization.py +15 -0
- RVC/modules/nsf_hifigan.py +116 -0
- RVC/modules/opencl.py +199 -0
- RVC/modules/pipeline.py +215 -0
- RVC/modules/pixeldrain.py +16 -0
- RVC/modules/pyworld.py +84 -0
- RVC/modules/refinegan.py +170 -0
- RVC/modules/residuals.py +140 -0
- RVC/modules/rms.py +30 -0
- RVC/modules/rmvpe.py +260 -0
- RVC/modules/swipe.py +200 -0
- RVC/modules/synthesizers.py +84 -0
- RVC/modules/torchcrepe.py +185 -0
- RVC/modules/torchfcpe.py +951 -0
- RVC/modules/utils.py +94 -0
RVC/modules/attentions.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
sys.path.append(os.getcwd())
|
| 10 |
+
|
| 11 |
+
from modules.commons import convert_pad_shape
|
| 12 |
+
|
| 13 |
+
class MultiHeadAttention(nn.Module):
|
| 14 |
+
def __init__(self, channels, out_channels, n_heads, p_dropout=0.0, window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
|
| 15 |
+
super().__init__()
|
| 16 |
+
assert channels % n_heads == 0
|
| 17 |
+
self.channels = channels
|
| 18 |
+
self.out_channels = out_channels
|
| 19 |
+
self.n_heads = n_heads
|
| 20 |
+
self.p_dropout = p_dropout
|
| 21 |
+
self.window_size = window_size
|
| 22 |
+
self.heads_share = heads_share
|
| 23 |
+
self.block_length = block_length
|
| 24 |
+
self.proximal_bias = proximal_bias
|
| 25 |
+
self.proximal_init = proximal_init
|
| 26 |
+
self.attn = None
|
| 27 |
+
self.k_channels = channels // n_heads
|
| 28 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
| 29 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
| 30 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
| 31 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
| 32 |
+
self.drop = nn.Dropout(p_dropout)
|
| 33 |
+
|
| 34 |
+
if window_size is not None:
|
| 35 |
+
n_heads_rel = 1 if heads_share else n_heads
|
| 36 |
+
rel_stddev = self.k_channels**-0.5
|
| 37 |
+
|
| 38 |
+
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
| 39 |
+
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
| 40 |
+
|
| 41 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
| 42 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
| 43 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
| 44 |
+
nn.init.xavier_uniform_(self.conv_o.weight)
|
| 45 |
+
|
| 46 |
+
if proximal_init:
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
self.conv_k.weight.copy_(self.conv_q.weight)
|
| 49 |
+
self.conv_k.bias.copy_(self.conv_q.bias)
|
| 50 |
+
|
| 51 |
+
def forward(self, x, c, attn_mask=None):
|
| 52 |
+
q, k, v = self.conv_q(x), self.conv_k(c), self.conv_v(c)
|
| 53 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
| 54 |
+
|
| 55 |
+
return self.conv_o(x)
|
| 56 |
+
|
| 57 |
+
def attention(self, query, key, value, mask=None):
|
| 58 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
| 59 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
| 60 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
| 61 |
+
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
| 62 |
+
|
| 63 |
+
if self.window_size is not None:
|
| 64 |
+
assert (t_s == t_t)
|
| 65 |
+
scores += self._relative_position_to_absolute_position(self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), self._get_relative_embeddings(self.emb_rel_k, t_s)))
|
| 66 |
+
|
| 67 |
+
if self.proximal_bias:
|
| 68 |
+
assert t_s == t_t
|
| 69 |
+
scores += self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
| 70 |
+
|
| 71 |
+
if mask is not None:
|
| 72 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
| 73 |
+
if self.block_length is not None:
|
| 74 |
+
assert (t_s == t_t)
|
| 75 |
+
scores = scores.masked_fill((torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)) == 0, -1e4)
|
| 76 |
+
|
| 77 |
+
p_attn = self.drop(F.softmax(scores, dim=-1))
|
| 78 |
+
output = torch.matmul(p_attn, value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3))
|
| 79 |
+
|
| 80 |
+
if self.window_size is not None: output += self._matmul_with_relative_values(self._absolute_position_to_relative_position(p_attn), self._get_relative_embeddings(self.emb_rel_v, t_s))
|
| 81 |
+
return (output.transpose(2, 3).contiguous().view(b, d, t_t)), p_attn
|
| 82 |
+
|
| 83 |
+
def _matmul_with_relative_values(self, x, y):
|
| 84 |
+
return torch.matmul(x, y.unsqueeze(0))
|
| 85 |
+
|
| 86 |
+
def _matmul_with_relative_keys(self, x, y):
|
| 87 |
+
return torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
| 88 |
+
|
| 89 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
| 90 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
| 91 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
| 92 |
+
|
| 93 |
+
return (F.pad(relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) if pad_length > 0 else relative_embeddings)[:, slice_start_position:(slice_start_position + 2 * length - 1)]
|
| 94 |
+
|
| 95 |
+
def _relative_position_to_absolute_position(self, x):
|
| 96 |
+
batch, heads, length, _ = x.size()
|
| 97 |
+
|
| 98 |
+
return F.pad(F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])).view([batch, heads, length * 2 * length]), convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])).view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
|
| 99 |
+
|
| 100 |
+
def _absolute_position_to_relative_position(self, x):
|
| 101 |
+
batch, heads, length, _ = x.size()
|
| 102 |
+
|
| 103 |
+
return F.pad(F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])).view([batch, heads, length**2 + length * (length - 1)]), convert_pad_shape([[0, 0], [0, 0], [length, 0]])).view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
| 104 |
+
|
| 105 |
+
def _attention_bias_proximal(self, length):
|
| 106 |
+
r = torch.arange(length, dtype=torch.float32)
|
| 107 |
+
|
| 108 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs((torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)))), 0), 0)
|
| 109 |
+
|
| 110 |
+
class FFN(nn.Module):
|
| 111 |
+
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, activation=None, causal=False):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.in_channels = in_channels
|
| 114 |
+
self.out_channels = out_channels
|
| 115 |
+
self.filter_channels = filter_channels
|
| 116 |
+
self.kernel_size = kernel_size
|
| 117 |
+
self.p_dropout = p_dropout
|
| 118 |
+
self.activation = activation
|
| 119 |
+
self.causal = causal
|
| 120 |
+
self.padding = self._causal_padding if causal else self._same_padding
|
| 121 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
| 122 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
| 123 |
+
self.drop = nn.Dropout(p_dropout)
|
| 124 |
+
|
| 125 |
+
def forward(self, x, x_mask):
|
| 126 |
+
x = self.conv_1(self.padding(x * x_mask))
|
| 127 |
+
|
| 128 |
+
return self.conv_2(self.padding(self.drop(((x * torch.sigmoid(1.702 * x)) if self.activation == "gelu" else torch.relu(x))) * x_mask)) * x_mask
|
| 129 |
+
|
| 130 |
+
def _causal_padding(self, x):
|
| 131 |
+
if self.kernel_size == 1: return x
|
| 132 |
+
|
| 133 |
+
return F.pad(x, convert_pad_shape([[0, 0], [0, 0], [(self.kernel_size - 1), 0]]))
|
| 134 |
+
|
| 135 |
+
def _same_padding(self, x):
|
| 136 |
+
if self.kernel_size == 1: return x
|
| 137 |
+
|
| 138 |
+
return F.pad(x, convert_pad_shape([[0, 0], [0, 0], [((self.kernel_size - 1) // 2), (self.kernel_size // 2)]]))
|
RVC/modules/commons.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 4 |
+
if m.__class__.__name__.find("Conv") != -1: m.weight.data.normal_(mean, std)
|
| 5 |
+
|
| 6 |
+
def get_padding(kernel_size, dilation=1):
|
| 7 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 8 |
+
|
| 9 |
+
def convert_pad_shape(pad_shape):
|
| 10 |
+
return [item for sublist in pad_shape[::-1] for item in sublist]
|
| 11 |
+
|
| 12 |
+
def slice_segments(x, ids_str, segment_size = 4, dim = 2):
|
| 13 |
+
if dim == 2: ret = torch.zeros_like(x[:, :segment_size])
|
| 14 |
+
elif dim == 3: ret = torch.zeros_like(x[:, :, :segment_size])
|
| 15 |
+
|
| 16 |
+
for i in range(x.size(0)):
|
| 17 |
+
idx_str = ids_str[i].item()
|
| 18 |
+
idx_end = idx_str + segment_size
|
| 19 |
+
|
| 20 |
+
if dim == 2: ret[i] = x[i, idx_str:idx_end]
|
| 21 |
+
else: ret[i] = x[i, :, idx_str:idx_end]
|
| 22 |
+
|
| 23 |
+
return ret
|
| 24 |
+
|
| 25 |
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
| 26 |
+
b, _, t = x.size()
|
| 27 |
+
if x_lengths is None: x_lengths = t
|
| 28 |
+
|
| 29 |
+
ids_str = (torch.rand([b]).to(device=x.device) * (x_lengths - segment_size + 1)).to(dtype=torch.long)
|
| 30 |
+
|
| 31 |
+
return slice_segments(x, ids_str, segment_size, dim=3), ids_str
|
| 32 |
+
|
| 33 |
+
@torch.jit.script
|
| 34 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
| 35 |
+
n_channels_int = n_channels[0]
|
| 36 |
+
in_act = input_a + input_b
|
| 37 |
+
|
| 38 |
+
return torch.tanh(in_act[:, :n_channels_int, :]) * torch.sigmoid(in_act[:, n_channels_int:, :])
|
| 39 |
+
|
| 40 |
+
def sequence_mask(length, max_length = None):
|
| 41 |
+
if max_length is None: max_length = length.max()
|
| 42 |
+
return torch.arange(max_length, dtype=length.dtype, device=length.device).unsqueeze(0) < length.unsqueeze(1)
|
| 43 |
+
|
| 44 |
+
def clip_grad_value(parameters, clip_value, norm_type=2):
|
| 45 |
+
if isinstance(parameters, torch.Tensor): parameters = [parameters]
|
| 46 |
+
norm_type = float(norm_type)
|
| 47 |
+
|
| 48 |
+
if clip_value is not None: clip_value = float(clip_value)
|
| 49 |
+
total_norm = 0
|
| 50 |
+
|
| 51 |
+
for p in list(filter(lambda p: p.grad is not None, parameters)):
|
| 52 |
+
total_norm += (p.grad.data.norm(norm_type)).item() ** norm_type
|
| 53 |
+
|
| 54 |
+
if clip_value is not None: p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
| 55 |
+
|
| 56 |
+
return total_norm ** (1.0 / norm_type)
|
RVC/modules/config.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
sys.path.append(os.getcwd())
|
| 6 |
+
|
| 7 |
+
from modules import opencl
|
| 8 |
+
|
| 9 |
+
def singleton(cls):
|
| 10 |
+
instances = {}
|
| 11 |
+
|
| 12 |
+
def get_instance(*args, **kwargs):
|
| 13 |
+
if cls not in instances: instances[cls] = cls(*args, **kwargs)
|
| 14 |
+
return instances[cls]
|
| 15 |
+
|
| 16 |
+
return get_instance
|
| 17 |
+
|
| 18 |
+
@singleton
|
| 19 |
+
class Config:
|
| 20 |
+
def __init__(self, cpu_mode=False, is_half=False):
|
| 21 |
+
self.device = "cuda:0" if torch.cuda.is_available() else ("ocl:0" if opencl.is_available() else "cpu")
|
| 22 |
+
self.is_half = is_half
|
| 23 |
+
self.gpu_mem = None
|
| 24 |
+
self.cpu_mode = cpu_mode
|
| 25 |
+
if cpu_mode: self.device = "cpu"
|
| 26 |
+
|
| 27 |
+
def device_config(self):
|
| 28 |
+
if not self.cpu_mode:
|
| 29 |
+
if self.device.startswith("cuda"): self.set_cuda_config()
|
| 30 |
+
elif opencl.is_available(): self.device = "ocl:0"
|
| 31 |
+
elif self.has_mps(): self.device = "mps"
|
| 32 |
+
else: self.device = "cpu"
|
| 33 |
+
|
| 34 |
+
if self.gpu_mem is not None and self.gpu_mem <= 4: return 1, 5, 30, 32
|
| 35 |
+
return (3, 10, 60, 65) if self.is_half else (1, 6, 38, 41)
|
| 36 |
+
|
| 37 |
+
def set_cuda_config(self):
|
| 38 |
+
i_device = int(self.device.split(":")[-1])
|
| 39 |
+
self.gpu_mem = torch.cuda.get_device_properties(i_device).total_memory // (1024**3)
|
| 40 |
+
|
| 41 |
+
def has_mps(self):
|
| 42 |
+
return torch.backends.mps.is_available()
|
RVC/modules/cut.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
class Slicer:
|
| 4 |
+
def __init__(self, sr, threshold = -40.0, min_length = 5000, min_interval = 300, hop_size = 20, max_sil_kept = 5000):
|
| 5 |
+
min_interval = sr * min_interval / 1000
|
| 6 |
+
self.threshold = 10 ** (threshold / 20.0)
|
| 7 |
+
self.hop_size = round(sr * hop_size / 1000)
|
| 8 |
+
self.win_size = min(round(min_interval), 4 * self.hop_size)
|
| 9 |
+
self.min_length = round(sr * min_length / 1000 / self.hop_size)
|
| 10 |
+
self.min_interval = round(min_interval / self.hop_size)
|
| 11 |
+
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
|
| 12 |
+
|
| 13 |
+
def _apply_slice(self, waveform, begin, end):
|
| 14 |
+
start_idx = begin * self.hop_size
|
| 15 |
+
|
| 16 |
+
return waveform[:, start_idx:min(waveform.shape[1], end * self.hop_size)] if len(waveform.shape) > 1 else waveform[start_idx:min(waveform.shape[0], end * self.hop_size)]
|
| 17 |
+
|
| 18 |
+
def slice(self, waveform):
|
| 19 |
+
samples = waveform.mean(axis=0) if len(waveform.shape) > 1 else waveform
|
| 20 |
+
if samples.shape[0] <= self.min_length: return [waveform]
|
| 21 |
+
rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
| 22 |
+
sil_tags = []
|
| 23 |
+
silence_start, clip_start = None, 0
|
| 24 |
+
|
| 25 |
+
for i, rms in enumerate(rms_list):
|
| 26 |
+
if rms < self.threshold:
|
| 27 |
+
if silence_start is None: silence_start = i
|
| 28 |
+
continue
|
| 29 |
+
|
| 30 |
+
if silence_start is None: continue
|
| 31 |
+
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
| 32 |
+
need_slice_middle = (i - silence_start >= self.min_interval and i - clip_start >= self.min_length)
|
| 33 |
+
if not is_leading_silence and not need_slice_middle:
|
| 34 |
+
silence_start = None
|
| 35 |
+
continue
|
| 36 |
+
|
| 37 |
+
if i - silence_start <= self.max_sil_kept:
|
| 38 |
+
pos = rms_list[silence_start : i + 1].argmin() + silence_start
|
| 39 |
+
sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
|
| 40 |
+
clip_start = pos
|
| 41 |
+
elif i - silence_start <= self.max_sil_kept * 2:
|
| 42 |
+
pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
|
| 43 |
+
pos += i - self.max_sil_kept
|
| 44 |
+
pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
|
| 45 |
+
if silence_start == 0:
|
| 46 |
+
sil_tags.append((0, pos_r))
|
| 47 |
+
clip_start = pos_r
|
| 48 |
+
else:
|
| 49 |
+
sil_tags.append((min((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos), max(pos_r, pos)))
|
| 50 |
+
clip_start = max(pos_r, pos)
|
| 51 |
+
else:
|
| 52 |
+
pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
|
| 53 |
+
sil_tags.append((0, pos_r) if silence_start == 0 else ((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos_r))
|
| 54 |
+
clip_start = pos_r
|
| 55 |
+
|
| 56 |
+
silence_start = None
|
| 57 |
+
total_frames = rms_list.shape[0]
|
| 58 |
+
if (silence_start is not None and total_frames - silence_start >= self.min_interval): sil_tags.append((rms_list[silence_start : min(total_frames, silence_start + self.max_sil_kept) + 1].argmin() + silence_start, total_frames + 1))
|
| 59 |
+
|
| 60 |
+
if not sil_tags: return [waveform]
|
| 61 |
+
else:
|
| 62 |
+
chunks = []
|
| 63 |
+
if sil_tags[0][0] > 0: chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0]))
|
| 64 |
+
|
| 65 |
+
for i in range(len(sil_tags) - 1):
|
| 66 |
+
chunks.append(self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]))
|
| 67 |
+
|
| 68 |
+
if sil_tags[-1][1] < total_frames: chunks.append(self._apply_slice(waveform, sil_tags[-1][1], total_frames))
|
| 69 |
+
return chunks
|
| 70 |
+
|
| 71 |
+
class Slicer2(Slicer):
|
| 72 |
+
def slice2(self, waveform):
|
| 73 |
+
samples = waveform.mean(axis=0) if len(waveform.shape) > 1 else waveform
|
| 74 |
+
|
| 75 |
+
if samples.shape[0] <= self.min_length: return [(waveform, 0, samples.shape[0])]
|
| 76 |
+
rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
| 77 |
+
|
| 78 |
+
sil_tags = []
|
| 79 |
+
silence_start, clip_start = None, 0
|
| 80 |
+
|
| 81 |
+
for i, rms in enumerate(rms_list):
|
| 82 |
+
if rms < self.threshold:
|
| 83 |
+
if silence_start is None: silence_start = i
|
| 84 |
+
continue
|
| 85 |
+
|
| 86 |
+
if silence_start is None: continue
|
| 87 |
+
|
| 88 |
+
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
| 89 |
+
need_slice_middle = (i - silence_start >= self.min_interval and i - clip_start >= self.min_length)
|
| 90 |
+
|
| 91 |
+
if not is_leading_silence and not need_slice_middle:
|
| 92 |
+
silence_start = None
|
| 93 |
+
continue
|
| 94 |
+
|
| 95 |
+
if i - silence_start <= self.max_sil_kept:
|
| 96 |
+
pos = rms_list[silence_start : i + 1].argmin() + silence_start
|
| 97 |
+
sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
|
| 98 |
+
clip_start = pos
|
| 99 |
+
elif i - silence_start <= self.max_sil_kept * 2:
|
| 100 |
+
pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
|
| 101 |
+
pos += i - self.max_sil_kept
|
| 102 |
+
|
| 103 |
+
pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
|
| 104 |
+
|
| 105 |
+
if silence_start == 0:
|
| 106 |
+
sil_tags.append((0, pos_r))
|
| 107 |
+
clip_start = pos_r
|
| 108 |
+
else:
|
| 109 |
+
sil_tags.append((min((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos), max(pos_r, pos)))
|
| 110 |
+
clip_start = max(pos_r, pos)
|
| 111 |
+
else:
|
| 112 |
+
pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
|
| 113 |
+
sil_tags.append((0, pos_r) if silence_start == 0 else ((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos_r))
|
| 114 |
+
clip_start = pos_r
|
| 115 |
+
|
| 116 |
+
silence_start = None
|
| 117 |
+
|
| 118 |
+
total_frames = rms_list.shape[0]
|
| 119 |
+
if (silence_start is not None and total_frames - silence_start >= self.min_interval): sil_tags.append((rms_list[silence_start : min(total_frames, silence_start + self.max_sil_kept) + 1].argmin() + silence_start, total_frames + 1))
|
| 120 |
+
|
| 121 |
+
if not sil_tags: return [(waveform, 0, samples.shape[-1])]
|
| 122 |
+
else:
|
| 123 |
+
chunks = []
|
| 124 |
+
if sil_tags[0][0] > 0: chunks.append((self._apply_slice(waveform, 0, sil_tags[0][0]), 0, sil_tags[0][0] * self.hop_size))
|
| 125 |
+
|
| 126 |
+
for i in range(len(sil_tags) - 1):
|
| 127 |
+
chunks.append((self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]), sil_tags[i][1] * self.hop_size, sil_tags[i + 1][0] * self.hop_size))
|
| 128 |
+
|
| 129 |
+
if sil_tags[-1][1] < total_frames: chunks.append((self._apply_slice(waveform, sil_tags[-1][1], total_frames), sil_tags[-1][1] * self.hop_size, samples.shape[-1]))
|
| 130 |
+
return chunks
|
| 131 |
+
|
| 132 |
+
def get_rms(y, frame_length=2048, hop_length=512, pad_mode="constant"):
|
| 133 |
+
y = np.pad(y, (int(frame_length // 2), int(frame_length // 2)), mode=pad_mode)
|
| 134 |
+
axis = -1
|
| 135 |
+
|
| 136 |
+
x_shape_trimmed = list(y.shape)
|
| 137 |
+
x_shape_trimmed[axis] -= frame_length - 1
|
| 138 |
+
xw = np.moveaxis(np.lib.stride_tricks.as_strided(y, shape=tuple(x_shape_trimmed) + tuple([frame_length]), strides=y.strides + tuple([y.strides[axis]])), -1, axis - 1 if axis < 0 else axis + 1)
|
| 139 |
+
|
| 140 |
+
slices = [slice(None)] * xw.ndim
|
| 141 |
+
slices[axis] = slice(0, None, hop_length)
|
| 142 |
+
|
| 143 |
+
return np.sqrt(np.mean(np.abs(xw[tuple(slices)]) ** 2, axis=-2, keepdims=True))
|
| 144 |
+
|
| 145 |
+
def cut(audio, sr, db_thresh=-60, min_interval=250):
|
| 146 |
+
slicer = Slicer2(sr=sr, threshold=db_thresh, min_interval=min_interval)
|
| 147 |
+
return slicer.slice2(audio)
|
| 148 |
+
|
| 149 |
+
def restore(segments, total_len, dtype=np.float32):
|
| 150 |
+
out = []
|
| 151 |
+
last_end = 0
|
| 152 |
+
|
| 153 |
+
for start, end, processed_seg in segments:
|
| 154 |
+
if start > last_end: out.append(np.zeros(start - last_end, dtype=dtype))
|
| 155 |
+
|
| 156 |
+
out.append(processed_seg)
|
| 157 |
+
last_end = end
|
| 158 |
+
|
| 159 |
+
if last_end < total_len: out.append(np.zeros(total_len - last_end, dtype=dtype))
|
| 160 |
+
return np.concatenate(out, axis=-1)
|
RVC/modules/download.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import shutil
|
| 5 |
+
|
| 6 |
+
sys.path.append(os.getcwd())
|
| 7 |
+
|
| 8 |
+
from modules.utils import HF_download_file
|
| 9 |
+
from modules import gdown, meganz, mediafire, pixeldrain
|
| 10 |
+
|
| 11 |
+
def move_files_from_directory(src_dir, dest_models, model_name):
|
| 12 |
+
for root, _, files in os.walk(src_dir):
|
| 13 |
+
for file in files:
|
| 14 |
+
file_path = os.path.join(root, file)
|
| 15 |
+
if file.endswith(".index"):
|
| 16 |
+
filepath = os.path.join(dest_models, file.replace(' ', '_').replace('(', '').replace(')', '').replace('[', '').replace(']', '').replace(",", "").replace('"', "").replace("'", "").replace("|", "").strip())
|
| 17 |
+
|
| 18 |
+
shutil.move(file_path, filepath)
|
| 19 |
+
elif file.endswith(".pth") and not file.startswith("D_") and not file.startswith("G_"):
|
| 20 |
+
pth_path = os.path.join(dest_models, model_name + ".pth")
|
| 21 |
+
|
| 22 |
+
shutil.move(file_path, pth_path)
|
| 23 |
+
|
| 24 |
+
def save_drop_model(dropbox):
|
| 25 |
+
model_folders = "rvc_models"
|
| 26 |
+
save_model_temp = "save_model_temp"
|
| 27 |
+
|
| 28 |
+
if not os.path.exists(model_folders): os.makedirs(model_folders, exist_ok=True)
|
| 29 |
+
if not os.path.exists(save_model_temp): os.makedirs(save_model_temp, exist_ok=True)
|
| 30 |
+
|
| 31 |
+
shutil.move(dropbox, save_model_temp)
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
print("[INFO] Start uploading...")
|
| 35 |
+
|
| 36 |
+
file_name = os.path.basename(dropbox)
|
| 37 |
+
model_folders = os.path.join(model_folders, file_name.replace(".zip", "").replace(".pth", "").replace(".index", ""))
|
| 38 |
+
|
| 39 |
+
if file_name.endswith(".zip"):
|
| 40 |
+
shutil.unpack_archive(os.path.join(save_model_temp, file_name), save_model_temp)
|
| 41 |
+
move_files_from_directory(save_model_temp, model_folders, file_name.replace(".zip", ""))
|
| 42 |
+
elif file_name.endswith(".pth"):
|
| 43 |
+
output_file = os.path.join(model_folders, file_name)
|
| 44 |
+
shutil.move(os.path.join(save_model_temp, file_name), output_file)
|
| 45 |
+
elif file_name.endswith(".index"):
|
| 46 |
+
def extract_name_model(filename):
|
| 47 |
+
match = re.search(r"([A-Za-z]+)(?=_v|\.|$)", filename)
|
| 48 |
+
return match.group(1) if match else None
|
| 49 |
+
|
| 50 |
+
model_logs = os.path.join(model_folders, extract_name_model(file_name))
|
| 51 |
+
if not os.path.exists(model_logs): os.makedirs(model_logs, exist_ok=True)
|
| 52 |
+
shutil.move(os.path.join(save_model_temp, file_name), model_logs)
|
| 53 |
+
else:
|
| 54 |
+
print("[WARNING] Format not supported. Supported formats ('.zip', '.pth', '.index')")
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
print("[INFO] Completed upload.")
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"[ERROR] An error occurred during unpack: {e}")
|
| 60 |
+
finally:
|
| 61 |
+
shutil.rmtree(save_model_temp, ignore_errors=True)
|
| 62 |
+
|
| 63 |
+
def download_model(url=None, model=None):
|
| 64 |
+
if not url:
|
| 65 |
+
print("[WARNING] Please provide a valid url.")
|
| 66 |
+
return
|
| 67 |
+
|
| 68 |
+
if not model:
|
| 69 |
+
print("[WARNING] Please provide a valid model name.")
|
| 70 |
+
return
|
| 71 |
+
|
| 72 |
+
model = model.replace(".pth", "").replace(".index", "").replace(".zip", "").replace(" ", "_").replace("(", "").replace(")", "").replace("[", "").replace("]", "").replace(",", "").replace('"', "").replace("'", "").replace("|", "").strip()
|
| 73 |
+
url = url.replace("/blob/", "/resolve/").replace("?download=true", "").strip()
|
| 74 |
+
|
| 75 |
+
download_dir = "download_model"
|
| 76 |
+
model_folders = "rvc_models"
|
| 77 |
+
|
| 78 |
+
if not os.path.exists(download_dir): os.makedirs(download_dir, exist_ok=True)
|
| 79 |
+
if not os.path.exists(model_folders): os.makedirs(model_folders, exist_ok=True)
|
| 80 |
+
|
| 81 |
+
model_folders = os.path.join(model_folders, model)
|
| 82 |
+
os.makedirs(model_folders, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
try:
|
| 85 |
+
print("[INFO] Start downloading...")
|
| 86 |
+
|
| 87 |
+
if url.endswith(".pth"): HF_download_file(url, os.path.join(model_folders, f"{model}.pth"))
|
| 88 |
+
elif url.endswith(".index"): HF_download_file(url, os.path.join(model_folders, f"{model}.index"))
|
| 89 |
+
elif url.endswith(".zip"):
|
| 90 |
+
output_path = HF_download_file(url, os.path.join(download_dir, model + ".zip"))
|
| 91 |
+
shutil.unpack_archive(output_path, download_dir)
|
| 92 |
+
|
| 93 |
+
move_files_from_directory(download_dir, model_folders, model)
|
| 94 |
+
else:
|
| 95 |
+
if "drive.google.com" in url or "drive.usercontent.google.com" in url:
|
| 96 |
+
file_id = None
|
| 97 |
+
|
| 98 |
+
if "/file/d/" in url: file_id = url.split("/d/")[1].split("/")[0]
|
| 99 |
+
elif "open?id=" in url: file_id = url.split("open?id=")[1].split("/")[0]
|
| 100 |
+
elif "/download?id=" in url: file_id = url.split("/download?id=")[1].split("&")[0]
|
| 101 |
+
|
| 102 |
+
if file_id:
|
| 103 |
+
file = gdown.gdown_download(id=file_id, output=download_dir)
|
| 104 |
+
if file.endswith(".zip"): shutil.unpack_archive(file, download_dir)
|
| 105 |
+
|
| 106 |
+
move_files_from_directory(download_dir, model_folders, model)
|
| 107 |
+
elif "mega.nz" in url:
|
| 108 |
+
meganz.mega_download_url(url, download_dir)
|
| 109 |
+
|
| 110 |
+
file_download = next((f for f in os.listdir(download_dir)), None)
|
| 111 |
+
if file_download.endswith(".zip"): shutil.unpack_archive(os.path.join(download_dir, file_download), download_dir)
|
| 112 |
+
|
| 113 |
+
move_files_from_directory(download_dir, model_folders, model)
|
| 114 |
+
elif "mediafire.com" in url:
|
| 115 |
+
file = mediafire.Mediafire_Download(url, download_dir)
|
| 116 |
+
if file.endswith(".zip"): shutil.unpack_archive(file, download_dir)
|
| 117 |
+
|
| 118 |
+
move_files_from_directory(download_dir, model_folders, model)
|
| 119 |
+
elif "pixeldrain.com" in url:
|
| 120 |
+
file = pixeldrain.pixeldrain(url, download_dir)
|
| 121 |
+
if file.endswith(".zip"): shutil.unpack_archive(file, download_dir)
|
| 122 |
+
|
| 123 |
+
move_files_from_directory(download_dir, model_folders, model)
|
| 124 |
+
else:
|
| 125 |
+
print("[WARNING] The url path is not supported.")
|
| 126 |
+
return
|
| 127 |
+
|
| 128 |
+
print("[INFO] Model download complete.")
|
| 129 |
+
except Exception as e:
|
| 130 |
+
print(f"[INFO] An error has occurred: {e}")
|
| 131 |
+
finally:
|
| 132 |
+
shutil.rmtree(download_dir, ignore_errors=True)
|
RVC/modules/encoders.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
sys.path.append(os.getcwd())
|
| 7 |
+
|
| 8 |
+
from modules.modules import WaveNet
|
| 9 |
+
from modules.commons import sequence_mask
|
| 10 |
+
from modules.normalization import LayerNorm
|
| 11 |
+
from modules.attentions import MultiHeadAttention, FFN
|
| 12 |
+
|
| 13 |
+
class Encoder(torch.nn.Module):
|
| 14 |
+
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.0, window_size=10, **kwargs):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.hidden_channels = hidden_channels
|
| 17 |
+
self.filter_channels = filter_channels
|
| 18 |
+
self.n_heads = n_heads
|
| 19 |
+
self.n_layers = n_layers
|
| 20 |
+
self.kernel_size = kernel_size
|
| 21 |
+
self.p_dropout = p_dropout
|
| 22 |
+
self.window_size = window_size
|
| 23 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
| 24 |
+
self.attn_layers = torch.nn.ModuleList()
|
| 25 |
+
self.norm_layers_1 = torch.nn.ModuleList()
|
| 26 |
+
self.ffn_layers = torch.nn.ModuleList()
|
| 27 |
+
self.norm_layers_2 = torch.nn.ModuleList()
|
| 28 |
+
|
| 29 |
+
for _ in range(self.n_layers):
|
| 30 |
+
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
|
| 31 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
| 32 |
+
|
| 33 |
+
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
|
| 34 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
| 35 |
+
|
| 36 |
+
def forward(self, x, x_mask):
|
| 37 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
| 38 |
+
x = x * x_mask
|
| 39 |
+
|
| 40 |
+
for i in range(self.n_layers):
|
| 41 |
+
x = self.norm_layers_1[i](x + self.drop(self.attn_layers[i](x, x, attn_mask)))
|
| 42 |
+
x = self.norm_layers_2[i](x + self.drop(self.ffn_layers[i](x, x_mask)))
|
| 43 |
+
|
| 44 |
+
return x * x_mask
|
| 45 |
+
|
| 46 |
+
class TextEncoder(torch.nn.Module):
|
| 47 |
+
def __init__(self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, embedding_dim, f0=True, energy=False, onnx=False):
|
| 48 |
+
super(TextEncoder, self).__init__()
|
| 49 |
+
self.out_channels = out_channels
|
| 50 |
+
self.hidden_channels = hidden_channels
|
| 51 |
+
self.filter_channels = filter_channels
|
| 52 |
+
self.n_heads = n_heads
|
| 53 |
+
self.n_layers = n_layers
|
| 54 |
+
self.kernel_size = kernel_size
|
| 55 |
+
self.p_dropout = float(p_dropout)
|
| 56 |
+
self.lrelu = torch.nn.LeakyReLU(0.1, inplace=True)
|
| 57 |
+
self.emb_phone = torch.nn.Linear(embedding_dim, hidden_channels)
|
| 58 |
+
self.emb_pitch = torch.nn.Embedding(256, hidden_channels) if f0 else None
|
| 59 |
+
self.emb_energy = torch.nn.Linear(1, hidden_channels) if energy else None
|
| 60 |
+
self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), onnx=onnx)
|
| 61 |
+
self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 62 |
+
|
| 63 |
+
def forward(self, phone, pitch, lengths, energy):
|
| 64 |
+
x = self.emb_phone(phone)
|
| 65 |
+
|
| 66 |
+
if pitch is not None: x += self.emb_pitch(pitch)
|
| 67 |
+
if energy is not None: x += self.emb_energy(energy.unsqueeze(-1))
|
| 68 |
+
|
| 69 |
+
x = torch.transpose(self.lrelu(x * math.sqrt(self.hidden_channels)), 1, -1)
|
| 70 |
+
x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype)
|
| 71 |
+
m, logs = torch.split((self.proj(self.encoder(x * x_mask, x_mask)) * x_mask), self.out_channels, dim=1)
|
| 72 |
+
|
| 73 |
+
return m, logs, x_mask
|
| 74 |
+
|
| 75 |
+
class PosteriorEncoder(torch.nn.Module):
|
| 76 |
+
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0):
|
| 77 |
+
super(PosteriorEncoder, self).__init__()
|
| 78 |
+
self.in_channels = in_channels
|
| 79 |
+
self.out_channels = out_channels
|
| 80 |
+
self.hidden_channels = hidden_channels
|
| 81 |
+
self.kernel_size = kernel_size
|
| 82 |
+
self.dilation_rate = dilation_rate
|
| 83 |
+
self.n_layers = n_layers
|
| 84 |
+
self.gin_channels = gin_channels
|
| 85 |
+
self.pre = torch.nn.Conv1d(in_channels, hidden_channels, 1)
|
| 86 |
+
self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
| 87 |
+
self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 88 |
+
|
| 89 |
+
def forward(self, x, x_lengths, g = None):
|
| 90 |
+
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
| 91 |
+
m, logs = torch.split((self.proj(self.enc((self.pre(x) * x_mask), x_mask, g=g)) * x_mask), self.out_channels, dim=1)
|
| 92 |
+
|
| 93 |
+
return ((m + torch.randn_like(m) * torch.exp(logs)) * x_mask), m, logs, x_mask
|
| 94 |
+
|
| 95 |
+
def remove_weight_norm(self):
|
| 96 |
+
self.enc.remove_weight_norm()
|
RVC/modules/fairseq.py
ADDED
|
@@ -0,0 +1,1396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import sys
|
| 3 |
+
import math
|
| 4 |
+
import uuid
|
| 5 |
+
import torch
|
| 6 |
+
import types
|
| 7 |
+
import contextlib
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
from torch import nn
|
| 13 |
+
from omegaconf import DictConfig, open_dict
|
| 14 |
+
|
| 15 |
+
class Dictionary:
|
| 16 |
+
def __init__(self, *args, **kwargs):
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
fairseq = types.ModuleType("fairseq")
|
| 20 |
+
fairseq_data = types.ModuleType("fairseq.data")
|
| 21 |
+
fairseq_data_dictionary = types.ModuleType("fairseq.data.dictionary")
|
| 22 |
+
fairseq_data_dictionary.Dictionary = Dictionary
|
| 23 |
+
fairseq.data = fairseq_data
|
| 24 |
+
fairseq_data.dictionary = fairseq_data_dictionary
|
| 25 |
+
sys.modules["fairseq"] = fairseq
|
| 26 |
+
sys.modules["fairseq.data"] = fairseq_data
|
| 27 |
+
sys.modules["fairseq.data.dictionary"] = fairseq_data_dictionary
|
| 28 |
+
|
| 29 |
+
def load_model(filename):
|
| 30 |
+
state = torch.load(filename, map_location="cpu")
|
| 31 |
+
model = HubertModel(HubertConfig(**state['cfg']['model']))
|
| 32 |
+
model.load_state_dict(state['model'], strict=False)
|
| 33 |
+
return model
|
| 34 |
+
|
| 35 |
+
def softmax(x, dim, onnx_trace = False):
|
| 36 |
+
return F.softmax(x.float(), dim=dim) if onnx_trace else F.softmax(x, dim=dim, dtype=torch.float32)
|
| 37 |
+
|
| 38 |
+
def log_softmax(x, dim, onnx_trace = False):
|
| 39 |
+
return F.log_softmax(x.float(), dim=dim) if onnx_trace else F.log_softmax(x, dim=dim, dtype=torch.float32)
|
| 40 |
+
|
| 41 |
+
def eval_str_dict(x, type=dict):
|
| 42 |
+
if x is None: return None
|
| 43 |
+
if isinstance(x, str): x = eval(x)
|
| 44 |
+
return x
|
| 45 |
+
|
| 46 |
+
def with_incremental_state(cls):
|
| 47 |
+
cls.__bases__ = (FairseqIncrementalState,) + tuple(b for b in cls.__bases__ if b != FairseqIncrementalState)
|
| 48 |
+
return cls
|
| 49 |
+
|
| 50 |
+
def quant_noise(module, p, block_size):
|
| 51 |
+
if p <= 0: return module
|
| 52 |
+
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
| 53 |
+
is_conv = module.weight.ndim == 4
|
| 54 |
+
if not is_conv: assert (module.weight.size(1) % block_size == 0)
|
| 55 |
+
else:
|
| 56 |
+
if module.kernel_size == (1, 1): assert (module.in_channels % block_size == 0)
|
| 57 |
+
else:
|
| 58 |
+
k = module.kernel_size[0] * module.kernel_size[1]
|
| 59 |
+
assert k % block_size == 0
|
| 60 |
+
|
| 61 |
+
def _forward_pre_hook(mod, input):
|
| 62 |
+
if mod.training:
|
| 63 |
+
if not is_conv:
|
| 64 |
+
weight = mod.weight
|
| 65 |
+
in_features = weight.size(1)
|
| 66 |
+
out_features = weight.size(0)
|
| 67 |
+
mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
|
| 68 |
+
mask.bernoulli_(p)
|
| 69 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
| 70 |
+
else:
|
| 71 |
+
weight = mod.weight
|
| 72 |
+
in_channels = mod.in_channels
|
| 73 |
+
out_channels = mod.out_channels
|
| 74 |
+
|
| 75 |
+
if mod.kernel_size == (1, 1):
|
| 76 |
+
mask = torch.zeros(int(in_channels // block_size * out_channels), device=weight.device)
|
| 77 |
+
mask.bernoulli_(p)
|
| 78 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
| 79 |
+
else:
|
| 80 |
+
mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
|
| 81 |
+
mask.bernoulli_(p)
|
| 82 |
+
mask = (mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]))
|
| 83 |
+
|
| 84 |
+
mask = mask.to(torch.bool)
|
| 85 |
+
s = 1 / (1 - p)
|
| 86 |
+
mod.weight.data = s * weight.masked_fill(mask, 0)
|
| 87 |
+
|
| 88 |
+
module.register_forward_pre_hook(_forward_pre_hook)
|
| 89 |
+
return module
|
| 90 |
+
|
| 91 |
+
class FairseqDropout(nn.Module):
|
| 92 |
+
def __init__(self, p, module_name=None):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.p = p
|
| 95 |
+
self.module_name = module_name
|
| 96 |
+
self.apply_during_inference = False
|
| 97 |
+
|
| 98 |
+
def forward(self, x, inplace = False):
|
| 99 |
+
return F.dropout(x, p=self.p, training=True, inplace=inplace) if self.p > 0 and (self.training or self.apply_during_inference) else x
|
| 100 |
+
|
| 101 |
+
def make_generation_fast_(self, name, retain_dropout = False, retain_dropout_modules = None, **kwargs):
|
| 102 |
+
if retain_dropout:
|
| 103 |
+
if (retain_dropout_modules is None or self.module_name in retain_dropout_modules): self.apply_during_inference = True
|
| 104 |
+
|
| 105 |
+
class FairseqIncrementalState(object):
|
| 106 |
+
def __init__(self, *args, **kwargs):
|
| 107 |
+
super().__init__(*args, **kwargs)
|
| 108 |
+
self.init_incremental_state()
|
| 109 |
+
|
| 110 |
+
def init_incremental_state(self):
|
| 111 |
+
self._incremental_state_id = str(uuid.uuid4())
|
| 112 |
+
|
| 113 |
+
def _get_full_incremental_state_key(self, key):
|
| 114 |
+
return "{}.{}".format(self._incremental_state_id, key)
|
| 115 |
+
|
| 116 |
+
def get_incremental_state(self, incremental_state, key):
|
| 117 |
+
full_key = self._get_full_incremental_state_key(key)
|
| 118 |
+
if incremental_state is None or full_key not in incremental_state: return None
|
| 119 |
+
return incremental_state[full_key]
|
| 120 |
+
|
| 121 |
+
def set_incremental_state(self, incremental_state, key, value):
|
| 122 |
+
if incremental_state is not None: incremental_state[self._get_full_incremental_state_key(key)] = value
|
| 123 |
+
return incremental_state
|
| 124 |
+
|
| 125 |
+
class FairseqDecoder(nn.Module):
|
| 126 |
+
def __init__(self, dictionary):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.dictionary = dictionary
|
| 129 |
+
self.onnx_trace = False
|
| 130 |
+
self.adaptive_softmax = None
|
| 131 |
+
|
| 132 |
+
def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
|
| 133 |
+
x, extra = self.extract_features(prev_output_tokens, encoder_out=encoder_out, **kwargs)
|
| 134 |
+
return self.output_layer(x), extra
|
| 135 |
+
|
| 136 |
+
def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs):
|
| 137 |
+
pass
|
| 138 |
+
|
| 139 |
+
def output_layer(self, features, **kwargs):
|
| 140 |
+
pass
|
| 141 |
+
|
| 142 |
+
def get_normalized_probs(self, net_output, log_probs, sample = None):
|
| 143 |
+
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
|
| 144 |
+
|
| 145 |
+
def get_normalized_probs_scriptable(self, net_output, log_probs, sample = None):
|
| 146 |
+
if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
|
| 147 |
+
if sample is not None:
|
| 148 |
+
assert "target" in sample
|
| 149 |
+
target = sample["target"]
|
| 150 |
+
else: target = None
|
| 151 |
+
out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
|
| 152 |
+
return out.exp_() if not log_probs else out
|
| 153 |
+
|
| 154 |
+
logits = net_output[0]
|
| 155 |
+
return log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) if log_probs else softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
|
| 156 |
+
|
| 157 |
+
def max_positions(self):
|
| 158 |
+
return 1e6
|
| 159 |
+
|
| 160 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 161 |
+
return state_dict
|
| 162 |
+
|
| 163 |
+
def prepare_for_onnx_export_(self):
|
| 164 |
+
self.onnx_trace = True
|
| 165 |
+
|
| 166 |
+
@with_incremental_state
|
| 167 |
+
class FairseqIncrementalDecoder(FairseqDecoder):
|
| 168 |
+
def __init__(self, dictionary):
|
| 169 |
+
super().__init__(dictionary)
|
| 170 |
+
|
| 171 |
+
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
|
| 172 |
+
pass
|
| 173 |
+
|
| 174 |
+
def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
|
| 175 |
+
pass
|
| 176 |
+
|
| 177 |
+
def reorder_incremental_state(self, incremental_state, new_order):
|
| 178 |
+
pass
|
| 179 |
+
|
| 180 |
+
def reorder_incremental_state_scripting(self, incremental_state, new_order):
|
| 181 |
+
for module in self.modules():
|
| 182 |
+
if hasattr(module, "reorder_incremental_state"):
|
| 183 |
+
result = module.reorder_incremental_state(incremental_state, new_order)
|
| 184 |
+
if result is not None: incremental_state = result
|
| 185 |
+
|
| 186 |
+
def set_beam_size(self, beam_size):
|
| 187 |
+
if getattr(self, "_beam_size", -1) != beam_size:
|
| 188 |
+
seen = set()
|
| 189 |
+
|
| 190 |
+
def apply_set_beam_size(module):
|
| 191 |
+
if (module != self and hasattr(module, "set_beam_size") and module not in seen):
|
| 192 |
+
seen.add(module)
|
| 193 |
+
module.set_beam_size(beam_size)
|
| 194 |
+
|
| 195 |
+
self.apply(apply_set_beam_size)
|
| 196 |
+
self._beam_size = beam_size
|
| 197 |
+
|
| 198 |
+
class MultiheadAttention(FairseqIncrementalDecoder):
|
| 199 |
+
def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, dictionary=None, q_noise=0.0, qn_block_size=8, xformers_att_config=None, xformers_blocksparse_layout=None, xformers_blocksparse_blocksize=16):
|
| 200 |
+
super().__init__(dictionary)
|
| 201 |
+
xformers_att_config = eval_str_dict(xformers_att_config)
|
| 202 |
+
self.use_xformers = xformers_att_config is not None
|
| 203 |
+
if self.use_xformers: raise ImportError
|
| 204 |
+
self.embed_dim = embed_dim
|
| 205 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 206 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 207 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
| 208 |
+
self.num_heads = num_heads
|
| 209 |
+
self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__)
|
| 210 |
+
self.head_dim = embed_dim // num_heads
|
| 211 |
+
assert (self.head_dim * num_heads == self.embed_dim)
|
| 212 |
+
self.scaling = self.head_dim**-0.5
|
| 213 |
+
self.self_attention = self_attention
|
| 214 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
| 215 |
+
assert not self.self_attention or self.qkv_same_dim
|
| 216 |
+
self.k_proj = quant_noise(nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size)
|
| 217 |
+
self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
|
| 218 |
+
self.q_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
|
| 219 |
+
self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
|
| 220 |
+
if add_bias_kv: self.bias_k, self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim)), nn.Parameter(torch.Tensor(1, 1, embed_dim))
|
| 221 |
+
else: self.bias_k = self.bias_v = None
|
| 222 |
+
self.add_zero_attn = add_zero_attn
|
| 223 |
+
self.beam_size = 1
|
| 224 |
+
self.reset_parameters()
|
| 225 |
+
self.onnx_trace = False
|
| 226 |
+
self.skip_embed_dim_check = False
|
| 227 |
+
self.init_incremental_state()
|
| 228 |
+
|
| 229 |
+
def prepare_for_onnx_export_(self):
|
| 230 |
+
self.onnx_trace = True
|
| 231 |
+
|
| 232 |
+
def reset_parameters(self):
|
| 233 |
+
if self.qkv_same_dim:
|
| 234 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
| 235 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
| 236 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
| 237 |
+
else:
|
| 238 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
| 239 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
| 240 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
| 241 |
+
|
| 242 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
| 243 |
+
if self.out_proj.bias is not None: nn.init.constant_(self.out_proj.bias, 0.0)
|
| 244 |
+
if self.bias_k is not None: nn.init.xavier_normal_(self.bias_k)
|
| 245 |
+
if self.bias_v is not None: nn.init.xavier_normal_(self.bias_v)
|
| 246 |
+
|
| 247 |
+
def _get_reserve_head_index(self, num_heads_to_keep: int):
|
| 248 |
+
k_proj_heads_norm, q_proj_heads_norm, v_proj_heads_norm = [], [], []
|
| 249 |
+
for i in range(self.num_heads):
|
| 250 |
+
start_idx = i * self.head_dim
|
| 251 |
+
end_idx = (i + 1) * self.head_dim
|
| 252 |
+
k_proj_heads_norm.append(torch.sum(torch.abs(self.k_proj.weight[start_idx:end_idx])).tolist() + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist())
|
| 253 |
+
q_proj_heads_norm.append(torch.sum(torch.abs(self.q_proj.weight[start_idx:end_idx])).tolist() + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist())
|
| 254 |
+
v_proj_heads_norm.append(torch.sum(torch.abs(self.v_proj.weight[start_idx:end_idx])).tolist() + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist())
|
| 255 |
+
|
| 256 |
+
heads_norm = []
|
| 257 |
+
for i in range(self.num_heads):
|
| 258 |
+
heads_norm.append(k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i])
|
| 259 |
+
|
| 260 |
+
sorted_head_index = sorted(range(self.num_heads), key=lambda k: heads_norm[k], reverse=True)
|
| 261 |
+
reserve_head_index = []
|
| 262 |
+
for i in range(num_heads_to_keep):
|
| 263 |
+
reserve_head_index.append((sorted_head_index[i] * self.head_dim, (sorted_head_index[i] + 1) * self.head_dim))
|
| 264 |
+
return reserve_head_index
|
| 265 |
+
|
| 266 |
+
def _adaptive_prune_heads(self, reserve_head_index):
|
| 267 |
+
new_q_weight, new_q_bias, new_k_weight, new_k_bias, new_v_weight, new_v_bias, new_out_proj_weight = [], [], [], [], [], [], []
|
| 268 |
+
for ele in reserve_head_index:
|
| 269 |
+
start_idx, end_idx = ele
|
| 270 |
+
new_q_weight.append(self.q_proj.weight[start_idx:end_idx])
|
| 271 |
+
new_q_bias.append(self.q_proj.bias[start_idx:end_idx])
|
| 272 |
+
new_k_weight.append(self.k_proj.weight[start_idx:end_idx])
|
| 273 |
+
new_k_bias.append(self.k_proj.bias[start_idx:end_idx])
|
| 274 |
+
new_v_weight.append(self.v_proj.weight[start_idx:end_idx])
|
| 275 |
+
new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
|
| 276 |
+
new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx])
|
| 277 |
+
new_q_weight = torch.cat(new_q_weight).detach()
|
| 278 |
+
new_k_weight = torch.cat(new_k_weight).detach()
|
| 279 |
+
new_v_weight = torch.cat(new_v_weight).detach()
|
| 280 |
+
new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach()
|
| 281 |
+
new_q_weight.requires_grad = True
|
| 282 |
+
new_k_weight.requires_grad = True
|
| 283 |
+
new_v_weight.requires_grad = True
|
| 284 |
+
new_out_proj_weight.requires_grad = True
|
| 285 |
+
new_q_bias = torch.cat(new_q_bias).detach()
|
| 286 |
+
new_q_bias.requires_grad = True
|
| 287 |
+
new_k_bias = torch.cat(new_k_bias).detach()
|
| 288 |
+
new_k_bias.requires_grad = True
|
| 289 |
+
new_v_bias = torch.cat(new_v_bias).detach()
|
| 290 |
+
new_v_bias.requires_grad = True
|
| 291 |
+
self.q_proj.weight = nn.Parameter(new_q_weight)
|
| 292 |
+
self.q_proj.bias = nn.Parameter(new_q_bias)
|
| 293 |
+
self.k_proj.weight = nn.Parameter(new_k_weight)
|
| 294 |
+
self.k_proj.bias = nn.Parameter(new_k_bias)
|
| 295 |
+
self.v_proj.weight = nn.Parameter(new_v_weight)
|
| 296 |
+
self.v_proj.bias = nn.Parameter(new_v_bias)
|
| 297 |
+
self.out_proj.weight = nn.Parameter(new_out_proj_weight)
|
| 298 |
+
self.num_heads = len(reserve_head_index)
|
| 299 |
+
self.embed_dim = self.head_dim * self.num_heads
|
| 300 |
+
self.q_proj.out_features = self.embed_dim
|
| 301 |
+
self.k_proj.out_features = self.embed_dim
|
| 302 |
+
self.v_proj.out_features = self.embed_dim
|
| 303 |
+
|
| 304 |
+
def _set_skip_embed_dim_check(self):
|
| 305 |
+
self.skip_embed_dim_check = True
|
| 306 |
+
|
| 307 |
+
def _pad_masks(self, key_padding_mask, attn_mask):
|
| 308 |
+
if attn_mask is not None:
|
| 309 |
+
shape = attn_mask.size()[:-1] + torch.Size([1])
|
| 310 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1)
|
| 311 |
+
|
| 312 |
+
if key_padding_mask is not None:
|
| 313 |
+
shape = key_padding_mask.size()[:-1] + torch.Size([1])
|
| 314 |
+
key_padding_mask = torch.cat([key_padding_mask, key_padding_mask.new_zeros(shape)], dim=-1)
|
| 315 |
+
|
| 316 |
+
return key_padding_mask, attn_mask
|
| 317 |
+
|
| 318 |
+
def _add_bias(self, k, v, key_padding_mask, attn_mask, bsz):
|
| 319 |
+
assert self.bias_k is not None or self.bias_v is not None
|
| 320 |
+
key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
| 321 |
+
return torch.cat([k, self.bias_k.repeat(1, bsz, 1)]), torch.cat([v, self.bias_v.repeat(1, bsz, 1)]), key_padding_mask, attn_mask
|
| 322 |
+
|
| 323 |
+
def _append_zero_attn(self, k, v, key_padding_mask, attn_mask):
|
| 324 |
+
zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:]
|
| 325 |
+
key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
| 326 |
+
return torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2), torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2), key_padding_mask, attn_mask
|
| 327 |
+
|
| 328 |
+
def forward(self, query, key, value, key_padding_mask = None, incremental_state = None, need_weights = True, static_kv = False, attn_mask = None, before_softmax = False, need_head_weights = False):
|
| 329 |
+
if need_head_weights: need_weights = True
|
| 330 |
+
is_tpu = query.device.type == "xla"
|
| 331 |
+
tgt_len, bsz, embed_dim = query.size()
|
| 332 |
+
src_len = tgt_len
|
| 333 |
+
if not self.skip_embed_dim_check: assert (embed_dim == self.embed_dim)
|
| 334 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
| 335 |
+
if key is not None:
|
| 336 |
+
src_len, key_bsz, _ = key.size()
|
| 337 |
+
if not torch.jit.is_scripting():
|
| 338 |
+
assert value is not None
|
| 339 |
+
assert src_len, key_bsz == value.shape[:2]
|
| 340 |
+
|
| 341 |
+
if (not self.onnx_trace and not is_tpu and incremental_state is None and not static_kv and not torch.jit.is_scripting() and not self.skip_embed_dim_check):
|
| 342 |
+
assert key is not None and value is not None
|
| 343 |
+
return F.multi_head_attention_forward(query, key, value, self.embed_dim, self.num_heads, torch.empty([0]), torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), self.bias_k, self.bias_v, self.add_zero_attn, self.dropout_module.p, self.out_proj.weight, self.out_proj.bias, self.training or self.dropout_module.apply_during_inference, key_padding_mask.bool() if key_padding_mask is not None else None, need_weights, attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight)
|
| 344 |
+
|
| 345 |
+
if incremental_state is not None:
|
| 346 |
+
saved_state = self._get_input_buffer(incremental_state)
|
| 347 |
+
if saved_state is not None and "prev_key" in saved_state:
|
| 348 |
+
if static_kv:
|
| 349 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
| 350 |
+
key = value = None
|
| 351 |
+
else: saved_state = None
|
| 352 |
+
|
| 353 |
+
if self.self_attention:
|
| 354 |
+
q = self.q_proj(query)
|
| 355 |
+
k = self.k_proj(query)
|
| 356 |
+
v = self.v_proj(query)
|
| 357 |
+
elif self.encoder_decoder_attention:
|
| 358 |
+
q = self.q_proj(query)
|
| 359 |
+
if key is None:
|
| 360 |
+
assert value is None
|
| 361 |
+
k = v = None
|
| 362 |
+
else:
|
| 363 |
+
if self.beam_size > 1 and bsz == key.size(1):
|
| 364 |
+
key = key.view(key.size(0), -1, self.beam_size, key.size(2))[:, :, 0, :]
|
| 365 |
+
if key_padding_mask is not None: key_padding_mask = key_padding_mask.view(-1, self.beam_size, key_padding_mask.size(1))[:, 0, :]
|
| 366 |
+
k = self.k_proj(key)
|
| 367 |
+
v = self.v_proj(key)
|
| 368 |
+
else:
|
| 369 |
+
assert key is not None and value is not None
|
| 370 |
+
q = self.q_proj(query)
|
| 371 |
+
k = self.k_proj(key)
|
| 372 |
+
v = self.v_proj(value)
|
| 373 |
+
|
| 374 |
+
q *= self.scaling
|
| 375 |
+
if self.bias_k is not None:
|
| 376 |
+
assert self.bias_v is not None
|
| 377 |
+
k, v, attn_mask, key_padding_mask = self._add_bias(k, v, attn_mask, key_padding_mask, bsz)
|
| 378 |
+
|
| 379 |
+
q = (q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1))
|
| 380 |
+
kv_bsz = bsz
|
| 381 |
+
if k is not None:
|
| 382 |
+
kv_bsz = k.size(1)
|
| 383 |
+
k = (k.contiguous().view(-1, kv_bsz * self.num_heads, self.head_dim).transpose(0, 1))
|
| 384 |
+
|
| 385 |
+
if v is not None: v = (v.contiguous().view(-1, kv_bsz * self.num_heads, self.head_dim).transpose(0, 1))
|
| 386 |
+
if saved_state is not None:
|
| 387 |
+
if "prev_key" in saved_state:
|
| 388 |
+
_prev_key = saved_state["prev_key"]
|
| 389 |
+
assert _prev_key is not None
|
| 390 |
+
|
| 391 |
+
kv_bsz = _prev_key.size(0)
|
| 392 |
+
prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim)
|
| 393 |
+
|
| 394 |
+
if static_kv: k = prev_key
|
| 395 |
+
else:
|
| 396 |
+
assert k is not None
|
| 397 |
+
k = torch.cat([prev_key, k], dim=1)
|
| 398 |
+
src_len = k.size(1)
|
| 399 |
+
|
| 400 |
+
if "prev_value" in saved_state:
|
| 401 |
+
_prev_value = saved_state["prev_value"]
|
| 402 |
+
assert _prev_value is not None or kv_bsz == _prev_value.size(0)
|
| 403 |
+
prev_value = _prev_value.view(kv_bsz * self.num_heads, -1, self.head_dim)
|
| 404 |
+
if static_kv: v = prev_value
|
| 405 |
+
else:
|
| 406 |
+
assert v is not None
|
| 407 |
+
v = torch.cat([prev_value, v], dim=1)
|
| 408 |
+
|
| 409 |
+
prev_key_padding_mask = None
|
| 410 |
+
if "prev_key_padding_mask" in saved_state: prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
| 411 |
+
assert k is not None and v is not None
|
| 412 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(key_padding_mask=key_padding_mask, prev_key_padding_mask=prev_key_padding_mask, batch_size=kv_bsz, src_len=k.size(1), static_kv=static_kv)
|
| 413 |
+
saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim)
|
| 414 |
+
saved_state["prev_value"] = v.view(kv_bsz, self.num_heads, -1, self.head_dim)
|
| 415 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
| 416 |
+
assert incremental_state is not None
|
| 417 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
| 418 |
+
|
| 419 |
+
assert k is not None
|
| 420 |
+
assert k.size(1) == src_len
|
| 421 |
+
|
| 422 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0: key_padding_mask = None
|
| 423 |
+
|
| 424 |
+
if key_padding_mask is not None:
|
| 425 |
+
assert key_padding_mask.size(0) == kv_bsz
|
| 426 |
+
assert key_padding_mask.size(1) == src_len
|
| 427 |
+
|
| 428 |
+
if self.add_zero_attn:
|
| 429 |
+
assert v is not None
|
| 430 |
+
src_len += 1
|
| 431 |
+
k, v, key_padding_mask, attn_mask = self._append_zero_attn(k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
|
| 432 |
+
|
| 433 |
+
if self.encoder_decoder_attention and bsz != kv_bsz:
|
| 434 |
+
attn_weights = torch.einsum("bxhtd,bhsd->bxhts", q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]), k.view((kv_bsz, self.num_heads) + k.size()[1:]))
|
| 435 |
+
attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:])
|
| 436 |
+
else: attn_weights = torch.bmm(q, k.transpose(1, 2))
|
| 437 |
+
|
| 438 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
| 439 |
+
|
| 440 |
+
if attn_mask is not None:
|
| 441 |
+
attn_mask = attn_mask.unsqueeze(0)
|
| 442 |
+
if self.onnx_trace: attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
|
| 443 |
+
attn_weights += attn_mask
|
| 444 |
+
|
| 445 |
+
if key_padding_mask is not None:
|
| 446 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
| 447 |
+
attn_weights = attn_weights.view(kv_bsz, -1, self.num_heads, tgt_len, src_len).masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(torch.bool), float("-inf")) if not is_tpu else attn_weights.transpose(0, 2).masked_fill(key_padding_mask, float("-inf")).transpose(0, 2)
|
| 448 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
| 449 |
+
|
| 450 |
+
if before_softmax: return attn_weights, v
|
| 451 |
+
attn_weights_float = softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
|
| 452 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
| 453 |
+
attn_probs = self.dropout_module(attn_weights)
|
| 454 |
+
assert v is not None
|
| 455 |
+
attn = None
|
| 456 |
+
|
| 457 |
+
if self.encoder_decoder_attention and bsz != kv_bsz:
|
| 458 |
+
attn = torch.einsum("bxhts,bhsd->bxhtd", attn_probs.view((kv_bsz, -1, self.num_heads) + attn_probs.size()[1:]), v.view((kv_bsz, self.num_heads) + v.size()[1:]))
|
| 459 |
+
attn = attn.reshape((-1,) + attn.size()[-2:])
|
| 460 |
+
else: attn = torch.bmm(attn_probs, v)
|
| 461 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
| 462 |
+
|
| 463 |
+
attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim) if self.onnx_trace and attn.size(1) == 1 else attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
|
| 464 |
+
attn = self.out_proj(attn)
|
| 465 |
+
attn_weights = None
|
| 466 |
+
|
| 467 |
+
if need_weights:
|
| 468 |
+
attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
|
| 469 |
+
if not need_head_weights: attn_weights = attn_weights.mean(dim=0)
|
| 470 |
+
|
| 471 |
+
return attn, attn_weights
|
| 472 |
+
|
| 473 |
+
@staticmethod
|
| 474 |
+
def _append_prev_key_padding_mask(key_padding_mask, prev_key_padding_mask, batch_size, src_len, static_kv):
|
| 475 |
+
if prev_key_padding_mask is not None and static_kv: new_key_padding_mask = prev_key_padding_mask
|
| 476 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None: new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
|
| 477 |
+
elif prev_key_padding_mask is not None:
|
| 478 |
+
if src_len > prev_key_padding_mask.size(1):
|
| 479 |
+
filler = torch.zeros((batch_size, src_len - prev_key_padding_mask.size(1)), device=prev_key_padding_mask.device)
|
| 480 |
+
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
|
| 481 |
+
else: new_key_padding_mask = prev_key_padding_mask.float()
|
| 482 |
+
elif key_padding_mask is not None:
|
| 483 |
+
if src_len > key_padding_mask.size(1):
|
| 484 |
+
filler = torch.zeros((batch_size, src_len - key_padding_mask.size(1)), device=key_padding_mask.device)
|
| 485 |
+
new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
|
| 486 |
+
else: new_key_padding_mask = key_padding_mask.float()
|
| 487 |
+
else: new_key_padding_mask = prev_key_padding_mask
|
| 488 |
+
return new_key_padding_mask
|
| 489 |
+
|
| 490 |
+
@torch.jit.export
|
| 491 |
+
def reorder_incremental_state(self, incremental_state, new_order):
|
| 492 |
+
input_buffer = self._get_input_buffer(incremental_state)
|
| 493 |
+
if input_buffer is not None:
|
| 494 |
+
for k in input_buffer.keys():
|
| 495 |
+
input_buffer_k = input_buffer[k]
|
| 496 |
+
if input_buffer_k is not None:
|
| 497 |
+
if self.encoder_decoder_attention:
|
| 498 |
+
if input_buffer_k.size(0) * self.beam_size == new_order.size(0): return incremental_state
|
| 499 |
+
elif self.beam_size > 1: input_buffer[k] = input_buffer_k.index_select(0, new_order.reshape(-1, self.beam_size)[:, 0] // self.beam_size)
|
| 500 |
+
else: input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
| 501 |
+
else: input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
| 502 |
+
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
| 503 |
+
return incremental_state
|
| 504 |
+
|
| 505 |
+
def set_beam_size(self, beam_size):
|
| 506 |
+
self.beam_size = beam_size
|
| 507 |
+
|
| 508 |
+
def _get_input_buffer(self, incremental_state):
|
| 509 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
| 510 |
+
return result if result is not None else {}
|
| 511 |
+
|
| 512 |
+
def _set_input_buffer(self, incremental_state, buffer):
|
| 513 |
+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
| 514 |
+
|
| 515 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 516 |
+
prefix = name + "." if name != "" else ""
|
| 517 |
+
items_to_add, keys_to_remove = {}, []
|
| 518 |
+
for k in state_dict.keys():
|
| 519 |
+
if k.endswith(prefix + "in_proj_weight"):
|
| 520 |
+
dim = int(state_dict[k].shape[0] / 3)
|
| 521 |
+
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
| 522 |
+
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
|
| 523 |
+
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
|
| 524 |
+
keys_to_remove.append(k)
|
| 525 |
+
k_bias = prefix + "in_proj_bias"
|
| 526 |
+
if k_bias in state_dict.keys():
|
| 527 |
+
dim = int(state_dict[k].shape[0] / 3)
|
| 528 |
+
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
|
| 529 |
+
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim]
|
| 530 |
+
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
|
| 531 |
+
keys_to_remove.append(prefix + "in_proj_bias")
|
| 532 |
+
|
| 533 |
+
for k in keys_to_remove:
|
| 534 |
+
del state_dict[k]
|
| 535 |
+
|
| 536 |
+
for key, value in items_to_add.items():
|
| 537 |
+
state_dict[key] = value
|
| 538 |
+
|
| 539 |
+
def init_bert_params(module):
|
| 540 |
+
def normal_(data):
|
| 541 |
+
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
| 542 |
+
|
| 543 |
+
if isinstance(module, nn.Linear):
|
| 544 |
+
normal_(module.weight.data)
|
| 545 |
+
if module.bias is not None: module.bias.data.zero_()
|
| 546 |
+
if isinstance(module, nn.Embedding):
|
| 547 |
+
normal_(module.weight.data)
|
| 548 |
+
if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_()
|
| 549 |
+
if isinstance(module, MultiheadAttention):
|
| 550 |
+
normal_(module.q_proj.weight.data)
|
| 551 |
+
normal_(module.k_proj.weight.data)
|
| 552 |
+
normal_(module.v_proj.weight.data)
|
| 553 |
+
|
| 554 |
+
def make_conv_pos(e, k, g):
|
| 555 |
+
pos_conv = nn.Conv1d(e, e, kernel_size=k, padding=k // 2, groups=g)
|
| 556 |
+
dropout = 0
|
| 557 |
+
nn.init.normal_(pos_conv.weight, mean=0, std=math.sqrt((4 * (1.0 - dropout)) / (k * e)))
|
| 558 |
+
nn.init.constant_(pos_conv.bias, 0)
|
| 559 |
+
return nn.Sequential(nn.utils.parametrizations.weight_norm(pos_conv, name="weight", dim=2), SamePad(k), nn.GELU())
|
| 560 |
+
|
| 561 |
+
def is_xla_tensor(tensor):
|
| 562 |
+
return torch.is_tensor(tensor) and tensor.device.type == "xla"
|
| 563 |
+
|
| 564 |
+
def index_put(tensor, indices, value):
|
| 565 |
+
if is_xla_tensor(tensor):
|
| 566 |
+
for _ in range(indices.dim(), tensor.dim()):
|
| 567 |
+
indices = indices.unsqueeze(-1)
|
| 568 |
+
|
| 569 |
+
if indices.size(-1) < tensor.size(-1): indices = indices.expand_as(tensor)
|
| 570 |
+
tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
|
| 571 |
+
else: tensor[indices] = value
|
| 572 |
+
|
| 573 |
+
return tensor
|
| 574 |
+
|
| 575 |
+
def pad_to_multiple(x, multiple, dim=-1, value=0):
|
| 576 |
+
if x is None: return None, 0
|
| 577 |
+
tsz = x.size(dim)
|
| 578 |
+
m = tsz / multiple
|
| 579 |
+
remainder = math.ceil(m) * multiple - tsz
|
| 580 |
+
if m.is_integer(): return x, 0
|
| 581 |
+
return F.pad(x, (*((0,) * (-1 - dim) * 2), 0, remainder), value=value), remainder
|
| 582 |
+
|
| 583 |
+
def compute_mask_indices(shape, padding_mask, mask_prob, mask_length, mask_type = "static", mask_other = 0.0, min_masks = 0, no_overlap = False, min_space = 0, require_same_masks = True, mask_dropout = 0.0, add_masks = False, seed = None, epoch = None, indices = None, idc_select_ver = 1, num_mask_ver = 2):
|
| 584 |
+
bsz, all_sz = shape
|
| 585 |
+
mask = np.full((bsz, all_sz), False)
|
| 586 |
+
if num_mask_ver == 1: all_num_mask = max(min_masks, int(mask_prob * all_sz / float(mask_length) + np.random.rand()))
|
| 587 |
+
mask_idcs = []
|
| 588 |
+
|
| 589 |
+
for i in range(bsz):
|
| 590 |
+
seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) if seed is not None and epoch is not None and indices is not None else None
|
| 591 |
+
rng = np.random.default_rng(seed_i)
|
| 592 |
+
|
| 593 |
+
if padding_mask is not None:
|
| 594 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
| 595 |
+
assert sz >= 0, sz
|
| 596 |
+
else: sz = all_sz
|
| 597 |
+
|
| 598 |
+
if num_mask_ver == 1: num_mask = max(min_masks, int(mask_prob * sz / float(mask_length) + np.random.rand())) if padding_mask is not None else all_num_mask
|
| 599 |
+
elif num_mask_ver == 2: num_mask = max(min_masks, int(mask_prob * sz / float(mask_length) + rng.random()))
|
| 600 |
+
else: raise ValueError
|
| 601 |
+
|
| 602 |
+
if mask_type == "static": lengths = np.full(num_mask, mask_length)
|
| 603 |
+
elif mask_type == "uniform": lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
| 604 |
+
elif mask_type == "normal": lengths = [max(1, int(round(x))) for x in rng.normal(mask_length, mask_other, size=num_mask)]
|
| 605 |
+
elif mask_type == "poisson": lengths = [int(round(x)) for x in rng.poisson(mask_length, size=num_mask)]
|
| 606 |
+
else: raise Exception
|
| 607 |
+
|
| 608 |
+
if sum(lengths) == 0:
|
| 609 |
+
if mask_type == "static": raise ValueError
|
| 610 |
+
else: lengths = [min(mask_length, sz - 1)]
|
| 611 |
+
|
| 612 |
+
if no_overlap:
|
| 613 |
+
mask_idc = []
|
| 614 |
+
|
| 615 |
+
def arrange(s, e, length, keep_length):
|
| 616 |
+
span_start = rng.randint(s, e - length)
|
| 617 |
+
mask_idc.extend(span_start + i for i in range(length))
|
| 618 |
+
new_parts = []
|
| 619 |
+
if span_start - s - min_space >= keep_length: new_parts.append((s, span_start - min_space + 1))
|
| 620 |
+
if e - span_start - length - min_space > keep_length: new_parts.append((span_start + length + min_space, e))
|
| 621 |
+
return new_parts
|
| 622 |
+
|
| 623 |
+
parts = [(0, sz)]
|
| 624 |
+
min_length = min(lengths)
|
| 625 |
+
for length in sorted(lengths, reverse=True):
|
| 626 |
+
lens = np.fromiter((e - s if e - s >= length + min_space else 0 for s, e in parts), np.int32)
|
| 627 |
+
l_sum = np.sum(lens)
|
| 628 |
+
if l_sum == 0: break
|
| 629 |
+
s, e = parts.pop(rng.choice(len(parts), p=lens / np.sum(lens)))
|
| 630 |
+
parts.extend(arrange(s, e, length, min_length))
|
| 631 |
+
mask_idc = np.asarray(mask_idc)
|
| 632 |
+
else:
|
| 633 |
+
if idc_select_ver == 1:
|
| 634 |
+
min_len = min(lengths)
|
| 635 |
+
if sz - min_len <= num_mask: min_len = sz - num_mask - 1
|
| 636 |
+
mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
|
| 637 |
+
elif idc_select_ver == 2: mask_idc = rng.choice(sz, num_mask, replace=False)
|
| 638 |
+
else: raise ValueError
|
| 639 |
+
|
| 640 |
+
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
|
| 641 |
+
|
| 642 |
+
mask_idc = np.unique(mask_idc[mask_idc < sz])
|
| 643 |
+
if len(mask_idc) >= sz: raise ValueError
|
| 644 |
+
mask_idcs.append(mask_idc)
|
| 645 |
+
|
| 646 |
+
target_len = None
|
| 647 |
+
if require_same_masks: target_len = max([len(m) for m in mask_idcs]) if add_masks else min([len(m) for m in mask_idcs])
|
| 648 |
+
|
| 649 |
+
for i, mask_idc in enumerate(mask_idcs):
|
| 650 |
+
if target_len is not None and len(mask_idc) > target_len: mask_idc = rng.choice(mask_idc, target_len, replace=False)
|
| 651 |
+
mask[i, mask_idc] = True
|
| 652 |
+
|
| 653 |
+
if target_len is not None and len(mask_idc) < target_len:
|
| 654 |
+
to_mask = rng.choice(np.flatnonzero(~mask[i]), target_len - len(mask_idc), replace=False)
|
| 655 |
+
mask[i, to_mask] = True
|
| 656 |
+
|
| 657 |
+
if mask_dropout > 0:
|
| 658 |
+
masked = np.flatnonzero(mask[i])
|
| 659 |
+
mask[i, rng.choice(masked, np.rint(len(masked) * mask_dropout).astype(int), replace=False)] = False
|
| 660 |
+
|
| 661 |
+
return mask
|
| 662 |
+
|
| 663 |
+
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
|
| 664 |
+
return nn.LayerNorm(normalized_shape, eps, elementwise_affine)
|
| 665 |
+
|
| 666 |
+
def prune_state_dict(state_dict, model_cfg):
|
| 667 |
+
arch = None
|
| 668 |
+
if model_cfg is not None: arch = (model_cfg._name if isinstance(model_cfg, DictConfig) else getattr(model_cfg, "arch", None))
|
| 669 |
+
if not model_cfg or arch is None or arch == "ptt_transformer": return state_dict
|
| 670 |
+
encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
|
| 671 |
+
decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
|
| 672 |
+
if not encoder_layers_to_keep and not decoder_layers_to_keep: return state_dict
|
| 673 |
+
|
| 674 |
+
def create_pruning_pass(layers_to_keep, layer_name):
|
| 675 |
+
keep_layers = sorted(int(layer_string) for layer_string in layers_to_keep.split(","))
|
| 676 |
+
mapping_dict = {}
|
| 677 |
+
for i in range(len(keep_layers)):
|
| 678 |
+
mapping_dict[str(keep_layers[i])] = str(i)
|
| 679 |
+
|
| 680 |
+
return {"substitution_regex": re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name)), "mapping_dict": mapping_dict}
|
| 681 |
+
|
| 682 |
+
pruning_passes, new_state_dict = [], {}
|
| 683 |
+
if encoder_layers_to_keep: pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
|
| 684 |
+
if decoder_layers_to_keep: pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
|
| 685 |
+
|
| 686 |
+
for layer_name in state_dict.keys():
|
| 687 |
+
match = re.search(r"\.layers\.(\d+)\.", layer_name)
|
| 688 |
+
if not match:
|
| 689 |
+
new_state_dict[layer_name] = state_dict[layer_name]
|
| 690 |
+
continue
|
| 691 |
+
|
| 692 |
+
original_layer_number = match.group(1)
|
| 693 |
+
for pruning_pass in pruning_passes:
|
| 694 |
+
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass["substitution_regex"].search(layer_name):
|
| 695 |
+
substitution_match = pruning_pass["substitution_regex"].search(layer_name)
|
| 696 |
+
new_state_dict[(layer_name[: substitution_match.start(1)] + pruning_pass["mapping_dict"][original_layer_number] + layer_name[substitution_match.end(1) :])] = state_dict[layer_name]
|
| 697 |
+
|
| 698 |
+
with open_dict(model_cfg) if isinstance(model_cfg, DictConfig) else contextlib.ExitStack():
|
| 699 |
+
if hasattr(model_cfg, "encoder_layers_to_keep"): model_cfg.encoder_layers_to_keep = None
|
| 700 |
+
if hasattr(model_cfg, "decoder_layers_to_keep"): model_cfg.decoder_layers_to_keep = None
|
| 701 |
+
|
| 702 |
+
return new_state_dict
|
| 703 |
+
|
| 704 |
+
def relu_squared(x):
|
| 705 |
+
return F.relu(x).pow(2)
|
| 706 |
+
|
| 707 |
+
def get_activation_fn(activation):
|
| 708 |
+
def gelu(x):
|
| 709 |
+
return nn.functional.gelu(x.float()).type_as(x)
|
| 710 |
+
|
| 711 |
+
def gelu_accurate(x):
|
| 712 |
+
if not hasattr(gelu_accurate, "_a"):
|
| 713 |
+
gelu_accurate._a = math.sqrt(2 / math.pi)
|
| 714 |
+
return (0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))))
|
| 715 |
+
|
| 716 |
+
if activation == "relu": return F.relu
|
| 717 |
+
elif activation == "relu_squared": return relu_squared
|
| 718 |
+
elif activation == "gelu": return gelu
|
| 719 |
+
elif activation == "gelu_fast": return gelu_accurate
|
| 720 |
+
elif activation == "gelu_accurate": return gelu_accurate
|
| 721 |
+
elif activation == "tanh": return torch.tanh
|
| 722 |
+
elif activation == "linear": return lambda x: x
|
| 723 |
+
elif activation == "swish": return nn.SiLU
|
| 724 |
+
else: raise RuntimeError
|
| 725 |
+
|
| 726 |
+
class SamePad(nn.Module):
|
| 727 |
+
def __init__(self, kernel_size, causal=False):
|
| 728 |
+
super().__init__()
|
| 729 |
+
if causal: self.remove = kernel_size - 1
|
| 730 |
+
else: self.remove = 1 if kernel_size % 2 == 0 else 0
|
| 731 |
+
|
| 732 |
+
def forward(self, x):
|
| 733 |
+
if self.remove > 0: x = x[:, :, : -self.remove]
|
| 734 |
+
return x
|
| 735 |
+
|
| 736 |
+
class TransformerSentenceEncoderLayer(nn.Module):
|
| 737 |
+
def __init__(self, embedding_dim = 768, ffn_embedding_dim = 3072, num_attention_heads = 8, dropout = 0.1, attention_dropout = 0.1, activation_dropout = 0.1, activation_fn = "relu", layer_norm_first = False):
|
| 738 |
+
super().__init__()
|
| 739 |
+
self.embedding_dim = embedding_dim
|
| 740 |
+
self.dropout = dropout
|
| 741 |
+
self.activation_dropout = activation_dropout
|
| 742 |
+
self.activation_fn = get_activation_fn(activation_fn)
|
| 743 |
+
self.self_attn = MultiheadAttention(self.embedding_dim, num_attention_heads, dropout=attention_dropout, self_attention=True)
|
| 744 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 745 |
+
self.dropout2 = nn.Dropout(self.activation_dropout)
|
| 746 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 747 |
+
self.layer_norm_first = layer_norm_first
|
| 748 |
+
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
| 749 |
+
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
| 750 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
| 751 |
+
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
| 752 |
+
|
| 753 |
+
def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None):
|
| 754 |
+
residual = x
|
| 755 |
+
if self.layer_norm_first:
|
| 756 |
+
x = self.self_attn_layer_norm(x)
|
| 757 |
+
x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, attn_mask=self_attn_mask, need_weights=False)
|
| 758 |
+
x = residual + self.dropout1(x)
|
| 759 |
+
residual = x
|
| 760 |
+
x = self.fc2(self.dropout2(self.activation_fn(self.fc1(self.final_layer_norm(x)))))
|
| 761 |
+
layer_result = x
|
| 762 |
+
x = residual + self.dropout3(x)
|
| 763 |
+
else:
|
| 764 |
+
x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, need_weights=False)
|
| 765 |
+
x = self.self_attn_layer_norm(residual + self.dropout1(x))
|
| 766 |
+
residual = x
|
| 767 |
+
x = self.fc2(self.dropout2(self.activation_fn(self.fc1(x))))
|
| 768 |
+
layer_result = x
|
| 769 |
+
x = self.final_layer_norm(residual + self.dropout3(x))
|
| 770 |
+
|
| 771 |
+
return x, (attn, layer_result)
|
| 772 |
+
|
| 773 |
+
class AdapterFast(nn.Module):
|
| 774 |
+
def __init__(self, adapter_num, input_dim, hidden_dim, act_fn):
|
| 775 |
+
super().__init__()
|
| 776 |
+
self.adapter_num = adapter_num
|
| 777 |
+
self.input_dim = input_dim
|
| 778 |
+
self.hidden_dim = hidden_dim
|
| 779 |
+
self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim))
|
| 780 |
+
self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim))
|
| 781 |
+
self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim))
|
| 782 |
+
self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim))
|
| 783 |
+
self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim))
|
| 784 |
+
self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim))
|
| 785 |
+
self.act_fn = nn.Identity()
|
| 786 |
+
if act_fn == "relu": self.act_fn = nn.ReLU()
|
| 787 |
+
elif act_fn == "gelu": self.act_fn = nn.GELU()
|
| 788 |
+
elif act_fn == "selu": self.act_fn = nn.SELU()
|
| 789 |
+
else: raise ValueError
|
| 790 |
+
self.input_dim = input_dim
|
| 791 |
+
self.reset_parameters()
|
| 792 |
+
|
| 793 |
+
def reset_parameters(self):
|
| 794 |
+
for ii in range(self.adapter_num):
|
| 795 |
+
nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5))
|
| 796 |
+
nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5))
|
| 797 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii])
|
| 798 |
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| 799 |
+
nn.init.uniform_(self.b_a[ii], -bound, bound)
|
| 800 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii])
|
| 801 |
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
| 802 |
+
nn.init.uniform_(self.b_b[ii], -bound, bound)
|
| 803 |
+
|
| 804 |
+
nn.init.ones_(self.ln_W)
|
| 805 |
+
nn.init.zeros_(self.ln_b)
|
| 806 |
+
|
| 807 |
+
def forward(self, x, adapter_id):
|
| 808 |
+
ii = adapter_id
|
| 809 |
+
return F.linear(self.act_fn(F.linear(F.layer_norm(x, (self.input_dim, ), self.ln_W[ii], self.ln_b[ii]), self.W_a[ii], self.b_a[ii])), self.W_b[ii], self.b_b[ii])
|
| 810 |
+
|
| 811 |
+
def extra_repr(self):
|
| 812 |
+
return ('adapter={}, input_dim={}, hidden_dim={}'.format(self.adapter_num, self.input_dim, self.hidden_dim))
|
| 813 |
+
|
| 814 |
+
class FeedForwardModule(nn.Module):
|
| 815 |
+
def __init__(self, input_feat, hidden_units, dropout1, dropout2, activation_fn="swish", bias=True):
|
| 816 |
+
super(FeedForwardModule, self).__init__()
|
| 817 |
+
self.layer_norm = LayerNorm(input_feat)
|
| 818 |
+
self.w_1 = nn.Linear(input_feat, hidden_units, bias=bias)
|
| 819 |
+
self.w_2 = nn.Linear(hidden_units, input_feat, bias=bias)
|
| 820 |
+
self.dropout1 = nn.Dropout(dropout1)
|
| 821 |
+
self.dropout2 = nn.Dropout(dropout2)
|
| 822 |
+
self.activation = get_activation_fn(activation_fn)(hidden_units)
|
| 823 |
+
|
| 824 |
+
def forward(self, x):
|
| 825 |
+
return self.dropout2(self.w_2(self.dropout1(self.activation(self.w_1(self.layer_norm(x))))))
|
| 826 |
+
|
| 827 |
+
class ConvolutionModule(nn.Module):
|
| 828 |
+
def __init__(self, embed_dim, channels, depthwise_kernel_size, dropout, activation_fn="swish", bias=False, export=False):
|
| 829 |
+
super(ConvolutionModule, self).__init__()
|
| 830 |
+
assert (depthwise_kernel_size - 1) % 2 == 0
|
| 831 |
+
self.layer_norm = LayerNorm(embed_dim, export=export)
|
| 832 |
+
self.pointwise_conv1 = nn.Conv1d(embed_dim, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias)
|
| 833 |
+
self.glu = nn.GLU(dim=1)
|
| 834 |
+
self.depthwise_conv = nn.Conv1d(channels, channels, depthwise_kernel_size, stride=1, padding=(depthwise_kernel_size - 1) // 2, groups=channels, bias=bias)
|
| 835 |
+
self.batch_norm = nn.BatchNorm1d(channels)
|
| 836 |
+
self.activation = get_activation_fn(activation_fn)(channels)
|
| 837 |
+
self.pointwise_conv2 = nn.Conv1d(channels, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias)
|
| 838 |
+
self.dropout = nn.Dropout(dropout)
|
| 839 |
+
|
| 840 |
+
def forward(self, x):
|
| 841 |
+
return self.dropout(self.pointwise_conv2(self.activation(self.batch_norm(self.depthwise_conv(self.glu(self.pointwise_conv1(self.layer_norm(x).transpose(1, 2)))))))).transpose(1, 2)
|
| 842 |
+
|
| 843 |
+
def rotate_half(x):
|
| 844 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
| 845 |
+
return torch.cat((-x2, x1), dim=x1.ndim - 1)
|
| 846 |
+
|
| 847 |
+
def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
|
| 848 |
+
cos, sin = (cos[offset : q.shape[0] + offset, ...], sin[offset : q.shape[0] + offset, ...])
|
| 849 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
| 850 |
+
|
| 851 |
+
class RotaryPositionalEmbedding(nn.Module):
|
| 852 |
+
def __init__(self, dim, base=10000, precision=torch.half):
|
| 853 |
+
super().__init__()
|
| 854 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 855 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 856 |
+
self.seq_len_cached = 0
|
| 857 |
+
self.cos_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
|
| 858 |
+
self.sin_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
|
| 859 |
+
self.precision = precision
|
| 860 |
+
|
| 861 |
+
def forward(self, x, seq_len = 0):
|
| 862 |
+
if seq_len > self.seq_len_cached:
|
| 863 |
+
self.seq_len_cached = seq_len
|
| 864 |
+
freqs = torch.einsum("i,j->ij", torch.arange(seq_len, device=x.device).type_as(self.inv_freq), self.inv_freq)
|
| 865 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
| 866 |
+
self.cos_cached = emb.cos().view(emb.size(0), 1, 1, emb.size(1))
|
| 867 |
+
self.sin_cached = emb.sin().view(emb.size(0), 1, 1, emb.size(1))
|
| 868 |
+
return self.cos_cached, self.sin_cached
|
| 869 |
+
|
| 870 |
+
class ESPNETMultiHeadedAttention(nn.Module):
|
| 871 |
+
def __init__(self, n_feat, n_head, dropout):
|
| 872 |
+
super(ESPNETMultiHeadedAttention, self).__init__()
|
| 873 |
+
assert n_feat % n_head == 0
|
| 874 |
+
self.d_k = n_feat // n_head
|
| 875 |
+
self.h = n_head
|
| 876 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
| 877 |
+
self.linear_k = nn.Linear(n_feat, n_feat)
|
| 878 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
| 879 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
| 880 |
+
self.attn = None
|
| 881 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 882 |
+
|
| 883 |
+
def forward_qkv(self, query, key, value, **kwargs):
|
| 884 |
+
n_batch = query.size(0)
|
| 885 |
+
return self.linear_q(query).view(n_batch, -1, self.h, self.d_k).transpose(1, 2), self.linear_k(key).view(n_batch, -1, self.h, self.d_k).transpose(1, 2), self.linear_v(value).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
|
| 886 |
+
|
| 887 |
+
def forward_attention(self, value, scores, mask):
|
| 888 |
+
n_batch = value.size(0)
|
| 889 |
+
if mask is not None:
|
| 890 |
+
scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2).to(bool), float("-inf"))
|
| 891 |
+
self.attn = torch.softmax(scores, dim=-1)
|
| 892 |
+
else: self.attn = torch.softmax(scores, dim=-1)
|
| 893 |
+
|
| 894 |
+
return self.linear_out((torch.matmul(self.dropout(self.attn), value).transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)))
|
| 895 |
+
|
| 896 |
+
def forward(self, query, key, value, key_padding_mask=None, **kwargs):
|
| 897 |
+
q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
|
| 898 |
+
return self.forward_attention(v, torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
|
| 899 |
+
|
| 900 |
+
class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
|
| 901 |
+
def __init__(self, n_feat, n_head, dropout, zero_triu=False):
|
| 902 |
+
super().__init__(n_feat, n_head, dropout)
|
| 903 |
+
self.zero_triu = zero_triu
|
| 904 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
| 905 |
+
self.pos_bias_u = nn.Parameter(torch.zeros(self.h, self.d_k))
|
| 906 |
+
self.pos_bias_v = nn.Parameter(torch.zeros(self.h, self.d_k))
|
| 907 |
+
nn.init.xavier_uniform_(self.pos_bias_u)
|
| 908 |
+
nn.init.xavier_uniform_(self.pos_bias_v)
|
| 909 |
+
|
| 910 |
+
def rel_shift(self, x):
|
| 911 |
+
x = torch.cat([torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype), x], dim=-1).view(*x.size()[:2], x.size(3) + 1, x.size(2))[:, :, 1:].view_as(x)[:, :, :, : x.size(-1) // 2 + 1]
|
| 912 |
+
if self.zero_triu: x = x * torch.tril(torch.ones((x.size(2), x.size(3)), device=x.device), x.size(3) - x.size(2))[None, None, :, :]
|
| 913 |
+
return x
|
| 914 |
+
|
| 915 |
+
def forward(self, query, key, value, pos_emb, key_padding_mask=None, **kwargs):
|
| 916 |
+
pos_emb = pos_emb.transpose(0, 1)
|
| 917 |
+
q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
|
| 918 |
+
q = q.transpose(1, 2)
|
| 919 |
+
|
| 920 |
+
return self.forward_attention(v, (torch.matmul((q + self.pos_bias_u).transpose(1, 2), k.transpose(-2, -1)) + self.rel_shift(torch.matmul((q + self.pos_bias_v).transpose(1, 2), self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.h, self.d_k).transpose(1, 2).transpose(-2, -1)))) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
|
| 921 |
+
|
| 922 |
+
class RotaryPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
|
| 923 |
+
def __init__(self, n_feat, n_head, dropout, precision, rotary_emd_base=10000):
|
| 924 |
+
super().__init__(n_feat, n_head, dropout)
|
| 925 |
+
precision = torch.float
|
| 926 |
+
self.rotary_ndims = self.d_k
|
| 927 |
+
if precision == "fp16": precision = torch.half
|
| 928 |
+
self.rotary_emb = RotaryPositionalEmbedding(self.rotary_ndims, base=rotary_emd_base, precision=precision)
|
| 929 |
+
|
| 930 |
+
def forward(self, query, key, value, key_padding_mask=None, **kwargs):
|
| 931 |
+
T, B, C = value.size()
|
| 932 |
+
query = query.view(T, B, self.h, self.d_k)
|
| 933 |
+
key = key.view(T, B, self.h, self.d_k)
|
| 934 |
+
value = value.view(T, B, self.h, self.d_k)
|
| 935 |
+
cos, sin = self.rotary_emb(value, seq_len=T)
|
| 936 |
+
query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=0)
|
| 937 |
+
query = query.view(T, B, self.h * self.d_k)
|
| 938 |
+
key = key.view(T, B, self.h * self.d_k)
|
| 939 |
+
value = value.view(T, B, self.h * self.d_k)
|
| 940 |
+
q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
|
| 941 |
+
return self.forward_attention(v, torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
|
| 942 |
+
|
| 943 |
+
class ConformerEncoderLayer(nn.Module):
|
| 944 |
+
def __init__(self, embed_dim, ffn_embed_dim, attention_heads, dropout, use_fp16, depthwise_conv_kernel_size=31, activation_fn="swish", attn_type=None, pos_enc_type="abs"):
|
| 945 |
+
self.pos_enc_type = pos_enc_type
|
| 946 |
+
super(ConformerEncoderLayer, self).__init__()
|
| 947 |
+
self.ffn1 = FeedForwardModule(embed_dim, ffn_embed_dim, dropout, dropout)
|
| 948 |
+
self.self_attn_layer_norm = LayerNorm(embed_dim, export=False)
|
| 949 |
+
self.self_attn_dropout = nn.Dropout(dropout)
|
| 950 |
+
if attn_type == "espnet":
|
| 951 |
+
if self.pos_enc_type == "rel_pos": self.self_attn = RelPositionMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout)
|
| 952 |
+
elif self.pos_enc_type == "rope": self.self_attn = RotaryPositionMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout, precision=use_fp16)
|
| 953 |
+
elif self.pos_enc_type == "abs": self.self_attn = ESPNETMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout)
|
| 954 |
+
else: raise Exception
|
| 955 |
+
else: self.self_attn = MultiheadAttention(embed_dim, attention_heads, dropout=dropout)
|
| 956 |
+
self.conv_module = ConvolutionModule(embed_dim=embed_dim, channels=embed_dim, depthwise_kernel_size=depthwise_conv_kernel_size, dropout=dropout, activation_fn=activation_fn)
|
| 957 |
+
self.ffn2 = FeedForwardModule(embed_dim, ffn_embed_dim, dropout, dropout, activation_fn=activation_fn)
|
| 958 |
+
self.final_layer_norm = LayerNorm(embed_dim, export=False)
|
| 959 |
+
|
| 960 |
+
def forward(self, x, encoder_padding_mask, position_emb = None):
|
| 961 |
+
residual = x
|
| 962 |
+
x = self.ffn1(x) * 0.5 + residual
|
| 963 |
+
residual = x
|
| 964 |
+
x = self.self_attn_layer_norm(x)
|
| 965 |
+
if self.pos_enc_type == "rel_pos": x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, pos_emb=position_emb, need_weights=False)
|
| 966 |
+
else: x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=False)
|
| 967 |
+
x = self.self_attn_dropout(x)
|
| 968 |
+
x = x + residual
|
| 969 |
+
residual = x
|
| 970 |
+
x = residual + self.conv_module(x.transpose(0, 1)).transpose(0, 1)
|
| 971 |
+
residual = x
|
| 972 |
+
x = self.ffn2(x)
|
| 973 |
+
layer_result = x
|
| 974 |
+
x = self.final_layer_norm(x * 0.5 + residual)
|
| 975 |
+
return x, (attn, layer_result)
|
| 976 |
+
|
| 977 |
+
class ConformerWav2Vec2EncoderLayer(ConformerEncoderLayer):
|
| 978 |
+
def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None, position_emb=None):
|
| 979 |
+
return super().forward(x, self_attn_padding_mask, position_emb)
|
| 980 |
+
|
| 981 |
+
class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer):
|
| 982 |
+
def __init__(self, embedding_dim = 768, ffn_embedding_dim = 3072, num_attention_heads = 8, dropout = 0.1, attention_dropout = 0.1, activation_dropout = 0.1, activation_fn = "relu", layer_norm_first = False, adapter_num=201, adapter_dim=64, adapter_act_fn="relu"):
|
| 983 |
+
super().__init__(embedding_dim=embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, dropout=dropout, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, layer_norm_first=layer_norm_first)
|
| 984 |
+
self.adapter_num = adapter_num
|
| 985 |
+
self.adapter_dim = adapter_dim
|
| 986 |
+
self.adapter_layer = AdapterFast(adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn)
|
| 987 |
+
|
| 988 |
+
def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None, corpus_key=None):
|
| 989 |
+
x, (attn, layer_result) = super().forward(x=x, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, need_weights=need_weights, att_args=att_args)
|
| 990 |
+
assert corpus_key is not None
|
| 991 |
+
assert len(set(corpus_key)) == 1
|
| 992 |
+
return x + self.adapter_layer(x, corpus_key[0]), (attn, layer_result)
|
| 993 |
+
|
| 994 |
+
class TransposeLast(nn.Module):
|
| 995 |
+
def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
|
| 996 |
+
super().__init__()
|
| 997 |
+
self.deconstruct_idx = deconstruct_idx
|
| 998 |
+
self.tranpose_dim = tranpose_dim
|
| 999 |
+
|
| 1000 |
+
def forward(self, x):
|
| 1001 |
+
if self.deconstruct_idx is not None: x = x[self.deconstruct_idx]
|
| 1002 |
+
return x.transpose(self.tranpose_dim, -1)
|
| 1003 |
+
|
| 1004 |
+
class TransformerEncoder(nn.Module):
|
| 1005 |
+
def build_encoder_layer(self, args, **kwargs):
|
| 1006 |
+
if args.layer_type == "transformer": layer = TransformerSentenceEncoderLayer(embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first)
|
| 1007 |
+
elif args.layer_type == "conformer": layer = ConformerWav2Vec2EncoderLayer(embed_dim=self.embedding_dim, ffn_embed_dim=args.encoder_ffn_embed_dim, attention_heads=args.encoder_attention_heads, dropout=args.dropout, depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, activation_fn="swish", attn_type=args.attn_type, use_fp16=args.fp16, pos_enc_type="abs")
|
| 1008 |
+
elif args.layer_type == "trf_adp":
|
| 1009 |
+
use_adp = False
|
| 1010 |
+
if args.adp_trf_idx == "all": use_adp = True
|
| 1011 |
+
else:
|
| 1012 |
+
if kwargs.get("layer_idx", None) in list(range(*[int(g) for g in args.adp_trf_idx.split(":")])): use_adp = True
|
| 1013 |
+
|
| 1014 |
+
layer = TransformerSentenceEncoderWithAdapterLayer(embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first, adapter_num=args.adp_num, adapter_dim=args.adp_dim, adapter_act_fn=args.adp_act_fn) if use_adp else TransformerSentenceEncoderLayer(embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first,)
|
| 1015 |
+
|
| 1016 |
+
return layer
|
| 1017 |
+
|
| 1018 |
+
def __init__(self, args):
|
| 1019 |
+
super().__init__()
|
| 1020 |
+
self.dropout = args.dropout
|
| 1021 |
+
self.embedding_dim = args.encoder_embed_dim
|
| 1022 |
+
self.required_seq_len_multiple = args.required_seq_len_multiple
|
| 1023 |
+
pos_conv_depth = getattr(args, "pos_conv_depth", 1)
|
| 1024 |
+
if pos_conv_depth > 1:
|
| 1025 |
+
num_layers = args.pos_conv_depth
|
| 1026 |
+
k = max(3, args.conv_pos // num_layers)
|
| 1027 |
+
|
| 1028 |
+
def make_conv_block(e, k, g, l):
|
| 1029 |
+
return nn.Sequential(*[nn.Sequential(nn.Conv1d(e, e, kernel_size=k, padding=k // 2, groups=g), SamePad(k), TransposeLast(), LayerNorm(e, elementwise_affine=False), TransposeLast(), nn.GELU()) for _ in range(l)])
|
| 1030 |
+
|
| 1031 |
+
self.pos_conv = make_conv_block(self.embedding_dim, k, args.conv_pos_groups, num_layers)
|
| 1032 |
+
else: self.pos_conv = make_conv_pos(self.embedding_dim, args.conv_pos, args.conv_pos_groups)
|
| 1033 |
+
|
| 1034 |
+
self.layers = nn.ModuleList([self.build_encoder_layer(args, layer_idx=ii) for ii in range(args.encoder_layers)])
|
| 1035 |
+
self.layer_norm_first = args.layer_norm_first
|
| 1036 |
+
self.layer_norm = LayerNorm(self.embedding_dim)
|
| 1037 |
+
self.layerdrop = args.encoder_layerdrop
|
| 1038 |
+
self.apply(init_bert_params)
|
| 1039 |
+
|
| 1040 |
+
def forward(self, x, padding_mask=None, layer=None, corpus_key=None):
|
| 1041 |
+
x, layer_results = self.extract_features(x, padding_mask, layer, corpus_key=corpus_key)
|
| 1042 |
+
if self.layer_norm_first and layer is None: x = self.layer_norm(x)
|
| 1043 |
+
return x, layer_results
|
| 1044 |
+
|
| 1045 |
+
def extract_features(self, x, padding_mask=None, tgt_layer=None, min_layer=0, corpus_key=None):
|
| 1046 |
+
if padding_mask is not None: x = index_put(x, padding_mask, 0)
|
| 1047 |
+
x = x + self.pos_conv(x.transpose(1, 2)).transpose(1, 2)
|
| 1048 |
+
if not self.layer_norm_first: x = self.layer_norm(x)
|
| 1049 |
+
x, pad_length = pad_to_multiple(x, self.required_seq_len_multiple, dim=-2, value=0)
|
| 1050 |
+
if pad_length > 0 and padding_mask is None:
|
| 1051 |
+
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
|
| 1052 |
+
padding_mask[:, -pad_length:] = True
|
| 1053 |
+
else: padding_mask, _ = pad_to_multiple(padding_mask, self.required_seq_len_multiple, dim=-1, value=True)
|
| 1054 |
+
x = F.dropout(x, p=self.dropout, training=self.training).transpose(0, 1)
|
| 1055 |
+
layer_results = []
|
| 1056 |
+
r = None
|
| 1057 |
+
|
| 1058 |
+
for i, layer in enumerate(self.layers):
|
| 1059 |
+
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
|
| 1060 |
+
if not self.training or (dropout_probability > self.layerdrop):
|
| 1061 |
+
layer_check = layer
|
| 1062 |
+
if (corpus_key is None) or (not isinstance(layer_check, (TransformerSentenceEncoderWithAdapterLayer))): x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
|
| 1063 |
+
else: x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, corpus_key=corpus_key)
|
| 1064 |
+
if i >= min_layer: layer_results.append((x, z, lr))
|
| 1065 |
+
if i == tgt_layer:
|
| 1066 |
+
r = x
|
| 1067 |
+
break
|
| 1068 |
+
|
| 1069 |
+
if r is not None: x = r
|
| 1070 |
+
x = x.transpose(0, 1)
|
| 1071 |
+
|
| 1072 |
+
if pad_length > 0:
|
| 1073 |
+
x = x[:, :-pad_length]
|
| 1074 |
+
def undo_pad(a, b, c):
|
| 1075 |
+
return (a[:-pad_length], b[:-pad_length] if b is not None else b, c[:-pad_length])
|
| 1076 |
+
|
| 1077 |
+
layer_results = [undo_pad(*u) for u in layer_results]
|
| 1078 |
+
|
| 1079 |
+
return x, layer_results
|
| 1080 |
+
|
| 1081 |
+
def max_positions(self):
|
| 1082 |
+
return self.args.max_positions
|
| 1083 |
+
|
| 1084 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 1085 |
+
return state_dict
|
| 1086 |
+
|
| 1087 |
+
class Fp32GroupNorm(nn.GroupNorm):
|
| 1088 |
+
def __init__(self, *args, **kwargs):
|
| 1089 |
+
super().__init__(*args, **kwargs)
|
| 1090 |
+
|
| 1091 |
+
def forward(self, input):
|
| 1092 |
+
output = F.group_norm(input.float(), self.num_groups, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps)
|
| 1093 |
+
return output.type_as(input)
|
| 1094 |
+
|
| 1095 |
+
class Fp32LayerNorm(nn.LayerNorm):
|
| 1096 |
+
def __init__(self, *args, **kwargs):
|
| 1097 |
+
super().__init__(*args, **kwargs)
|
| 1098 |
+
|
| 1099 |
+
def forward(self, input):
|
| 1100 |
+
output = F.layer_norm(input.float(), self.normalized_shape, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps)
|
| 1101 |
+
return output.type_as(input)
|
| 1102 |
+
|
| 1103 |
+
class ConvFeatureExtractionModel(nn.Module):
|
| 1104 |
+
def __init__(self, conv_layers, dropout = 0.0, mode = "default", conv_bias = False):
|
| 1105 |
+
super().__init__()
|
| 1106 |
+
assert mode in {"default", "layer_norm"}
|
| 1107 |
+
|
| 1108 |
+
def block(n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False):
|
| 1109 |
+
def make_conv():
|
| 1110 |
+
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
|
| 1111 |
+
nn.init.kaiming_normal_(conv.weight)
|
| 1112 |
+
return conv
|
| 1113 |
+
|
| 1114 |
+
assert (is_layer_norm and is_group_norm) == False
|
| 1115 |
+
|
| 1116 |
+
if is_layer_norm: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.Sequential(TransposeLast(), Fp32LayerNorm(dim, elementwise_affine=True), TransposeLast()), nn.GELU())
|
| 1117 |
+
elif is_group_norm: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), Fp32GroupNorm(dim, dim, affine=True), nn.GELU())
|
| 1118 |
+
else: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
|
| 1119 |
+
|
| 1120 |
+
in_d = 1
|
| 1121 |
+
self.conv_layers = nn.ModuleList()
|
| 1122 |
+
for i, cl in enumerate(conv_layers):
|
| 1123 |
+
assert len(cl) == 3
|
| 1124 |
+
(dim, k, stride) = cl
|
| 1125 |
+
self.conv_layers.append(block(in_d, dim, k, stride, is_layer_norm=mode == "layer_norm", is_group_norm=mode == "default" and i == 0, conv_bias=conv_bias))
|
| 1126 |
+
in_d = dim
|
| 1127 |
+
|
| 1128 |
+
def forward(self, x):
|
| 1129 |
+
x = x.unsqueeze(1)
|
| 1130 |
+
for conv in self.conv_layers:
|
| 1131 |
+
x = conv(x)
|
| 1132 |
+
|
| 1133 |
+
return x
|
| 1134 |
+
|
| 1135 |
+
class GradMultiply(torch.autograd.Function):
|
| 1136 |
+
@staticmethod
|
| 1137 |
+
def forward(ctx, x, scale):
|
| 1138 |
+
ctx.scale = scale
|
| 1139 |
+
res = x.new(x)
|
| 1140 |
+
return res
|
| 1141 |
+
|
| 1142 |
+
@staticmethod
|
| 1143 |
+
def backward(ctx, grad):
|
| 1144 |
+
return grad * ctx.scale, None
|
| 1145 |
+
|
| 1146 |
+
class BaseFairseqModel(nn.Module):
|
| 1147 |
+
def __init__(self):
|
| 1148 |
+
super().__init__()
|
| 1149 |
+
self._is_generation_fast = False
|
| 1150 |
+
|
| 1151 |
+
def get_targets(self, sample, net_output):
|
| 1152 |
+
return sample["target"]
|
| 1153 |
+
|
| 1154 |
+
def extract_features(self, *args, **kwargs):
|
| 1155 |
+
return self(*args, **kwargs)
|
| 1156 |
+
|
| 1157 |
+
def load_state_dict(self, state_dict, strict=True, model_cfg = None, args = None):
|
| 1158 |
+
self.upgrade_state_dict(state_dict)
|
| 1159 |
+
new_state_dict = prune_state_dict(state_dict, model_cfg)
|
| 1160 |
+
return super().load_state_dict(new_state_dict, strict)
|
| 1161 |
+
|
| 1162 |
+
def upgrade_state_dict(self, state_dict):
|
| 1163 |
+
self.upgrade_state_dict_named(state_dict, "")
|
| 1164 |
+
|
| 1165 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 1166 |
+
assert state_dict is not None
|
| 1167 |
+
|
| 1168 |
+
def do_upgrade(m, prefix):
|
| 1169 |
+
if len(prefix) > 0: prefix += "."
|
| 1170 |
+
for n, c in m.named_children():
|
| 1171 |
+
name = prefix + n
|
| 1172 |
+
if hasattr(c, "upgrade_state_dict_named"): c.upgrade_state_dict_named(state_dict, name)
|
| 1173 |
+
elif hasattr(c, "upgrade_state_dict"): c.upgrade_state_dict(state_dict)
|
| 1174 |
+
do_upgrade(c, name)
|
| 1175 |
+
|
| 1176 |
+
do_upgrade(self, name)
|
| 1177 |
+
|
| 1178 |
+
def make_generation_fast_(self, **kwargs):
|
| 1179 |
+
if self._is_generation_fast: return
|
| 1180 |
+
self._is_generation_fast = True
|
| 1181 |
+
|
| 1182 |
+
def apply_remove_weight_norm(module):
|
| 1183 |
+
try:
|
| 1184 |
+
nn.utils.remove_weight_norm(module)
|
| 1185 |
+
except (AttributeError, ValueError):
|
| 1186 |
+
return
|
| 1187 |
+
|
| 1188 |
+
self.apply(apply_remove_weight_norm)
|
| 1189 |
+
def apply_make_generation_fast_(module, prefix):
|
| 1190 |
+
if len(prefix) > 0: prefix += "."
|
| 1191 |
+
|
| 1192 |
+
base_func = BaseFairseqModel.make_generation_fast_
|
| 1193 |
+
for n, m in module.named_modules():
|
| 1194 |
+
if (m != self and hasattr(m, "make_generation_fast_") and m.make_generation_fast_.__func__ is not base_func): m.make_generation_fast_(name=prefix + n, **kwargs)
|
| 1195 |
+
|
| 1196 |
+
apply_make_generation_fast_(self, "")
|
| 1197 |
+
self.eval()
|
| 1198 |
+
|
| 1199 |
+
class HubertConfig:
|
| 1200 |
+
def __init__(self, _name, label_rate, encoder_layers_1, logit_temp_ctr, num_negatives, cross_sample_negatives, ctr_layers, extractor_mode = "default", encoder_layers = 12, encoder_embed_dim = 768, encoder_ffn_embed_dim = 3072, encoder_attention_heads = 12, activation_fn = "gelu", layer_type = "transformer", dropout = 0.1, attention_dropout = 0.1, activation_dropout = 0.0, encoder_layerdrop = 0.0, dropout_input = 0.0, dropout_features = 0.0, final_dim = 0, untie_final_proj = False, layer_norm_first = False, conv_feature_layers = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", conv_bias = False, logit_temp = 0.1, target_glu = False, feature_grad_mult = 1.0, mask_length = 10, mask_prob = 0.65, mask_selection = "static", mask_other = 0.0, no_mask_overlap = False, mask_min_space = 1, mask_channel_length = 10, mask_channel_prob = 0.0, mask_channel_selection = "static", mask_channel_other = 0.0, no_mask_channel_overlap = False, mask_channel_min_space = 1, conv_pos = 128, conv_pos_groups = 16, conv_pos_batch_norm = False, latent_temp = (2, 0.5, 0.999995), skip_masked = False, skip_nomask = False, checkpoint_activations = False, required_seq_len_multiple = 2, depthwise_conv_kernel_size = 31, attn_type = "", pos_enc_type = "abs", fp16 = False):
|
| 1201 |
+
self._name = _name
|
| 1202 |
+
self.label_rate = label_rate
|
| 1203 |
+
self.encoder_layers_1 = encoder_layers_1
|
| 1204 |
+
self.logit_temp_ctr = logit_temp_ctr
|
| 1205 |
+
self.num_negatives = num_negatives
|
| 1206 |
+
self.cross_sample_negatives = cross_sample_negatives
|
| 1207 |
+
self.ctr_layers = ctr_layers
|
| 1208 |
+
self.extractor_mode = extractor_mode
|
| 1209 |
+
self.encoder_layers = encoder_layers
|
| 1210 |
+
self.encoder_embed_dim = encoder_embed_dim
|
| 1211 |
+
self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
|
| 1212 |
+
self.encoder_attention_heads = encoder_attention_heads
|
| 1213 |
+
self.activation_fn = activation_fn
|
| 1214 |
+
self.layer_type = layer_type
|
| 1215 |
+
self.dropout = dropout
|
| 1216 |
+
self.attention_dropout = attention_dropout
|
| 1217 |
+
self.activation_dropout = activation_dropout
|
| 1218 |
+
self.encoder_layerdrop = encoder_layerdrop
|
| 1219 |
+
self.dropout_input = encoder_layerdrop
|
| 1220 |
+
self.dropout_features = dropout_features
|
| 1221 |
+
self.final_dim = final_dim
|
| 1222 |
+
self.untie_final_proj = untie_final_proj
|
| 1223 |
+
self.layer_norm_first = layer_norm_first
|
| 1224 |
+
self.conv_feature_layers = conv_feature_layers
|
| 1225 |
+
self.conv_bias = conv_bias
|
| 1226 |
+
self.logit_temp = logit_temp
|
| 1227 |
+
self.target_glu = target_glu
|
| 1228 |
+
self.feature_grad_mult = feature_grad_mult
|
| 1229 |
+
self.mask_length = mask_length
|
| 1230 |
+
self.mask_prob = mask_prob
|
| 1231 |
+
self.mask_selection = mask_selection
|
| 1232 |
+
self.mask_other = mask_other
|
| 1233 |
+
self.no_mask_overlap = no_mask_overlap
|
| 1234 |
+
self.mask_min_space = mask_min_space
|
| 1235 |
+
self.mask_channel_length = mask_channel_length
|
| 1236 |
+
self.mask_channel_prob = mask_channel_prob
|
| 1237 |
+
self.mask_channel_selection = mask_channel_selection
|
| 1238 |
+
self.mask_channel_other = mask_channel_other
|
| 1239 |
+
self.no_mask_channel_overlap = no_mask_channel_overlap
|
| 1240 |
+
self.mask_channel_min_space = mask_channel_min_space
|
| 1241 |
+
self.conv_pos = conv_pos
|
| 1242 |
+
self.conv_pos_groups = conv_pos_groups
|
| 1243 |
+
self.conv_pos_batch_norm = conv_pos_batch_norm
|
| 1244 |
+
self.latent_temp = latent_temp
|
| 1245 |
+
self.skip_masked = skip_masked
|
| 1246 |
+
self.skip_nomask = skip_nomask
|
| 1247 |
+
self.checkpoint_activations = checkpoint_activations
|
| 1248 |
+
self.required_seq_len_multiple = required_seq_len_multiple
|
| 1249 |
+
self.depthwise_conv_kernel_size = depthwise_conv_kernel_size
|
| 1250 |
+
self.attn_type = attn_type
|
| 1251 |
+
self.pos_enc_type = pos_enc_type
|
| 1252 |
+
self.fp16 = fp16
|
| 1253 |
+
|
| 1254 |
+
class HubertModel(BaseFairseqModel):
|
| 1255 |
+
def __init__(self, cfg):
|
| 1256 |
+
super().__init__()
|
| 1257 |
+
feature_enc_layers = eval(cfg.conv_feature_layers)
|
| 1258 |
+
self.embed = feature_enc_layers[-1][0]
|
| 1259 |
+
self.feature_extractor = ConvFeatureExtractionModel(conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias)
|
| 1260 |
+
feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
|
| 1261 |
+
self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / 16000
|
| 1262 |
+
self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None)
|
| 1263 |
+
self.mask_prob = cfg.mask_prob
|
| 1264 |
+
self.mask_selection = cfg.mask_selection
|
| 1265 |
+
self.mask_other = cfg.mask_other
|
| 1266 |
+
self.mask_length = cfg.mask_length
|
| 1267 |
+
self.no_mask_overlap = cfg.no_mask_overlap
|
| 1268 |
+
self.mask_min_space = cfg.mask_min_space
|
| 1269 |
+
self.mask_channel_prob = cfg.mask_channel_prob
|
| 1270 |
+
self.mask_channel_selection = cfg.mask_channel_selection
|
| 1271 |
+
self.mask_channel_other = cfg.mask_channel_other
|
| 1272 |
+
self.mask_channel_length = cfg.mask_channel_length
|
| 1273 |
+
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
|
| 1274 |
+
self.mask_channel_min_space = cfg.mask_channel_min_space
|
| 1275 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
| 1276 |
+
self.dropout_features = nn.Dropout(cfg.dropout_features)
|
| 1277 |
+
self.feature_grad_mult = cfg.feature_grad_mult
|
| 1278 |
+
self.logit_temp = cfg.logit_temp
|
| 1279 |
+
self.skip_masked = cfg.skip_masked
|
| 1280 |
+
self.skip_nomask = cfg.skip_nomask
|
| 1281 |
+
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
|
| 1282 |
+
self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_())
|
| 1283 |
+
self.encoder = TransformerEncoder(cfg)
|
| 1284 |
+
self.layer_norm = LayerNorm(self.embed)
|
| 1285 |
+
self.target_glu = None
|
| 1286 |
+
if cfg.target_glu: self.target_glu = nn.Sequential(nn.Linear(final_dim, final_dim * 2), nn.GLU())
|
| 1287 |
+
self.untie_final_proj = cfg.untie_final_proj
|
| 1288 |
+
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
|
| 1289 |
+
self.num_classes = [504]
|
| 1290 |
+
self.label_embs_concat = nn.Parameter(torch.FloatTensor(sum(self.num_classes), final_dim))
|
| 1291 |
+
nn.init.uniform_(self.label_embs_concat)
|
| 1292 |
+
|
| 1293 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 1294 |
+
super().upgrade_state_dict_named(state_dict, name)
|
| 1295 |
+
return state_dict
|
| 1296 |
+
|
| 1297 |
+
def apply_mask(self, x, padding_mask, target_list):
|
| 1298 |
+
B, T, C = x.shape
|
| 1299 |
+
if self.mask_prob > 0:
|
| 1300 |
+
mask_indices = torch.from_numpy(compute_mask_indices((B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space)).to(x.device)
|
| 1301 |
+
x[mask_indices] = self.mask_emb
|
| 1302 |
+
else: mask_indices = None
|
| 1303 |
+
|
| 1304 |
+
if self.mask_channel_prob > 0: x[(torch.from_numpy(compute_mask_indices((B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space)).to(x.device).unsqueeze(1).expand(-1, T, -1))] = 0
|
| 1305 |
+
return x, mask_indices
|
| 1306 |
+
|
| 1307 |
+
def compute_nce(self, x, pos, negs):
|
| 1308 |
+
neg_is_pos = (pos == negs).all(-1)
|
| 1309 |
+
logits = torch.cosine_similarity(x.float(), torch.cat([pos.unsqueeze(0), negs], dim=0).float(), dim=-1).type_as(x)
|
| 1310 |
+
logits /= self.logit_temp
|
| 1311 |
+
if neg_is_pos.any(): logits[1:][neg_is_pos] = float("-inf")
|
| 1312 |
+
return logits.transpose(0, 1)
|
| 1313 |
+
|
| 1314 |
+
def forward_features(self, source):
|
| 1315 |
+
if self.feature_grad_mult > 0:
|
| 1316 |
+
features = self.feature_extractor(source)
|
| 1317 |
+
if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult)
|
| 1318 |
+
else:
|
| 1319 |
+
with torch.no_grad():
|
| 1320 |
+
features = self.feature_extractor(source)
|
| 1321 |
+
return features
|
| 1322 |
+
|
| 1323 |
+
def forward_targets(self, features, target_list):
|
| 1324 |
+
feat_tsz = features.size(2)
|
| 1325 |
+
targ_tsz = min([t.size(1) for t in target_list])
|
| 1326 |
+
if self.feat2tar_ratio * feat_tsz > targ_tsz:
|
| 1327 |
+
feat_tsz = int(targ_tsz / self.feat2tar_ratio)
|
| 1328 |
+
features = features[..., :feat_tsz]
|
| 1329 |
+
|
| 1330 |
+
return features, [t[:, (torch.arange(feat_tsz).float() * self.feat2tar_ratio).long()] for t in target_list]
|
| 1331 |
+
|
| 1332 |
+
def forward_padding_mask(self, features, padding_mask):
|
| 1333 |
+
extra = padding_mask.size(1) % features.size(1)
|
| 1334 |
+
if extra > 0: padding_mask = padding_mask[:, :-extra]
|
| 1335 |
+
return padding_mask.view(padding_mask.size(0), features.size(1), -1).all(-1)
|
| 1336 |
+
|
| 1337 |
+
def forward(self, source, target_list = None, padding_mask = None, mask = True, features_only = False, output_layer = None):
|
| 1338 |
+
features = self.forward_features(source)
|
| 1339 |
+
if target_list is not None: features, target_list = self.forward_targets(features, target_list)
|
| 1340 |
+
features_pen = features.float().pow(2).mean()
|
| 1341 |
+
features = self.layer_norm(features.transpose(1, 2))
|
| 1342 |
+
unmasked_features = features.clone()
|
| 1343 |
+
if padding_mask is not None: padding_mask = self.forward_padding_mask(features, padding_mask)
|
| 1344 |
+
if self.post_extract_proj is not None: features = self.post_extract_proj(features)
|
| 1345 |
+
features = self.dropout_input(features)
|
| 1346 |
+
unmasked_features = self.dropout_features(unmasked_features)
|
| 1347 |
+
if mask: x, mask_indices = self.apply_mask(features, padding_mask, target_list)
|
| 1348 |
+
else: x, mask_indices = features, None
|
| 1349 |
+
x, _ = self.encoder(x, padding_mask=padding_mask, layer=None if output_layer is None else output_layer - 1)
|
| 1350 |
+
if features_only: return {"x": x, "padding_mask": padding_mask, "features": features}
|
| 1351 |
+
|
| 1352 |
+
def compute_pred(proj_x, target, label_embs):
|
| 1353 |
+
y = torch.index_select(label_embs, 0, target.long())
|
| 1354 |
+
negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
|
| 1355 |
+
if self.target_glu:
|
| 1356 |
+
y = self.target_glu(y)
|
| 1357 |
+
negs = self.target_glu(negs)
|
| 1358 |
+
|
| 1359 |
+
return self.compute_nce(proj_x, y, negs)
|
| 1360 |
+
|
| 1361 |
+
label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
|
| 1362 |
+
if not self.skip_masked:
|
| 1363 |
+
masked_indices = torch.logical_and(~padding_mask, mask_indices)
|
| 1364 |
+
proj_x_m = self.final_proj(x[masked_indices])
|
| 1365 |
+
logit_m_list = [compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) for i, (proj_x_m, t) in enumerate(zip(proj_x_m.chunk(len(target_list), dim=-1) if self.untie_final_proj else [proj_x_m for _ in range(len(target_list))], target_list))]
|
| 1366 |
+
else: logit_m_list = [None for _ in target_list]
|
| 1367 |
+
|
| 1368 |
+
if not self.skip_nomask:
|
| 1369 |
+
nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
|
| 1370 |
+
proj_x_u = self.final_proj(x[nomask_indices])
|
| 1371 |
+
logit_u_list = [compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i]) for i, (proj_x_u, t) in enumerate(zip(proj_x_u.chunk(len(target_list), dim=-1) if self.untie_final_proj else [proj_x_u for _ in range(len(target_list))], target_list))]
|
| 1372 |
+
else: logit_u_list = [None for _ in target_list]
|
| 1373 |
+
|
| 1374 |
+
return {"logit_m_list": logit_m_list, "logit_u_list": logit_u_list, "padding_mask": padding_mask, "features_pen": features_pen}
|
| 1375 |
+
|
| 1376 |
+
def extract_features(self, source, padding_mask = None, mask = False, ret_conv = False, output_layer = None):
|
| 1377 |
+
res = self.forward(source, padding_mask=padding_mask, mask=mask, features_only=True, output_layer=output_layer)
|
| 1378 |
+
return res["features"] if ret_conv else res["x"], res["padding_mask"]
|
| 1379 |
+
|
| 1380 |
+
def get_logits(self, net_output, is_masked=True):
|
| 1381 |
+
return [x.float() for x in (net_output["logit_m_list"] if is_masked else net_output["logit_u_list"]) if x is not None]
|
| 1382 |
+
|
| 1383 |
+
def get_targets(self, net_output, is_masked=True):
|
| 1384 |
+
return [x.new_zeros(x.size(0), dtype=torch.long) for x in self.get_logits(net_output, is_masked)]
|
| 1385 |
+
|
| 1386 |
+
def get_extra_losses(self, net_output):
|
| 1387 |
+
extra_losses, names = [], []
|
| 1388 |
+
if "features_pen" in net_output:
|
| 1389 |
+
extra_losses.append(net_output["features_pen"])
|
| 1390 |
+
names.append("features_pen")
|
| 1391 |
+
|
| 1392 |
+
return extra_losses, names
|
| 1393 |
+
|
| 1394 |
+
def remove_pretraining_modules(self):
|
| 1395 |
+
self.target_glu = None
|
| 1396 |
+
self.final_proj = None
|
RVC/modules/gdown.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import sys
|
| 4 |
+
import json
|
| 5 |
+
import codecs
|
| 6 |
+
import tempfile
|
| 7 |
+
import requests
|
| 8 |
+
|
| 9 |
+
from urllib.parse import urlparse, parse_qs, unquote
|
| 10 |
+
|
| 11 |
+
def parse_url(url):
|
| 12 |
+
parsed = urlparse(url)
|
| 13 |
+
is_download_link = parsed.path.endswith("/uc")
|
| 14 |
+
if not parsed.hostname in ("drive.google.com", "docs.google.com"): return None, is_download_link
|
| 15 |
+
file_id = parse_qs(parsed.query).get("id", [None])[0]
|
| 16 |
+
|
| 17 |
+
if file_id is None:
|
| 18 |
+
for pattern in (r"^/file/d/(.*?)/(edit|view)$", r"^/file/u/[0-9]+/d/(.*?)/(edit|view)$", r"^/document/d/(.*?)/(edit|htmlview|view)$", r"^/document/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$", r"^/presentation/d/(.*?)/(edit|htmlview|view)$", r"^/presentation/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$", r"^/spreadsheets/d/(.*?)/(edit|htmlview|view)$", r"^/spreadsheets/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$"):
|
| 19 |
+
match = re.match(pattern, parsed.path)
|
| 20 |
+
if match:
|
| 21 |
+
file_id = match.group(1)
|
| 22 |
+
break
|
| 23 |
+
return file_id, is_download_link
|
| 24 |
+
|
| 25 |
+
def get_url_from_gdrive_confirmation(contents):
|
| 26 |
+
for pattern in (r'href="(\/uc\?export=download[^"]+)', r'href="/open\?id=([^"]+)"', r'"downloadUrl":"([^"]+)'):
|
| 27 |
+
match = re.search(pattern, contents)
|
| 28 |
+
if match:
|
| 29 |
+
url = match.group(1)
|
| 30 |
+
if pattern == r'href="/open\?id=([^"]+)"': url = (codecs.decode("uggcf://qevir.hfrepbagrag.tbbtyr.pbz/qbjaybnq?vq=", "rot13") + url + "&confirm=t&uuid=" + re.search(r'<input\s+type="hidden"\s+name="uuid"\s+value="([^"]+)"', contents).group(1))
|
| 31 |
+
elif pattern == r'"downloadUrl":"([^"]+)': url = url.replace("\\u003d", "=").replace("\\u0026", "&")
|
| 32 |
+
else: url = codecs.decode("uggcf://qbpf.tbbtyr.pbz", "rot13") + url.replace("&", "&")
|
| 33 |
+
return url
|
| 34 |
+
|
| 35 |
+
match = re.search(r'<p class="uc-error-subcaption">(.*)</p>', contents)
|
| 36 |
+
if match: raise Exception(match.group(1))
|
| 37 |
+
raise Exception
|
| 38 |
+
|
| 39 |
+
def _get_session(use_cookies, return_cookies_file=False):
|
| 40 |
+
sess = requests.session()
|
| 41 |
+
sess.headers.update({"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)"})
|
| 42 |
+
cookies_file = os.path.join(os.path.expanduser("~"), ".cache/gdown/cookies.json")
|
| 43 |
+
|
| 44 |
+
if os.path.exists(cookies_file) and use_cookies:
|
| 45 |
+
with open(cookies_file) as f:
|
| 46 |
+
for k, v in json.load(f):
|
| 47 |
+
sess.cookies[k] = v
|
| 48 |
+
return (sess, cookies_file) if return_cookies_file else sess
|
| 49 |
+
|
| 50 |
+
def gdown_download(url=None, id=None, output=None):
|
| 51 |
+
if not (id is None) ^ (url is None): raise ValueError
|
| 52 |
+
if id is not None: url = f"{codecs.decode('uggcf://qevir.tbbtyr.pbz/hp?vq=', 'rot13')}{id}"
|
| 53 |
+
|
| 54 |
+
url_origin = url
|
| 55 |
+
sess, cookies_file = _get_session(use_cookies=True, return_cookies_file=True)
|
| 56 |
+
gdrive_file_id, is_gdrive_download_link = parse_url(url)
|
| 57 |
+
|
| 58 |
+
if gdrive_file_id:
|
| 59 |
+
url = f"{codecs.decode('uggcf://qevir.tbbtyr.pbz/hp?vq=', 'rot13')}{gdrive_file_id}"
|
| 60 |
+
url_origin = url
|
| 61 |
+
is_gdrive_download_link = True
|
| 62 |
+
|
| 63 |
+
while 1:
|
| 64 |
+
res = sess.get(url, stream=True, verify=True)
|
| 65 |
+
if url == url_origin and res.status_code == 500:
|
| 66 |
+
url = f"{codecs.decode('uggcf://qevir.tbbtyr.pbz/bcra?vq=', 'rot13')}{gdrive_file_id}"
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
os.makedirs(os.path.dirname(cookies_file), exist_ok=True)
|
| 70 |
+
with open(cookies_file, "w") as f:
|
| 71 |
+
json.dump([(k, v) for k, v in sess.cookies.items() if not k.startswith("download_warning_")], f, indent=2)
|
| 72 |
+
|
| 73 |
+
if "Content-Disposition" in res.headers: break
|
| 74 |
+
if not (gdrive_file_id and is_gdrive_download_link): break
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
url = get_url_from_gdrive_confirmation(res.text)
|
| 78 |
+
except Exception as e:
|
| 79 |
+
raise Exception(e)
|
| 80 |
+
|
| 81 |
+
if gdrive_file_id and is_gdrive_download_link:
|
| 82 |
+
content_disposition = unquote(res.headers["Content-Disposition"])
|
| 83 |
+
filename_from_url = (re.search(r"filename\*=UTF-8''(.*)", content_disposition) or re.search(r'filename=["\']?(.*?)["\']?$', content_disposition)).group(1).replace(os.path.sep, "_")
|
| 84 |
+
else: filename_from_url = os.path.basename(url)
|
| 85 |
+
|
| 86 |
+
output = os.path.join(output or ".", filename_from_url)
|
| 87 |
+
tmp_file = tempfile.mktemp(suffix=tempfile.template, prefix=os.path.basename(output), dir=os.path.dirname(output))
|
| 88 |
+
f = open(tmp_file, "ab")
|
| 89 |
+
|
| 90 |
+
if tmp_file is not None and f.tell() != 0: res = sess.get(url, headers={"Range": f"bytes={f.tell()}-"}, stream=True, verify=True)
|
| 91 |
+
print("To:", os.path.abspath(output), file=sys.stderr)
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
for chunk in res.iter_content(chunk_size=512 * 1024):
|
| 95 |
+
f.write(chunk)
|
| 96 |
+
if tmp_file: f.close()
|
| 97 |
+
finally:
|
| 98 |
+
os.rename(tmp_file, output)
|
| 99 |
+
sess.close()
|
| 100 |
+
return output
|
RVC/modules/generator.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import parselmouth
|
| 6 |
+
|
| 7 |
+
import numba as nb
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from librosa import yin, pyin
|
| 11 |
+
from scipy.signal import medfilt
|
| 12 |
+
|
| 13 |
+
sys.path.append(os.getcwd())
|
| 14 |
+
|
| 15 |
+
from modules.rmvpe import RMVPE
|
| 16 |
+
from modules.utils import Autotune
|
| 17 |
+
from modules.torchfcpe import FCPE
|
| 18 |
+
from modules.pyworld import PYWORLD
|
| 19 |
+
from modules.swipe import swipe, stonemask
|
| 20 |
+
from modules.torchcrepe import CREPE, mean, median
|
| 21 |
+
|
| 22 |
+
@nb.jit(nopython=True)
|
| 23 |
+
def post_process(f0, f0_up_key, f0_mel_min, f0_mel_max):
|
| 24 |
+
f0 = np.multiply(f0, pow(2, f0_up_key / 12))
|
| 25 |
+
|
| 26 |
+
f0_mel = 1127 * np.log(1 + f0 / 700)
|
| 27 |
+
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (f0_mel_max - f0_mel_min) + 1
|
| 28 |
+
f0_mel[f0_mel <= 1] = 1
|
| 29 |
+
f0_mel[f0_mel > 255] = 255
|
| 30 |
+
|
| 31 |
+
return np.rint(f0_mel).astype(np.int32), f0
|
| 32 |
+
|
| 33 |
+
class Generator:
|
| 34 |
+
def __init__(self, sample_rate = 16000, hop_length = 160, f0_min = 50, f0_max = 1100, is_half = False, device = "cpu"):
|
| 35 |
+
self.sample_rate = sample_rate
|
| 36 |
+
self.hop_length = hop_length
|
| 37 |
+
self.f0_min = f0_min
|
| 38 |
+
self.f0_max = f0_max
|
| 39 |
+
self.is_half = is_half
|
| 40 |
+
self.device = device
|
| 41 |
+
self.window = 160
|
| 42 |
+
self.ref_freqs = [49.00, 51.91, 55.00, 58.27, 61.74, 65.41, 69.30, 73.42, 77.78, 82.41, 87.31, 92.50, 98.00, 103.83, 110.00, 116.54, 123.47, 130.81, 138.59, 146.83, 155.56, 164.81, 174.61, 185.00, 196.00, 207.65, 220.00, 233.08, 246.94, 261.63, 277.18, 293.66, 311.13, 329.63, 349.23, 369.99, 392.00, 415.30, 440.00, 466.16, 493.88, 523.25, 554.37, 587.33, 622.25, 659.25, 698.46, 739.99, 783.99, 830.61, 880.00, 932.33, 987.77, 1046.50]
|
| 43 |
+
self.autotune = Autotune(self.ref_freqs)
|
| 44 |
+
self.note_dict = self.autotune.note_dict
|
| 45 |
+
|
| 46 |
+
def calculator(self, f0_method, x, f0_up_key = 0, p_len = None, filter_radius = 3, f0_autotune = False, f0_autotune_strength = 1):
|
| 47 |
+
if p_len is None: p_len = x.shape[0] // self.window
|
| 48 |
+
f0 = self.compute_f0(f0_method, x, p_len, filter_radius if filter_radius % 2 != 0 else filter_radius + 1)
|
| 49 |
+
|
| 50 |
+
if isinstance(f0, tuple): f0 = f0[0]
|
| 51 |
+
if f0_autotune: f0 = Autotune.autotune_f0(self, f0, f0_autotune_strength)
|
| 52 |
+
|
| 53 |
+
return post_process(
|
| 54 |
+
f0,
|
| 55 |
+
f0_up_key,
|
| 56 |
+
1127 * math.log(1 + self.f0_min / 700),
|
| 57 |
+
1127 * math.log(1 + self.f0_max / 700),
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
def _resize_f0(self, x, target_len):
|
| 61 |
+
source = np.array(x)
|
| 62 |
+
source[source < 0.001] = np.nan
|
| 63 |
+
|
| 64 |
+
return np.nan_to_num(
|
| 65 |
+
np.interp(
|
| 66 |
+
np.arange(0, len(source) * target_len, len(source)) / target_len,
|
| 67 |
+
np.arange(0, len(source)),
|
| 68 |
+
source
|
| 69 |
+
)
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def compute_f0(self, f0_method, x, p_len, filter_radius):
|
| 73 |
+
return {
|
| 74 |
+
"pm": lambda: self.get_f0_pm(x, p_len),
|
| 75 |
+
"dio": lambda: self.get_f0_pyworld(x, p_len, filter_radius, "dio"),
|
| 76 |
+
"mangio-crepe-tiny": lambda: self.get_f0_mangio_crepe(x, p_len, "tiny"),
|
| 77 |
+
"mangio-crepe-small": lambda: self.get_f0_mangio_crepe(x, p_len, "small"),
|
| 78 |
+
"mangio-crepe-medium": lambda: self.get_f0_mangio_crepe(x, p_len, "medium"),
|
| 79 |
+
"mangio-crepe-large": lambda: self.get_f0_mangio_crepe(x, p_len, "large"),
|
| 80 |
+
"mangio-crepe-full": lambda: self.get_f0_mangio_crepe(x, p_len, "full"),
|
| 81 |
+
"crepe-tiny": lambda: self.get_f0_crepe(x, p_len, "tiny"),
|
| 82 |
+
"crepe-small": lambda: self.get_f0_crepe(x, p_len, "small"),
|
| 83 |
+
"crepe-medium": lambda: self.get_f0_crepe(x, p_len, "medium"),
|
| 84 |
+
"crepe-large": lambda: self.get_f0_crepe(x, p_len, "large"),
|
| 85 |
+
"crepe-full": lambda: self.get_f0_crepe(x, p_len, "full"),
|
| 86 |
+
"fcpe": lambda: self.get_f0_fcpe(x, p_len),
|
| 87 |
+
"fcpe-legacy": lambda: self.get_f0_fcpe(x, p_len, legacy=True),
|
| 88 |
+
"rmvpe": lambda: self.get_f0_rmvpe(x, p_len),
|
| 89 |
+
"rmvpe-legacy": lambda: self.get_f0_rmvpe(x, p_len, legacy=True),
|
| 90 |
+
"harvest": lambda: self.get_f0_pyworld(x, p_len, filter_radius, "harvest"),
|
| 91 |
+
"yin": lambda: self.get_f0_yin(x, p_len, mode="yin"),
|
| 92 |
+
"pyin": lambda: self.get_f0_yin(x, p_len, mode="pyin"),
|
| 93 |
+
"swipe": lambda: self.get_f0_swipe(x, p_len)
|
| 94 |
+
}[f0_method]()
|
| 95 |
+
|
| 96 |
+
def get_f0_pm(self, x, p_len):
|
| 97 |
+
f0 = (
|
| 98 |
+
parselmouth.Sound(
|
| 99 |
+
x,
|
| 100 |
+
self.sample_rate
|
| 101 |
+
).to_pitch_ac(
|
| 102 |
+
time_step=160 / self.sample_rate * 1000 / 1000,
|
| 103 |
+
voicing_threshold=0.6,
|
| 104 |
+
pitch_floor=self.f0_min,
|
| 105 |
+
pitch_ceiling=self.f0_max
|
| 106 |
+
).selected_array["frequency"]
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
pad_size = (p_len - len(f0) + 1) // 2
|
| 110 |
+
|
| 111 |
+
if pad_size > 0 or p_len - len(f0) - pad_size > 0: f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
|
| 112 |
+
return f0
|
| 113 |
+
|
| 114 |
+
def get_f0_mangio_crepe(self, x, p_len, model="full"):
|
| 115 |
+
if not hasattr(self, "mangio_crepe"):
|
| 116 |
+
self.mangio_crepe = CREPE(
|
| 117 |
+
os.path.join(
|
| 118 |
+
"models",
|
| 119 |
+
f"crepe_{model}.pth"
|
| 120 |
+
),
|
| 121 |
+
model_size=model,
|
| 122 |
+
hop_length=self.hop_length,
|
| 123 |
+
batch_size=self.hop_length * 2,
|
| 124 |
+
f0_min=self.f0_min,
|
| 125 |
+
f0_max=self.f0_max,
|
| 126 |
+
device=self.device,
|
| 127 |
+
sample_rate=self.sample_rate,
|
| 128 |
+
return_periodicity=False
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
x = x.astype(np.float32)
|
| 132 |
+
x /= np.quantile(np.abs(x), 0.999)
|
| 133 |
+
|
| 134 |
+
audio = torch.unsqueeze(torch.from_numpy(x).to(self.device, copy=True), dim=0)
|
| 135 |
+
if audio.ndim == 2 and audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True).detach()
|
| 136 |
+
|
| 137 |
+
f0 = self.mangio_crepe.compute_f0(audio.detach(), pad=True)
|
| 138 |
+
return self._resize_f0(f0.squeeze(0).cpu().float().numpy(), p_len)
|
| 139 |
+
|
| 140 |
+
def get_f0_crepe(self, x, p_len, model="full"):
|
| 141 |
+
if not hasattr(self, "crepe"):
|
| 142 |
+
self.crepe = CREPE(
|
| 143 |
+
os.path.join(
|
| 144 |
+
"models",
|
| 145 |
+
f"crepe_{model}.pth"
|
| 146 |
+
),
|
| 147 |
+
model_size=model,
|
| 148 |
+
hop_length=self.hop_length,
|
| 149 |
+
batch_size=512,
|
| 150 |
+
f0_min=self.f0_min,
|
| 151 |
+
f0_max=self.f0_max,
|
| 152 |
+
device=self.device,
|
| 153 |
+
sample_rate=self.sample_rate,
|
| 154 |
+
return_periodicity=True
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
f0, pd = self.crepe.compute_f0(torch.tensor(np.copy(x))[None].float(), pad=True)
|
| 158 |
+
f0, pd = mean(f0, 3), median(pd, 3)
|
| 159 |
+
f0[pd < 0.1] = 0
|
| 160 |
+
|
| 161 |
+
return self._resize_f0(f0[0].cpu().numpy(), p_len)
|
| 162 |
+
|
| 163 |
+
def get_f0_fcpe(self, x, p_len, legacy=False):
|
| 164 |
+
if not hasattr(self, "fcpe"):
|
| 165 |
+
self.fcpe = FCPE(
|
| 166 |
+
os.path.join(
|
| 167 |
+
"models",
|
| 168 |
+
("fcpe_legacy" if legacy else "fcpe") + ".pt"
|
| 169 |
+
),
|
| 170 |
+
hop_length=self.hop_length,
|
| 171 |
+
f0_min=self.f0_min,
|
| 172 |
+
f0_max=self.f0_max,
|
| 173 |
+
dtype=torch.float32,
|
| 174 |
+
device=self.device,
|
| 175 |
+
sample_rate=self.sample_rate,
|
| 176 |
+
threshold=0.03 if legacy else 0.006,
|
| 177 |
+
legacy=legacy
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
f0 = self.fcpe.compute_f0(x, p_len)
|
| 181 |
+
return f0
|
| 182 |
+
|
| 183 |
+
def get_f0_rmvpe(self, x, p_len, legacy=False):
|
| 184 |
+
if not hasattr(self, "rmvpe"):
|
| 185 |
+
self.rmvpe = RMVPE(
|
| 186 |
+
os.path.join(
|
| 187 |
+
"models",
|
| 188 |
+
"rmvpe.pt"
|
| 189 |
+
),
|
| 190 |
+
is_half=self.is_half,
|
| 191 |
+
device=self.device,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
f0 = self.rmvpe.infer_from_audio_with_pitch(x, thred=0.03, f0_min=self.f0_min, f0_max=self.f0_max) if legacy else self.rmvpe.infer_from_audio(x, thred=0.03)
|
| 195 |
+
return self._resize_f0(f0, p_len)
|
| 196 |
+
|
| 197 |
+
def get_f0_pyworld(self, x, p_len, filter_radius, model="harvest"):
|
| 198 |
+
if not hasattr(self, "pw"): self.pw = PYWORLD()
|
| 199 |
+
|
| 200 |
+
x = x.astype(np.double)
|
| 201 |
+
pw = self.pw.harvest if model == "harvest" else self.pw.dio
|
| 202 |
+
|
| 203 |
+
f0, t = pw(
|
| 204 |
+
x,
|
| 205 |
+
fs=self.sample_rate,
|
| 206 |
+
f0_ceil=self.f0_max,
|
| 207 |
+
f0_floor=self.f0_min,
|
| 208 |
+
frame_period=1000 * self.window / self.sample_rate
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
f0 = self.pw.stonemask(
|
| 212 |
+
x,
|
| 213 |
+
self.sample_rate,
|
| 214 |
+
t,
|
| 215 |
+
f0
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
if filter_radius > 2 and model == "harvest": f0 = medfilt(f0, filter_radius)
|
| 219 |
+
elif model == "dio":
|
| 220 |
+
for index, pitch in enumerate(f0):
|
| 221 |
+
f0[index] = round(pitch, 1)
|
| 222 |
+
|
| 223 |
+
return self._resize_f0(f0, p_len)
|
| 224 |
+
|
| 225 |
+
def get_f0_swipe(self, x, p_len):
|
| 226 |
+
f0, t = swipe(
|
| 227 |
+
x.astype(np.float32),
|
| 228 |
+
self.sample_rate,
|
| 229 |
+
f0_floor=self.f0_min,
|
| 230 |
+
f0_ceil=self.f0_max,
|
| 231 |
+
frame_period=1000 * self.window / self.sample_rate
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
return self._resize_f0(
|
| 235 |
+
stonemask(
|
| 236 |
+
x,
|
| 237 |
+
self.sample_rate,
|
| 238 |
+
t,
|
| 239 |
+
f0
|
| 240 |
+
),
|
| 241 |
+
p_len
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
def get_f0_yin(self, x, p_len, mode="yin"):
|
| 245 |
+
self.if_yin = mode == "yin"
|
| 246 |
+
self.yin = yin if self.if_yin else pyin
|
| 247 |
+
|
| 248 |
+
f0 = self.yin(
|
| 249 |
+
x.astype(np.float32),
|
| 250 |
+
sr=self.sample_rate,
|
| 251 |
+
fmin=self.f0_min,
|
| 252 |
+
fmax=self.f0_max,
|
| 253 |
+
hop_length=self.hop_length
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
if not self.if_yin: f0 = f0[0]
|
| 257 |
+
return self._resize_f0(f0, p_len)
|
RVC/modules/hifigan.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from torch.nn.utils import remove_weight_norm
|
| 7 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 8 |
+
|
| 9 |
+
sys.path.append(os.getcwd())
|
| 10 |
+
|
| 11 |
+
from modules.commons import init_weights
|
| 12 |
+
from modules.residuals import ResBlock, LRELU_SLOPE
|
| 13 |
+
|
| 14 |
+
class HiFiGANGenerator(torch.nn.Module):
|
| 15 |
+
def __init__(self, initial_channel, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
|
| 16 |
+
super(HiFiGANGenerator, self).__init__()
|
| 17 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 18 |
+
self.num_upsamples = len(upsample_rates)
|
| 19 |
+
self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
| 20 |
+
self.ups_and_resblocks = torch.nn.ModuleList()
|
| 21 |
+
|
| 22 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 23 |
+
self.ups_and_resblocks.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2)))
|
| 24 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
| 25 |
+
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
| 26 |
+
self.ups_and_resblocks.append(ResBlock(ch, k, d))
|
| 27 |
+
|
| 28 |
+
self.conv_post = torch.nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
| 29 |
+
self.ups_and_resblocks.apply(init_weights)
|
| 30 |
+
if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
| 31 |
+
|
| 32 |
+
def forward(self, x, g = None):
|
| 33 |
+
x = self.conv_pre(x)
|
| 34 |
+
if g is not None: x = x + self.cond(g)
|
| 35 |
+
|
| 36 |
+
resblock_idx = 0
|
| 37 |
+
|
| 38 |
+
for _ in range(self.num_upsamples):
|
| 39 |
+
x = self.ups_and_resblocks[resblock_idx](F.leaky_relu(x, LRELU_SLOPE))
|
| 40 |
+
resblock_idx += 1
|
| 41 |
+
xs = 0
|
| 42 |
+
|
| 43 |
+
for _ in range(self.num_kernels):
|
| 44 |
+
xs += self.ups_and_resblocks[resblock_idx](x)
|
| 45 |
+
resblock_idx += 1
|
| 46 |
+
|
| 47 |
+
x = xs / self.num_kernels
|
| 48 |
+
|
| 49 |
+
return torch.tanh(self.conv_post(F.leaky_relu(x)))
|
| 50 |
+
|
| 51 |
+
def __prepare_scriptable__(self):
|
| 52 |
+
for l in self.ups_and_resblocks:
|
| 53 |
+
for hook in l._forward_pre_hooks.values():
|
| 54 |
+
if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(l)
|
| 55 |
+
|
| 56 |
+
return self
|
| 57 |
+
|
| 58 |
+
def remove_weight_norm(self):
|
| 59 |
+
for l in self.ups_and_resblocks:
|
| 60 |
+
remove_weight_norm(l)
|
RVC/modules/mediafire.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import requests
|
| 4 |
+
|
| 5 |
+
from bs4 import BeautifulSoup
|
| 6 |
+
|
| 7 |
+
def Mediafire_Download(url, output=None, filename=None):
|
| 8 |
+
if not filename: filename = url.split('/')[-2]
|
| 9 |
+
if not output: output = os.path.dirname(os.path.realpath(__file__))
|
| 10 |
+
output_file = os.path.join(output, filename)
|
| 11 |
+
|
| 12 |
+
sess = requests.session()
|
| 13 |
+
sess.headers.update({"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)"})
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
with requests.get(BeautifulSoup(sess.get(url).content, "html.parser").find(id="downloadButton").get("href"), stream=True) as r:
|
| 17 |
+
r.raise_for_status()
|
| 18 |
+
with open(output_file, "wb") as f:
|
| 19 |
+
total_length = int(r.headers.get('content-length'))
|
| 20 |
+
download_progress = 0
|
| 21 |
+
|
| 22 |
+
for chunk in r.iter_content(chunk_size=1024):
|
| 23 |
+
download_progress += len(chunk)
|
| 24 |
+
f.write(chunk)
|
| 25 |
+
sys.stdout.write(f"\r[{filename}]: {int(100 * download_progress/total_length)}% ({round(download_progress/1024/1024, 2)}mb/{round(total_length/1024/1024, 2)}mb)")
|
| 26 |
+
sys.stdout.flush()
|
| 27 |
+
sys.stdout.write("\n")
|
| 28 |
+
return output_file
|
| 29 |
+
except Exception as e:
|
| 30 |
+
raise RuntimeError(e)
|
RVC/modules/meganz.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import json
|
| 4 |
+
import codecs
|
| 5 |
+
import random
|
| 6 |
+
import base64
|
| 7 |
+
import struct
|
| 8 |
+
import shutil
|
| 9 |
+
import requests
|
| 10 |
+
import tempfile
|
| 11 |
+
|
| 12 |
+
from Crypto.Cipher import AES
|
| 13 |
+
from Crypto.Util import Counter
|
| 14 |
+
|
| 15 |
+
def makebyte(x):
|
| 16 |
+
return codecs.latin_1_encode(x)[0]
|
| 17 |
+
|
| 18 |
+
def a32_to_str(a):
|
| 19 |
+
return struct.pack('>%dI' % len(a), *a)
|
| 20 |
+
|
| 21 |
+
def get_chunks(size):
|
| 22 |
+
p, s = 0, 0x20000
|
| 23 |
+
|
| 24 |
+
while p + s < size:
|
| 25 |
+
yield(p, s)
|
| 26 |
+
p += s
|
| 27 |
+
|
| 28 |
+
if s < 0x100000: s += 0x20000
|
| 29 |
+
|
| 30 |
+
yield(p, size - p)
|
| 31 |
+
|
| 32 |
+
def aes_cbc_decrypt(data, key):
|
| 33 |
+
aes_cipher = AES.new(key, AES.MODE_CBC, makebyte('\0' * 16))
|
| 34 |
+
return aes_cipher.decrypt(data)
|
| 35 |
+
|
| 36 |
+
def decrypt_attr(attr, key):
|
| 37 |
+
attr = codecs.latin_1_decode(aes_cbc_decrypt(attr, a32_to_str(key)))[0].rstrip('\0')
|
| 38 |
+
return json.loads(attr[4:]) if attr[:6] == 'MEGA{"' else False
|
| 39 |
+
|
| 40 |
+
def _api_request(data):
|
| 41 |
+
sequence_num = random.randint(0, 0xFFFFFFFF)
|
| 42 |
+
params = {'id': sequence_num}
|
| 43 |
+
sequence_num += 1
|
| 44 |
+
|
| 45 |
+
if not isinstance(data, list): data = [data]
|
| 46 |
+
json_resp = json.loads(requests.post('{0}://g.api.{1}/cs'.format('https', 'mega.co.nz'), params=params, data=json.dumps(data), timeout=160).text)
|
| 47 |
+
if isinstance(json_resp, int): raise Exception(json_resp)
|
| 48 |
+
|
| 49 |
+
return json_resp[0]
|
| 50 |
+
|
| 51 |
+
def base64_url_decode(data):
|
| 52 |
+
data += '=='[(2 - len(data) * 3) % 4:]
|
| 53 |
+
|
| 54 |
+
for search, replace in (('-', '+'), ('_', '/'), (',', '')):
|
| 55 |
+
data = data.replace(search, replace)
|
| 56 |
+
|
| 57 |
+
return base64.b64decode(data)
|
| 58 |
+
|
| 59 |
+
def str_to_a32(b):
|
| 60 |
+
if isinstance(b, str): b = makebyte(b)
|
| 61 |
+
if len(b) % 4: b += b'\0' * (4 - len(b) % 4)
|
| 62 |
+
return struct.unpack('>%dI' % (len(b) / 4), b)
|
| 63 |
+
|
| 64 |
+
def base64_to_a32(s):
|
| 65 |
+
return str_to_a32(base64_url_decode(s))
|
| 66 |
+
|
| 67 |
+
def mega_download_file(file_handle, file_key, dest_path=None):
|
| 68 |
+
file_key = base64_to_a32(file_key)
|
| 69 |
+
file_data = _api_request({'a': 'g', 'g': 1, 'p': file_handle})
|
| 70 |
+
|
| 71 |
+
k = (file_key[0] ^ file_key[4], file_key[1] ^ file_key[5], file_key[2] ^ file_key[6], file_key[3] ^ file_key[7])
|
| 72 |
+
iv = file_key[4:6] + (0, 0)
|
| 73 |
+
|
| 74 |
+
if 'g' not in file_data: raise Exception
|
| 75 |
+
|
| 76 |
+
file_size = file_data['s']
|
| 77 |
+
attribs = decrypt_attr(base64_url_decode(file_data['at']), k)
|
| 78 |
+
input_file = requests.get(file_data['g'], stream=True).raw
|
| 79 |
+
|
| 80 |
+
temp_output_file = tempfile.NamedTemporaryFile(mode='w+b', prefix='megapy_', delete=False)
|
| 81 |
+
k_str = a32_to_str(k)
|
| 82 |
+
aes = AES.new(k_str, AES.MODE_CTR, counter=Counter.new(128, initial_value=((iv[0] << 32) + iv[1]) << 64))
|
| 83 |
+
|
| 84 |
+
mac_str = b'\0' * 16
|
| 85 |
+
mac_encryptor = AES.new(k_str, AES.MODE_CBC, mac_str)
|
| 86 |
+
iv_str = a32_to_str([iv[0], iv[1], iv[0], iv[1]])
|
| 87 |
+
|
| 88 |
+
for _, chunk_size in get_chunks(file_size):
|
| 89 |
+
chunk = aes.decrypt(input_file.read(chunk_size))
|
| 90 |
+
temp_output_file.write(chunk)
|
| 91 |
+
|
| 92 |
+
encryptor = AES.new(k_str, AES.MODE_CBC, iv_str)
|
| 93 |
+
|
| 94 |
+
for i in range(0, len(chunk) - 16, 16):
|
| 95 |
+
block = chunk[i:i + 16]
|
| 96 |
+
encryptor.encrypt(block)
|
| 97 |
+
|
| 98 |
+
i = (i + 16) if file_size > 16 else 0
|
| 99 |
+
block = chunk[i:i + 16]
|
| 100 |
+
if len(block) % 16: block += b'\0' * (16 - (len(block) % 16))
|
| 101 |
+
|
| 102 |
+
mac_str = mac_encryptor.encrypt(encryptor.encrypt(block))
|
| 103 |
+
|
| 104 |
+
file_mac = str_to_a32(mac_str)
|
| 105 |
+
temp_output_file.close()
|
| 106 |
+
|
| 107 |
+
if (file_mac[0] ^ file_mac[1], file_mac[2] ^ file_mac[3]) != file_key[6:8]: raise ValueError
|
| 108 |
+
|
| 109 |
+
file_path = os.path.join(dest_path, attribs['n'])
|
| 110 |
+
if os.path.exists(file_path): os.remove(file_path)
|
| 111 |
+
|
| 112 |
+
shutil.move(temp_output_file.name, file_path)
|
| 113 |
+
|
| 114 |
+
def mega_download_url(url, dest_path=None):
|
| 115 |
+
if '/file/' in url:
|
| 116 |
+
url = url.replace(' ', '')
|
| 117 |
+
file_id = re.findall(r'\W\w\w\w\w\w\w\w\w\W', url)[0][1:-1]
|
| 118 |
+
path = f'{file_id}!{url[re.search(file_id, url).end() + 1:]}'.split('!')
|
| 119 |
+
elif '!' in url: path = re.findall(r'/#!(.*)', url)[0].split('!')
|
| 120 |
+
else: raise Exception
|
| 121 |
+
|
| 122 |
+
return mega_download_file(path[0], path[1], dest_path)
|
RVC/modules/modules.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
sys.path.append(os.getcwd())
|
| 6 |
+
|
| 7 |
+
from .commons import fused_add_tanh_sigmoid_multiply
|
| 8 |
+
|
| 9 |
+
class WaveNet(torch.nn.Module):
|
| 10 |
+
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
|
| 11 |
+
super(WaveNet, self).__init__()
|
| 12 |
+
assert kernel_size % 2 == 1
|
| 13 |
+
self.hidden_channels = hidden_channels
|
| 14 |
+
self.kernel_size = (kernel_size,)
|
| 15 |
+
self.dilation_rate = dilation_rate
|
| 16 |
+
self.n_layers = n_layers
|
| 17 |
+
self.gin_channels = gin_channels
|
| 18 |
+
self.p_dropout = p_dropout
|
| 19 |
+
self.in_layers = torch.nn.ModuleList()
|
| 20 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
| 21 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
| 22 |
+
if gin_channels != 0: self.cond_layer = torch.nn.utils.parametrizations.weight_norm(torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1), name="weight")
|
| 23 |
+
dilations = [dilation_rate ** i for i in range(n_layers)]
|
| 24 |
+
paddings = [(kernel_size * d - d) // 2 for d in dilations]
|
| 25 |
+
|
| 26 |
+
for i in range(n_layers):
|
| 27 |
+
in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilations[i], padding=paddings[i])
|
| 28 |
+
in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
|
| 29 |
+
self.in_layers.append(in_layer)
|
| 30 |
+
res_skip_channels = (hidden_channels if i == n_layers - 1 else 2 * hidden_channels)
|
| 31 |
+
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
| 32 |
+
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
|
| 33 |
+
self.res_skip_layers.append(res_skip_layer)
|
| 34 |
+
|
| 35 |
+
def forward(self, x, x_mask, g=None):
|
| 36 |
+
output = x.clone().zero_()
|
| 37 |
+
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
| 38 |
+
|
| 39 |
+
if g is not None: g = self.cond_layer(g)
|
| 40 |
+
|
| 41 |
+
for i in range(self.n_layers):
|
| 42 |
+
x_in = self.in_layers[i](x)
|
| 43 |
+
g_l = (g[:, i * 2 * self.hidden_channels : (i + 1) * 2 * self.hidden_channels, :] if g is not None else 0)
|
| 44 |
+
res_skip_acts = self.res_skip_layers[i](self.drop(fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)))
|
| 45 |
+
|
| 46 |
+
if i < self.n_layers - 1:
|
| 47 |
+
x = (x + (res_skip_acts[:, : self.hidden_channels, :])) * x_mask
|
| 48 |
+
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
| 49 |
+
else: output = output + res_skip_acts
|
| 50 |
+
|
| 51 |
+
return output * x_mask
|
| 52 |
+
|
| 53 |
+
def remove_weight_norm(self):
|
| 54 |
+
if self.gin_channels != 0: torch.nn.utils.remove_weight_norm(self.cond_layer)
|
| 55 |
+
|
| 56 |
+
for l in self.in_layers:
|
| 57 |
+
torch.nn.utils.remove_weight_norm(l)
|
| 58 |
+
|
| 59 |
+
for l in self.res_skip_layers:
|
| 60 |
+
torch.nn.utils.remove_weight_norm(l)
|
RVC/modules/mrf_hifigan.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from torch.nn.utils import remove_weight_norm
|
| 9 |
+
from torch.utils.checkpoint import checkpoint
|
| 10 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 11 |
+
|
| 12 |
+
LRELU_SLOPE = 0.1
|
| 13 |
+
|
| 14 |
+
class MRFLayer(nn.Module):
|
| 15 |
+
def __init__(self, channels, kernel_size, dilation):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.conv1 = weight_norm(nn.Conv1d(channels, channels, kernel_size, padding=(kernel_size * dilation - dilation) // 2, dilation=dilation))
|
| 18 |
+
self.conv2 = weight_norm(nn.Conv1d(channels, channels, kernel_size, padding=kernel_size // 2, dilation=1))
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
return x + self.conv2(F.leaky_relu(self.conv1(F.leaky_relu(x, LRELU_SLOPE)), LRELU_SLOPE))
|
| 22 |
+
|
| 23 |
+
def remove_weight_norm(self):
|
| 24 |
+
remove_weight_norm(self.conv1)
|
| 25 |
+
remove_weight_norm(self.conv2)
|
| 26 |
+
|
| 27 |
+
class MRFBlock(nn.Module):
|
| 28 |
+
def __init__(self, channels, kernel_size, dilations):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.layers = nn.ModuleList()
|
| 31 |
+
|
| 32 |
+
for dilation in dilations:
|
| 33 |
+
self.layers.append(MRFLayer(channels, kernel_size, dilation))
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
for layer in self.layers:
|
| 37 |
+
x = layer(x)
|
| 38 |
+
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
def remove_weight_norm(self):
|
| 42 |
+
for layer in self.layers:
|
| 43 |
+
layer.remove_weight_norm()
|
| 44 |
+
|
| 45 |
+
class SineGenerator(nn.Module):
|
| 46 |
+
def __init__(self, samp_rate, harmonic_num = 0, sine_amp = 0.1, noise_std = 0.003, voiced_threshold = 0):
|
| 47 |
+
super(SineGenerator, self).__init__()
|
| 48 |
+
self.sine_amp = sine_amp
|
| 49 |
+
self.noise_std = noise_std
|
| 50 |
+
self.harmonic_num = harmonic_num
|
| 51 |
+
self.dim = self.harmonic_num + 1
|
| 52 |
+
self.sampling_rate = samp_rate
|
| 53 |
+
self.voiced_threshold = voiced_threshold
|
| 54 |
+
|
| 55 |
+
def _f02uv(self, f0):
|
| 56 |
+
return torch.ones_like(f0) * (f0 > self.voiced_threshold)
|
| 57 |
+
|
| 58 |
+
def _f02sine(self, f0_values):
|
| 59 |
+
rad_values = (f0_values / self.sampling_rate) % 1
|
| 60 |
+
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], dtype=f0_values.dtype, device=f0_values.device)
|
| 61 |
+
rand_ini[:, 0] = 0
|
| 62 |
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
| 63 |
+
tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
| 64 |
+
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
|
| 65 |
+
cumsum_shift = torch.zeros_like(rad_values)
|
| 66 |
+
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
| 67 |
+
|
| 68 |
+
return torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
|
| 69 |
+
|
| 70 |
+
def forward(self, f0):
|
| 71 |
+
with torch.no_grad():
|
| 72 |
+
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, dtype=f0.dtype, device=f0.device)
|
| 73 |
+
f0_buf[:, :, 0] = f0[:, :, 0]
|
| 74 |
+
|
| 75 |
+
for idx in np.arange(self.harmonic_num):
|
| 76 |
+
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
|
| 77 |
+
|
| 78 |
+
sine_waves = self._f02sine(f0_buf) * self.sine_amp
|
| 79 |
+
uv = self._f02uv(f0)
|
| 80 |
+
sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
|
| 81 |
+
|
| 82 |
+
return sine_waves
|
| 83 |
+
|
| 84 |
+
class SourceModuleHnNSF(nn.Module):
|
| 85 |
+
def __init__(self, sampling_rate, harmonic_num = 0, sine_amp = 0.1, add_noise_std = 0.003, voiced_threshold = 0):
|
| 86 |
+
super(SourceModuleHnNSF, self).__init__()
|
| 87 |
+
self.sine_amp = sine_amp
|
| 88 |
+
self.noise_std = add_noise_std
|
| 89 |
+
self.l_sin_gen = SineGenerator(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshold)
|
| 90 |
+
self.l_linear = nn.Linear(harmonic_num + 1, 1)
|
| 91 |
+
self.l_tanh = nn.Tanh()
|
| 92 |
+
|
| 93 |
+
def forward(self, x):
|
| 94 |
+
return self.l_tanh(self.l_linear(self.l_sin_gen(x).to(dtype=self.l_linear.weight.dtype)))
|
| 95 |
+
|
| 96 |
+
class HiFiGANMRFGenerator(nn.Module):
|
| 97 |
+
def __init__(self, in_channel, upsample_initial_channel, upsample_rates, upsample_kernel_sizes, resblock_kernel_sizes, resblock_dilations, gin_channels, sample_rate, harmonic_num, checkpointing = False):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 100 |
+
self.checkpointing = checkpointing
|
| 101 |
+
self.f0_upsample = nn.Upsample(scale_factor=np.prod(upsample_rates))
|
| 102 |
+
self.m_source = SourceModuleHnNSF(sample_rate, harmonic_num)
|
| 103 |
+
self.conv_pre = weight_norm(nn.Conv1d(in_channel, upsample_initial_channel, kernel_size=7, stride=1, padding=3))
|
| 104 |
+
self.upsamples = nn.ModuleList()
|
| 105 |
+
self.noise_convs = nn.ModuleList()
|
| 106 |
+
stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1 for i in range(len(upsample_rates))]
|
| 107 |
+
|
| 108 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 109 |
+
self.upsamples.append(weight_norm(nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), kernel_size=k, stride=u, padding=((k - u) // 2) if u % 2 == 0 else (u // 2 + u % 2), output_padding=u % 2)))
|
| 110 |
+
stride = stride_f0s[i]
|
| 111 |
+
kernel = 1 if stride == 1 else stride * 2 - stride % 2
|
| 112 |
+
self.noise_convs.append(nn.Conv1d(1, upsample_initial_channel // (2 ** (i + 1)), kernel_size=kernel, stride=stride, padding=0 if stride == 1 else (kernel - stride) // 2))
|
| 113 |
+
|
| 114 |
+
self.mrfs = nn.ModuleList()
|
| 115 |
+
for i in range(len(self.upsamples)):
|
| 116 |
+
channel = upsample_initial_channel // (2 ** (i + 1))
|
| 117 |
+
self.mrfs.append(nn.ModuleList([MRFBlock(channel, kernel_size=k, dilations=d) for k, d in zip(resblock_kernel_sizes, resblock_dilations)]))
|
| 118 |
+
|
| 119 |
+
self.conv_post = weight_norm(nn.Conv1d(channel, 1, kernel_size=7, stride=1, padding=3))
|
| 120 |
+
if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
| 121 |
+
|
| 122 |
+
def forward(self, x, f0, g = None):
|
| 123 |
+
har_source = self.m_source(self.f0_upsample(f0[:, None, :]).transpose(-1, -2)).transpose(-1, -2)
|
| 124 |
+
x = self.conv_pre(x)
|
| 125 |
+
if g is not None: x += self.cond(g)
|
| 126 |
+
|
| 127 |
+
for ups, mrf, noise_conv in zip(self.upsamples, self.mrfs, self.noise_convs):
|
| 128 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 129 |
+
|
| 130 |
+
if self.training and self.checkpointing:
|
| 131 |
+
x = checkpoint(ups, x, use_reentrant=False) + noise_conv(har_source)
|
| 132 |
+
xs = sum([checkpoint(layer, x, use_reentrant=False) for layer in mrf])
|
| 133 |
+
else:
|
| 134 |
+
x = ups(x) + noise_conv(har_source)
|
| 135 |
+
xs = sum([layer(x) for layer in mrf])
|
| 136 |
+
|
| 137 |
+
x = xs / self.num_kernels
|
| 138 |
+
|
| 139 |
+
return torch.tanh(self.conv_post(F.leaky_relu(x)))
|
| 140 |
+
|
| 141 |
+
def remove_weight_norm(self):
|
| 142 |
+
remove_weight_norm(self.conv_pre)
|
| 143 |
+
|
| 144 |
+
for up in self.upsamples:
|
| 145 |
+
remove_weight_norm(up)
|
| 146 |
+
|
| 147 |
+
for mrf in self.mrfs:
|
| 148 |
+
mrf.remove_weight_norm()
|
| 149 |
+
|
| 150 |
+
remove_weight_norm(self.conv_post)
|
RVC/modules/noisereduce.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import tempfile
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from joblib import Parallel, delayed
|
| 6 |
+
from torch.nn.functional import conv1d, conv2d
|
| 7 |
+
|
| 8 |
+
@torch.no_grad()
|
| 9 |
+
def amp_to_db(x, eps = torch.finfo(torch.float32).eps, top_db = 40):
|
| 10 |
+
x_db = 20 * torch.log10(x.abs() + eps)
|
| 11 |
+
return torch.max(x_db, (x_db.max(-1).values - top_db).unsqueeze(-1))
|
| 12 |
+
|
| 13 |
+
@torch.no_grad()
|
| 14 |
+
def temperature_sigmoid(x, x0, temp_coeff):
|
| 15 |
+
return torch.sigmoid((x - x0) / temp_coeff)
|
| 16 |
+
|
| 17 |
+
@torch.no_grad()
|
| 18 |
+
def linspace(start, stop, num = 50, endpoint = True, **kwargs):
|
| 19 |
+
return torch.linspace(start, stop, num, **kwargs) if endpoint else torch.linspace(start, stop, num + 1, **kwargs)[:-1]
|
| 20 |
+
|
| 21 |
+
def _smoothing_filter(n_grad_freq, n_grad_time):
|
| 22 |
+
smoothing_filter = np.outer(np.concatenate([np.linspace(0, 1, n_grad_freq + 1, endpoint=False), np.linspace(1, 0, n_grad_freq + 2)])[1:-1], np.concatenate([np.linspace(0, 1, n_grad_time + 1, endpoint=False), np.linspace(1, 0, n_grad_time + 2)])[1:-1])
|
| 23 |
+
return smoothing_filter / np.sum(smoothing_filter)
|
| 24 |
+
|
| 25 |
+
class SpectralGate:
|
| 26 |
+
def __init__(self, y, sr, prop_decrease, chunk_size, padding, n_fft, win_length, hop_length, time_constant_s, freq_mask_smooth_hz, time_mask_smooth_ms, tmp_folder, use_tqdm, n_jobs):
|
| 27 |
+
self.sr = sr
|
| 28 |
+
self.flat = False
|
| 29 |
+
y = np.array(y)
|
| 30 |
+
|
| 31 |
+
if len(y.shape) == 1:
|
| 32 |
+
self.y = np.expand_dims(y, 0)
|
| 33 |
+
self.flat = True
|
| 34 |
+
elif len(y.shape) > 2: raise ValueError
|
| 35 |
+
else: self.y = y
|
| 36 |
+
|
| 37 |
+
self._dtype = y.dtype
|
| 38 |
+
self.n_channels, self.n_frames = self.y.shape
|
| 39 |
+
self._chunk_size = chunk_size
|
| 40 |
+
self.padding = padding
|
| 41 |
+
self.n_jobs = n_jobs
|
| 42 |
+
self.use_tqdm = use_tqdm
|
| 43 |
+
self._tmp_folder = tmp_folder
|
| 44 |
+
self._n_fft = n_fft
|
| 45 |
+
self._win_length = self._n_fft if win_length is None else win_length
|
| 46 |
+
self._hop_length = (self._win_length // 4) if hop_length is None else hop_length
|
| 47 |
+
self._time_constant_s = time_constant_s
|
| 48 |
+
self._prop_decrease = prop_decrease
|
| 49 |
+
|
| 50 |
+
if (freq_mask_smooth_hz is None) & (time_mask_smooth_ms is None): self.smooth_mask = False
|
| 51 |
+
else: self._generate_mask_smoothing_filter(freq_mask_smooth_hz, time_mask_smooth_ms)
|
| 52 |
+
|
| 53 |
+
def _generate_mask_smoothing_filter(self, freq_mask_smooth_hz, time_mask_smooth_ms):
|
| 54 |
+
if freq_mask_smooth_hz is None: n_grad_freq = 1
|
| 55 |
+
else:
|
| 56 |
+
n_grad_freq = int(freq_mask_smooth_hz / (self.sr / (self._n_fft / 2)))
|
| 57 |
+
if n_grad_freq < 1: raise ValueError
|
| 58 |
+
|
| 59 |
+
if time_mask_smooth_ms is None: n_grad_time = 1
|
| 60 |
+
else:
|
| 61 |
+
n_grad_time = int(time_mask_smooth_ms / ((self._hop_length / self.sr) * 1000))
|
| 62 |
+
if n_grad_time < 1: raise ValueError
|
| 63 |
+
|
| 64 |
+
if (n_grad_time == 1) & (n_grad_freq == 1): self.smooth_mask = False
|
| 65 |
+
else:
|
| 66 |
+
self.smooth_mask = True
|
| 67 |
+
self._smoothing_filter = _smoothing_filter(n_grad_freq, n_grad_time)
|
| 68 |
+
|
| 69 |
+
def _read_chunk(self, i1, i2):
|
| 70 |
+
i1b = 0 if i1 < 0 else i1
|
| 71 |
+
i2b = self.n_frames if i2 > self.n_frames else i2
|
| 72 |
+
chunk = np.zeros((self.n_channels, i2 - i1))
|
| 73 |
+
chunk[:, i1b - i1: i2b - i1] = self.y[:, i1b:i2b]
|
| 74 |
+
return chunk
|
| 75 |
+
|
| 76 |
+
def filter_chunk(self, start_frame, end_frame):
|
| 77 |
+
i1 = start_frame - self.padding
|
| 78 |
+
return self._do_filter(self._read_chunk(i1, (end_frame + self.padding)))[:, start_frame - i1: end_frame - i1]
|
| 79 |
+
|
| 80 |
+
def _get_filtered_chunk(self, ind):
|
| 81 |
+
start0 = ind * self._chunk_size
|
| 82 |
+
end0 = (ind + 1) * self._chunk_size
|
| 83 |
+
return self.filter_chunk(start_frame=start0, end_frame=end0)
|
| 84 |
+
|
| 85 |
+
def _do_filter(self, chunk):
|
| 86 |
+
pass
|
| 87 |
+
|
| 88 |
+
def _iterate_chunk(self, filtered_chunk, pos, end0, start0, ich):
|
| 89 |
+
filtered_chunk[:, pos: pos + end0 - start0] = self._get_filtered_chunk(ich)[:, start0:end0]
|
| 90 |
+
pos += end0 - start0
|
| 91 |
+
|
| 92 |
+
def get_traces(self, start_frame=None, end_frame=None):
|
| 93 |
+
if start_frame is None: start_frame = 0
|
| 94 |
+
if end_frame is None: end_frame = self.n_frames
|
| 95 |
+
|
| 96 |
+
if self._chunk_size is not None:
|
| 97 |
+
if end_frame - start_frame > self._chunk_size:
|
| 98 |
+
ich1 = int(start_frame / self._chunk_size)
|
| 99 |
+
ich2 = int((end_frame - 1) / self._chunk_size)
|
| 100 |
+
|
| 101 |
+
with tempfile.NamedTemporaryFile(prefix=self._tmp_folder) as fp:
|
| 102 |
+
filtered_chunk = np.memmap(fp, dtype=self._dtype, shape=(self.n_channels, int(end_frame - start_frame)), mode="w+")
|
| 103 |
+
pos_list, start_list, end_list = [], [], []
|
| 104 |
+
pos = 0
|
| 105 |
+
|
| 106 |
+
for ich in range(ich1, ich2 + 1):
|
| 107 |
+
start0 = (start_frame - ich * self._chunk_size) if ich == ich1 else 0
|
| 108 |
+
end0 = end_frame - ich * self._chunk_size if ich == ich2 else self._chunk_size
|
| 109 |
+
pos_list.append(pos)
|
| 110 |
+
start_list.append(start0)
|
| 111 |
+
end_list.append(end0)
|
| 112 |
+
pos += end0 - start0
|
| 113 |
+
|
| 114 |
+
Parallel(n_jobs=self.n_jobs)(delayed(self._iterate_chunk)(filtered_chunk, pos, end0, start0, ich) for pos, start0, end0, ich in zip(pos_list, start_list, end_list, range(ich1, ich2 + 1)))
|
| 115 |
+
return filtered_chunk.astype(self._dtype).flatten() if self.flat else filtered_chunk.astype(self._dtype)
|
| 116 |
+
|
| 117 |
+
filtered_chunk = self.filter_chunk(start_frame=0, end_frame=end_frame)
|
| 118 |
+
return filtered_chunk.astype(self._dtype).flatten() if self.flat else filtered_chunk.astype(self._dtype)
|
| 119 |
+
|
| 120 |
+
class TG(torch.nn.Module):
|
| 121 |
+
@torch.no_grad()
|
| 122 |
+
def __init__(self, sr, nonstationary = False, n_std_thresh_stationary = 1.5, n_thresh_nonstationary = 1.3, temp_coeff_nonstationary = 0.1, n_movemean_nonstationary = 20, prop_decrease = 1.0, n_fft = 1024, win_length = None, hop_length = None, freq_mask_smooth_hz = 500, time_mask_smooth_ms = 50):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.sr = sr
|
| 125 |
+
self.nonstationary = nonstationary
|
| 126 |
+
assert 0.0 <= prop_decrease <= 1.0
|
| 127 |
+
self.prop_decrease = prop_decrease
|
| 128 |
+
self.n_fft = n_fft
|
| 129 |
+
self.win_length = self.n_fft if win_length is None else win_length
|
| 130 |
+
self.hop_length = self.win_length // 4 if hop_length is None else hop_length
|
| 131 |
+
self.n_std_thresh_stationary = n_std_thresh_stationary
|
| 132 |
+
self.temp_coeff_nonstationary = temp_coeff_nonstationary
|
| 133 |
+
self.n_movemean_nonstationary = n_movemean_nonstationary
|
| 134 |
+
self.n_thresh_nonstationary = n_thresh_nonstationary
|
| 135 |
+
self.freq_mask_smooth_hz = freq_mask_smooth_hz
|
| 136 |
+
self.time_mask_smooth_ms = time_mask_smooth_ms
|
| 137 |
+
self.register_buffer("smoothing_filter", self._generate_mask_smoothing_filter())
|
| 138 |
+
|
| 139 |
+
@torch.no_grad()
|
| 140 |
+
def _generate_mask_smoothing_filter(self):
|
| 141 |
+
if self.freq_mask_smooth_hz is None and self.time_mask_smooth_ms is None: return None
|
| 142 |
+
n_grad_freq = (1 if self.freq_mask_smooth_hz is None else int(self.freq_mask_smooth_hz / (self.sr / (self.n_fft / 2))))
|
| 143 |
+
if n_grad_freq < 1: raise ValueError
|
| 144 |
+
|
| 145 |
+
n_grad_time = (1 if self.time_mask_smooth_ms is None else int(self.time_mask_smooth_ms / ((self.hop_length / self.sr) * 1000)))
|
| 146 |
+
if n_grad_time < 1: raise ValueError
|
| 147 |
+
if n_grad_time == 1 and n_grad_freq == 1: return None
|
| 148 |
+
|
| 149 |
+
smoothing_filter = torch.outer(torch.cat([linspace(0, 1, n_grad_freq + 1, endpoint=False), linspace(1, 0, n_grad_freq + 2)])[1:-1], torch.cat([linspace(0, 1, n_grad_time + 1, endpoint=False), linspace(1, 0, n_grad_time + 2)])[1:-1]).unsqueeze(0).unsqueeze(0)
|
| 150 |
+
return smoothing_filter / smoothing_filter.sum()
|
| 151 |
+
|
| 152 |
+
@torch.no_grad()
|
| 153 |
+
def _stationary_mask(self, X_db, xn = None):
|
| 154 |
+
XN_db = amp_to_db(torch.stft(xn, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, return_complex=True, pad_mode="constant", center=True, window=torch.hann_window(self.win_length).to(xn.device))).to(dtype=X_db.dtype) if xn is not None else X_db
|
| 155 |
+
std_freq_noise, mean_freq_noise = torch.std_mean(XN_db, dim=-1)
|
| 156 |
+
return torch.gt(X_db, (mean_freq_noise + std_freq_noise * self.n_std_thresh_stationary).unsqueeze(2))
|
| 157 |
+
|
| 158 |
+
@torch.no_grad()
|
| 159 |
+
def _nonstationary_mask(self, X_abs):
|
| 160 |
+
X_smoothed = (conv1d(X_abs.reshape(-1, 1, X_abs.shape[-1]), torch.ones(self.n_movemean_nonstationary, dtype=X_abs.dtype, device=X_abs.device).view(1, 1, -1), padding="same").view(X_abs.shape) / self.n_movemean_nonstationary)
|
| 161 |
+
return temperature_sigmoid(((X_abs - X_smoothed) / X_smoothed), self.n_thresh_nonstationary, self.temp_coeff_nonstationary)
|
| 162 |
+
|
| 163 |
+
def forward(self, x, xn = None):
|
| 164 |
+
assert x.ndim == 2
|
| 165 |
+
if x.shape[-1] < self.win_length * 2: raise Exception
|
| 166 |
+
assert xn is None or xn.ndim == 1 or xn.ndim == 2
|
| 167 |
+
if xn is not None and xn.shape[-1] < self.win_length * 2: raise Exception
|
| 168 |
+
|
| 169 |
+
X = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, return_complex=True, pad_mode="constant", center=True, window=torch.hann_window(self.win_length).to(x.device))
|
| 170 |
+
sig_mask = self._nonstationary_mask(X.abs()) if self.nonstationary else self._stationary_mask(amp_to_db(X), xn)
|
| 171 |
+
|
| 172 |
+
sig_mask = self.prop_decrease * (sig_mask * 1.0 - 1.0) + 1.0
|
| 173 |
+
if self.smoothing_filter is not None: sig_mask = conv2d(sig_mask.unsqueeze(1), self.smoothing_filter.to(sig_mask.dtype), padding="same")
|
| 174 |
+
|
| 175 |
+
Y = X * sig_mask.squeeze(1)
|
| 176 |
+
return torch.istft(Y, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, center=True, window=torch.hann_window(self.win_length).to(Y.device)).to(dtype=x.dtype)
|
| 177 |
+
|
| 178 |
+
class StreamedTorchGate(SpectralGate):
|
| 179 |
+
def __init__(self, y, sr, stationary=False, y_noise=None, prop_decrease=1.0, time_constant_s=2.0, freq_mask_smooth_hz=500, time_mask_smooth_ms=50, thresh_n_mult_nonstationary=2, sigmoid_slope_nonstationary=10, n_std_thresh_stationary=1.5, tmp_folder=None, chunk_size=600000, padding=30000, n_fft=1024, win_length=None, hop_length=None, clip_noise_stationary=True, use_tqdm=False, n_jobs=1, device="cpu"):
|
| 180 |
+
super().__init__(y=y, sr=sr, chunk_size=chunk_size, padding=padding, n_fft=n_fft, win_length=win_length, hop_length=hop_length, time_constant_s=time_constant_s, freq_mask_smooth_hz=freq_mask_smooth_hz, time_mask_smooth_ms=time_mask_smooth_ms, tmp_folder=tmp_folder, prop_decrease=prop_decrease, use_tqdm=use_tqdm, n_jobs=n_jobs)
|
| 181 |
+
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
|
| 182 |
+
|
| 183 |
+
if y_noise is not None:
|
| 184 |
+
if y_noise.shape[-1] > y.shape[-1] and clip_noise_stationary: y_noise = y_noise[: y.shape[-1]]
|
| 185 |
+
y_noise = torch.from_numpy(y_noise).to(device)
|
| 186 |
+
if len(y_noise.shape) == 1: y_noise = y_noise.unsqueeze(0)
|
| 187 |
+
|
| 188 |
+
self.y_noise = y_noise
|
| 189 |
+
self.tg = TG(sr=sr, nonstationary=not stationary, n_std_thresh_stationary=n_std_thresh_stationary, n_thresh_nonstationary=thresh_n_mult_nonstationary, temp_coeff_nonstationary=1 / sigmoid_slope_nonstationary, n_movemean_nonstationary=int(time_constant_s / self._hop_length * sr), prop_decrease=prop_decrease, n_fft=self._n_fft, win_length=self._win_length, hop_length=self._hop_length, freq_mask_smooth_hz=freq_mask_smooth_hz, time_mask_smooth_ms=time_mask_smooth_ms).to(device)
|
| 190 |
+
|
| 191 |
+
def _do_filter(self, chunk):
|
| 192 |
+
if type(chunk) is np.ndarray: chunk = torch.from_numpy(chunk).to(self.device)
|
| 193 |
+
return self.tg(x=chunk, xn=self.y_noise).cpu().detach().numpy()
|
| 194 |
+
|
| 195 |
+
def reduce_noise(y, sr, stationary=False, y_noise=None, prop_decrease=1.0, time_constant_s=2.0, freq_mask_smooth_hz=500, time_mask_smooth_ms=50, thresh_n_mult_nonstationary=2, sigmoid_slope_nonstationary=10, tmp_folder=None, chunk_size=600000, padding=30000, n_fft=1024, win_length=None, hop_length=None, clip_noise_stationary=True, use_tqdm=False, device="cpu"):
|
| 196 |
+
return StreamedTorchGate(y=y, sr=sr, stationary=stationary, y_noise=y_noise, prop_decrease=prop_decrease, time_constant_s=time_constant_s, freq_mask_smooth_hz=freq_mask_smooth_hz, time_mask_smooth_ms=time_mask_smooth_ms, thresh_n_mult_nonstationary=thresh_n_mult_nonstationary, sigmoid_slope_nonstationary=sigmoid_slope_nonstationary, tmp_folder=tmp_folder, chunk_size=chunk_size, padding=padding, n_fft=n_fft, win_length=win_length, hop_length=hop_length, clip_noise_stationary=clip_noise_stationary, use_tqdm=use_tqdm, n_jobs=1, device=device).get_traces()
|
RVC/modules/normalization.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class LayerNorm(torch.nn.Module):
|
| 6 |
+
def __init__(self, channels, eps=1e-5):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.channels = channels
|
| 9 |
+
self.eps = eps
|
| 10 |
+
self.gamma = torch.nn.Parameter(torch.ones(channels))
|
| 11 |
+
self.beta = torch.nn.Parameter(torch.zeros(channels))
|
| 12 |
+
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
x = x.transpose(1, -1)
|
| 15 |
+
return F.layer_norm(x, (x.size(-1),), self.gamma, self.beta, self.eps).transpose(1, -1)
|
RVC/modules/nsf_hifigan.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from torch.nn.utils import remove_weight_norm
|
| 9 |
+
from torch.utils.checkpoint import checkpoint
|
| 10 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 11 |
+
|
| 12 |
+
sys.path.append(os.getcwd())
|
| 13 |
+
|
| 14 |
+
from modules.commons import init_weights
|
| 15 |
+
from modules.residuals import ResBlock, LRELU_SLOPE
|
| 16 |
+
|
| 17 |
+
class SineGen(torch.nn.Module):
|
| 18 |
+
def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0, flag_for_pulse=False):
|
| 19 |
+
super(SineGen, self).__init__()
|
| 20 |
+
self.sine_amp = sine_amp
|
| 21 |
+
self.noise_std = noise_std
|
| 22 |
+
self.harmonic_num = harmonic_num
|
| 23 |
+
self.dim = self.harmonic_num + 1
|
| 24 |
+
self.sampling_rate = samp_rate
|
| 25 |
+
self.voiced_threshold = voiced_threshold
|
| 26 |
+
|
| 27 |
+
def _f02uv(self, f0):
|
| 28 |
+
return torch.ones_like(f0) * (f0 > self.voiced_threshold)
|
| 29 |
+
|
| 30 |
+
def _f02sine(self, f0, upp):
|
| 31 |
+
rad = f0 / self.sampling_rate * torch.arange(1, upp + 1, dtype=f0.dtype, device=f0.device)
|
| 32 |
+
rad += F.pad((torch.fmod(rad[:, :-1, -1:].float() + 0.5, 1.0) - 0.5).cumsum(dim=1).fmod(1.0).to(f0), (0, 0, 1, 0), mode='constant')
|
| 33 |
+
rad = rad.reshape(f0.shape[0], -1, 1)
|
| 34 |
+
rad *= torch.arange(1, self.dim + 1, dtype=f0.dtype, device=f0.device).reshape(1, 1, -1)
|
| 35 |
+
rand_ini = torch.rand(1, 1, self.dim, device=f0.device)
|
| 36 |
+
rand_ini[..., 0] = 0
|
| 37 |
+
rad += rand_ini
|
| 38 |
+
|
| 39 |
+
return torch.sin(2 * np.pi * rad)
|
| 40 |
+
|
| 41 |
+
def forward(self, f0, upp):
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
f0 = f0.unsqueeze(-1)
|
| 44 |
+
sine_waves = self._f02sine(f0, upp) * self.sine_amp
|
| 45 |
+
uv = F.interpolate(self._f02uv(f0).transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1)
|
| 46 |
+
sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
|
| 47 |
+
|
| 48 |
+
return sine_waves
|
| 49 |
+
|
| 50 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
| 51 |
+
def __init__(self, sample_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0):
|
| 52 |
+
super(SourceModuleHnNSF, self).__init__()
|
| 53 |
+
self.sine_amp = sine_amp
|
| 54 |
+
self.noise_std = add_noise_std
|
| 55 |
+
self.l_sin_gen = SineGen(sample_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
|
| 56 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
| 57 |
+
self.l_tanh = torch.nn.Tanh()
|
| 58 |
+
|
| 59 |
+
def forward(self, x, upsample_factor = 1):
|
| 60 |
+
return self.l_tanh(self.l_linear(self.l_sin_gen(x, upsample_factor).to(dtype=self.l_linear.weight.dtype)))
|
| 61 |
+
|
| 62 |
+
class HiFiGANNRFGenerator(torch.nn.Module):
|
| 63 |
+
def __init__(self, initial_channel, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, sr, checkpointing = False):
|
| 64 |
+
super(HiFiGANNRFGenerator, self).__init__()
|
| 65 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 66 |
+
self.num_upsamples = len(upsample_rates)
|
| 67 |
+
self.upp = math.prod(upsample_rates)
|
| 68 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=self.upp)
|
| 69 |
+
self.m_source = SourceModuleHnNSF(sample_rate=sr, harmonic_num=0)
|
| 70 |
+
|
| 71 |
+
self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
| 72 |
+
self.checkpointing = checkpointing
|
| 73 |
+
|
| 74 |
+
self.ups = torch.nn.ModuleList()
|
| 75 |
+
self.noise_convs = torch.nn.ModuleList()
|
| 76 |
+
|
| 77 |
+
channels = [upsample_initial_channel // (2 ** (i + 1)) for i in range(self.num_upsamples)]
|
| 78 |
+
stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < self.num_upsamples else 1 for i in range(self.num_upsamples)]
|
| 79 |
+
|
| 80 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 81 |
+
self.ups.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), channels[i], k, u, padding=((k - u) // 2) if u % 2 == 0 else (u // 2 + u % 2), output_padding=u % 2)))
|
| 82 |
+
stride = stride_f0s[i]
|
| 83 |
+
kernel = 1 if stride == 1 else stride * 2 - stride % 2
|
| 84 |
+
self.noise_convs.append(torch.nn.Conv1d(1, channels[i], kernel_size=kernel, stride=stride, padding=0 if stride == 1 else (kernel - stride) // 2))
|
| 85 |
+
|
| 86 |
+
self.resblocks = torch.nn.ModuleList([ResBlock(channels[i], k, d) for i in range(len(self.ups)) for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes)])
|
| 87 |
+
self.conv_post = torch.nn.Conv1d(channels[-1], 1, 7, 1, padding=3, bias=False)
|
| 88 |
+
|
| 89 |
+
self.ups.apply(init_weights)
|
| 90 |
+
if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
| 91 |
+
|
| 92 |
+
def forward(self, x, f0, g = None):
|
| 93 |
+
har_source = self.m_source(f0, self.upp).transpose(1, 2)
|
| 94 |
+
x = self.conv_pre(x)
|
| 95 |
+
if g is not None: x += self.cond(g)
|
| 96 |
+
|
| 97 |
+
for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
|
| 98 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 99 |
+
|
| 100 |
+
if self.training and self.checkpointing:
|
| 101 |
+
x = checkpoint(ups, x, use_reentrant=False) + noise_convs(har_source)
|
| 102 |
+
xs = sum([checkpoint(resblock, x, use_reentrant=False) for j, resblock in enumerate(self.resblocks) if j in range(i * self.num_kernels, (i + 1) * self.num_kernels)])
|
| 103 |
+
else:
|
| 104 |
+
x = ups(x) + noise_convs(har_source)
|
| 105 |
+
xs = sum([resblock(x) for j, resblock in enumerate(self.resblocks) if j in range(i * self.num_kernels, (i + 1) * self.num_kernels)])
|
| 106 |
+
|
| 107 |
+
x = xs / self.num_kernels
|
| 108 |
+
|
| 109 |
+
return torch.tanh(self.conv_post(F.leaky_relu(x)))
|
| 110 |
+
|
| 111 |
+
def remove_weight_norm(self):
|
| 112 |
+
for l in self.ups:
|
| 113 |
+
remove_weight_norm(l)
|
| 114 |
+
|
| 115 |
+
for l in self.resblocks:
|
| 116 |
+
l.remove_weight_norm()
|
RVC/modules/opencl.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import platform
|
| 3 |
+
import subprocess
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from librosa.util import pad_center
|
| 10 |
+
from scipy.signal import get_window
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
import pytorch_ocl
|
| 14 |
+
except:
|
| 15 |
+
pytorch_ocl = None
|
| 16 |
+
|
| 17 |
+
torch_available = pytorch_ocl != None
|
| 18 |
+
|
| 19 |
+
def get_amd_gpu_windows():
|
| 20 |
+
try:
|
| 21 |
+
return [gpu.strip() for gpu in subprocess.check_output("wmic path win32_VideoController get name", shell=True).decode().split('\n')[1:] if 'AMD' in gpu or 'Radeon' in gpu or 'Vega' in gpu]
|
| 22 |
+
except:
|
| 23 |
+
return []
|
| 24 |
+
|
| 25 |
+
def get_amd_gpu_linux():
|
| 26 |
+
try:
|
| 27 |
+
return [gpu for gpu in subprocess.check_output("lspci | grep VGA", shell=True).decode().split('\n') if 'AMD' in gpu or 'Radeon' in gpu or 'Vega' in gpu]
|
| 28 |
+
except:
|
| 29 |
+
return []
|
| 30 |
+
|
| 31 |
+
def get_gpu_list():
|
| 32 |
+
return (get_amd_gpu_windows() if platform.system() == "Windows" else get_amd_gpu_linux()) if torch_available else []
|
| 33 |
+
|
| 34 |
+
def device_count():
|
| 35 |
+
return len(get_gpu_list()) if torch_available else 0
|
| 36 |
+
|
| 37 |
+
def device_name(device_id = 0):
|
| 38 |
+
return (get_gpu_list()[device_id] if device_id >= 0 and device_id < device_count() else "") if torch_available else ""
|
| 39 |
+
|
| 40 |
+
def is_available():
|
| 41 |
+
return (device_count() > 0) if torch_available else False
|
| 42 |
+
|
| 43 |
+
class STFT(torch.nn.Module):
|
| 44 |
+
def __init__(self, filter_length=1024, hop_length=512, win_length=None, window="hann"):
|
| 45 |
+
super(STFT, self).__init__()
|
| 46 |
+
self.filter_length = filter_length
|
| 47 |
+
self.hop_length = hop_length
|
| 48 |
+
self.pad_amount = int(self.filter_length / 2)
|
| 49 |
+
self.win_length = win_length
|
| 50 |
+
self.hann_window = {}
|
| 51 |
+
|
| 52 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
| 53 |
+
cutoff = int((self.filter_length / 2 + 1))
|
| 54 |
+
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])])
|
| 55 |
+
forward_basis = torch.FloatTensor(fourier_basis)
|
| 56 |
+
inverse_basis = torch.FloatTensor(np.linalg.pinv(fourier_basis))
|
| 57 |
+
|
| 58 |
+
if win_length is None or not win_length: win_length = filter_length
|
| 59 |
+
assert filter_length >= win_length
|
| 60 |
+
|
| 61 |
+
fft_window = torch.from_numpy(pad_center(get_window(window, win_length, fftbins=True), size=filter_length)).float()
|
| 62 |
+
forward_basis *= fft_window
|
| 63 |
+
inverse_basis = (inverse_basis.T * fft_window).T
|
| 64 |
+
|
| 65 |
+
self.register_buffer("forward_basis", forward_basis.float())
|
| 66 |
+
self.register_buffer("inverse_basis", inverse_basis.float())
|
| 67 |
+
self.register_buffer("fft_window", fft_window.float())
|
| 68 |
+
|
| 69 |
+
def transform(self, input_data, eps):
|
| 70 |
+
input_data = F.pad(input_data, (self.pad_amount, self.pad_amount), mode="reflect")
|
| 71 |
+
forward_transform = torch.matmul(self.forward_basis, input_data.unfold(1, self.filter_length, self.hop_length).permute(0, 2, 1))
|
| 72 |
+
cutoff = int(self.filter_length / 2 + 1)
|
| 73 |
+
|
| 74 |
+
return torch.sqrt(forward_transform[:, :cutoff, :]**2 + forward_transform[:, cutoff:, :]**2 + eps)
|
| 75 |
+
|
| 76 |
+
class GRU(nn.RNNBase):
|
| 77 |
+
def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0.0, bidirectional=False, device=None, dtype=None):
|
| 78 |
+
super().__init__("GRU", input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, device=device, dtype=dtype)
|
| 79 |
+
|
| 80 |
+
@staticmethod
|
| 81 |
+
def _gru_cell(x, hx, weight_ih, bias_ih, weight_hh, bias_hh):
|
| 82 |
+
gate_x = F.linear(x, weight_ih, bias_ih)
|
| 83 |
+
gate_h = F.linear(hx, weight_hh, bias_hh)
|
| 84 |
+
|
| 85 |
+
i_r, i_i, i_n = gate_x.chunk(3, 1)
|
| 86 |
+
h_r, h_i, h_n = gate_h.chunk(3, 1)
|
| 87 |
+
|
| 88 |
+
resetgate = torch.sigmoid(i_r + h_r)
|
| 89 |
+
inputgate = torch.sigmoid(i_i + h_i)
|
| 90 |
+
newgate = torch.tanh(i_n + resetgate * h_n)
|
| 91 |
+
|
| 92 |
+
hy = newgate + inputgate * (hx - newgate)
|
| 93 |
+
return hy
|
| 94 |
+
|
| 95 |
+
def _gru_layer(self, x, hx, weights):
|
| 96 |
+
weight_ih, weight_hh, bias_ih, bias_hh = weights
|
| 97 |
+
outputs = []
|
| 98 |
+
|
| 99 |
+
for x_t in x.unbind(1):
|
| 100 |
+
hx = self._gru_cell(x_t, hx, weight_ih, bias_ih, weight_hh, bias_hh)
|
| 101 |
+
outputs.append(hx)
|
| 102 |
+
|
| 103 |
+
return torch.stack(outputs, dim=1), hx
|
| 104 |
+
|
| 105 |
+
def _gru(self, x, hx):
|
| 106 |
+
if not self.batch_first: x = x.permute(1, 0, 2)
|
| 107 |
+
num_directions = 2 if self.bidirectional else 1
|
| 108 |
+
|
| 109 |
+
h_n = []
|
| 110 |
+
output_fwd, output_bwd = x, x
|
| 111 |
+
|
| 112 |
+
for layer in range(self.num_layers):
|
| 113 |
+
fwd_idx = layer * num_directions
|
| 114 |
+
bwd_idx = fwd_idx + 1 if self.bidirectional else None
|
| 115 |
+
|
| 116 |
+
weights_fwd = self._get_weights(fwd_idx)
|
| 117 |
+
h_fwd = hx[fwd_idx]
|
| 118 |
+
|
| 119 |
+
out_fwd, h_out_fwd = self._gru_layer(output_fwd, h_fwd, weights_fwd)
|
| 120 |
+
h_n.append(h_out_fwd)
|
| 121 |
+
|
| 122 |
+
if self.bidirectional:
|
| 123 |
+
weights_bwd = self._get_weights(bwd_idx)
|
| 124 |
+
h_bwd = hx[bwd_idx]
|
| 125 |
+
|
| 126 |
+
reversed_input = torch.flip(output_bwd, dims=[1])
|
| 127 |
+
out_bwd, h_out_bwd = self._gru_layer(reversed_input, h_bwd, weights_bwd)
|
| 128 |
+
|
| 129 |
+
out_bwd = torch.flip(out_bwd, dims=[1])
|
| 130 |
+
h_n.append(h_out_bwd)
|
| 131 |
+
|
| 132 |
+
output_fwd = torch.cat([out_fwd, out_bwd], dim=2)
|
| 133 |
+
output_bwd = output_fwd
|
| 134 |
+
else: output_fwd = out_fwd
|
| 135 |
+
|
| 136 |
+
if layer < self.num_layers - 1 and self.dropout > 0:
|
| 137 |
+
output_fwd = F.dropout(output_fwd, p=self.dropout, training=self.training)
|
| 138 |
+
if self.bidirectional: output_bwd = output_fwd
|
| 139 |
+
|
| 140 |
+
output = output_fwd
|
| 141 |
+
h_n = torch.stack(h_n, dim=0)
|
| 142 |
+
|
| 143 |
+
if not self.batch_first: output = output.permute(1, 0, 2)
|
| 144 |
+
return output, h_n
|
| 145 |
+
|
| 146 |
+
def _get_weights(self, layer_idx):
|
| 147 |
+
weights = self._all_weights[layer_idx]
|
| 148 |
+
|
| 149 |
+
weight_ih = getattr(self, weights[0])
|
| 150 |
+
weight_hh = getattr(self, weights[1])
|
| 151 |
+
|
| 152 |
+
bias_ih = getattr(self, weights[2]) if self.bias else None
|
| 153 |
+
bias_hh = getattr(self, weights[3]) if self.bias else None
|
| 154 |
+
|
| 155 |
+
return weight_ih, weight_hh, bias_ih, bias_hh
|
| 156 |
+
|
| 157 |
+
def forward(self, input, hx=None):
|
| 158 |
+
if input.dim() != 3: raise ValueError
|
| 159 |
+
|
| 160 |
+
batch_size = input.size(0) if self.batch_first else input.size(1)
|
| 161 |
+
num_directions = 2 if self.bidirectional else 1
|
| 162 |
+
|
| 163 |
+
if hx is None: hx = torch.zeros(self.num_layers * num_directions, batch_size, self.hidden_size, dtype=input.dtype, device=input.device)
|
| 164 |
+
|
| 165 |
+
self.check_forward_args(input, hx, batch_sizes=None)
|
| 166 |
+
return self._gru(input, hx)
|
| 167 |
+
|
| 168 |
+
def group_norm(x, num_groups, weight=None, bias=None, eps=1e-5):
|
| 169 |
+
N, C = x.shape[:2]
|
| 170 |
+
assert C % num_groups == 0
|
| 171 |
+
|
| 172 |
+
shape = (N, num_groups, C // num_groups) + x.shape[2:]
|
| 173 |
+
x_reshaped = x.view(shape)
|
| 174 |
+
|
| 175 |
+
dims = (2,) + tuple(range(3, x_reshaped.dim()))
|
| 176 |
+
mean = x_reshaped.mean(dim=dims, keepdim=True)
|
| 177 |
+
var = x_reshaped.var(dim=dims, keepdim=True, unbiased=False)
|
| 178 |
+
|
| 179 |
+
x_norm = (x_reshaped - mean) / torch.sqrt(var + eps)
|
| 180 |
+
x_norm = x_norm.view_as(x)
|
| 181 |
+
|
| 182 |
+
if weight is not None:
|
| 183 |
+
weight = weight.view(1, C, *([1] * (x.dim() - 2)))
|
| 184 |
+
x_norm = x_norm * weight
|
| 185 |
+
|
| 186 |
+
if bias is not None:
|
| 187 |
+
bias = bias.view(1, C, *([1] * (x.dim() - 2)))
|
| 188 |
+
x_norm = x_norm + bias
|
| 189 |
+
|
| 190 |
+
return x_norm
|
| 191 |
+
|
| 192 |
+
def script(f, *_, **__):
|
| 193 |
+
f.graph = pytorch_ocl.torch._C.Graph()
|
| 194 |
+
return f
|
| 195 |
+
|
| 196 |
+
if torch_available:
|
| 197 |
+
nn.GRU = GRU
|
| 198 |
+
F.group_norm = group_norm
|
| 199 |
+
torch.jit.script = script
|
RVC/modules/pipeline.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import faiss
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from scipy import signal
|
| 10 |
+
|
| 11 |
+
sys.path.append(os.getcwd())
|
| 12 |
+
|
| 13 |
+
from modules.generator import Generator
|
| 14 |
+
from modules.rms import RMSEnergyExtractor
|
| 15 |
+
from modules.utils import change_rms, clear_gpu_cache
|
| 16 |
+
|
| 17 |
+
bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=16000)
|
| 18 |
+
|
| 19 |
+
class Pipeline:
|
| 20 |
+
def __init__(self, tgt_sr, config):
|
| 21 |
+
self.x_pad, self.x_query, self.x_center, self.x_max = config.device_config()
|
| 22 |
+
self.sample_rate = 16000
|
| 23 |
+
self.window = 160
|
| 24 |
+
self.t_pad = self.sample_rate * self.x_pad
|
| 25 |
+
self.t_pad_tgt = tgt_sr * self.x_pad
|
| 26 |
+
self.t_pad2 = self.t_pad * 2
|
| 27 |
+
self.t_query = self.sample_rate * self.x_query
|
| 28 |
+
self.t_center = self.sample_rate * self.x_center
|
| 29 |
+
self.t_max = self.sample_rate * self.x_max
|
| 30 |
+
self.time_step = self.window / self.sample_rate * 1000
|
| 31 |
+
self.f0_min = 50
|
| 32 |
+
self.f0_max = 1100
|
| 33 |
+
self.device = config.device
|
| 34 |
+
self.is_half = config.is_half
|
| 35 |
+
|
| 36 |
+
def voice_conversion(self, model, net_g, sid, audio0, pitch, pitchf, index, big_npy, index_rate, version, protect, energy):
|
| 37 |
+
feats = (torch.from_numpy(audio0).half() if self.is_half else torch.from_numpy(audio0).float())
|
| 38 |
+
pitch_guidance = pitch != None and pitchf != None
|
| 39 |
+
energy_use = energy != None
|
| 40 |
+
|
| 41 |
+
if feats.dim() == 2: feats = feats.mean(-1)
|
| 42 |
+
assert feats.dim() == 1, feats.dim()
|
| 43 |
+
feats = feats.view(1, -1)
|
| 44 |
+
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)
|
| 47 |
+
logits = model.extract_features(**{"source": feats.to(self.device), "padding_mask": padding_mask, "output_layer": 9 if version == "v1" else 12})
|
| 48 |
+
feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
|
| 49 |
+
|
| 50 |
+
if protect < 0.5 and pitch_guidance: feats0 = feats.clone()
|
| 51 |
+
|
| 52 |
+
if (not isinstance(index, type(None)) and not isinstance(big_npy, type(None)) and index_rate != 0):
|
| 53 |
+
npy = feats[0].cpu().numpy()
|
| 54 |
+
if self.is_half: npy = npy.astype(np.float32)
|
| 55 |
+
|
| 56 |
+
score, ix = index.search(npy, k=8)
|
| 57 |
+
weight = np.square(1 / score)
|
| 58 |
+
|
| 59 |
+
npy = np.sum(big_npy[ix] * np.expand_dims(weight / weight.sum(axis=1, keepdims=True), axis=2), axis=1)
|
| 60 |
+
if self.is_half: npy = npy.astype(np.float16)
|
| 61 |
+
|
| 62 |
+
feats = (torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate + (1 - index_rate) * feats)
|
| 63 |
+
|
| 64 |
+
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
|
| 65 |
+
if protect < 0.5 and pitch_guidance: feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
|
| 66 |
+
p_len = audio0.shape[0] // self.window
|
| 67 |
+
|
| 68 |
+
if feats.shape[1] < p_len:
|
| 69 |
+
p_len = feats.shape[1]
|
| 70 |
+
if pitch_guidance: pitch, pitchf = pitch[:, :p_len], pitchf[:, :p_len]
|
| 71 |
+
if energy_use: energy = energy[:, :p_len]
|
| 72 |
+
|
| 73 |
+
if protect < 0.5 and pitch_guidance:
|
| 74 |
+
pitchff = pitchf.clone()
|
| 75 |
+
pitchff[pitchf > 0] = 1
|
| 76 |
+
pitchff[pitchf < 1] = protect
|
| 77 |
+
pitchff = pitchff.unsqueeze(-1)
|
| 78 |
+
|
| 79 |
+
feats = (feats * pitchff + feats0 * (1 - pitchff)).to(feats0.dtype)
|
| 80 |
+
|
| 81 |
+
p_len = torch.tensor([p_len], device=self.device).long()
|
| 82 |
+
feats = feats.half() if self.is_half else feats.float()
|
| 83 |
+
|
| 84 |
+
if not pitch_guidance: pitch, pitchf = None, None
|
| 85 |
+
else: pitchf = pitchf.half() if self.is_half else pitchf.float()
|
| 86 |
+
if not energy_use: energy = None
|
| 87 |
+
else: energy = energy.half() if self.is_half else energy.float()
|
| 88 |
+
|
| 89 |
+
audio1 = (
|
| 90 |
+
(
|
| 91 |
+
net_g.infer(
|
| 92 |
+
feats,
|
| 93 |
+
p_len,
|
| 94 |
+
pitch,
|
| 95 |
+
pitchf,
|
| 96 |
+
sid,
|
| 97 |
+
energy
|
| 98 |
+
)[0][0, 0]
|
| 99 |
+
).data.cpu().float().numpy()
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
del feats, p_len, net_g, model, padding_mask
|
| 103 |
+
clear_gpu_cache()
|
| 104 |
+
return audio1
|
| 105 |
+
|
| 106 |
+
def pipeline(
|
| 107 |
+
self,
|
| 108 |
+
model,
|
| 109 |
+
net_g,
|
| 110 |
+
sid,
|
| 111 |
+
audio,
|
| 112 |
+
f0_up_key,
|
| 113 |
+
f0_method,
|
| 114 |
+
file_index,
|
| 115 |
+
index_rate,
|
| 116 |
+
pitch_guidance,
|
| 117 |
+
filter_radius,
|
| 118 |
+
volume_envelope,
|
| 119 |
+
version,
|
| 120 |
+
protect,
|
| 121 |
+
hop_length,
|
| 122 |
+
energy_use=False,
|
| 123 |
+
f0_autotune=False,
|
| 124 |
+
f0_autotune_strength=False
|
| 125 |
+
):
|
| 126 |
+
if file_index != "" and os.path.exists(file_index) and index_rate != 0:
|
| 127 |
+
try:
|
| 128 |
+
index = faiss.read_index(file_index)
|
| 129 |
+
big_npy = index.reconstruct_n(0, index.ntotal)
|
| 130 |
+
except Exception as e:
|
| 131 |
+
print(f"[ERROR] Error occurred while reading index file: {e}")
|
| 132 |
+
index = big_npy = None
|
| 133 |
+
else: index = big_npy = None
|
| 134 |
+
|
| 135 |
+
opt_ts, audio_opt = [], []
|
| 136 |
+
audio = signal.filtfilt(bh, ah, audio)
|
| 137 |
+
audio_pad = np.pad(audio, (self.window // 2, self.window // 2), mode="reflect")
|
| 138 |
+
|
| 139 |
+
if audio_pad.shape[0] > self.t_max:
|
| 140 |
+
audio_sum = np.zeros_like(audio)
|
| 141 |
+
|
| 142 |
+
for i in range(self.window):
|
| 143 |
+
audio_sum += audio_pad[i : i - self.window]
|
| 144 |
+
|
| 145 |
+
for t in range(self.t_center, audio.shape[0], self.t_center):
|
| 146 |
+
opt_ts.append(t - self.t_query + np.where(np.abs(audio_sum[t - self.t_query : t + self.t_query]) == np.abs(audio_sum[t - self.t_query : t + self.t_query]).min())[0][0])
|
| 147 |
+
|
| 148 |
+
s = 0
|
| 149 |
+
t = None
|
| 150 |
+
audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
|
| 151 |
+
sid = torch.tensor(sid, device=self.device).unsqueeze(0).long()
|
| 152 |
+
p_len = audio_pad.shape[0] // self.window
|
| 153 |
+
|
| 154 |
+
if pitch_guidance:
|
| 155 |
+
if not hasattr(self, "f0_generator"): self.f0_generator = Generator(self.sample_rate, hop_length, self.f0_min, self.f0_max, self.is_half, self.device)
|
| 156 |
+
pitch, pitchf = self.f0_generator.calculator(f0_method, audio_pad, f0_up_key, p_len, filter_radius, f0_autotune, f0_autotune_strength)
|
| 157 |
+
|
| 158 |
+
if self.device == "mps": pitchf = pitchf.astype(np.float32)
|
| 159 |
+
pitch, pitchf = torch.tensor(pitch[:p_len], device=self.device).unsqueeze(0).long(), torch.tensor(pitchf[:p_len], device=self.device).unsqueeze(0).float()
|
| 160 |
+
|
| 161 |
+
if energy_use:
|
| 162 |
+
if not hasattr(self, "rms_extract"): self.rms_extract = RMSEnergyExtractor(frame_length=2048, hop_length=self.window, center=True, pad_mode = "reflect").to(self.device).eval()
|
| 163 |
+
energy = self.rms_extract(torch.from_numpy(audio_pad).to(self.device).unsqueeze(0)).cpu().numpy()
|
| 164 |
+
|
| 165 |
+
if self.device == "mps": energy = energy.astype(np.float32)
|
| 166 |
+
energy = torch.tensor(energy[:p_len], device=self.device).unsqueeze(0).float()
|
| 167 |
+
|
| 168 |
+
for t in opt_ts:
|
| 169 |
+
t = t // self.window * self.window
|
| 170 |
+
audio_opt.append(
|
| 171 |
+
self.voice_conversion(
|
| 172 |
+
model,
|
| 173 |
+
net_g,
|
| 174 |
+
sid,
|
| 175 |
+
audio_pad[s : t + self.t_pad2 + self.window],
|
| 176 |
+
pitch[:, s // self.window : (t + self.t_pad2) // self.window] if pitch_guidance else None,
|
| 177 |
+
pitchf[:, s // self.window : (t + self.t_pad2) // self.window] if pitch_guidance else None,
|
| 178 |
+
index,
|
| 179 |
+
big_npy,
|
| 180 |
+
index_rate,
|
| 181 |
+
version,
|
| 182 |
+
protect,
|
| 183 |
+
energy[:, s // self.window : (t + self.t_pad2) // self.window] if energy_use else None
|
| 184 |
+
)[self.t_pad_tgt : -self.t_pad_tgt]
|
| 185 |
+
)
|
| 186 |
+
s = t
|
| 187 |
+
|
| 188 |
+
audio_opt.append(
|
| 189 |
+
self.voice_conversion(
|
| 190 |
+
model,
|
| 191 |
+
net_g,
|
| 192 |
+
sid,
|
| 193 |
+
audio_pad[t:],
|
| 194 |
+
(pitch[:, t // self.window :] if t is not None else pitch) if pitch_guidance else None,
|
| 195 |
+
(pitchf[:, t // self.window :] if t is not None else pitchf) if pitch_guidance else None,
|
| 196 |
+
index,
|
| 197 |
+
big_npy,
|
| 198 |
+
index_rate,
|
| 199 |
+
version,
|
| 200 |
+
protect,
|
| 201 |
+
(energy[:, t // self.window :] if t is not None else energy) if energy_use else None
|
| 202 |
+
)[self.t_pad_tgt : -self.t_pad_tgt]
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
audio_opt = np.concatenate(audio_opt)
|
| 206 |
+
|
| 207 |
+
if volume_envelope != 1: audio_opt = change_rms(audio, self.sample_rate, audio_opt, self.sample_rate, volume_envelope)
|
| 208 |
+
audio_max = np.abs(audio_opt).max() / 0.99
|
| 209 |
+
if audio_max > 1: audio_opt /= audio_max
|
| 210 |
+
|
| 211 |
+
if pitch_guidance: del pitch, pitchf
|
| 212 |
+
del sid
|
| 213 |
+
|
| 214 |
+
clear_gpu_cache()
|
| 215 |
+
return audio_opt
|
RVC/modules/pixeldrain.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import requests
|
| 3 |
+
|
| 4 |
+
def pixeldrain(url, output_dir):
|
| 5 |
+
try:
|
| 6 |
+
response = requests.get(f"https://pixeldrain.com/api/file/{url.split('pixeldrain.com/u/')[1]}")
|
| 7 |
+
|
| 8 |
+
if response.status_code == 200:
|
| 9 |
+
file_path = os.path.join(output_dir, (response.headers.get("Content-Disposition").split("filename=")[-1].strip('";')))
|
| 10 |
+
|
| 11 |
+
with open(file_path, "wb") as newfile:
|
| 12 |
+
newfile.write(response.content)
|
| 13 |
+
return file_path
|
| 14 |
+
else: return None
|
| 15 |
+
except Exception as e:
|
| 16 |
+
raise RuntimeError(e)
|
RVC/modules/pyworld.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pickle
|
| 3 |
+
import ctypes
|
| 4 |
+
import platform
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
class DioOption(ctypes.Structure):
|
| 9 |
+
_fields_ = [("F0Floor", ctypes.c_double), ("F0Ceil", ctypes.c_double), ("ChannelsInOctave", ctypes.c_double), ("FramePeriod", ctypes.c_double), ("Speed", ctypes.c_int), ("AllowedRange", ctypes.c_double)]
|
| 10 |
+
|
| 11 |
+
class HarvestOption(ctypes.Structure):
|
| 12 |
+
_fields_ = [("F0Floor", ctypes.c_double), ("F0Ceil", ctypes.c_double), ("FramePeriod", ctypes.c_double)]
|
| 13 |
+
|
| 14 |
+
class PYWORLD:
|
| 15 |
+
def __init__(self):
|
| 16 |
+
self.world_path = os.path.join("models", "world")
|
| 17 |
+
os.makedirs(self.world_path, exist_ok=True)
|
| 18 |
+
model_type, suffix = (("world_64" if platform.architecture()[0] == "64bit" else "world_86"), ".dll") if platform.system() == "Windows" else ("world_linux", ".so")
|
| 19 |
+
self.world_file_path = os.path.join(self.world_path, f"{model_type}{suffix}")
|
| 20 |
+
|
| 21 |
+
if not os.path.exists(self.world_file_path):
|
| 22 |
+
with open(os.path.join("models", "world.bin"), "rb") as f:
|
| 23 |
+
model = pickle.load(f)
|
| 24 |
+
|
| 25 |
+
with open(self.world_file_path, "wb") as w:
|
| 26 |
+
w.write(model[model_type])
|
| 27 |
+
|
| 28 |
+
self.world_dll = ctypes.CDLL(self.world_file_path)
|
| 29 |
+
|
| 30 |
+
def harvest(self, x, fs, f0_floor=50, f0_ceil=1100, frame_period=10):
|
| 31 |
+
self.world_dll.Harvest.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(HarvestOption), ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double)]
|
| 32 |
+
self.world_dll.Harvest.restype = None
|
| 33 |
+
self.world_dll.InitializeHarvestOption.argtypes = [ctypes.POINTER(HarvestOption)]
|
| 34 |
+
self.world_dll.InitializeHarvestOption.restype = None
|
| 35 |
+
self.world_dll.GetSamplesForHarvest.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_double]
|
| 36 |
+
self.world_dll.GetSamplesForHarvest.restype = ctypes.c_int
|
| 37 |
+
|
| 38 |
+
option = HarvestOption()
|
| 39 |
+
self.world_dll.InitializeHarvestOption(ctypes.byref(option))
|
| 40 |
+
|
| 41 |
+
option.F0Floor = f0_floor
|
| 42 |
+
option.F0Ceil = f0_ceil
|
| 43 |
+
option.FramePeriod = frame_period
|
| 44 |
+
|
| 45 |
+
f0_length = self.world_dll.GetSamplesForHarvest(fs, len(x), option.FramePeriod)
|
| 46 |
+
f0 = (ctypes.c_double * f0_length)()
|
| 47 |
+
tpos = (ctypes.c_double * f0_length)()
|
| 48 |
+
|
| 49 |
+
self.world_dll.Harvest((ctypes.c_double * len(x))(*x), len(x), fs, ctypes.byref(option), tpos, f0)
|
| 50 |
+
return np.array(f0, dtype=np.float32), np.array(tpos, dtype=np.float32)
|
| 51 |
+
|
| 52 |
+
def dio(self, x, fs, f0_floor=50, f0_ceil=1100, channels_in_octave=2, frame_period=10, speed=1, allowed_range=0.1):
|
| 53 |
+
self.world_dll.Dio.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(DioOption), ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double)]
|
| 54 |
+
self.world_dll.Dio.restype = None
|
| 55 |
+
self.world_dll.InitializeDioOption.argtypes = [ctypes.POINTER(DioOption)]
|
| 56 |
+
self.world_dll.InitializeDioOption.restype = None
|
| 57 |
+
self.world_dll.GetSamplesForDIO.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_double]
|
| 58 |
+
self.world_dll.GetSamplesForDIO.restype = ctypes.c_int
|
| 59 |
+
|
| 60 |
+
option = DioOption()
|
| 61 |
+
self.world_dll.InitializeDioOption(ctypes.byref(option))
|
| 62 |
+
|
| 63 |
+
option.F0Floor = f0_floor
|
| 64 |
+
option.F0Ceil = f0_ceil
|
| 65 |
+
option.ChannelsInOctave = channels_in_octave
|
| 66 |
+
option.FramePeriod = frame_period
|
| 67 |
+
option.Speed = speed
|
| 68 |
+
option.AllowedRange = allowed_range
|
| 69 |
+
|
| 70 |
+
f0_length = self.world_dll.GetSamplesForDIO(fs, len(x), option.FramePeriod)
|
| 71 |
+
f0 = (ctypes.c_double * f0_length)()
|
| 72 |
+
tpos = (ctypes.c_double * f0_length)()
|
| 73 |
+
|
| 74 |
+
self.world_dll.Dio((ctypes.c_double * len(x))(*x), len(x), fs, ctypes.byref(option), tpos, f0)
|
| 75 |
+
return np.array(f0, dtype=np.float32), np.array(tpos, dtype=np.float32)
|
| 76 |
+
|
| 77 |
+
def stonemask(self, x, fs, tpos, f0):
|
| 78 |
+
self.world_dll.StoneMask.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.POINTER(ctypes.c_double)]
|
| 79 |
+
self.world_dll.StoneMask.restype = None
|
| 80 |
+
|
| 81 |
+
out_f0 = (ctypes.c_double * len(f0))()
|
| 82 |
+
self.world_dll.StoneMask((ctypes.c_double * len(x))(*x), len(x), fs, (ctypes.c_double * len(tpos))(*tpos), (ctypes.c_double * len(f0))(*f0), len(f0), out_f0)
|
| 83 |
+
|
| 84 |
+
return np.array(out_f0, dtype=np.float32)
|
RVC/modules/refinegan.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from torch.utils.checkpoint import checkpoint
|
| 11 |
+
from torch.nn.utils import remove_weight_norm
|
| 12 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 13 |
+
|
| 14 |
+
sys.path.append(os.getcwd())
|
| 15 |
+
|
| 16 |
+
from modules.commons import init_weights, get_padding
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ResBlock(nn.Module):
|
| 20 |
+
def __init__(self, channels, kernel_size = 7, dilation = (1, 3, 5), leaky_relu_slope = 0.2):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.leaky_relu_slope = leaky_relu_slope
|
| 23 |
+
self.convs1 = nn.ModuleList([weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=d, padding=get_padding(kernel_size, d))) for d in dilation])
|
| 24 |
+
self.convs1.apply(init_weights)
|
| 25 |
+
self.convs2 = nn.ModuleList([weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1, padding=get_padding(kernel_size, 1))) for _ in dilation])
|
| 26 |
+
self.convs2.apply(init_weights)
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
| 30 |
+
x = c2(F.leaky_relu(c1(F.leaky_relu(x, self.leaky_relu_slope)), self.leaky_relu_slope)) + x
|
| 31 |
+
|
| 32 |
+
return x
|
| 33 |
+
|
| 34 |
+
def remove_weight_norm(self):
|
| 35 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
| 36 |
+
remove_weight_norm(c1)
|
| 37 |
+
remove_weight_norm(c2)
|
| 38 |
+
|
| 39 |
+
class AdaIN(nn.Module):
|
| 40 |
+
def __init__(self, *, channels, leaky_relu_slope = 0.2):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.weight = nn.Parameter(torch.ones(channels))
|
| 43 |
+
self.activation = nn.LeakyReLU(leaky_relu_slope)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
return self.activation(x + (torch.randn_like(x) * self.weight[None, :, None]))
|
| 47 |
+
|
| 48 |
+
class ParallelResBlock(nn.Module):
|
| 49 |
+
def __init__(self, *, in_channels, out_channels, kernel_sizes = (3, 7, 11), dilation = (1, 3, 5), leaky_relu_slope = 0.2):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.in_channels = in_channels
|
| 52 |
+
self.out_channels = out_channels
|
| 53 |
+
self.input_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=1, padding=3)
|
| 54 |
+
self.input_conv.apply(init_weights)
|
| 55 |
+
self.blocks = nn.ModuleList([nn.Sequential(AdaIN(channels=out_channels), ResBlock(out_channels, kernel_size=kernel_size, dilation=dilation, leaky_relu_slope=leaky_relu_slope), AdaIN(channels=out_channels)) for kernel_size in kernel_sizes])
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
x = self.input_conv(x)
|
| 59 |
+
return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
|
| 60 |
+
|
| 61 |
+
def remove_weight_norm(self):
|
| 62 |
+
remove_weight_norm(self.input_conv)
|
| 63 |
+
for block in self.blocks:
|
| 64 |
+
block[1].remove_weight_norm()
|
| 65 |
+
|
| 66 |
+
class SineGenerator(nn.Module):
|
| 67 |
+
def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0):
|
| 68 |
+
super(SineGenerator, self).__init__()
|
| 69 |
+
self.sine_amp = sine_amp
|
| 70 |
+
self.noise_std = noise_std
|
| 71 |
+
self.harmonic_num = harmonic_num
|
| 72 |
+
self.dim = self.harmonic_num + 1
|
| 73 |
+
self.sampling_rate = samp_rate
|
| 74 |
+
self.voiced_threshold = voiced_threshold
|
| 75 |
+
self.merge = nn.Sequential(nn.Linear(self.dim, 1, bias=False), nn.Tanh())
|
| 76 |
+
|
| 77 |
+
def _f02uv(self, f0):
|
| 78 |
+
return torch.ones_like(f0) * (f0 > self.voiced_threshold)
|
| 79 |
+
|
| 80 |
+
def _f02sine(self, f0_values):
|
| 81 |
+
rad_values = (f0_values / self.sampling_rate) % 1
|
| 82 |
+
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], dtype=f0_values.dtype, device=f0_values.device)
|
| 83 |
+
|
| 84 |
+
rand_ini[:, 0] = 0
|
| 85 |
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
| 86 |
+
|
| 87 |
+
tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
| 88 |
+
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
|
| 89 |
+
|
| 90 |
+
cumsum_shift = torch.zeros_like(rad_values)
|
| 91 |
+
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
| 92 |
+
|
| 93 |
+
return torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
|
| 94 |
+
|
| 95 |
+
def forward(self, f0):
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, dtype=f0.dtype, device=f0.device)
|
| 98 |
+
f0_buf[:, :, 0] = f0[:, :, 0]
|
| 99 |
+
|
| 100 |
+
for idx in np.arange(self.harmonic_num):
|
| 101 |
+
f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
|
| 102 |
+
|
| 103 |
+
sine_waves = self._f02sine(f0_buf) * self.sine_amp
|
| 104 |
+
uv = self._f02uv(f0)
|
| 105 |
+
sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
|
| 106 |
+
|
| 107 |
+
return self.merge(sine_waves)
|
| 108 |
+
|
| 109 |
+
class RefineGANGenerator(nn.Module):
|
| 110 |
+
def __init__(self, *, sample_rate = 44100, upsample_rates = (8, 8, 2, 2), leaky_relu_slope = 0.2, num_mels = 128, gin_channels = 256, checkpointing = False, upsample_initial_channel = 512):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.upsample_rates = upsample_rates
|
| 113 |
+
self.checkpointing = checkpointing
|
| 114 |
+
self.leaky_relu_slope = leaky_relu_slope
|
| 115 |
+
self.upp = np.prod(upsample_rates)
|
| 116 |
+
self.m_source = SineGenerator(sample_rate)
|
| 117 |
+
self.pre_conv = weight_norm(nn.Conv1d(1, upsample_initial_channel // 2, 7, 1, padding=3))
|
| 118 |
+
stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1 for i in range(len(upsample_rates))]
|
| 119 |
+
|
| 120 |
+
channels = upsample_initial_channel
|
| 121 |
+
self.downsample_blocks = nn.ModuleList([])
|
| 122 |
+
|
| 123 |
+
for i, _ in enumerate(upsample_rates):
|
| 124 |
+
stride = stride_f0s[i]
|
| 125 |
+
kernel = 1 if stride == 1 else stride * 2 - stride % 2
|
| 126 |
+
|
| 127 |
+
self.downsample_blocks.append(weight_norm(nn.Conv1d(1, channels // 2 ** (i + 2), kernel, stride, padding=0 if stride == 1 else (kernel - stride) // 2)))
|
| 128 |
+
|
| 129 |
+
self.mel_conv = weight_norm(nn.Conv1d(num_mels, channels // 2, 7, 1, padding=3))
|
| 130 |
+
self.mel_conv.apply(init_weights)
|
| 131 |
+
|
| 132 |
+
if gin_channels != 0: self.cond = nn.Conv1d(256, channels // 2, 1)
|
| 133 |
+
|
| 134 |
+
self.upsample_blocks = nn.ModuleList([])
|
| 135 |
+
self.upsample_conv_blocks = nn.ModuleList([])
|
| 136 |
+
|
| 137 |
+
for rate in upsample_rates:
|
| 138 |
+
new_channels = channels // 2
|
| 139 |
+
self.upsample_blocks.append(nn.Upsample(scale_factor=rate, mode="linear"))
|
| 140 |
+
self.upsample_conv_blocks.append(ParallelResBlock(in_channels=channels + channels // 4, out_channels=new_channels, kernel_sizes=(3, 7, 11), dilation=(1, 3, 5), leaky_relu_slope=leaky_relu_slope))
|
| 141 |
+
channels = new_channels
|
| 142 |
+
|
| 143 |
+
self.conv_post = weight_norm(nn.Conv1d(channels, 1, 7, 1, padding=3, bias=False))
|
| 144 |
+
self.conv_post.apply(init_weights)
|
| 145 |
+
|
| 146 |
+
def forward(self, mel, f0, g = None):
|
| 147 |
+
har_source = self.m_source(F.interpolate(f0.unsqueeze(1), size=mel.shape[-1] * self.upp, mode="linear").transpose(1, 2)).transpose(1, 2)
|
| 148 |
+
x = F.interpolate(self.pre_conv(har_source), size=mel.shape[-1], mode="linear")
|
| 149 |
+
|
| 150 |
+
mel = self.mel_conv(mel)
|
| 151 |
+
if g is not None: mel += self.cond(g)
|
| 152 |
+
|
| 153 |
+
x = torch.cat([mel, x], dim=1)
|
| 154 |
+
|
| 155 |
+
for ups, res, down in zip(self.upsample_blocks, self.upsample_conv_blocks, self.downsample_blocks):
|
| 156 |
+
x = F.leaky_relu(x, self.leaky_relu_slope)
|
| 157 |
+
x = checkpoint(res, torch.cat([checkpoint(ups, x, use_reentrant=False), down(har_source)], dim=1), use_reentrant=False) if self.training and self.checkpointing else res(torch.cat([ups(x), down(har_source)], dim=1))
|
| 158 |
+
|
| 159 |
+
return torch.tanh(self.conv_post(F.leaky_relu(x, self.leaky_relu_slope)))
|
| 160 |
+
|
| 161 |
+
def remove_weight_norm(self):
|
| 162 |
+
remove_weight_norm(self.pre_conv)
|
| 163 |
+
remove_weight_norm(self.mel_conv)
|
| 164 |
+
remove_weight_norm(self.conv_post)
|
| 165 |
+
|
| 166 |
+
for block in self.downsample_blocks:
|
| 167 |
+
block.remove_weight_norm()
|
| 168 |
+
|
| 169 |
+
for block in self.upsample_conv_blocks:
|
| 170 |
+
block.remove_weight_norm()
|
RVC/modules/residuals.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from torch.nn.utils import remove_weight_norm
|
| 6 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 7 |
+
|
| 8 |
+
sys.path.append(os.getcwd())
|
| 9 |
+
|
| 10 |
+
from .modules import WaveNet
|
| 11 |
+
from .commons import get_padding, init_weights
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
LRELU_SLOPE = 0.1
|
| 15 |
+
|
| 16 |
+
def create_conv1d_layer(channels, kernel_size, dilation):
|
| 17 |
+
return weight_norm(torch.nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation, padding=get_padding(kernel_size, dilation)))
|
| 18 |
+
|
| 19 |
+
def apply_mask(tensor, mask):
|
| 20 |
+
return tensor * mask if mask is not None else tensor
|
| 21 |
+
|
| 22 |
+
class ResBlockBase(torch.nn.Module):
|
| 23 |
+
def __init__(self, channels, kernel_size, dilations):
|
| 24 |
+
super(ResBlockBase, self).__init__()
|
| 25 |
+
|
| 26 |
+
self.convs1 = torch.nn.ModuleList([create_conv1d_layer(channels, kernel_size, d) for d in dilations])
|
| 27 |
+
self.convs1.apply(init_weights)
|
| 28 |
+
|
| 29 |
+
self.convs2 = torch.nn.ModuleList([create_conv1d_layer(channels, kernel_size, 1) for _ in dilations])
|
| 30 |
+
self.convs2.apply(init_weights)
|
| 31 |
+
|
| 32 |
+
def forward(self, x, x_mask=None):
|
| 33 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
| 34 |
+
x = c2(apply_mask(torch.nn.functional.leaky_relu(c1(apply_mask(torch.nn.functional.leaky_relu(x, LRELU_SLOPE), x_mask)), LRELU_SLOPE), x_mask)) + x
|
| 35 |
+
|
| 36 |
+
return apply_mask(x, x_mask)
|
| 37 |
+
|
| 38 |
+
def remove_weight_norm(self):
|
| 39 |
+
for conv in self.convs1 + self.convs2:
|
| 40 |
+
remove_weight_norm(conv)
|
| 41 |
+
|
| 42 |
+
class ResBlock(ResBlockBase):
|
| 43 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
| 44 |
+
super(ResBlock, self).__init__(channels, kernel_size, dilation)
|
| 45 |
+
|
| 46 |
+
class Log(torch.nn.Module):
|
| 47 |
+
def forward(self, x, x_mask, reverse=False, **kwargs):
|
| 48 |
+
if not reverse:
|
| 49 |
+
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
| 50 |
+
return y, torch.sum(-y, [1, 2])
|
| 51 |
+
else: return torch.exp(x) * x_mask
|
| 52 |
+
|
| 53 |
+
class Flip(torch.nn.Module):
|
| 54 |
+
def forward(self, x, *args, reverse=False, **kwargs):
|
| 55 |
+
x = torch.flip(x, [1])
|
| 56 |
+
|
| 57 |
+
if not reverse: return x, torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
| 58 |
+
else: return x
|
| 59 |
+
|
| 60 |
+
class ElementwiseAffine(torch.nn.Module):
|
| 61 |
+
def __init__(self, channels):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.channels = channels
|
| 64 |
+
self.m = torch.nn.Parameter(torch.zeros(channels, 1))
|
| 65 |
+
self.logs = torch.nn.Parameter(torch.zeros(channels, 1))
|
| 66 |
+
|
| 67 |
+
def forward(self, x, x_mask, reverse=False, **kwargs):
|
| 68 |
+
if not reverse: return ((self.m + torch.exp(self.logs) * x) * x_mask), torch.sum(self.logs * x_mask, [1, 2])
|
| 69 |
+
else: return (x - self.m) * torch.exp(-self.logs) * x_mask
|
| 70 |
+
|
| 71 |
+
class ResidualCouplingBlock(torch.nn.Module):
|
| 72 |
+
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
|
| 73 |
+
super(ResidualCouplingBlock, self).__init__()
|
| 74 |
+
self.channels = channels
|
| 75 |
+
self.hidden_channels = hidden_channels
|
| 76 |
+
self.kernel_size = kernel_size
|
| 77 |
+
self.dilation_rate = dilation_rate
|
| 78 |
+
self.n_layers = n_layers
|
| 79 |
+
self.n_flows = n_flows
|
| 80 |
+
self.gin_channels = gin_channels
|
| 81 |
+
self.flows = torch.nn.ModuleList()
|
| 82 |
+
|
| 83 |
+
for _ in range(n_flows):
|
| 84 |
+
self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
|
| 85 |
+
self.flows.append(Flip())
|
| 86 |
+
|
| 87 |
+
def forward(self, x, x_mask, g = None, reverse = False):
|
| 88 |
+
if not reverse:
|
| 89 |
+
for flow in self.flows:
|
| 90 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
| 91 |
+
else:
|
| 92 |
+
for flow in reversed(self.flows):
|
| 93 |
+
x = flow.forward(x, x_mask, g=g, reverse=reverse)
|
| 94 |
+
|
| 95 |
+
return x
|
| 96 |
+
|
| 97 |
+
def remove_weight_norm(self):
|
| 98 |
+
for i in range(self.n_flows):
|
| 99 |
+
self.flows[i * 2].remove_weight_norm()
|
| 100 |
+
|
| 101 |
+
def __prepare_scriptable__(self):
|
| 102 |
+
for i in range(self.n_flows):
|
| 103 |
+
for hook in self.flows[i * 2]._forward_pre_hooks.values():
|
| 104 |
+
if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(self.flows[i * 2])
|
| 105 |
+
|
| 106 |
+
return self
|
| 107 |
+
|
| 108 |
+
class ResidualCouplingLayer(torch.nn.Module):
|
| 109 |
+
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, mean_only=False):
|
| 110 |
+
assert channels % 2 == 0, "Channels/2"
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.channels = channels
|
| 113 |
+
self.hidden_channels = hidden_channels
|
| 114 |
+
self.kernel_size = kernel_size
|
| 115 |
+
self.dilation_rate = dilation_rate
|
| 116 |
+
self.n_layers = n_layers
|
| 117 |
+
self.half_channels = channels // 2
|
| 118 |
+
self.mean_only = mean_only
|
| 119 |
+
|
| 120 |
+
self.pre = torch.nn.Conv1d(self.half_channels, hidden_channels, 1)
|
| 121 |
+
self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
|
| 122 |
+
self.post = torch.nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
| 123 |
+
|
| 124 |
+
self.post.weight.data.zero_()
|
| 125 |
+
self.post.bias.data.zero_()
|
| 126 |
+
|
| 127 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
| 128 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
| 129 |
+
stats = self.post(self.enc((self.pre(x0) * x_mask), x_mask, g=g)) * x_mask
|
| 130 |
+
|
| 131 |
+
if not self.mean_only: m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
| 132 |
+
else:
|
| 133 |
+
m = stats
|
| 134 |
+
logs = torch.zeros_like(m)
|
| 135 |
+
|
| 136 |
+
if not reverse: return torch.cat([x0, (m + x1 * torch.exp(logs) * x_mask)], 1), torch.sum(logs, [1, 2])
|
| 137 |
+
else: return torch.cat([x0, ((x1 - m) * torch.exp(-logs) * x_mask)], 1)
|
| 138 |
+
|
| 139 |
+
def remove_weight_norm(self):
|
| 140 |
+
self.enc.remove_weight_norm()
|
RVC/modules/rms.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import librosa
|
| 3 |
+
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
class RMSEnergyExtractor(nn.Module):
|
| 7 |
+
def __init__(self, frame_length=2048, hop_length=512, center=True, pad_mode = "reflect"):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.frame_length = frame_length
|
| 10 |
+
self.hop_length = hop_length
|
| 11 |
+
self.center = center
|
| 12 |
+
self.pad_mode = pad_mode
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
assert x.ndim == 2
|
| 16 |
+
assert x.shape[0] == 1
|
| 17 |
+
|
| 18 |
+
if str(x.device).startswith("ocl"): x = x.contiguous()
|
| 19 |
+
|
| 20 |
+
rms = torch.from_numpy(
|
| 21 |
+
librosa.feature.rms(
|
| 22 |
+
y=x.squeeze(0).cpu().numpy(),
|
| 23 |
+
frame_length=self.frame_length,
|
| 24 |
+
hop_length=self.hop_length,
|
| 25 |
+
center=self.center,
|
| 26 |
+
pad_mode=self.pad_mode
|
| 27 |
+
)
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
return rms.squeeze(-2).to(x.device) if not str(x.device).startswith("ocl") else rms.contiguous().squeeze(-2).to(x.device)
|
RVC/modules/rmvpe.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from librosa.filters import mel
|
| 10 |
+
|
| 11 |
+
sys.path.append(os.getcwd())
|
| 12 |
+
|
| 13 |
+
from modules import opencl
|
| 14 |
+
|
| 15 |
+
N_MELS, N_CLASS = 128, 360
|
| 16 |
+
|
| 17 |
+
class ConvBlockRes(nn.Module):
|
| 18 |
+
def __init__(self, in_channels, out_channels, momentum=0.01):
|
| 19 |
+
super(ConvBlockRes, self).__init__()
|
| 20 |
+
self.conv = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
|
| 21 |
+
if in_channels != out_channels:
|
| 22 |
+
self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
|
| 23 |
+
self.is_shortcut = True
|
| 24 |
+
else: self.is_shortcut = False
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
return (self.conv(x) + self.shortcut(x)) if self.is_shortcut else (self.conv(x) + x)
|
| 28 |
+
|
| 29 |
+
class ResEncoderBlock(nn.Module):
|
| 30 |
+
def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
|
| 31 |
+
super(ResEncoderBlock, self).__init__()
|
| 32 |
+
self.n_blocks = n_blocks
|
| 33 |
+
self.conv = nn.ModuleList()
|
| 34 |
+
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
|
| 35 |
+
|
| 36 |
+
for _ in range(n_blocks - 1):
|
| 37 |
+
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
|
| 38 |
+
|
| 39 |
+
self.kernel_size = kernel_size
|
| 40 |
+
if self.kernel_size is not None: self.pool = nn.AvgPool2d(kernel_size=kernel_size)
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
for i in range(self.n_blocks):
|
| 44 |
+
x = self.conv[i](x)
|
| 45 |
+
|
| 46 |
+
if self.kernel_size is not None: return x, self.pool(x)
|
| 47 |
+
else: return x
|
| 48 |
+
|
| 49 |
+
class Encoder(nn.Module):
|
| 50 |
+
def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
|
| 51 |
+
super(Encoder, self).__init__()
|
| 52 |
+
self.n_encoders = n_encoders
|
| 53 |
+
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
|
| 54 |
+
self.layers = nn.ModuleList()
|
| 55 |
+
|
| 56 |
+
for _ in range(self.n_encoders):
|
| 57 |
+
self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum))
|
| 58 |
+
in_channels = out_channels
|
| 59 |
+
out_channels *= 2
|
| 60 |
+
in_size //= 2
|
| 61 |
+
|
| 62 |
+
self.out_size = in_size
|
| 63 |
+
self.out_channel = out_channels
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
concat_tensors = []
|
| 67 |
+
x = self.bn(x)
|
| 68 |
+
|
| 69 |
+
for layer in self.layers:
|
| 70 |
+
t, x = layer(x)
|
| 71 |
+
concat_tensors.append(t)
|
| 72 |
+
|
| 73 |
+
return x, concat_tensors
|
| 74 |
+
|
| 75 |
+
class Intermediate(nn.Module):
|
| 76 |
+
def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
|
| 77 |
+
super(Intermediate, self).__init__()
|
| 78 |
+
self.layers = nn.ModuleList()
|
| 79 |
+
self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
|
| 80 |
+
|
| 81 |
+
for _ in range(n_inters - 1):
|
| 82 |
+
self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
for layer in self.layers:
|
| 86 |
+
x = layer(x)
|
| 87 |
+
|
| 88 |
+
return x
|
| 89 |
+
|
| 90 |
+
class ResDecoderBlock(nn.Module):
|
| 91 |
+
def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
|
| 92 |
+
super(ResDecoderBlock, self).__init__()
|
| 93 |
+
out_padding = (0, 1) if stride == (1, 2) else (1, 1)
|
| 94 |
+
self.conv1 = nn.Sequential(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=stride, padding=(1, 1), output_padding=out_padding, bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
|
| 95 |
+
self.conv2 = nn.ModuleList()
|
| 96 |
+
self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
|
| 97 |
+
|
| 98 |
+
for _ in range(n_blocks - 1):
|
| 99 |
+
self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
|
| 100 |
+
|
| 101 |
+
def forward(self, x, concat_tensor):
|
| 102 |
+
x = torch.cat((self.conv1(x), concat_tensor), dim=1)
|
| 103 |
+
for conv2 in self.conv2:
|
| 104 |
+
x = conv2(x)
|
| 105 |
+
|
| 106 |
+
return x
|
| 107 |
+
|
| 108 |
+
class Decoder(nn.Module):
|
| 109 |
+
def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
|
| 110 |
+
super(Decoder, self).__init__()
|
| 111 |
+
self.layers = nn.ModuleList()
|
| 112 |
+
|
| 113 |
+
for _ in range(n_decoders):
|
| 114 |
+
out_channels = in_channels // 2
|
| 115 |
+
self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
|
| 116 |
+
in_channels = out_channels
|
| 117 |
+
|
| 118 |
+
def forward(self, x, concat_tensors):
|
| 119 |
+
for i, layer in enumerate(self.layers):
|
| 120 |
+
x = layer(x, concat_tensors[-1 - i])
|
| 121 |
+
|
| 122 |
+
return x
|
| 123 |
+
|
| 124 |
+
class DeepUnet(nn.Module):
|
| 125 |
+
def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
|
| 126 |
+
super(DeepUnet, self).__init__()
|
| 127 |
+
self.encoder = Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels)
|
| 128 |
+
self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
|
| 129 |
+
self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
|
| 130 |
+
|
| 131 |
+
def forward(self, x):
|
| 132 |
+
x, concat_tensors = self.encoder(x)
|
| 133 |
+
return self.decoder(self.intermediate(x), concat_tensors)
|
| 134 |
+
|
| 135 |
+
class E2E(nn.Module):
|
| 136 |
+
def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
|
| 137 |
+
super(E2E, self).__init__()
|
| 138 |
+
self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
|
| 139 |
+
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
|
| 140 |
+
self.fc = nn.Sequential(BiGRU(3 * 128, 256, n_gru), nn.Linear(512, N_CLASS), nn.Dropout(0.25), nn.Sigmoid()) if n_gru else nn.Sequential(nn.Linear(3 * N_MELS, N_CLASS), nn.Dropout(0.25), nn.Sigmoid())
|
| 141 |
+
|
| 142 |
+
def forward(self, mel):
|
| 143 |
+
return self.fc(self.cnn(self.unet(mel.transpose(-1, -2).unsqueeze(1))).transpose(1, 2).flatten(-2))
|
| 144 |
+
|
| 145 |
+
class MelSpectrogram(torch.nn.Module):
|
| 146 |
+
def __init__(self, is_half, n_mel_channels, sample_rate, win_length, hop_length, n_fft=None, mel_fmin=0, mel_fmax=None, clamp=1e-5):
|
| 147 |
+
super().__init__()
|
| 148 |
+
n_fft = win_length if n_fft is None else n_fft
|
| 149 |
+
self.hann_window = {}
|
| 150 |
+
mel_basis = mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax, htk=True)
|
| 151 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
| 152 |
+
self.register_buffer("mel_basis", mel_basis)
|
| 153 |
+
self.n_fft = win_length if n_fft is None else n_fft
|
| 154 |
+
self.hop_length = hop_length
|
| 155 |
+
self.win_length = win_length
|
| 156 |
+
self.sample_rate = sample_rate
|
| 157 |
+
self.n_mel_channels = n_mel_channels
|
| 158 |
+
self.clamp = clamp
|
| 159 |
+
self.is_half = is_half
|
| 160 |
+
|
| 161 |
+
def forward(self, audio, keyshift=0, speed=1, center=True):
|
| 162 |
+
factor = 2 ** (keyshift / 12)
|
| 163 |
+
win_length_new = int(np.round(self.win_length * factor))
|
| 164 |
+
keyshift_key = str(keyshift) + "_" + str(audio.device)
|
| 165 |
+
if keyshift_key not in self.hann_window: self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
|
| 166 |
+
|
| 167 |
+
n_fft = int(np.round(self.n_fft * factor))
|
| 168 |
+
hop_length = int(np.round(self.hop_length * speed))
|
| 169 |
+
|
| 170 |
+
if str(audio.device).startswith("ocl"):
|
| 171 |
+
stft = opencl.STFT(filter_length=n_fft, hop_length=hop_length, win_length=win_length_new).to(audio.device)
|
| 172 |
+
magnitude = stft.transform(audio, 1e-9)
|
| 173 |
+
else:
|
| 174 |
+
fft = torch.stft(audio, n_fft=n_fft, hop_length=hop_length, win_length=win_length_new, window=self.hann_window[keyshift_key], center=center, return_complex=True)
|
| 175 |
+
magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
|
| 176 |
+
|
| 177 |
+
if keyshift != 0:
|
| 178 |
+
size = self.n_fft // 2 + 1
|
| 179 |
+
resize = magnitude.size(1)
|
| 180 |
+
if resize < size: magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
|
| 181 |
+
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
|
| 182 |
+
|
| 183 |
+
mel_output = torch.matmul(self.mel_basis, magnitude)
|
| 184 |
+
if self.is_half: mel_output = mel_output.half()
|
| 185 |
+
|
| 186 |
+
return torch.log(torch.clamp(mel_output, min=self.clamp))
|
| 187 |
+
|
| 188 |
+
class RMVPE:
|
| 189 |
+
def __init__(self, model_path, is_half, device=None):
|
| 190 |
+
self.resample_kernel = {}
|
| 191 |
+
self.resample_kernel = {}
|
| 192 |
+
model = E2E(4, 1, (2, 2))
|
| 193 |
+
ckpt = torch.load(model_path, map_location="cpu")
|
| 194 |
+
model.load_state_dict(ckpt)
|
| 195 |
+
model.eval()
|
| 196 |
+
if is_half: model = model.half()
|
| 197 |
+
self.model = model.to(device)
|
| 198 |
+
self.is_half = is_half
|
| 199 |
+
self.device = device
|
| 200 |
+
self.mel_extractor = MelSpectrogram(is_half, N_MELS, 16000, 1024, 160, None, 30, 8000).to(device)
|
| 201 |
+
cents_mapping = 20 * np.arange(N_CLASS) + 1997.3794084376191
|
| 202 |
+
self.cents_mapping = np.pad(cents_mapping, (4, 4))
|
| 203 |
+
|
| 204 |
+
def mel2hidden(self, mel):
|
| 205 |
+
with torch.no_grad():
|
| 206 |
+
n_frames = mel.shape[-1]
|
| 207 |
+
n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
|
| 208 |
+
if n_pad > 0: mel = F.pad(mel, (0, n_pad), mode="constant")
|
| 209 |
+
|
| 210 |
+
hidden = self.model(mel.half() if self.is_half else mel.float())
|
| 211 |
+
return hidden[:, :n_frames]
|
| 212 |
+
|
| 213 |
+
def decode(self, hidden, thred=0.03):
|
| 214 |
+
f0 = 10 * (2 ** (self.to_local_average_cents(hidden, thred=thred) / 1200))
|
| 215 |
+
f0[f0 == 10] = 0
|
| 216 |
+
|
| 217 |
+
return f0
|
| 218 |
+
|
| 219 |
+
def infer_from_audio(self, audio, thred=0.03):
|
| 220 |
+
hidden = self.mel2hidden(self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True))
|
| 221 |
+
|
| 222 |
+
return self.decode((hidden.squeeze(0).cpu().numpy().astype(np.float32) if self.is_half else hidden.squeeze(0).cpu().numpy()), thred=thred)
|
| 223 |
+
|
| 224 |
+
def infer_from_audio_with_pitch(self, audio, thred=0.03, f0_min=50, f0_max=1100):
|
| 225 |
+
hidden = self.mel2hidden(self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True))
|
| 226 |
+
|
| 227 |
+
f0 = self.decode((hidden.squeeze(0).cpu().numpy().astype(np.float32) if self.is_half else hidden.squeeze(0).cpu().numpy()), thred=thred)
|
| 228 |
+
f0[(f0 < f0_min) | (f0 > f0_max)] = 0
|
| 229 |
+
|
| 230 |
+
return f0
|
| 231 |
+
|
| 232 |
+
def to_local_average_cents(self, salience, thred=0.05):
|
| 233 |
+
center = np.argmax(salience, axis=1)
|
| 234 |
+
salience = np.pad(salience, ((0, 0), (4, 4)))
|
| 235 |
+
center += 4
|
| 236 |
+
todo_salience, todo_cents_mapping = [], []
|
| 237 |
+
starts = center - 4
|
| 238 |
+
ends = center + 5
|
| 239 |
+
|
| 240 |
+
for idx in range(salience.shape[0]):
|
| 241 |
+
todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
|
| 242 |
+
todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
|
| 243 |
+
|
| 244 |
+
todo_salience = np.array(todo_salience)
|
| 245 |
+
devided = np.sum(todo_salience * np.array(todo_cents_mapping), 1) / np.sum(todo_salience, 1)
|
| 246 |
+
devided[np.max(salience, axis=1) <= thred] = 0
|
| 247 |
+
|
| 248 |
+
return devided
|
| 249 |
+
|
| 250 |
+
class BiGRU(nn.Module):
|
| 251 |
+
def __init__(self, input_features, hidden_features, num_layers):
|
| 252 |
+
super(BiGRU, self).__init__()
|
| 253 |
+
self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
|
| 254 |
+
|
| 255 |
+
def forward(self, x):
|
| 256 |
+
try:
|
| 257 |
+
return self.gru(x)[0]
|
| 258 |
+
except:
|
| 259 |
+
torch.backends.cudnn.enabled = False
|
| 260 |
+
return self.gru(x)[0]
|
RVC/modules/swipe.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import numba as nb
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from matplotlib import mlab
|
| 7 |
+
from scipy import interpolate
|
| 8 |
+
from decimal import Decimal, ROUND_HALF_UP
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def swipe(x, fs, f0_floor=50, f0_ceil=1100, frame_period=10, sTHR=0.3):
|
| 12 |
+
plim = np.array([f0_floor, f0_ceil])
|
| 13 |
+
t = np.arange(0, int(1000 * len(x) / fs / (frame_period) + 1)) * (frame_period / 1000)
|
| 14 |
+
|
| 15 |
+
log2pc = np.arange(np.log2(plim[0]) * 96, np.log2(plim[-1]) * 96)
|
| 16 |
+
log2pc *= (1 / 96)
|
| 17 |
+
|
| 18 |
+
pc = 2 ** log2pc
|
| 19 |
+
S = np.zeros((len(pc), len(t)))
|
| 20 |
+
|
| 21 |
+
logWs = [round_matlab(elm) for elm in np.log2(4 * 2 * fs / plim)]
|
| 22 |
+
ws = 2 ** np.arange(logWs[0], logWs[1] - 1, -1)
|
| 23 |
+
p0 = 4 * 2 * fs / ws
|
| 24 |
+
|
| 25 |
+
d = 1 + log2pc - np.log2(4 * 2 * fs / ws[0])
|
| 26 |
+
fERBs = erbs2hz(np.arange(hz2erbs(pc[0] / 4), hz2erbs(fs / 2), 0.1))
|
| 27 |
+
|
| 28 |
+
for i in range(len(ws)):
|
| 29 |
+
dn = round_matlab(4 * fs / p0[i])
|
| 30 |
+
X, f, ti = mlab.specgram(x=np.r_[np.zeros(int(ws[i] / 2)), np.r_[x, np.zeros(int(dn + ws[i] / 2))]], NFFT=ws[i], Fs=fs, window=np.hanning(ws[i] + 2)[1:-1], noverlap=max(0, np.round(ws[i] - dn)), mode='complex')
|
| 31 |
+
ti = np.r_[0, ti[:-1]]
|
| 32 |
+
M = np.maximum(0, interpolate.interp1d(f, np.abs(X.T), kind='cubic')(fERBs)).T
|
| 33 |
+
|
| 34 |
+
if i == len(ws) - 1:
|
| 35 |
+
j = np.where(d - (i + 1) > -1)[0]
|
| 36 |
+
k = np.where(d[j] - (i + 1) < 0)[0]
|
| 37 |
+
elif i == 0:
|
| 38 |
+
j = np.where(d - (i + 1) < 1)[0]
|
| 39 |
+
k = np.where(d[j] - (i + 1) > 0)[0]
|
| 40 |
+
else:
|
| 41 |
+
j = np.where(np.abs(d - (i + 1)) < 1)[0]
|
| 42 |
+
k = np.arange(len(j))
|
| 43 |
+
|
| 44 |
+
Si = pitchStrengthAllCandidates(fERBs, np.sqrt(M), pc[j])
|
| 45 |
+
Si = interpolate.interp1d(ti, Si, bounds_error=False, fill_value='nan')(t) if Si.shape[1] > 1 else np.full((len(Si), len(t)), np.nan)
|
| 46 |
+
|
| 47 |
+
mu = np.ones(j.shape)
|
| 48 |
+
mu[k] = 1 - np.abs(d[j[k]] - i - 1)
|
| 49 |
+
S[j, :] = S[j, :] + np.tile(mu.reshape(-1, 1), (1, Si.shape[1])) * Si
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
p = np.full((S.shape[1], 1), np.nan)
|
| 53 |
+
s = np.full((S.shape[1], 1), np.nan)
|
| 54 |
+
|
| 55 |
+
for j in range(S.shape[1]):
|
| 56 |
+
s[j] = np.max(S[:, j])
|
| 57 |
+
i = np.argmax(S[:, j])
|
| 58 |
+
|
| 59 |
+
if s[j] < sTHR: continue
|
| 60 |
+
|
| 61 |
+
if i == 0: p[j] = pc[0]
|
| 62 |
+
elif i == len(pc) - 1: p[j] = pc[0]
|
| 63 |
+
else:
|
| 64 |
+
I = np.arange(i-1, i+2)
|
| 65 |
+
tc = 1 / pc[I]
|
| 66 |
+
|
| 67 |
+
ntc = (tc / tc[1] - 1) * 2 * np.pi
|
| 68 |
+
idx = np.isfinite(S[I, j])
|
| 69 |
+
|
| 70 |
+
c = np.zeros(len(ntc))
|
| 71 |
+
c += np.nan
|
| 72 |
+
|
| 73 |
+
I_ = I[idx]
|
| 74 |
+
|
| 75 |
+
if len(I_) < 2: c[idx] = (S[I, j])[0] / ntc[0]
|
| 76 |
+
else: c[idx] = np.polyfit(ntc[idx], (S[I_, j]), 2)
|
| 77 |
+
|
| 78 |
+
pval = np.polyval(c, ((1 / (2 ** np.arange(np.log2(pc[I[0]]), np.log2(pc[I[2]]) + 1 / 12 / 64, 1 / 12 / 64))) / tc[1] - 1) * 2 * np.pi)
|
| 79 |
+
s[j] = np.max(pval)
|
| 80 |
+
p[j] = 2 ** (np.log2(pc[I[0]]) + (np.argmax(pval)) / 12 / 64)
|
| 81 |
+
|
| 82 |
+
p = p.flatten()
|
| 83 |
+
p[np.isnan(p)] = 0
|
| 84 |
+
|
| 85 |
+
return np.array(p, dtype=np.float32), np.array(t, dtype=np.float32)
|
| 86 |
+
|
| 87 |
+
def round_matlab(n):
|
| 88 |
+
return int(Decimal(n).quantize(0, ROUND_HALF_UP))
|
| 89 |
+
|
| 90 |
+
def pitchStrengthAllCandidates(f, L, pc):
|
| 91 |
+
den = np.sqrt(np.sum(L * L, axis=0))
|
| 92 |
+
den = np.where(den == 0, 2.220446049250313e-16, den)
|
| 93 |
+
|
| 94 |
+
L = L / den
|
| 95 |
+
S = np.zeros((len(pc), L.shape[1]))
|
| 96 |
+
|
| 97 |
+
for j in range(len(pc)):
|
| 98 |
+
S[j,:] = pitchStrengthOneCandidate(f, L, pc[j])
|
| 99 |
+
|
| 100 |
+
return S
|
| 101 |
+
|
| 102 |
+
def pitchStrengthOneCandidate(f, L, pc):
|
| 103 |
+
k = np.zeros(len(f))
|
| 104 |
+
q = f / pc
|
| 105 |
+
|
| 106 |
+
for i in ([1] + sieve(int(np.fix(f[-1] / pc - 0.75)))):
|
| 107 |
+
a = np.abs(q - i)
|
| 108 |
+
p = a < 0.25
|
| 109 |
+
k[p] = np.cos(2 * np.pi * q[p])
|
| 110 |
+
|
| 111 |
+
v = np.logical_and((0.25 < a), (a < 0.75))
|
| 112 |
+
k[v] = k[v] + np.cos(2 * np.pi * q[v]) / 2
|
| 113 |
+
|
| 114 |
+
k *= np.sqrt(1 / f)
|
| 115 |
+
k /= np.linalg.norm(k[k>0])
|
| 116 |
+
|
| 117 |
+
return k @ L
|
| 118 |
+
|
| 119 |
+
def hz2erbs(hz):
|
| 120 |
+
return 21.4 * np.log10(1 + hz / 229)
|
| 121 |
+
|
| 122 |
+
def erbs2hz(erbs):
|
| 123 |
+
return (10 ** (erbs / 21.4) - 1) * 229
|
| 124 |
+
|
| 125 |
+
def sieve(n):
|
| 126 |
+
primes = list(range(2, n + 1))
|
| 127 |
+
num = 2
|
| 128 |
+
|
| 129 |
+
while num < math.sqrt(n):
|
| 130 |
+
i = num
|
| 131 |
+
|
| 132 |
+
while i <= n:
|
| 133 |
+
i += num
|
| 134 |
+
|
| 135 |
+
if i in primes: primes.remove(i)
|
| 136 |
+
|
| 137 |
+
for j in primes:
|
| 138 |
+
if j > num:
|
| 139 |
+
num = j
|
| 140 |
+
break
|
| 141 |
+
|
| 142 |
+
return primes
|
| 143 |
+
|
| 144 |
+
def stonemask(x, fs, temporal_positions, f0):
|
| 145 |
+
refined_f0 = np.copy(f0)
|
| 146 |
+
|
| 147 |
+
for i in range(len(temporal_positions)):
|
| 148 |
+
if f0[i] != 0:
|
| 149 |
+
refined_f0[i] = get_refined_f0(x, fs, temporal_positions[i], f0[i])
|
| 150 |
+
if abs(refined_f0[i] - f0[i]) / f0[i] > 0.2: refined_f0[i] = f0[i]
|
| 151 |
+
|
| 152 |
+
return np.array(refined_f0, dtype=np.float32)
|
| 153 |
+
|
| 154 |
+
def get_refined_f0(x, fs, current_time, current_f0):
|
| 155 |
+
f0_initial = current_f0
|
| 156 |
+
half_window_length = np.ceil(3 * fs / f0_initial / 2)
|
| 157 |
+
window_length_in_time = (2 * half_window_length + 1) / fs
|
| 158 |
+
|
| 159 |
+
base_time = np.arange(-half_window_length, half_window_length + 1) / fs
|
| 160 |
+
fft_size = 2 ** math.ceil(math.log((half_window_length * 2 + 1), 2) + 1)
|
| 161 |
+
|
| 162 |
+
base_time = np.array([float("{0:.4f}".format(elm)) for elm in base_time])
|
| 163 |
+
index_raw = round_matlab_2((current_time + base_time) * fs)
|
| 164 |
+
|
| 165 |
+
window_time = ((index_raw - 1) / fs) - current_time
|
| 166 |
+
main_window = 0.42 + 0.5 * np.cos(2 * math.pi * window_time / window_length_in_time) + 0.08 * np.cos(4 * math.pi * window_time / window_length_in_time)
|
| 167 |
+
|
| 168 |
+
index = np.array(np.maximum(1, np.minimum(len(x), index_raw)), dtype=int)
|
| 169 |
+
spectrum = np.fft.fft(x[index - 1] * main_window, fft_size)
|
| 170 |
+
|
| 171 |
+
diff_spectrum = np.fft.fft(x[index - 1] * (-(np.diff(np.r_[0, main_window]) + np.diff(np.r_[main_window, 0])) / 2), fft_size)
|
| 172 |
+
power_spectrum = np.abs(spectrum) ** 2
|
| 173 |
+
|
| 174 |
+
from sys import float_info
|
| 175 |
+
|
| 176 |
+
power_spectrum[power_spectrum == 0] = float_info.epsilon
|
| 177 |
+
instantaneous_frequency = (np.arange(fft_size) / fft_size * fs) + (np.real(spectrum) * np.imag(diff_spectrum) - np.imag(spectrum) * np.real(diff_spectrum)) / power_spectrum * fs / 2 / math.pi
|
| 178 |
+
|
| 179 |
+
trim_index = np.array([1, 2])
|
| 180 |
+
index_list_trim = np.array(round_matlab_2(f0_initial * fft_size / fs * trim_index) + 1, int)
|
| 181 |
+
|
| 182 |
+
amp_list = np.sqrt(power_spectrum[index_list_trim - 1])
|
| 183 |
+
f0_initial = np.sum(amp_list * instantaneous_frequency[index_list_trim - 1]) / np.sum(amp_list * trim_index)
|
| 184 |
+
|
| 185 |
+
if f0_initial < 0: return 0
|
| 186 |
+
|
| 187 |
+
trim_index = np.array([1, 2, 3, 4, 5, 6])
|
| 188 |
+
index_list_trim = np.array(round_matlab_2(f0_initial * fft_size / fs * trim_index) + 1, int)
|
| 189 |
+
amp_list = np.sqrt(power_spectrum[index_list_trim - 1])
|
| 190 |
+
|
| 191 |
+
return np.sum(amp_list * instantaneous_frequency[index_list_trim - 1]) / np.sum(amp_list * trim_index)
|
| 192 |
+
|
| 193 |
+
@nb.jit((nb.float64[:],), nopython=True, cache=True)
|
| 194 |
+
def round_matlab_2(x):
|
| 195 |
+
y = x.copy()
|
| 196 |
+
|
| 197 |
+
y[x > 0] += 0.5
|
| 198 |
+
y[x <= 0] -= 0.5
|
| 199 |
+
|
| 200 |
+
return y
|
RVC/modules/synthesizers.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
sys.path.append(os.getcwd())
|
| 6 |
+
|
| 7 |
+
from modules.hifigan import HiFiGANGenerator
|
| 8 |
+
from modules.refinegan import RefineGANGenerator
|
| 9 |
+
from modules.residuals import ResidualCouplingBlock
|
| 10 |
+
from modules.mrf_hifigan import HiFiGANMRFGenerator
|
| 11 |
+
from modules.nsf_hifigan import HiFiGANNRFGenerator
|
| 12 |
+
from modules.encoders import TextEncoder, PosteriorEncoder
|
| 13 |
+
from modules.commons import slice_segments, rand_slice_segments
|
| 14 |
+
|
| 15 |
+
class Synthesizer(torch.nn.Module):
|
| 16 |
+
def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, use_f0, text_enc_hidden_dim=768, vocoder="Default", checkpointing=False, energy=False, **kwargs):
|
| 17 |
+
super(Synthesizer, self).__init__()
|
| 18 |
+
self.spec_channels = spec_channels
|
| 19 |
+
self.inter_channels = inter_channels
|
| 20 |
+
self.hidden_channels = hidden_channels
|
| 21 |
+
self.filter_channels = filter_channels
|
| 22 |
+
self.n_heads = n_heads
|
| 23 |
+
self.n_layers = n_layers
|
| 24 |
+
self.kernel_size = kernel_size
|
| 25 |
+
self.p_dropout = float(p_dropout)
|
| 26 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
| 27 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
| 28 |
+
self.upsample_rates = upsample_rates
|
| 29 |
+
self.upsample_initial_channel = upsample_initial_channel
|
| 30 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
| 31 |
+
self.segment_size = segment_size
|
| 32 |
+
self.gin_channels = gin_channels
|
| 33 |
+
self.spk_embed_dim = spk_embed_dim
|
| 34 |
+
self.use_f0 = use_f0
|
| 35 |
+
self.enc_p = TextEncoder(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), text_enc_hidden_dim, f0=use_f0, energy=energy)
|
| 36 |
+
|
| 37 |
+
if use_f0:
|
| 38 |
+
if vocoder == "RefineGAN": self.dec = RefineGANGenerator(sample_rate=sr, upsample_rates=upsample_rates, num_mels=inter_channels, checkpointing=checkpointing)
|
| 39 |
+
elif vocoder in ["MRF-HiFi-GAN", "MRF HiFi-GAN"]: self.dec = HiFiGANMRFGenerator(in_channel=inter_channels, upsample_initial_channel=upsample_initial_channel, upsample_rates=upsample_rates, upsample_kernel_sizes=upsample_kernel_sizes, resblock_kernel_sizes=resblock_kernel_sizes, resblock_dilations=resblock_dilation_sizes, gin_channels=gin_channels, sample_rate=sr, harmonic_num=8, checkpointing=checkpointing)
|
| 40 |
+
else: self.dec = HiFiGANNRFGenerator(inter_channels, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, checkpointing=checkpointing)
|
| 41 |
+
else: self.dec = HiFiGANGenerator(inter_channels, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
| 42 |
+
|
| 43 |
+
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
|
| 44 |
+
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels)
|
| 45 |
+
self.emb_g = torch.nn.Embedding(self.spk_embed_dim, gin_channels)
|
| 46 |
+
|
| 47 |
+
def remove_weight_norm(self):
|
| 48 |
+
self.dec.remove_weight_norm()
|
| 49 |
+
self.flow.remove_weight_norm()
|
| 50 |
+
self.enc_q.remove_weight_norm()
|
| 51 |
+
|
| 52 |
+
@torch.jit.ignore
|
| 53 |
+
def forward(self, phone, phone_lengths, pitch = None, pitchf = None, y = None, y_lengths = None, ds = None, energy = None):
|
| 54 |
+
g = self.emb_g(ds).unsqueeze(-1)
|
| 55 |
+
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths, energy)
|
| 56 |
+
|
| 57 |
+
if y is not None:
|
| 58 |
+
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
| 59 |
+
z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size)
|
| 60 |
+
|
| 61 |
+
return (self.dec(z_slice, slice_segments(pitchf, ids_slice, self.segment_size, 2), g=g) if self.use_f0 else self.dec(z_slice, g=g)), ids_slice, x_mask, y_mask, (z, self.flow(z, y_mask, g=g), m_p, logs_p, m_q, logs_q)
|
| 62 |
+
else: return None, None, x_mask, None, (None, None, m_p, logs_p, None, None)
|
| 63 |
+
|
| 64 |
+
@torch.jit.export
|
| 65 |
+
def infer(self, phone, phone_lengths, pitch = None, nsff0 = None, sid = None, energy = None, rate = None):
|
| 66 |
+
g = self.emb_g(sid).unsqueeze(-1)
|
| 67 |
+
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths, energy)
|
| 68 |
+
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
| 69 |
+
|
| 70 |
+
if rate is not None:
|
| 71 |
+
assert isinstance(rate, torch.Tensor)
|
| 72 |
+
head = int(z_p.shape[2] * (1.0 - rate.item()))
|
| 73 |
+
z_p = z_p[:, :, head:]
|
| 74 |
+
x_mask = x_mask[:, :, head:]
|
| 75 |
+
if self.use_f0: nsff0 = nsff0[:, head:]
|
| 76 |
+
|
| 77 |
+
if self.use_f0:
|
| 78 |
+
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
| 79 |
+
o = self.dec(z * x_mask, nsff0, g=g)
|
| 80 |
+
else:
|
| 81 |
+
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
| 82 |
+
o = self.dec(z * x_mask, g=g)
|
| 83 |
+
|
| 84 |
+
return o, x_mask, (z, z_p, m_p, logs_p)
|
RVC/modules/torchcrepe.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import librosa
|
| 3 |
+
import functools
|
| 4 |
+
import scipy.stats
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
CENTS_PER_BIN, MAX_FMAX, PITCH_BINS, SAMPLE_RATE, WINDOW_SIZE = 20, 2006, 360, 16000, 1024
|
| 9 |
+
|
| 10 |
+
def mean(signals, win_length=9):
|
| 11 |
+
assert signals.dim() == 2
|
| 12 |
+
|
| 13 |
+
signals = signals.unsqueeze(1)
|
| 14 |
+
mask = ~torch.isnan(signals)
|
| 15 |
+
padding = win_length // 2
|
| 16 |
+
|
| 17 |
+
ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device)
|
| 18 |
+
avg_pooled = torch.nn.functional.conv1d(torch.where(mask, signals, torch.zeros_like(signals)), ones_kernel, stride=1, padding=padding) / torch.nn.functional.conv1d(mask.float(), ones_kernel, stride=1, padding=padding).clamp(min=1)
|
| 19 |
+
avg_pooled[avg_pooled == 0] = float("nan")
|
| 20 |
+
|
| 21 |
+
return avg_pooled.squeeze(1)
|
| 22 |
+
|
| 23 |
+
def median(signals, win_length):
|
| 24 |
+
assert signals.dim() == 2
|
| 25 |
+
|
| 26 |
+
signals = signals.unsqueeze(1)
|
| 27 |
+
mask = ~torch.isnan(signals)
|
| 28 |
+
padding = win_length // 2
|
| 29 |
+
|
| 30 |
+
x = torch.nn.functional.pad(torch.where(mask, signals, torch.zeros_like(signals)), (padding, padding), mode="reflect")
|
| 31 |
+
mask = torch.nn.functional.pad(mask.float(), (padding, padding), mode="constant", value=0)
|
| 32 |
+
|
| 33 |
+
x = x.unfold(2, win_length, 1)
|
| 34 |
+
mask = mask.unfold(2, win_length, 1)
|
| 35 |
+
|
| 36 |
+
x = x.contiguous().view(x.size()[:3] + (-1,))
|
| 37 |
+
mask = mask.contiguous().view(mask.size()[:3] + (-1,))
|
| 38 |
+
|
| 39 |
+
x_sorted, _ = torch.sort(torch.where(mask.bool(), x.float(), float("inf")).to(x), dim=-1)
|
| 40 |
+
|
| 41 |
+
median_pooled = x_sorted.gather(-1, ((mask.sum(dim=-1) - 1) // 2).clamp(min=0).unsqueeze(-1).long()).squeeze(-1)
|
| 42 |
+
median_pooled[torch.isinf(median_pooled)] = float("nan")
|
| 43 |
+
|
| 44 |
+
return median_pooled.squeeze(1)
|
| 45 |
+
|
| 46 |
+
class CREPE_MODEL(torch.nn.Module):
|
| 47 |
+
def __init__(self, model='full'):
|
| 48 |
+
super().__init__()
|
| 49 |
+
in_channels = {"full": [1, 1024, 128, 128, 128, 256], "large": [1, 768, 96, 96, 96, 192], "medium": [1, 512, 64, 64, 64, 128], "small": [1, 256, 32, 32, 32, 64], "tiny": [1, 128, 16, 16, 16, 32]}[model]
|
| 50 |
+
out_channels = {"full": [1024, 128, 128, 128, 256, 512], "large": [768, 96, 96, 96, 192, 384], "medium": [512, 64, 64, 64, 128, 256], "small": [256, 32, 32, 32, 64, 128], "tiny": [128, 16, 16, 16, 32, 64]}[model]
|
| 51 |
+
self.in_features = {"full": 2048, "large": 1536, "medium": 1024, "small": 512, "tiny": 256}[model]
|
| 52 |
+
|
| 53 |
+
kernel_sizes = [(512, 1)] + 5 * [(64, 1)]
|
| 54 |
+
strides = [(4, 1)] + 5 * [(1, 1)]
|
| 55 |
+
batch_norm_fn = functools.partial(torch.nn.BatchNorm2d, eps=0.0010000000474974513, momentum=0.0)
|
| 56 |
+
|
| 57 |
+
self.conv1 = torch.nn.Conv2d(in_channels=in_channels[0], out_channels=out_channels[0], kernel_size=kernel_sizes[0], stride=strides[0])
|
| 58 |
+
self.conv1_BN = batch_norm_fn(num_features=out_channels[0])
|
| 59 |
+
|
| 60 |
+
self.conv2 = torch.nn.Conv2d(in_channels=in_channels[1], out_channels=out_channels[1], kernel_size=kernel_sizes[1], stride=strides[1])
|
| 61 |
+
self.conv2_BN = batch_norm_fn(num_features=out_channels[1])
|
| 62 |
+
|
| 63 |
+
self.conv3 = torch.nn.Conv2d(in_channels=in_channels[2], out_channels=out_channels[2], kernel_size=kernel_sizes[2], stride=strides[2])
|
| 64 |
+
self.conv3_BN = batch_norm_fn(num_features=out_channels[2])
|
| 65 |
+
|
| 66 |
+
self.conv4 = torch.nn.Conv2d(in_channels=in_channels[3], out_channels=out_channels[3], kernel_size=kernel_sizes[3], stride=strides[3])
|
| 67 |
+
self.conv4_BN = batch_norm_fn(num_features=out_channels[3])
|
| 68 |
+
|
| 69 |
+
self.conv5 = torch.nn.Conv2d(in_channels=in_channels[4], out_channels=out_channels[4], kernel_size=kernel_sizes[4], stride=strides[4])
|
| 70 |
+
self.conv5_BN = batch_norm_fn(num_features=out_channels[4])
|
| 71 |
+
|
| 72 |
+
self.conv6 = torch.nn.Conv2d(in_channels=in_channels[5], out_channels=out_channels[5], kernel_size=kernel_sizes[5], stride=strides[5])
|
| 73 |
+
self.conv6_BN = batch_norm_fn(num_features=out_channels[5])
|
| 74 |
+
|
| 75 |
+
self.classifier = torch.nn.Linear(in_features=self.in_features, out_features=PITCH_BINS)
|
| 76 |
+
|
| 77 |
+
def forward(self, x, embed=False):
|
| 78 |
+
x = self.embed(x)
|
| 79 |
+
if embed: return x
|
| 80 |
+
return torch.sigmoid(self.classifier(self.layer(x, self.conv6, self.conv6_BN).permute(0, 2, 1, 3).reshape(-1, self.in_features)))
|
| 81 |
+
|
| 82 |
+
def embed(self, x):
|
| 83 |
+
x = x[:, None, :, None]
|
| 84 |
+
return self.layer(self.layer(self.layer(self.layer(self.layer(x, self.conv1, self.conv1_BN, (0, 0, 254, 254)), self.conv2, self.conv2_BN), self.conv3, self.conv3_BN), self.conv4, self.conv4_BN), self.conv5, self.conv5_BN)
|
| 85 |
+
|
| 86 |
+
def layer(self, x, conv, batch_norm, padding=(0, 0, 31, 32)):
|
| 87 |
+
return torch.nn.functional.max_pool2d(batch_norm(torch.nn.functional.relu(conv(torch.nn.functional.pad(x, padding)))), (2, 1), (2, 1))
|
| 88 |
+
|
| 89 |
+
class CREPE:
|
| 90 |
+
def __init__(self, model_path, model_size="full", hop_length=512, batch_size=None, f0_min=50, f0_max=1100, device=None, sample_rate=16000, return_periodicity=False):
|
| 91 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 92 |
+
self.hop_length = hop_length
|
| 93 |
+
self.batch_size = batch_size
|
| 94 |
+
self.sample_rate = sample_rate
|
| 95 |
+
self.f0_min = f0_min
|
| 96 |
+
self.f0_max = f0_max
|
| 97 |
+
self.return_periodicity = return_periodicity
|
| 98 |
+
model = CREPE_MODEL(model_size)
|
| 99 |
+
ckpt = torch.load(model_path, map_location="cpu")
|
| 100 |
+
model.load_state_dict(ckpt)
|
| 101 |
+
model.eval()
|
| 102 |
+
self.model = model.to(device)
|
| 103 |
+
|
| 104 |
+
def bins_to_frequency(self, bins):
|
| 105 |
+
if str(bins.device).startswith("ocl"): bins = bins.to(torch.float32)
|
| 106 |
+
|
| 107 |
+
cents = CENTS_PER_BIN * bins + 1997.3794084376191
|
| 108 |
+
return 10 * 2 ** ((cents + cents.new_tensor(scipy.stats.triang.rvs(c=0.5, loc=-CENTS_PER_BIN, scale=2 * CENTS_PER_BIN, size=cents.size()))) / 1200)
|
| 109 |
+
|
| 110 |
+
def frequency_to_bins(self, frequency, quantize_fn=torch.floor):
|
| 111 |
+
return quantize_fn(((1200 * torch.log2(frequency / 10)) - 1997.3794084376191) / CENTS_PER_BIN).int()
|
| 112 |
+
|
| 113 |
+
def viterbi(self, logits):
|
| 114 |
+
if not hasattr(self, 'transition'):
|
| 115 |
+
xx, yy = np.meshgrid(range(360), range(360))
|
| 116 |
+
transition = np.maximum(12 - abs(xx - yy), 0)
|
| 117 |
+
self.transition = transition / transition.sum(axis=1, keepdims=True)
|
| 118 |
+
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
probs = torch.nn.functional.softmax(logits, dim=1)
|
| 121 |
+
|
| 122 |
+
bins = torch.tensor(np.array([librosa.sequence.viterbi(sequence, self.transition).astype(np.int64) for sequence in probs.cpu().numpy()]), device=probs.device)
|
| 123 |
+
return bins, self.bins_to_frequency(bins)
|
| 124 |
+
|
| 125 |
+
def preprocess(self, audio, pad=True):
|
| 126 |
+
hop_length = (self.sample_rate // 100) if self.hop_length is None else self.hop_length
|
| 127 |
+
|
| 128 |
+
if self.sample_rate != SAMPLE_RATE:
|
| 129 |
+
audio = torch.tensor(librosa.resample(audio.detach().cpu().numpy().squeeze(0), orig_sr=self.sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_vhq"), device=audio.device).unsqueeze(0)
|
| 130 |
+
hop_length = int(hop_length * SAMPLE_RATE / self.sample_rate)
|
| 131 |
+
|
| 132 |
+
if pad:
|
| 133 |
+
total_frames = 1 + int(audio.size(1) // hop_length)
|
| 134 |
+
audio = torch.nn.functional.pad(audio, (WINDOW_SIZE // 2, WINDOW_SIZE // 2))
|
| 135 |
+
else: total_frames = 1 + int((audio.size(1) - WINDOW_SIZE) // hop_length)
|
| 136 |
+
|
| 137 |
+
batch_size = total_frames if self.batch_size is None else self.batch_size
|
| 138 |
+
|
| 139 |
+
for i in range(0, total_frames, batch_size):
|
| 140 |
+
frames = torch.nn.functional.unfold(audio[:, None, None, max(0, i * hop_length):min(audio.size(1), (i + batch_size - 1) * hop_length + WINDOW_SIZE)], kernel_size=(1, WINDOW_SIZE), stride=(1, hop_length))
|
| 141 |
+
|
| 142 |
+
if self.device.startswith("ocl"):
|
| 143 |
+
frames = frames.transpose(1, 2).contiguous().reshape(-1, WINDOW_SIZE).to(self.device)
|
| 144 |
+
else:
|
| 145 |
+
frames = frames.transpose(1, 2).reshape(-1, WINDOW_SIZE).to(self.device)
|
| 146 |
+
|
| 147 |
+
frames -= frames.mean(dim=1, keepdim=True)
|
| 148 |
+
frames /= torch.max(torch.tensor(1e-10, device=frames.device), frames.std(dim=1, keepdim=True))
|
| 149 |
+
|
| 150 |
+
yield frames
|
| 151 |
+
|
| 152 |
+
def periodicity(self, probabilities, bins):
|
| 153 |
+
probs_stacked = probabilities.transpose(1, 2).reshape(-1, PITCH_BINS)
|
| 154 |
+
periodicity = probs_stacked.gather(1, bins.reshape(-1, 1).to(torch.int64))
|
| 155 |
+
|
| 156 |
+
return periodicity.reshape(probabilities.size(0), probabilities.size(2))
|
| 157 |
+
|
| 158 |
+
def postprocess(self, probabilities):
|
| 159 |
+
probabilities = probabilities.detach()
|
| 160 |
+
probabilities[:, :self.frequency_to_bins(torch.tensor(self.f0_min))] = -float('inf')
|
| 161 |
+
probabilities[:, self.frequency_to_bins(torch.tensor(self.f0_max), torch.ceil):] = -float('inf')
|
| 162 |
+
|
| 163 |
+
bins, pitch = self.viterbi(probabilities)
|
| 164 |
+
|
| 165 |
+
if not self.return_periodicity: return pitch
|
| 166 |
+
return pitch, self.periodicity(probabilities, bins)
|
| 167 |
+
|
| 168 |
+
def compute_f0(self, audio, pad=True):
|
| 169 |
+
results = []
|
| 170 |
+
|
| 171 |
+
for frames in self.preprocess(audio, pad):
|
| 172 |
+
with torch.no_grad():
|
| 173 |
+
model = self.model(
|
| 174 |
+
frames,
|
| 175 |
+
embed=False
|
| 176 |
+
).reshape(audio.size(0), -1, PITCH_BINS).transpose(1, 2)
|
| 177 |
+
|
| 178 |
+
result = self.postprocess(model)
|
| 179 |
+
results.append((result[0].to(audio.device), result[1].to(audio.device)) if isinstance(result, tuple) else result.to(audio.device))
|
| 180 |
+
|
| 181 |
+
if self.return_periodicity:
|
| 182 |
+
pitch, periodicity = zip(*results)
|
| 183 |
+
return torch.cat(pitch, 1), torch.cat(periodicity, 1)
|
| 184 |
+
|
| 185 |
+
return torch.cat(results, 1)
|
RVC/modules/torchfcpe.py
ADDED
|
@@ -0,0 +1,951 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from torch import einsum
|
| 11 |
+
from functools import partial
|
| 12 |
+
from librosa.filters import mel
|
| 13 |
+
from torchaudio.transforms import Resample
|
| 14 |
+
from einops import rearrange, repeat, pack, unpack
|
| 15 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 16 |
+
|
| 17 |
+
sys.path.append(os.getcwd())
|
| 18 |
+
|
| 19 |
+
from modules import opencl
|
| 20 |
+
|
| 21 |
+
os.environ["LRU_CACHE_CAPACITY"] = "3"
|
| 22 |
+
|
| 23 |
+
def spawn_wav2mel(args, device = None):
|
| 24 |
+
_type = args.mel.type
|
| 25 |
+
if (str(_type).lower() == 'none') or (str(_type).lower() == 'default'): _type = 'default'
|
| 26 |
+
elif str(_type).lower() == 'stft': _type = 'stft'
|
| 27 |
+
wav2mel = Wav2MelModule(sr=args.mel.sr, n_mels=args.mel.num_mels, n_fft=args.mel.n_fft, win_size=args.mel.win_size, hop_length=args.mel.hop_size, fmin=args.mel.fmin, fmax=args.mel.fmax, clip_val=1e-05, mel_type=_type)
|
| 28 |
+
|
| 29 |
+
return wav2mel.to(torch.device(device))
|
| 30 |
+
|
| 31 |
+
def calc_same_padding(kernel_size):
|
| 32 |
+
pad = kernel_size // 2
|
| 33 |
+
return (pad, pad - (kernel_size + 1) % 2)
|
| 34 |
+
|
| 35 |
+
def l2_regularization(model, l2_alpha):
|
| 36 |
+
l2_loss = []
|
| 37 |
+
for module in model.modules():
|
| 38 |
+
if type(module) is nn.Conv2d: l2_loss.append((module.weight**2).sum() / 2.0)
|
| 39 |
+
|
| 40 |
+
return l2_alpha * sum(l2_loss)
|
| 41 |
+
|
| 42 |
+
def torch_interp(x, xp, fp):
|
| 43 |
+
sort_idx = torch.argsort(xp)
|
| 44 |
+
xp = xp[sort_idx]
|
| 45 |
+
fp = fp[sort_idx]
|
| 46 |
+
|
| 47 |
+
right_idxs = torch.searchsorted(xp, x).clamp(max=len(xp) - 1)
|
| 48 |
+
left_idxs = (right_idxs - 1).clamp(min=0)
|
| 49 |
+
x_left = xp[left_idxs]
|
| 50 |
+
y_left = fp[left_idxs]
|
| 51 |
+
|
| 52 |
+
interp_vals = y_left + ((x - x_left) * (fp[right_idxs] - y_left) / (xp[right_idxs] - x_left))
|
| 53 |
+
interp_vals[x < xp[0]] = fp[0]
|
| 54 |
+
interp_vals[x > xp[-1]] = fp[-1]
|
| 55 |
+
|
| 56 |
+
return interp_vals
|
| 57 |
+
|
| 58 |
+
def batch_interp_with_replacement_detach(uv, f0):
|
| 59 |
+
result = f0.clone()
|
| 60 |
+
for i in range(uv.shape[0]):
|
| 61 |
+
interp_vals = torch_interp(torch.where(uv[i])[-1], torch.where(~uv[i])[-1], f0[i][~uv[i]]).detach()
|
| 62 |
+
result[i][uv[i]] = interp_vals
|
| 63 |
+
|
| 64 |
+
return result
|
| 65 |
+
|
| 66 |
+
def ensemble_f0(f0s, key_shift_list, tta_uv_penalty):
|
| 67 |
+
device = f0s.device
|
| 68 |
+
f0s = f0s / (torch.pow(2, torch.tensor(key_shift_list, device=device).to(device).unsqueeze(0).unsqueeze(0) / 12))
|
| 69 |
+
notes = torch.log2(f0s / 440) * 12 + 69
|
| 70 |
+
notes[notes < 0] = 0
|
| 71 |
+
|
| 72 |
+
uv_penalty = tta_uv_penalty**2
|
| 73 |
+
dp = torch.zeros_like(notes, device=device)
|
| 74 |
+
backtrack = torch.zeros_like(notes, device=device).long()
|
| 75 |
+
dp[:, 0, :] = (notes[:, 0, :] <= 0) * uv_penalty
|
| 76 |
+
|
| 77 |
+
for t in range(1, notes.size(1)):
|
| 78 |
+
penalty = torch.zeros([notes.size(0), notes.size(2), notes.size(2)], device=device)
|
| 79 |
+
t_uv = notes[:, t, :] <= 0
|
| 80 |
+
penalty += uv_penalty * t_uv.unsqueeze(1)
|
| 81 |
+
|
| 82 |
+
t1_uv = notes[:, t - 1, :] <= 0
|
| 83 |
+
l2 = torch.pow((notes[:, t - 1, :].unsqueeze(-1) - notes[:, t, :].unsqueeze(1)) * (~t1_uv).unsqueeze(-1) * (~t_uv).unsqueeze(1), 2) - 0.5
|
| 84 |
+
l2 = l2 * (l2 > 0)
|
| 85 |
+
|
| 86 |
+
penalty += l2
|
| 87 |
+
penalty += t1_uv.unsqueeze(-1) * (~t_uv).unsqueeze(1) * uv_penalty * 2
|
| 88 |
+
|
| 89 |
+
min_value, min_indices = torch.min(dp[:, t - 1, :].unsqueeze(-1) + penalty, dim=1)
|
| 90 |
+
dp[:, t, :] = min_value
|
| 91 |
+
backtrack[:, t, :] = min_indices
|
| 92 |
+
|
| 93 |
+
t = f0s.size(1) - 1
|
| 94 |
+
f0_result = torch.zeros_like(f0s[:, :, 0], device=device)
|
| 95 |
+
min_indices = torch.argmin(dp[:, t, :], dim=-1)
|
| 96 |
+
|
| 97 |
+
for i in range(0, t + 1):
|
| 98 |
+
f0_result[:, t - i] = f0s[:, t - i, min_indices]
|
| 99 |
+
min_indices = backtrack[:, t - i, min_indices]
|
| 100 |
+
|
| 101 |
+
return f0_result.unsqueeze(-1)
|
| 102 |
+
|
| 103 |
+
def exists(val):
|
| 104 |
+
return val is not None
|
| 105 |
+
|
| 106 |
+
def default(value, d):
|
| 107 |
+
return value if exists(value) else d
|
| 108 |
+
|
| 109 |
+
def empty(tensor):
|
| 110 |
+
return tensor.numel() == 0
|
| 111 |
+
|
| 112 |
+
def pad_to_multiple(tensor, multiple, dim=-1, value=0):
|
| 113 |
+
seqlen = tensor.shape[dim]
|
| 114 |
+
m = seqlen / multiple
|
| 115 |
+
if m.is_integer(): return False, tensor
|
| 116 |
+
return True, F.pad(tensor, (*((0,) * (-1 - dim) * 2), 0, (math.ceil(m) * multiple - seqlen)), value = value)
|
| 117 |
+
|
| 118 |
+
def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2):
|
| 119 |
+
t = x.shape[1]
|
| 120 |
+
dims = (len(x.shape) - dim) * (0, 0)
|
| 121 |
+
padded_x = F.pad(x, (*dims, backward, forward), value = pad_value)
|
| 122 |
+
return torch.cat([padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)], dim = dim)
|
| 123 |
+
|
| 124 |
+
def rotate_half(x):
|
| 125 |
+
x1, x2 = rearrange(x, 'b ... (r d) -> b ... r d', r = 2).unbind(dim = -2)
|
| 126 |
+
return torch.cat((-x2, x1), dim = -1)
|
| 127 |
+
|
| 128 |
+
def apply_rotary_pos_emb(q, k, freqs, scale = 1):
|
| 129 |
+
q_len = q.shape[-2]
|
| 130 |
+
q_freqs = freqs[..., -q_len:, :]
|
| 131 |
+
inv_scale = scale ** -1
|
| 132 |
+
if scale.ndim == 2: scale = scale[-q_len:, :]
|
| 133 |
+
q = (q * q_freqs.cos() * scale) + (rotate_half(q) * q_freqs.sin() * scale)
|
| 134 |
+
k = (k * freqs.cos() * inv_scale) + (rotate_half(k) * freqs.sin() * inv_scale)
|
| 135 |
+
|
| 136 |
+
return q, k
|
| 137 |
+
|
| 138 |
+
def orthogonal_matrix_chunk(cols, qr_uniform_q=False, device=None):
|
| 139 |
+
unstructured_block = torch.randn((cols, cols), device=device)
|
| 140 |
+
q, r = torch.linalg.qr(unstructured_block.cpu(), mode="reduced")
|
| 141 |
+
q, r = map(lambda t: t.to(device), (q, r))
|
| 142 |
+
if qr_uniform_q:
|
| 143 |
+
d = torch.diag(r, 0)
|
| 144 |
+
q *= d.sign()
|
| 145 |
+
|
| 146 |
+
return q.t()
|
| 147 |
+
|
| 148 |
+
def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling=0, qr_uniform_q=False, device=None):
|
| 149 |
+
nb_full_blocks = int(nb_rows / nb_columns)
|
| 150 |
+
block_list = []
|
| 151 |
+
for _ in range(nb_full_blocks):
|
| 152 |
+
block_list.append(orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device))
|
| 153 |
+
|
| 154 |
+
remaining_rows = nb_rows - nb_full_blocks * nb_columns
|
| 155 |
+
if remaining_rows > 0: block_list.append(orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device)[:remaining_rows])
|
| 156 |
+
if scaling == 0: multiplier = torch.randn((nb_rows, nb_columns), device=device).norm(dim=1)
|
| 157 |
+
elif scaling == 1: multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device=device)
|
| 158 |
+
else: raise ValueError
|
| 159 |
+
|
| 160 |
+
return torch.diag(multiplier) @ torch.cat(block_list)
|
| 161 |
+
|
| 162 |
+
def linear_attention(q, k, v):
|
| 163 |
+
return einsum("...ed,...nd->...ne", k, q) if v is None else einsum("...de,...nd,...n->...ne", einsum("...nd,...ne->...de", k, v), q, 1.0 / (einsum("...nd,...d->...n", q, k.sum(dim=-2).type_as(q)) + 1e-8))
|
| 164 |
+
|
| 165 |
+
def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device=None):
|
| 166 |
+
b, h, *_ = data.shape
|
| 167 |
+
|
| 168 |
+
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.0
|
| 169 |
+
ratio = projection_matrix.shape[0] ** -0.5
|
| 170 |
+
data_dash = torch.einsum("...id,...jd->...ij", (data_normalizer * data), repeat(projection_matrix, "j d -> b h j d", b=b, h=h).type_as(data))
|
| 171 |
+
diag_data = ((torch.sum(data**2, dim=-1) / 2.0) * (data_normalizer**2)).unsqueeze(dim=-1)
|
| 172 |
+
|
| 173 |
+
return (ratio * (torch.exp(data_dash - diag_data - torch.max(data_dash, dim=-1, keepdim=True).values) + eps) if is_query else ratio * (torch.exp(data_dash - diag_data + eps))).type_as(data)
|
| 174 |
+
|
| 175 |
+
class SinusoidalEmbeddings(nn.Module):
|
| 176 |
+
def __init__(self, dim, scale_base = None, use_xpos = False, theta = 10000):
|
| 177 |
+
super().__init__()
|
| 178 |
+
inv_freq = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
| 179 |
+
self.register_buffer('inv_freq', inv_freq)
|
| 180 |
+
self.use_xpos = use_xpos
|
| 181 |
+
self.scale_base = scale_base
|
| 182 |
+
assert not (use_xpos and not exists(scale_base))
|
| 183 |
+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
| 184 |
+
self.register_buffer('scale', scale, persistent = False)
|
| 185 |
+
|
| 186 |
+
def forward(self, x):
|
| 187 |
+
seq_len, device = x.shape[-2], x.device
|
| 188 |
+
t = torch.arange(seq_len, device = x.device).type_as(self.inv_freq)
|
| 189 |
+
|
| 190 |
+
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
|
| 191 |
+
freqs = torch.cat((freqs, freqs), dim = -1)
|
| 192 |
+
|
| 193 |
+
if not self.use_xpos: return freqs, torch.ones(1, device = device)
|
| 194 |
+
|
| 195 |
+
power = (t - (seq_len // 2)) / self.scale_base
|
| 196 |
+
scale = self.scale ** rearrange(power, 'n -> n 1')
|
| 197 |
+
|
| 198 |
+
return freqs, torch.cat((scale, scale), dim = -1)
|
| 199 |
+
|
| 200 |
+
class LocalAttention(nn.Module):
|
| 201 |
+
def __init__(self, window_size, causal = False, look_backward = 1, look_forward = None, dropout = 0., shared_qk = False, rel_pos_emb_config = None, dim = None, autopad = False, exact_windowsize = False, scale = None, use_rotary_pos_emb = True, use_xpos = False, xpos_scale_base = None):
|
| 202 |
+
super().__init__()
|
| 203 |
+
look_forward = default(look_forward, 0 if causal else 1)
|
| 204 |
+
assert not (causal and look_forward > 0)
|
| 205 |
+
self.scale = scale
|
| 206 |
+
self.window_size = window_size
|
| 207 |
+
self.autopad = autopad
|
| 208 |
+
self.exact_windowsize = exact_windowsize
|
| 209 |
+
self.causal = causal
|
| 210 |
+
self.look_backward = look_backward
|
| 211 |
+
self.look_forward = look_forward
|
| 212 |
+
self.dropout = nn.Dropout(dropout)
|
| 213 |
+
self.shared_qk = shared_qk
|
| 214 |
+
self.rel_pos = None
|
| 215 |
+
self.use_xpos = use_xpos
|
| 216 |
+
if use_rotary_pos_emb and (exists(rel_pos_emb_config) or exists(dim)):
|
| 217 |
+
if exists(rel_pos_emb_config): dim = rel_pos_emb_config[0]
|
| 218 |
+
self.rel_pos = SinusoidalEmbeddings(dim, use_xpos = use_xpos, scale_base = default(xpos_scale_base, window_size // 2))
|
| 219 |
+
|
| 220 |
+
def forward(self, q, k, v, mask = None, input_mask = None, attn_bias = None, window_size = None):
|
| 221 |
+
mask = default(mask, input_mask)
|
| 222 |
+
assert not (exists(window_size) and not self.use_xpos)
|
| 223 |
+
|
| 224 |
+
_, autopad, pad_value, window_size, causal, look_backward, look_forward, shared_qk = q.shape, self.autopad, -1, default(window_size, self.window_size), self.causal, self.look_backward, self.look_forward, self.shared_qk
|
| 225 |
+
(q, packed_shape), (k, _), (v, _) = map(lambda t: pack([t], '* n d'), (q, k, v))
|
| 226 |
+
|
| 227 |
+
if autopad:
|
| 228 |
+
orig_seq_len = q.shape[1]
|
| 229 |
+
(_, q), (_, k), (_, v) = map(lambda t: pad_to_multiple(t, self.window_size, dim = -2), (q, k, v))
|
| 230 |
+
|
| 231 |
+
b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype
|
| 232 |
+
scale = default(self.scale, dim_head ** -0.5)
|
| 233 |
+
|
| 234 |
+
assert (n % window_size) == 0
|
| 235 |
+
windows = n // window_size
|
| 236 |
+
|
| 237 |
+
if shared_qk: k = F.normalize(k, dim = -1).type(k.dtype)
|
| 238 |
+
|
| 239 |
+
seq = torch.arange(n, device = device)
|
| 240 |
+
b_t = rearrange(seq, '(w n) -> 1 w n', w = windows, n = window_size)
|
| 241 |
+
bq, bk, bv = map(lambda t: rearrange(t, 'b (w n) d -> b w n d', w = windows), (q, k, v))
|
| 242 |
+
|
| 243 |
+
bq = bq * scale
|
| 244 |
+
look_around_kwargs = dict(backward = look_backward, forward = look_forward, pad_value = pad_value)
|
| 245 |
+
|
| 246 |
+
bk = look_around(bk, **look_around_kwargs)
|
| 247 |
+
bv = look_around(bv, **look_around_kwargs)
|
| 248 |
+
|
| 249 |
+
if exists(self.rel_pos):
|
| 250 |
+
pos_emb, xpos_scale = self.rel_pos(bk)
|
| 251 |
+
bq, bk = apply_rotary_pos_emb(bq, bk, pos_emb, scale = xpos_scale)
|
| 252 |
+
|
| 253 |
+
bq_t = b_t
|
| 254 |
+
bq_k = look_around(b_t, **look_around_kwargs)
|
| 255 |
+
bq_t = rearrange(bq_t, '... i -> ... i 1')
|
| 256 |
+
bq_k = rearrange(bq_k, '... j -> ... 1 j')
|
| 257 |
+
|
| 258 |
+
pad_mask = bq_k == pad_value
|
| 259 |
+
sim = einsum('b h i e, b h j e -> b h i j', bq, bk)
|
| 260 |
+
|
| 261 |
+
if exists(attn_bias):
|
| 262 |
+
heads = attn_bias.shape[0]
|
| 263 |
+
assert (b % heads) == 0
|
| 264 |
+
|
| 265 |
+
attn_bias = repeat(attn_bias, 'h i j -> (b h) 1 i j', b = b // heads)
|
| 266 |
+
sim = sim + attn_bias
|
| 267 |
+
|
| 268 |
+
mask_value = -torch.finfo(sim.dtype).max
|
| 269 |
+
if shared_qk:
|
| 270 |
+
self_mask = bq_t == bq_k
|
| 271 |
+
sim = sim.masked_fill(self_mask, -5e4)
|
| 272 |
+
del self_mask
|
| 273 |
+
|
| 274 |
+
if causal:
|
| 275 |
+
causal_mask = bq_t < bq_k
|
| 276 |
+
if self.exact_windowsize: causal_mask = causal_mask | (bq_t > (bq_k + (self.window_size * self.look_backward)))
|
| 277 |
+
sim = sim.masked_fill(causal_mask, mask_value)
|
| 278 |
+
del causal_mask
|
| 279 |
+
|
| 280 |
+
sim = sim.masked_fill(((bq_k - (self.window_size * self.look_forward)) > bq_t) | (bq_t > (bq_k + (self.window_size * self.look_backward))) | pad_mask, mask_value) if not causal and self.exact_windowsize else sim.masked_fill(pad_mask, mask_value)
|
| 281 |
+
|
| 282 |
+
if exists(mask):
|
| 283 |
+
batch = mask.shape[0]
|
| 284 |
+
assert (b % batch) == 0
|
| 285 |
+
|
| 286 |
+
h = b // mask.shape[0]
|
| 287 |
+
if autopad: _, mask = pad_to_multiple(mask, window_size, dim = -1, value = False)
|
| 288 |
+
|
| 289 |
+
mask = repeat(rearrange(look_around(rearrange(mask, '... (w n) -> (...) w n', w = windows, n = window_size), **{**look_around_kwargs, 'pad_value': False}), '... j -> ... 1 j'), 'b ... -> (b h) ...', h = h)
|
| 290 |
+
sim = sim.masked_fill(~mask, mask_value)
|
| 291 |
+
|
| 292 |
+
del mask
|
| 293 |
+
|
| 294 |
+
out = rearrange(einsum('b h i j, b h j e -> b h i e', self.dropout(sim.softmax(dim = -1)), bv), 'b w n d -> b (w n) d')
|
| 295 |
+
if autopad: out = out[:, :orig_seq_len, :]
|
| 296 |
+
|
| 297 |
+
out, *_ = unpack(out, packed_shape, '* n d')
|
| 298 |
+
return out
|
| 299 |
+
|
| 300 |
+
class FastAttention(nn.Module):
|
| 301 |
+
def __init__(self, dim_heads, nb_features=None, ortho_scaling=0, causal=False, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, no_projection=False):
|
| 302 |
+
super().__init__()
|
| 303 |
+
nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
|
| 304 |
+
self.dim_heads = dim_heads
|
| 305 |
+
self.nb_features = nb_features
|
| 306 |
+
self.ortho_scaling = ortho_scaling
|
| 307 |
+
self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows=self.nb_features, nb_columns=dim_heads, scaling=ortho_scaling, qr_uniform_q=qr_uniform_q)
|
| 308 |
+
projection_matrix = self.create_projection()
|
| 309 |
+
self.register_buffer("projection_matrix", projection_matrix)
|
| 310 |
+
self.generalized_attention = generalized_attention
|
| 311 |
+
self.kernel_fn = kernel_fn
|
| 312 |
+
self.no_projection = no_projection
|
| 313 |
+
self.causal = causal
|
| 314 |
+
|
| 315 |
+
@torch.no_grad()
|
| 316 |
+
def redraw_projection_matrix(self):
|
| 317 |
+
projections = self.create_projection()
|
| 318 |
+
self.projection_matrix.copy_(projections)
|
| 319 |
+
del projections
|
| 320 |
+
|
| 321 |
+
def forward(self, q, k, v):
|
| 322 |
+
if self.no_projection: q, k = q.softmax(dim=-1), (torch.exp(k) if self.causal else k.softmax(dim=-2))
|
| 323 |
+
else:
|
| 324 |
+
create_kernel = partial(softmax_kernel, projection_matrix=self.projection_matrix, device=q.device)
|
| 325 |
+
q, k = create_kernel(q, is_query=True), create_kernel(k, is_query=False)
|
| 326 |
+
|
| 327 |
+
attn_fn = linear_attention if not self.causal else self.causal_linear_fn
|
| 328 |
+
return attn_fn(q, k, None) if v is None else attn_fn(q, k, v)
|
| 329 |
+
|
| 330 |
+
class SelfAttention(nn.Module):
|
| 331 |
+
def __init__(self, dim, causal=False, heads=8, dim_head=64, local_heads=0, local_window_size=256, nb_features=None, feature_redraw_interval=1000, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, dropout=0.0, no_projection=False):
|
| 332 |
+
super().__init__()
|
| 333 |
+
assert dim % heads == 0
|
| 334 |
+
dim_head = default(dim_head, dim // heads)
|
| 335 |
+
inner_dim = dim_head * heads
|
| 336 |
+
self.fast_attention = FastAttention(dim_head, nb_features, causal=causal, generalized_attention=generalized_attention, kernel_fn=kernel_fn, qr_uniform_q=qr_uniform_q, no_projection=no_projection)
|
| 337 |
+
self.heads = heads
|
| 338 |
+
self.global_heads = heads - local_heads
|
| 339 |
+
self.local_attn = (LocalAttention(window_size=local_window_size, causal=causal, autopad=True, dropout=dropout, look_forward=int(not causal), rel_pos_emb_config=(dim_head, local_heads)) if local_heads > 0 else None)
|
| 340 |
+
self.to_q = nn.Linear(dim, inner_dim)
|
| 341 |
+
self.to_k = nn.Linear(dim, inner_dim)
|
| 342 |
+
self.to_v = nn.Linear(dim, inner_dim)
|
| 343 |
+
self.to_out = nn.Linear(inner_dim, dim)
|
| 344 |
+
self.dropout = nn.Dropout(dropout)
|
| 345 |
+
|
| 346 |
+
@torch.no_grad()
|
| 347 |
+
def redraw_projection_matrix(self):
|
| 348 |
+
self.fast_attention.redraw_projection_matrix()
|
| 349 |
+
|
| 350 |
+
def forward(self, x, context=None, mask=None, context_mask=None, name=None, inference=False, **kwargs):
|
| 351 |
+
_, _, _, h, gh = *x.shape, self.heads, self.global_heads
|
| 352 |
+
cross_attend = exists(context)
|
| 353 |
+
context = default(context, x)
|
| 354 |
+
context_mask = default(context_mask, mask) if not cross_attend else context_mask
|
| 355 |
+
|
| 356 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (self.to_q(x), self.to_k(context), self.to_v(context)))
|
| 357 |
+
(q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
|
| 358 |
+
|
| 359 |
+
attn_outs = []
|
| 360 |
+
|
| 361 |
+
if not empty(q):
|
| 362 |
+
if exists(context_mask): v.masked_fill_(~context_mask[:, None, :, None], 0.0)
|
| 363 |
+
if cross_attend: pass
|
| 364 |
+
else: out = self.fast_attention(q, k, v)
|
| 365 |
+
|
| 366 |
+
attn_outs.append(out)
|
| 367 |
+
|
| 368 |
+
if not empty(lq):
|
| 369 |
+
assert (not cross_attend), "not cross_attend"
|
| 370 |
+
|
| 371 |
+
out = self.local_attn(lq, lk, lv, input_mask=mask)
|
| 372 |
+
attn_outs.append(out)
|
| 373 |
+
|
| 374 |
+
return self.dropout(self.to_out(rearrange(torch.cat(attn_outs, dim=1), "b h n d -> b n (h d)")))
|
| 375 |
+
|
| 376 |
+
class DotDict(dict):
|
| 377 |
+
def __getattr__(*args):
|
| 378 |
+
val = dict.get(*args)
|
| 379 |
+
return DotDict(val) if type(val) is dict else val
|
| 380 |
+
|
| 381 |
+
__setattr__ = dict.__setitem__
|
| 382 |
+
__delattr__ = dict.__delitem__
|
| 383 |
+
|
| 384 |
+
class Swish(nn.Module):
|
| 385 |
+
def forward(self, x):
|
| 386 |
+
return x * x.sigmoid()
|
| 387 |
+
|
| 388 |
+
class Transpose(nn.Module):
|
| 389 |
+
def __init__(self, dims):
|
| 390 |
+
super().__init__()
|
| 391 |
+
assert len(dims) == 2, "dims == 2"
|
| 392 |
+
self.dims = dims
|
| 393 |
+
|
| 394 |
+
def forward(self, x):
|
| 395 |
+
return x.transpose(*self.dims)
|
| 396 |
+
|
| 397 |
+
class GLU(nn.Module):
|
| 398 |
+
def __init__(self, dim):
|
| 399 |
+
super().__init__()
|
| 400 |
+
self.dim = dim
|
| 401 |
+
|
| 402 |
+
def forward(self, x):
|
| 403 |
+
out, gate = x.chunk(2, dim=self.dim)
|
| 404 |
+
return out * gate.sigmoid()
|
| 405 |
+
|
| 406 |
+
class ConformerConvModule_LEGACY(nn.Module):
|
| 407 |
+
def __init__(self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0):
|
| 408 |
+
super().__init__()
|
| 409 |
+
inner_dim = dim * expansion_factor
|
| 410 |
+
self.net = nn.Sequential(nn.LayerNorm(dim), Transpose((1, 2)), nn.Conv1d(dim, inner_dim * 2, 1), GLU(dim=1), DepthWiseConv1d_LEGACY(inner_dim, inner_dim, kernel_size=kernel_size, padding=(calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0))), Swish(), nn.Conv1d(inner_dim, dim, 1), Transpose((1, 2)), nn.Dropout(dropout))
|
| 411 |
+
|
| 412 |
+
def forward(self, x):
|
| 413 |
+
return self.net(x)
|
| 414 |
+
|
| 415 |
+
class ConformerConvModule(nn.Module):
|
| 416 |
+
def __init__(self, dim, expansion_factor=2, kernel_size=31, dropout=0):
|
| 417 |
+
super().__init__()
|
| 418 |
+
inner_dim = dim * expansion_factor
|
| 419 |
+
self.net = nn.Sequential(nn.LayerNorm(dim), Transpose((1, 2)), nn.Conv1d(dim, inner_dim * 2, 1), nn.GLU(dim=1), DepthWiseConv1d(inner_dim, inner_dim, kernel_size=kernel_size, padding=calc_same_padding(kernel_size)[0], groups=inner_dim), nn.SiLU(), nn.Conv1d(inner_dim, dim, 1), Transpose((1, 2)), nn.Dropout(dropout))
|
| 420 |
+
|
| 421 |
+
def forward(self, x):
|
| 422 |
+
return self.net(x)
|
| 423 |
+
|
| 424 |
+
class DepthWiseConv1d_LEGACY(nn.Module):
|
| 425 |
+
def __init__(self, chan_in, chan_out, kernel_size, padding):
|
| 426 |
+
super().__init__()
|
| 427 |
+
self.padding = padding
|
| 428 |
+
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in)
|
| 429 |
+
|
| 430 |
+
def forward(self, x):
|
| 431 |
+
return self.conv(F.pad(x, self.padding))
|
| 432 |
+
|
| 433 |
+
class DepthWiseConv1d(nn.Module):
|
| 434 |
+
def __init__(self, chan_in, chan_out, kernel_size, padding, groups):
|
| 435 |
+
super().__init__()
|
| 436 |
+
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=kernel_size, padding=padding, groups=groups)
|
| 437 |
+
|
| 438 |
+
def forward(self, x):
|
| 439 |
+
return self.conv(x)
|
| 440 |
+
|
| 441 |
+
class EncoderLayer(nn.Module):
|
| 442 |
+
def __init__(self, parent):
|
| 443 |
+
super().__init__()
|
| 444 |
+
self.conformer = ConformerConvModule_LEGACY(parent.dim_model)
|
| 445 |
+
self.norm = nn.LayerNorm(parent.dim_model)
|
| 446 |
+
self.dropout = nn.Dropout(parent.residual_dropout)
|
| 447 |
+
self.attn = SelfAttention(dim=parent.dim_model, heads=parent.num_heads, causal=False)
|
| 448 |
+
|
| 449 |
+
def forward(self, phone, mask=None):
|
| 450 |
+
phone = phone + (self.attn(self.norm(phone), mask=mask))
|
| 451 |
+
return phone + (self.conformer(phone))
|
| 452 |
+
|
| 453 |
+
class ConformerNaiveEncoder(nn.Module):
|
| 454 |
+
def __init__(self, num_layers, num_heads, dim_model, use_norm = False, conv_only = False, conv_dropout = 0, atten_dropout = 0):
|
| 455 |
+
super().__init__()
|
| 456 |
+
self.num_layers = num_layers
|
| 457 |
+
self.num_heads = num_heads
|
| 458 |
+
self.dim_model = dim_model
|
| 459 |
+
self.use_norm = use_norm
|
| 460 |
+
self.residual_dropout = 0.1
|
| 461 |
+
self.attention_dropout = 0.1
|
| 462 |
+
self.encoder_layers = nn.ModuleList([CFNEncoderLayer(dim_model, num_heads, use_norm, conv_only, conv_dropout, atten_dropout) for _ in range(num_layers)])
|
| 463 |
+
|
| 464 |
+
def forward(self, x, mask=None):
|
| 465 |
+
for (_, layer) in enumerate(self.encoder_layers):
|
| 466 |
+
x = layer(x, mask)
|
| 467 |
+
|
| 468 |
+
return x
|
| 469 |
+
|
| 470 |
+
class CFNEncoderLayer(nn.Module):
|
| 471 |
+
def __init__(self, dim_model, num_heads = 8, use_norm = False, conv_only = False, conv_dropout = 0, atten_dropout = 0):
|
| 472 |
+
super().__init__()
|
| 473 |
+
self.conformer = nn.Sequential(ConformerConvModule(dim_model), nn.Dropout(conv_dropout)) if conv_dropout > 0 else ConformerConvModule(dim_model)
|
| 474 |
+
self.norm = nn.LayerNorm(dim_model)
|
| 475 |
+
self.dropout = nn.Dropout(0.1)
|
| 476 |
+
self.attn = SelfAttention(dim=dim_model, heads=num_heads, causal=False, use_norm=use_norm, dropout=atten_dropout) if not conv_only else None
|
| 477 |
+
|
| 478 |
+
def forward(self, x, mask=None):
|
| 479 |
+
if self.attn is not None: x = x + (self.attn(self.norm(x), mask=mask))
|
| 480 |
+
return x + (self.conformer(x))
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
class HannWindow(torch.nn.Module):
|
| 484 |
+
def __init__(self, win_size):
|
| 485 |
+
super().__init__()
|
| 486 |
+
self.register_buffer('window', torch.hann_window(win_size), persistent=False)
|
| 487 |
+
|
| 488 |
+
def forward(self):
|
| 489 |
+
return self.window
|
| 490 |
+
|
| 491 |
+
class MelModule(torch.nn.Module):
|
| 492 |
+
def __init__(self, sr, n_mels, n_fft, win_size, hop_length, fmin = None, fmax = None, clip_val = 1e-5, out_stft = False):
|
| 493 |
+
super().__init__()
|
| 494 |
+
if fmin is None: fmin = 0
|
| 495 |
+
if fmax is None: fmax = sr / 2
|
| 496 |
+
self.target_sr = sr
|
| 497 |
+
self.n_mels = n_mels
|
| 498 |
+
self.n_fft = n_fft
|
| 499 |
+
self.win_size = win_size
|
| 500 |
+
self.hop_length = hop_length
|
| 501 |
+
self.fmin = fmin
|
| 502 |
+
self.fmax = fmax
|
| 503 |
+
self.clip_val = clip_val
|
| 504 |
+
self.register_buffer('mel_basis', torch.tensor(mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)).float(), persistent=False)
|
| 505 |
+
self.hann_window = torch.nn.ModuleDict()
|
| 506 |
+
self.out_stft = out_stft
|
| 507 |
+
|
| 508 |
+
@torch.no_grad()
|
| 509 |
+
def __call__(self, y, key_shift = 0, speed = 1, center = False, no_cache_window = False):
|
| 510 |
+
n_fft = self.n_fft
|
| 511 |
+
win_size = self.win_size
|
| 512 |
+
hop_length = self.hop_length
|
| 513 |
+
clip_val = self.clip_val
|
| 514 |
+
factor = 2 ** (key_shift / 12)
|
| 515 |
+
n_fft_new = int(np.round(n_fft * factor))
|
| 516 |
+
win_size_new = int(np.round(win_size * factor))
|
| 517 |
+
hop_length_new = int(np.round(hop_length * speed))
|
| 518 |
+
|
| 519 |
+
y = y.squeeze(-1)
|
| 520 |
+
key_shift_key = str(key_shift)
|
| 521 |
+
|
| 522 |
+
if not no_cache_window:
|
| 523 |
+
if key_shift_key in self.hann_window: hann_window = self.hann_window[key_shift_key]
|
| 524 |
+
else:
|
| 525 |
+
hann_window = HannWindow(win_size_new).to(self.mel_basis.device)
|
| 526 |
+
self.hann_window[key_shift_key] = hann_window
|
| 527 |
+
|
| 528 |
+
hann_window_tensor = hann_window()
|
| 529 |
+
else: hann_window_tensor = torch.hann_window(win_size_new).to(self.mel_basis.device)
|
| 530 |
+
|
| 531 |
+
pad_left = (win_size_new - hop_length_new) // 2
|
| 532 |
+
pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
|
| 533 |
+
|
| 534 |
+
mode = 'reflect' if pad_right < y.size(-1) else 'constant'
|
| 535 |
+
pad = F.pad(y.unsqueeze(1), (pad_left, pad_right), mode=mode).squeeze(1)
|
| 536 |
+
|
| 537 |
+
if str(y.device).startswith("ocl"):
|
| 538 |
+
stft = opencl.STFT(filter_length=n_fft_new, hop_length=hop_length_new, win_length=win_size_new).to(y.device)
|
| 539 |
+
spec = stft.transform(pad, 1e-9)
|
| 540 |
+
else:
|
| 541 |
+
spec = torch.stft(pad, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=hann_window_tensor, center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
| 542 |
+
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-9)
|
| 543 |
+
|
| 544 |
+
if key_shift != 0:
|
| 545 |
+
size = n_fft // 2 + 1
|
| 546 |
+
resize = spec.size(1)
|
| 547 |
+
|
| 548 |
+
if resize < size: spec = F.pad(spec, (0, 0, 0, size - resize))
|
| 549 |
+
spec = spec[:, :size, :] * win_size / win_size_new
|
| 550 |
+
|
| 551 |
+
spec = spec[:, :512, :] if self.out_stft else torch.matmul(self.mel_basis, spec)
|
| 552 |
+
return torch.log(torch.clamp(spec, min=clip_val) * 1).transpose(-1, -2)
|
| 553 |
+
|
| 554 |
+
class Wav2MelModule(torch.nn.Module):
|
| 555 |
+
def __init__(self, sr, n_mels, n_fft, win_size, hop_length, fmin = None, fmax = None, clip_val = 1e-5, mel_type="default"):
|
| 556 |
+
super().__init__()
|
| 557 |
+
if fmin is None: fmin = 0
|
| 558 |
+
if fmax is None: fmax = sr / 2
|
| 559 |
+
self.sampling_rate = sr
|
| 560 |
+
self.n_mels = n_mels
|
| 561 |
+
self.n_fft = n_fft
|
| 562 |
+
self.win_size = win_size
|
| 563 |
+
self.hop_size = hop_length
|
| 564 |
+
self.fmin = fmin
|
| 565 |
+
self.fmax = fmax
|
| 566 |
+
self.clip_val = clip_val
|
| 567 |
+
self.register_buffer('tensor_device_marker', torch.tensor(1.0).float(), persistent=False)
|
| 568 |
+
self.resample_kernel = torch.nn.ModuleDict()
|
| 569 |
+
if mel_type == "default": self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, out_stft=False)
|
| 570 |
+
elif mel_type == "stft": self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, out_stft=True)
|
| 571 |
+
self.mel_type = mel_type
|
| 572 |
+
|
| 573 |
+
@torch.no_grad()
|
| 574 |
+
def __call__(self, audio, sample_rate, keyshift = 0, no_cache_window = False):
|
| 575 |
+
if sample_rate == self.sampling_rate: audio_res = audio
|
| 576 |
+
else:
|
| 577 |
+
key_str = str(sample_rate)
|
| 578 |
+
if key_str not in self.resample_kernel:
|
| 579 |
+
if len(self.resample_kernel) > 8: self.resample_kernel.clear()
|
| 580 |
+
self.resample_kernel[key_str] = Resample(sample_rate, self.sampling_rate, lowpass_filter_width=128).to(self.tensor_device_marker.device)
|
| 581 |
+
|
| 582 |
+
audio_res = self.resample_kernel[key_str](audio.squeeze(-1)).unsqueeze(-1)
|
| 583 |
+
|
| 584 |
+
mel = self.mel_extractor(audio_res, keyshift, no_cache_window=no_cache_window)
|
| 585 |
+
n_frames = int(audio.shape[1] // self.hop_size) + 1
|
| 586 |
+
if n_frames > int(mel.shape[1]): mel = torch.cat((mel, mel[:, -1:, :]), 1)
|
| 587 |
+
if n_frames < int(mel.shape[1]): mel = mel[:, :n_frames, :]
|
| 588 |
+
|
| 589 |
+
return mel
|
| 590 |
+
|
| 591 |
+
class STFT:
|
| 592 |
+
def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
|
| 593 |
+
self.target_sr = sr
|
| 594 |
+
self.n_mels = n_mels
|
| 595 |
+
self.n_fft = n_fft
|
| 596 |
+
self.win_size = win_size
|
| 597 |
+
self.hop_length = hop_length
|
| 598 |
+
self.fmin = fmin
|
| 599 |
+
self.fmax = fmax
|
| 600 |
+
self.clip_val = clip_val
|
| 601 |
+
self.mel_basis = {}
|
| 602 |
+
self.hann_window = {}
|
| 603 |
+
|
| 604 |
+
def get_mel(self, y, keyshift=0, speed=1, center=False, train=False):
|
| 605 |
+
n_fft = self.n_fft
|
| 606 |
+
win_size = self.win_size
|
| 607 |
+
hop_length = self.hop_length
|
| 608 |
+
fmax = self.fmax
|
| 609 |
+
factor = 2 ** (keyshift / 12)
|
| 610 |
+
win_size_new = int(np.round(win_size * factor))
|
| 611 |
+
hop_length_new = int(np.round(hop_length * speed))
|
| 612 |
+
mel_basis = self.mel_basis if not train else {}
|
| 613 |
+
hann_window = self.hann_window if not train else {}
|
| 614 |
+
mel_basis_key = str(fmax) + "_" + str(y.device)
|
| 615 |
+
|
| 616 |
+
if mel_basis_key not in mel_basis: mel_basis[mel_basis_key] = torch.from_numpy(mel(sr=self.target_sr, n_fft=n_fft, n_mels=self.n_mels, fmin=self.fmin, fmax=fmax)).float().to(y.device)
|
| 617 |
+
keyshift_key = str(keyshift) + "_" + str(y.device)
|
| 618 |
+
if keyshift_key not in hann_window: hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
|
| 619 |
+
|
| 620 |
+
pad_left = (win_size_new - hop_length_new) // 2
|
| 621 |
+
pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
|
| 622 |
+
|
| 623 |
+
pad = F.pad(y.unsqueeze(1), (pad_left, pad_right), mode="reflect" if pad_right < y.size(-1) else "constant").squeeze(1)
|
| 624 |
+
n_fft = int(np.round(n_fft * factor))
|
| 625 |
+
|
| 626 |
+
if str(y.device).startswith("ocl"):
|
| 627 |
+
stft = opencl.STFT(filter_length=n_fft, hop_length=hop_length_new, win_length=win_size_new).to(y.device)
|
| 628 |
+
spec = stft.transform(pad, 1e-9)
|
| 629 |
+
else:
|
| 630 |
+
spec = torch.stft(pad, n_fft, hop_length=hop_length_new, win_length=win_size_new, window=hann_window[keyshift_key], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
|
| 631 |
+
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-9)
|
| 632 |
+
|
| 633 |
+
if keyshift != 0:
|
| 634 |
+
size = n_fft // 2 + 1
|
| 635 |
+
resize = spec.size(1)
|
| 636 |
+
spec = (F.pad(spec, (0, 0, 0, size - resize)) if resize < size else spec[:, :size, :]) * win_size / win_size_new
|
| 637 |
+
|
| 638 |
+
return torch.log(torch.clamp(torch.matmul(mel_basis[mel_basis_key], spec), min=self.clip_val) * 1)
|
| 639 |
+
|
| 640 |
+
class Wav2Mel:
|
| 641 |
+
def __init__(self, device=None, dtype=torch.float32):
|
| 642 |
+
self.sample_rate = 16000
|
| 643 |
+
self.hop_size = 160
|
| 644 |
+
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 645 |
+
self.device = device
|
| 646 |
+
self.dtype = dtype
|
| 647 |
+
self.stft = STFT(16000, 128, 1024, 1024, 160, 0, 8000)
|
| 648 |
+
self.resample_kernel = {}
|
| 649 |
+
|
| 650 |
+
def extract_nvstft(self, audio, keyshift=0, train=False):
|
| 651 |
+
return self.stft.get_mel(audio, keyshift=keyshift, train=train).transpose(1, 2)
|
| 652 |
+
|
| 653 |
+
def extract_mel(self, audio, sample_rate, keyshift=0, train=False):
|
| 654 |
+
audio = audio.to(self.dtype).to(self.device)
|
| 655 |
+
if sample_rate == self.sample_rate: audio_res = audio
|
| 656 |
+
else:
|
| 657 |
+
key_str = str(sample_rate)
|
| 658 |
+
if key_str not in self.resample_kernel: self.resample_kernel[key_str] = Resample(sample_rate, self.sample_rate, lowpass_filter_width=128)
|
| 659 |
+
self.resample_kernel[key_str] = (self.resample_kernel[key_str].to(self.dtype).to(self.device))
|
| 660 |
+
audio_res = self.resample_kernel[key_str](audio)
|
| 661 |
+
|
| 662 |
+
mel = self.extract_nvstft(audio_res, keyshift=keyshift, train=train)
|
| 663 |
+
n_frames = int(audio.shape[1] // self.hop_size) + 1
|
| 664 |
+
mel = (torch.cat((mel, mel[:, -1:, :]), 1) if n_frames > int(mel.shape[1]) else mel)
|
| 665 |
+
return mel[:, :n_frames, :] if n_frames < int(mel.shape[1]) else mel
|
| 666 |
+
|
| 667 |
+
def __call__(self, audio, sample_rate, keyshift=0, train=False):
|
| 668 |
+
return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train)
|
| 669 |
+
|
| 670 |
+
class PCmer(nn.Module):
|
| 671 |
+
def __init__(self, num_layers, num_heads, dim_model, dim_keys, dim_values, residual_dropout, attention_dropout):
|
| 672 |
+
super().__init__()
|
| 673 |
+
self.num_layers = num_layers
|
| 674 |
+
self.num_heads = num_heads
|
| 675 |
+
self.dim_model = dim_model
|
| 676 |
+
self.dim_values = dim_values
|
| 677 |
+
self.dim_keys = dim_keys
|
| 678 |
+
self.residual_dropout = residual_dropout
|
| 679 |
+
self.attention_dropout = attention_dropout
|
| 680 |
+
self._layers = nn.ModuleList([EncoderLayer(self) for _ in range(num_layers)])
|
| 681 |
+
|
| 682 |
+
def forward(self, phone, mask=None):
|
| 683 |
+
for layer in self._layers:
|
| 684 |
+
phone = layer(phone, mask)
|
| 685 |
+
|
| 686 |
+
return phone
|
| 687 |
+
|
| 688 |
+
class CFNaiveMelPE(nn.Module):
|
| 689 |
+
def __init__(self, input_channels, out_dims, hidden_dims = 512, n_layers = 6, n_heads = 8, f0_max = 1975.5, f0_min = 32.70, use_fa_norm = False, conv_only = False, conv_dropout = 0, atten_dropout = 0, use_harmonic_emb = False):
|
| 690 |
+
super().__init__()
|
| 691 |
+
self.input_channels = input_channels
|
| 692 |
+
self.out_dims = out_dims
|
| 693 |
+
self.hidden_dims = hidden_dims
|
| 694 |
+
self.n_layers = n_layers
|
| 695 |
+
self.n_heads = n_heads
|
| 696 |
+
self.f0_max = f0_max
|
| 697 |
+
self.f0_min = f0_min
|
| 698 |
+
self.use_fa_norm = use_fa_norm
|
| 699 |
+
self.residual_dropout = 0.1
|
| 700 |
+
self.attention_dropout = 0.1
|
| 701 |
+
self.harmonic_emb = nn.Embedding(9, hidden_dims) if use_harmonic_emb else None
|
| 702 |
+
self.input_stack = nn.Sequential(nn.Conv1d(input_channels, hidden_dims, 3, 1, 1), nn.GroupNorm(4, hidden_dims), nn.LeakyReLU(), nn.Conv1d(hidden_dims, hidden_dims, 3, 1, 1))
|
| 703 |
+
self.net = ConformerNaiveEncoder(num_layers=n_layers, num_heads=n_heads, dim_model=hidden_dims, use_norm=use_fa_norm, conv_only=conv_only, conv_dropout=conv_dropout, atten_dropout=atten_dropout)
|
| 704 |
+
self.norm = nn.LayerNorm(hidden_dims)
|
| 705 |
+
self.output_proj = weight_norm(nn.Linear(hidden_dims, out_dims))
|
| 706 |
+
self.cent_table_b = torch.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims).detach()
|
| 707 |
+
self.register_buffer("cent_table", self.cent_table_b)
|
| 708 |
+
self.gaussian_blurred_cent_mask_b = (1200 * torch.log2(torch.Tensor([self.f0_max / 10.])))[0].detach()
|
| 709 |
+
self.register_buffer("gaussian_blurred_cent_mask", self.gaussian_blurred_cent_mask_b)
|
| 710 |
+
|
| 711 |
+
def forward(self, x, _h_emb=None):
|
| 712 |
+
x = self.input_stack(x.transpose(-1, -2)).transpose(-1, -2)
|
| 713 |
+
if self.harmonic_emb is not None: x = x + self.harmonic_emb(torch.LongTensor([0]).to(x.device)) if _h_emb is None else x + self.harmonic_emb(torch.LongTensor([int(_h_emb)]).to(x.device))
|
| 714 |
+
return torch.sigmoid(self.output_proj(self.norm(self.net(x))))
|
| 715 |
+
|
| 716 |
+
@torch.no_grad()
|
| 717 |
+
def latent2cents_decoder(self, y, threshold = 0.05, mask = True):
|
| 718 |
+
B, N, _ = y.size()
|
| 719 |
+
ci = self.cent_table[None, None, :].expand(B, N, -1)
|
| 720 |
+
rtn = torch.sum(ci * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True)
|
| 721 |
+
|
| 722 |
+
if mask:
|
| 723 |
+
confident = torch.max(y, dim=-1, keepdim=True)[0]
|
| 724 |
+
confident_mask = torch.ones_like(confident)
|
| 725 |
+
confident_mask[confident <= threshold] = float("-INF")
|
| 726 |
+
rtn = rtn * confident_mask
|
| 727 |
+
|
| 728 |
+
return rtn
|
| 729 |
+
|
| 730 |
+
@torch.no_grad()
|
| 731 |
+
def latent2cents_local_decoder(self, y, threshold = 0.05, mask = True):
|
| 732 |
+
B, N, _ = y.size()
|
| 733 |
+
ci = self.cent_table[None, None, :].expand(B, N, -1)
|
| 734 |
+
confident, max_index = torch.max(y, dim=-1, keepdim=True)
|
| 735 |
+
|
| 736 |
+
local_argmax_index = torch.arange(0, 9).to(max_index.device) + (max_index - 4)
|
| 737 |
+
local_argmax_index[local_argmax_index < 0] = 0
|
| 738 |
+
local_argmax_index[local_argmax_index >= self.out_dims] = self.out_dims - 1
|
| 739 |
+
|
| 740 |
+
y_l = torch.gather(y, -1, local_argmax_index)
|
| 741 |
+
rtn = torch.sum(torch.gather(ci, -1, local_argmax_index) * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True)
|
| 742 |
+
|
| 743 |
+
if mask:
|
| 744 |
+
confident_mask = torch.ones_like(confident)
|
| 745 |
+
confident_mask[confident <= threshold] = float("-INF")
|
| 746 |
+
rtn = rtn * confident_mask
|
| 747 |
+
|
| 748 |
+
return rtn
|
| 749 |
+
|
| 750 |
+
@torch.no_grad()
|
| 751 |
+
def infer(self, mel, decoder = "local_argmax", threshold = 0.05):
|
| 752 |
+
latent = self.forward(mel)
|
| 753 |
+
if decoder == "argmax": cents = self.latent2cents_local_decoder
|
| 754 |
+
elif decoder == "local_argmax": cents = self.latent2cents_local_decoder
|
| 755 |
+
|
| 756 |
+
return self.cent_to_f0(cents(latent, threshold=threshold))
|
| 757 |
+
|
| 758 |
+
@torch.no_grad()
|
| 759 |
+
def cent_to_f0(self, cent: torch.Tensor) -> torch.Tensor:
|
| 760 |
+
return 10 * 2 ** (cent / 1200)
|
| 761 |
+
|
| 762 |
+
@torch.no_grad()
|
| 763 |
+
def f0_to_cent(self, f0):
|
| 764 |
+
return 1200 * torch.log2(f0 / 10)
|
| 765 |
+
|
| 766 |
+
class FCPE_LEGACY(nn.Module):
|
| 767 |
+
def __init__(self, input_channel=128, out_dims=360, n_layers=12, n_chans=512, loss_mse_scale=10, loss_l2_regularization=False, loss_l2_regularization_scale=1, loss_grad1_mse=False, loss_grad1_mse_scale=1, f0_max=1975.5, f0_min=32.70, confidence=False, threshold=0.05, use_input_conv=True):
|
| 768 |
+
super().__init__()
|
| 769 |
+
self.loss_mse_scale = loss_mse_scale
|
| 770 |
+
self.loss_l2_regularization = loss_l2_regularization
|
| 771 |
+
self.loss_l2_regularization_scale = loss_l2_regularization_scale
|
| 772 |
+
self.loss_grad1_mse = loss_grad1_mse
|
| 773 |
+
self.loss_grad1_mse_scale = loss_grad1_mse_scale
|
| 774 |
+
self.f0_max = f0_max
|
| 775 |
+
self.f0_min = f0_min
|
| 776 |
+
self.confidence = confidence
|
| 777 |
+
self.threshold = threshold
|
| 778 |
+
self.use_input_conv = use_input_conv
|
| 779 |
+
self.cent_table_b = torch.Tensor(np.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims))
|
| 780 |
+
self.register_buffer("cent_table", self.cent_table_b)
|
| 781 |
+
self.stack = nn.Sequential(nn.Conv1d(input_channel, n_chans, 3, 1, 1), nn.GroupNorm(4, n_chans), nn.LeakyReLU(), nn.Conv1d(n_chans, n_chans, 3, 1, 1))
|
| 782 |
+
self.decoder = PCmer(num_layers=n_layers, num_heads=8, dim_model=n_chans, dim_keys=n_chans, dim_values=n_chans, residual_dropout=0.1, attention_dropout=0.1)
|
| 783 |
+
self.norm = nn.LayerNorm(n_chans)
|
| 784 |
+
self.n_out = out_dims
|
| 785 |
+
self.dense_out = weight_norm(nn.Linear(n_chans, self.n_out))
|
| 786 |
+
|
| 787 |
+
def forward(self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder="local_argmax", output_interp_target_length=None):
|
| 788 |
+
if cdecoder == "argmax": self.cdecoder = self.cents_decoder
|
| 789 |
+
elif cdecoder == "local_argmax": self.cdecoder = self.cents_local_decoder
|
| 790 |
+
|
| 791 |
+
x = torch.sigmoid(self.dense_out(self.norm(self.decoder((self.stack(mel.transpose(1, 2)).transpose(1, 2) if self.use_input_conv else mel)))))
|
| 792 |
+
|
| 793 |
+
if not infer:
|
| 794 |
+
loss_all = self.loss_mse_scale * F.binary_cross_entropy(x, self.gaussian_blurred_cent(self.f0_to_cent(gt_f0)))
|
| 795 |
+
if self.loss_l2_regularization: loss_all = loss_all + l2_regularization(model=self, l2_alpha=self.loss_l2_regularization_scale)
|
| 796 |
+
x = loss_all
|
| 797 |
+
else:
|
| 798 |
+
x = self.cent_to_f0(self.cdecoder(x))
|
| 799 |
+
x = (1 + x / 700).log() if not return_hz_f0 else x
|
| 800 |
+
|
| 801 |
+
if output_interp_target_length is not None:
|
| 802 |
+
x = F.interpolate(torch.where(x == 0, float("nan"), x).transpose(1, 2), size=int(output_interp_target_length), mode="linear").transpose(1, 2)
|
| 803 |
+
x = torch.where(x.isnan(), float(0.0), x)
|
| 804 |
+
|
| 805 |
+
return x
|
| 806 |
+
|
| 807 |
+
def cents_decoder(self, y, mask=True):
|
| 808 |
+
B, N, _ = y.size()
|
| 809 |
+
rtn = torch.sum(self.cent_table[None, None, :].expand(B, N, -1) * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True)
|
| 810 |
+
|
| 811 |
+
if mask:
|
| 812 |
+
confident = torch.max(y, dim=-1, keepdim=True)[0]
|
| 813 |
+
confident_mask = torch.ones_like(confident)
|
| 814 |
+
confident_mask[confident <= self.threshold] = float("-INF")
|
| 815 |
+
rtn = rtn * confident_mask
|
| 816 |
+
|
| 817 |
+
return (rtn, confident) if self.confidence else rtn
|
| 818 |
+
|
| 819 |
+
def cents_local_decoder(self, y, mask=True):
|
| 820 |
+
B, N, _ = y.size()
|
| 821 |
+
|
| 822 |
+
confident, max_index = torch.max(y, dim=-1, keepdim=True)
|
| 823 |
+
local_argmax_index = torch.clamp(torch.arange(0, 9).to(max_index.device) + (max_index - 4), 0, self.n_out - 1)
|
| 824 |
+
y_l = torch.gather(y, -1, local_argmax_index)
|
| 825 |
+
rtn = torch.sum(torch.gather(self.cent_table[None, None, :].expand(B, N, -1), -1, local_argmax_index) * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True)
|
| 826 |
+
|
| 827 |
+
if mask:
|
| 828 |
+
confident_mask = torch.ones_like(confident)
|
| 829 |
+
confident_mask[confident <= self.threshold] = float("-INF")
|
| 830 |
+
rtn = rtn * confident_mask
|
| 831 |
+
|
| 832 |
+
return (rtn, confident) if self.confidence else rtn
|
| 833 |
+
|
| 834 |
+
def cent_to_f0(self, cent):
|
| 835 |
+
return 10.0 * 2 ** (cent / 1200.0)
|
| 836 |
+
|
| 837 |
+
def f0_to_cent(self, f0):
|
| 838 |
+
return 1200.0 * torch.log2(f0 / 10.0)
|
| 839 |
+
|
| 840 |
+
def gaussian_blurred_cent(self, cents):
|
| 841 |
+
B, N, _ = cents.size()
|
| 842 |
+
return torch.exp(-torch.square(self.cent_table[None, None, :].expand(B, N, -1) - cents) / 1250) * (cents > 0.1) & (cents < (1200.0 * np.log2(self.f0_max / 10.0))).float()
|
| 843 |
+
|
| 844 |
+
class InferCFNaiveMelPE(torch.nn.Module):
|
| 845 |
+
def __init__(self, args, state_dict):
|
| 846 |
+
super().__init__()
|
| 847 |
+
self.wav2mel = spawn_wav2mel(args, device="cpu")
|
| 848 |
+
self.model = CFNaiveMelPE(input_channels=args.mel.num_mels, out_dims=args.model.out_dims, hidden_dims=args.model.hidden_dims, n_layers=args.model.n_layers, n_heads=args.model.n_heads, f0_max=args.model.f0_max, f0_min=args.model.f0_min, use_fa_norm=args.model.use_fa_norm, conv_only=args.model.conv_only, conv_dropout=args.model.conv_dropout, atten_dropout=args.model.atten_dropout, use_harmonic_emb=False)
|
| 849 |
+
self.model.load_state_dict(state_dict)
|
| 850 |
+
self.model.eval()
|
| 851 |
+
self.args_dict = dict(args)
|
| 852 |
+
self.register_buffer("tensor_device_marker", torch.tensor(1.0).float(), persistent=False)
|
| 853 |
+
|
| 854 |
+
def forward(self, wav, sr, decoder_mode = "local_argmax", threshold = 0.006, key_shifts = [0]):
|
| 855 |
+
with torch.no_grad():
|
| 856 |
+
mels = rearrange(torch.stack([self.wav2mel(wav.to(self.tensor_device_marker.device), sr, keyshift=keyshift) for keyshift in key_shifts], -1), "B T C K -> (B K) T C")
|
| 857 |
+
f0s = rearrange(self.model.infer(mels, decoder=decoder_mode, threshold=threshold), "(B K) T 1 -> B T (K 1)", K=len(key_shifts))
|
| 858 |
+
|
| 859 |
+
return f0s
|
| 860 |
+
|
| 861 |
+
def infer(self, wav, sr, decoder_mode = "local_argmax", threshold = 0.006, f0_min = None, f0_max = None, interp_uv = False, output_interp_target_length = None, return_uv = False, test_time_augmentation = False, tta_uv_penalty = 12.0, tta_key_shifts = [0, -12, 12], tta_use_origin_uv=False):
|
| 862 |
+
if test_time_augmentation:
|
| 863 |
+
assert len(tta_key_shifts) > 0
|
| 864 |
+
flag = 0
|
| 865 |
+
if tta_use_origin_uv:
|
| 866 |
+
if 0 not in tta_key_shifts:
|
| 867 |
+
flag = 1
|
| 868 |
+
tta_key_shifts.append(0)
|
| 869 |
+
|
| 870 |
+
tta_key_shifts.sort(key=lambda x: (x if x >= 0 else -x / 2))
|
| 871 |
+
f0s = self.__call__(wav, sr, decoder_mode, threshold, tta_key_shifts)
|
| 872 |
+
f0 = ensemble_f0(f0s[:, :, flag:], tta_key_shifts[flag:], tta_uv_penalty)
|
| 873 |
+
f0_for_uv = f0s[:, :, [0]] if tta_use_origin_uv else f0
|
| 874 |
+
else:
|
| 875 |
+
f0 = self.__call__(wav, sr, decoder_mode, threshold)
|
| 876 |
+
f0_for_uv = f0
|
| 877 |
+
|
| 878 |
+
if f0_min is None: f0_min = self.args_dict["model"]["f0_min"]
|
| 879 |
+
uv = (f0_for_uv < f0_min).type(f0_for_uv.dtype)
|
| 880 |
+
f0 = f0 * (1 - uv)
|
| 881 |
+
|
| 882 |
+
if interp_uv: f0 = batch_interp_with_replacement_detach(uv.squeeze(-1).bool(), f0.squeeze(-1)).unsqueeze(-1)
|
| 883 |
+
if f0_max is not None: f0[f0 > f0_max] = f0_max
|
| 884 |
+
if output_interp_target_length is not None:
|
| 885 |
+
f0 = F.interpolate(torch.where(f0 == 0, float("nan"), f0).transpose(1, 2), size=int(output_interp_target_length), mode="linear").transpose(1, 2)
|
| 886 |
+
f0 = torch.where(f0.isnan(), float(0.0), f0)
|
| 887 |
+
|
| 888 |
+
if return_uv: return f0, F.interpolate(uv.transpose(1, 2), size=int(output_interp_target_length), mode="nearest").transpose(1, 2)
|
| 889 |
+
else: return f0
|
| 890 |
+
|
| 891 |
+
class FCPEInfer_LEGACY:
|
| 892 |
+
def __init__(self, model_path, device=None, dtype=torch.float32, f0_min=50, f0_max=1100):
|
| 893 |
+
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 894 |
+
self.device = device
|
| 895 |
+
self.dtype = dtype
|
| 896 |
+
self.f0_min = f0_min
|
| 897 |
+
self.f0_max = f0_max
|
| 898 |
+
ckpt = torch.load(model_path, map_location=torch.device(self.device))
|
| 899 |
+
self.args = DotDict(ckpt["config"])
|
| 900 |
+
model = FCPE_LEGACY(input_channel=self.args.model.input_channel, out_dims=self.args.model.out_dims, n_layers=self.args.model.n_layers, n_chans=self.args.model.n_chans, loss_mse_scale=self.args.loss.loss_mse_scale, loss_l2_regularization=self.args.loss.loss_l2_regularization, loss_l2_regularization_scale=self.args.loss.loss_l2_regularization_scale, loss_grad1_mse=self.args.loss.loss_grad1_mse, loss_grad1_mse_scale=self.args.loss.loss_grad1_mse_scale, f0_max=self.f0_max, f0_min=self.f0_min, confidence=self.args.model.confidence)
|
| 901 |
+
model.to(self.device).to(self.dtype)
|
| 902 |
+
model.load_state_dict(ckpt["model"])
|
| 903 |
+
model.eval()
|
| 904 |
+
self.model = model
|
| 905 |
+
|
| 906 |
+
@torch.no_grad()
|
| 907 |
+
def __call__(self, audio, sr, threshold=0.05, p_len=None):
|
| 908 |
+
self.model.threshold = threshold
|
| 909 |
+
self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype)
|
| 910 |
+
|
| 911 |
+
return self.model(mel=self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype), infer=True, return_hz_f0=True, output_interp_target_length=p_len)
|
| 912 |
+
|
| 913 |
+
class FCPEInfer:
|
| 914 |
+
def __init__(self, model_path, device=None, dtype=torch.float32, f0_min=50, f0_max=1100):
|
| 915 |
+
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 916 |
+
self.device = device
|
| 917 |
+
self.dtype = dtype
|
| 918 |
+
self.f0_min = f0_min
|
| 919 |
+
self.f0_max = f0_max
|
| 920 |
+
ckpt = torch.load(model_path, map_location=torch.device(device))
|
| 921 |
+
ckpt["config_dict"]["model"]["conv_dropout"] = ckpt["config_dict"]["model"]["atten_dropout"] = 0.0
|
| 922 |
+
self.args = DotDict(ckpt["config_dict"])
|
| 923 |
+
model = InferCFNaiveMelPE(self.args, ckpt["model"])
|
| 924 |
+
model = model.to(device).to(self.dtype)
|
| 925 |
+
model.eval()
|
| 926 |
+
self.model = model
|
| 927 |
+
|
| 928 |
+
@torch.no_grad()
|
| 929 |
+
def __call__(self, audio, sr, threshold=0.05, p_len=None):
|
| 930 |
+
return self.model.infer(audio[None, :], sr, threshold=threshold, f0_min=self.f0_min, f0_max=self.f0_max, output_interp_target_length=p_len)
|
| 931 |
+
|
| 932 |
+
class FCPE:
|
| 933 |
+
def __init__(self, model_path, hop_length=512, f0_min=50, f0_max=1100, dtype=torch.float32, device=None, sample_rate=16000, threshold=0.05, legacy=False):
|
| 934 |
+
self.model = FCPEInfer_LEGACY if legacy else FCPEInfer
|
| 935 |
+
self.fcpe = self.model(model_path, device=device, dtype=dtype, f0_min=f0_min, f0_max=f0_max)
|
| 936 |
+
self.hop_length = hop_length
|
| 937 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 938 |
+
self.threshold = threshold
|
| 939 |
+
self.sample_rate = sample_rate
|
| 940 |
+
self.dtype = dtype
|
| 941 |
+
self.legacy = legacy
|
| 942 |
+
|
| 943 |
+
def compute_f0(self, wav, p_len=None):
|
| 944 |
+
x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
|
| 945 |
+
p_len = (x.shape[0] // self.hop_length) if p_len is None else p_len
|
| 946 |
+
|
| 947 |
+
f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold, p_len=p_len)
|
| 948 |
+
f0 = f0[:] if f0.dim() == 1 else f0[0, :, 0]
|
| 949 |
+
|
| 950 |
+
if torch.all(f0 == 0): return f0.cpu().numpy() if p_len is None else np.zeros(p_len), (f0.cpu().numpy() if p_len is None else np.zeros(p_len))
|
| 951 |
+
return f0.cpu().numpy()
|
RVC/modules/utils.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import sys
|
| 4 |
+
import torch
|
| 5 |
+
import codecs
|
| 6 |
+
import librosa
|
| 7 |
+
import requests
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import soundfile as sf
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
sys.path.append(os.getcwd())
|
| 14 |
+
|
| 15 |
+
from modules import opencl
|
| 16 |
+
|
| 17 |
+
def change_rms(source_audio, source_rate, target_audio, target_rate, rate):
|
| 18 |
+
rms2 = F.interpolate(torch.from_numpy(librosa.feature.rms(y=target_audio, frame_length=target_rate // 2 * 2, hop_length=target_rate // 2)).float().unsqueeze(0), size=target_audio.shape[0], mode="linear").squeeze()
|
| 19 |
+
return (target_audio * (torch.pow(F.interpolate(torch.from_numpy(librosa.feature.rms(y=source_audio, frame_length=source_rate // 2 * 2, hop_length=source_rate // 2)).float().unsqueeze(0), size=target_audio.shape[0], mode="linear").squeeze(), 1 - rate) * torch.pow(torch.maximum(rms2, torch.zeros_like(rms2) + 1e-6), rate - 1)).numpy())
|
| 20 |
+
|
| 21 |
+
def clear_gpu_cache():
|
| 22 |
+
gc.collect()
|
| 23 |
+
|
| 24 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 25 |
+
elif torch.backends.mps.is_available(): torch.mps.empty_cache()
|
| 26 |
+
elif opencl.is_available(): opencl.pytorch_ocl.empty_cache()
|
| 27 |
+
|
| 28 |
+
def HF_download_file(url, output_path=None):
|
| 29 |
+
url = url.replace("/blob/", "/resolve/").replace("?download=true", "").strip()
|
| 30 |
+
output_path = os.path.basename(url) if output_path is None else (os.path.join(output_path, os.path.basename(url)) if os.path.isdir(output_path) else output_path)
|
| 31 |
+
response = requests.get(url, stream=True, timeout=300)
|
| 32 |
+
|
| 33 |
+
if response.status_code == 200:
|
| 34 |
+
with open(output_path, "wb") as f:
|
| 35 |
+
for chunk in response.iter_content(chunk_size=10 * 1024 * 1024):
|
| 36 |
+
f.write(chunk)
|
| 37 |
+
|
| 38 |
+
return output_path
|
| 39 |
+
else: raise ValueError(response.status_code)
|
| 40 |
+
|
| 41 |
+
def check_predictors(method):
|
| 42 |
+
def download(predictors):
|
| 43 |
+
if not os.path.exists(os.path.join("models", predictors)):
|
| 44 |
+
HF_download_file(codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/cerqvpgbef/", "rot13") + predictors, os.path.join("models", predictors))
|
| 45 |
+
|
| 46 |
+
model_dict = {
|
| 47 |
+
**dict.fromkeys(["rmvpe", "rmvpe-legacy"], "rmvpe.pt"),
|
| 48 |
+
**dict.fromkeys(["fcpe"], "fcpe.pt"),
|
| 49 |
+
**dict.fromkeys(["fcpe-legacy"], "fcpe_legacy.pt"),
|
| 50 |
+
**dict.fromkeys(["crepe-full", "mangio-crepe-full"], "crepe_full.pth"),
|
| 51 |
+
**dict.fromkeys(["crepe-large", "mangio-crepe-large"], "crepe_large.pth"),
|
| 52 |
+
**dict.fromkeys(["crepe-medium", "mangio-crepe-medium"], "crepe_medium.pth"),
|
| 53 |
+
**dict.fromkeys(["crepe-small", "mangio-crepe-small"], "crepe_small.pth"),
|
| 54 |
+
**dict.fromkeys(["crepe-tiny", "mangio-crepe-tiny"], "crepe_tiny.pth"),
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
if method in model_dict: download(model_dict[method])
|
| 58 |
+
|
| 59 |
+
def check_embedders(hubert):
|
| 60 |
+
if hubert in ["contentvec_base", "hubert_base", "japanese_hubert_base", "korean_hubert_base", "chinese_hubert_base", "portuguese_hubert_base", "spin"]:
|
| 61 |
+
hubert += ".pt"
|
| 62 |
+
model_path = os.path.join("models", hubert)
|
| 63 |
+
if not os.path.exists(model_path):
|
| 64 |
+
HF_download_file("".join([codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/rzorqqref/", "rot13"), "fairseq/", hubert]), model_path)
|
| 65 |
+
|
| 66 |
+
def load_audio(file, sample_rate=16000):
|
| 67 |
+
try:
|
| 68 |
+
file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
|
| 69 |
+
if not os.path.isfile(file): raise FileNotFoundError(f"[ERROR] Not found audio: {file}")
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
audio, sr = sf.read(file, dtype=np.float32)
|
| 73 |
+
except:
|
| 74 |
+
audio, sr = librosa.load(file, sr=None)
|
| 75 |
+
|
| 76 |
+
if len(audio.shape) > 1: audio = librosa.to_mono(audio.T)
|
| 77 |
+
if sr != sample_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate, res_type="soxr_vhq")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
raise RuntimeError(f"[ERROR] Error reading audio file: {e}")
|
| 80 |
+
|
| 81 |
+
return audio.flatten()
|
| 82 |
+
|
| 83 |
+
class Autotune:
|
| 84 |
+
def __init__(self, ref_freqs):
|
| 85 |
+
self.ref_freqs = ref_freqs
|
| 86 |
+
self.note_dict = self.ref_freqs
|
| 87 |
+
|
| 88 |
+
def autotune_f0(self, f0, f0_autotune_strength):
|
| 89 |
+
autotuned_f0 = np.zeros_like(f0)
|
| 90 |
+
|
| 91 |
+
for i, freq in enumerate(f0):
|
| 92 |
+
autotuned_f0[i] = freq + (min(self.note_dict, key=lambda x: abs(x - freq)) - freq) * f0_autotune_strength
|
| 93 |
+
|
| 94 |
+
return autotuned_f0
|