NeoPy commited on
Commit
05aac64
·
verified ·
1 Parent(s): f8a7cd6
RVC/modules/attentions.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import torch
5
+
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ sys.path.append(os.getcwd())
10
+
11
+ from modules.commons import convert_pad_shape
12
+
13
+ class MultiHeadAttention(nn.Module):
14
+ def __init__(self, channels, out_channels, n_heads, p_dropout=0.0, window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
15
+ super().__init__()
16
+ assert channels % n_heads == 0
17
+ self.channels = channels
18
+ self.out_channels = out_channels
19
+ self.n_heads = n_heads
20
+ self.p_dropout = p_dropout
21
+ self.window_size = window_size
22
+ self.heads_share = heads_share
23
+ self.block_length = block_length
24
+ self.proximal_bias = proximal_bias
25
+ self.proximal_init = proximal_init
26
+ self.attn = None
27
+ self.k_channels = channels // n_heads
28
+ self.conv_q = nn.Conv1d(channels, channels, 1)
29
+ self.conv_k = nn.Conv1d(channels, channels, 1)
30
+ self.conv_v = nn.Conv1d(channels, channels, 1)
31
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
32
+ self.drop = nn.Dropout(p_dropout)
33
+
34
+ if window_size is not None:
35
+ n_heads_rel = 1 if heads_share else n_heads
36
+ rel_stddev = self.k_channels**-0.5
37
+
38
+ self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
39
+ self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
40
+
41
+ nn.init.xavier_uniform_(self.conv_q.weight)
42
+ nn.init.xavier_uniform_(self.conv_k.weight)
43
+ nn.init.xavier_uniform_(self.conv_v.weight)
44
+ nn.init.xavier_uniform_(self.conv_o.weight)
45
+
46
+ if proximal_init:
47
+ with torch.no_grad():
48
+ self.conv_k.weight.copy_(self.conv_q.weight)
49
+ self.conv_k.bias.copy_(self.conv_q.bias)
50
+
51
+ def forward(self, x, c, attn_mask=None):
52
+ q, k, v = self.conv_q(x), self.conv_k(c), self.conv_v(c)
53
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
54
+
55
+ return self.conv_o(x)
56
+
57
+ def attention(self, query, key, value, mask=None):
58
+ b, d, t_s, t_t = (*key.size(), query.size(2))
59
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
60
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
61
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
62
+
63
+ if self.window_size is not None:
64
+ assert (t_s == t_t)
65
+ scores += self._relative_position_to_absolute_position(self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), self._get_relative_embeddings(self.emb_rel_k, t_s)))
66
+
67
+ if self.proximal_bias:
68
+ assert t_s == t_t
69
+ scores += self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
70
+
71
+ if mask is not None:
72
+ scores = scores.masked_fill(mask == 0, -1e4)
73
+ if self.block_length is not None:
74
+ assert (t_s == t_t)
75
+ scores = scores.masked_fill((torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)) == 0, -1e4)
76
+
77
+ p_attn = self.drop(F.softmax(scores, dim=-1))
78
+ output = torch.matmul(p_attn, value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3))
79
+
80
+ if self.window_size is not None: output += self._matmul_with_relative_values(self._absolute_position_to_relative_position(p_attn), self._get_relative_embeddings(self.emb_rel_v, t_s))
81
+ return (output.transpose(2, 3).contiguous().view(b, d, t_t)), p_attn
82
+
83
+ def _matmul_with_relative_values(self, x, y):
84
+ return torch.matmul(x, y.unsqueeze(0))
85
+
86
+ def _matmul_with_relative_keys(self, x, y):
87
+ return torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
88
+
89
+ def _get_relative_embeddings(self, relative_embeddings, length):
90
+ pad_length = max(length - (self.window_size + 1), 0)
91
+ slice_start_position = max((self.window_size + 1) - length, 0)
92
+
93
+ return (F.pad(relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) if pad_length > 0 else relative_embeddings)[:, slice_start_position:(slice_start_position + 2 * length - 1)]
94
+
95
+ def _relative_position_to_absolute_position(self, x):
96
+ batch, heads, length, _ = x.size()
97
+
98
+ return F.pad(F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])).view([batch, heads, length * 2 * length]), convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])).view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
99
+
100
+ def _absolute_position_to_relative_position(self, x):
101
+ batch, heads, length, _ = x.size()
102
+
103
+ return F.pad(F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])).view([batch, heads, length**2 + length * (length - 1)]), convert_pad_shape([[0, 0], [0, 0], [length, 0]])).view([batch, heads, length, 2 * length])[:, :, :, 1:]
104
+
105
+ def _attention_bias_proximal(self, length):
106
+ r = torch.arange(length, dtype=torch.float32)
107
+
108
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs((torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)))), 0), 0)
109
+
110
+ class FFN(nn.Module):
111
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, activation=None, causal=False):
112
+ super().__init__()
113
+ self.in_channels = in_channels
114
+ self.out_channels = out_channels
115
+ self.filter_channels = filter_channels
116
+ self.kernel_size = kernel_size
117
+ self.p_dropout = p_dropout
118
+ self.activation = activation
119
+ self.causal = causal
120
+ self.padding = self._causal_padding if causal else self._same_padding
121
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
122
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
123
+ self.drop = nn.Dropout(p_dropout)
124
+
125
+ def forward(self, x, x_mask):
126
+ x = self.conv_1(self.padding(x * x_mask))
127
+
128
+ return self.conv_2(self.padding(self.drop(((x * torch.sigmoid(1.702 * x)) if self.activation == "gelu" else torch.relu(x))) * x_mask)) * x_mask
129
+
130
+ def _causal_padding(self, x):
131
+ if self.kernel_size == 1: return x
132
+
133
+ return F.pad(x, convert_pad_shape([[0, 0], [0, 0], [(self.kernel_size - 1), 0]]))
134
+
135
+ def _same_padding(self, x):
136
+ if self.kernel_size == 1: return x
137
+
138
+ return F.pad(x, convert_pad_shape([[0, 0], [0, 0], [((self.kernel_size - 1) // 2), (self.kernel_size // 2)]]))
RVC/modules/commons.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def init_weights(m, mean=0.0, std=0.01):
4
+ if m.__class__.__name__.find("Conv") != -1: m.weight.data.normal_(mean, std)
5
+
6
+ def get_padding(kernel_size, dilation=1):
7
+ return int((kernel_size * dilation - dilation) / 2)
8
+
9
+ def convert_pad_shape(pad_shape):
10
+ return [item for sublist in pad_shape[::-1] for item in sublist]
11
+
12
+ def slice_segments(x, ids_str, segment_size = 4, dim = 2):
13
+ if dim == 2: ret = torch.zeros_like(x[:, :segment_size])
14
+ elif dim == 3: ret = torch.zeros_like(x[:, :, :segment_size])
15
+
16
+ for i in range(x.size(0)):
17
+ idx_str = ids_str[i].item()
18
+ idx_end = idx_str + segment_size
19
+
20
+ if dim == 2: ret[i] = x[i, idx_str:idx_end]
21
+ else: ret[i] = x[i, :, idx_str:idx_end]
22
+
23
+ return ret
24
+
25
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
26
+ b, _, t = x.size()
27
+ if x_lengths is None: x_lengths = t
28
+
29
+ ids_str = (torch.rand([b]).to(device=x.device) * (x_lengths - segment_size + 1)).to(dtype=torch.long)
30
+
31
+ return slice_segments(x, ids_str, segment_size, dim=3), ids_str
32
+
33
+ @torch.jit.script
34
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
35
+ n_channels_int = n_channels[0]
36
+ in_act = input_a + input_b
37
+
38
+ return torch.tanh(in_act[:, :n_channels_int, :]) * torch.sigmoid(in_act[:, n_channels_int:, :])
39
+
40
+ def sequence_mask(length, max_length = None):
41
+ if max_length is None: max_length = length.max()
42
+ return torch.arange(max_length, dtype=length.dtype, device=length.device).unsqueeze(0) < length.unsqueeze(1)
43
+
44
+ def clip_grad_value(parameters, clip_value, norm_type=2):
45
+ if isinstance(parameters, torch.Tensor): parameters = [parameters]
46
+ norm_type = float(norm_type)
47
+
48
+ if clip_value is not None: clip_value = float(clip_value)
49
+ total_norm = 0
50
+
51
+ for p in list(filter(lambda p: p.grad is not None, parameters)):
52
+ total_norm += (p.grad.data.norm(norm_type)).item() ** norm_type
53
+
54
+ if clip_value is not None: p.grad.data.clamp_(min=-clip_value, max=clip_value)
55
+
56
+ return total_norm ** (1.0 / norm_type)
RVC/modules/config.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+ sys.path.append(os.getcwd())
6
+
7
+ from modules import opencl
8
+
9
+ def singleton(cls):
10
+ instances = {}
11
+
12
+ def get_instance(*args, **kwargs):
13
+ if cls not in instances: instances[cls] = cls(*args, **kwargs)
14
+ return instances[cls]
15
+
16
+ return get_instance
17
+
18
+ @singleton
19
+ class Config:
20
+ def __init__(self, cpu_mode=False, is_half=False):
21
+ self.device = "cuda:0" if torch.cuda.is_available() else ("ocl:0" if opencl.is_available() else "cpu")
22
+ self.is_half = is_half
23
+ self.gpu_mem = None
24
+ self.cpu_mode = cpu_mode
25
+ if cpu_mode: self.device = "cpu"
26
+
27
+ def device_config(self):
28
+ if not self.cpu_mode:
29
+ if self.device.startswith("cuda"): self.set_cuda_config()
30
+ elif opencl.is_available(): self.device = "ocl:0"
31
+ elif self.has_mps(): self.device = "mps"
32
+ else: self.device = "cpu"
33
+
34
+ if self.gpu_mem is not None and self.gpu_mem <= 4: return 1, 5, 30, 32
35
+ return (3, 10, 60, 65) if self.is_half else (1, 6, 38, 41)
36
+
37
+ def set_cuda_config(self):
38
+ i_device = int(self.device.split(":")[-1])
39
+ self.gpu_mem = torch.cuda.get_device_properties(i_device).total_memory // (1024**3)
40
+
41
+ def has_mps(self):
42
+ return torch.backends.mps.is_available()
RVC/modules/cut.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ class Slicer:
4
+ def __init__(self, sr, threshold = -40.0, min_length = 5000, min_interval = 300, hop_size = 20, max_sil_kept = 5000):
5
+ min_interval = sr * min_interval / 1000
6
+ self.threshold = 10 ** (threshold / 20.0)
7
+ self.hop_size = round(sr * hop_size / 1000)
8
+ self.win_size = min(round(min_interval), 4 * self.hop_size)
9
+ self.min_length = round(sr * min_length / 1000 / self.hop_size)
10
+ self.min_interval = round(min_interval / self.hop_size)
11
+ self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
12
+
13
+ def _apply_slice(self, waveform, begin, end):
14
+ start_idx = begin * self.hop_size
15
+
16
+ return waveform[:, start_idx:min(waveform.shape[1], end * self.hop_size)] if len(waveform.shape) > 1 else waveform[start_idx:min(waveform.shape[0], end * self.hop_size)]
17
+
18
+ def slice(self, waveform):
19
+ samples = waveform.mean(axis=0) if len(waveform.shape) > 1 else waveform
20
+ if samples.shape[0] <= self.min_length: return [waveform]
21
+ rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
22
+ sil_tags = []
23
+ silence_start, clip_start = None, 0
24
+
25
+ for i, rms in enumerate(rms_list):
26
+ if rms < self.threshold:
27
+ if silence_start is None: silence_start = i
28
+ continue
29
+
30
+ if silence_start is None: continue
31
+ is_leading_silence = silence_start == 0 and i > self.max_sil_kept
32
+ need_slice_middle = (i - silence_start >= self.min_interval and i - clip_start >= self.min_length)
33
+ if not is_leading_silence and not need_slice_middle:
34
+ silence_start = None
35
+ continue
36
+
37
+ if i - silence_start <= self.max_sil_kept:
38
+ pos = rms_list[silence_start : i + 1].argmin() + silence_start
39
+ sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
40
+ clip_start = pos
41
+ elif i - silence_start <= self.max_sil_kept * 2:
42
+ pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
43
+ pos += i - self.max_sil_kept
44
+ pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
45
+ if silence_start == 0:
46
+ sil_tags.append((0, pos_r))
47
+ clip_start = pos_r
48
+ else:
49
+ sil_tags.append((min((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos), max(pos_r, pos)))
50
+ clip_start = max(pos_r, pos)
51
+ else:
52
+ pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
53
+ sil_tags.append((0, pos_r) if silence_start == 0 else ((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos_r))
54
+ clip_start = pos_r
55
+
56
+ silence_start = None
57
+ total_frames = rms_list.shape[0]
58
+ if (silence_start is not None and total_frames - silence_start >= self.min_interval): sil_tags.append((rms_list[silence_start : min(total_frames, silence_start + self.max_sil_kept) + 1].argmin() + silence_start, total_frames + 1))
59
+
60
+ if not sil_tags: return [waveform]
61
+ else:
62
+ chunks = []
63
+ if sil_tags[0][0] > 0: chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0]))
64
+
65
+ for i in range(len(sil_tags) - 1):
66
+ chunks.append(self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]))
67
+
68
+ if sil_tags[-1][1] < total_frames: chunks.append(self._apply_slice(waveform, sil_tags[-1][1], total_frames))
69
+ return chunks
70
+
71
+ class Slicer2(Slicer):
72
+ def slice2(self, waveform):
73
+ samples = waveform.mean(axis=0) if len(waveform.shape) > 1 else waveform
74
+
75
+ if samples.shape[0] <= self.min_length: return [(waveform, 0, samples.shape[0])]
76
+ rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
77
+
78
+ sil_tags = []
79
+ silence_start, clip_start = None, 0
80
+
81
+ for i, rms in enumerate(rms_list):
82
+ if rms < self.threshold:
83
+ if silence_start is None: silence_start = i
84
+ continue
85
+
86
+ if silence_start is None: continue
87
+
88
+ is_leading_silence = silence_start == 0 and i > self.max_sil_kept
89
+ need_slice_middle = (i - silence_start >= self.min_interval and i - clip_start >= self.min_length)
90
+
91
+ if not is_leading_silence and not need_slice_middle:
92
+ silence_start = None
93
+ continue
94
+
95
+ if i - silence_start <= self.max_sil_kept:
96
+ pos = rms_list[silence_start : i + 1].argmin() + silence_start
97
+ sil_tags.append((0, pos) if silence_start == 0 else (pos, pos))
98
+ clip_start = pos
99
+ elif i - silence_start <= self.max_sil_kept * 2:
100
+ pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
101
+ pos += i - self.max_sil_kept
102
+
103
+ pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
104
+
105
+ if silence_start == 0:
106
+ sil_tags.append((0, pos_r))
107
+ clip_start = pos_r
108
+ else:
109
+ sil_tags.append((min((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos), max(pos_r, pos)))
110
+ clip_start = max(pos_r, pos)
111
+ else:
112
+ pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
113
+ sil_tags.append((0, pos_r) if silence_start == 0 else ((rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start), pos_r))
114
+ clip_start = pos_r
115
+
116
+ silence_start = None
117
+
118
+ total_frames = rms_list.shape[0]
119
+ if (silence_start is not None and total_frames - silence_start >= self.min_interval): sil_tags.append((rms_list[silence_start : min(total_frames, silence_start + self.max_sil_kept) + 1].argmin() + silence_start, total_frames + 1))
120
+
121
+ if not sil_tags: return [(waveform, 0, samples.shape[-1])]
122
+ else:
123
+ chunks = []
124
+ if sil_tags[0][0] > 0: chunks.append((self._apply_slice(waveform, 0, sil_tags[0][0]), 0, sil_tags[0][0] * self.hop_size))
125
+
126
+ for i in range(len(sil_tags) - 1):
127
+ chunks.append((self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]), sil_tags[i][1] * self.hop_size, sil_tags[i + 1][0] * self.hop_size))
128
+
129
+ if sil_tags[-1][1] < total_frames: chunks.append((self._apply_slice(waveform, sil_tags[-1][1], total_frames), sil_tags[-1][1] * self.hop_size, samples.shape[-1]))
130
+ return chunks
131
+
132
+ def get_rms(y, frame_length=2048, hop_length=512, pad_mode="constant"):
133
+ y = np.pad(y, (int(frame_length // 2), int(frame_length // 2)), mode=pad_mode)
134
+ axis = -1
135
+
136
+ x_shape_trimmed = list(y.shape)
137
+ x_shape_trimmed[axis] -= frame_length - 1
138
+ xw = np.moveaxis(np.lib.stride_tricks.as_strided(y, shape=tuple(x_shape_trimmed) + tuple([frame_length]), strides=y.strides + tuple([y.strides[axis]])), -1, axis - 1 if axis < 0 else axis + 1)
139
+
140
+ slices = [slice(None)] * xw.ndim
141
+ slices[axis] = slice(0, None, hop_length)
142
+
143
+ return np.sqrt(np.mean(np.abs(xw[tuple(slices)]) ** 2, axis=-2, keepdims=True))
144
+
145
+ def cut(audio, sr, db_thresh=-60, min_interval=250):
146
+ slicer = Slicer2(sr=sr, threshold=db_thresh, min_interval=min_interval)
147
+ return slicer.slice2(audio)
148
+
149
+ def restore(segments, total_len, dtype=np.float32):
150
+ out = []
151
+ last_end = 0
152
+
153
+ for start, end, processed_seg in segments:
154
+ if start > last_end: out.append(np.zeros(start - last_end, dtype=dtype))
155
+
156
+ out.append(processed_seg)
157
+ last_end = end
158
+
159
+ if last_end < total_len: out.append(np.zeros(total_len - last_end, dtype=dtype))
160
+ return np.concatenate(out, axis=-1)
RVC/modules/download.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import sys
4
+ import shutil
5
+
6
+ sys.path.append(os.getcwd())
7
+
8
+ from modules.utils import HF_download_file
9
+ from modules import gdown, meganz, mediafire, pixeldrain
10
+
11
+ def move_files_from_directory(src_dir, dest_models, model_name):
12
+ for root, _, files in os.walk(src_dir):
13
+ for file in files:
14
+ file_path = os.path.join(root, file)
15
+ if file.endswith(".index"):
16
+ filepath = os.path.join(dest_models, file.replace(' ', '_').replace('(', '').replace(')', '').replace('[', '').replace(']', '').replace(",", "").replace('"', "").replace("'", "").replace("|", "").strip())
17
+
18
+ shutil.move(file_path, filepath)
19
+ elif file.endswith(".pth") and not file.startswith("D_") and not file.startswith("G_"):
20
+ pth_path = os.path.join(dest_models, model_name + ".pth")
21
+
22
+ shutil.move(file_path, pth_path)
23
+
24
+ def save_drop_model(dropbox):
25
+ model_folders = "rvc_models"
26
+ save_model_temp = "save_model_temp"
27
+
28
+ if not os.path.exists(model_folders): os.makedirs(model_folders, exist_ok=True)
29
+ if not os.path.exists(save_model_temp): os.makedirs(save_model_temp, exist_ok=True)
30
+
31
+ shutil.move(dropbox, save_model_temp)
32
+
33
+ try:
34
+ print("[INFO] Start uploading...")
35
+
36
+ file_name = os.path.basename(dropbox)
37
+ model_folders = os.path.join(model_folders, file_name.replace(".zip", "").replace(".pth", "").replace(".index", ""))
38
+
39
+ if file_name.endswith(".zip"):
40
+ shutil.unpack_archive(os.path.join(save_model_temp, file_name), save_model_temp)
41
+ move_files_from_directory(save_model_temp, model_folders, file_name.replace(".zip", ""))
42
+ elif file_name.endswith(".pth"):
43
+ output_file = os.path.join(model_folders, file_name)
44
+ shutil.move(os.path.join(save_model_temp, file_name), output_file)
45
+ elif file_name.endswith(".index"):
46
+ def extract_name_model(filename):
47
+ match = re.search(r"([A-Za-z]+)(?=_v|\.|$)", filename)
48
+ return match.group(1) if match else None
49
+
50
+ model_logs = os.path.join(model_folders, extract_name_model(file_name))
51
+ if not os.path.exists(model_logs): os.makedirs(model_logs, exist_ok=True)
52
+ shutil.move(os.path.join(save_model_temp, file_name), model_logs)
53
+ else:
54
+ print("[WARNING] Format not supported. Supported formats ('.zip', '.pth', '.index')")
55
+ return
56
+
57
+ print("[INFO] Completed upload.")
58
+ except Exception as e:
59
+ print(f"[ERROR] An error occurred during unpack: {e}")
60
+ finally:
61
+ shutil.rmtree(save_model_temp, ignore_errors=True)
62
+
63
+ def download_model(url=None, model=None):
64
+ if not url:
65
+ print("[WARNING] Please provide a valid url.")
66
+ return
67
+
68
+ if not model:
69
+ print("[WARNING] Please provide a valid model name.")
70
+ return
71
+
72
+ model = model.replace(".pth", "").replace(".index", "").replace(".zip", "").replace(" ", "_").replace("(", "").replace(")", "").replace("[", "").replace("]", "").replace(",", "").replace('"', "").replace("'", "").replace("|", "").strip()
73
+ url = url.replace("/blob/", "/resolve/").replace("?download=true", "").strip()
74
+
75
+ download_dir = "download_model"
76
+ model_folders = "rvc_models"
77
+
78
+ if not os.path.exists(download_dir): os.makedirs(download_dir, exist_ok=True)
79
+ if not os.path.exists(model_folders): os.makedirs(model_folders, exist_ok=True)
80
+
81
+ model_folders = os.path.join(model_folders, model)
82
+ os.makedirs(model_folders, exist_ok=True)
83
+
84
+ try:
85
+ print("[INFO] Start downloading...")
86
+
87
+ if url.endswith(".pth"): HF_download_file(url, os.path.join(model_folders, f"{model}.pth"))
88
+ elif url.endswith(".index"): HF_download_file(url, os.path.join(model_folders, f"{model}.index"))
89
+ elif url.endswith(".zip"):
90
+ output_path = HF_download_file(url, os.path.join(download_dir, model + ".zip"))
91
+ shutil.unpack_archive(output_path, download_dir)
92
+
93
+ move_files_from_directory(download_dir, model_folders, model)
94
+ else:
95
+ if "drive.google.com" in url or "drive.usercontent.google.com" in url:
96
+ file_id = None
97
+
98
+ if "/file/d/" in url: file_id = url.split("/d/")[1].split("/")[0]
99
+ elif "open?id=" in url: file_id = url.split("open?id=")[1].split("/")[0]
100
+ elif "/download?id=" in url: file_id = url.split("/download?id=")[1].split("&")[0]
101
+
102
+ if file_id:
103
+ file = gdown.gdown_download(id=file_id, output=download_dir)
104
+ if file.endswith(".zip"): shutil.unpack_archive(file, download_dir)
105
+
106
+ move_files_from_directory(download_dir, model_folders, model)
107
+ elif "mega.nz" in url:
108
+ meganz.mega_download_url(url, download_dir)
109
+
110
+ file_download = next((f for f in os.listdir(download_dir)), None)
111
+ if file_download.endswith(".zip"): shutil.unpack_archive(os.path.join(download_dir, file_download), download_dir)
112
+
113
+ move_files_from_directory(download_dir, model_folders, model)
114
+ elif "mediafire.com" in url:
115
+ file = mediafire.Mediafire_Download(url, download_dir)
116
+ if file.endswith(".zip"): shutil.unpack_archive(file, download_dir)
117
+
118
+ move_files_from_directory(download_dir, model_folders, model)
119
+ elif "pixeldrain.com" in url:
120
+ file = pixeldrain.pixeldrain(url, download_dir)
121
+ if file.endswith(".zip"): shutil.unpack_archive(file, download_dir)
122
+
123
+ move_files_from_directory(download_dir, model_folders, model)
124
+ else:
125
+ print("[WARNING] The url path is not supported.")
126
+ return
127
+
128
+ print("[INFO] Model download complete.")
129
+ except Exception as e:
130
+ print(f"[INFO] An error has occurred: {e}")
131
+ finally:
132
+ shutil.rmtree(download_dir, ignore_errors=True)
RVC/modules/encoders.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import torch
5
+
6
+ sys.path.append(os.getcwd())
7
+
8
+ from modules.modules import WaveNet
9
+ from modules.commons import sequence_mask
10
+ from modules.normalization import LayerNorm
11
+ from modules.attentions import MultiHeadAttention, FFN
12
+
13
+ class Encoder(torch.nn.Module):
14
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.0, window_size=10, **kwargs):
15
+ super().__init__()
16
+ self.hidden_channels = hidden_channels
17
+ self.filter_channels = filter_channels
18
+ self.n_heads = n_heads
19
+ self.n_layers = n_layers
20
+ self.kernel_size = kernel_size
21
+ self.p_dropout = p_dropout
22
+ self.window_size = window_size
23
+ self.drop = torch.nn.Dropout(p_dropout)
24
+ self.attn_layers = torch.nn.ModuleList()
25
+ self.norm_layers_1 = torch.nn.ModuleList()
26
+ self.ffn_layers = torch.nn.ModuleList()
27
+ self.norm_layers_2 = torch.nn.ModuleList()
28
+
29
+ for _ in range(self.n_layers):
30
+ self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
31
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
32
+
33
+ self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
34
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
35
+
36
+ def forward(self, x, x_mask):
37
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
38
+ x = x * x_mask
39
+
40
+ for i in range(self.n_layers):
41
+ x = self.norm_layers_1[i](x + self.drop(self.attn_layers[i](x, x, attn_mask)))
42
+ x = self.norm_layers_2[i](x + self.drop(self.ffn_layers[i](x, x_mask)))
43
+
44
+ return x * x_mask
45
+
46
+ class TextEncoder(torch.nn.Module):
47
+ def __init__(self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, embedding_dim, f0=True, energy=False, onnx=False):
48
+ super(TextEncoder, self).__init__()
49
+ self.out_channels = out_channels
50
+ self.hidden_channels = hidden_channels
51
+ self.filter_channels = filter_channels
52
+ self.n_heads = n_heads
53
+ self.n_layers = n_layers
54
+ self.kernel_size = kernel_size
55
+ self.p_dropout = float(p_dropout)
56
+ self.lrelu = torch.nn.LeakyReLU(0.1, inplace=True)
57
+ self.emb_phone = torch.nn.Linear(embedding_dim, hidden_channels)
58
+ self.emb_pitch = torch.nn.Embedding(256, hidden_channels) if f0 else None
59
+ self.emb_energy = torch.nn.Linear(1, hidden_channels) if energy else None
60
+ self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), onnx=onnx)
61
+ self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
62
+
63
+ def forward(self, phone, pitch, lengths, energy):
64
+ x = self.emb_phone(phone)
65
+
66
+ if pitch is not None: x += self.emb_pitch(pitch)
67
+ if energy is not None: x += self.emb_energy(energy.unsqueeze(-1))
68
+
69
+ x = torch.transpose(self.lrelu(x * math.sqrt(self.hidden_channels)), 1, -1)
70
+ x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype)
71
+ m, logs = torch.split((self.proj(self.encoder(x * x_mask, x_mask)) * x_mask), self.out_channels, dim=1)
72
+
73
+ return m, logs, x_mask
74
+
75
+ class PosteriorEncoder(torch.nn.Module):
76
+ def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0):
77
+ super(PosteriorEncoder, self).__init__()
78
+ self.in_channels = in_channels
79
+ self.out_channels = out_channels
80
+ self.hidden_channels = hidden_channels
81
+ self.kernel_size = kernel_size
82
+ self.dilation_rate = dilation_rate
83
+ self.n_layers = n_layers
84
+ self.gin_channels = gin_channels
85
+ self.pre = torch.nn.Conv1d(in_channels, hidden_channels, 1)
86
+ self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
87
+ self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
88
+
89
+ def forward(self, x, x_lengths, g = None):
90
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
91
+ m, logs = torch.split((self.proj(self.enc((self.pre(x) * x_mask), x_mask, g=g)) * x_mask), self.out_channels, dim=1)
92
+
93
+ return ((m + torch.randn_like(m) * torch.exp(logs)) * x_mask), m, logs, x_mask
94
+
95
+ def remove_weight_norm(self):
96
+ self.enc.remove_weight_norm()
RVC/modules/fairseq.py ADDED
@@ -0,0 +1,1396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ import math
4
+ import uuid
5
+ import torch
6
+ import types
7
+ import contextlib
8
+
9
+ import numpy as np
10
+ import torch.nn.functional as F
11
+
12
+ from torch import nn
13
+ from omegaconf import DictConfig, open_dict
14
+
15
+ class Dictionary:
16
+ def __init__(self, *args, **kwargs):
17
+ pass
18
+
19
+ fairseq = types.ModuleType("fairseq")
20
+ fairseq_data = types.ModuleType("fairseq.data")
21
+ fairseq_data_dictionary = types.ModuleType("fairseq.data.dictionary")
22
+ fairseq_data_dictionary.Dictionary = Dictionary
23
+ fairseq.data = fairseq_data
24
+ fairseq_data.dictionary = fairseq_data_dictionary
25
+ sys.modules["fairseq"] = fairseq
26
+ sys.modules["fairseq.data"] = fairseq_data
27
+ sys.modules["fairseq.data.dictionary"] = fairseq_data_dictionary
28
+
29
+ def load_model(filename):
30
+ state = torch.load(filename, map_location="cpu")
31
+ model = HubertModel(HubertConfig(**state['cfg']['model']))
32
+ model.load_state_dict(state['model'], strict=False)
33
+ return model
34
+
35
+ def softmax(x, dim, onnx_trace = False):
36
+ return F.softmax(x.float(), dim=dim) if onnx_trace else F.softmax(x, dim=dim, dtype=torch.float32)
37
+
38
+ def log_softmax(x, dim, onnx_trace = False):
39
+ return F.log_softmax(x.float(), dim=dim) if onnx_trace else F.log_softmax(x, dim=dim, dtype=torch.float32)
40
+
41
+ def eval_str_dict(x, type=dict):
42
+ if x is None: return None
43
+ if isinstance(x, str): x = eval(x)
44
+ return x
45
+
46
+ def with_incremental_state(cls):
47
+ cls.__bases__ = (FairseqIncrementalState,) + tuple(b for b in cls.__bases__ if b != FairseqIncrementalState)
48
+ return cls
49
+
50
+ def quant_noise(module, p, block_size):
51
+ if p <= 0: return module
52
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
53
+ is_conv = module.weight.ndim == 4
54
+ if not is_conv: assert (module.weight.size(1) % block_size == 0)
55
+ else:
56
+ if module.kernel_size == (1, 1): assert (module.in_channels % block_size == 0)
57
+ else:
58
+ k = module.kernel_size[0] * module.kernel_size[1]
59
+ assert k % block_size == 0
60
+
61
+ def _forward_pre_hook(mod, input):
62
+ if mod.training:
63
+ if not is_conv:
64
+ weight = mod.weight
65
+ in_features = weight.size(1)
66
+ out_features = weight.size(0)
67
+ mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
68
+ mask.bernoulli_(p)
69
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
70
+ else:
71
+ weight = mod.weight
72
+ in_channels = mod.in_channels
73
+ out_channels = mod.out_channels
74
+
75
+ if mod.kernel_size == (1, 1):
76
+ mask = torch.zeros(int(in_channels // block_size * out_channels), device=weight.device)
77
+ mask.bernoulli_(p)
78
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
79
+ else:
80
+ mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
81
+ mask.bernoulli_(p)
82
+ mask = (mask.unsqueeze(2).unsqueeze(3).repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]))
83
+
84
+ mask = mask.to(torch.bool)
85
+ s = 1 / (1 - p)
86
+ mod.weight.data = s * weight.masked_fill(mask, 0)
87
+
88
+ module.register_forward_pre_hook(_forward_pre_hook)
89
+ return module
90
+
91
+ class FairseqDropout(nn.Module):
92
+ def __init__(self, p, module_name=None):
93
+ super().__init__()
94
+ self.p = p
95
+ self.module_name = module_name
96
+ self.apply_during_inference = False
97
+
98
+ def forward(self, x, inplace = False):
99
+ return F.dropout(x, p=self.p, training=True, inplace=inplace) if self.p > 0 and (self.training or self.apply_during_inference) else x
100
+
101
+ def make_generation_fast_(self, name, retain_dropout = False, retain_dropout_modules = None, **kwargs):
102
+ if retain_dropout:
103
+ if (retain_dropout_modules is None or self.module_name in retain_dropout_modules): self.apply_during_inference = True
104
+
105
+ class FairseqIncrementalState(object):
106
+ def __init__(self, *args, **kwargs):
107
+ super().__init__(*args, **kwargs)
108
+ self.init_incremental_state()
109
+
110
+ def init_incremental_state(self):
111
+ self._incremental_state_id = str(uuid.uuid4())
112
+
113
+ def _get_full_incremental_state_key(self, key):
114
+ return "{}.{}".format(self._incremental_state_id, key)
115
+
116
+ def get_incremental_state(self, incremental_state, key):
117
+ full_key = self._get_full_incremental_state_key(key)
118
+ if incremental_state is None or full_key not in incremental_state: return None
119
+ return incremental_state[full_key]
120
+
121
+ def set_incremental_state(self, incremental_state, key, value):
122
+ if incremental_state is not None: incremental_state[self._get_full_incremental_state_key(key)] = value
123
+ return incremental_state
124
+
125
+ class FairseqDecoder(nn.Module):
126
+ def __init__(self, dictionary):
127
+ super().__init__()
128
+ self.dictionary = dictionary
129
+ self.onnx_trace = False
130
+ self.adaptive_softmax = None
131
+
132
+ def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
133
+ x, extra = self.extract_features(prev_output_tokens, encoder_out=encoder_out, **kwargs)
134
+ return self.output_layer(x), extra
135
+
136
+ def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs):
137
+ pass
138
+
139
+ def output_layer(self, features, **kwargs):
140
+ pass
141
+
142
+ def get_normalized_probs(self, net_output, log_probs, sample = None):
143
+ return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
144
+
145
+ def get_normalized_probs_scriptable(self, net_output, log_probs, sample = None):
146
+ if hasattr(self, "adaptive_softmax") and self.adaptive_softmax is not None:
147
+ if sample is not None:
148
+ assert "target" in sample
149
+ target = sample["target"]
150
+ else: target = None
151
+ out = self.adaptive_softmax.get_log_prob(net_output[0], target=target)
152
+ return out.exp_() if not log_probs else out
153
+
154
+ logits = net_output[0]
155
+ return log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace) if log_probs else softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
156
+
157
+ def max_positions(self):
158
+ return 1e6
159
+
160
+ def upgrade_state_dict_named(self, state_dict, name):
161
+ return state_dict
162
+
163
+ def prepare_for_onnx_export_(self):
164
+ self.onnx_trace = True
165
+
166
+ @with_incremental_state
167
+ class FairseqIncrementalDecoder(FairseqDecoder):
168
+ def __init__(self, dictionary):
169
+ super().__init__(dictionary)
170
+
171
+ def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
172
+ pass
173
+
174
+ def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
175
+ pass
176
+
177
+ def reorder_incremental_state(self, incremental_state, new_order):
178
+ pass
179
+
180
+ def reorder_incremental_state_scripting(self, incremental_state, new_order):
181
+ for module in self.modules():
182
+ if hasattr(module, "reorder_incremental_state"):
183
+ result = module.reorder_incremental_state(incremental_state, new_order)
184
+ if result is not None: incremental_state = result
185
+
186
+ def set_beam_size(self, beam_size):
187
+ if getattr(self, "_beam_size", -1) != beam_size:
188
+ seen = set()
189
+
190
+ def apply_set_beam_size(module):
191
+ if (module != self and hasattr(module, "set_beam_size") and module not in seen):
192
+ seen.add(module)
193
+ module.set_beam_size(beam_size)
194
+
195
+ self.apply(apply_set_beam_size)
196
+ self._beam_size = beam_size
197
+
198
+ class MultiheadAttention(FairseqIncrementalDecoder):
199
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, dictionary=None, q_noise=0.0, qn_block_size=8, xformers_att_config=None, xformers_blocksparse_layout=None, xformers_blocksparse_blocksize=16):
200
+ super().__init__(dictionary)
201
+ xformers_att_config = eval_str_dict(xformers_att_config)
202
+ self.use_xformers = xformers_att_config is not None
203
+ if self.use_xformers: raise ImportError
204
+ self.embed_dim = embed_dim
205
+ self.kdim = kdim if kdim is not None else embed_dim
206
+ self.vdim = vdim if vdim is not None else embed_dim
207
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
208
+ self.num_heads = num_heads
209
+ self.dropout_module = FairseqDropout(dropout, module_name=self.__class__.__name__)
210
+ self.head_dim = embed_dim // num_heads
211
+ assert (self.head_dim * num_heads == self.embed_dim)
212
+ self.scaling = self.head_dim**-0.5
213
+ self.self_attention = self_attention
214
+ self.encoder_decoder_attention = encoder_decoder_attention
215
+ assert not self.self_attention or self.qkv_same_dim
216
+ self.k_proj = quant_noise(nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size)
217
+ self.v_proj = quant_noise(nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size)
218
+ self.q_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
219
+ self.out_proj = quant_noise(nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size)
220
+ if add_bias_kv: self.bias_k, self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim)), nn.Parameter(torch.Tensor(1, 1, embed_dim))
221
+ else: self.bias_k = self.bias_v = None
222
+ self.add_zero_attn = add_zero_attn
223
+ self.beam_size = 1
224
+ self.reset_parameters()
225
+ self.onnx_trace = False
226
+ self.skip_embed_dim_check = False
227
+ self.init_incremental_state()
228
+
229
+ def prepare_for_onnx_export_(self):
230
+ self.onnx_trace = True
231
+
232
+ def reset_parameters(self):
233
+ if self.qkv_same_dim:
234
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
235
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
236
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
237
+ else:
238
+ nn.init.xavier_uniform_(self.k_proj.weight)
239
+ nn.init.xavier_uniform_(self.v_proj.weight)
240
+ nn.init.xavier_uniform_(self.q_proj.weight)
241
+
242
+ nn.init.xavier_uniform_(self.out_proj.weight)
243
+ if self.out_proj.bias is not None: nn.init.constant_(self.out_proj.bias, 0.0)
244
+ if self.bias_k is not None: nn.init.xavier_normal_(self.bias_k)
245
+ if self.bias_v is not None: nn.init.xavier_normal_(self.bias_v)
246
+
247
+ def _get_reserve_head_index(self, num_heads_to_keep: int):
248
+ k_proj_heads_norm, q_proj_heads_norm, v_proj_heads_norm = [], [], []
249
+ for i in range(self.num_heads):
250
+ start_idx = i * self.head_dim
251
+ end_idx = (i + 1) * self.head_dim
252
+ k_proj_heads_norm.append(torch.sum(torch.abs(self.k_proj.weight[start_idx:end_idx])).tolist() + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist())
253
+ q_proj_heads_norm.append(torch.sum(torch.abs(self.q_proj.weight[start_idx:end_idx])).tolist() + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist())
254
+ v_proj_heads_norm.append(torch.sum(torch.abs(self.v_proj.weight[start_idx:end_idx])).tolist() + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist())
255
+
256
+ heads_norm = []
257
+ for i in range(self.num_heads):
258
+ heads_norm.append(k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i])
259
+
260
+ sorted_head_index = sorted(range(self.num_heads), key=lambda k: heads_norm[k], reverse=True)
261
+ reserve_head_index = []
262
+ for i in range(num_heads_to_keep):
263
+ reserve_head_index.append((sorted_head_index[i] * self.head_dim, (sorted_head_index[i] + 1) * self.head_dim))
264
+ return reserve_head_index
265
+
266
+ def _adaptive_prune_heads(self, reserve_head_index):
267
+ new_q_weight, new_q_bias, new_k_weight, new_k_bias, new_v_weight, new_v_bias, new_out_proj_weight = [], [], [], [], [], [], []
268
+ for ele in reserve_head_index:
269
+ start_idx, end_idx = ele
270
+ new_q_weight.append(self.q_proj.weight[start_idx:end_idx])
271
+ new_q_bias.append(self.q_proj.bias[start_idx:end_idx])
272
+ new_k_weight.append(self.k_proj.weight[start_idx:end_idx])
273
+ new_k_bias.append(self.k_proj.bias[start_idx:end_idx])
274
+ new_v_weight.append(self.v_proj.weight[start_idx:end_idx])
275
+ new_v_bias.append(self.v_proj.bias[start_idx:end_idx])
276
+ new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx])
277
+ new_q_weight = torch.cat(new_q_weight).detach()
278
+ new_k_weight = torch.cat(new_k_weight).detach()
279
+ new_v_weight = torch.cat(new_v_weight).detach()
280
+ new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach()
281
+ new_q_weight.requires_grad = True
282
+ new_k_weight.requires_grad = True
283
+ new_v_weight.requires_grad = True
284
+ new_out_proj_weight.requires_grad = True
285
+ new_q_bias = torch.cat(new_q_bias).detach()
286
+ new_q_bias.requires_grad = True
287
+ new_k_bias = torch.cat(new_k_bias).detach()
288
+ new_k_bias.requires_grad = True
289
+ new_v_bias = torch.cat(new_v_bias).detach()
290
+ new_v_bias.requires_grad = True
291
+ self.q_proj.weight = nn.Parameter(new_q_weight)
292
+ self.q_proj.bias = nn.Parameter(new_q_bias)
293
+ self.k_proj.weight = nn.Parameter(new_k_weight)
294
+ self.k_proj.bias = nn.Parameter(new_k_bias)
295
+ self.v_proj.weight = nn.Parameter(new_v_weight)
296
+ self.v_proj.bias = nn.Parameter(new_v_bias)
297
+ self.out_proj.weight = nn.Parameter(new_out_proj_weight)
298
+ self.num_heads = len(reserve_head_index)
299
+ self.embed_dim = self.head_dim * self.num_heads
300
+ self.q_proj.out_features = self.embed_dim
301
+ self.k_proj.out_features = self.embed_dim
302
+ self.v_proj.out_features = self.embed_dim
303
+
304
+ def _set_skip_embed_dim_check(self):
305
+ self.skip_embed_dim_check = True
306
+
307
+ def _pad_masks(self, key_padding_mask, attn_mask):
308
+ if attn_mask is not None:
309
+ shape = attn_mask.size()[:-1] + torch.Size([1])
310
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1)
311
+
312
+ if key_padding_mask is not None:
313
+ shape = key_padding_mask.size()[:-1] + torch.Size([1])
314
+ key_padding_mask = torch.cat([key_padding_mask, key_padding_mask.new_zeros(shape)], dim=-1)
315
+
316
+ return key_padding_mask, attn_mask
317
+
318
+ def _add_bias(self, k, v, key_padding_mask, attn_mask, bsz):
319
+ assert self.bias_k is not None or self.bias_v is not None
320
+ key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask)
321
+ return torch.cat([k, self.bias_k.repeat(1, bsz, 1)]), torch.cat([v, self.bias_v.repeat(1, bsz, 1)]), key_padding_mask, attn_mask
322
+
323
+ def _append_zero_attn(self, k, v, key_padding_mask, attn_mask):
324
+ zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:]
325
+ key_padding_mask, attn_mask = self._pad_masks(key_padding_mask=key_padding_mask, attn_mask=attn_mask)
326
+ return torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2), torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2), key_padding_mask, attn_mask
327
+
328
+ def forward(self, query, key, value, key_padding_mask = None, incremental_state = None, need_weights = True, static_kv = False, attn_mask = None, before_softmax = False, need_head_weights = False):
329
+ if need_head_weights: need_weights = True
330
+ is_tpu = query.device.type == "xla"
331
+ tgt_len, bsz, embed_dim = query.size()
332
+ src_len = tgt_len
333
+ if not self.skip_embed_dim_check: assert (embed_dim == self.embed_dim)
334
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
335
+ if key is not None:
336
+ src_len, key_bsz, _ = key.size()
337
+ if not torch.jit.is_scripting():
338
+ assert value is not None
339
+ assert src_len, key_bsz == value.shape[:2]
340
+
341
+ if (not self.onnx_trace and not is_tpu and incremental_state is None and not static_kv and not torch.jit.is_scripting() and not self.skip_embed_dim_check):
342
+ assert key is not None and value is not None
343
+ return F.multi_head_attention_forward(query, key, value, self.embed_dim, self.num_heads, torch.empty([0]), torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), self.bias_k, self.bias_v, self.add_zero_attn, self.dropout_module.p, self.out_proj.weight, self.out_proj.bias, self.training or self.dropout_module.apply_during_inference, key_padding_mask.bool() if key_padding_mask is not None else None, need_weights, attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight)
344
+
345
+ if incremental_state is not None:
346
+ saved_state = self._get_input_buffer(incremental_state)
347
+ if saved_state is not None and "prev_key" in saved_state:
348
+ if static_kv:
349
+ assert self.encoder_decoder_attention and not self.self_attention
350
+ key = value = None
351
+ else: saved_state = None
352
+
353
+ if self.self_attention:
354
+ q = self.q_proj(query)
355
+ k = self.k_proj(query)
356
+ v = self.v_proj(query)
357
+ elif self.encoder_decoder_attention:
358
+ q = self.q_proj(query)
359
+ if key is None:
360
+ assert value is None
361
+ k = v = None
362
+ else:
363
+ if self.beam_size > 1 and bsz == key.size(1):
364
+ key = key.view(key.size(0), -1, self.beam_size, key.size(2))[:, :, 0, :]
365
+ if key_padding_mask is not None: key_padding_mask = key_padding_mask.view(-1, self.beam_size, key_padding_mask.size(1))[:, 0, :]
366
+ k = self.k_proj(key)
367
+ v = self.v_proj(key)
368
+ else:
369
+ assert key is not None and value is not None
370
+ q = self.q_proj(query)
371
+ k = self.k_proj(key)
372
+ v = self.v_proj(value)
373
+
374
+ q *= self.scaling
375
+ if self.bias_k is not None:
376
+ assert self.bias_v is not None
377
+ k, v, attn_mask, key_padding_mask = self._add_bias(k, v, attn_mask, key_padding_mask, bsz)
378
+
379
+ q = (q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1))
380
+ kv_bsz = bsz
381
+ if k is not None:
382
+ kv_bsz = k.size(1)
383
+ k = (k.contiguous().view(-1, kv_bsz * self.num_heads, self.head_dim).transpose(0, 1))
384
+
385
+ if v is not None: v = (v.contiguous().view(-1, kv_bsz * self.num_heads, self.head_dim).transpose(0, 1))
386
+ if saved_state is not None:
387
+ if "prev_key" in saved_state:
388
+ _prev_key = saved_state["prev_key"]
389
+ assert _prev_key is not None
390
+
391
+ kv_bsz = _prev_key.size(0)
392
+ prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim)
393
+
394
+ if static_kv: k = prev_key
395
+ else:
396
+ assert k is not None
397
+ k = torch.cat([prev_key, k], dim=1)
398
+ src_len = k.size(1)
399
+
400
+ if "prev_value" in saved_state:
401
+ _prev_value = saved_state["prev_value"]
402
+ assert _prev_value is not None or kv_bsz == _prev_value.size(0)
403
+ prev_value = _prev_value.view(kv_bsz * self.num_heads, -1, self.head_dim)
404
+ if static_kv: v = prev_value
405
+ else:
406
+ assert v is not None
407
+ v = torch.cat([prev_value, v], dim=1)
408
+
409
+ prev_key_padding_mask = None
410
+ if "prev_key_padding_mask" in saved_state: prev_key_padding_mask = saved_state["prev_key_padding_mask"]
411
+ assert k is not None and v is not None
412
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(key_padding_mask=key_padding_mask, prev_key_padding_mask=prev_key_padding_mask, batch_size=kv_bsz, src_len=k.size(1), static_kv=static_kv)
413
+ saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim)
414
+ saved_state["prev_value"] = v.view(kv_bsz, self.num_heads, -1, self.head_dim)
415
+ saved_state["prev_key_padding_mask"] = key_padding_mask
416
+ assert incremental_state is not None
417
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
418
+
419
+ assert k is not None
420
+ assert k.size(1) == src_len
421
+
422
+ if key_padding_mask is not None and key_padding_mask.dim() == 0: key_padding_mask = None
423
+
424
+ if key_padding_mask is not None:
425
+ assert key_padding_mask.size(0) == kv_bsz
426
+ assert key_padding_mask.size(1) == src_len
427
+
428
+ if self.add_zero_attn:
429
+ assert v is not None
430
+ src_len += 1
431
+ k, v, key_padding_mask, attn_mask = self._append_zero_attn(k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
432
+
433
+ if self.encoder_decoder_attention and bsz != kv_bsz:
434
+ attn_weights = torch.einsum("bxhtd,bhsd->bxhts", q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]), k.view((kv_bsz, self.num_heads) + k.size()[1:]))
435
+ attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:])
436
+ else: attn_weights = torch.bmm(q, k.transpose(1, 2))
437
+
438
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
439
+
440
+ if attn_mask is not None:
441
+ attn_mask = attn_mask.unsqueeze(0)
442
+ if self.onnx_trace: attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
443
+ attn_weights += attn_mask
444
+
445
+ if key_padding_mask is not None:
446
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
447
+ attn_weights = attn_weights.view(kv_bsz, -1, self.num_heads, tgt_len, src_len).masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(torch.bool), float("-inf")) if not is_tpu else attn_weights.transpose(0, 2).masked_fill(key_padding_mask, float("-inf")).transpose(0, 2)
448
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
449
+
450
+ if before_softmax: return attn_weights, v
451
+ attn_weights_float = softmax(attn_weights, dim=-1, onnx_trace=self.onnx_trace)
452
+ attn_weights = attn_weights_float.type_as(attn_weights)
453
+ attn_probs = self.dropout_module(attn_weights)
454
+ assert v is not None
455
+ attn = None
456
+
457
+ if self.encoder_decoder_attention and bsz != kv_bsz:
458
+ attn = torch.einsum("bxhts,bhsd->bxhtd", attn_probs.view((kv_bsz, -1, self.num_heads) + attn_probs.size()[1:]), v.view((kv_bsz, self.num_heads) + v.size()[1:]))
459
+ attn = attn.reshape((-1,) + attn.size()[-2:])
460
+ else: attn = torch.bmm(attn_probs, v)
461
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
462
+
463
+ attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim) if self.onnx_trace and attn.size(1) == 1 else attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
464
+ attn = self.out_proj(attn)
465
+ attn_weights = None
466
+
467
+ if need_weights:
468
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
469
+ if not need_head_weights: attn_weights = attn_weights.mean(dim=0)
470
+
471
+ return attn, attn_weights
472
+
473
+ @staticmethod
474
+ def _append_prev_key_padding_mask(key_padding_mask, prev_key_padding_mask, batch_size, src_len, static_kv):
475
+ if prev_key_padding_mask is not None and static_kv: new_key_padding_mask = prev_key_padding_mask
476
+ elif prev_key_padding_mask is not None and key_padding_mask is not None: new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), key_padding_mask.float()], dim=1)
477
+ elif prev_key_padding_mask is not None:
478
+ if src_len > prev_key_padding_mask.size(1):
479
+ filler = torch.zeros((batch_size, src_len - prev_key_padding_mask.size(1)), device=prev_key_padding_mask.device)
480
+ new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
481
+ else: new_key_padding_mask = prev_key_padding_mask.float()
482
+ elif key_padding_mask is not None:
483
+ if src_len > key_padding_mask.size(1):
484
+ filler = torch.zeros((batch_size, src_len - key_padding_mask.size(1)), device=key_padding_mask.device)
485
+ new_key_padding_mask = torch.cat([filler.float(), key_padding_mask.float()], dim=1)
486
+ else: new_key_padding_mask = key_padding_mask.float()
487
+ else: new_key_padding_mask = prev_key_padding_mask
488
+ return new_key_padding_mask
489
+
490
+ @torch.jit.export
491
+ def reorder_incremental_state(self, incremental_state, new_order):
492
+ input_buffer = self._get_input_buffer(incremental_state)
493
+ if input_buffer is not None:
494
+ for k in input_buffer.keys():
495
+ input_buffer_k = input_buffer[k]
496
+ if input_buffer_k is not None:
497
+ if self.encoder_decoder_attention:
498
+ if input_buffer_k.size(0) * self.beam_size == new_order.size(0): return incremental_state
499
+ elif self.beam_size > 1: input_buffer[k] = input_buffer_k.index_select(0, new_order.reshape(-1, self.beam_size)[:, 0] // self.beam_size)
500
+ else: input_buffer[k] = input_buffer_k.index_select(0, new_order)
501
+ else: input_buffer[k] = input_buffer_k.index_select(0, new_order)
502
+ incremental_state = self._set_input_buffer(incremental_state, input_buffer)
503
+ return incremental_state
504
+
505
+ def set_beam_size(self, beam_size):
506
+ self.beam_size = beam_size
507
+
508
+ def _get_input_buffer(self, incremental_state):
509
+ result = self.get_incremental_state(incremental_state, "attn_state")
510
+ return result if result is not None else {}
511
+
512
+ def _set_input_buffer(self, incremental_state, buffer):
513
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
514
+
515
+ def upgrade_state_dict_named(self, state_dict, name):
516
+ prefix = name + "." if name != "" else ""
517
+ items_to_add, keys_to_remove = {}, []
518
+ for k in state_dict.keys():
519
+ if k.endswith(prefix + "in_proj_weight"):
520
+ dim = int(state_dict[k].shape[0] / 3)
521
+ items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
522
+ items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
523
+ items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
524
+ keys_to_remove.append(k)
525
+ k_bias = prefix + "in_proj_bias"
526
+ if k_bias in state_dict.keys():
527
+ dim = int(state_dict[k].shape[0] / 3)
528
+ items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
529
+ items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][dim : 2 * dim]
530
+ items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
531
+ keys_to_remove.append(prefix + "in_proj_bias")
532
+
533
+ for k in keys_to_remove:
534
+ del state_dict[k]
535
+
536
+ for key, value in items_to_add.items():
537
+ state_dict[key] = value
538
+
539
+ def init_bert_params(module):
540
+ def normal_(data):
541
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
542
+
543
+ if isinstance(module, nn.Linear):
544
+ normal_(module.weight.data)
545
+ if module.bias is not None: module.bias.data.zero_()
546
+ if isinstance(module, nn.Embedding):
547
+ normal_(module.weight.data)
548
+ if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_()
549
+ if isinstance(module, MultiheadAttention):
550
+ normal_(module.q_proj.weight.data)
551
+ normal_(module.k_proj.weight.data)
552
+ normal_(module.v_proj.weight.data)
553
+
554
+ def make_conv_pos(e, k, g):
555
+ pos_conv = nn.Conv1d(e, e, kernel_size=k, padding=k // 2, groups=g)
556
+ dropout = 0
557
+ nn.init.normal_(pos_conv.weight, mean=0, std=math.sqrt((4 * (1.0 - dropout)) / (k * e)))
558
+ nn.init.constant_(pos_conv.bias, 0)
559
+ return nn.Sequential(nn.utils.parametrizations.weight_norm(pos_conv, name="weight", dim=2), SamePad(k), nn.GELU())
560
+
561
+ def is_xla_tensor(tensor):
562
+ return torch.is_tensor(tensor) and tensor.device.type == "xla"
563
+
564
+ def index_put(tensor, indices, value):
565
+ if is_xla_tensor(tensor):
566
+ for _ in range(indices.dim(), tensor.dim()):
567
+ indices = indices.unsqueeze(-1)
568
+
569
+ if indices.size(-1) < tensor.size(-1): indices = indices.expand_as(tensor)
570
+ tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices)
571
+ else: tensor[indices] = value
572
+
573
+ return tensor
574
+
575
+ def pad_to_multiple(x, multiple, dim=-1, value=0):
576
+ if x is None: return None, 0
577
+ tsz = x.size(dim)
578
+ m = tsz / multiple
579
+ remainder = math.ceil(m) * multiple - tsz
580
+ if m.is_integer(): return x, 0
581
+ return F.pad(x, (*((0,) * (-1 - dim) * 2), 0, remainder), value=value), remainder
582
+
583
+ def compute_mask_indices(shape, padding_mask, mask_prob, mask_length, mask_type = "static", mask_other = 0.0, min_masks = 0, no_overlap = False, min_space = 0, require_same_masks = True, mask_dropout = 0.0, add_masks = False, seed = None, epoch = None, indices = None, idc_select_ver = 1, num_mask_ver = 2):
584
+ bsz, all_sz = shape
585
+ mask = np.full((bsz, all_sz), False)
586
+ if num_mask_ver == 1: all_num_mask = max(min_masks, int(mask_prob * all_sz / float(mask_length) + np.random.rand()))
587
+ mask_idcs = []
588
+
589
+ for i in range(bsz):
590
+ seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) if seed is not None and epoch is not None and indices is not None else None
591
+ rng = np.random.default_rng(seed_i)
592
+
593
+ if padding_mask is not None:
594
+ sz = all_sz - padding_mask[i].long().sum().item()
595
+ assert sz >= 0, sz
596
+ else: sz = all_sz
597
+
598
+ if num_mask_ver == 1: num_mask = max(min_masks, int(mask_prob * sz / float(mask_length) + np.random.rand())) if padding_mask is not None else all_num_mask
599
+ elif num_mask_ver == 2: num_mask = max(min_masks, int(mask_prob * sz / float(mask_length) + rng.random()))
600
+ else: raise ValueError
601
+
602
+ if mask_type == "static": lengths = np.full(num_mask, mask_length)
603
+ elif mask_type == "uniform": lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask)
604
+ elif mask_type == "normal": lengths = [max(1, int(round(x))) for x in rng.normal(mask_length, mask_other, size=num_mask)]
605
+ elif mask_type == "poisson": lengths = [int(round(x)) for x in rng.poisson(mask_length, size=num_mask)]
606
+ else: raise Exception
607
+
608
+ if sum(lengths) == 0:
609
+ if mask_type == "static": raise ValueError
610
+ else: lengths = [min(mask_length, sz - 1)]
611
+
612
+ if no_overlap:
613
+ mask_idc = []
614
+
615
+ def arrange(s, e, length, keep_length):
616
+ span_start = rng.randint(s, e - length)
617
+ mask_idc.extend(span_start + i for i in range(length))
618
+ new_parts = []
619
+ if span_start - s - min_space >= keep_length: new_parts.append((s, span_start - min_space + 1))
620
+ if e - span_start - length - min_space > keep_length: new_parts.append((span_start + length + min_space, e))
621
+ return new_parts
622
+
623
+ parts = [(0, sz)]
624
+ min_length = min(lengths)
625
+ for length in sorted(lengths, reverse=True):
626
+ lens = np.fromiter((e - s if e - s >= length + min_space else 0 for s, e in parts), np.int32)
627
+ l_sum = np.sum(lens)
628
+ if l_sum == 0: break
629
+ s, e = parts.pop(rng.choice(len(parts), p=lens / np.sum(lens)))
630
+ parts.extend(arrange(s, e, length, min_length))
631
+ mask_idc = np.asarray(mask_idc)
632
+ else:
633
+ if idc_select_ver == 1:
634
+ min_len = min(lengths)
635
+ if sz - min_len <= num_mask: min_len = sz - num_mask - 1
636
+ mask_idc = rng.choice(sz - min_len, num_mask, replace=False)
637
+ elif idc_select_ver == 2: mask_idc = rng.choice(sz, num_mask, replace=False)
638
+ else: raise ValueError
639
+
640
+ mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
641
+
642
+ mask_idc = np.unique(mask_idc[mask_idc < sz])
643
+ if len(mask_idc) >= sz: raise ValueError
644
+ mask_idcs.append(mask_idc)
645
+
646
+ target_len = None
647
+ if require_same_masks: target_len = max([len(m) for m in mask_idcs]) if add_masks else min([len(m) for m in mask_idcs])
648
+
649
+ for i, mask_idc in enumerate(mask_idcs):
650
+ if target_len is not None and len(mask_idc) > target_len: mask_idc = rng.choice(mask_idc, target_len, replace=False)
651
+ mask[i, mask_idc] = True
652
+
653
+ if target_len is not None and len(mask_idc) < target_len:
654
+ to_mask = rng.choice(np.flatnonzero(~mask[i]), target_len - len(mask_idc), replace=False)
655
+ mask[i, to_mask] = True
656
+
657
+ if mask_dropout > 0:
658
+ masked = np.flatnonzero(mask[i])
659
+ mask[i, rng.choice(masked, np.rint(len(masked) * mask_dropout).astype(int), replace=False)] = False
660
+
661
+ return mask
662
+
663
+ def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
664
+ return nn.LayerNorm(normalized_shape, eps, elementwise_affine)
665
+
666
+ def prune_state_dict(state_dict, model_cfg):
667
+ arch = None
668
+ if model_cfg is not None: arch = (model_cfg._name if isinstance(model_cfg, DictConfig) else getattr(model_cfg, "arch", None))
669
+ if not model_cfg or arch is None or arch == "ptt_transformer": return state_dict
670
+ encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
671
+ decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
672
+ if not encoder_layers_to_keep and not decoder_layers_to_keep: return state_dict
673
+
674
+ def create_pruning_pass(layers_to_keep, layer_name):
675
+ keep_layers = sorted(int(layer_string) for layer_string in layers_to_keep.split(","))
676
+ mapping_dict = {}
677
+ for i in range(len(keep_layers)):
678
+ mapping_dict[str(keep_layers[i])] = str(i)
679
+
680
+ return {"substitution_regex": re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name)), "mapping_dict": mapping_dict}
681
+
682
+ pruning_passes, new_state_dict = [], {}
683
+ if encoder_layers_to_keep: pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
684
+ if decoder_layers_to_keep: pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
685
+
686
+ for layer_name in state_dict.keys():
687
+ match = re.search(r"\.layers\.(\d+)\.", layer_name)
688
+ if not match:
689
+ new_state_dict[layer_name] = state_dict[layer_name]
690
+ continue
691
+
692
+ original_layer_number = match.group(1)
693
+ for pruning_pass in pruning_passes:
694
+ if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass["substitution_regex"].search(layer_name):
695
+ substitution_match = pruning_pass["substitution_regex"].search(layer_name)
696
+ new_state_dict[(layer_name[: substitution_match.start(1)] + pruning_pass["mapping_dict"][original_layer_number] + layer_name[substitution_match.end(1) :])] = state_dict[layer_name]
697
+
698
+ with open_dict(model_cfg) if isinstance(model_cfg, DictConfig) else contextlib.ExitStack():
699
+ if hasattr(model_cfg, "encoder_layers_to_keep"): model_cfg.encoder_layers_to_keep = None
700
+ if hasattr(model_cfg, "decoder_layers_to_keep"): model_cfg.decoder_layers_to_keep = None
701
+
702
+ return new_state_dict
703
+
704
+ def relu_squared(x):
705
+ return F.relu(x).pow(2)
706
+
707
+ def get_activation_fn(activation):
708
+ def gelu(x):
709
+ return nn.functional.gelu(x.float()).type_as(x)
710
+
711
+ def gelu_accurate(x):
712
+ if not hasattr(gelu_accurate, "_a"):
713
+ gelu_accurate._a = math.sqrt(2 / math.pi)
714
+ return (0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))))
715
+
716
+ if activation == "relu": return F.relu
717
+ elif activation == "relu_squared": return relu_squared
718
+ elif activation == "gelu": return gelu
719
+ elif activation == "gelu_fast": return gelu_accurate
720
+ elif activation == "gelu_accurate": return gelu_accurate
721
+ elif activation == "tanh": return torch.tanh
722
+ elif activation == "linear": return lambda x: x
723
+ elif activation == "swish": return nn.SiLU
724
+ else: raise RuntimeError
725
+
726
+ class SamePad(nn.Module):
727
+ def __init__(self, kernel_size, causal=False):
728
+ super().__init__()
729
+ if causal: self.remove = kernel_size - 1
730
+ else: self.remove = 1 if kernel_size % 2 == 0 else 0
731
+
732
+ def forward(self, x):
733
+ if self.remove > 0: x = x[:, :, : -self.remove]
734
+ return x
735
+
736
+ class TransformerSentenceEncoderLayer(nn.Module):
737
+ def __init__(self, embedding_dim = 768, ffn_embedding_dim = 3072, num_attention_heads = 8, dropout = 0.1, attention_dropout = 0.1, activation_dropout = 0.1, activation_fn = "relu", layer_norm_first = False):
738
+ super().__init__()
739
+ self.embedding_dim = embedding_dim
740
+ self.dropout = dropout
741
+ self.activation_dropout = activation_dropout
742
+ self.activation_fn = get_activation_fn(activation_fn)
743
+ self.self_attn = MultiheadAttention(self.embedding_dim, num_attention_heads, dropout=attention_dropout, self_attention=True)
744
+ self.dropout1 = nn.Dropout(dropout)
745
+ self.dropout2 = nn.Dropout(self.activation_dropout)
746
+ self.dropout3 = nn.Dropout(dropout)
747
+ self.layer_norm_first = layer_norm_first
748
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
749
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
750
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
751
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
752
+
753
+ def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None):
754
+ residual = x
755
+ if self.layer_norm_first:
756
+ x = self.self_attn_layer_norm(x)
757
+ x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, attn_mask=self_attn_mask, need_weights=False)
758
+ x = residual + self.dropout1(x)
759
+ residual = x
760
+ x = self.fc2(self.dropout2(self.activation_fn(self.fc1(self.final_layer_norm(x)))))
761
+ layer_result = x
762
+ x = residual + self.dropout3(x)
763
+ else:
764
+ x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_attn_padding_mask, need_weights=False)
765
+ x = self.self_attn_layer_norm(residual + self.dropout1(x))
766
+ residual = x
767
+ x = self.fc2(self.dropout2(self.activation_fn(self.fc1(x))))
768
+ layer_result = x
769
+ x = self.final_layer_norm(residual + self.dropout3(x))
770
+
771
+ return x, (attn, layer_result)
772
+
773
+ class AdapterFast(nn.Module):
774
+ def __init__(self, adapter_num, input_dim, hidden_dim, act_fn):
775
+ super().__init__()
776
+ self.adapter_num = adapter_num
777
+ self.input_dim = input_dim
778
+ self.hidden_dim = hidden_dim
779
+ self.W_a = nn.Parameter(torch.empty(adapter_num, hidden_dim, input_dim))
780
+ self.W_b = nn.Parameter(torch.empty(adapter_num, input_dim, hidden_dim))
781
+ self.b_a = nn.Parameter(torch.empty(adapter_num, hidden_dim))
782
+ self.b_b = nn.Parameter(torch.empty(adapter_num, input_dim))
783
+ self.ln_W = nn.Parameter(torch.empty(adapter_num, input_dim))
784
+ self.ln_b = nn.Parameter(torch.empty(adapter_num, input_dim))
785
+ self.act_fn = nn.Identity()
786
+ if act_fn == "relu": self.act_fn = nn.ReLU()
787
+ elif act_fn == "gelu": self.act_fn = nn.GELU()
788
+ elif act_fn == "selu": self.act_fn = nn.SELU()
789
+ else: raise ValueError
790
+ self.input_dim = input_dim
791
+ self.reset_parameters()
792
+
793
+ def reset_parameters(self):
794
+ for ii in range(self.adapter_num):
795
+ nn.init.kaiming_uniform_(self.W_a[ii], a=math.sqrt(5))
796
+ nn.init.kaiming_uniform_(self.W_b[ii], a=math.sqrt(5))
797
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_a[ii])
798
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
799
+ nn.init.uniform_(self.b_a[ii], -bound, bound)
800
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_b[ii])
801
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
802
+ nn.init.uniform_(self.b_b[ii], -bound, bound)
803
+
804
+ nn.init.ones_(self.ln_W)
805
+ nn.init.zeros_(self.ln_b)
806
+
807
+ def forward(self, x, adapter_id):
808
+ ii = adapter_id
809
+ return F.linear(self.act_fn(F.linear(F.layer_norm(x, (self.input_dim, ), self.ln_W[ii], self.ln_b[ii]), self.W_a[ii], self.b_a[ii])), self.W_b[ii], self.b_b[ii])
810
+
811
+ def extra_repr(self):
812
+ return ('adapter={}, input_dim={}, hidden_dim={}'.format(self.adapter_num, self.input_dim, self.hidden_dim))
813
+
814
+ class FeedForwardModule(nn.Module):
815
+ def __init__(self, input_feat, hidden_units, dropout1, dropout2, activation_fn="swish", bias=True):
816
+ super(FeedForwardModule, self).__init__()
817
+ self.layer_norm = LayerNorm(input_feat)
818
+ self.w_1 = nn.Linear(input_feat, hidden_units, bias=bias)
819
+ self.w_2 = nn.Linear(hidden_units, input_feat, bias=bias)
820
+ self.dropout1 = nn.Dropout(dropout1)
821
+ self.dropout2 = nn.Dropout(dropout2)
822
+ self.activation = get_activation_fn(activation_fn)(hidden_units)
823
+
824
+ def forward(self, x):
825
+ return self.dropout2(self.w_2(self.dropout1(self.activation(self.w_1(self.layer_norm(x))))))
826
+
827
+ class ConvolutionModule(nn.Module):
828
+ def __init__(self, embed_dim, channels, depthwise_kernel_size, dropout, activation_fn="swish", bias=False, export=False):
829
+ super(ConvolutionModule, self).__init__()
830
+ assert (depthwise_kernel_size - 1) % 2 == 0
831
+ self.layer_norm = LayerNorm(embed_dim, export=export)
832
+ self.pointwise_conv1 = nn.Conv1d(embed_dim, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias)
833
+ self.glu = nn.GLU(dim=1)
834
+ self.depthwise_conv = nn.Conv1d(channels, channels, depthwise_kernel_size, stride=1, padding=(depthwise_kernel_size - 1) // 2, groups=channels, bias=bias)
835
+ self.batch_norm = nn.BatchNorm1d(channels)
836
+ self.activation = get_activation_fn(activation_fn)(channels)
837
+ self.pointwise_conv2 = nn.Conv1d(channels, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias)
838
+ self.dropout = nn.Dropout(dropout)
839
+
840
+ def forward(self, x):
841
+ return self.dropout(self.pointwise_conv2(self.activation(self.batch_norm(self.depthwise_conv(self.glu(self.pointwise_conv1(self.layer_norm(x).transpose(1, 2)))))))).transpose(1, 2)
842
+
843
+ def rotate_half(x):
844
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
845
+ return torch.cat((-x2, x1), dim=x1.ndim - 1)
846
+
847
+ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
848
+ cos, sin = (cos[offset : q.shape[0] + offset, ...], sin[offset : q.shape[0] + offset, ...])
849
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
850
+
851
+ class RotaryPositionalEmbedding(nn.Module):
852
+ def __init__(self, dim, base=10000, precision=torch.half):
853
+ super().__init__()
854
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
855
+ self.register_buffer("inv_freq", inv_freq)
856
+ self.seq_len_cached = 0
857
+ self.cos_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
858
+ self.sin_cached = torch.empty(self.seq_len_cached, 1, 1, dim)
859
+ self.precision = precision
860
+
861
+ def forward(self, x, seq_len = 0):
862
+ if seq_len > self.seq_len_cached:
863
+ self.seq_len_cached = seq_len
864
+ freqs = torch.einsum("i,j->ij", torch.arange(seq_len, device=x.device).type_as(self.inv_freq), self.inv_freq)
865
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
866
+ self.cos_cached = emb.cos().view(emb.size(0), 1, 1, emb.size(1))
867
+ self.sin_cached = emb.sin().view(emb.size(0), 1, 1, emb.size(1))
868
+ return self.cos_cached, self.sin_cached
869
+
870
+ class ESPNETMultiHeadedAttention(nn.Module):
871
+ def __init__(self, n_feat, n_head, dropout):
872
+ super(ESPNETMultiHeadedAttention, self).__init__()
873
+ assert n_feat % n_head == 0
874
+ self.d_k = n_feat // n_head
875
+ self.h = n_head
876
+ self.linear_q = nn.Linear(n_feat, n_feat)
877
+ self.linear_k = nn.Linear(n_feat, n_feat)
878
+ self.linear_v = nn.Linear(n_feat, n_feat)
879
+ self.linear_out = nn.Linear(n_feat, n_feat)
880
+ self.attn = None
881
+ self.dropout = nn.Dropout(p=dropout)
882
+
883
+ def forward_qkv(self, query, key, value, **kwargs):
884
+ n_batch = query.size(0)
885
+ return self.linear_q(query).view(n_batch, -1, self.h, self.d_k).transpose(1, 2), self.linear_k(key).view(n_batch, -1, self.h, self.d_k).transpose(1, 2), self.linear_v(value).view(n_batch, -1, self.h, self.d_k).transpose(1, 2)
886
+
887
+ def forward_attention(self, value, scores, mask):
888
+ n_batch = value.size(0)
889
+ if mask is not None:
890
+ scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2).to(bool), float("-inf"))
891
+ self.attn = torch.softmax(scores, dim=-1)
892
+ else: self.attn = torch.softmax(scores, dim=-1)
893
+
894
+ return self.linear_out((torch.matmul(self.dropout(self.attn), value).transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)))
895
+
896
+ def forward(self, query, key, value, key_padding_mask=None, **kwargs):
897
+ q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
898
+ return self.forward_attention(v, torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
899
+
900
+ class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
901
+ def __init__(self, n_feat, n_head, dropout, zero_triu=False):
902
+ super().__init__(n_feat, n_head, dropout)
903
+ self.zero_triu = zero_triu
904
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
905
+ self.pos_bias_u = nn.Parameter(torch.zeros(self.h, self.d_k))
906
+ self.pos_bias_v = nn.Parameter(torch.zeros(self.h, self.d_k))
907
+ nn.init.xavier_uniform_(self.pos_bias_u)
908
+ nn.init.xavier_uniform_(self.pos_bias_v)
909
+
910
+ def rel_shift(self, x):
911
+ x = torch.cat([torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype), x], dim=-1).view(*x.size()[:2], x.size(3) + 1, x.size(2))[:, :, 1:].view_as(x)[:, :, :, : x.size(-1) // 2 + 1]
912
+ if self.zero_triu: x = x * torch.tril(torch.ones((x.size(2), x.size(3)), device=x.device), x.size(3) - x.size(2))[None, None, :, :]
913
+ return x
914
+
915
+ def forward(self, query, key, value, pos_emb, key_padding_mask=None, **kwargs):
916
+ pos_emb = pos_emb.transpose(0, 1)
917
+ q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
918
+ q = q.transpose(1, 2)
919
+
920
+ return self.forward_attention(v, (torch.matmul((q + self.pos_bias_u).transpose(1, 2), k.transpose(-2, -1)) + self.rel_shift(torch.matmul((q + self.pos_bias_v).transpose(1, 2), self.linear_pos(pos_emb).view(pos_emb.size(0), -1, self.h, self.d_k).transpose(1, 2).transpose(-2, -1)))) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
921
+
922
+ class RotaryPositionMultiHeadedAttention(ESPNETMultiHeadedAttention):
923
+ def __init__(self, n_feat, n_head, dropout, precision, rotary_emd_base=10000):
924
+ super().__init__(n_feat, n_head, dropout)
925
+ precision = torch.float
926
+ self.rotary_ndims = self.d_k
927
+ if precision == "fp16": precision = torch.half
928
+ self.rotary_emb = RotaryPositionalEmbedding(self.rotary_ndims, base=rotary_emd_base, precision=precision)
929
+
930
+ def forward(self, query, key, value, key_padding_mask=None, **kwargs):
931
+ T, B, C = value.size()
932
+ query = query.view(T, B, self.h, self.d_k)
933
+ key = key.view(T, B, self.h, self.d_k)
934
+ value = value.view(T, B, self.h, self.d_k)
935
+ cos, sin = self.rotary_emb(value, seq_len=T)
936
+ query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=0)
937
+ query = query.view(T, B, self.h * self.d_k)
938
+ key = key.view(T, B, self.h * self.d_k)
939
+ value = value.view(T, B, self.h * self.d_k)
940
+ q, k, v = self.forward_qkv(query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1))
941
+ return self.forward_attention(v, torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k), key_padding_mask).transpose(0, 1), None
942
+
943
+ class ConformerEncoderLayer(nn.Module):
944
+ def __init__(self, embed_dim, ffn_embed_dim, attention_heads, dropout, use_fp16, depthwise_conv_kernel_size=31, activation_fn="swish", attn_type=None, pos_enc_type="abs"):
945
+ self.pos_enc_type = pos_enc_type
946
+ super(ConformerEncoderLayer, self).__init__()
947
+ self.ffn1 = FeedForwardModule(embed_dim, ffn_embed_dim, dropout, dropout)
948
+ self.self_attn_layer_norm = LayerNorm(embed_dim, export=False)
949
+ self.self_attn_dropout = nn.Dropout(dropout)
950
+ if attn_type == "espnet":
951
+ if self.pos_enc_type == "rel_pos": self.self_attn = RelPositionMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout)
952
+ elif self.pos_enc_type == "rope": self.self_attn = RotaryPositionMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout, precision=use_fp16)
953
+ elif self.pos_enc_type == "abs": self.self_attn = ESPNETMultiHeadedAttention(embed_dim, attention_heads, dropout=dropout)
954
+ else: raise Exception
955
+ else: self.self_attn = MultiheadAttention(embed_dim, attention_heads, dropout=dropout)
956
+ self.conv_module = ConvolutionModule(embed_dim=embed_dim, channels=embed_dim, depthwise_kernel_size=depthwise_conv_kernel_size, dropout=dropout, activation_fn=activation_fn)
957
+ self.ffn2 = FeedForwardModule(embed_dim, ffn_embed_dim, dropout, dropout, activation_fn=activation_fn)
958
+ self.final_layer_norm = LayerNorm(embed_dim, export=False)
959
+
960
+ def forward(self, x, encoder_padding_mask, position_emb = None):
961
+ residual = x
962
+ x = self.ffn1(x) * 0.5 + residual
963
+ residual = x
964
+ x = self.self_attn_layer_norm(x)
965
+ if self.pos_enc_type == "rel_pos": x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, pos_emb=position_emb, need_weights=False)
966
+ else: x, attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask, need_weights=False)
967
+ x = self.self_attn_dropout(x)
968
+ x = x + residual
969
+ residual = x
970
+ x = residual + self.conv_module(x.transpose(0, 1)).transpose(0, 1)
971
+ residual = x
972
+ x = self.ffn2(x)
973
+ layer_result = x
974
+ x = self.final_layer_norm(x * 0.5 + residual)
975
+ return x, (attn, layer_result)
976
+
977
+ class ConformerWav2Vec2EncoderLayer(ConformerEncoderLayer):
978
+ def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None, position_emb=None):
979
+ return super().forward(x, self_attn_padding_mask, position_emb)
980
+
981
+ class TransformerSentenceEncoderWithAdapterLayer(TransformerSentenceEncoderLayer):
982
+ def __init__(self, embedding_dim = 768, ffn_embedding_dim = 3072, num_attention_heads = 8, dropout = 0.1, attention_dropout = 0.1, activation_dropout = 0.1, activation_fn = "relu", layer_norm_first = False, adapter_num=201, adapter_dim=64, adapter_act_fn="relu"):
983
+ super().__init__(embedding_dim=embedding_dim, ffn_embedding_dim=ffn_embedding_dim, num_attention_heads=num_attention_heads, dropout=dropout, attention_dropout=attention_dropout, activation_dropout=activation_dropout, activation_fn=activation_fn, layer_norm_first=layer_norm_first)
984
+ self.adapter_num = adapter_num
985
+ self.adapter_dim = adapter_dim
986
+ self.adapter_layer = AdapterFast(adapter_num, self.embedding_dim, self.adapter_dim, adapter_act_fn)
987
+
988
+ def forward(self, x, self_attn_mask=None, self_attn_padding_mask=None, need_weights=False, att_args=None, corpus_key=None):
989
+ x, (attn, layer_result) = super().forward(x=x, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, need_weights=need_weights, att_args=att_args)
990
+ assert corpus_key is not None
991
+ assert len(set(corpus_key)) == 1
992
+ return x + self.adapter_layer(x, corpus_key[0]), (attn, layer_result)
993
+
994
+ class TransposeLast(nn.Module):
995
+ def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
996
+ super().__init__()
997
+ self.deconstruct_idx = deconstruct_idx
998
+ self.tranpose_dim = tranpose_dim
999
+
1000
+ def forward(self, x):
1001
+ if self.deconstruct_idx is not None: x = x[self.deconstruct_idx]
1002
+ return x.transpose(self.tranpose_dim, -1)
1003
+
1004
+ class TransformerEncoder(nn.Module):
1005
+ def build_encoder_layer(self, args, **kwargs):
1006
+ if args.layer_type == "transformer": layer = TransformerSentenceEncoderLayer(embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first)
1007
+ elif args.layer_type == "conformer": layer = ConformerWav2Vec2EncoderLayer(embed_dim=self.embedding_dim, ffn_embed_dim=args.encoder_ffn_embed_dim, attention_heads=args.encoder_attention_heads, dropout=args.dropout, depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, activation_fn="swish", attn_type=args.attn_type, use_fp16=args.fp16, pos_enc_type="abs")
1008
+ elif args.layer_type == "trf_adp":
1009
+ use_adp = False
1010
+ if args.adp_trf_idx == "all": use_adp = True
1011
+ else:
1012
+ if kwargs.get("layer_idx", None) in list(range(*[int(g) for g in args.adp_trf_idx.split(":")])): use_adp = True
1013
+
1014
+ layer = TransformerSentenceEncoderWithAdapterLayer(embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first, adapter_num=args.adp_num, adapter_dim=args.adp_dim, adapter_act_fn=args.adp_act_fn) if use_adp else TransformerSentenceEncoderLayer(embedding_dim=self.embedding_dim, ffn_embedding_dim=args.encoder_ffn_embed_dim, num_attention_heads=args.encoder_attention_heads, dropout=self.dropout, attention_dropout=args.attention_dropout, activation_dropout=args.activation_dropout, activation_fn=args.activation_fn, layer_norm_first=args.layer_norm_first,)
1015
+
1016
+ return layer
1017
+
1018
+ def __init__(self, args):
1019
+ super().__init__()
1020
+ self.dropout = args.dropout
1021
+ self.embedding_dim = args.encoder_embed_dim
1022
+ self.required_seq_len_multiple = args.required_seq_len_multiple
1023
+ pos_conv_depth = getattr(args, "pos_conv_depth", 1)
1024
+ if pos_conv_depth > 1:
1025
+ num_layers = args.pos_conv_depth
1026
+ k = max(3, args.conv_pos // num_layers)
1027
+
1028
+ def make_conv_block(e, k, g, l):
1029
+ return nn.Sequential(*[nn.Sequential(nn.Conv1d(e, e, kernel_size=k, padding=k // 2, groups=g), SamePad(k), TransposeLast(), LayerNorm(e, elementwise_affine=False), TransposeLast(), nn.GELU()) for _ in range(l)])
1030
+
1031
+ self.pos_conv = make_conv_block(self.embedding_dim, k, args.conv_pos_groups, num_layers)
1032
+ else: self.pos_conv = make_conv_pos(self.embedding_dim, args.conv_pos, args.conv_pos_groups)
1033
+
1034
+ self.layers = nn.ModuleList([self.build_encoder_layer(args, layer_idx=ii) for ii in range(args.encoder_layers)])
1035
+ self.layer_norm_first = args.layer_norm_first
1036
+ self.layer_norm = LayerNorm(self.embedding_dim)
1037
+ self.layerdrop = args.encoder_layerdrop
1038
+ self.apply(init_bert_params)
1039
+
1040
+ def forward(self, x, padding_mask=None, layer=None, corpus_key=None):
1041
+ x, layer_results = self.extract_features(x, padding_mask, layer, corpus_key=corpus_key)
1042
+ if self.layer_norm_first and layer is None: x = self.layer_norm(x)
1043
+ return x, layer_results
1044
+
1045
+ def extract_features(self, x, padding_mask=None, tgt_layer=None, min_layer=0, corpus_key=None):
1046
+ if padding_mask is not None: x = index_put(x, padding_mask, 0)
1047
+ x = x + self.pos_conv(x.transpose(1, 2)).transpose(1, 2)
1048
+ if not self.layer_norm_first: x = self.layer_norm(x)
1049
+ x, pad_length = pad_to_multiple(x, self.required_seq_len_multiple, dim=-2, value=0)
1050
+ if pad_length > 0 and padding_mask is None:
1051
+ padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
1052
+ padding_mask[:, -pad_length:] = True
1053
+ else: padding_mask, _ = pad_to_multiple(padding_mask, self.required_seq_len_multiple, dim=-1, value=True)
1054
+ x = F.dropout(x, p=self.dropout, training=self.training).transpose(0, 1)
1055
+ layer_results = []
1056
+ r = None
1057
+
1058
+ for i, layer in enumerate(self.layers):
1059
+ dropout_probability = np.random.random() if self.layerdrop > 0 else 1
1060
+ if not self.training or (dropout_probability > self.layerdrop):
1061
+ layer_check = layer
1062
+ if (corpus_key is None) or (not isinstance(layer_check, (TransformerSentenceEncoderWithAdapterLayer))): x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask, need_weights=False)
1063
+ else: x, (z, lr) = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, corpus_key=corpus_key)
1064
+ if i >= min_layer: layer_results.append((x, z, lr))
1065
+ if i == tgt_layer:
1066
+ r = x
1067
+ break
1068
+
1069
+ if r is not None: x = r
1070
+ x = x.transpose(0, 1)
1071
+
1072
+ if pad_length > 0:
1073
+ x = x[:, :-pad_length]
1074
+ def undo_pad(a, b, c):
1075
+ return (a[:-pad_length], b[:-pad_length] if b is not None else b, c[:-pad_length])
1076
+
1077
+ layer_results = [undo_pad(*u) for u in layer_results]
1078
+
1079
+ return x, layer_results
1080
+
1081
+ def max_positions(self):
1082
+ return self.args.max_positions
1083
+
1084
+ def upgrade_state_dict_named(self, state_dict, name):
1085
+ return state_dict
1086
+
1087
+ class Fp32GroupNorm(nn.GroupNorm):
1088
+ def __init__(self, *args, **kwargs):
1089
+ super().__init__(*args, **kwargs)
1090
+
1091
+ def forward(self, input):
1092
+ output = F.group_norm(input.float(), self.num_groups, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps)
1093
+ return output.type_as(input)
1094
+
1095
+ class Fp32LayerNorm(nn.LayerNorm):
1096
+ def __init__(self, *args, **kwargs):
1097
+ super().__init__(*args, **kwargs)
1098
+
1099
+ def forward(self, input):
1100
+ output = F.layer_norm(input.float(), self.normalized_shape, self.weight.float() if self.weight is not None else None, self.bias.float() if self.bias is not None else None, self.eps)
1101
+ return output.type_as(input)
1102
+
1103
+ class ConvFeatureExtractionModel(nn.Module):
1104
+ def __init__(self, conv_layers, dropout = 0.0, mode = "default", conv_bias = False):
1105
+ super().__init__()
1106
+ assert mode in {"default", "layer_norm"}
1107
+
1108
+ def block(n_in, n_out, k, stride, is_layer_norm=False, is_group_norm=False, conv_bias=False):
1109
+ def make_conv():
1110
+ conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
1111
+ nn.init.kaiming_normal_(conv.weight)
1112
+ return conv
1113
+
1114
+ assert (is_layer_norm and is_group_norm) == False
1115
+
1116
+ if is_layer_norm: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.Sequential(TransposeLast(), Fp32LayerNorm(dim, elementwise_affine=True), TransposeLast()), nn.GELU())
1117
+ elif is_group_norm: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), Fp32GroupNorm(dim, dim, affine=True), nn.GELU())
1118
+ else: return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
1119
+
1120
+ in_d = 1
1121
+ self.conv_layers = nn.ModuleList()
1122
+ for i, cl in enumerate(conv_layers):
1123
+ assert len(cl) == 3
1124
+ (dim, k, stride) = cl
1125
+ self.conv_layers.append(block(in_d, dim, k, stride, is_layer_norm=mode == "layer_norm", is_group_norm=mode == "default" and i == 0, conv_bias=conv_bias))
1126
+ in_d = dim
1127
+
1128
+ def forward(self, x):
1129
+ x = x.unsqueeze(1)
1130
+ for conv in self.conv_layers:
1131
+ x = conv(x)
1132
+
1133
+ return x
1134
+
1135
+ class GradMultiply(torch.autograd.Function):
1136
+ @staticmethod
1137
+ def forward(ctx, x, scale):
1138
+ ctx.scale = scale
1139
+ res = x.new(x)
1140
+ return res
1141
+
1142
+ @staticmethod
1143
+ def backward(ctx, grad):
1144
+ return grad * ctx.scale, None
1145
+
1146
+ class BaseFairseqModel(nn.Module):
1147
+ def __init__(self):
1148
+ super().__init__()
1149
+ self._is_generation_fast = False
1150
+
1151
+ def get_targets(self, sample, net_output):
1152
+ return sample["target"]
1153
+
1154
+ def extract_features(self, *args, **kwargs):
1155
+ return self(*args, **kwargs)
1156
+
1157
+ def load_state_dict(self, state_dict, strict=True, model_cfg = None, args = None):
1158
+ self.upgrade_state_dict(state_dict)
1159
+ new_state_dict = prune_state_dict(state_dict, model_cfg)
1160
+ return super().load_state_dict(new_state_dict, strict)
1161
+
1162
+ def upgrade_state_dict(self, state_dict):
1163
+ self.upgrade_state_dict_named(state_dict, "")
1164
+
1165
+ def upgrade_state_dict_named(self, state_dict, name):
1166
+ assert state_dict is not None
1167
+
1168
+ def do_upgrade(m, prefix):
1169
+ if len(prefix) > 0: prefix += "."
1170
+ for n, c in m.named_children():
1171
+ name = prefix + n
1172
+ if hasattr(c, "upgrade_state_dict_named"): c.upgrade_state_dict_named(state_dict, name)
1173
+ elif hasattr(c, "upgrade_state_dict"): c.upgrade_state_dict(state_dict)
1174
+ do_upgrade(c, name)
1175
+
1176
+ do_upgrade(self, name)
1177
+
1178
+ def make_generation_fast_(self, **kwargs):
1179
+ if self._is_generation_fast: return
1180
+ self._is_generation_fast = True
1181
+
1182
+ def apply_remove_weight_norm(module):
1183
+ try:
1184
+ nn.utils.remove_weight_norm(module)
1185
+ except (AttributeError, ValueError):
1186
+ return
1187
+
1188
+ self.apply(apply_remove_weight_norm)
1189
+ def apply_make_generation_fast_(module, prefix):
1190
+ if len(prefix) > 0: prefix += "."
1191
+
1192
+ base_func = BaseFairseqModel.make_generation_fast_
1193
+ for n, m in module.named_modules():
1194
+ if (m != self and hasattr(m, "make_generation_fast_") and m.make_generation_fast_.__func__ is not base_func): m.make_generation_fast_(name=prefix + n, **kwargs)
1195
+
1196
+ apply_make_generation_fast_(self, "")
1197
+ self.eval()
1198
+
1199
+ class HubertConfig:
1200
+ def __init__(self, _name, label_rate, encoder_layers_1, logit_temp_ctr, num_negatives, cross_sample_negatives, ctr_layers, extractor_mode = "default", encoder_layers = 12, encoder_embed_dim = 768, encoder_ffn_embed_dim = 3072, encoder_attention_heads = 12, activation_fn = "gelu", layer_type = "transformer", dropout = 0.1, attention_dropout = 0.1, activation_dropout = 0.0, encoder_layerdrop = 0.0, dropout_input = 0.0, dropout_features = 0.0, final_dim = 0, untie_final_proj = False, layer_norm_first = False, conv_feature_layers = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", conv_bias = False, logit_temp = 0.1, target_glu = False, feature_grad_mult = 1.0, mask_length = 10, mask_prob = 0.65, mask_selection = "static", mask_other = 0.0, no_mask_overlap = False, mask_min_space = 1, mask_channel_length = 10, mask_channel_prob = 0.0, mask_channel_selection = "static", mask_channel_other = 0.0, no_mask_channel_overlap = False, mask_channel_min_space = 1, conv_pos = 128, conv_pos_groups = 16, conv_pos_batch_norm = False, latent_temp = (2, 0.5, 0.999995), skip_masked = False, skip_nomask = False, checkpoint_activations = False, required_seq_len_multiple = 2, depthwise_conv_kernel_size = 31, attn_type = "", pos_enc_type = "abs", fp16 = False):
1201
+ self._name = _name
1202
+ self.label_rate = label_rate
1203
+ self.encoder_layers_1 = encoder_layers_1
1204
+ self.logit_temp_ctr = logit_temp_ctr
1205
+ self.num_negatives = num_negatives
1206
+ self.cross_sample_negatives = cross_sample_negatives
1207
+ self.ctr_layers = ctr_layers
1208
+ self.extractor_mode = extractor_mode
1209
+ self.encoder_layers = encoder_layers
1210
+ self.encoder_embed_dim = encoder_embed_dim
1211
+ self.encoder_ffn_embed_dim = encoder_ffn_embed_dim
1212
+ self.encoder_attention_heads = encoder_attention_heads
1213
+ self.activation_fn = activation_fn
1214
+ self.layer_type = layer_type
1215
+ self.dropout = dropout
1216
+ self.attention_dropout = attention_dropout
1217
+ self.activation_dropout = activation_dropout
1218
+ self.encoder_layerdrop = encoder_layerdrop
1219
+ self.dropout_input = encoder_layerdrop
1220
+ self.dropout_features = dropout_features
1221
+ self.final_dim = final_dim
1222
+ self.untie_final_proj = untie_final_proj
1223
+ self.layer_norm_first = layer_norm_first
1224
+ self.conv_feature_layers = conv_feature_layers
1225
+ self.conv_bias = conv_bias
1226
+ self.logit_temp = logit_temp
1227
+ self.target_glu = target_glu
1228
+ self.feature_grad_mult = feature_grad_mult
1229
+ self.mask_length = mask_length
1230
+ self.mask_prob = mask_prob
1231
+ self.mask_selection = mask_selection
1232
+ self.mask_other = mask_other
1233
+ self.no_mask_overlap = no_mask_overlap
1234
+ self.mask_min_space = mask_min_space
1235
+ self.mask_channel_length = mask_channel_length
1236
+ self.mask_channel_prob = mask_channel_prob
1237
+ self.mask_channel_selection = mask_channel_selection
1238
+ self.mask_channel_other = mask_channel_other
1239
+ self.no_mask_channel_overlap = no_mask_channel_overlap
1240
+ self.mask_channel_min_space = mask_channel_min_space
1241
+ self.conv_pos = conv_pos
1242
+ self.conv_pos_groups = conv_pos_groups
1243
+ self.conv_pos_batch_norm = conv_pos_batch_norm
1244
+ self.latent_temp = latent_temp
1245
+ self.skip_masked = skip_masked
1246
+ self.skip_nomask = skip_nomask
1247
+ self.checkpoint_activations = checkpoint_activations
1248
+ self.required_seq_len_multiple = required_seq_len_multiple
1249
+ self.depthwise_conv_kernel_size = depthwise_conv_kernel_size
1250
+ self.attn_type = attn_type
1251
+ self.pos_enc_type = pos_enc_type
1252
+ self.fp16 = fp16
1253
+
1254
+ class HubertModel(BaseFairseqModel):
1255
+ def __init__(self, cfg):
1256
+ super().__init__()
1257
+ feature_enc_layers = eval(cfg.conv_feature_layers)
1258
+ self.embed = feature_enc_layers[-1][0]
1259
+ self.feature_extractor = ConvFeatureExtractionModel(conv_layers=feature_enc_layers, dropout=0.0, mode=cfg.extractor_mode, conv_bias=cfg.conv_bias)
1260
+ feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers])
1261
+ self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / 16000
1262
+ self.post_extract_proj = (nn.Linear(self.embed, cfg.encoder_embed_dim) if self.embed != cfg.encoder_embed_dim else None)
1263
+ self.mask_prob = cfg.mask_prob
1264
+ self.mask_selection = cfg.mask_selection
1265
+ self.mask_other = cfg.mask_other
1266
+ self.mask_length = cfg.mask_length
1267
+ self.no_mask_overlap = cfg.no_mask_overlap
1268
+ self.mask_min_space = cfg.mask_min_space
1269
+ self.mask_channel_prob = cfg.mask_channel_prob
1270
+ self.mask_channel_selection = cfg.mask_channel_selection
1271
+ self.mask_channel_other = cfg.mask_channel_other
1272
+ self.mask_channel_length = cfg.mask_channel_length
1273
+ self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
1274
+ self.mask_channel_min_space = cfg.mask_channel_min_space
1275
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
1276
+ self.dropout_features = nn.Dropout(cfg.dropout_features)
1277
+ self.feature_grad_mult = cfg.feature_grad_mult
1278
+ self.logit_temp = cfg.logit_temp
1279
+ self.skip_masked = cfg.skip_masked
1280
+ self.skip_nomask = cfg.skip_nomask
1281
+ final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
1282
+ self.mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_())
1283
+ self.encoder = TransformerEncoder(cfg)
1284
+ self.layer_norm = LayerNorm(self.embed)
1285
+ self.target_glu = None
1286
+ if cfg.target_glu: self.target_glu = nn.Sequential(nn.Linear(final_dim, final_dim * 2), nn.GLU())
1287
+ self.untie_final_proj = cfg.untie_final_proj
1288
+ self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
1289
+ self.num_classes = [504]
1290
+ self.label_embs_concat = nn.Parameter(torch.FloatTensor(sum(self.num_classes), final_dim))
1291
+ nn.init.uniform_(self.label_embs_concat)
1292
+
1293
+ def upgrade_state_dict_named(self, state_dict, name):
1294
+ super().upgrade_state_dict_named(state_dict, name)
1295
+ return state_dict
1296
+
1297
+ def apply_mask(self, x, padding_mask, target_list):
1298
+ B, T, C = x.shape
1299
+ if self.mask_prob > 0:
1300
+ mask_indices = torch.from_numpy(compute_mask_indices((B, T), padding_mask, self.mask_prob, self.mask_length, self.mask_selection, self.mask_other, min_masks=2, no_overlap=self.no_mask_overlap, min_space=self.mask_min_space)).to(x.device)
1301
+ x[mask_indices] = self.mask_emb
1302
+ else: mask_indices = None
1303
+
1304
+ if self.mask_channel_prob > 0: x[(torch.from_numpy(compute_mask_indices((B, C), None, self.mask_channel_prob, self.mask_channel_length, self.mask_channel_selection, self.mask_channel_other, no_overlap=self.no_mask_channel_overlap, min_space=self.mask_channel_min_space)).to(x.device).unsqueeze(1).expand(-1, T, -1))] = 0
1305
+ return x, mask_indices
1306
+
1307
+ def compute_nce(self, x, pos, negs):
1308
+ neg_is_pos = (pos == negs).all(-1)
1309
+ logits = torch.cosine_similarity(x.float(), torch.cat([pos.unsqueeze(0), negs], dim=0).float(), dim=-1).type_as(x)
1310
+ logits /= self.logit_temp
1311
+ if neg_is_pos.any(): logits[1:][neg_is_pos] = float("-inf")
1312
+ return logits.transpose(0, 1)
1313
+
1314
+ def forward_features(self, source):
1315
+ if self.feature_grad_mult > 0:
1316
+ features = self.feature_extractor(source)
1317
+ if self.feature_grad_mult != 1.0: features = GradMultiply.apply(features, self.feature_grad_mult)
1318
+ else:
1319
+ with torch.no_grad():
1320
+ features = self.feature_extractor(source)
1321
+ return features
1322
+
1323
+ def forward_targets(self, features, target_list):
1324
+ feat_tsz = features.size(2)
1325
+ targ_tsz = min([t.size(1) for t in target_list])
1326
+ if self.feat2tar_ratio * feat_tsz > targ_tsz:
1327
+ feat_tsz = int(targ_tsz / self.feat2tar_ratio)
1328
+ features = features[..., :feat_tsz]
1329
+
1330
+ return features, [t[:, (torch.arange(feat_tsz).float() * self.feat2tar_ratio).long()] for t in target_list]
1331
+
1332
+ def forward_padding_mask(self, features, padding_mask):
1333
+ extra = padding_mask.size(1) % features.size(1)
1334
+ if extra > 0: padding_mask = padding_mask[:, :-extra]
1335
+ return padding_mask.view(padding_mask.size(0), features.size(1), -1).all(-1)
1336
+
1337
+ def forward(self, source, target_list = None, padding_mask = None, mask = True, features_only = False, output_layer = None):
1338
+ features = self.forward_features(source)
1339
+ if target_list is not None: features, target_list = self.forward_targets(features, target_list)
1340
+ features_pen = features.float().pow(2).mean()
1341
+ features = self.layer_norm(features.transpose(1, 2))
1342
+ unmasked_features = features.clone()
1343
+ if padding_mask is not None: padding_mask = self.forward_padding_mask(features, padding_mask)
1344
+ if self.post_extract_proj is not None: features = self.post_extract_proj(features)
1345
+ features = self.dropout_input(features)
1346
+ unmasked_features = self.dropout_features(unmasked_features)
1347
+ if mask: x, mask_indices = self.apply_mask(features, padding_mask, target_list)
1348
+ else: x, mask_indices = features, None
1349
+ x, _ = self.encoder(x, padding_mask=padding_mask, layer=None if output_layer is None else output_layer - 1)
1350
+ if features_only: return {"x": x, "padding_mask": padding_mask, "features": features}
1351
+
1352
+ def compute_pred(proj_x, target, label_embs):
1353
+ y = torch.index_select(label_embs, 0, target.long())
1354
+ negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1)
1355
+ if self.target_glu:
1356
+ y = self.target_glu(y)
1357
+ negs = self.target_glu(negs)
1358
+
1359
+ return self.compute_nce(proj_x, y, negs)
1360
+
1361
+ label_embs_list = self.label_embs_concat.split(self.num_classes, 0)
1362
+ if not self.skip_masked:
1363
+ masked_indices = torch.logical_and(~padding_mask, mask_indices)
1364
+ proj_x_m = self.final_proj(x[masked_indices])
1365
+ logit_m_list = [compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) for i, (proj_x_m, t) in enumerate(zip(proj_x_m.chunk(len(target_list), dim=-1) if self.untie_final_proj else [proj_x_m for _ in range(len(target_list))], target_list))]
1366
+ else: logit_m_list = [None for _ in target_list]
1367
+
1368
+ if not self.skip_nomask:
1369
+ nomask_indices = torch.logical_and(~padding_mask, ~mask_indices)
1370
+ proj_x_u = self.final_proj(x[nomask_indices])
1371
+ logit_u_list = [compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i]) for i, (proj_x_u, t) in enumerate(zip(proj_x_u.chunk(len(target_list), dim=-1) if self.untie_final_proj else [proj_x_u for _ in range(len(target_list))], target_list))]
1372
+ else: logit_u_list = [None for _ in target_list]
1373
+
1374
+ return {"logit_m_list": logit_m_list, "logit_u_list": logit_u_list, "padding_mask": padding_mask, "features_pen": features_pen}
1375
+
1376
+ def extract_features(self, source, padding_mask = None, mask = False, ret_conv = False, output_layer = None):
1377
+ res = self.forward(source, padding_mask=padding_mask, mask=mask, features_only=True, output_layer=output_layer)
1378
+ return res["features"] if ret_conv else res["x"], res["padding_mask"]
1379
+
1380
+ def get_logits(self, net_output, is_masked=True):
1381
+ return [x.float() for x in (net_output["logit_m_list"] if is_masked else net_output["logit_u_list"]) if x is not None]
1382
+
1383
+ def get_targets(self, net_output, is_masked=True):
1384
+ return [x.new_zeros(x.size(0), dtype=torch.long) for x in self.get_logits(net_output, is_masked)]
1385
+
1386
+ def get_extra_losses(self, net_output):
1387
+ extra_losses, names = [], []
1388
+ if "features_pen" in net_output:
1389
+ extra_losses.append(net_output["features_pen"])
1390
+ names.append("features_pen")
1391
+
1392
+ return extra_losses, names
1393
+
1394
+ def remove_pretraining_modules(self):
1395
+ self.target_glu = None
1396
+ self.final_proj = None
RVC/modules/gdown.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import sys
4
+ import json
5
+ import codecs
6
+ import tempfile
7
+ import requests
8
+
9
+ from urllib.parse import urlparse, parse_qs, unquote
10
+
11
+ def parse_url(url):
12
+ parsed = urlparse(url)
13
+ is_download_link = parsed.path.endswith("/uc")
14
+ if not parsed.hostname in ("drive.google.com", "docs.google.com"): return None, is_download_link
15
+ file_id = parse_qs(parsed.query).get("id", [None])[0]
16
+
17
+ if file_id is None:
18
+ for pattern in (r"^/file/d/(.*?)/(edit|view)$", r"^/file/u/[0-9]+/d/(.*?)/(edit|view)$", r"^/document/d/(.*?)/(edit|htmlview|view)$", r"^/document/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$", r"^/presentation/d/(.*?)/(edit|htmlview|view)$", r"^/presentation/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$", r"^/spreadsheets/d/(.*?)/(edit|htmlview|view)$", r"^/spreadsheets/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$"):
19
+ match = re.match(pattern, parsed.path)
20
+ if match:
21
+ file_id = match.group(1)
22
+ break
23
+ return file_id, is_download_link
24
+
25
+ def get_url_from_gdrive_confirmation(contents):
26
+ for pattern in (r'href="(\/uc\?export=download[^"]+)', r'href="/open\?id=([^"]+)"', r'"downloadUrl":"([^"]+)'):
27
+ match = re.search(pattern, contents)
28
+ if match:
29
+ url = match.group(1)
30
+ if pattern == r'href="/open\?id=([^"]+)"': url = (codecs.decode("uggcf://qevir.hfrepbagrag.tbbtyr.pbz/qbjaybnq?vq=", "rot13") + url + "&confirm=t&uuid=" + re.search(r'<input\s+type="hidden"\s+name="uuid"\s+value="([^"]+)"', contents).group(1))
31
+ elif pattern == r'"downloadUrl":"([^"]+)': url = url.replace("\\u003d", "=").replace("\\u0026", "&")
32
+ else: url = codecs.decode("uggcf://qbpf.tbbtyr.pbz", "rot13") + url.replace("&", "&")
33
+ return url
34
+
35
+ match = re.search(r'<p class="uc-error-subcaption">(.*)</p>', contents)
36
+ if match: raise Exception(match.group(1))
37
+ raise Exception
38
+
39
+ def _get_session(use_cookies, return_cookies_file=False):
40
+ sess = requests.session()
41
+ sess.headers.update({"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)"})
42
+ cookies_file = os.path.join(os.path.expanduser("~"), ".cache/gdown/cookies.json")
43
+
44
+ if os.path.exists(cookies_file) and use_cookies:
45
+ with open(cookies_file) as f:
46
+ for k, v in json.load(f):
47
+ sess.cookies[k] = v
48
+ return (sess, cookies_file) if return_cookies_file else sess
49
+
50
+ def gdown_download(url=None, id=None, output=None):
51
+ if not (id is None) ^ (url is None): raise ValueError
52
+ if id is not None: url = f"{codecs.decode('uggcf://qevir.tbbtyr.pbz/hp?vq=', 'rot13')}{id}"
53
+
54
+ url_origin = url
55
+ sess, cookies_file = _get_session(use_cookies=True, return_cookies_file=True)
56
+ gdrive_file_id, is_gdrive_download_link = parse_url(url)
57
+
58
+ if gdrive_file_id:
59
+ url = f"{codecs.decode('uggcf://qevir.tbbtyr.pbz/hp?vq=', 'rot13')}{gdrive_file_id}"
60
+ url_origin = url
61
+ is_gdrive_download_link = True
62
+
63
+ while 1:
64
+ res = sess.get(url, stream=True, verify=True)
65
+ if url == url_origin and res.status_code == 500:
66
+ url = f"{codecs.decode('uggcf://qevir.tbbtyr.pbz/bcra?vq=', 'rot13')}{gdrive_file_id}"
67
+ continue
68
+
69
+ os.makedirs(os.path.dirname(cookies_file), exist_ok=True)
70
+ with open(cookies_file, "w") as f:
71
+ json.dump([(k, v) for k, v in sess.cookies.items() if not k.startswith("download_warning_")], f, indent=2)
72
+
73
+ if "Content-Disposition" in res.headers: break
74
+ if not (gdrive_file_id and is_gdrive_download_link): break
75
+
76
+ try:
77
+ url = get_url_from_gdrive_confirmation(res.text)
78
+ except Exception as e:
79
+ raise Exception(e)
80
+
81
+ if gdrive_file_id and is_gdrive_download_link:
82
+ content_disposition = unquote(res.headers["Content-Disposition"])
83
+ filename_from_url = (re.search(r"filename\*=UTF-8''(.*)", content_disposition) or re.search(r'filename=["\']?(.*?)["\']?$', content_disposition)).group(1).replace(os.path.sep, "_")
84
+ else: filename_from_url = os.path.basename(url)
85
+
86
+ output = os.path.join(output or ".", filename_from_url)
87
+ tmp_file = tempfile.mktemp(suffix=tempfile.template, prefix=os.path.basename(output), dir=os.path.dirname(output))
88
+ f = open(tmp_file, "ab")
89
+
90
+ if tmp_file is not None and f.tell() != 0: res = sess.get(url, headers={"Range": f"bytes={f.tell()}-"}, stream=True, verify=True)
91
+ print("To:", os.path.abspath(output), file=sys.stderr)
92
+
93
+ try:
94
+ for chunk in res.iter_content(chunk_size=512 * 1024):
95
+ f.write(chunk)
96
+ if tmp_file: f.close()
97
+ finally:
98
+ os.rename(tmp_file, output)
99
+ sess.close()
100
+ return output
RVC/modules/generator.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import torch
5
+ import parselmouth
6
+
7
+ import numba as nb
8
+ import numpy as np
9
+
10
+ from librosa import yin, pyin
11
+ from scipy.signal import medfilt
12
+
13
+ sys.path.append(os.getcwd())
14
+
15
+ from modules.rmvpe import RMVPE
16
+ from modules.utils import Autotune
17
+ from modules.torchfcpe import FCPE
18
+ from modules.pyworld import PYWORLD
19
+ from modules.swipe import swipe, stonemask
20
+ from modules.torchcrepe import CREPE, mean, median
21
+
22
+ @nb.jit(nopython=True)
23
+ def post_process(f0, f0_up_key, f0_mel_min, f0_mel_max):
24
+ f0 = np.multiply(f0, pow(2, f0_up_key / 12))
25
+
26
+ f0_mel = 1127 * np.log(1 + f0 / 700)
27
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * 254 / (f0_mel_max - f0_mel_min) + 1
28
+ f0_mel[f0_mel <= 1] = 1
29
+ f0_mel[f0_mel > 255] = 255
30
+
31
+ return np.rint(f0_mel).astype(np.int32), f0
32
+
33
+ class Generator:
34
+ def __init__(self, sample_rate = 16000, hop_length = 160, f0_min = 50, f0_max = 1100, is_half = False, device = "cpu"):
35
+ self.sample_rate = sample_rate
36
+ self.hop_length = hop_length
37
+ self.f0_min = f0_min
38
+ self.f0_max = f0_max
39
+ self.is_half = is_half
40
+ self.device = device
41
+ self.window = 160
42
+ self.ref_freqs = [49.00, 51.91, 55.00, 58.27, 61.74, 65.41, 69.30, 73.42, 77.78, 82.41, 87.31, 92.50, 98.00, 103.83, 110.00, 116.54, 123.47, 130.81, 138.59, 146.83, 155.56, 164.81, 174.61, 185.00, 196.00, 207.65, 220.00, 233.08, 246.94, 261.63, 277.18, 293.66, 311.13, 329.63, 349.23, 369.99, 392.00, 415.30, 440.00, 466.16, 493.88, 523.25, 554.37, 587.33, 622.25, 659.25, 698.46, 739.99, 783.99, 830.61, 880.00, 932.33, 987.77, 1046.50]
43
+ self.autotune = Autotune(self.ref_freqs)
44
+ self.note_dict = self.autotune.note_dict
45
+
46
+ def calculator(self, f0_method, x, f0_up_key = 0, p_len = None, filter_radius = 3, f0_autotune = False, f0_autotune_strength = 1):
47
+ if p_len is None: p_len = x.shape[0] // self.window
48
+ f0 = self.compute_f0(f0_method, x, p_len, filter_radius if filter_radius % 2 != 0 else filter_radius + 1)
49
+
50
+ if isinstance(f0, tuple): f0 = f0[0]
51
+ if f0_autotune: f0 = Autotune.autotune_f0(self, f0, f0_autotune_strength)
52
+
53
+ return post_process(
54
+ f0,
55
+ f0_up_key,
56
+ 1127 * math.log(1 + self.f0_min / 700),
57
+ 1127 * math.log(1 + self.f0_max / 700),
58
+ )
59
+
60
+ def _resize_f0(self, x, target_len):
61
+ source = np.array(x)
62
+ source[source < 0.001] = np.nan
63
+
64
+ return np.nan_to_num(
65
+ np.interp(
66
+ np.arange(0, len(source) * target_len, len(source)) / target_len,
67
+ np.arange(0, len(source)),
68
+ source
69
+ )
70
+ )
71
+
72
+ def compute_f0(self, f0_method, x, p_len, filter_radius):
73
+ return {
74
+ "pm": lambda: self.get_f0_pm(x, p_len),
75
+ "dio": lambda: self.get_f0_pyworld(x, p_len, filter_radius, "dio"),
76
+ "mangio-crepe-tiny": lambda: self.get_f0_mangio_crepe(x, p_len, "tiny"),
77
+ "mangio-crepe-small": lambda: self.get_f0_mangio_crepe(x, p_len, "small"),
78
+ "mangio-crepe-medium": lambda: self.get_f0_mangio_crepe(x, p_len, "medium"),
79
+ "mangio-crepe-large": lambda: self.get_f0_mangio_crepe(x, p_len, "large"),
80
+ "mangio-crepe-full": lambda: self.get_f0_mangio_crepe(x, p_len, "full"),
81
+ "crepe-tiny": lambda: self.get_f0_crepe(x, p_len, "tiny"),
82
+ "crepe-small": lambda: self.get_f0_crepe(x, p_len, "small"),
83
+ "crepe-medium": lambda: self.get_f0_crepe(x, p_len, "medium"),
84
+ "crepe-large": lambda: self.get_f0_crepe(x, p_len, "large"),
85
+ "crepe-full": lambda: self.get_f0_crepe(x, p_len, "full"),
86
+ "fcpe": lambda: self.get_f0_fcpe(x, p_len),
87
+ "fcpe-legacy": lambda: self.get_f0_fcpe(x, p_len, legacy=True),
88
+ "rmvpe": lambda: self.get_f0_rmvpe(x, p_len),
89
+ "rmvpe-legacy": lambda: self.get_f0_rmvpe(x, p_len, legacy=True),
90
+ "harvest": lambda: self.get_f0_pyworld(x, p_len, filter_radius, "harvest"),
91
+ "yin": lambda: self.get_f0_yin(x, p_len, mode="yin"),
92
+ "pyin": lambda: self.get_f0_yin(x, p_len, mode="pyin"),
93
+ "swipe": lambda: self.get_f0_swipe(x, p_len)
94
+ }[f0_method]()
95
+
96
+ def get_f0_pm(self, x, p_len):
97
+ f0 = (
98
+ parselmouth.Sound(
99
+ x,
100
+ self.sample_rate
101
+ ).to_pitch_ac(
102
+ time_step=160 / self.sample_rate * 1000 / 1000,
103
+ voicing_threshold=0.6,
104
+ pitch_floor=self.f0_min,
105
+ pitch_ceiling=self.f0_max
106
+ ).selected_array["frequency"]
107
+ )
108
+
109
+ pad_size = (p_len - len(f0) + 1) // 2
110
+
111
+ if pad_size > 0 or p_len - len(f0) - pad_size > 0: f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
112
+ return f0
113
+
114
+ def get_f0_mangio_crepe(self, x, p_len, model="full"):
115
+ if not hasattr(self, "mangio_crepe"):
116
+ self.mangio_crepe = CREPE(
117
+ os.path.join(
118
+ "models",
119
+ f"crepe_{model}.pth"
120
+ ),
121
+ model_size=model,
122
+ hop_length=self.hop_length,
123
+ batch_size=self.hop_length * 2,
124
+ f0_min=self.f0_min,
125
+ f0_max=self.f0_max,
126
+ device=self.device,
127
+ sample_rate=self.sample_rate,
128
+ return_periodicity=False
129
+ )
130
+
131
+ x = x.astype(np.float32)
132
+ x /= np.quantile(np.abs(x), 0.999)
133
+
134
+ audio = torch.unsqueeze(torch.from_numpy(x).to(self.device, copy=True), dim=0)
135
+ if audio.ndim == 2 and audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True).detach()
136
+
137
+ f0 = self.mangio_crepe.compute_f0(audio.detach(), pad=True)
138
+ return self._resize_f0(f0.squeeze(0).cpu().float().numpy(), p_len)
139
+
140
+ def get_f0_crepe(self, x, p_len, model="full"):
141
+ if not hasattr(self, "crepe"):
142
+ self.crepe = CREPE(
143
+ os.path.join(
144
+ "models",
145
+ f"crepe_{model}.pth"
146
+ ),
147
+ model_size=model,
148
+ hop_length=self.hop_length,
149
+ batch_size=512,
150
+ f0_min=self.f0_min,
151
+ f0_max=self.f0_max,
152
+ device=self.device,
153
+ sample_rate=self.sample_rate,
154
+ return_periodicity=True
155
+ )
156
+
157
+ f0, pd = self.crepe.compute_f0(torch.tensor(np.copy(x))[None].float(), pad=True)
158
+ f0, pd = mean(f0, 3), median(pd, 3)
159
+ f0[pd < 0.1] = 0
160
+
161
+ return self._resize_f0(f0[0].cpu().numpy(), p_len)
162
+
163
+ def get_f0_fcpe(self, x, p_len, legacy=False):
164
+ if not hasattr(self, "fcpe"):
165
+ self.fcpe = FCPE(
166
+ os.path.join(
167
+ "models",
168
+ ("fcpe_legacy" if legacy else "fcpe") + ".pt"
169
+ ),
170
+ hop_length=self.hop_length,
171
+ f0_min=self.f0_min,
172
+ f0_max=self.f0_max,
173
+ dtype=torch.float32,
174
+ device=self.device,
175
+ sample_rate=self.sample_rate,
176
+ threshold=0.03 if legacy else 0.006,
177
+ legacy=legacy
178
+ )
179
+
180
+ f0 = self.fcpe.compute_f0(x, p_len)
181
+ return f0
182
+
183
+ def get_f0_rmvpe(self, x, p_len, legacy=False):
184
+ if not hasattr(self, "rmvpe"):
185
+ self.rmvpe = RMVPE(
186
+ os.path.join(
187
+ "models",
188
+ "rmvpe.pt"
189
+ ),
190
+ is_half=self.is_half,
191
+ device=self.device,
192
+ )
193
+
194
+ f0 = self.rmvpe.infer_from_audio_with_pitch(x, thred=0.03, f0_min=self.f0_min, f0_max=self.f0_max) if legacy else self.rmvpe.infer_from_audio(x, thred=0.03)
195
+ return self._resize_f0(f0, p_len)
196
+
197
+ def get_f0_pyworld(self, x, p_len, filter_radius, model="harvest"):
198
+ if not hasattr(self, "pw"): self.pw = PYWORLD()
199
+
200
+ x = x.astype(np.double)
201
+ pw = self.pw.harvest if model == "harvest" else self.pw.dio
202
+
203
+ f0, t = pw(
204
+ x,
205
+ fs=self.sample_rate,
206
+ f0_ceil=self.f0_max,
207
+ f0_floor=self.f0_min,
208
+ frame_period=1000 * self.window / self.sample_rate
209
+ )
210
+
211
+ f0 = self.pw.stonemask(
212
+ x,
213
+ self.sample_rate,
214
+ t,
215
+ f0
216
+ )
217
+
218
+ if filter_radius > 2 and model == "harvest": f0 = medfilt(f0, filter_radius)
219
+ elif model == "dio":
220
+ for index, pitch in enumerate(f0):
221
+ f0[index] = round(pitch, 1)
222
+
223
+ return self._resize_f0(f0, p_len)
224
+
225
+ def get_f0_swipe(self, x, p_len):
226
+ f0, t = swipe(
227
+ x.astype(np.float32),
228
+ self.sample_rate,
229
+ f0_floor=self.f0_min,
230
+ f0_ceil=self.f0_max,
231
+ frame_period=1000 * self.window / self.sample_rate
232
+ )
233
+
234
+ return self._resize_f0(
235
+ stonemask(
236
+ x,
237
+ self.sample_rate,
238
+ t,
239
+ f0
240
+ ),
241
+ p_len
242
+ )
243
+
244
+ def get_f0_yin(self, x, p_len, mode="yin"):
245
+ self.if_yin = mode == "yin"
246
+ self.yin = yin if self.if_yin else pyin
247
+
248
+ f0 = self.yin(
249
+ x.astype(np.float32),
250
+ sr=self.sample_rate,
251
+ fmin=self.f0_min,
252
+ fmax=self.f0_max,
253
+ hop_length=self.hop_length
254
+ )
255
+
256
+ if not self.if_yin: f0 = f0[0]
257
+ return self._resize_f0(f0, p_len)
RVC/modules/hifigan.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from torch.nn.utils import remove_weight_norm
7
+ from torch.nn.utils.parametrizations import weight_norm
8
+
9
+ sys.path.append(os.getcwd())
10
+
11
+ from modules.commons import init_weights
12
+ from modules.residuals import ResBlock, LRELU_SLOPE
13
+
14
+ class HiFiGANGenerator(torch.nn.Module):
15
+ def __init__(self, initial_channel, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
16
+ super(HiFiGANGenerator, self).__init__()
17
+ self.num_kernels = len(resblock_kernel_sizes)
18
+ self.num_upsamples = len(upsample_rates)
19
+ self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
20
+ self.ups_and_resblocks = torch.nn.ModuleList()
21
+
22
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
23
+ self.ups_and_resblocks.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2)))
24
+ ch = upsample_initial_channel // (2 ** (i + 1))
25
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
26
+ self.ups_and_resblocks.append(ResBlock(ch, k, d))
27
+
28
+ self.conv_post = torch.nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
29
+ self.ups_and_resblocks.apply(init_weights)
30
+ if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
31
+
32
+ def forward(self, x, g = None):
33
+ x = self.conv_pre(x)
34
+ if g is not None: x = x + self.cond(g)
35
+
36
+ resblock_idx = 0
37
+
38
+ for _ in range(self.num_upsamples):
39
+ x = self.ups_and_resblocks[resblock_idx](F.leaky_relu(x, LRELU_SLOPE))
40
+ resblock_idx += 1
41
+ xs = 0
42
+
43
+ for _ in range(self.num_kernels):
44
+ xs += self.ups_and_resblocks[resblock_idx](x)
45
+ resblock_idx += 1
46
+
47
+ x = xs / self.num_kernels
48
+
49
+ return torch.tanh(self.conv_post(F.leaky_relu(x)))
50
+
51
+ def __prepare_scriptable__(self):
52
+ for l in self.ups_and_resblocks:
53
+ for hook in l._forward_pre_hooks.values():
54
+ if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(l)
55
+
56
+ return self
57
+
58
+ def remove_weight_norm(self):
59
+ for l in self.ups_and_resblocks:
60
+ remove_weight_norm(l)
RVC/modules/mediafire.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import requests
4
+
5
+ from bs4 import BeautifulSoup
6
+
7
+ def Mediafire_Download(url, output=None, filename=None):
8
+ if not filename: filename = url.split('/')[-2]
9
+ if not output: output = os.path.dirname(os.path.realpath(__file__))
10
+ output_file = os.path.join(output, filename)
11
+
12
+ sess = requests.session()
13
+ sess.headers.update({"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)"})
14
+
15
+ try:
16
+ with requests.get(BeautifulSoup(sess.get(url).content, "html.parser").find(id="downloadButton").get("href"), stream=True) as r:
17
+ r.raise_for_status()
18
+ with open(output_file, "wb") as f:
19
+ total_length = int(r.headers.get('content-length'))
20
+ download_progress = 0
21
+
22
+ for chunk in r.iter_content(chunk_size=1024):
23
+ download_progress += len(chunk)
24
+ f.write(chunk)
25
+ sys.stdout.write(f"\r[{filename}]: {int(100 * download_progress/total_length)}% ({round(download_progress/1024/1024, 2)}mb/{round(total_length/1024/1024, 2)}mb)")
26
+ sys.stdout.flush()
27
+ sys.stdout.write("\n")
28
+ return output_file
29
+ except Exception as e:
30
+ raise RuntimeError(e)
RVC/modules/meganz.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import codecs
5
+ import random
6
+ import base64
7
+ import struct
8
+ import shutil
9
+ import requests
10
+ import tempfile
11
+
12
+ from Crypto.Cipher import AES
13
+ from Crypto.Util import Counter
14
+
15
+ def makebyte(x):
16
+ return codecs.latin_1_encode(x)[0]
17
+
18
+ def a32_to_str(a):
19
+ return struct.pack('>%dI' % len(a), *a)
20
+
21
+ def get_chunks(size):
22
+ p, s = 0, 0x20000
23
+
24
+ while p + s < size:
25
+ yield(p, s)
26
+ p += s
27
+
28
+ if s < 0x100000: s += 0x20000
29
+
30
+ yield(p, size - p)
31
+
32
+ def aes_cbc_decrypt(data, key):
33
+ aes_cipher = AES.new(key, AES.MODE_CBC, makebyte('\0' * 16))
34
+ return aes_cipher.decrypt(data)
35
+
36
+ def decrypt_attr(attr, key):
37
+ attr = codecs.latin_1_decode(aes_cbc_decrypt(attr, a32_to_str(key)))[0].rstrip('\0')
38
+ return json.loads(attr[4:]) if attr[:6] == 'MEGA{"' else False
39
+
40
+ def _api_request(data):
41
+ sequence_num = random.randint(0, 0xFFFFFFFF)
42
+ params = {'id': sequence_num}
43
+ sequence_num += 1
44
+
45
+ if not isinstance(data, list): data = [data]
46
+ json_resp = json.loads(requests.post('{0}://g.api.{1}/cs'.format('https', 'mega.co.nz'), params=params, data=json.dumps(data), timeout=160).text)
47
+ if isinstance(json_resp, int): raise Exception(json_resp)
48
+
49
+ return json_resp[0]
50
+
51
+ def base64_url_decode(data):
52
+ data += '=='[(2 - len(data) * 3) % 4:]
53
+
54
+ for search, replace in (('-', '+'), ('_', '/'), (',', '')):
55
+ data = data.replace(search, replace)
56
+
57
+ return base64.b64decode(data)
58
+
59
+ def str_to_a32(b):
60
+ if isinstance(b, str): b = makebyte(b)
61
+ if len(b) % 4: b += b'\0' * (4 - len(b) % 4)
62
+ return struct.unpack('>%dI' % (len(b) / 4), b)
63
+
64
+ def base64_to_a32(s):
65
+ return str_to_a32(base64_url_decode(s))
66
+
67
+ def mega_download_file(file_handle, file_key, dest_path=None):
68
+ file_key = base64_to_a32(file_key)
69
+ file_data = _api_request({'a': 'g', 'g': 1, 'p': file_handle})
70
+
71
+ k = (file_key[0] ^ file_key[4], file_key[1] ^ file_key[5], file_key[2] ^ file_key[6], file_key[3] ^ file_key[7])
72
+ iv = file_key[4:6] + (0, 0)
73
+
74
+ if 'g' not in file_data: raise Exception
75
+
76
+ file_size = file_data['s']
77
+ attribs = decrypt_attr(base64_url_decode(file_data['at']), k)
78
+ input_file = requests.get(file_data['g'], stream=True).raw
79
+
80
+ temp_output_file = tempfile.NamedTemporaryFile(mode='w+b', prefix='megapy_', delete=False)
81
+ k_str = a32_to_str(k)
82
+ aes = AES.new(k_str, AES.MODE_CTR, counter=Counter.new(128, initial_value=((iv[0] << 32) + iv[1]) << 64))
83
+
84
+ mac_str = b'\0' * 16
85
+ mac_encryptor = AES.new(k_str, AES.MODE_CBC, mac_str)
86
+ iv_str = a32_to_str([iv[0], iv[1], iv[0], iv[1]])
87
+
88
+ for _, chunk_size in get_chunks(file_size):
89
+ chunk = aes.decrypt(input_file.read(chunk_size))
90
+ temp_output_file.write(chunk)
91
+
92
+ encryptor = AES.new(k_str, AES.MODE_CBC, iv_str)
93
+
94
+ for i in range(0, len(chunk) - 16, 16):
95
+ block = chunk[i:i + 16]
96
+ encryptor.encrypt(block)
97
+
98
+ i = (i + 16) if file_size > 16 else 0
99
+ block = chunk[i:i + 16]
100
+ if len(block) % 16: block += b'\0' * (16 - (len(block) % 16))
101
+
102
+ mac_str = mac_encryptor.encrypt(encryptor.encrypt(block))
103
+
104
+ file_mac = str_to_a32(mac_str)
105
+ temp_output_file.close()
106
+
107
+ if (file_mac[0] ^ file_mac[1], file_mac[2] ^ file_mac[3]) != file_key[6:8]: raise ValueError
108
+
109
+ file_path = os.path.join(dest_path, attribs['n'])
110
+ if os.path.exists(file_path): os.remove(file_path)
111
+
112
+ shutil.move(temp_output_file.name, file_path)
113
+
114
+ def mega_download_url(url, dest_path=None):
115
+ if '/file/' in url:
116
+ url = url.replace(' ', '')
117
+ file_id = re.findall(r'\W\w\w\w\w\w\w\w\w\W', url)[0][1:-1]
118
+ path = f'{file_id}!{url[re.search(file_id, url).end() + 1:]}'.split('!')
119
+ elif '!' in url: path = re.findall(r'/#!(.*)', url)[0].split('!')
120
+ else: raise Exception
121
+
122
+ return mega_download_file(path[0], path[1], dest_path)
RVC/modules/modules.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+ sys.path.append(os.getcwd())
6
+
7
+ from .commons import fused_add_tanh_sigmoid_multiply
8
+
9
+ class WaveNet(torch.nn.Module):
10
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
11
+ super(WaveNet, self).__init__()
12
+ assert kernel_size % 2 == 1
13
+ self.hidden_channels = hidden_channels
14
+ self.kernel_size = (kernel_size,)
15
+ self.dilation_rate = dilation_rate
16
+ self.n_layers = n_layers
17
+ self.gin_channels = gin_channels
18
+ self.p_dropout = p_dropout
19
+ self.in_layers = torch.nn.ModuleList()
20
+ self.res_skip_layers = torch.nn.ModuleList()
21
+ self.drop = torch.nn.Dropout(p_dropout)
22
+ if gin_channels != 0: self.cond_layer = torch.nn.utils.parametrizations.weight_norm(torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1), name="weight")
23
+ dilations = [dilation_rate ** i for i in range(n_layers)]
24
+ paddings = [(kernel_size * d - d) // 2 for d in dilations]
25
+
26
+ for i in range(n_layers):
27
+ in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilations[i], padding=paddings[i])
28
+ in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
29
+ self.in_layers.append(in_layer)
30
+ res_skip_channels = (hidden_channels if i == n_layers - 1 else 2 * hidden_channels)
31
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
32
+ res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
33
+ self.res_skip_layers.append(res_skip_layer)
34
+
35
+ def forward(self, x, x_mask, g=None):
36
+ output = x.clone().zero_()
37
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
38
+
39
+ if g is not None: g = self.cond_layer(g)
40
+
41
+ for i in range(self.n_layers):
42
+ x_in = self.in_layers[i](x)
43
+ g_l = (g[:, i * 2 * self.hidden_channels : (i + 1) * 2 * self.hidden_channels, :] if g is not None else 0)
44
+ res_skip_acts = self.res_skip_layers[i](self.drop(fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)))
45
+
46
+ if i < self.n_layers - 1:
47
+ x = (x + (res_skip_acts[:, : self.hidden_channels, :])) * x_mask
48
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
49
+ else: output = output + res_skip_acts
50
+
51
+ return output * x_mask
52
+
53
+ def remove_weight_norm(self):
54
+ if self.gin_channels != 0: torch.nn.utils.remove_weight_norm(self.cond_layer)
55
+
56
+ for l in self.in_layers:
57
+ torch.nn.utils.remove_weight_norm(l)
58
+
59
+ for l in self.res_skip_layers:
60
+ torch.nn.utils.remove_weight_norm(l)
RVC/modules/mrf_hifigan.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+ import numpy as np
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from torch.nn.utils import remove_weight_norm
9
+ from torch.utils.checkpoint import checkpoint
10
+ from torch.nn.utils.parametrizations import weight_norm
11
+
12
+ LRELU_SLOPE = 0.1
13
+
14
+ class MRFLayer(nn.Module):
15
+ def __init__(self, channels, kernel_size, dilation):
16
+ super().__init__()
17
+ self.conv1 = weight_norm(nn.Conv1d(channels, channels, kernel_size, padding=(kernel_size * dilation - dilation) // 2, dilation=dilation))
18
+ self.conv2 = weight_norm(nn.Conv1d(channels, channels, kernel_size, padding=kernel_size // 2, dilation=1))
19
+
20
+ def forward(self, x):
21
+ return x + self.conv2(F.leaky_relu(self.conv1(F.leaky_relu(x, LRELU_SLOPE)), LRELU_SLOPE))
22
+
23
+ def remove_weight_norm(self):
24
+ remove_weight_norm(self.conv1)
25
+ remove_weight_norm(self.conv2)
26
+
27
+ class MRFBlock(nn.Module):
28
+ def __init__(self, channels, kernel_size, dilations):
29
+ super().__init__()
30
+ self.layers = nn.ModuleList()
31
+
32
+ for dilation in dilations:
33
+ self.layers.append(MRFLayer(channels, kernel_size, dilation))
34
+
35
+ def forward(self, x):
36
+ for layer in self.layers:
37
+ x = layer(x)
38
+
39
+ return x
40
+
41
+ def remove_weight_norm(self):
42
+ for layer in self.layers:
43
+ layer.remove_weight_norm()
44
+
45
+ class SineGenerator(nn.Module):
46
+ def __init__(self, samp_rate, harmonic_num = 0, sine_amp = 0.1, noise_std = 0.003, voiced_threshold = 0):
47
+ super(SineGenerator, self).__init__()
48
+ self.sine_amp = sine_amp
49
+ self.noise_std = noise_std
50
+ self.harmonic_num = harmonic_num
51
+ self.dim = self.harmonic_num + 1
52
+ self.sampling_rate = samp_rate
53
+ self.voiced_threshold = voiced_threshold
54
+
55
+ def _f02uv(self, f0):
56
+ return torch.ones_like(f0) * (f0 > self.voiced_threshold)
57
+
58
+ def _f02sine(self, f0_values):
59
+ rad_values = (f0_values / self.sampling_rate) % 1
60
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], dtype=f0_values.dtype, device=f0_values.device)
61
+ rand_ini[:, 0] = 0
62
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
63
+ tmp_over_one = torch.cumsum(rad_values, 1) % 1
64
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
65
+ cumsum_shift = torch.zeros_like(rad_values)
66
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
67
+
68
+ return torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
69
+
70
+ def forward(self, f0):
71
+ with torch.no_grad():
72
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, dtype=f0.dtype, device=f0.device)
73
+ f0_buf[:, :, 0] = f0[:, :, 0]
74
+
75
+ for idx in np.arange(self.harmonic_num):
76
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
77
+
78
+ sine_waves = self._f02sine(f0_buf) * self.sine_amp
79
+ uv = self._f02uv(f0)
80
+ sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
81
+
82
+ return sine_waves
83
+
84
+ class SourceModuleHnNSF(nn.Module):
85
+ def __init__(self, sampling_rate, harmonic_num = 0, sine_amp = 0.1, add_noise_std = 0.003, voiced_threshold = 0):
86
+ super(SourceModuleHnNSF, self).__init__()
87
+ self.sine_amp = sine_amp
88
+ self.noise_std = add_noise_std
89
+ self.l_sin_gen = SineGenerator(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshold)
90
+ self.l_linear = nn.Linear(harmonic_num + 1, 1)
91
+ self.l_tanh = nn.Tanh()
92
+
93
+ def forward(self, x):
94
+ return self.l_tanh(self.l_linear(self.l_sin_gen(x).to(dtype=self.l_linear.weight.dtype)))
95
+
96
+ class HiFiGANMRFGenerator(nn.Module):
97
+ def __init__(self, in_channel, upsample_initial_channel, upsample_rates, upsample_kernel_sizes, resblock_kernel_sizes, resblock_dilations, gin_channels, sample_rate, harmonic_num, checkpointing = False):
98
+ super().__init__()
99
+ self.num_kernels = len(resblock_kernel_sizes)
100
+ self.checkpointing = checkpointing
101
+ self.f0_upsample = nn.Upsample(scale_factor=np.prod(upsample_rates))
102
+ self.m_source = SourceModuleHnNSF(sample_rate, harmonic_num)
103
+ self.conv_pre = weight_norm(nn.Conv1d(in_channel, upsample_initial_channel, kernel_size=7, stride=1, padding=3))
104
+ self.upsamples = nn.ModuleList()
105
+ self.noise_convs = nn.ModuleList()
106
+ stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1 for i in range(len(upsample_rates))]
107
+
108
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
109
+ self.upsamples.append(weight_norm(nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), kernel_size=k, stride=u, padding=((k - u) // 2) if u % 2 == 0 else (u // 2 + u % 2), output_padding=u % 2)))
110
+ stride = stride_f0s[i]
111
+ kernel = 1 if stride == 1 else stride * 2 - stride % 2
112
+ self.noise_convs.append(nn.Conv1d(1, upsample_initial_channel // (2 ** (i + 1)), kernel_size=kernel, stride=stride, padding=0 if stride == 1 else (kernel - stride) // 2))
113
+
114
+ self.mrfs = nn.ModuleList()
115
+ for i in range(len(self.upsamples)):
116
+ channel = upsample_initial_channel // (2 ** (i + 1))
117
+ self.mrfs.append(nn.ModuleList([MRFBlock(channel, kernel_size=k, dilations=d) for k, d in zip(resblock_kernel_sizes, resblock_dilations)]))
118
+
119
+ self.conv_post = weight_norm(nn.Conv1d(channel, 1, kernel_size=7, stride=1, padding=3))
120
+ if gin_channels != 0: self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
121
+
122
+ def forward(self, x, f0, g = None):
123
+ har_source = self.m_source(self.f0_upsample(f0[:, None, :]).transpose(-1, -2)).transpose(-1, -2)
124
+ x = self.conv_pre(x)
125
+ if g is not None: x += self.cond(g)
126
+
127
+ for ups, mrf, noise_conv in zip(self.upsamples, self.mrfs, self.noise_convs):
128
+ x = F.leaky_relu(x, LRELU_SLOPE)
129
+
130
+ if self.training and self.checkpointing:
131
+ x = checkpoint(ups, x, use_reentrant=False) + noise_conv(har_source)
132
+ xs = sum([checkpoint(layer, x, use_reentrant=False) for layer in mrf])
133
+ else:
134
+ x = ups(x) + noise_conv(har_source)
135
+ xs = sum([layer(x) for layer in mrf])
136
+
137
+ x = xs / self.num_kernels
138
+
139
+ return torch.tanh(self.conv_post(F.leaky_relu(x)))
140
+
141
+ def remove_weight_norm(self):
142
+ remove_weight_norm(self.conv_pre)
143
+
144
+ for up in self.upsamples:
145
+ remove_weight_norm(up)
146
+
147
+ for mrf in self.mrfs:
148
+ mrf.remove_weight_norm()
149
+
150
+ remove_weight_norm(self.conv_post)
RVC/modules/noisereduce.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tempfile
3
+ import numpy as np
4
+
5
+ from joblib import Parallel, delayed
6
+ from torch.nn.functional import conv1d, conv2d
7
+
8
+ @torch.no_grad()
9
+ def amp_to_db(x, eps = torch.finfo(torch.float32).eps, top_db = 40):
10
+ x_db = 20 * torch.log10(x.abs() + eps)
11
+ return torch.max(x_db, (x_db.max(-1).values - top_db).unsqueeze(-1))
12
+
13
+ @torch.no_grad()
14
+ def temperature_sigmoid(x, x0, temp_coeff):
15
+ return torch.sigmoid((x - x0) / temp_coeff)
16
+
17
+ @torch.no_grad()
18
+ def linspace(start, stop, num = 50, endpoint = True, **kwargs):
19
+ return torch.linspace(start, stop, num, **kwargs) if endpoint else torch.linspace(start, stop, num + 1, **kwargs)[:-1]
20
+
21
+ def _smoothing_filter(n_grad_freq, n_grad_time):
22
+ smoothing_filter = np.outer(np.concatenate([np.linspace(0, 1, n_grad_freq + 1, endpoint=False), np.linspace(1, 0, n_grad_freq + 2)])[1:-1], np.concatenate([np.linspace(0, 1, n_grad_time + 1, endpoint=False), np.linspace(1, 0, n_grad_time + 2)])[1:-1])
23
+ return smoothing_filter / np.sum(smoothing_filter)
24
+
25
+ class SpectralGate:
26
+ def __init__(self, y, sr, prop_decrease, chunk_size, padding, n_fft, win_length, hop_length, time_constant_s, freq_mask_smooth_hz, time_mask_smooth_ms, tmp_folder, use_tqdm, n_jobs):
27
+ self.sr = sr
28
+ self.flat = False
29
+ y = np.array(y)
30
+
31
+ if len(y.shape) == 1:
32
+ self.y = np.expand_dims(y, 0)
33
+ self.flat = True
34
+ elif len(y.shape) > 2: raise ValueError
35
+ else: self.y = y
36
+
37
+ self._dtype = y.dtype
38
+ self.n_channels, self.n_frames = self.y.shape
39
+ self._chunk_size = chunk_size
40
+ self.padding = padding
41
+ self.n_jobs = n_jobs
42
+ self.use_tqdm = use_tqdm
43
+ self._tmp_folder = tmp_folder
44
+ self._n_fft = n_fft
45
+ self._win_length = self._n_fft if win_length is None else win_length
46
+ self._hop_length = (self._win_length // 4) if hop_length is None else hop_length
47
+ self._time_constant_s = time_constant_s
48
+ self._prop_decrease = prop_decrease
49
+
50
+ if (freq_mask_smooth_hz is None) & (time_mask_smooth_ms is None): self.smooth_mask = False
51
+ else: self._generate_mask_smoothing_filter(freq_mask_smooth_hz, time_mask_smooth_ms)
52
+
53
+ def _generate_mask_smoothing_filter(self, freq_mask_smooth_hz, time_mask_smooth_ms):
54
+ if freq_mask_smooth_hz is None: n_grad_freq = 1
55
+ else:
56
+ n_grad_freq = int(freq_mask_smooth_hz / (self.sr / (self._n_fft / 2)))
57
+ if n_grad_freq < 1: raise ValueError
58
+
59
+ if time_mask_smooth_ms is None: n_grad_time = 1
60
+ else:
61
+ n_grad_time = int(time_mask_smooth_ms / ((self._hop_length / self.sr) * 1000))
62
+ if n_grad_time < 1: raise ValueError
63
+
64
+ if (n_grad_time == 1) & (n_grad_freq == 1): self.smooth_mask = False
65
+ else:
66
+ self.smooth_mask = True
67
+ self._smoothing_filter = _smoothing_filter(n_grad_freq, n_grad_time)
68
+
69
+ def _read_chunk(self, i1, i2):
70
+ i1b = 0 if i1 < 0 else i1
71
+ i2b = self.n_frames if i2 > self.n_frames else i2
72
+ chunk = np.zeros((self.n_channels, i2 - i1))
73
+ chunk[:, i1b - i1: i2b - i1] = self.y[:, i1b:i2b]
74
+ return chunk
75
+
76
+ def filter_chunk(self, start_frame, end_frame):
77
+ i1 = start_frame - self.padding
78
+ return self._do_filter(self._read_chunk(i1, (end_frame + self.padding)))[:, start_frame - i1: end_frame - i1]
79
+
80
+ def _get_filtered_chunk(self, ind):
81
+ start0 = ind * self._chunk_size
82
+ end0 = (ind + 1) * self._chunk_size
83
+ return self.filter_chunk(start_frame=start0, end_frame=end0)
84
+
85
+ def _do_filter(self, chunk):
86
+ pass
87
+
88
+ def _iterate_chunk(self, filtered_chunk, pos, end0, start0, ich):
89
+ filtered_chunk[:, pos: pos + end0 - start0] = self._get_filtered_chunk(ich)[:, start0:end0]
90
+ pos += end0 - start0
91
+
92
+ def get_traces(self, start_frame=None, end_frame=None):
93
+ if start_frame is None: start_frame = 0
94
+ if end_frame is None: end_frame = self.n_frames
95
+
96
+ if self._chunk_size is not None:
97
+ if end_frame - start_frame > self._chunk_size:
98
+ ich1 = int(start_frame / self._chunk_size)
99
+ ich2 = int((end_frame - 1) / self._chunk_size)
100
+
101
+ with tempfile.NamedTemporaryFile(prefix=self._tmp_folder) as fp:
102
+ filtered_chunk = np.memmap(fp, dtype=self._dtype, shape=(self.n_channels, int(end_frame - start_frame)), mode="w+")
103
+ pos_list, start_list, end_list = [], [], []
104
+ pos = 0
105
+
106
+ for ich in range(ich1, ich2 + 1):
107
+ start0 = (start_frame - ich * self._chunk_size) if ich == ich1 else 0
108
+ end0 = end_frame - ich * self._chunk_size if ich == ich2 else self._chunk_size
109
+ pos_list.append(pos)
110
+ start_list.append(start0)
111
+ end_list.append(end0)
112
+ pos += end0 - start0
113
+
114
+ Parallel(n_jobs=self.n_jobs)(delayed(self._iterate_chunk)(filtered_chunk, pos, end0, start0, ich) for pos, start0, end0, ich in zip(pos_list, start_list, end_list, range(ich1, ich2 + 1)))
115
+ return filtered_chunk.astype(self._dtype).flatten() if self.flat else filtered_chunk.astype(self._dtype)
116
+
117
+ filtered_chunk = self.filter_chunk(start_frame=0, end_frame=end_frame)
118
+ return filtered_chunk.astype(self._dtype).flatten() if self.flat else filtered_chunk.astype(self._dtype)
119
+
120
+ class TG(torch.nn.Module):
121
+ @torch.no_grad()
122
+ def __init__(self, sr, nonstationary = False, n_std_thresh_stationary = 1.5, n_thresh_nonstationary = 1.3, temp_coeff_nonstationary = 0.1, n_movemean_nonstationary = 20, prop_decrease = 1.0, n_fft = 1024, win_length = None, hop_length = None, freq_mask_smooth_hz = 500, time_mask_smooth_ms = 50):
123
+ super().__init__()
124
+ self.sr = sr
125
+ self.nonstationary = nonstationary
126
+ assert 0.0 <= prop_decrease <= 1.0
127
+ self.prop_decrease = prop_decrease
128
+ self.n_fft = n_fft
129
+ self.win_length = self.n_fft if win_length is None else win_length
130
+ self.hop_length = self.win_length // 4 if hop_length is None else hop_length
131
+ self.n_std_thresh_stationary = n_std_thresh_stationary
132
+ self.temp_coeff_nonstationary = temp_coeff_nonstationary
133
+ self.n_movemean_nonstationary = n_movemean_nonstationary
134
+ self.n_thresh_nonstationary = n_thresh_nonstationary
135
+ self.freq_mask_smooth_hz = freq_mask_smooth_hz
136
+ self.time_mask_smooth_ms = time_mask_smooth_ms
137
+ self.register_buffer("smoothing_filter", self._generate_mask_smoothing_filter())
138
+
139
+ @torch.no_grad()
140
+ def _generate_mask_smoothing_filter(self):
141
+ if self.freq_mask_smooth_hz is None and self.time_mask_smooth_ms is None: return None
142
+ n_grad_freq = (1 if self.freq_mask_smooth_hz is None else int(self.freq_mask_smooth_hz / (self.sr / (self.n_fft / 2))))
143
+ if n_grad_freq < 1: raise ValueError
144
+
145
+ n_grad_time = (1 if self.time_mask_smooth_ms is None else int(self.time_mask_smooth_ms / ((self.hop_length / self.sr) * 1000)))
146
+ if n_grad_time < 1: raise ValueError
147
+ if n_grad_time == 1 and n_grad_freq == 1: return None
148
+
149
+ smoothing_filter = torch.outer(torch.cat([linspace(0, 1, n_grad_freq + 1, endpoint=False), linspace(1, 0, n_grad_freq + 2)])[1:-1], torch.cat([linspace(0, 1, n_grad_time + 1, endpoint=False), linspace(1, 0, n_grad_time + 2)])[1:-1]).unsqueeze(0).unsqueeze(0)
150
+ return smoothing_filter / smoothing_filter.sum()
151
+
152
+ @torch.no_grad()
153
+ def _stationary_mask(self, X_db, xn = None):
154
+ XN_db = amp_to_db(torch.stft(xn, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, return_complex=True, pad_mode="constant", center=True, window=torch.hann_window(self.win_length).to(xn.device))).to(dtype=X_db.dtype) if xn is not None else X_db
155
+ std_freq_noise, mean_freq_noise = torch.std_mean(XN_db, dim=-1)
156
+ return torch.gt(X_db, (mean_freq_noise + std_freq_noise * self.n_std_thresh_stationary).unsqueeze(2))
157
+
158
+ @torch.no_grad()
159
+ def _nonstationary_mask(self, X_abs):
160
+ X_smoothed = (conv1d(X_abs.reshape(-1, 1, X_abs.shape[-1]), torch.ones(self.n_movemean_nonstationary, dtype=X_abs.dtype, device=X_abs.device).view(1, 1, -1), padding="same").view(X_abs.shape) / self.n_movemean_nonstationary)
161
+ return temperature_sigmoid(((X_abs - X_smoothed) / X_smoothed), self.n_thresh_nonstationary, self.temp_coeff_nonstationary)
162
+
163
+ def forward(self, x, xn = None):
164
+ assert x.ndim == 2
165
+ if x.shape[-1] < self.win_length * 2: raise Exception
166
+ assert xn is None or xn.ndim == 1 or xn.ndim == 2
167
+ if xn is not None and xn.shape[-1] < self.win_length * 2: raise Exception
168
+
169
+ X = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, return_complex=True, pad_mode="constant", center=True, window=torch.hann_window(self.win_length).to(x.device))
170
+ sig_mask = self._nonstationary_mask(X.abs()) if self.nonstationary else self._stationary_mask(amp_to_db(X), xn)
171
+
172
+ sig_mask = self.prop_decrease * (sig_mask * 1.0 - 1.0) + 1.0
173
+ if self.smoothing_filter is not None: sig_mask = conv2d(sig_mask.unsqueeze(1), self.smoothing_filter.to(sig_mask.dtype), padding="same")
174
+
175
+ Y = X * sig_mask.squeeze(1)
176
+ return torch.istft(Y, n_fft=self.n_fft, hop_length=self.hop_length, win_length=self.win_length, center=True, window=torch.hann_window(self.win_length).to(Y.device)).to(dtype=x.dtype)
177
+
178
+ class StreamedTorchGate(SpectralGate):
179
+ def __init__(self, y, sr, stationary=False, y_noise=None, prop_decrease=1.0, time_constant_s=2.0, freq_mask_smooth_hz=500, time_mask_smooth_ms=50, thresh_n_mult_nonstationary=2, sigmoid_slope_nonstationary=10, n_std_thresh_stationary=1.5, tmp_folder=None, chunk_size=600000, padding=30000, n_fft=1024, win_length=None, hop_length=None, clip_noise_stationary=True, use_tqdm=False, n_jobs=1, device="cpu"):
180
+ super().__init__(y=y, sr=sr, chunk_size=chunk_size, padding=padding, n_fft=n_fft, win_length=win_length, hop_length=hop_length, time_constant_s=time_constant_s, freq_mask_smooth_hz=freq_mask_smooth_hz, time_mask_smooth_ms=time_mask_smooth_ms, tmp_folder=tmp_folder, prop_decrease=prop_decrease, use_tqdm=use_tqdm, n_jobs=n_jobs)
181
+ self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
182
+
183
+ if y_noise is not None:
184
+ if y_noise.shape[-1] > y.shape[-1] and clip_noise_stationary: y_noise = y_noise[: y.shape[-1]]
185
+ y_noise = torch.from_numpy(y_noise).to(device)
186
+ if len(y_noise.shape) == 1: y_noise = y_noise.unsqueeze(0)
187
+
188
+ self.y_noise = y_noise
189
+ self.tg = TG(sr=sr, nonstationary=not stationary, n_std_thresh_stationary=n_std_thresh_stationary, n_thresh_nonstationary=thresh_n_mult_nonstationary, temp_coeff_nonstationary=1 / sigmoid_slope_nonstationary, n_movemean_nonstationary=int(time_constant_s / self._hop_length * sr), prop_decrease=prop_decrease, n_fft=self._n_fft, win_length=self._win_length, hop_length=self._hop_length, freq_mask_smooth_hz=freq_mask_smooth_hz, time_mask_smooth_ms=time_mask_smooth_ms).to(device)
190
+
191
+ def _do_filter(self, chunk):
192
+ if type(chunk) is np.ndarray: chunk = torch.from_numpy(chunk).to(self.device)
193
+ return self.tg(x=chunk, xn=self.y_noise).cpu().detach().numpy()
194
+
195
+ def reduce_noise(y, sr, stationary=False, y_noise=None, prop_decrease=1.0, time_constant_s=2.0, freq_mask_smooth_hz=500, time_mask_smooth_ms=50, thresh_n_mult_nonstationary=2, sigmoid_slope_nonstationary=10, tmp_folder=None, chunk_size=600000, padding=30000, n_fft=1024, win_length=None, hop_length=None, clip_noise_stationary=True, use_tqdm=False, device="cpu"):
196
+ return StreamedTorchGate(y=y, sr=sr, stationary=stationary, y_noise=y_noise, prop_decrease=prop_decrease, time_constant_s=time_constant_s, freq_mask_smooth_hz=freq_mask_smooth_hz, time_mask_smooth_ms=time_mask_smooth_ms, thresh_n_mult_nonstationary=thresh_n_mult_nonstationary, sigmoid_slope_nonstationary=sigmoid_slope_nonstationary, tmp_folder=tmp_folder, chunk_size=chunk_size, padding=padding, n_fft=n_fft, win_length=win_length, hop_length=hop_length, clip_noise_stationary=clip_noise_stationary, use_tqdm=use_tqdm, n_jobs=1, device=device).get_traces()
RVC/modules/normalization.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import torch.nn.functional as F
4
+
5
+ class LayerNorm(torch.nn.Module):
6
+ def __init__(self, channels, eps=1e-5):
7
+ super().__init__()
8
+ self.channels = channels
9
+ self.eps = eps
10
+ self.gamma = torch.nn.Parameter(torch.ones(channels))
11
+ self.beta = torch.nn.Parameter(torch.zeros(channels))
12
+
13
+ def forward(self, x):
14
+ x = x.transpose(1, -1)
15
+ return F.layer_norm(x, (x.size(-1),), self.gamma, self.beta, self.eps).transpose(1, -1)
RVC/modules/nsf_hifigan.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import torch
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+
8
+ from torch.nn.utils import remove_weight_norm
9
+ from torch.utils.checkpoint import checkpoint
10
+ from torch.nn.utils.parametrizations import weight_norm
11
+
12
+ sys.path.append(os.getcwd())
13
+
14
+ from modules.commons import init_weights
15
+ from modules.residuals import ResBlock, LRELU_SLOPE
16
+
17
+ class SineGen(torch.nn.Module):
18
+ def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0, flag_for_pulse=False):
19
+ super(SineGen, self).__init__()
20
+ self.sine_amp = sine_amp
21
+ self.noise_std = noise_std
22
+ self.harmonic_num = harmonic_num
23
+ self.dim = self.harmonic_num + 1
24
+ self.sampling_rate = samp_rate
25
+ self.voiced_threshold = voiced_threshold
26
+
27
+ def _f02uv(self, f0):
28
+ return torch.ones_like(f0) * (f0 > self.voiced_threshold)
29
+
30
+ def _f02sine(self, f0, upp):
31
+ rad = f0 / self.sampling_rate * torch.arange(1, upp + 1, dtype=f0.dtype, device=f0.device)
32
+ rad += F.pad((torch.fmod(rad[:, :-1, -1:].float() + 0.5, 1.0) - 0.5).cumsum(dim=1).fmod(1.0).to(f0), (0, 0, 1, 0), mode='constant')
33
+ rad = rad.reshape(f0.shape[0], -1, 1)
34
+ rad *= torch.arange(1, self.dim + 1, dtype=f0.dtype, device=f0.device).reshape(1, 1, -1)
35
+ rand_ini = torch.rand(1, 1, self.dim, device=f0.device)
36
+ rand_ini[..., 0] = 0
37
+ rad += rand_ini
38
+
39
+ return torch.sin(2 * np.pi * rad)
40
+
41
+ def forward(self, f0, upp):
42
+ with torch.no_grad():
43
+ f0 = f0.unsqueeze(-1)
44
+ sine_waves = self._f02sine(f0, upp) * self.sine_amp
45
+ uv = F.interpolate(self._f02uv(f0).transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1)
46
+ sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
47
+
48
+ return sine_waves
49
+
50
+ class SourceModuleHnNSF(torch.nn.Module):
51
+ def __init__(self, sample_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0):
52
+ super(SourceModuleHnNSF, self).__init__()
53
+ self.sine_amp = sine_amp
54
+ self.noise_std = add_noise_std
55
+ self.l_sin_gen = SineGen(sample_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
56
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
57
+ self.l_tanh = torch.nn.Tanh()
58
+
59
+ def forward(self, x, upsample_factor = 1):
60
+ return self.l_tanh(self.l_linear(self.l_sin_gen(x, upsample_factor).to(dtype=self.l_linear.weight.dtype)))
61
+
62
+ class HiFiGANNRFGenerator(torch.nn.Module):
63
+ def __init__(self, initial_channel, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, sr, checkpointing = False):
64
+ super(HiFiGANNRFGenerator, self).__init__()
65
+ self.num_kernels = len(resblock_kernel_sizes)
66
+ self.num_upsamples = len(upsample_rates)
67
+ self.upp = math.prod(upsample_rates)
68
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=self.upp)
69
+ self.m_source = SourceModuleHnNSF(sample_rate=sr, harmonic_num=0)
70
+
71
+ self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
72
+ self.checkpointing = checkpointing
73
+
74
+ self.ups = torch.nn.ModuleList()
75
+ self.noise_convs = torch.nn.ModuleList()
76
+
77
+ channels = [upsample_initial_channel // (2 ** (i + 1)) for i in range(self.num_upsamples)]
78
+ stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < self.num_upsamples else 1 for i in range(self.num_upsamples)]
79
+
80
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
81
+ self.ups.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), channels[i], k, u, padding=((k - u) // 2) if u % 2 == 0 else (u // 2 + u % 2), output_padding=u % 2)))
82
+ stride = stride_f0s[i]
83
+ kernel = 1 if stride == 1 else stride * 2 - stride % 2
84
+ self.noise_convs.append(torch.nn.Conv1d(1, channels[i], kernel_size=kernel, stride=stride, padding=0 if stride == 1 else (kernel - stride) // 2))
85
+
86
+ self.resblocks = torch.nn.ModuleList([ResBlock(channels[i], k, d) for i in range(len(self.ups)) for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes)])
87
+ self.conv_post = torch.nn.Conv1d(channels[-1], 1, 7, 1, padding=3, bias=False)
88
+
89
+ self.ups.apply(init_weights)
90
+ if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
91
+
92
+ def forward(self, x, f0, g = None):
93
+ har_source = self.m_source(f0, self.upp).transpose(1, 2)
94
+ x = self.conv_pre(x)
95
+ if g is not None: x += self.cond(g)
96
+
97
+ for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
98
+ x = F.leaky_relu(x, LRELU_SLOPE)
99
+
100
+ if self.training and self.checkpointing:
101
+ x = checkpoint(ups, x, use_reentrant=False) + noise_convs(har_source)
102
+ xs = sum([checkpoint(resblock, x, use_reentrant=False) for j, resblock in enumerate(self.resblocks) if j in range(i * self.num_kernels, (i + 1) * self.num_kernels)])
103
+ else:
104
+ x = ups(x) + noise_convs(har_source)
105
+ xs = sum([resblock(x) for j, resblock in enumerate(self.resblocks) if j in range(i * self.num_kernels, (i + 1) * self.num_kernels)])
106
+
107
+ x = xs / self.num_kernels
108
+
109
+ return torch.tanh(self.conv_post(F.leaky_relu(x)))
110
+
111
+ def remove_weight_norm(self):
112
+ for l in self.ups:
113
+ remove_weight_norm(l)
114
+
115
+ for l in self.resblocks:
116
+ l.remove_weight_norm()
RVC/modules/opencl.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import platform
3
+ import subprocess
4
+
5
+ import numpy as np
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from librosa.util import pad_center
10
+ from scipy.signal import get_window
11
+
12
+ try:
13
+ import pytorch_ocl
14
+ except:
15
+ pytorch_ocl = None
16
+
17
+ torch_available = pytorch_ocl != None
18
+
19
+ def get_amd_gpu_windows():
20
+ try:
21
+ return [gpu.strip() for gpu in subprocess.check_output("wmic path win32_VideoController get name", shell=True).decode().split('\n')[1:] if 'AMD' in gpu or 'Radeon' in gpu or 'Vega' in gpu]
22
+ except:
23
+ return []
24
+
25
+ def get_amd_gpu_linux():
26
+ try:
27
+ return [gpu for gpu in subprocess.check_output("lspci | grep VGA", shell=True).decode().split('\n') if 'AMD' in gpu or 'Radeon' in gpu or 'Vega' in gpu]
28
+ except:
29
+ return []
30
+
31
+ def get_gpu_list():
32
+ return (get_amd_gpu_windows() if platform.system() == "Windows" else get_amd_gpu_linux()) if torch_available else []
33
+
34
+ def device_count():
35
+ return len(get_gpu_list()) if torch_available else 0
36
+
37
+ def device_name(device_id = 0):
38
+ return (get_gpu_list()[device_id] if device_id >= 0 and device_id < device_count() else "") if torch_available else ""
39
+
40
+ def is_available():
41
+ return (device_count() > 0) if torch_available else False
42
+
43
+ class STFT(torch.nn.Module):
44
+ def __init__(self, filter_length=1024, hop_length=512, win_length=None, window="hann"):
45
+ super(STFT, self).__init__()
46
+ self.filter_length = filter_length
47
+ self.hop_length = hop_length
48
+ self.pad_amount = int(self.filter_length / 2)
49
+ self.win_length = win_length
50
+ self.hann_window = {}
51
+
52
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
53
+ cutoff = int((self.filter_length / 2 + 1))
54
+ fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])])
55
+ forward_basis = torch.FloatTensor(fourier_basis)
56
+ inverse_basis = torch.FloatTensor(np.linalg.pinv(fourier_basis))
57
+
58
+ if win_length is None or not win_length: win_length = filter_length
59
+ assert filter_length >= win_length
60
+
61
+ fft_window = torch.from_numpy(pad_center(get_window(window, win_length, fftbins=True), size=filter_length)).float()
62
+ forward_basis *= fft_window
63
+ inverse_basis = (inverse_basis.T * fft_window).T
64
+
65
+ self.register_buffer("forward_basis", forward_basis.float())
66
+ self.register_buffer("inverse_basis", inverse_basis.float())
67
+ self.register_buffer("fft_window", fft_window.float())
68
+
69
+ def transform(self, input_data, eps):
70
+ input_data = F.pad(input_data, (self.pad_amount, self.pad_amount), mode="reflect")
71
+ forward_transform = torch.matmul(self.forward_basis, input_data.unfold(1, self.filter_length, self.hop_length).permute(0, 2, 1))
72
+ cutoff = int(self.filter_length / 2 + 1)
73
+
74
+ return torch.sqrt(forward_transform[:, :cutoff, :]**2 + forward_transform[:, cutoff:, :]**2 + eps)
75
+
76
+ class GRU(nn.RNNBase):
77
+ def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0.0, bidirectional=False, device=None, dtype=None):
78
+ super().__init__("GRU", input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional, device=device, dtype=dtype)
79
+
80
+ @staticmethod
81
+ def _gru_cell(x, hx, weight_ih, bias_ih, weight_hh, bias_hh):
82
+ gate_x = F.linear(x, weight_ih, bias_ih)
83
+ gate_h = F.linear(hx, weight_hh, bias_hh)
84
+
85
+ i_r, i_i, i_n = gate_x.chunk(3, 1)
86
+ h_r, h_i, h_n = gate_h.chunk(3, 1)
87
+
88
+ resetgate = torch.sigmoid(i_r + h_r)
89
+ inputgate = torch.sigmoid(i_i + h_i)
90
+ newgate = torch.tanh(i_n + resetgate * h_n)
91
+
92
+ hy = newgate + inputgate * (hx - newgate)
93
+ return hy
94
+
95
+ def _gru_layer(self, x, hx, weights):
96
+ weight_ih, weight_hh, bias_ih, bias_hh = weights
97
+ outputs = []
98
+
99
+ for x_t in x.unbind(1):
100
+ hx = self._gru_cell(x_t, hx, weight_ih, bias_ih, weight_hh, bias_hh)
101
+ outputs.append(hx)
102
+
103
+ return torch.stack(outputs, dim=1), hx
104
+
105
+ def _gru(self, x, hx):
106
+ if not self.batch_first: x = x.permute(1, 0, 2)
107
+ num_directions = 2 if self.bidirectional else 1
108
+
109
+ h_n = []
110
+ output_fwd, output_bwd = x, x
111
+
112
+ for layer in range(self.num_layers):
113
+ fwd_idx = layer * num_directions
114
+ bwd_idx = fwd_idx + 1 if self.bidirectional else None
115
+
116
+ weights_fwd = self._get_weights(fwd_idx)
117
+ h_fwd = hx[fwd_idx]
118
+
119
+ out_fwd, h_out_fwd = self._gru_layer(output_fwd, h_fwd, weights_fwd)
120
+ h_n.append(h_out_fwd)
121
+
122
+ if self.bidirectional:
123
+ weights_bwd = self._get_weights(bwd_idx)
124
+ h_bwd = hx[bwd_idx]
125
+
126
+ reversed_input = torch.flip(output_bwd, dims=[1])
127
+ out_bwd, h_out_bwd = self._gru_layer(reversed_input, h_bwd, weights_bwd)
128
+
129
+ out_bwd = torch.flip(out_bwd, dims=[1])
130
+ h_n.append(h_out_bwd)
131
+
132
+ output_fwd = torch.cat([out_fwd, out_bwd], dim=2)
133
+ output_bwd = output_fwd
134
+ else: output_fwd = out_fwd
135
+
136
+ if layer < self.num_layers - 1 and self.dropout > 0:
137
+ output_fwd = F.dropout(output_fwd, p=self.dropout, training=self.training)
138
+ if self.bidirectional: output_bwd = output_fwd
139
+
140
+ output = output_fwd
141
+ h_n = torch.stack(h_n, dim=0)
142
+
143
+ if not self.batch_first: output = output.permute(1, 0, 2)
144
+ return output, h_n
145
+
146
+ def _get_weights(self, layer_idx):
147
+ weights = self._all_weights[layer_idx]
148
+
149
+ weight_ih = getattr(self, weights[0])
150
+ weight_hh = getattr(self, weights[1])
151
+
152
+ bias_ih = getattr(self, weights[2]) if self.bias else None
153
+ bias_hh = getattr(self, weights[3]) if self.bias else None
154
+
155
+ return weight_ih, weight_hh, bias_ih, bias_hh
156
+
157
+ def forward(self, input, hx=None):
158
+ if input.dim() != 3: raise ValueError
159
+
160
+ batch_size = input.size(0) if self.batch_first else input.size(1)
161
+ num_directions = 2 if self.bidirectional else 1
162
+
163
+ if hx is None: hx = torch.zeros(self.num_layers * num_directions, batch_size, self.hidden_size, dtype=input.dtype, device=input.device)
164
+
165
+ self.check_forward_args(input, hx, batch_sizes=None)
166
+ return self._gru(input, hx)
167
+
168
+ def group_norm(x, num_groups, weight=None, bias=None, eps=1e-5):
169
+ N, C = x.shape[:2]
170
+ assert C % num_groups == 0
171
+
172
+ shape = (N, num_groups, C // num_groups) + x.shape[2:]
173
+ x_reshaped = x.view(shape)
174
+
175
+ dims = (2,) + tuple(range(3, x_reshaped.dim()))
176
+ mean = x_reshaped.mean(dim=dims, keepdim=True)
177
+ var = x_reshaped.var(dim=dims, keepdim=True, unbiased=False)
178
+
179
+ x_norm = (x_reshaped - mean) / torch.sqrt(var + eps)
180
+ x_norm = x_norm.view_as(x)
181
+
182
+ if weight is not None:
183
+ weight = weight.view(1, C, *([1] * (x.dim() - 2)))
184
+ x_norm = x_norm * weight
185
+
186
+ if bias is not None:
187
+ bias = bias.view(1, C, *([1] * (x.dim() - 2)))
188
+ x_norm = x_norm + bias
189
+
190
+ return x_norm
191
+
192
+ def script(f, *_, **__):
193
+ f.graph = pytorch_ocl.torch._C.Graph()
194
+ return f
195
+
196
+ if torch_available:
197
+ nn.GRU = GRU
198
+ F.group_norm = group_norm
199
+ torch.jit.script = script
RVC/modules/pipeline.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import faiss
5
+
6
+ import numpy as np
7
+ import torch.nn.functional as F
8
+
9
+ from scipy import signal
10
+
11
+ sys.path.append(os.getcwd())
12
+
13
+ from modules.generator import Generator
14
+ from modules.rms import RMSEnergyExtractor
15
+ from modules.utils import change_rms, clear_gpu_cache
16
+
17
+ bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=16000)
18
+
19
+ class Pipeline:
20
+ def __init__(self, tgt_sr, config):
21
+ self.x_pad, self.x_query, self.x_center, self.x_max = config.device_config()
22
+ self.sample_rate = 16000
23
+ self.window = 160
24
+ self.t_pad = self.sample_rate * self.x_pad
25
+ self.t_pad_tgt = tgt_sr * self.x_pad
26
+ self.t_pad2 = self.t_pad * 2
27
+ self.t_query = self.sample_rate * self.x_query
28
+ self.t_center = self.sample_rate * self.x_center
29
+ self.t_max = self.sample_rate * self.x_max
30
+ self.time_step = self.window / self.sample_rate * 1000
31
+ self.f0_min = 50
32
+ self.f0_max = 1100
33
+ self.device = config.device
34
+ self.is_half = config.is_half
35
+
36
+ def voice_conversion(self, model, net_g, sid, audio0, pitch, pitchf, index, big_npy, index_rate, version, protect, energy):
37
+ feats = (torch.from_numpy(audio0).half() if self.is_half else torch.from_numpy(audio0).float())
38
+ pitch_guidance = pitch != None and pitchf != None
39
+ energy_use = energy != None
40
+
41
+ if feats.dim() == 2: feats = feats.mean(-1)
42
+ assert feats.dim() == 1, feats.dim()
43
+ feats = feats.view(1, -1)
44
+
45
+ with torch.no_grad():
46
+ padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)
47
+ logits = model.extract_features(**{"source": feats.to(self.device), "padding_mask": padding_mask, "output_layer": 9 if version == "v1" else 12})
48
+ feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
49
+
50
+ if protect < 0.5 and pitch_guidance: feats0 = feats.clone()
51
+
52
+ if (not isinstance(index, type(None)) and not isinstance(big_npy, type(None)) and index_rate != 0):
53
+ npy = feats[0].cpu().numpy()
54
+ if self.is_half: npy = npy.astype(np.float32)
55
+
56
+ score, ix = index.search(npy, k=8)
57
+ weight = np.square(1 / score)
58
+
59
+ npy = np.sum(big_npy[ix] * np.expand_dims(weight / weight.sum(axis=1, keepdims=True), axis=2), axis=1)
60
+ if self.is_half: npy = npy.astype(np.float16)
61
+
62
+ feats = (torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate + (1 - index_rate) * feats)
63
+
64
+ feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
65
+ if protect < 0.5 and pitch_guidance: feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
66
+ p_len = audio0.shape[0] // self.window
67
+
68
+ if feats.shape[1] < p_len:
69
+ p_len = feats.shape[1]
70
+ if pitch_guidance: pitch, pitchf = pitch[:, :p_len], pitchf[:, :p_len]
71
+ if energy_use: energy = energy[:, :p_len]
72
+
73
+ if protect < 0.5 and pitch_guidance:
74
+ pitchff = pitchf.clone()
75
+ pitchff[pitchf > 0] = 1
76
+ pitchff[pitchf < 1] = protect
77
+ pitchff = pitchff.unsqueeze(-1)
78
+
79
+ feats = (feats * pitchff + feats0 * (1 - pitchff)).to(feats0.dtype)
80
+
81
+ p_len = torch.tensor([p_len], device=self.device).long()
82
+ feats = feats.half() if self.is_half else feats.float()
83
+
84
+ if not pitch_guidance: pitch, pitchf = None, None
85
+ else: pitchf = pitchf.half() if self.is_half else pitchf.float()
86
+ if not energy_use: energy = None
87
+ else: energy = energy.half() if self.is_half else energy.float()
88
+
89
+ audio1 = (
90
+ (
91
+ net_g.infer(
92
+ feats,
93
+ p_len,
94
+ pitch,
95
+ pitchf,
96
+ sid,
97
+ energy
98
+ )[0][0, 0]
99
+ ).data.cpu().float().numpy()
100
+ )
101
+
102
+ del feats, p_len, net_g, model, padding_mask
103
+ clear_gpu_cache()
104
+ return audio1
105
+
106
+ def pipeline(
107
+ self,
108
+ model,
109
+ net_g,
110
+ sid,
111
+ audio,
112
+ f0_up_key,
113
+ f0_method,
114
+ file_index,
115
+ index_rate,
116
+ pitch_guidance,
117
+ filter_radius,
118
+ volume_envelope,
119
+ version,
120
+ protect,
121
+ hop_length,
122
+ energy_use=False,
123
+ f0_autotune=False,
124
+ f0_autotune_strength=False
125
+ ):
126
+ if file_index != "" and os.path.exists(file_index) and index_rate != 0:
127
+ try:
128
+ index = faiss.read_index(file_index)
129
+ big_npy = index.reconstruct_n(0, index.ntotal)
130
+ except Exception as e:
131
+ print(f"[ERROR] Error occurred while reading index file: {e}")
132
+ index = big_npy = None
133
+ else: index = big_npy = None
134
+
135
+ opt_ts, audio_opt = [], []
136
+ audio = signal.filtfilt(bh, ah, audio)
137
+ audio_pad = np.pad(audio, (self.window // 2, self.window // 2), mode="reflect")
138
+
139
+ if audio_pad.shape[0] > self.t_max:
140
+ audio_sum = np.zeros_like(audio)
141
+
142
+ for i in range(self.window):
143
+ audio_sum += audio_pad[i : i - self.window]
144
+
145
+ for t in range(self.t_center, audio.shape[0], self.t_center):
146
+ opt_ts.append(t - self.t_query + np.where(np.abs(audio_sum[t - self.t_query : t + self.t_query]) == np.abs(audio_sum[t - self.t_query : t + self.t_query]).min())[0][0])
147
+
148
+ s = 0
149
+ t = None
150
+ audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
151
+ sid = torch.tensor(sid, device=self.device).unsqueeze(0).long()
152
+ p_len = audio_pad.shape[0] // self.window
153
+
154
+ if pitch_guidance:
155
+ if not hasattr(self, "f0_generator"): self.f0_generator = Generator(self.sample_rate, hop_length, self.f0_min, self.f0_max, self.is_half, self.device)
156
+ pitch, pitchf = self.f0_generator.calculator(f0_method, audio_pad, f0_up_key, p_len, filter_radius, f0_autotune, f0_autotune_strength)
157
+
158
+ if self.device == "mps": pitchf = pitchf.astype(np.float32)
159
+ pitch, pitchf = torch.tensor(pitch[:p_len], device=self.device).unsqueeze(0).long(), torch.tensor(pitchf[:p_len], device=self.device).unsqueeze(0).float()
160
+
161
+ if energy_use:
162
+ if not hasattr(self, "rms_extract"): self.rms_extract = RMSEnergyExtractor(frame_length=2048, hop_length=self.window, center=True, pad_mode = "reflect").to(self.device).eval()
163
+ energy = self.rms_extract(torch.from_numpy(audio_pad).to(self.device).unsqueeze(0)).cpu().numpy()
164
+
165
+ if self.device == "mps": energy = energy.astype(np.float32)
166
+ energy = torch.tensor(energy[:p_len], device=self.device).unsqueeze(0).float()
167
+
168
+ for t in opt_ts:
169
+ t = t // self.window * self.window
170
+ audio_opt.append(
171
+ self.voice_conversion(
172
+ model,
173
+ net_g,
174
+ sid,
175
+ audio_pad[s : t + self.t_pad2 + self.window],
176
+ pitch[:, s // self.window : (t + self.t_pad2) // self.window] if pitch_guidance else None,
177
+ pitchf[:, s // self.window : (t + self.t_pad2) // self.window] if pitch_guidance else None,
178
+ index,
179
+ big_npy,
180
+ index_rate,
181
+ version,
182
+ protect,
183
+ energy[:, s // self.window : (t + self.t_pad2) // self.window] if energy_use else None
184
+ )[self.t_pad_tgt : -self.t_pad_tgt]
185
+ )
186
+ s = t
187
+
188
+ audio_opt.append(
189
+ self.voice_conversion(
190
+ model,
191
+ net_g,
192
+ sid,
193
+ audio_pad[t:],
194
+ (pitch[:, t // self.window :] if t is not None else pitch) if pitch_guidance else None,
195
+ (pitchf[:, t // self.window :] if t is not None else pitchf) if pitch_guidance else None,
196
+ index,
197
+ big_npy,
198
+ index_rate,
199
+ version,
200
+ protect,
201
+ (energy[:, t // self.window :] if t is not None else energy) if energy_use else None
202
+ )[self.t_pad_tgt : -self.t_pad_tgt]
203
+ )
204
+
205
+ audio_opt = np.concatenate(audio_opt)
206
+
207
+ if volume_envelope != 1: audio_opt = change_rms(audio, self.sample_rate, audio_opt, self.sample_rate, volume_envelope)
208
+ audio_max = np.abs(audio_opt).max() / 0.99
209
+ if audio_max > 1: audio_opt /= audio_max
210
+
211
+ if pitch_guidance: del pitch, pitchf
212
+ del sid
213
+
214
+ clear_gpu_cache()
215
+ return audio_opt
RVC/modules/pixeldrain.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+
4
+ def pixeldrain(url, output_dir):
5
+ try:
6
+ response = requests.get(f"https://pixeldrain.com/api/file/{url.split('pixeldrain.com/u/')[1]}")
7
+
8
+ if response.status_code == 200:
9
+ file_path = os.path.join(output_dir, (response.headers.get("Content-Disposition").split("filename=")[-1].strip('";')))
10
+
11
+ with open(file_path, "wb") as newfile:
12
+ newfile.write(response.content)
13
+ return file_path
14
+ else: return None
15
+ except Exception as e:
16
+ raise RuntimeError(e)
RVC/modules/pyworld.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import ctypes
4
+ import platform
5
+
6
+ import numpy as np
7
+
8
+ class DioOption(ctypes.Structure):
9
+ _fields_ = [("F0Floor", ctypes.c_double), ("F0Ceil", ctypes.c_double), ("ChannelsInOctave", ctypes.c_double), ("FramePeriod", ctypes.c_double), ("Speed", ctypes.c_int), ("AllowedRange", ctypes.c_double)]
10
+
11
+ class HarvestOption(ctypes.Structure):
12
+ _fields_ = [("F0Floor", ctypes.c_double), ("F0Ceil", ctypes.c_double), ("FramePeriod", ctypes.c_double)]
13
+
14
+ class PYWORLD:
15
+ def __init__(self):
16
+ self.world_path = os.path.join("models", "world")
17
+ os.makedirs(self.world_path, exist_ok=True)
18
+ model_type, suffix = (("world_64" if platform.architecture()[0] == "64bit" else "world_86"), ".dll") if platform.system() == "Windows" else ("world_linux", ".so")
19
+ self.world_file_path = os.path.join(self.world_path, f"{model_type}{suffix}")
20
+
21
+ if not os.path.exists(self.world_file_path):
22
+ with open(os.path.join("models", "world.bin"), "rb") as f:
23
+ model = pickle.load(f)
24
+
25
+ with open(self.world_file_path, "wb") as w:
26
+ w.write(model[model_type])
27
+
28
+ self.world_dll = ctypes.CDLL(self.world_file_path)
29
+
30
+ def harvest(self, x, fs, f0_floor=50, f0_ceil=1100, frame_period=10):
31
+ self.world_dll.Harvest.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(HarvestOption), ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double)]
32
+ self.world_dll.Harvest.restype = None
33
+ self.world_dll.InitializeHarvestOption.argtypes = [ctypes.POINTER(HarvestOption)]
34
+ self.world_dll.InitializeHarvestOption.restype = None
35
+ self.world_dll.GetSamplesForHarvest.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_double]
36
+ self.world_dll.GetSamplesForHarvest.restype = ctypes.c_int
37
+
38
+ option = HarvestOption()
39
+ self.world_dll.InitializeHarvestOption(ctypes.byref(option))
40
+
41
+ option.F0Floor = f0_floor
42
+ option.F0Ceil = f0_ceil
43
+ option.FramePeriod = frame_period
44
+
45
+ f0_length = self.world_dll.GetSamplesForHarvest(fs, len(x), option.FramePeriod)
46
+ f0 = (ctypes.c_double * f0_length)()
47
+ tpos = (ctypes.c_double * f0_length)()
48
+
49
+ self.world_dll.Harvest((ctypes.c_double * len(x))(*x), len(x), fs, ctypes.byref(option), tpos, f0)
50
+ return np.array(f0, dtype=np.float32), np.array(tpos, dtype=np.float32)
51
+
52
+ def dio(self, x, fs, f0_floor=50, f0_ceil=1100, channels_in_octave=2, frame_period=10, speed=1, allowed_range=0.1):
53
+ self.world_dll.Dio.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(DioOption), ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double)]
54
+ self.world_dll.Dio.restype = None
55
+ self.world_dll.InitializeDioOption.argtypes = [ctypes.POINTER(DioOption)]
56
+ self.world_dll.InitializeDioOption.restype = None
57
+ self.world_dll.GetSamplesForDIO.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_double]
58
+ self.world_dll.GetSamplesForDIO.restype = ctypes.c_int
59
+
60
+ option = DioOption()
61
+ self.world_dll.InitializeDioOption(ctypes.byref(option))
62
+
63
+ option.F0Floor = f0_floor
64
+ option.F0Ceil = f0_ceil
65
+ option.ChannelsInOctave = channels_in_octave
66
+ option.FramePeriod = frame_period
67
+ option.Speed = speed
68
+ option.AllowedRange = allowed_range
69
+
70
+ f0_length = self.world_dll.GetSamplesForDIO(fs, len(x), option.FramePeriod)
71
+ f0 = (ctypes.c_double * f0_length)()
72
+ tpos = (ctypes.c_double * f0_length)()
73
+
74
+ self.world_dll.Dio((ctypes.c_double * len(x))(*x), len(x), fs, ctypes.byref(option), tpos, f0)
75
+ return np.array(f0, dtype=np.float32), np.array(tpos, dtype=np.float32)
76
+
77
+ def stonemask(self, x, fs, tpos, f0):
78
+ self.world_dll.StoneMask.argtypes = [ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double), ctypes.c_int, ctypes.POINTER(ctypes.c_double)]
79
+ self.world_dll.StoneMask.restype = None
80
+
81
+ out_f0 = (ctypes.c_double * len(f0))()
82
+ self.world_dll.StoneMask((ctypes.c_double * len(x))(*x), len(x), fs, (ctypes.c_double * len(tpos))(*tpos), (ctypes.c_double * len(f0))(*f0), len(f0), out_f0)
83
+
84
+ return np.array(out_f0, dtype=np.float32)
RVC/modules/refinegan.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import torch
5
+
6
+ import numpy as np
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from torch.utils.checkpoint import checkpoint
11
+ from torch.nn.utils import remove_weight_norm
12
+ from torch.nn.utils.parametrizations import weight_norm
13
+
14
+ sys.path.append(os.getcwd())
15
+
16
+ from modules.commons import init_weights, get_padding
17
+
18
+
19
+ class ResBlock(nn.Module):
20
+ def __init__(self, channels, kernel_size = 7, dilation = (1, 3, 5), leaky_relu_slope = 0.2):
21
+ super().__init__()
22
+ self.leaky_relu_slope = leaky_relu_slope
23
+ self.convs1 = nn.ModuleList([weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=d, padding=get_padding(kernel_size, d))) for d in dilation])
24
+ self.convs1.apply(init_weights)
25
+ self.convs2 = nn.ModuleList([weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1, padding=get_padding(kernel_size, 1))) for _ in dilation])
26
+ self.convs2.apply(init_weights)
27
+
28
+ def forward(self, x):
29
+ for c1, c2 in zip(self.convs1, self.convs2):
30
+ x = c2(F.leaky_relu(c1(F.leaky_relu(x, self.leaky_relu_slope)), self.leaky_relu_slope)) + x
31
+
32
+ return x
33
+
34
+ def remove_weight_norm(self):
35
+ for c1, c2 in zip(self.convs1, self.convs2):
36
+ remove_weight_norm(c1)
37
+ remove_weight_norm(c2)
38
+
39
+ class AdaIN(nn.Module):
40
+ def __init__(self, *, channels, leaky_relu_slope = 0.2):
41
+ super().__init__()
42
+ self.weight = nn.Parameter(torch.ones(channels))
43
+ self.activation = nn.LeakyReLU(leaky_relu_slope)
44
+
45
+ def forward(self, x):
46
+ return self.activation(x + (torch.randn_like(x) * self.weight[None, :, None]))
47
+
48
+ class ParallelResBlock(nn.Module):
49
+ def __init__(self, *, in_channels, out_channels, kernel_sizes = (3, 7, 11), dilation = (1, 3, 5), leaky_relu_slope = 0.2):
50
+ super().__init__()
51
+ self.in_channels = in_channels
52
+ self.out_channels = out_channels
53
+ self.input_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=1, padding=3)
54
+ self.input_conv.apply(init_weights)
55
+ self.blocks = nn.ModuleList([nn.Sequential(AdaIN(channels=out_channels), ResBlock(out_channels, kernel_size=kernel_size, dilation=dilation, leaky_relu_slope=leaky_relu_slope), AdaIN(channels=out_channels)) for kernel_size in kernel_sizes])
56
+
57
+ def forward(self, x):
58
+ x = self.input_conv(x)
59
+ return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
60
+
61
+ def remove_weight_norm(self):
62
+ remove_weight_norm(self.input_conv)
63
+ for block in self.blocks:
64
+ block[1].remove_weight_norm()
65
+
66
+ class SineGenerator(nn.Module):
67
+ def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0):
68
+ super(SineGenerator, self).__init__()
69
+ self.sine_amp = sine_amp
70
+ self.noise_std = noise_std
71
+ self.harmonic_num = harmonic_num
72
+ self.dim = self.harmonic_num + 1
73
+ self.sampling_rate = samp_rate
74
+ self.voiced_threshold = voiced_threshold
75
+ self.merge = nn.Sequential(nn.Linear(self.dim, 1, bias=False), nn.Tanh())
76
+
77
+ def _f02uv(self, f0):
78
+ return torch.ones_like(f0) * (f0 > self.voiced_threshold)
79
+
80
+ def _f02sine(self, f0_values):
81
+ rad_values = (f0_values / self.sampling_rate) % 1
82
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], dtype=f0_values.dtype, device=f0_values.device)
83
+
84
+ rand_ini[:, 0] = 0
85
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
86
+
87
+ tmp_over_one = torch.cumsum(rad_values, 1) % 1
88
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
89
+
90
+ cumsum_shift = torch.zeros_like(rad_values)
91
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
92
+
93
+ return torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
94
+
95
+ def forward(self, f0):
96
+ with torch.no_grad():
97
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, dtype=f0.dtype, device=f0.device)
98
+ f0_buf[:, :, 0] = f0[:, :, 0]
99
+
100
+ for idx in np.arange(self.harmonic_num):
101
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
102
+
103
+ sine_waves = self._f02sine(f0_buf) * self.sine_amp
104
+ uv = self._f02uv(f0)
105
+ sine_waves = sine_waves * uv + ((uv * self.noise_std + (1 - uv) * self.sine_amp / 3) * torch.randn_like(sine_waves))
106
+
107
+ return self.merge(sine_waves)
108
+
109
+ class RefineGANGenerator(nn.Module):
110
+ def __init__(self, *, sample_rate = 44100, upsample_rates = (8, 8, 2, 2), leaky_relu_slope = 0.2, num_mels = 128, gin_channels = 256, checkpointing = False, upsample_initial_channel = 512):
111
+ super().__init__()
112
+ self.upsample_rates = upsample_rates
113
+ self.checkpointing = checkpointing
114
+ self.leaky_relu_slope = leaky_relu_slope
115
+ self.upp = np.prod(upsample_rates)
116
+ self.m_source = SineGenerator(sample_rate)
117
+ self.pre_conv = weight_norm(nn.Conv1d(1, upsample_initial_channel // 2, 7, 1, padding=3))
118
+ stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1 for i in range(len(upsample_rates))]
119
+
120
+ channels = upsample_initial_channel
121
+ self.downsample_blocks = nn.ModuleList([])
122
+
123
+ for i, _ in enumerate(upsample_rates):
124
+ stride = stride_f0s[i]
125
+ kernel = 1 if stride == 1 else stride * 2 - stride % 2
126
+
127
+ self.downsample_blocks.append(weight_norm(nn.Conv1d(1, channels // 2 ** (i + 2), kernel, stride, padding=0 if stride == 1 else (kernel - stride) // 2)))
128
+
129
+ self.mel_conv = weight_norm(nn.Conv1d(num_mels, channels // 2, 7, 1, padding=3))
130
+ self.mel_conv.apply(init_weights)
131
+
132
+ if gin_channels != 0: self.cond = nn.Conv1d(256, channels // 2, 1)
133
+
134
+ self.upsample_blocks = nn.ModuleList([])
135
+ self.upsample_conv_blocks = nn.ModuleList([])
136
+
137
+ for rate in upsample_rates:
138
+ new_channels = channels // 2
139
+ self.upsample_blocks.append(nn.Upsample(scale_factor=rate, mode="linear"))
140
+ self.upsample_conv_blocks.append(ParallelResBlock(in_channels=channels + channels // 4, out_channels=new_channels, kernel_sizes=(3, 7, 11), dilation=(1, 3, 5), leaky_relu_slope=leaky_relu_slope))
141
+ channels = new_channels
142
+
143
+ self.conv_post = weight_norm(nn.Conv1d(channels, 1, 7, 1, padding=3, bias=False))
144
+ self.conv_post.apply(init_weights)
145
+
146
+ def forward(self, mel, f0, g = None):
147
+ har_source = self.m_source(F.interpolate(f0.unsqueeze(1), size=mel.shape[-1] * self.upp, mode="linear").transpose(1, 2)).transpose(1, 2)
148
+ x = F.interpolate(self.pre_conv(har_source), size=mel.shape[-1], mode="linear")
149
+
150
+ mel = self.mel_conv(mel)
151
+ if g is not None: mel += self.cond(g)
152
+
153
+ x = torch.cat([mel, x], dim=1)
154
+
155
+ for ups, res, down in zip(self.upsample_blocks, self.upsample_conv_blocks, self.downsample_blocks):
156
+ x = F.leaky_relu(x, self.leaky_relu_slope)
157
+ x = checkpoint(res, torch.cat([checkpoint(ups, x, use_reentrant=False), down(har_source)], dim=1), use_reentrant=False) if self.training and self.checkpointing else res(torch.cat([ups(x), down(har_source)], dim=1))
158
+
159
+ return torch.tanh(self.conv_post(F.leaky_relu(x, self.leaky_relu_slope)))
160
+
161
+ def remove_weight_norm(self):
162
+ remove_weight_norm(self.pre_conv)
163
+ remove_weight_norm(self.mel_conv)
164
+ remove_weight_norm(self.conv_post)
165
+
166
+ for block in self.downsample_blocks:
167
+ block.remove_weight_norm()
168
+
169
+ for block in self.upsample_conv_blocks:
170
+ block.remove_weight_norm()
RVC/modules/residuals.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+ from torch.nn.utils import remove_weight_norm
6
+ from torch.nn.utils.parametrizations import weight_norm
7
+
8
+ sys.path.append(os.getcwd())
9
+
10
+ from .modules import WaveNet
11
+ from .commons import get_padding, init_weights
12
+
13
+
14
+ LRELU_SLOPE = 0.1
15
+
16
+ def create_conv1d_layer(channels, kernel_size, dilation):
17
+ return weight_norm(torch.nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation, padding=get_padding(kernel_size, dilation)))
18
+
19
+ def apply_mask(tensor, mask):
20
+ return tensor * mask if mask is not None else tensor
21
+
22
+ class ResBlockBase(torch.nn.Module):
23
+ def __init__(self, channels, kernel_size, dilations):
24
+ super(ResBlockBase, self).__init__()
25
+
26
+ self.convs1 = torch.nn.ModuleList([create_conv1d_layer(channels, kernel_size, d) for d in dilations])
27
+ self.convs1.apply(init_weights)
28
+
29
+ self.convs2 = torch.nn.ModuleList([create_conv1d_layer(channels, kernel_size, 1) for _ in dilations])
30
+ self.convs2.apply(init_weights)
31
+
32
+ def forward(self, x, x_mask=None):
33
+ for c1, c2 in zip(self.convs1, self.convs2):
34
+ x = c2(apply_mask(torch.nn.functional.leaky_relu(c1(apply_mask(torch.nn.functional.leaky_relu(x, LRELU_SLOPE), x_mask)), LRELU_SLOPE), x_mask)) + x
35
+
36
+ return apply_mask(x, x_mask)
37
+
38
+ def remove_weight_norm(self):
39
+ for conv in self.convs1 + self.convs2:
40
+ remove_weight_norm(conv)
41
+
42
+ class ResBlock(ResBlockBase):
43
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
44
+ super(ResBlock, self).__init__(channels, kernel_size, dilation)
45
+
46
+ class Log(torch.nn.Module):
47
+ def forward(self, x, x_mask, reverse=False, **kwargs):
48
+ if not reverse:
49
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
50
+ return y, torch.sum(-y, [1, 2])
51
+ else: return torch.exp(x) * x_mask
52
+
53
+ class Flip(torch.nn.Module):
54
+ def forward(self, x, *args, reverse=False, **kwargs):
55
+ x = torch.flip(x, [1])
56
+
57
+ if not reverse: return x, torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
58
+ else: return x
59
+
60
+ class ElementwiseAffine(torch.nn.Module):
61
+ def __init__(self, channels):
62
+ super().__init__()
63
+ self.channels = channels
64
+ self.m = torch.nn.Parameter(torch.zeros(channels, 1))
65
+ self.logs = torch.nn.Parameter(torch.zeros(channels, 1))
66
+
67
+ def forward(self, x, x_mask, reverse=False, **kwargs):
68
+ if not reverse: return ((self.m + torch.exp(self.logs) * x) * x_mask), torch.sum(self.logs * x_mask, [1, 2])
69
+ else: return (x - self.m) * torch.exp(-self.logs) * x_mask
70
+
71
+ class ResidualCouplingBlock(torch.nn.Module):
72
+ def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
73
+ super(ResidualCouplingBlock, self).__init__()
74
+ self.channels = channels
75
+ self.hidden_channels = hidden_channels
76
+ self.kernel_size = kernel_size
77
+ self.dilation_rate = dilation_rate
78
+ self.n_layers = n_layers
79
+ self.n_flows = n_flows
80
+ self.gin_channels = gin_channels
81
+ self.flows = torch.nn.ModuleList()
82
+
83
+ for _ in range(n_flows):
84
+ self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
85
+ self.flows.append(Flip())
86
+
87
+ def forward(self, x, x_mask, g = None, reverse = False):
88
+ if not reverse:
89
+ for flow in self.flows:
90
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
91
+ else:
92
+ for flow in reversed(self.flows):
93
+ x = flow.forward(x, x_mask, g=g, reverse=reverse)
94
+
95
+ return x
96
+
97
+ def remove_weight_norm(self):
98
+ for i in range(self.n_flows):
99
+ self.flows[i * 2].remove_weight_norm()
100
+
101
+ def __prepare_scriptable__(self):
102
+ for i in range(self.n_flows):
103
+ for hook in self.flows[i * 2]._forward_pre_hooks.values():
104
+ if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(self.flows[i * 2])
105
+
106
+ return self
107
+
108
+ class ResidualCouplingLayer(torch.nn.Module):
109
+ def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, mean_only=False):
110
+ assert channels % 2 == 0, "Channels/2"
111
+ super().__init__()
112
+ self.channels = channels
113
+ self.hidden_channels = hidden_channels
114
+ self.kernel_size = kernel_size
115
+ self.dilation_rate = dilation_rate
116
+ self.n_layers = n_layers
117
+ self.half_channels = channels // 2
118
+ self.mean_only = mean_only
119
+
120
+ self.pre = torch.nn.Conv1d(self.half_channels, hidden_channels, 1)
121
+ self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
122
+ self.post = torch.nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
123
+
124
+ self.post.weight.data.zero_()
125
+ self.post.bias.data.zero_()
126
+
127
+ def forward(self, x, x_mask, g=None, reverse=False):
128
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
129
+ stats = self.post(self.enc((self.pre(x0) * x_mask), x_mask, g=g)) * x_mask
130
+
131
+ if not self.mean_only: m, logs = torch.split(stats, [self.half_channels] * 2, 1)
132
+ else:
133
+ m = stats
134
+ logs = torch.zeros_like(m)
135
+
136
+ if not reverse: return torch.cat([x0, (m + x1 * torch.exp(logs) * x_mask)], 1), torch.sum(logs, [1, 2])
137
+ else: return torch.cat([x0, ((x1 - m) * torch.exp(-logs) * x_mask)], 1)
138
+
139
+ def remove_weight_norm(self):
140
+ self.enc.remove_weight_norm()
RVC/modules/rms.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import librosa
3
+
4
+ import torch.nn as nn
5
+
6
+ class RMSEnergyExtractor(nn.Module):
7
+ def __init__(self, frame_length=2048, hop_length=512, center=True, pad_mode = "reflect"):
8
+ super().__init__()
9
+ self.frame_length = frame_length
10
+ self.hop_length = hop_length
11
+ self.center = center
12
+ self.pad_mode = pad_mode
13
+
14
+ def forward(self, x):
15
+ assert x.ndim == 2
16
+ assert x.shape[0] == 1
17
+
18
+ if str(x.device).startswith("ocl"): x = x.contiguous()
19
+
20
+ rms = torch.from_numpy(
21
+ librosa.feature.rms(
22
+ y=x.squeeze(0).cpu().numpy(),
23
+ frame_length=self.frame_length,
24
+ hop_length=self.hop_length,
25
+ center=self.center,
26
+ pad_mode=self.pad_mode
27
+ )
28
+ )
29
+
30
+ return rms.squeeze(-2).to(x.device) if not str(x.device).startswith("ocl") else rms.contiguous().squeeze(-2).to(x.device)
RVC/modules/rmvpe.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+ import numpy as np
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from librosa.filters import mel
10
+
11
+ sys.path.append(os.getcwd())
12
+
13
+ from modules import opencl
14
+
15
+ N_MELS, N_CLASS = 128, 360
16
+
17
+ class ConvBlockRes(nn.Module):
18
+ def __init__(self, in_channels, out_channels, momentum=0.01):
19
+ super(ConvBlockRes, self).__init__()
20
+ self.conv = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
21
+ if in_channels != out_channels:
22
+ self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
23
+ self.is_shortcut = True
24
+ else: self.is_shortcut = False
25
+
26
+ def forward(self, x):
27
+ return (self.conv(x) + self.shortcut(x)) if self.is_shortcut else (self.conv(x) + x)
28
+
29
+ class ResEncoderBlock(nn.Module):
30
+ def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
31
+ super(ResEncoderBlock, self).__init__()
32
+ self.n_blocks = n_blocks
33
+ self.conv = nn.ModuleList()
34
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
35
+
36
+ for _ in range(n_blocks - 1):
37
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
38
+
39
+ self.kernel_size = kernel_size
40
+ if self.kernel_size is not None: self.pool = nn.AvgPool2d(kernel_size=kernel_size)
41
+
42
+ def forward(self, x):
43
+ for i in range(self.n_blocks):
44
+ x = self.conv[i](x)
45
+
46
+ if self.kernel_size is not None: return x, self.pool(x)
47
+ else: return x
48
+
49
+ class Encoder(nn.Module):
50
+ def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
51
+ super(Encoder, self).__init__()
52
+ self.n_encoders = n_encoders
53
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
54
+ self.layers = nn.ModuleList()
55
+
56
+ for _ in range(self.n_encoders):
57
+ self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum))
58
+ in_channels = out_channels
59
+ out_channels *= 2
60
+ in_size //= 2
61
+
62
+ self.out_size = in_size
63
+ self.out_channel = out_channels
64
+
65
+ def forward(self, x):
66
+ concat_tensors = []
67
+ x = self.bn(x)
68
+
69
+ for layer in self.layers:
70
+ t, x = layer(x)
71
+ concat_tensors.append(t)
72
+
73
+ return x, concat_tensors
74
+
75
+ class Intermediate(nn.Module):
76
+ def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
77
+ super(Intermediate, self).__init__()
78
+ self.layers = nn.ModuleList()
79
+ self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
80
+
81
+ for _ in range(n_inters - 1):
82
+ self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
83
+
84
+ def forward(self, x):
85
+ for layer in self.layers:
86
+ x = layer(x)
87
+
88
+ return x
89
+
90
+ class ResDecoderBlock(nn.Module):
91
+ def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
92
+ super(ResDecoderBlock, self).__init__()
93
+ out_padding = (0, 1) if stride == (1, 2) else (1, 1)
94
+ self.conv1 = nn.Sequential(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=stride, padding=(1, 1), output_padding=out_padding, bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
95
+ self.conv2 = nn.ModuleList()
96
+ self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
97
+
98
+ for _ in range(n_blocks - 1):
99
+ self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
100
+
101
+ def forward(self, x, concat_tensor):
102
+ x = torch.cat((self.conv1(x), concat_tensor), dim=1)
103
+ for conv2 in self.conv2:
104
+ x = conv2(x)
105
+
106
+ return x
107
+
108
+ class Decoder(nn.Module):
109
+ def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
110
+ super(Decoder, self).__init__()
111
+ self.layers = nn.ModuleList()
112
+
113
+ for _ in range(n_decoders):
114
+ out_channels = in_channels // 2
115
+ self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
116
+ in_channels = out_channels
117
+
118
+ def forward(self, x, concat_tensors):
119
+ for i, layer in enumerate(self.layers):
120
+ x = layer(x, concat_tensors[-1 - i])
121
+
122
+ return x
123
+
124
+ class DeepUnet(nn.Module):
125
+ def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
126
+ super(DeepUnet, self).__init__()
127
+ self.encoder = Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels)
128
+ self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
129
+ self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
130
+
131
+ def forward(self, x):
132
+ x, concat_tensors = self.encoder(x)
133
+ return self.decoder(self.intermediate(x), concat_tensors)
134
+
135
+ class E2E(nn.Module):
136
+ def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
137
+ super(E2E, self).__init__()
138
+ self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
139
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
140
+ self.fc = nn.Sequential(BiGRU(3 * 128, 256, n_gru), nn.Linear(512, N_CLASS), nn.Dropout(0.25), nn.Sigmoid()) if n_gru else nn.Sequential(nn.Linear(3 * N_MELS, N_CLASS), nn.Dropout(0.25), nn.Sigmoid())
141
+
142
+ def forward(self, mel):
143
+ return self.fc(self.cnn(self.unet(mel.transpose(-1, -2).unsqueeze(1))).transpose(1, 2).flatten(-2))
144
+
145
+ class MelSpectrogram(torch.nn.Module):
146
+ def __init__(self, is_half, n_mel_channels, sample_rate, win_length, hop_length, n_fft=None, mel_fmin=0, mel_fmax=None, clamp=1e-5):
147
+ super().__init__()
148
+ n_fft = win_length if n_fft is None else n_fft
149
+ self.hann_window = {}
150
+ mel_basis = mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax, htk=True)
151
+ mel_basis = torch.from_numpy(mel_basis).float()
152
+ self.register_buffer("mel_basis", mel_basis)
153
+ self.n_fft = win_length if n_fft is None else n_fft
154
+ self.hop_length = hop_length
155
+ self.win_length = win_length
156
+ self.sample_rate = sample_rate
157
+ self.n_mel_channels = n_mel_channels
158
+ self.clamp = clamp
159
+ self.is_half = is_half
160
+
161
+ def forward(self, audio, keyshift=0, speed=1, center=True):
162
+ factor = 2 ** (keyshift / 12)
163
+ win_length_new = int(np.round(self.win_length * factor))
164
+ keyshift_key = str(keyshift) + "_" + str(audio.device)
165
+ if keyshift_key not in self.hann_window: self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
166
+
167
+ n_fft = int(np.round(self.n_fft * factor))
168
+ hop_length = int(np.round(self.hop_length * speed))
169
+
170
+ if str(audio.device).startswith("ocl"):
171
+ stft = opencl.STFT(filter_length=n_fft, hop_length=hop_length, win_length=win_length_new).to(audio.device)
172
+ magnitude = stft.transform(audio, 1e-9)
173
+ else:
174
+ fft = torch.stft(audio, n_fft=n_fft, hop_length=hop_length, win_length=win_length_new, window=self.hann_window[keyshift_key], center=center, return_complex=True)
175
+ magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
176
+
177
+ if keyshift != 0:
178
+ size = self.n_fft // 2 + 1
179
+ resize = magnitude.size(1)
180
+ if resize < size: magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
181
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
182
+
183
+ mel_output = torch.matmul(self.mel_basis, magnitude)
184
+ if self.is_half: mel_output = mel_output.half()
185
+
186
+ return torch.log(torch.clamp(mel_output, min=self.clamp))
187
+
188
+ class RMVPE:
189
+ def __init__(self, model_path, is_half, device=None):
190
+ self.resample_kernel = {}
191
+ self.resample_kernel = {}
192
+ model = E2E(4, 1, (2, 2))
193
+ ckpt = torch.load(model_path, map_location="cpu")
194
+ model.load_state_dict(ckpt)
195
+ model.eval()
196
+ if is_half: model = model.half()
197
+ self.model = model.to(device)
198
+ self.is_half = is_half
199
+ self.device = device
200
+ self.mel_extractor = MelSpectrogram(is_half, N_MELS, 16000, 1024, 160, None, 30, 8000).to(device)
201
+ cents_mapping = 20 * np.arange(N_CLASS) + 1997.3794084376191
202
+ self.cents_mapping = np.pad(cents_mapping, (4, 4))
203
+
204
+ def mel2hidden(self, mel):
205
+ with torch.no_grad():
206
+ n_frames = mel.shape[-1]
207
+ n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
208
+ if n_pad > 0: mel = F.pad(mel, (0, n_pad), mode="constant")
209
+
210
+ hidden = self.model(mel.half() if self.is_half else mel.float())
211
+ return hidden[:, :n_frames]
212
+
213
+ def decode(self, hidden, thred=0.03):
214
+ f0 = 10 * (2 ** (self.to_local_average_cents(hidden, thred=thred) / 1200))
215
+ f0[f0 == 10] = 0
216
+
217
+ return f0
218
+
219
+ def infer_from_audio(self, audio, thred=0.03):
220
+ hidden = self.mel2hidden(self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True))
221
+
222
+ return self.decode((hidden.squeeze(0).cpu().numpy().astype(np.float32) if self.is_half else hidden.squeeze(0).cpu().numpy()), thred=thred)
223
+
224
+ def infer_from_audio_with_pitch(self, audio, thred=0.03, f0_min=50, f0_max=1100):
225
+ hidden = self.mel2hidden(self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True))
226
+
227
+ f0 = self.decode((hidden.squeeze(0).cpu().numpy().astype(np.float32) if self.is_half else hidden.squeeze(0).cpu().numpy()), thred=thred)
228
+ f0[(f0 < f0_min) | (f0 > f0_max)] = 0
229
+
230
+ return f0
231
+
232
+ def to_local_average_cents(self, salience, thred=0.05):
233
+ center = np.argmax(salience, axis=1)
234
+ salience = np.pad(salience, ((0, 0), (4, 4)))
235
+ center += 4
236
+ todo_salience, todo_cents_mapping = [], []
237
+ starts = center - 4
238
+ ends = center + 5
239
+
240
+ for idx in range(salience.shape[0]):
241
+ todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
242
+ todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
243
+
244
+ todo_salience = np.array(todo_salience)
245
+ devided = np.sum(todo_salience * np.array(todo_cents_mapping), 1) / np.sum(todo_salience, 1)
246
+ devided[np.max(salience, axis=1) <= thred] = 0
247
+
248
+ return devided
249
+
250
+ class BiGRU(nn.Module):
251
+ def __init__(self, input_features, hidden_features, num_layers):
252
+ super(BiGRU, self).__init__()
253
+ self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
254
+
255
+ def forward(self, x):
256
+ try:
257
+ return self.gru(x)[0]
258
+ except:
259
+ torch.backends.cudnn.enabled = False
260
+ return self.gru(x)[0]
RVC/modules/swipe.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numba as nb
4
+ import numpy as np
5
+
6
+ from matplotlib import mlab
7
+ from scipy import interpolate
8
+ from decimal import Decimal, ROUND_HALF_UP
9
+
10
+
11
+ def swipe(x, fs, f0_floor=50, f0_ceil=1100, frame_period=10, sTHR=0.3):
12
+ plim = np.array([f0_floor, f0_ceil])
13
+ t = np.arange(0, int(1000 * len(x) / fs / (frame_period) + 1)) * (frame_period / 1000)
14
+
15
+ log2pc = np.arange(np.log2(plim[0]) * 96, np.log2(plim[-1]) * 96)
16
+ log2pc *= (1 / 96)
17
+
18
+ pc = 2 ** log2pc
19
+ S = np.zeros((len(pc), len(t)))
20
+
21
+ logWs = [round_matlab(elm) for elm in np.log2(4 * 2 * fs / plim)]
22
+ ws = 2 ** np.arange(logWs[0], logWs[1] - 1, -1)
23
+ p0 = 4 * 2 * fs / ws
24
+
25
+ d = 1 + log2pc - np.log2(4 * 2 * fs / ws[0])
26
+ fERBs = erbs2hz(np.arange(hz2erbs(pc[0] / 4), hz2erbs(fs / 2), 0.1))
27
+
28
+ for i in range(len(ws)):
29
+ dn = round_matlab(4 * fs / p0[i])
30
+ X, f, ti = mlab.specgram(x=np.r_[np.zeros(int(ws[i] / 2)), np.r_[x, np.zeros(int(dn + ws[i] / 2))]], NFFT=ws[i], Fs=fs, window=np.hanning(ws[i] + 2)[1:-1], noverlap=max(0, np.round(ws[i] - dn)), mode='complex')
31
+ ti = np.r_[0, ti[:-1]]
32
+ M = np.maximum(0, interpolate.interp1d(f, np.abs(X.T), kind='cubic')(fERBs)).T
33
+
34
+ if i == len(ws) - 1:
35
+ j = np.where(d - (i + 1) > -1)[0]
36
+ k = np.where(d[j] - (i + 1) < 0)[0]
37
+ elif i == 0:
38
+ j = np.where(d - (i + 1) < 1)[0]
39
+ k = np.where(d[j] - (i + 1) > 0)[0]
40
+ else:
41
+ j = np.where(np.abs(d - (i + 1)) < 1)[0]
42
+ k = np.arange(len(j))
43
+
44
+ Si = pitchStrengthAllCandidates(fERBs, np.sqrt(M), pc[j])
45
+ Si = interpolate.interp1d(ti, Si, bounds_error=False, fill_value='nan')(t) if Si.shape[1] > 1 else np.full((len(Si), len(t)), np.nan)
46
+
47
+ mu = np.ones(j.shape)
48
+ mu[k] = 1 - np.abs(d[j[k]] - i - 1)
49
+ S[j, :] = S[j, :] + np.tile(mu.reshape(-1, 1), (1, Si.shape[1])) * Si
50
+
51
+
52
+ p = np.full((S.shape[1], 1), np.nan)
53
+ s = np.full((S.shape[1], 1), np.nan)
54
+
55
+ for j in range(S.shape[1]):
56
+ s[j] = np.max(S[:, j])
57
+ i = np.argmax(S[:, j])
58
+
59
+ if s[j] < sTHR: continue
60
+
61
+ if i == 0: p[j] = pc[0]
62
+ elif i == len(pc) - 1: p[j] = pc[0]
63
+ else:
64
+ I = np.arange(i-1, i+2)
65
+ tc = 1 / pc[I]
66
+
67
+ ntc = (tc / tc[1] - 1) * 2 * np.pi
68
+ idx = np.isfinite(S[I, j])
69
+
70
+ c = np.zeros(len(ntc))
71
+ c += np.nan
72
+
73
+ I_ = I[idx]
74
+
75
+ if len(I_) < 2: c[idx] = (S[I, j])[0] / ntc[0]
76
+ else: c[idx] = np.polyfit(ntc[idx], (S[I_, j]), 2)
77
+
78
+ pval = np.polyval(c, ((1 / (2 ** np.arange(np.log2(pc[I[0]]), np.log2(pc[I[2]]) + 1 / 12 / 64, 1 / 12 / 64))) / tc[1] - 1) * 2 * np.pi)
79
+ s[j] = np.max(pval)
80
+ p[j] = 2 ** (np.log2(pc[I[0]]) + (np.argmax(pval)) / 12 / 64)
81
+
82
+ p = p.flatten()
83
+ p[np.isnan(p)] = 0
84
+
85
+ return np.array(p, dtype=np.float32), np.array(t, dtype=np.float32)
86
+
87
+ def round_matlab(n):
88
+ return int(Decimal(n).quantize(0, ROUND_HALF_UP))
89
+
90
+ def pitchStrengthAllCandidates(f, L, pc):
91
+ den = np.sqrt(np.sum(L * L, axis=0))
92
+ den = np.where(den == 0, 2.220446049250313e-16, den)
93
+
94
+ L = L / den
95
+ S = np.zeros((len(pc), L.shape[1]))
96
+
97
+ for j in range(len(pc)):
98
+ S[j,:] = pitchStrengthOneCandidate(f, L, pc[j])
99
+
100
+ return S
101
+
102
+ def pitchStrengthOneCandidate(f, L, pc):
103
+ k = np.zeros(len(f))
104
+ q = f / pc
105
+
106
+ for i in ([1] + sieve(int(np.fix(f[-1] / pc - 0.75)))):
107
+ a = np.abs(q - i)
108
+ p = a < 0.25
109
+ k[p] = np.cos(2 * np.pi * q[p])
110
+
111
+ v = np.logical_and((0.25 < a), (a < 0.75))
112
+ k[v] = k[v] + np.cos(2 * np.pi * q[v]) / 2
113
+
114
+ k *= np.sqrt(1 / f)
115
+ k /= np.linalg.norm(k[k>0])
116
+
117
+ return k @ L
118
+
119
+ def hz2erbs(hz):
120
+ return 21.4 * np.log10(1 + hz / 229)
121
+
122
+ def erbs2hz(erbs):
123
+ return (10 ** (erbs / 21.4) - 1) * 229
124
+
125
+ def sieve(n):
126
+ primes = list(range(2, n + 1))
127
+ num = 2
128
+
129
+ while num < math.sqrt(n):
130
+ i = num
131
+
132
+ while i <= n:
133
+ i += num
134
+
135
+ if i in primes: primes.remove(i)
136
+
137
+ for j in primes:
138
+ if j > num:
139
+ num = j
140
+ break
141
+
142
+ return primes
143
+
144
+ def stonemask(x, fs, temporal_positions, f0):
145
+ refined_f0 = np.copy(f0)
146
+
147
+ for i in range(len(temporal_positions)):
148
+ if f0[i] != 0:
149
+ refined_f0[i] = get_refined_f0(x, fs, temporal_positions[i], f0[i])
150
+ if abs(refined_f0[i] - f0[i]) / f0[i] > 0.2: refined_f0[i] = f0[i]
151
+
152
+ return np.array(refined_f0, dtype=np.float32)
153
+
154
+ def get_refined_f0(x, fs, current_time, current_f0):
155
+ f0_initial = current_f0
156
+ half_window_length = np.ceil(3 * fs / f0_initial / 2)
157
+ window_length_in_time = (2 * half_window_length + 1) / fs
158
+
159
+ base_time = np.arange(-half_window_length, half_window_length + 1) / fs
160
+ fft_size = 2 ** math.ceil(math.log((half_window_length * 2 + 1), 2) + 1)
161
+
162
+ base_time = np.array([float("{0:.4f}".format(elm)) for elm in base_time])
163
+ index_raw = round_matlab_2((current_time + base_time) * fs)
164
+
165
+ window_time = ((index_raw - 1) / fs) - current_time
166
+ main_window = 0.42 + 0.5 * np.cos(2 * math.pi * window_time / window_length_in_time) + 0.08 * np.cos(4 * math.pi * window_time / window_length_in_time)
167
+
168
+ index = np.array(np.maximum(1, np.minimum(len(x), index_raw)), dtype=int)
169
+ spectrum = np.fft.fft(x[index - 1] * main_window, fft_size)
170
+
171
+ diff_spectrum = np.fft.fft(x[index - 1] * (-(np.diff(np.r_[0, main_window]) + np.diff(np.r_[main_window, 0])) / 2), fft_size)
172
+ power_spectrum = np.abs(spectrum) ** 2
173
+
174
+ from sys import float_info
175
+
176
+ power_spectrum[power_spectrum == 0] = float_info.epsilon
177
+ instantaneous_frequency = (np.arange(fft_size) / fft_size * fs) + (np.real(spectrum) * np.imag(diff_spectrum) - np.imag(spectrum) * np.real(diff_spectrum)) / power_spectrum * fs / 2 / math.pi
178
+
179
+ trim_index = np.array([1, 2])
180
+ index_list_trim = np.array(round_matlab_2(f0_initial * fft_size / fs * trim_index) + 1, int)
181
+
182
+ amp_list = np.sqrt(power_spectrum[index_list_trim - 1])
183
+ f0_initial = np.sum(amp_list * instantaneous_frequency[index_list_trim - 1]) / np.sum(amp_list * trim_index)
184
+
185
+ if f0_initial < 0: return 0
186
+
187
+ trim_index = np.array([1, 2, 3, 4, 5, 6])
188
+ index_list_trim = np.array(round_matlab_2(f0_initial * fft_size / fs * trim_index) + 1, int)
189
+ amp_list = np.sqrt(power_spectrum[index_list_trim - 1])
190
+
191
+ return np.sum(amp_list * instantaneous_frequency[index_list_trim - 1]) / np.sum(amp_list * trim_index)
192
+
193
+ @nb.jit((nb.float64[:],), nopython=True, cache=True)
194
+ def round_matlab_2(x):
195
+ y = x.copy()
196
+
197
+ y[x > 0] += 0.5
198
+ y[x <= 0] -= 0.5
199
+
200
+ return y
RVC/modules/synthesizers.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+ sys.path.append(os.getcwd())
6
+
7
+ from modules.hifigan import HiFiGANGenerator
8
+ from modules.refinegan import RefineGANGenerator
9
+ from modules.residuals import ResidualCouplingBlock
10
+ from modules.mrf_hifigan import HiFiGANMRFGenerator
11
+ from modules.nsf_hifigan import HiFiGANNRFGenerator
12
+ from modules.encoders import TextEncoder, PosteriorEncoder
13
+ from modules.commons import slice_segments, rand_slice_segments
14
+
15
+ class Synthesizer(torch.nn.Module):
16
+ def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, use_f0, text_enc_hidden_dim=768, vocoder="Default", checkpointing=False, energy=False, **kwargs):
17
+ super(Synthesizer, self).__init__()
18
+ self.spec_channels = spec_channels
19
+ self.inter_channels = inter_channels
20
+ self.hidden_channels = hidden_channels
21
+ self.filter_channels = filter_channels
22
+ self.n_heads = n_heads
23
+ self.n_layers = n_layers
24
+ self.kernel_size = kernel_size
25
+ self.p_dropout = float(p_dropout)
26
+ self.resblock_kernel_sizes = resblock_kernel_sizes
27
+ self.resblock_dilation_sizes = resblock_dilation_sizes
28
+ self.upsample_rates = upsample_rates
29
+ self.upsample_initial_channel = upsample_initial_channel
30
+ self.upsample_kernel_sizes = upsample_kernel_sizes
31
+ self.segment_size = segment_size
32
+ self.gin_channels = gin_channels
33
+ self.spk_embed_dim = spk_embed_dim
34
+ self.use_f0 = use_f0
35
+ self.enc_p = TextEncoder(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), text_enc_hidden_dim, f0=use_f0, energy=energy)
36
+
37
+ if use_f0:
38
+ if vocoder == "RefineGAN": self.dec = RefineGANGenerator(sample_rate=sr, upsample_rates=upsample_rates, num_mels=inter_channels, checkpointing=checkpointing)
39
+ elif vocoder in ["MRF-HiFi-GAN", "MRF HiFi-GAN"]: self.dec = HiFiGANMRFGenerator(in_channel=inter_channels, upsample_initial_channel=upsample_initial_channel, upsample_rates=upsample_rates, upsample_kernel_sizes=upsample_kernel_sizes, resblock_kernel_sizes=resblock_kernel_sizes, resblock_dilations=resblock_dilation_sizes, gin_channels=gin_channels, sample_rate=sr, harmonic_num=8, checkpointing=checkpointing)
40
+ else: self.dec = HiFiGANNRFGenerator(inter_channels, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, checkpointing=checkpointing)
41
+ else: self.dec = HiFiGANGenerator(inter_channels, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
42
+
43
+ self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
44
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels)
45
+ self.emb_g = torch.nn.Embedding(self.spk_embed_dim, gin_channels)
46
+
47
+ def remove_weight_norm(self):
48
+ self.dec.remove_weight_norm()
49
+ self.flow.remove_weight_norm()
50
+ self.enc_q.remove_weight_norm()
51
+
52
+ @torch.jit.ignore
53
+ def forward(self, phone, phone_lengths, pitch = None, pitchf = None, y = None, y_lengths = None, ds = None, energy = None):
54
+ g = self.emb_g(ds).unsqueeze(-1)
55
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths, energy)
56
+
57
+ if y is not None:
58
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
59
+ z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size)
60
+
61
+ return (self.dec(z_slice, slice_segments(pitchf, ids_slice, self.segment_size, 2), g=g) if self.use_f0 else self.dec(z_slice, g=g)), ids_slice, x_mask, y_mask, (z, self.flow(z, y_mask, g=g), m_p, logs_p, m_q, logs_q)
62
+ else: return None, None, x_mask, None, (None, None, m_p, logs_p, None, None)
63
+
64
+ @torch.jit.export
65
+ def infer(self, phone, phone_lengths, pitch = None, nsff0 = None, sid = None, energy = None, rate = None):
66
+ g = self.emb_g(sid).unsqueeze(-1)
67
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths, energy)
68
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
69
+
70
+ if rate is not None:
71
+ assert isinstance(rate, torch.Tensor)
72
+ head = int(z_p.shape[2] * (1.0 - rate.item()))
73
+ z_p = z_p[:, :, head:]
74
+ x_mask = x_mask[:, :, head:]
75
+ if self.use_f0: nsff0 = nsff0[:, head:]
76
+
77
+ if self.use_f0:
78
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
79
+ o = self.dec(z * x_mask, nsff0, g=g)
80
+ else:
81
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
82
+ o = self.dec(z * x_mask, g=g)
83
+
84
+ return o, x_mask, (z, z_p, m_p, logs_p)
RVC/modules/torchcrepe.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import librosa
3
+ import functools
4
+ import scipy.stats
5
+
6
+ import numpy as np
7
+
8
+ CENTS_PER_BIN, MAX_FMAX, PITCH_BINS, SAMPLE_RATE, WINDOW_SIZE = 20, 2006, 360, 16000, 1024
9
+
10
+ def mean(signals, win_length=9):
11
+ assert signals.dim() == 2
12
+
13
+ signals = signals.unsqueeze(1)
14
+ mask = ~torch.isnan(signals)
15
+ padding = win_length // 2
16
+
17
+ ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device)
18
+ avg_pooled = torch.nn.functional.conv1d(torch.where(mask, signals, torch.zeros_like(signals)), ones_kernel, stride=1, padding=padding) / torch.nn.functional.conv1d(mask.float(), ones_kernel, stride=1, padding=padding).clamp(min=1)
19
+ avg_pooled[avg_pooled == 0] = float("nan")
20
+
21
+ return avg_pooled.squeeze(1)
22
+
23
+ def median(signals, win_length):
24
+ assert signals.dim() == 2
25
+
26
+ signals = signals.unsqueeze(1)
27
+ mask = ~torch.isnan(signals)
28
+ padding = win_length // 2
29
+
30
+ x = torch.nn.functional.pad(torch.where(mask, signals, torch.zeros_like(signals)), (padding, padding), mode="reflect")
31
+ mask = torch.nn.functional.pad(mask.float(), (padding, padding), mode="constant", value=0)
32
+
33
+ x = x.unfold(2, win_length, 1)
34
+ mask = mask.unfold(2, win_length, 1)
35
+
36
+ x = x.contiguous().view(x.size()[:3] + (-1,))
37
+ mask = mask.contiguous().view(mask.size()[:3] + (-1,))
38
+
39
+ x_sorted, _ = torch.sort(torch.where(mask.bool(), x.float(), float("inf")).to(x), dim=-1)
40
+
41
+ median_pooled = x_sorted.gather(-1, ((mask.sum(dim=-1) - 1) // 2).clamp(min=0).unsqueeze(-1).long()).squeeze(-1)
42
+ median_pooled[torch.isinf(median_pooled)] = float("nan")
43
+
44
+ return median_pooled.squeeze(1)
45
+
46
+ class CREPE_MODEL(torch.nn.Module):
47
+ def __init__(self, model='full'):
48
+ super().__init__()
49
+ in_channels = {"full": [1, 1024, 128, 128, 128, 256], "large": [1, 768, 96, 96, 96, 192], "medium": [1, 512, 64, 64, 64, 128], "small": [1, 256, 32, 32, 32, 64], "tiny": [1, 128, 16, 16, 16, 32]}[model]
50
+ out_channels = {"full": [1024, 128, 128, 128, 256, 512], "large": [768, 96, 96, 96, 192, 384], "medium": [512, 64, 64, 64, 128, 256], "small": [256, 32, 32, 32, 64, 128], "tiny": [128, 16, 16, 16, 32, 64]}[model]
51
+ self.in_features = {"full": 2048, "large": 1536, "medium": 1024, "small": 512, "tiny": 256}[model]
52
+
53
+ kernel_sizes = [(512, 1)] + 5 * [(64, 1)]
54
+ strides = [(4, 1)] + 5 * [(1, 1)]
55
+ batch_norm_fn = functools.partial(torch.nn.BatchNorm2d, eps=0.0010000000474974513, momentum=0.0)
56
+
57
+ self.conv1 = torch.nn.Conv2d(in_channels=in_channels[0], out_channels=out_channels[0], kernel_size=kernel_sizes[0], stride=strides[0])
58
+ self.conv1_BN = batch_norm_fn(num_features=out_channels[0])
59
+
60
+ self.conv2 = torch.nn.Conv2d(in_channels=in_channels[1], out_channels=out_channels[1], kernel_size=kernel_sizes[1], stride=strides[1])
61
+ self.conv2_BN = batch_norm_fn(num_features=out_channels[1])
62
+
63
+ self.conv3 = torch.nn.Conv2d(in_channels=in_channels[2], out_channels=out_channels[2], kernel_size=kernel_sizes[2], stride=strides[2])
64
+ self.conv3_BN = batch_norm_fn(num_features=out_channels[2])
65
+
66
+ self.conv4 = torch.nn.Conv2d(in_channels=in_channels[3], out_channels=out_channels[3], kernel_size=kernel_sizes[3], stride=strides[3])
67
+ self.conv4_BN = batch_norm_fn(num_features=out_channels[3])
68
+
69
+ self.conv5 = torch.nn.Conv2d(in_channels=in_channels[4], out_channels=out_channels[4], kernel_size=kernel_sizes[4], stride=strides[4])
70
+ self.conv5_BN = batch_norm_fn(num_features=out_channels[4])
71
+
72
+ self.conv6 = torch.nn.Conv2d(in_channels=in_channels[5], out_channels=out_channels[5], kernel_size=kernel_sizes[5], stride=strides[5])
73
+ self.conv6_BN = batch_norm_fn(num_features=out_channels[5])
74
+
75
+ self.classifier = torch.nn.Linear(in_features=self.in_features, out_features=PITCH_BINS)
76
+
77
+ def forward(self, x, embed=False):
78
+ x = self.embed(x)
79
+ if embed: return x
80
+ return torch.sigmoid(self.classifier(self.layer(x, self.conv6, self.conv6_BN).permute(0, 2, 1, 3).reshape(-1, self.in_features)))
81
+
82
+ def embed(self, x):
83
+ x = x[:, None, :, None]
84
+ return self.layer(self.layer(self.layer(self.layer(self.layer(x, self.conv1, self.conv1_BN, (0, 0, 254, 254)), self.conv2, self.conv2_BN), self.conv3, self.conv3_BN), self.conv4, self.conv4_BN), self.conv5, self.conv5_BN)
85
+
86
+ def layer(self, x, conv, batch_norm, padding=(0, 0, 31, 32)):
87
+ return torch.nn.functional.max_pool2d(batch_norm(torch.nn.functional.relu(conv(torch.nn.functional.pad(x, padding)))), (2, 1), (2, 1))
88
+
89
+ class CREPE:
90
+ def __init__(self, model_path, model_size="full", hop_length=512, batch_size=None, f0_min=50, f0_max=1100, device=None, sample_rate=16000, return_periodicity=False):
91
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
92
+ self.hop_length = hop_length
93
+ self.batch_size = batch_size
94
+ self.sample_rate = sample_rate
95
+ self.f0_min = f0_min
96
+ self.f0_max = f0_max
97
+ self.return_periodicity = return_periodicity
98
+ model = CREPE_MODEL(model_size)
99
+ ckpt = torch.load(model_path, map_location="cpu")
100
+ model.load_state_dict(ckpt)
101
+ model.eval()
102
+ self.model = model.to(device)
103
+
104
+ def bins_to_frequency(self, bins):
105
+ if str(bins.device).startswith("ocl"): bins = bins.to(torch.float32)
106
+
107
+ cents = CENTS_PER_BIN * bins + 1997.3794084376191
108
+ return 10 * 2 ** ((cents + cents.new_tensor(scipy.stats.triang.rvs(c=0.5, loc=-CENTS_PER_BIN, scale=2 * CENTS_PER_BIN, size=cents.size()))) / 1200)
109
+
110
+ def frequency_to_bins(self, frequency, quantize_fn=torch.floor):
111
+ return quantize_fn(((1200 * torch.log2(frequency / 10)) - 1997.3794084376191) / CENTS_PER_BIN).int()
112
+
113
+ def viterbi(self, logits):
114
+ if not hasattr(self, 'transition'):
115
+ xx, yy = np.meshgrid(range(360), range(360))
116
+ transition = np.maximum(12 - abs(xx - yy), 0)
117
+ self.transition = transition / transition.sum(axis=1, keepdims=True)
118
+
119
+ with torch.no_grad():
120
+ probs = torch.nn.functional.softmax(logits, dim=1)
121
+
122
+ bins = torch.tensor(np.array([librosa.sequence.viterbi(sequence, self.transition).astype(np.int64) for sequence in probs.cpu().numpy()]), device=probs.device)
123
+ return bins, self.bins_to_frequency(bins)
124
+
125
+ def preprocess(self, audio, pad=True):
126
+ hop_length = (self.sample_rate // 100) if self.hop_length is None else self.hop_length
127
+
128
+ if self.sample_rate != SAMPLE_RATE:
129
+ audio = torch.tensor(librosa.resample(audio.detach().cpu().numpy().squeeze(0), orig_sr=self.sample_rate, target_sr=SAMPLE_RATE, res_type="soxr_vhq"), device=audio.device).unsqueeze(0)
130
+ hop_length = int(hop_length * SAMPLE_RATE / self.sample_rate)
131
+
132
+ if pad:
133
+ total_frames = 1 + int(audio.size(1) // hop_length)
134
+ audio = torch.nn.functional.pad(audio, (WINDOW_SIZE // 2, WINDOW_SIZE // 2))
135
+ else: total_frames = 1 + int((audio.size(1) - WINDOW_SIZE) // hop_length)
136
+
137
+ batch_size = total_frames if self.batch_size is None else self.batch_size
138
+
139
+ for i in range(0, total_frames, batch_size):
140
+ frames = torch.nn.functional.unfold(audio[:, None, None, max(0, i * hop_length):min(audio.size(1), (i + batch_size - 1) * hop_length + WINDOW_SIZE)], kernel_size=(1, WINDOW_SIZE), stride=(1, hop_length))
141
+
142
+ if self.device.startswith("ocl"):
143
+ frames = frames.transpose(1, 2).contiguous().reshape(-1, WINDOW_SIZE).to(self.device)
144
+ else:
145
+ frames = frames.transpose(1, 2).reshape(-1, WINDOW_SIZE).to(self.device)
146
+
147
+ frames -= frames.mean(dim=1, keepdim=True)
148
+ frames /= torch.max(torch.tensor(1e-10, device=frames.device), frames.std(dim=1, keepdim=True))
149
+
150
+ yield frames
151
+
152
+ def periodicity(self, probabilities, bins):
153
+ probs_stacked = probabilities.transpose(1, 2).reshape(-1, PITCH_BINS)
154
+ periodicity = probs_stacked.gather(1, bins.reshape(-1, 1).to(torch.int64))
155
+
156
+ return periodicity.reshape(probabilities.size(0), probabilities.size(2))
157
+
158
+ def postprocess(self, probabilities):
159
+ probabilities = probabilities.detach()
160
+ probabilities[:, :self.frequency_to_bins(torch.tensor(self.f0_min))] = -float('inf')
161
+ probabilities[:, self.frequency_to_bins(torch.tensor(self.f0_max), torch.ceil):] = -float('inf')
162
+
163
+ bins, pitch = self.viterbi(probabilities)
164
+
165
+ if not self.return_periodicity: return pitch
166
+ return pitch, self.periodicity(probabilities, bins)
167
+
168
+ def compute_f0(self, audio, pad=True):
169
+ results = []
170
+
171
+ for frames in self.preprocess(audio, pad):
172
+ with torch.no_grad():
173
+ model = self.model(
174
+ frames,
175
+ embed=False
176
+ ).reshape(audio.size(0), -1, PITCH_BINS).transpose(1, 2)
177
+
178
+ result = self.postprocess(model)
179
+ results.append((result[0].to(audio.device), result[1].to(audio.device)) if isinstance(result, tuple) else result.to(audio.device))
180
+
181
+ if self.return_periodicity:
182
+ pitch, periodicity = zip(*results)
183
+ return torch.cat(pitch, 1), torch.cat(periodicity, 1)
184
+
185
+ return torch.cat(results, 1)
RVC/modules/torchfcpe.py ADDED
@@ -0,0 +1,951 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import torch
5
+
6
+ import numpy as np
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from torch import einsum
11
+ from functools import partial
12
+ from librosa.filters import mel
13
+ from torchaudio.transforms import Resample
14
+ from einops import rearrange, repeat, pack, unpack
15
+ from torch.nn.utils.parametrizations import weight_norm
16
+
17
+ sys.path.append(os.getcwd())
18
+
19
+ from modules import opencl
20
+
21
+ os.environ["LRU_CACHE_CAPACITY"] = "3"
22
+
23
+ def spawn_wav2mel(args, device = None):
24
+ _type = args.mel.type
25
+ if (str(_type).lower() == 'none') or (str(_type).lower() == 'default'): _type = 'default'
26
+ elif str(_type).lower() == 'stft': _type = 'stft'
27
+ wav2mel = Wav2MelModule(sr=args.mel.sr, n_mels=args.mel.num_mels, n_fft=args.mel.n_fft, win_size=args.mel.win_size, hop_length=args.mel.hop_size, fmin=args.mel.fmin, fmax=args.mel.fmax, clip_val=1e-05, mel_type=_type)
28
+
29
+ return wav2mel.to(torch.device(device))
30
+
31
+ def calc_same_padding(kernel_size):
32
+ pad = kernel_size // 2
33
+ return (pad, pad - (kernel_size + 1) % 2)
34
+
35
+ def l2_regularization(model, l2_alpha):
36
+ l2_loss = []
37
+ for module in model.modules():
38
+ if type(module) is nn.Conv2d: l2_loss.append((module.weight**2).sum() / 2.0)
39
+
40
+ return l2_alpha * sum(l2_loss)
41
+
42
+ def torch_interp(x, xp, fp):
43
+ sort_idx = torch.argsort(xp)
44
+ xp = xp[sort_idx]
45
+ fp = fp[sort_idx]
46
+
47
+ right_idxs = torch.searchsorted(xp, x).clamp(max=len(xp) - 1)
48
+ left_idxs = (right_idxs - 1).clamp(min=0)
49
+ x_left = xp[left_idxs]
50
+ y_left = fp[left_idxs]
51
+
52
+ interp_vals = y_left + ((x - x_left) * (fp[right_idxs] - y_left) / (xp[right_idxs] - x_left))
53
+ interp_vals[x < xp[0]] = fp[0]
54
+ interp_vals[x > xp[-1]] = fp[-1]
55
+
56
+ return interp_vals
57
+
58
+ def batch_interp_with_replacement_detach(uv, f0):
59
+ result = f0.clone()
60
+ for i in range(uv.shape[0]):
61
+ interp_vals = torch_interp(torch.where(uv[i])[-1], torch.where(~uv[i])[-1], f0[i][~uv[i]]).detach()
62
+ result[i][uv[i]] = interp_vals
63
+
64
+ return result
65
+
66
+ def ensemble_f0(f0s, key_shift_list, tta_uv_penalty):
67
+ device = f0s.device
68
+ f0s = f0s / (torch.pow(2, torch.tensor(key_shift_list, device=device).to(device).unsqueeze(0).unsqueeze(0) / 12))
69
+ notes = torch.log2(f0s / 440) * 12 + 69
70
+ notes[notes < 0] = 0
71
+
72
+ uv_penalty = tta_uv_penalty**2
73
+ dp = torch.zeros_like(notes, device=device)
74
+ backtrack = torch.zeros_like(notes, device=device).long()
75
+ dp[:, 0, :] = (notes[:, 0, :] <= 0) * uv_penalty
76
+
77
+ for t in range(1, notes.size(1)):
78
+ penalty = torch.zeros([notes.size(0), notes.size(2), notes.size(2)], device=device)
79
+ t_uv = notes[:, t, :] <= 0
80
+ penalty += uv_penalty * t_uv.unsqueeze(1)
81
+
82
+ t1_uv = notes[:, t - 1, :] <= 0
83
+ l2 = torch.pow((notes[:, t - 1, :].unsqueeze(-1) - notes[:, t, :].unsqueeze(1)) * (~t1_uv).unsqueeze(-1) * (~t_uv).unsqueeze(1), 2) - 0.5
84
+ l2 = l2 * (l2 > 0)
85
+
86
+ penalty += l2
87
+ penalty += t1_uv.unsqueeze(-1) * (~t_uv).unsqueeze(1) * uv_penalty * 2
88
+
89
+ min_value, min_indices = torch.min(dp[:, t - 1, :].unsqueeze(-1) + penalty, dim=1)
90
+ dp[:, t, :] = min_value
91
+ backtrack[:, t, :] = min_indices
92
+
93
+ t = f0s.size(1) - 1
94
+ f0_result = torch.zeros_like(f0s[:, :, 0], device=device)
95
+ min_indices = torch.argmin(dp[:, t, :], dim=-1)
96
+
97
+ for i in range(0, t + 1):
98
+ f0_result[:, t - i] = f0s[:, t - i, min_indices]
99
+ min_indices = backtrack[:, t - i, min_indices]
100
+
101
+ return f0_result.unsqueeze(-1)
102
+
103
+ def exists(val):
104
+ return val is not None
105
+
106
+ def default(value, d):
107
+ return value if exists(value) else d
108
+
109
+ def empty(tensor):
110
+ return tensor.numel() == 0
111
+
112
+ def pad_to_multiple(tensor, multiple, dim=-1, value=0):
113
+ seqlen = tensor.shape[dim]
114
+ m = seqlen / multiple
115
+ if m.is_integer(): return False, tensor
116
+ return True, F.pad(tensor, (*((0,) * (-1 - dim) * 2), 0, (math.ceil(m) * multiple - seqlen)), value = value)
117
+
118
+ def look_around(x, backward = 1, forward = 0, pad_value = -1, dim = 2):
119
+ t = x.shape[1]
120
+ dims = (len(x.shape) - dim) * (0, 0)
121
+ padded_x = F.pad(x, (*dims, backward, forward), value = pad_value)
122
+ return torch.cat([padded_x[:, ind:(ind + t), ...] for ind in range(forward + backward + 1)], dim = dim)
123
+
124
+ def rotate_half(x):
125
+ x1, x2 = rearrange(x, 'b ... (r d) -> b ... r d', r = 2).unbind(dim = -2)
126
+ return torch.cat((-x2, x1), dim = -1)
127
+
128
+ def apply_rotary_pos_emb(q, k, freqs, scale = 1):
129
+ q_len = q.shape[-2]
130
+ q_freqs = freqs[..., -q_len:, :]
131
+ inv_scale = scale ** -1
132
+ if scale.ndim == 2: scale = scale[-q_len:, :]
133
+ q = (q * q_freqs.cos() * scale) + (rotate_half(q) * q_freqs.sin() * scale)
134
+ k = (k * freqs.cos() * inv_scale) + (rotate_half(k) * freqs.sin() * inv_scale)
135
+
136
+ return q, k
137
+
138
+ def orthogonal_matrix_chunk(cols, qr_uniform_q=False, device=None):
139
+ unstructured_block = torch.randn((cols, cols), device=device)
140
+ q, r = torch.linalg.qr(unstructured_block.cpu(), mode="reduced")
141
+ q, r = map(lambda t: t.to(device), (q, r))
142
+ if qr_uniform_q:
143
+ d = torch.diag(r, 0)
144
+ q *= d.sign()
145
+
146
+ return q.t()
147
+
148
+ def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling=0, qr_uniform_q=False, device=None):
149
+ nb_full_blocks = int(nb_rows / nb_columns)
150
+ block_list = []
151
+ for _ in range(nb_full_blocks):
152
+ block_list.append(orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device))
153
+
154
+ remaining_rows = nb_rows - nb_full_blocks * nb_columns
155
+ if remaining_rows > 0: block_list.append(orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device)[:remaining_rows])
156
+ if scaling == 0: multiplier = torch.randn((nb_rows, nb_columns), device=device).norm(dim=1)
157
+ elif scaling == 1: multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device=device)
158
+ else: raise ValueError
159
+
160
+ return torch.diag(multiplier) @ torch.cat(block_list)
161
+
162
+ def linear_attention(q, k, v):
163
+ return einsum("...ed,...nd->...ne", k, q) if v is None else einsum("...de,...nd,...n->...ne", einsum("...nd,...ne->...de", k, v), q, 1.0 / (einsum("...nd,...d->...n", q, k.sum(dim=-2).type_as(q)) + 1e-8))
164
+
165
+ def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device=None):
166
+ b, h, *_ = data.shape
167
+
168
+ data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.0
169
+ ratio = projection_matrix.shape[0] ** -0.5
170
+ data_dash = torch.einsum("...id,...jd->...ij", (data_normalizer * data), repeat(projection_matrix, "j d -> b h j d", b=b, h=h).type_as(data))
171
+ diag_data = ((torch.sum(data**2, dim=-1) / 2.0) * (data_normalizer**2)).unsqueeze(dim=-1)
172
+
173
+ return (ratio * (torch.exp(data_dash - diag_data - torch.max(data_dash, dim=-1, keepdim=True).values) + eps) if is_query else ratio * (torch.exp(data_dash - diag_data + eps))).type_as(data)
174
+
175
+ class SinusoidalEmbeddings(nn.Module):
176
+ def __init__(self, dim, scale_base = None, use_xpos = False, theta = 10000):
177
+ super().__init__()
178
+ inv_freq = 1. / (theta ** (torch.arange(0, dim, 2).float() / dim))
179
+ self.register_buffer('inv_freq', inv_freq)
180
+ self.use_xpos = use_xpos
181
+ self.scale_base = scale_base
182
+ assert not (use_xpos and not exists(scale_base))
183
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
184
+ self.register_buffer('scale', scale, persistent = False)
185
+
186
+ def forward(self, x):
187
+ seq_len, device = x.shape[-2], x.device
188
+ t = torch.arange(seq_len, device = x.device).type_as(self.inv_freq)
189
+
190
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
191
+ freqs = torch.cat((freqs, freqs), dim = -1)
192
+
193
+ if not self.use_xpos: return freqs, torch.ones(1, device = device)
194
+
195
+ power = (t - (seq_len // 2)) / self.scale_base
196
+ scale = self.scale ** rearrange(power, 'n -> n 1')
197
+
198
+ return freqs, torch.cat((scale, scale), dim = -1)
199
+
200
+ class LocalAttention(nn.Module):
201
+ def __init__(self, window_size, causal = False, look_backward = 1, look_forward = None, dropout = 0., shared_qk = False, rel_pos_emb_config = None, dim = None, autopad = False, exact_windowsize = False, scale = None, use_rotary_pos_emb = True, use_xpos = False, xpos_scale_base = None):
202
+ super().__init__()
203
+ look_forward = default(look_forward, 0 if causal else 1)
204
+ assert not (causal and look_forward > 0)
205
+ self.scale = scale
206
+ self.window_size = window_size
207
+ self.autopad = autopad
208
+ self.exact_windowsize = exact_windowsize
209
+ self.causal = causal
210
+ self.look_backward = look_backward
211
+ self.look_forward = look_forward
212
+ self.dropout = nn.Dropout(dropout)
213
+ self.shared_qk = shared_qk
214
+ self.rel_pos = None
215
+ self.use_xpos = use_xpos
216
+ if use_rotary_pos_emb and (exists(rel_pos_emb_config) or exists(dim)):
217
+ if exists(rel_pos_emb_config): dim = rel_pos_emb_config[0]
218
+ self.rel_pos = SinusoidalEmbeddings(dim, use_xpos = use_xpos, scale_base = default(xpos_scale_base, window_size // 2))
219
+
220
+ def forward(self, q, k, v, mask = None, input_mask = None, attn_bias = None, window_size = None):
221
+ mask = default(mask, input_mask)
222
+ assert not (exists(window_size) and not self.use_xpos)
223
+
224
+ _, autopad, pad_value, window_size, causal, look_backward, look_forward, shared_qk = q.shape, self.autopad, -1, default(window_size, self.window_size), self.causal, self.look_backward, self.look_forward, self.shared_qk
225
+ (q, packed_shape), (k, _), (v, _) = map(lambda t: pack([t], '* n d'), (q, k, v))
226
+
227
+ if autopad:
228
+ orig_seq_len = q.shape[1]
229
+ (_, q), (_, k), (_, v) = map(lambda t: pad_to_multiple(t, self.window_size, dim = -2), (q, k, v))
230
+
231
+ b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype
232
+ scale = default(self.scale, dim_head ** -0.5)
233
+
234
+ assert (n % window_size) == 0
235
+ windows = n // window_size
236
+
237
+ if shared_qk: k = F.normalize(k, dim = -1).type(k.dtype)
238
+
239
+ seq = torch.arange(n, device = device)
240
+ b_t = rearrange(seq, '(w n) -> 1 w n', w = windows, n = window_size)
241
+ bq, bk, bv = map(lambda t: rearrange(t, 'b (w n) d -> b w n d', w = windows), (q, k, v))
242
+
243
+ bq = bq * scale
244
+ look_around_kwargs = dict(backward = look_backward, forward = look_forward, pad_value = pad_value)
245
+
246
+ bk = look_around(bk, **look_around_kwargs)
247
+ bv = look_around(bv, **look_around_kwargs)
248
+
249
+ if exists(self.rel_pos):
250
+ pos_emb, xpos_scale = self.rel_pos(bk)
251
+ bq, bk = apply_rotary_pos_emb(bq, bk, pos_emb, scale = xpos_scale)
252
+
253
+ bq_t = b_t
254
+ bq_k = look_around(b_t, **look_around_kwargs)
255
+ bq_t = rearrange(bq_t, '... i -> ... i 1')
256
+ bq_k = rearrange(bq_k, '... j -> ... 1 j')
257
+
258
+ pad_mask = bq_k == pad_value
259
+ sim = einsum('b h i e, b h j e -> b h i j', bq, bk)
260
+
261
+ if exists(attn_bias):
262
+ heads = attn_bias.shape[0]
263
+ assert (b % heads) == 0
264
+
265
+ attn_bias = repeat(attn_bias, 'h i j -> (b h) 1 i j', b = b // heads)
266
+ sim = sim + attn_bias
267
+
268
+ mask_value = -torch.finfo(sim.dtype).max
269
+ if shared_qk:
270
+ self_mask = bq_t == bq_k
271
+ sim = sim.masked_fill(self_mask, -5e4)
272
+ del self_mask
273
+
274
+ if causal:
275
+ causal_mask = bq_t < bq_k
276
+ if self.exact_windowsize: causal_mask = causal_mask | (bq_t > (bq_k + (self.window_size * self.look_backward)))
277
+ sim = sim.masked_fill(causal_mask, mask_value)
278
+ del causal_mask
279
+
280
+ sim = sim.masked_fill(((bq_k - (self.window_size * self.look_forward)) > bq_t) | (bq_t > (bq_k + (self.window_size * self.look_backward))) | pad_mask, mask_value) if not causal and self.exact_windowsize else sim.masked_fill(pad_mask, mask_value)
281
+
282
+ if exists(mask):
283
+ batch = mask.shape[0]
284
+ assert (b % batch) == 0
285
+
286
+ h = b // mask.shape[0]
287
+ if autopad: _, mask = pad_to_multiple(mask, window_size, dim = -1, value = False)
288
+
289
+ mask = repeat(rearrange(look_around(rearrange(mask, '... (w n) -> (...) w n', w = windows, n = window_size), **{**look_around_kwargs, 'pad_value': False}), '... j -> ... 1 j'), 'b ... -> (b h) ...', h = h)
290
+ sim = sim.masked_fill(~mask, mask_value)
291
+
292
+ del mask
293
+
294
+ out = rearrange(einsum('b h i j, b h j e -> b h i e', self.dropout(sim.softmax(dim = -1)), bv), 'b w n d -> b (w n) d')
295
+ if autopad: out = out[:, :orig_seq_len, :]
296
+
297
+ out, *_ = unpack(out, packed_shape, '* n d')
298
+ return out
299
+
300
+ class FastAttention(nn.Module):
301
+ def __init__(self, dim_heads, nb_features=None, ortho_scaling=0, causal=False, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, no_projection=False):
302
+ super().__init__()
303
+ nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
304
+ self.dim_heads = dim_heads
305
+ self.nb_features = nb_features
306
+ self.ortho_scaling = ortho_scaling
307
+ self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows=self.nb_features, nb_columns=dim_heads, scaling=ortho_scaling, qr_uniform_q=qr_uniform_q)
308
+ projection_matrix = self.create_projection()
309
+ self.register_buffer("projection_matrix", projection_matrix)
310
+ self.generalized_attention = generalized_attention
311
+ self.kernel_fn = kernel_fn
312
+ self.no_projection = no_projection
313
+ self.causal = causal
314
+
315
+ @torch.no_grad()
316
+ def redraw_projection_matrix(self):
317
+ projections = self.create_projection()
318
+ self.projection_matrix.copy_(projections)
319
+ del projections
320
+
321
+ def forward(self, q, k, v):
322
+ if self.no_projection: q, k = q.softmax(dim=-1), (torch.exp(k) if self.causal else k.softmax(dim=-2))
323
+ else:
324
+ create_kernel = partial(softmax_kernel, projection_matrix=self.projection_matrix, device=q.device)
325
+ q, k = create_kernel(q, is_query=True), create_kernel(k, is_query=False)
326
+
327
+ attn_fn = linear_attention if not self.causal else self.causal_linear_fn
328
+ return attn_fn(q, k, None) if v is None else attn_fn(q, k, v)
329
+
330
+ class SelfAttention(nn.Module):
331
+ def __init__(self, dim, causal=False, heads=8, dim_head=64, local_heads=0, local_window_size=256, nb_features=None, feature_redraw_interval=1000, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, dropout=0.0, no_projection=False):
332
+ super().__init__()
333
+ assert dim % heads == 0
334
+ dim_head = default(dim_head, dim // heads)
335
+ inner_dim = dim_head * heads
336
+ self.fast_attention = FastAttention(dim_head, nb_features, causal=causal, generalized_attention=generalized_attention, kernel_fn=kernel_fn, qr_uniform_q=qr_uniform_q, no_projection=no_projection)
337
+ self.heads = heads
338
+ self.global_heads = heads - local_heads
339
+ self.local_attn = (LocalAttention(window_size=local_window_size, causal=causal, autopad=True, dropout=dropout, look_forward=int(not causal), rel_pos_emb_config=(dim_head, local_heads)) if local_heads > 0 else None)
340
+ self.to_q = nn.Linear(dim, inner_dim)
341
+ self.to_k = nn.Linear(dim, inner_dim)
342
+ self.to_v = nn.Linear(dim, inner_dim)
343
+ self.to_out = nn.Linear(inner_dim, dim)
344
+ self.dropout = nn.Dropout(dropout)
345
+
346
+ @torch.no_grad()
347
+ def redraw_projection_matrix(self):
348
+ self.fast_attention.redraw_projection_matrix()
349
+
350
+ def forward(self, x, context=None, mask=None, context_mask=None, name=None, inference=False, **kwargs):
351
+ _, _, _, h, gh = *x.shape, self.heads, self.global_heads
352
+ cross_attend = exists(context)
353
+ context = default(context, x)
354
+ context_mask = default(context_mask, mask) if not cross_attend else context_mask
355
+
356
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (self.to_q(x), self.to_k(context), self.to_v(context)))
357
+ (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
358
+
359
+ attn_outs = []
360
+
361
+ if not empty(q):
362
+ if exists(context_mask): v.masked_fill_(~context_mask[:, None, :, None], 0.0)
363
+ if cross_attend: pass
364
+ else: out = self.fast_attention(q, k, v)
365
+
366
+ attn_outs.append(out)
367
+
368
+ if not empty(lq):
369
+ assert (not cross_attend), "not cross_attend"
370
+
371
+ out = self.local_attn(lq, lk, lv, input_mask=mask)
372
+ attn_outs.append(out)
373
+
374
+ return self.dropout(self.to_out(rearrange(torch.cat(attn_outs, dim=1), "b h n d -> b n (h d)")))
375
+
376
+ class DotDict(dict):
377
+ def __getattr__(*args):
378
+ val = dict.get(*args)
379
+ return DotDict(val) if type(val) is dict else val
380
+
381
+ __setattr__ = dict.__setitem__
382
+ __delattr__ = dict.__delitem__
383
+
384
+ class Swish(nn.Module):
385
+ def forward(self, x):
386
+ return x * x.sigmoid()
387
+
388
+ class Transpose(nn.Module):
389
+ def __init__(self, dims):
390
+ super().__init__()
391
+ assert len(dims) == 2, "dims == 2"
392
+ self.dims = dims
393
+
394
+ def forward(self, x):
395
+ return x.transpose(*self.dims)
396
+
397
+ class GLU(nn.Module):
398
+ def __init__(self, dim):
399
+ super().__init__()
400
+ self.dim = dim
401
+
402
+ def forward(self, x):
403
+ out, gate = x.chunk(2, dim=self.dim)
404
+ return out * gate.sigmoid()
405
+
406
+ class ConformerConvModule_LEGACY(nn.Module):
407
+ def __init__(self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0):
408
+ super().__init__()
409
+ inner_dim = dim * expansion_factor
410
+ self.net = nn.Sequential(nn.LayerNorm(dim), Transpose((1, 2)), nn.Conv1d(dim, inner_dim * 2, 1), GLU(dim=1), DepthWiseConv1d_LEGACY(inner_dim, inner_dim, kernel_size=kernel_size, padding=(calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0))), Swish(), nn.Conv1d(inner_dim, dim, 1), Transpose((1, 2)), nn.Dropout(dropout))
411
+
412
+ def forward(self, x):
413
+ return self.net(x)
414
+
415
+ class ConformerConvModule(nn.Module):
416
+ def __init__(self, dim, expansion_factor=2, kernel_size=31, dropout=0):
417
+ super().__init__()
418
+ inner_dim = dim * expansion_factor
419
+ self.net = nn.Sequential(nn.LayerNorm(dim), Transpose((1, 2)), nn.Conv1d(dim, inner_dim * 2, 1), nn.GLU(dim=1), DepthWiseConv1d(inner_dim, inner_dim, kernel_size=kernel_size, padding=calc_same_padding(kernel_size)[0], groups=inner_dim), nn.SiLU(), nn.Conv1d(inner_dim, dim, 1), Transpose((1, 2)), nn.Dropout(dropout))
420
+
421
+ def forward(self, x):
422
+ return self.net(x)
423
+
424
+ class DepthWiseConv1d_LEGACY(nn.Module):
425
+ def __init__(self, chan_in, chan_out, kernel_size, padding):
426
+ super().__init__()
427
+ self.padding = padding
428
+ self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in)
429
+
430
+ def forward(self, x):
431
+ return self.conv(F.pad(x, self.padding))
432
+
433
+ class DepthWiseConv1d(nn.Module):
434
+ def __init__(self, chan_in, chan_out, kernel_size, padding, groups):
435
+ super().__init__()
436
+ self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=kernel_size, padding=padding, groups=groups)
437
+
438
+ def forward(self, x):
439
+ return self.conv(x)
440
+
441
+ class EncoderLayer(nn.Module):
442
+ def __init__(self, parent):
443
+ super().__init__()
444
+ self.conformer = ConformerConvModule_LEGACY(parent.dim_model)
445
+ self.norm = nn.LayerNorm(parent.dim_model)
446
+ self.dropout = nn.Dropout(parent.residual_dropout)
447
+ self.attn = SelfAttention(dim=parent.dim_model, heads=parent.num_heads, causal=False)
448
+
449
+ def forward(self, phone, mask=None):
450
+ phone = phone + (self.attn(self.norm(phone), mask=mask))
451
+ return phone + (self.conformer(phone))
452
+
453
+ class ConformerNaiveEncoder(nn.Module):
454
+ def __init__(self, num_layers, num_heads, dim_model, use_norm = False, conv_only = False, conv_dropout = 0, atten_dropout = 0):
455
+ super().__init__()
456
+ self.num_layers = num_layers
457
+ self.num_heads = num_heads
458
+ self.dim_model = dim_model
459
+ self.use_norm = use_norm
460
+ self.residual_dropout = 0.1
461
+ self.attention_dropout = 0.1
462
+ self.encoder_layers = nn.ModuleList([CFNEncoderLayer(dim_model, num_heads, use_norm, conv_only, conv_dropout, atten_dropout) for _ in range(num_layers)])
463
+
464
+ def forward(self, x, mask=None):
465
+ for (_, layer) in enumerate(self.encoder_layers):
466
+ x = layer(x, mask)
467
+
468
+ return x
469
+
470
+ class CFNEncoderLayer(nn.Module):
471
+ def __init__(self, dim_model, num_heads = 8, use_norm = False, conv_only = False, conv_dropout = 0, atten_dropout = 0):
472
+ super().__init__()
473
+ self.conformer = nn.Sequential(ConformerConvModule(dim_model), nn.Dropout(conv_dropout)) if conv_dropout > 0 else ConformerConvModule(dim_model)
474
+ self.norm = nn.LayerNorm(dim_model)
475
+ self.dropout = nn.Dropout(0.1)
476
+ self.attn = SelfAttention(dim=dim_model, heads=num_heads, causal=False, use_norm=use_norm, dropout=atten_dropout) if not conv_only else None
477
+
478
+ def forward(self, x, mask=None):
479
+ if self.attn is not None: x = x + (self.attn(self.norm(x), mask=mask))
480
+ return x + (self.conformer(x))
481
+
482
+
483
+ class HannWindow(torch.nn.Module):
484
+ def __init__(self, win_size):
485
+ super().__init__()
486
+ self.register_buffer('window', torch.hann_window(win_size), persistent=False)
487
+
488
+ def forward(self):
489
+ return self.window
490
+
491
+ class MelModule(torch.nn.Module):
492
+ def __init__(self, sr, n_mels, n_fft, win_size, hop_length, fmin = None, fmax = None, clip_val = 1e-5, out_stft = False):
493
+ super().__init__()
494
+ if fmin is None: fmin = 0
495
+ if fmax is None: fmax = sr / 2
496
+ self.target_sr = sr
497
+ self.n_mels = n_mels
498
+ self.n_fft = n_fft
499
+ self.win_size = win_size
500
+ self.hop_length = hop_length
501
+ self.fmin = fmin
502
+ self.fmax = fmax
503
+ self.clip_val = clip_val
504
+ self.register_buffer('mel_basis', torch.tensor(mel(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)).float(), persistent=False)
505
+ self.hann_window = torch.nn.ModuleDict()
506
+ self.out_stft = out_stft
507
+
508
+ @torch.no_grad()
509
+ def __call__(self, y, key_shift = 0, speed = 1, center = False, no_cache_window = False):
510
+ n_fft = self.n_fft
511
+ win_size = self.win_size
512
+ hop_length = self.hop_length
513
+ clip_val = self.clip_val
514
+ factor = 2 ** (key_shift / 12)
515
+ n_fft_new = int(np.round(n_fft * factor))
516
+ win_size_new = int(np.round(win_size * factor))
517
+ hop_length_new = int(np.round(hop_length * speed))
518
+
519
+ y = y.squeeze(-1)
520
+ key_shift_key = str(key_shift)
521
+
522
+ if not no_cache_window:
523
+ if key_shift_key in self.hann_window: hann_window = self.hann_window[key_shift_key]
524
+ else:
525
+ hann_window = HannWindow(win_size_new).to(self.mel_basis.device)
526
+ self.hann_window[key_shift_key] = hann_window
527
+
528
+ hann_window_tensor = hann_window()
529
+ else: hann_window_tensor = torch.hann_window(win_size_new).to(self.mel_basis.device)
530
+
531
+ pad_left = (win_size_new - hop_length_new) // 2
532
+ pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
533
+
534
+ mode = 'reflect' if pad_right < y.size(-1) else 'constant'
535
+ pad = F.pad(y.unsqueeze(1), (pad_left, pad_right), mode=mode).squeeze(1)
536
+
537
+ if str(y.device).startswith("ocl"):
538
+ stft = opencl.STFT(filter_length=n_fft_new, hop_length=hop_length_new, win_length=win_size_new).to(y.device)
539
+ spec = stft.transform(pad, 1e-9)
540
+ else:
541
+ spec = torch.stft(pad, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=hann_window_tensor, center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
542
+ spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-9)
543
+
544
+ if key_shift != 0:
545
+ size = n_fft // 2 + 1
546
+ resize = spec.size(1)
547
+
548
+ if resize < size: spec = F.pad(spec, (0, 0, 0, size - resize))
549
+ spec = spec[:, :size, :] * win_size / win_size_new
550
+
551
+ spec = spec[:, :512, :] if self.out_stft else torch.matmul(self.mel_basis, spec)
552
+ return torch.log(torch.clamp(spec, min=clip_val) * 1).transpose(-1, -2)
553
+
554
+ class Wav2MelModule(torch.nn.Module):
555
+ def __init__(self, sr, n_mels, n_fft, win_size, hop_length, fmin = None, fmax = None, clip_val = 1e-5, mel_type="default"):
556
+ super().__init__()
557
+ if fmin is None: fmin = 0
558
+ if fmax is None: fmax = sr / 2
559
+ self.sampling_rate = sr
560
+ self.n_mels = n_mels
561
+ self.n_fft = n_fft
562
+ self.win_size = win_size
563
+ self.hop_size = hop_length
564
+ self.fmin = fmin
565
+ self.fmax = fmax
566
+ self.clip_val = clip_val
567
+ self.register_buffer('tensor_device_marker', torch.tensor(1.0).float(), persistent=False)
568
+ self.resample_kernel = torch.nn.ModuleDict()
569
+ if mel_type == "default": self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, out_stft=False)
570
+ elif mel_type == "stft": self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, out_stft=True)
571
+ self.mel_type = mel_type
572
+
573
+ @torch.no_grad()
574
+ def __call__(self, audio, sample_rate, keyshift = 0, no_cache_window = False):
575
+ if sample_rate == self.sampling_rate: audio_res = audio
576
+ else:
577
+ key_str = str(sample_rate)
578
+ if key_str not in self.resample_kernel:
579
+ if len(self.resample_kernel) > 8: self.resample_kernel.clear()
580
+ self.resample_kernel[key_str] = Resample(sample_rate, self.sampling_rate, lowpass_filter_width=128).to(self.tensor_device_marker.device)
581
+
582
+ audio_res = self.resample_kernel[key_str](audio.squeeze(-1)).unsqueeze(-1)
583
+
584
+ mel = self.mel_extractor(audio_res, keyshift, no_cache_window=no_cache_window)
585
+ n_frames = int(audio.shape[1] // self.hop_size) + 1
586
+ if n_frames > int(mel.shape[1]): mel = torch.cat((mel, mel[:, -1:, :]), 1)
587
+ if n_frames < int(mel.shape[1]): mel = mel[:, :n_frames, :]
588
+
589
+ return mel
590
+
591
+ class STFT:
592
+ def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
593
+ self.target_sr = sr
594
+ self.n_mels = n_mels
595
+ self.n_fft = n_fft
596
+ self.win_size = win_size
597
+ self.hop_length = hop_length
598
+ self.fmin = fmin
599
+ self.fmax = fmax
600
+ self.clip_val = clip_val
601
+ self.mel_basis = {}
602
+ self.hann_window = {}
603
+
604
+ def get_mel(self, y, keyshift=0, speed=1, center=False, train=False):
605
+ n_fft = self.n_fft
606
+ win_size = self.win_size
607
+ hop_length = self.hop_length
608
+ fmax = self.fmax
609
+ factor = 2 ** (keyshift / 12)
610
+ win_size_new = int(np.round(win_size * factor))
611
+ hop_length_new = int(np.round(hop_length * speed))
612
+ mel_basis = self.mel_basis if not train else {}
613
+ hann_window = self.hann_window if not train else {}
614
+ mel_basis_key = str(fmax) + "_" + str(y.device)
615
+
616
+ if mel_basis_key not in mel_basis: mel_basis[mel_basis_key] = torch.from_numpy(mel(sr=self.target_sr, n_fft=n_fft, n_mels=self.n_mels, fmin=self.fmin, fmax=fmax)).float().to(y.device)
617
+ keyshift_key = str(keyshift) + "_" + str(y.device)
618
+ if keyshift_key not in hann_window: hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
619
+
620
+ pad_left = (win_size_new - hop_length_new) // 2
621
+ pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
622
+
623
+ pad = F.pad(y.unsqueeze(1), (pad_left, pad_right), mode="reflect" if pad_right < y.size(-1) else "constant").squeeze(1)
624
+ n_fft = int(np.round(n_fft * factor))
625
+
626
+ if str(y.device).startswith("ocl"):
627
+ stft = opencl.STFT(filter_length=n_fft, hop_length=hop_length_new, win_length=win_size_new).to(y.device)
628
+ spec = stft.transform(pad, 1e-9)
629
+ else:
630
+ spec = torch.stft(pad, n_fft, hop_length=hop_length_new, win_length=win_size_new, window=hann_window[keyshift_key], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
631
+ spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-9)
632
+
633
+ if keyshift != 0:
634
+ size = n_fft // 2 + 1
635
+ resize = spec.size(1)
636
+ spec = (F.pad(spec, (0, 0, 0, size - resize)) if resize < size else spec[:, :size, :]) * win_size / win_size_new
637
+
638
+ return torch.log(torch.clamp(torch.matmul(mel_basis[mel_basis_key], spec), min=self.clip_val) * 1)
639
+
640
+ class Wav2Mel:
641
+ def __init__(self, device=None, dtype=torch.float32):
642
+ self.sample_rate = 16000
643
+ self.hop_size = 160
644
+ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
645
+ self.device = device
646
+ self.dtype = dtype
647
+ self.stft = STFT(16000, 128, 1024, 1024, 160, 0, 8000)
648
+ self.resample_kernel = {}
649
+
650
+ def extract_nvstft(self, audio, keyshift=0, train=False):
651
+ return self.stft.get_mel(audio, keyshift=keyshift, train=train).transpose(1, 2)
652
+
653
+ def extract_mel(self, audio, sample_rate, keyshift=0, train=False):
654
+ audio = audio.to(self.dtype).to(self.device)
655
+ if sample_rate == self.sample_rate: audio_res = audio
656
+ else:
657
+ key_str = str(sample_rate)
658
+ if key_str not in self.resample_kernel: self.resample_kernel[key_str] = Resample(sample_rate, self.sample_rate, lowpass_filter_width=128)
659
+ self.resample_kernel[key_str] = (self.resample_kernel[key_str].to(self.dtype).to(self.device))
660
+ audio_res = self.resample_kernel[key_str](audio)
661
+
662
+ mel = self.extract_nvstft(audio_res, keyshift=keyshift, train=train)
663
+ n_frames = int(audio.shape[1] // self.hop_size) + 1
664
+ mel = (torch.cat((mel, mel[:, -1:, :]), 1) if n_frames > int(mel.shape[1]) else mel)
665
+ return mel[:, :n_frames, :] if n_frames < int(mel.shape[1]) else mel
666
+
667
+ def __call__(self, audio, sample_rate, keyshift=0, train=False):
668
+ return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train)
669
+
670
+ class PCmer(nn.Module):
671
+ def __init__(self, num_layers, num_heads, dim_model, dim_keys, dim_values, residual_dropout, attention_dropout):
672
+ super().__init__()
673
+ self.num_layers = num_layers
674
+ self.num_heads = num_heads
675
+ self.dim_model = dim_model
676
+ self.dim_values = dim_values
677
+ self.dim_keys = dim_keys
678
+ self.residual_dropout = residual_dropout
679
+ self.attention_dropout = attention_dropout
680
+ self._layers = nn.ModuleList([EncoderLayer(self) for _ in range(num_layers)])
681
+
682
+ def forward(self, phone, mask=None):
683
+ for layer in self._layers:
684
+ phone = layer(phone, mask)
685
+
686
+ return phone
687
+
688
+ class CFNaiveMelPE(nn.Module):
689
+ def __init__(self, input_channels, out_dims, hidden_dims = 512, n_layers = 6, n_heads = 8, f0_max = 1975.5, f0_min = 32.70, use_fa_norm = False, conv_only = False, conv_dropout = 0, atten_dropout = 0, use_harmonic_emb = False):
690
+ super().__init__()
691
+ self.input_channels = input_channels
692
+ self.out_dims = out_dims
693
+ self.hidden_dims = hidden_dims
694
+ self.n_layers = n_layers
695
+ self.n_heads = n_heads
696
+ self.f0_max = f0_max
697
+ self.f0_min = f0_min
698
+ self.use_fa_norm = use_fa_norm
699
+ self.residual_dropout = 0.1
700
+ self.attention_dropout = 0.1
701
+ self.harmonic_emb = nn.Embedding(9, hidden_dims) if use_harmonic_emb else None
702
+ self.input_stack = nn.Sequential(nn.Conv1d(input_channels, hidden_dims, 3, 1, 1), nn.GroupNorm(4, hidden_dims), nn.LeakyReLU(), nn.Conv1d(hidden_dims, hidden_dims, 3, 1, 1))
703
+ self.net = ConformerNaiveEncoder(num_layers=n_layers, num_heads=n_heads, dim_model=hidden_dims, use_norm=use_fa_norm, conv_only=conv_only, conv_dropout=conv_dropout, atten_dropout=atten_dropout)
704
+ self.norm = nn.LayerNorm(hidden_dims)
705
+ self.output_proj = weight_norm(nn.Linear(hidden_dims, out_dims))
706
+ self.cent_table_b = torch.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims).detach()
707
+ self.register_buffer("cent_table", self.cent_table_b)
708
+ self.gaussian_blurred_cent_mask_b = (1200 * torch.log2(torch.Tensor([self.f0_max / 10.])))[0].detach()
709
+ self.register_buffer("gaussian_blurred_cent_mask", self.gaussian_blurred_cent_mask_b)
710
+
711
+ def forward(self, x, _h_emb=None):
712
+ x = self.input_stack(x.transpose(-1, -2)).transpose(-1, -2)
713
+ if self.harmonic_emb is not None: x = x + self.harmonic_emb(torch.LongTensor([0]).to(x.device)) if _h_emb is None else x + self.harmonic_emb(torch.LongTensor([int(_h_emb)]).to(x.device))
714
+ return torch.sigmoid(self.output_proj(self.norm(self.net(x))))
715
+
716
+ @torch.no_grad()
717
+ def latent2cents_decoder(self, y, threshold = 0.05, mask = True):
718
+ B, N, _ = y.size()
719
+ ci = self.cent_table[None, None, :].expand(B, N, -1)
720
+ rtn = torch.sum(ci * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True)
721
+
722
+ if mask:
723
+ confident = torch.max(y, dim=-1, keepdim=True)[0]
724
+ confident_mask = torch.ones_like(confident)
725
+ confident_mask[confident <= threshold] = float("-INF")
726
+ rtn = rtn * confident_mask
727
+
728
+ return rtn
729
+
730
+ @torch.no_grad()
731
+ def latent2cents_local_decoder(self, y, threshold = 0.05, mask = True):
732
+ B, N, _ = y.size()
733
+ ci = self.cent_table[None, None, :].expand(B, N, -1)
734
+ confident, max_index = torch.max(y, dim=-1, keepdim=True)
735
+
736
+ local_argmax_index = torch.arange(0, 9).to(max_index.device) + (max_index - 4)
737
+ local_argmax_index[local_argmax_index < 0] = 0
738
+ local_argmax_index[local_argmax_index >= self.out_dims] = self.out_dims - 1
739
+
740
+ y_l = torch.gather(y, -1, local_argmax_index)
741
+ rtn = torch.sum(torch.gather(ci, -1, local_argmax_index) * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True)
742
+
743
+ if mask:
744
+ confident_mask = torch.ones_like(confident)
745
+ confident_mask[confident <= threshold] = float("-INF")
746
+ rtn = rtn * confident_mask
747
+
748
+ return rtn
749
+
750
+ @torch.no_grad()
751
+ def infer(self, mel, decoder = "local_argmax", threshold = 0.05):
752
+ latent = self.forward(mel)
753
+ if decoder == "argmax": cents = self.latent2cents_local_decoder
754
+ elif decoder == "local_argmax": cents = self.latent2cents_local_decoder
755
+
756
+ return self.cent_to_f0(cents(latent, threshold=threshold))
757
+
758
+ @torch.no_grad()
759
+ def cent_to_f0(self, cent: torch.Tensor) -> torch.Tensor:
760
+ return 10 * 2 ** (cent / 1200)
761
+
762
+ @torch.no_grad()
763
+ def f0_to_cent(self, f0):
764
+ return 1200 * torch.log2(f0 / 10)
765
+
766
+ class FCPE_LEGACY(nn.Module):
767
+ def __init__(self, input_channel=128, out_dims=360, n_layers=12, n_chans=512, loss_mse_scale=10, loss_l2_regularization=False, loss_l2_regularization_scale=1, loss_grad1_mse=False, loss_grad1_mse_scale=1, f0_max=1975.5, f0_min=32.70, confidence=False, threshold=0.05, use_input_conv=True):
768
+ super().__init__()
769
+ self.loss_mse_scale = loss_mse_scale
770
+ self.loss_l2_regularization = loss_l2_regularization
771
+ self.loss_l2_regularization_scale = loss_l2_regularization_scale
772
+ self.loss_grad1_mse = loss_grad1_mse
773
+ self.loss_grad1_mse_scale = loss_grad1_mse_scale
774
+ self.f0_max = f0_max
775
+ self.f0_min = f0_min
776
+ self.confidence = confidence
777
+ self.threshold = threshold
778
+ self.use_input_conv = use_input_conv
779
+ self.cent_table_b = torch.Tensor(np.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims))
780
+ self.register_buffer("cent_table", self.cent_table_b)
781
+ self.stack = nn.Sequential(nn.Conv1d(input_channel, n_chans, 3, 1, 1), nn.GroupNorm(4, n_chans), nn.LeakyReLU(), nn.Conv1d(n_chans, n_chans, 3, 1, 1))
782
+ self.decoder = PCmer(num_layers=n_layers, num_heads=8, dim_model=n_chans, dim_keys=n_chans, dim_values=n_chans, residual_dropout=0.1, attention_dropout=0.1)
783
+ self.norm = nn.LayerNorm(n_chans)
784
+ self.n_out = out_dims
785
+ self.dense_out = weight_norm(nn.Linear(n_chans, self.n_out))
786
+
787
+ def forward(self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder="local_argmax", output_interp_target_length=None):
788
+ if cdecoder == "argmax": self.cdecoder = self.cents_decoder
789
+ elif cdecoder == "local_argmax": self.cdecoder = self.cents_local_decoder
790
+
791
+ x = torch.sigmoid(self.dense_out(self.norm(self.decoder((self.stack(mel.transpose(1, 2)).transpose(1, 2) if self.use_input_conv else mel)))))
792
+
793
+ if not infer:
794
+ loss_all = self.loss_mse_scale * F.binary_cross_entropy(x, self.gaussian_blurred_cent(self.f0_to_cent(gt_f0)))
795
+ if self.loss_l2_regularization: loss_all = loss_all + l2_regularization(model=self, l2_alpha=self.loss_l2_regularization_scale)
796
+ x = loss_all
797
+ else:
798
+ x = self.cent_to_f0(self.cdecoder(x))
799
+ x = (1 + x / 700).log() if not return_hz_f0 else x
800
+
801
+ if output_interp_target_length is not None:
802
+ x = F.interpolate(torch.where(x == 0, float("nan"), x).transpose(1, 2), size=int(output_interp_target_length), mode="linear").transpose(1, 2)
803
+ x = torch.where(x.isnan(), float(0.0), x)
804
+
805
+ return x
806
+
807
+ def cents_decoder(self, y, mask=True):
808
+ B, N, _ = y.size()
809
+ rtn = torch.sum(self.cent_table[None, None, :].expand(B, N, -1) * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True)
810
+
811
+ if mask:
812
+ confident = torch.max(y, dim=-1, keepdim=True)[0]
813
+ confident_mask = torch.ones_like(confident)
814
+ confident_mask[confident <= self.threshold] = float("-INF")
815
+ rtn = rtn * confident_mask
816
+
817
+ return (rtn, confident) if self.confidence else rtn
818
+
819
+ def cents_local_decoder(self, y, mask=True):
820
+ B, N, _ = y.size()
821
+
822
+ confident, max_index = torch.max(y, dim=-1, keepdim=True)
823
+ local_argmax_index = torch.clamp(torch.arange(0, 9).to(max_index.device) + (max_index - 4), 0, self.n_out - 1)
824
+ y_l = torch.gather(y, -1, local_argmax_index)
825
+ rtn = torch.sum(torch.gather(self.cent_table[None, None, :].expand(B, N, -1), -1, local_argmax_index) * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True)
826
+
827
+ if mask:
828
+ confident_mask = torch.ones_like(confident)
829
+ confident_mask[confident <= self.threshold] = float("-INF")
830
+ rtn = rtn * confident_mask
831
+
832
+ return (rtn, confident) if self.confidence else rtn
833
+
834
+ def cent_to_f0(self, cent):
835
+ return 10.0 * 2 ** (cent / 1200.0)
836
+
837
+ def f0_to_cent(self, f0):
838
+ return 1200.0 * torch.log2(f0 / 10.0)
839
+
840
+ def gaussian_blurred_cent(self, cents):
841
+ B, N, _ = cents.size()
842
+ return torch.exp(-torch.square(self.cent_table[None, None, :].expand(B, N, -1) - cents) / 1250) * (cents > 0.1) & (cents < (1200.0 * np.log2(self.f0_max / 10.0))).float()
843
+
844
+ class InferCFNaiveMelPE(torch.nn.Module):
845
+ def __init__(self, args, state_dict):
846
+ super().__init__()
847
+ self.wav2mel = spawn_wav2mel(args, device="cpu")
848
+ self.model = CFNaiveMelPE(input_channels=args.mel.num_mels, out_dims=args.model.out_dims, hidden_dims=args.model.hidden_dims, n_layers=args.model.n_layers, n_heads=args.model.n_heads, f0_max=args.model.f0_max, f0_min=args.model.f0_min, use_fa_norm=args.model.use_fa_norm, conv_only=args.model.conv_only, conv_dropout=args.model.conv_dropout, atten_dropout=args.model.atten_dropout, use_harmonic_emb=False)
849
+ self.model.load_state_dict(state_dict)
850
+ self.model.eval()
851
+ self.args_dict = dict(args)
852
+ self.register_buffer("tensor_device_marker", torch.tensor(1.0).float(), persistent=False)
853
+
854
+ def forward(self, wav, sr, decoder_mode = "local_argmax", threshold = 0.006, key_shifts = [0]):
855
+ with torch.no_grad():
856
+ mels = rearrange(torch.stack([self.wav2mel(wav.to(self.tensor_device_marker.device), sr, keyshift=keyshift) for keyshift in key_shifts], -1), "B T C K -> (B K) T C")
857
+ f0s = rearrange(self.model.infer(mels, decoder=decoder_mode, threshold=threshold), "(B K) T 1 -> B T (K 1)", K=len(key_shifts))
858
+
859
+ return f0s
860
+
861
+ def infer(self, wav, sr, decoder_mode = "local_argmax", threshold = 0.006, f0_min = None, f0_max = None, interp_uv = False, output_interp_target_length = None, return_uv = False, test_time_augmentation = False, tta_uv_penalty = 12.0, tta_key_shifts = [0, -12, 12], tta_use_origin_uv=False):
862
+ if test_time_augmentation:
863
+ assert len(tta_key_shifts) > 0
864
+ flag = 0
865
+ if tta_use_origin_uv:
866
+ if 0 not in tta_key_shifts:
867
+ flag = 1
868
+ tta_key_shifts.append(0)
869
+
870
+ tta_key_shifts.sort(key=lambda x: (x if x >= 0 else -x / 2))
871
+ f0s = self.__call__(wav, sr, decoder_mode, threshold, tta_key_shifts)
872
+ f0 = ensemble_f0(f0s[:, :, flag:], tta_key_shifts[flag:], tta_uv_penalty)
873
+ f0_for_uv = f0s[:, :, [0]] if tta_use_origin_uv else f0
874
+ else:
875
+ f0 = self.__call__(wav, sr, decoder_mode, threshold)
876
+ f0_for_uv = f0
877
+
878
+ if f0_min is None: f0_min = self.args_dict["model"]["f0_min"]
879
+ uv = (f0_for_uv < f0_min).type(f0_for_uv.dtype)
880
+ f0 = f0 * (1 - uv)
881
+
882
+ if interp_uv: f0 = batch_interp_with_replacement_detach(uv.squeeze(-1).bool(), f0.squeeze(-1)).unsqueeze(-1)
883
+ if f0_max is not None: f0[f0 > f0_max] = f0_max
884
+ if output_interp_target_length is not None:
885
+ f0 = F.interpolate(torch.where(f0 == 0, float("nan"), f0).transpose(1, 2), size=int(output_interp_target_length), mode="linear").transpose(1, 2)
886
+ f0 = torch.where(f0.isnan(), float(0.0), f0)
887
+
888
+ if return_uv: return f0, F.interpolate(uv.transpose(1, 2), size=int(output_interp_target_length), mode="nearest").transpose(1, 2)
889
+ else: return f0
890
+
891
+ class FCPEInfer_LEGACY:
892
+ def __init__(self, model_path, device=None, dtype=torch.float32, f0_min=50, f0_max=1100):
893
+ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
894
+ self.device = device
895
+ self.dtype = dtype
896
+ self.f0_min = f0_min
897
+ self.f0_max = f0_max
898
+ ckpt = torch.load(model_path, map_location=torch.device(self.device))
899
+ self.args = DotDict(ckpt["config"])
900
+ model = FCPE_LEGACY(input_channel=self.args.model.input_channel, out_dims=self.args.model.out_dims, n_layers=self.args.model.n_layers, n_chans=self.args.model.n_chans, loss_mse_scale=self.args.loss.loss_mse_scale, loss_l2_regularization=self.args.loss.loss_l2_regularization, loss_l2_regularization_scale=self.args.loss.loss_l2_regularization_scale, loss_grad1_mse=self.args.loss.loss_grad1_mse, loss_grad1_mse_scale=self.args.loss.loss_grad1_mse_scale, f0_max=self.f0_max, f0_min=self.f0_min, confidence=self.args.model.confidence)
901
+ model.to(self.device).to(self.dtype)
902
+ model.load_state_dict(ckpt["model"])
903
+ model.eval()
904
+ self.model = model
905
+
906
+ @torch.no_grad()
907
+ def __call__(self, audio, sr, threshold=0.05, p_len=None):
908
+ self.model.threshold = threshold
909
+ self.wav2mel = Wav2Mel(device=self.device, dtype=self.dtype)
910
+
911
+ return self.model(mel=self.wav2mel(audio=audio[None, :], sample_rate=sr).to(self.dtype), infer=True, return_hz_f0=True, output_interp_target_length=p_len)
912
+
913
+ class FCPEInfer:
914
+ def __init__(self, model_path, device=None, dtype=torch.float32, f0_min=50, f0_max=1100):
915
+ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
916
+ self.device = device
917
+ self.dtype = dtype
918
+ self.f0_min = f0_min
919
+ self.f0_max = f0_max
920
+ ckpt = torch.load(model_path, map_location=torch.device(device))
921
+ ckpt["config_dict"]["model"]["conv_dropout"] = ckpt["config_dict"]["model"]["atten_dropout"] = 0.0
922
+ self.args = DotDict(ckpt["config_dict"])
923
+ model = InferCFNaiveMelPE(self.args, ckpt["model"])
924
+ model = model.to(device).to(self.dtype)
925
+ model.eval()
926
+ self.model = model
927
+
928
+ @torch.no_grad()
929
+ def __call__(self, audio, sr, threshold=0.05, p_len=None):
930
+ return self.model.infer(audio[None, :], sr, threshold=threshold, f0_min=self.f0_min, f0_max=self.f0_max, output_interp_target_length=p_len)
931
+
932
+ class FCPE:
933
+ def __init__(self, model_path, hop_length=512, f0_min=50, f0_max=1100, dtype=torch.float32, device=None, sample_rate=16000, threshold=0.05, legacy=False):
934
+ self.model = FCPEInfer_LEGACY if legacy else FCPEInfer
935
+ self.fcpe = self.model(model_path, device=device, dtype=dtype, f0_min=f0_min, f0_max=f0_max)
936
+ self.hop_length = hop_length
937
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
938
+ self.threshold = threshold
939
+ self.sample_rate = sample_rate
940
+ self.dtype = dtype
941
+ self.legacy = legacy
942
+
943
+ def compute_f0(self, wav, p_len=None):
944
+ x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
945
+ p_len = (x.shape[0] // self.hop_length) if p_len is None else p_len
946
+
947
+ f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold, p_len=p_len)
948
+ f0 = f0[:] if f0.dim() == 1 else f0[0, :, 0]
949
+
950
+ if torch.all(f0 == 0): return f0.cpu().numpy() if p_len is None else np.zeros(p_len), (f0.cpu().numpy() if p_len is None else np.zeros(p_len))
951
+ return f0.cpu().numpy()
RVC/modules/utils.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import sys
4
+ import torch
5
+ import codecs
6
+ import librosa
7
+ import requests
8
+
9
+ import numpy as np
10
+ import soundfile as sf
11
+ import torch.nn.functional as F
12
+
13
+ sys.path.append(os.getcwd())
14
+
15
+ from modules import opencl
16
+
17
+ def change_rms(source_audio, source_rate, target_audio, target_rate, rate):
18
+ rms2 = F.interpolate(torch.from_numpy(librosa.feature.rms(y=target_audio, frame_length=target_rate // 2 * 2, hop_length=target_rate // 2)).float().unsqueeze(0), size=target_audio.shape[0], mode="linear").squeeze()
19
+ return (target_audio * (torch.pow(F.interpolate(torch.from_numpy(librosa.feature.rms(y=source_audio, frame_length=source_rate // 2 * 2, hop_length=source_rate // 2)).float().unsqueeze(0), size=target_audio.shape[0], mode="linear").squeeze(), 1 - rate) * torch.pow(torch.maximum(rms2, torch.zeros_like(rms2) + 1e-6), rate - 1)).numpy())
20
+
21
+ def clear_gpu_cache():
22
+ gc.collect()
23
+
24
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
25
+ elif torch.backends.mps.is_available(): torch.mps.empty_cache()
26
+ elif opencl.is_available(): opencl.pytorch_ocl.empty_cache()
27
+
28
+ def HF_download_file(url, output_path=None):
29
+ url = url.replace("/blob/", "/resolve/").replace("?download=true", "").strip()
30
+ output_path = os.path.basename(url) if output_path is None else (os.path.join(output_path, os.path.basename(url)) if os.path.isdir(output_path) else output_path)
31
+ response = requests.get(url, stream=True, timeout=300)
32
+
33
+ if response.status_code == 200:
34
+ with open(output_path, "wb") as f:
35
+ for chunk in response.iter_content(chunk_size=10 * 1024 * 1024):
36
+ f.write(chunk)
37
+
38
+ return output_path
39
+ else: raise ValueError(response.status_code)
40
+
41
+ def check_predictors(method):
42
+ def download(predictors):
43
+ if not os.path.exists(os.path.join("models", predictors)):
44
+ HF_download_file(codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/cerqvpgbef/", "rot13") + predictors, os.path.join("models", predictors))
45
+
46
+ model_dict = {
47
+ **dict.fromkeys(["rmvpe", "rmvpe-legacy"], "rmvpe.pt"),
48
+ **dict.fromkeys(["fcpe"], "fcpe.pt"),
49
+ **dict.fromkeys(["fcpe-legacy"], "fcpe_legacy.pt"),
50
+ **dict.fromkeys(["crepe-full", "mangio-crepe-full"], "crepe_full.pth"),
51
+ **dict.fromkeys(["crepe-large", "mangio-crepe-large"], "crepe_large.pth"),
52
+ **dict.fromkeys(["crepe-medium", "mangio-crepe-medium"], "crepe_medium.pth"),
53
+ **dict.fromkeys(["crepe-small", "mangio-crepe-small"], "crepe_small.pth"),
54
+ **dict.fromkeys(["crepe-tiny", "mangio-crepe-tiny"], "crepe_tiny.pth"),
55
+ }
56
+
57
+ if method in model_dict: download(model_dict[method])
58
+
59
+ def check_embedders(hubert):
60
+ if hubert in ["contentvec_base", "hubert_base", "japanese_hubert_base", "korean_hubert_base", "chinese_hubert_base", "portuguese_hubert_base", "spin"]:
61
+ hubert += ".pt"
62
+ model_path = os.path.join("models", hubert)
63
+ if not os.path.exists(model_path):
64
+ HF_download_file("".join([codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Ivrganzrfr-EIP-Cebwrpg/erfbyir/znva/rzorqqref/", "rot13"), "fairseq/", hubert]), model_path)
65
+
66
+ def load_audio(file, sample_rate=16000):
67
+ try:
68
+ file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
69
+ if not os.path.isfile(file): raise FileNotFoundError(f"[ERROR] Not found audio: {file}")
70
+
71
+ try:
72
+ audio, sr = sf.read(file, dtype=np.float32)
73
+ except:
74
+ audio, sr = librosa.load(file, sr=None)
75
+
76
+ if len(audio.shape) > 1: audio = librosa.to_mono(audio.T)
77
+ if sr != sample_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate, res_type="soxr_vhq")
78
+ except Exception as e:
79
+ raise RuntimeError(f"[ERROR] Error reading audio file: {e}")
80
+
81
+ return audio.flatten()
82
+
83
+ class Autotune:
84
+ def __init__(self, ref_freqs):
85
+ self.ref_freqs = ref_freqs
86
+ self.note_dict = self.ref_freqs
87
+
88
+ def autotune_f0(self, f0, f0_autotune_strength):
89
+ autotuned_f0 = np.zeros_like(f0)
90
+
91
+ for i, freq in enumerate(f0):
92
+ autotuned_f0[i] = freq + (min(self.note_dict, key=lambda x: abs(x - freq)) - freq) * f0_autotune_strength
93
+
94
+ return autotuned_f0