XiaoHei Studio commited on
Commit
0047e35
1 Parent(s): 8907ed4

Upload 49 files

Browse files
Files changed (49) hide show
  1. modules/DSConv.py +76 -0
  2. modules/F0Predictor/CrepeF0Predictor.py +34 -0
  3. modules/F0Predictor/DioF0Predictor.py +74 -0
  4. modules/F0Predictor/F0Predictor.py +16 -0
  5. modules/F0Predictor/FCPEF0Predictor.py +109 -0
  6. modules/F0Predictor/HarvestF0Predictor.py +69 -0
  7. modules/F0Predictor/PMF0Predictor.py +72 -0
  8. modules/F0Predictor/RMVPEF0Predictor.py +107 -0
  9. modules/F0Predictor/__init__.py +0 -0
  10. modules/F0Predictor/__pycache__/CrepeF0Predictor.cpython-38.pyc +0 -0
  11. modules/F0Predictor/__pycache__/DioF0Predictor.cpython-38.pyc +0 -0
  12. modules/F0Predictor/__pycache__/F0Predictor.cpython-38.pyc +0 -0
  13. modules/F0Predictor/__pycache__/HarvestF0Predictor.cpython-38.pyc +0 -0
  14. modules/F0Predictor/__pycache__/PMF0Predictor.cpython-38.pyc +0 -0
  15. modules/F0Predictor/__pycache__/__init__.cpython-38.pyc +0 -0
  16. modules/F0Predictor/__pycache__/crepe.cpython-38.pyc +0 -0
  17. modules/F0Predictor/crepe.py +340 -0
  18. modules/F0Predictor/fcpe/__init__.py +3 -0
  19. modules/F0Predictor/fcpe/model.py +262 -0
  20. modules/F0Predictor/fcpe/nvSTFT.py +133 -0
  21. modules/F0Predictor/fcpe/pcmer.py +369 -0
  22. modules/F0Predictor/rmvpe/__init__.py +10 -0
  23. modules/F0Predictor/rmvpe/constants.py +9 -0
  24. modules/F0Predictor/rmvpe/deepunet.py +190 -0
  25. modules/F0Predictor/rmvpe/inference.py +57 -0
  26. modules/F0Predictor/rmvpe/model.py +67 -0
  27. modules/F0Predictor/rmvpe/seq.py +20 -0
  28. modules/F0Predictor/rmvpe/spec.py +67 -0
  29. modules/F0Predictor/rmvpe/utils.py +107 -0
  30. modules/__init__.py +0 -0
  31. modules/__pycache__/DSConv.cpython-38.pyc +0 -0
  32. modules/__pycache__/__init__.cpython-38.pyc +0 -0
  33. modules/__pycache__/attentions.cpython-38.pyc +0 -0
  34. modules/__pycache__/commons.cpython-38.pyc +0 -0
  35. modules/__pycache__/losses.cpython-38.pyc +0 -0
  36. modules/__pycache__/mel_processing.cpython-38.pyc +0 -0
  37. modules/__pycache__/modules.cpython-38.pyc +0 -0
  38. modules/__pycache__/slicer2.cpython-38.pyc +0 -0
  39. modules/attentions.py +347 -0
  40. modules/commons.py +183 -0
  41. modules/enhancer.py +107 -0
  42. modules/losses.py +58 -0
  43. modules/mel_processing.py +83 -0
  44. modules/modules.py +306 -0
  45. modules/slicer2.py +186 -0
  46. onnxexport/__pycache__/model_onnx.cpython-38.pyc +0 -0
  47. onnxexport/__pycache__/model_onnx_speaker_mix.cpython-38.pyc +0 -0
  48. onnxexport/model_onnx.py +333 -0
  49. onnxexport/model_onnx_speaker_mix.py +262 -0
modules/DSConv.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torch.nn.utils import remove_weight_norm, weight_norm
3
+
4
+
5
+ class Depthwise_Separable_Conv1D(nn.Module):
6
+ def __init__(
7
+ self,
8
+ in_channels,
9
+ out_channels,
10
+ kernel_size,
11
+ stride = 1,
12
+ padding = 0,
13
+ dilation = 1,
14
+ bias = True,
15
+ padding_mode = 'zeros', # TODO: refine this type
16
+ device=None,
17
+ dtype=None
18
+ ):
19
+ super().__init__()
20
+ self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype)
21
+ self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype)
22
+
23
+ def forward(self, input):
24
+ return self.point_conv(self.depth_conv(input))
25
+
26
+ def weight_norm(self):
27
+ self.depth_conv = weight_norm(self.depth_conv, name = 'weight')
28
+ self.point_conv = weight_norm(self.point_conv, name = 'weight')
29
+
30
+ def remove_weight_norm(self):
31
+ self.depth_conv = remove_weight_norm(self.depth_conv, name = 'weight')
32
+ self.point_conv = remove_weight_norm(self.point_conv, name = 'weight')
33
+
34
+ class Depthwise_Separable_TransposeConv1D(nn.Module):
35
+ def __init__(
36
+ self,
37
+ in_channels,
38
+ out_channels,
39
+ kernel_size,
40
+ stride = 1,
41
+ padding = 0,
42
+ output_padding = 0,
43
+ bias = True,
44
+ dilation = 1,
45
+ padding_mode = 'zeros', # TODO: refine this type
46
+ device=None,
47
+ dtype=None
48
+ ):
49
+ super().__init__()
50
+ self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,output_padding=output_padding,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype)
51
+ self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype)
52
+
53
+ def forward(self, input):
54
+ return self.point_conv(self.depth_conv(input))
55
+
56
+ def weight_norm(self):
57
+ self.depth_conv = weight_norm(self.depth_conv, name = 'weight')
58
+ self.point_conv = weight_norm(self.point_conv, name = 'weight')
59
+
60
+ def remove_weight_norm(self):
61
+ remove_weight_norm(self.depth_conv, name = 'weight')
62
+ remove_weight_norm(self.point_conv, name = 'weight')
63
+
64
+
65
+ def weight_norm_modules(module, name = 'weight', dim = 0):
66
+ if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D):
67
+ module.weight_norm()
68
+ return module
69
+ else:
70
+ return weight_norm(module,name,dim)
71
+
72
+ def remove_weight_norm_modules(module, name = 'weight'):
73
+ if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D):
74
+ module.remove_weight_norm()
75
+ else:
76
+ remove_weight_norm(module,name)
modules/F0Predictor/CrepeF0Predictor.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from modules.F0Predictor.crepe import CrepePitchExtractor
4
+ from modules.F0Predictor.F0Predictor import F0Predictor
5
+
6
+
7
+ class CrepeF0Predictor(F0Predictor):
8
+ def __init__(self,hop_length=512,f0_min=50,f0_max=1100,device=None,sampling_rate=44100,threshold=0.05,model="full"):
9
+ self.F0Creper = CrepePitchExtractor(hop_length=hop_length,f0_min=f0_min,f0_max=f0_max,device=device,threshold=threshold,model=model)
10
+ self.hop_length = hop_length
11
+ self.f0_min = f0_min
12
+ self.f0_max = f0_max
13
+ self.device = device
14
+ self.threshold = threshold
15
+ self.sampling_rate = sampling_rate
16
+ self.name = "crepe"
17
+
18
+ def compute_f0(self,wav,p_len=None):
19
+ x = torch.FloatTensor(wav).to(self.device)
20
+ if p_len is None:
21
+ p_len = x.shape[0]//self.hop_length
22
+ else:
23
+ assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
24
+ f0,uv = self.F0Creper(x[None,:].float(),self.sampling_rate,pad_to=p_len)
25
+ return f0
26
+
27
+ def compute_f0_uv(self,wav,p_len=None):
28
+ x = torch.FloatTensor(wav).to(self.device)
29
+ if p_len is None:
30
+ p_len = x.shape[0]//self.hop_length
31
+ else:
32
+ assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
33
+ f0,uv = self.F0Creper(x[None,:].float(),self.sampling_rate,pad_to=p_len)
34
+ return f0,uv
modules/F0Predictor/DioF0Predictor.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pyworld
3
+
4
+ from modules.F0Predictor.F0Predictor import F0Predictor
5
+
6
+
7
+ class DioF0Predictor(F0Predictor):
8
+ def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100):
9
+ self.hop_length = hop_length
10
+ self.f0_min = f0_min
11
+ self.f0_max = f0_max
12
+ self.sampling_rate = sampling_rate
13
+ self.name = "dio"
14
+
15
+ def interpolate_f0(self,f0):
16
+ '''
17
+ 对F0进行插值处理
18
+ '''
19
+ vuv_vector = np.zeros_like(f0, dtype=np.float32)
20
+ vuv_vector[f0 > 0.0] = 1.0
21
+ vuv_vector[f0 <= 0.0] = 0.0
22
+
23
+ nzindex = np.nonzero(f0)[0]
24
+ data = f0[nzindex]
25
+ nzindex = nzindex.astype(np.float32)
26
+ time_org = self.hop_length / self.sampling_rate * nzindex
27
+ time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate
28
+
29
+ if data.shape[0] <= 0:
30
+ return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector
31
+
32
+ if data.shape[0] == 1:
33
+ return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector
34
+
35
+ f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1])
36
+
37
+ return f0,vuv_vector
38
+
39
+ def resize_f0(self,x, target_len):
40
+ source = np.array(x)
41
+ source[source<0.001] = np.nan
42
+ target = np.interp(np.arange(0, len(source)*target_len, len(source))/ target_len, np.arange(0, len(source)), source)
43
+ res = np.nan_to_num(target)
44
+ return res
45
+
46
+ def compute_f0(self,wav,p_len=None):
47
+ if p_len is None:
48
+ p_len = wav.shape[0]//self.hop_length
49
+ f0, t = pyworld.dio(
50
+ wav.astype(np.double),
51
+ fs=self.sampling_rate,
52
+ f0_floor=self.f0_min,
53
+ f0_ceil=self.f0_max,
54
+ frame_period=1000 * self.hop_length / self.sampling_rate,
55
+ )
56
+ f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate)
57
+ for index, pitch in enumerate(f0):
58
+ f0[index] = round(pitch, 1)
59
+ return self.interpolate_f0(self.resize_f0(f0, p_len))[0]
60
+
61
+ def compute_f0_uv(self,wav,p_len=None):
62
+ if p_len is None:
63
+ p_len = wav.shape[0]//self.hop_length
64
+ f0, t = pyworld.dio(
65
+ wav.astype(np.double),
66
+ fs=self.sampling_rate,
67
+ f0_floor=self.f0_min,
68
+ f0_ceil=self.f0_max,
69
+ frame_period=1000 * self.hop_length / self.sampling_rate,
70
+ )
71
+ f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate)
72
+ for index, pitch in enumerate(f0):
73
+ f0[index] = round(pitch, 1)
74
+ return self.interpolate_f0(self.resize_f0(f0, p_len))
modules/F0Predictor/F0Predictor.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class F0Predictor(object):
2
+ def compute_f0(self,wav,p_len):
3
+ '''
4
+ input: wav:[signal_length]
5
+ p_len:int
6
+ output: f0:[signal_length//hop_length]
7
+ '''
8
+ pass
9
+
10
+ def compute_f0_uv(self,wav,p_len):
11
+ '''
12
+ input: wav:[signal_length]
13
+ p_len:int
14
+ output: f0:[signal_length//hop_length],uv:[signal_length//hop_length]
15
+ '''
16
+ pass
modules/F0Predictor/FCPEF0Predictor.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from modules.F0Predictor.F0Predictor import F0Predictor
8
+
9
+ from .fcpe.model import FCPEInfer
10
+
11
+
12
+ class FCPEF0Predictor(F0Predictor):
13
+ def __init__(self, hop_length=512, f0_min=50, f0_max=1100, dtype=torch.float32, device=None, sampling_rate=44100,
14
+ threshold=0.05):
15
+ self.fcpe = FCPEInfer(model_path="pretrain/fcpe.pt", device=device, dtype=dtype)
16
+ self.hop_length = hop_length
17
+ self.f0_min = f0_min
18
+ self.f0_max = f0_max
19
+ if device is None:
20
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
+ else:
22
+ self.device = device
23
+ self.threshold = threshold
24
+ self.sampling_rate = sampling_rate
25
+ self.dtype = dtype
26
+ self.name = "fcpe"
27
+
28
+ def repeat_expand(
29
+ self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest"
30
+ ):
31
+ ndim = content.ndim
32
+
33
+ if content.ndim == 1:
34
+ content = content[None, None]
35
+ elif content.ndim == 2:
36
+ content = content[None]
37
+
38
+ assert content.ndim == 3
39
+
40
+ is_np = isinstance(content, np.ndarray)
41
+ if is_np:
42
+ content = torch.from_numpy(content)
43
+
44
+ results = torch.nn.functional.interpolate(content, size=target_len, mode=mode)
45
+
46
+ if is_np:
47
+ results = results.numpy()
48
+
49
+ if ndim == 1:
50
+ return results[0, 0]
51
+ elif ndim == 2:
52
+ return results[0]
53
+
54
+ def post_process(self, x, sampling_rate, f0, pad_to):
55
+ if isinstance(f0, np.ndarray):
56
+ f0 = torch.from_numpy(f0).float().to(x.device)
57
+
58
+ if pad_to is None:
59
+ return f0
60
+
61
+ f0 = self.repeat_expand(f0, pad_to)
62
+
63
+ vuv_vector = torch.zeros_like(f0)
64
+ vuv_vector[f0 > 0.0] = 1.0
65
+ vuv_vector[f0 <= 0.0] = 0.0
66
+
67
+ # 去掉0频率, 并线性插值
68
+ nzindex = torch.nonzero(f0).squeeze()
69
+ f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
70
+ time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy()
71
+ time_frame = np.arange(pad_to) * self.hop_length / sampling_rate
72
+
73
+ vuv_vector = F.interpolate(vuv_vector[None, None, :], size=pad_to)[0][0]
74
+
75
+ if f0.shape[0] <= 0:
76
+ return torch.zeros(pad_to, dtype=torch.float, device=x.device).cpu().numpy(), vuv_vector.cpu().numpy()
77
+ if f0.shape[0] == 1:
78
+ return (torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[
79
+ 0]).cpu().numpy(), vuv_vector.cpu().numpy()
80
+
81
+ # 大概可以用 torch 重写?
82
+ f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
83
+ # vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0))
84
+
85
+ return f0, vuv_vector.cpu().numpy()
86
+
87
+ def compute_f0(self, wav, p_len=None):
88
+ x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
89
+ if p_len is None:
90
+ p_len = x.shape[0] // self.hop_length
91
+ else:
92
+ assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error"
93
+ f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold)[0,:,0]
94
+ if torch.all(f0 == 0):
95
+ rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
96
+ return rtn, rtn
97
+ return self.post_process(x, self.sampling_rate, f0, p_len)[0]
98
+
99
+ def compute_f0_uv(self, wav, p_len=None):
100
+ x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
101
+ if p_len is None:
102
+ p_len = x.shape[0] // self.hop_length
103
+ else:
104
+ assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error"
105
+ f0 = self.fcpe(x, sr=self.sampling_rate, threshold=self.threshold)[0,:,0]
106
+ if torch.all(f0 == 0):
107
+ rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
108
+ return rtn, rtn
109
+ return self.post_process(x, self.sampling_rate, f0, p_len)
modules/F0Predictor/HarvestF0Predictor.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pyworld
3
+
4
+ from modules.F0Predictor.F0Predictor import F0Predictor
5
+
6
+
7
+ class HarvestF0Predictor(F0Predictor):
8
+ def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100):
9
+ self.hop_length = hop_length
10
+ self.f0_min = f0_min
11
+ self.f0_max = f0_max
12
+ self.sampling_rate = sampling_rate
13
+ self.name = "harvest"
14
+
15
+ def interpolate_f0(self,f0):
16
+ '''
17
+ 对F0进行插值处理
18
+ '''
19
+ vuv_vector = np.zeros_like(f0, dtype=np.float32)
20
+ vuv_vector[f0 > 0.0] = 1.0
21
+ vuv_vector[f0 <= 0.0] = 0.0
22
+
23
+ nzindex = np.nonzero(f0)[0]
24
+ data = f0[nzindex]
25
+ nzindex = nzindex.astype(np.float32)
26
+ time_org = self.hop_length / self.sampling_rate * nzindex
27
+ time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate
28
+
29
+ if data.shape[0] <= 0:
30
+ return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector
31
+
32
+ if data.shape[0] == 1:
33
+ return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector
34
+
35
+ f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1])
36
+
37
+ return f0,vuv_vector
38
+ def resize_f0(self,x, target_len):
39
+ source = np.array(x)
40
+ source[source<0.001] = np.nan
41
+ target = np.interp(np.arange(0, len(source)*target_len, len(source))/ target_len, np.arange(0, len(source)), source)
42
+ res = np.nan_to_num(target)
43
+ return res
44
+
45
+ def compute_f0(self,wav,p_len=None):
46
+ if p_len is None:
47
+ p_len = wav.shape[0]//self.hop_length
48
+ f0, t = pyworld.harvest(
49
+ wav.astype(np.double),
50
+ fs=self.hop_length,
51
+ f0_ceil=self.f0_max,
52
+ f0_floor=self.f0_min,
53
+ frame_period=1000 * self.hop_length / self.sampling_rate,
54
+ )
55
+ f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.fs)
56
+ return self.interpolate_f0(self.resize_f0(f0, p_len))[0]
57
+
58
+ def compute_f0_uv(self,wav,p_len=None):
59
+ if p_len is None:
60
+ p_len = wav.shape[0]//self.hop_length
61
+ f0, t = pyworld.harvest(
62
+ wav.astype(np.double),
63
+ fs=self.sampling_rate,
64
+ f0_floor=self.f0_min,
65
+ f0_ceil=self.f0_max,
66
+ frame_period=1000 * self.hop_length / self.sampling_rate,
67
+ )
68
+ f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate)
69
+ return self.interpolate_f0(self.resize_f0(f0, p_len))
modules/F0Predictor/PMF0Predictor.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import parselmouth
3
+
4
+ from modules.F0Predictor.F0Predictor import F0Predictor
5
+
6
+
7
+ class PMF0Predictor(F0Predictor):
8
+ def __init__(self,hop_length=512,f0_min=50,f0_max=1100,sampling_rate=44100):
9
+ self.hop_length = hop_length
10
+ self.f0_min = f0_min
11
+ self.f0_max = f0_max
12
+ self.sampling_rate = sampling_rate
13
+ self.name = "pm"
14
+
15
+ def interpolate_f0(self,f0):
16
+ '''
17
+ 对F0进行插值处理
18
+ '''
19
+ vuv_vector = np.zeros_like(f0, dtype=np.float32)
20
+ vuv_vector[f0 > 0.0] = 1.0
21
+ vuv_vector[f0 <= 0.0] = 0.0
22
+
23
+ nzindex = np.nonzero(f0)[0]
24
+ data = f0[nzindex]
25
+ nzindex = nzindex.astype(np.float32)
26
+ time_org = self.hop_length / self.sampling_rate * nzindex
27
+ time_frame = np.arange(f0.shape[0]) * self.hop_length / self.sampling_rate
28
+
29
+ if data.shape[0] <= 0:
30
+ return np.zeros(f0.shape[0], dtype=np.float32),vuv_vector
31
+
32
+ if data.shape[0] == 1:
33
+ return np.ones(f0.shape[0], dtype=np.float32) * f0[0],vuv_vector
34
+
35
+ f0 = np.interp(time_frame, time_org, data, left=data[0], right=data[-1])
36
+
37
+ return f0,vuv_vector
38
+
39
+
40
+ def compute_f0(self,wav,p_len=None):
41
+ x = wav
42
+ if p_len is None:
43
+ p_len = x.shape[0]//self.hop_length
44
+ else:
45
+ assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
46
+ time_step = self.hop_length / self.sampling_rate * 1000
47
+ f0 = parselmouth.Sound(x, self.sampling_rate).to_pitch_ac(
48
+ time_step=time_step / 1000, voicing_threshold=0.6,
49
+ pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array['frequency']
50
+
51
+ pad_size=(p_len - len(f0) + 1) // 2
52
+ if(pad_size>0 or p_len - len(f0) - pad_size>0):
53
+ f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
54
+ f0,uv = self.interpolate_f0(f0)
55
+ return f0
56
+
57
+ def compute_f0_uv(self,wav,p_len=None):
58
+ x = wav
59
+ if p_len is None:
60
+ p_len = x.shape[0]//self.hop_length
61
+ else:
62
+ assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
63
+ time_step = self.hop_length / self.sampling_rate * 1000
64
+ f0 = parselmouth.Sound(x, self.sampling_rate).to_pitch_ac(
65
+ time_step=time_step / 1000, voicing_threshold=0.6,
66
+ pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array['frequency']
67
+
68
+ pad_size=(p_len - len(f0) + 1) // 2
69
+ if(pad_size>0 or p_len - len(f0) - pad_size>0):
70
+ f0 = np.pad(f0,[[pad_size,p_len - len(f0) - pad_size]], mode='constant')
71
+ f0,uv = self.interpolate_f0(f0)
72
+ return f0,uv
modules/F0Predictor/RMVPEF0Predictor.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from modules.F0Predictor.F0Predictor import F0Predictor
8
+
9
+ from .rmvpe import RMVPE
10
+
11
+
12
+ class RMVPEF0Predictor(F0Predictor):
13
+ def __init__(self,hop_length=512,f0_min=50,f0_max=1100, dtype=torch.float32, device=None,sampling_rate=44100,threshold=0.05):
14
+ self.rmvpe = RMVPE(model_path="pretrain/rmvpe.pt",dtype=dtype,device=device)
15
+ self.hop_length = hop_length
16
+ self.f0_min = f0_min
17
+ self.f0_max = f0_max
18
+ if device is None:
19
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+ else:
21
+ self.device = device
22
+ self.threshold = threshold
23
+ self.sampling_rate = sampling_rate
24
+ self.dtype = dtype
25
+ self.name = "rmvpe"
26
+
27
+ def repeat_expand(
28
+ self, content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest"
29
+ ):
30
+ ndim = content.ndim
31
+
32
+ if content.ndim == 1:
33
+ content = content[None, None]
34
+ elif content.ndim == 2:
35
+ content = content[None]
36
+
37
+ assert content.ndim == 3
38
+
39
+ is_np = isinstance(content, np.ndarray)
40
+ if is_np:
41
+ content = torch.from_numpy(content)
42
+
43
+ results = torch.nn.functional.interpolate(content, size=target_len, mode=mode)
44
+
45
+ if is_np:
46
+ results = results.numpy()
47
+
48
+ if ndim == 1:
49
+ return results[0, 0]
50
+ elif ndim == 2:
51
+ return results[0]
52
+
53
+ def post_process(self, x, sampling_rate, f0, pad_to):
54
+ if isinstance(f0, np.ndarray):
55
+ f0 = torch.from_numpy(f0).float().to(x.device)
56
+
57
+ if pad_to is None:
58
+ return f0
59
+
60
+ f0 = self.repeat_expand(f0, pad_to)
61
+
62
+ vuv_vector = torch.zeros_like(f0)
63
+ vuv_vector[f0 > 0.0] = 1.0
64
+ vuv_vector[f0 <= 0.0] = 0.0
65
+
66
+ # 去掉0频率, 并线性插值
67
+ nzindex = torch.nonzero(f0).squeeze()
68
+ f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
69
+ time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy()
70
+ time_frame = np.arange(pad_to) * self.hop_length / sampling_rate
71
+
72
+ vuv_vector = F.interpolate(vuv_vector[None,None,:],size=pad_to)[0][0]
73
+
74
+ if f0.shape[0] <= 0:
75
+ return torch.zeros(pad_to, dtype=torch.float, device=x.device).cpu().numpy(),vuv_vector.cpu().numpy()
76
+ if f0.shape[0] == 1:
77
+ return (torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0]).cpu().numpy() ,vuv_vector.cpu().numpy()
78
+
79
+ # 大概可以用 torch 重写?
80
+ f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
81
+ #vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0))
82
+
83
+ return f0,vuv_vector.cpu().numpy()
84
+
85
+ def compute_f0(self,wav,p_len=None):
86
+ x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
87
+ if p_len is None:
88
+ p_len = x.shape[0]//self.hop_length
89
+ else:
90
+ assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
91
+ f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold)
92
+ if torch.all(f0 == 0):
93
+ rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
94
+ return rtn,rtn
95
+ return self.post_process(x,self.sampling_rate,f0,p_len)[0]
96
+
97
+ def compute_f0_uv(self,wav,p_len=None):
98
+ x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
99
+ if p_len is None:
100
+ p_len = x.shape[0]//self.hop_length
101
+ else:
102
+ assert abs(p_len-x.shape[0]//self.hop_length) < 4, "pad length error"
103
+ f0 = self.rmvpe.infer_from_audio(x,self.sampling_rate,self.threshold)
104
+ if torch.all(f0 == 0):
105
+ rtn = f0.cpu().numpy() if p_len is None else np.zeros(p_len)
106
+ return rtn,rtn
107
+ return self.post_process(x,self.sampling_rate,f0,p_len)
modules/F0Predictor/__init__.py ADDED
File without changes
modules/F0Predictor/__pycache__/CrepeF0Predictor.cpython-38.pyc ADDED
Binary file (1.56 kB). View file
 
modules/F0Predictor/__pycache__/DioF0Predictor.cpython-38.pyc ADDED
Binary file (2.54 kB). View file
 
modules/F0Predictor/__pycache__/F0Predictor.cpython-38.pyc ADDED
Binary file (869 Bytes). View file
 
modules/F0Predictor/__pycache__/HarvestF0Predictor.cpython-38.pyc ADDED
Binary file (2.58 kB). View file
 
modules/F0Predictor/__pycache__/PMF0Predictor.cpython-38.pyc ADDED
Binary file (2.43 kB). View file
 
modules/F0Predictor/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (169 Bytes). View file
 
modules/F0Predictor/__pycache__/crepe.cpython-38.pyc ADDED
Binary file (9.02 kB). View file
 
modules/F0Predictor/crepe.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ try:
4
+ from typing import Literal
5
+ except Exception:
6
+ from typing_extensions import Literal
7
+ import numpy as np
8
+ import torch
9
+ import torchcrepe
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+
13
+ #from:https://github.com/fishaudio/fish-diffusion
14
+
15
+ def repeat_expand(
16
+ content: Union[torch.Tensor, np.ndarray], target_len: int, mode: str = "nearest"
17
+ ):
18
+ """Repeat content to target length.
19
+ This is a wrapper of torch.nn.functional.interpolate.
20
+
21
+ Args:
22
+ content (torch.Tensor): tensor
23
+ target_len (int): target length
24
+ mode (str, optional): interpolation mode. Defaults to "nearest".
25
+
26
+ Returns:
27
+ torch.Tensor: tensor
28
+ """
29
+
30
+ ndim = content.ndim
31
+
32
+ if content.ndim == 1:
33
+ content = content[None, None]
34
+ elif content.ndim == 2:
35
+ content = content[None]
36
+
37
+ assert content.ndim == 3
38
+
39
+ is_np = isinstance(content, np.ndarray)
40
+ if is_np:
41
+ content = torch.from_numpy(content)
42
+
43
+ results = torch.nn.functional.interpolate(content, size=target_len, mode=mode)
44
+
45
+ if is_np:
46
+ results = results.numpy()
47
+
48
+ if ndim == 1:
49
+ return results[0, 0]
50
+ elif ndim == 2:
51
+ return results[0]
52
+
53
+
54
+ class BasePitchExtractor:
55
+ def __init__(
56
+ self,
57
+ hop_length: int = 512,
58
+ f0_min: float = 50.0,
59
+ f0_max: float = 1100.0,
60
+ keep_zeros: bool = True,
61
+ ):
62
+ """Base pitch extractor.
63
+
64
+ Args:
65
+ hop_length (int, optional): Hop length. Defaults to 512.
66
+ f0_min (float, optional): Minimum f0. Defaults to 50.0.
67
+ f0_max (float, optional): Maximum f0. Defaults to 1100.0.
68
+ keep_zeros (bool, optional): Whether keep zeros in pitch. Defaults to True.
69
+ """
70
+
71
+ self.hop_length = hop_length
72
+ self.f0_min = f0_min
73
+ self.f0_max = f0_max
74
+ self.keep_zeros = keep_zeros
75
+
76
+ def __call__(self, x, sampling_rate=44100, pad_to=None):
77
+ raise NotImplementedError("BasePitchExtractor is not callable.")
78
+
79
+ def post_process(self, x, sampling_rate, f0, pad_to):
80
+ if isinstance(f0, np.ndarray):
81
+ f0 = torch.from_numpy(f0).float().to(x.device)
82
+
83
+ if pad_to is None:
84
+ return f0
85
+
86
+ f0 = repeat_expand(f0, pad_to)
87
+
88
+ if self.keep_zeros:
89
+ return f0
90
+
91
+ vuv_vector = torch.zeros_like(f0)
92
+ vuv_vector[f0 > 0.0] = 1.0
93
+ vuv_vector[f0 <= 0.0] = 0.0
94
+
95
+ # 去掉0频率, 并线性插值
96
+ nzindex = torch.nonzero(f0).squeeze()
97
+ f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
98
+ time_org = self.hop_length / sampling_rate * nzindex.cpu().numpy()
99
+ time_frame = np.arange(pad_to) * self.hop_length / sampling_rate
100
+
101
+ vuv_vector = F.interpolate(vuv_vector[None,None,:],size=pad_to)[0][0]
102
+
103
+ if f0.shape[0] <= 0:
104
+ return torch.zeros(pad_to, dtype=torch.float, device=x.device),vuv_vector.cpu().numpy()
105
+ if f0.shape[0] == 1:
106
+ return torch.ones(pad_to, dtype=torch.float, device=x.device) * f0[0],vuv_vector.cpu().numpy()
107
+
108
+ # 大概可以用 torch 重写?
109
+ f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
110
+ #vuv_vector = np.ceil(scipy.ndimage.zoom(vuv_vector,pad_to/len(vuv_vector),order = 0))
111
+
112
+ return f0,vuv_vector.cpu().numpy()
113
+
114
+
115
+ class MaskedAvgPool1d(nn.Module):
116
+ def __init__(
117
+ self, kernel_size: int, stride: Optional[int] = None, padding: Optional[int] = 0
118
+ ):
119
+ """An implementation of mean pooling that supports masked values.
120
+
121
+ Args:
122
+ kernel_size (int): The size of the median pooling window.
123
+ stride (int, optional): The stride of the median pooling window. Defaults to None.
124
+ padding (int, optional): The padding of the median pooling window. Defaults to 0.
125
+ """
126
+
127
+ super(MaskedAvgPool1d, self).__init__()
128
+ self.kernel_size = kernel_size
129
+ self.stride = stride or kernel_size
130
+ self.padding = padding
131
+
132
+ def forward(self, x, mask=None):
133
+ ndim = x.dim()
134
+ if ndim == 2:
135
+ x = x.unsqueeze(1)
136
+
137
+ assert (
138
+ x.dim() == 3
139
+ ), "Input tensor must have 2 or 3 dimensions (batch_size, channels, width)"
140
+
141
+ # Apply the mask by setting masked elements to zero, or make NaNs zero
142
+ if mask is None:
143
+ mask = ~torch.isnan(x)
144
+
145
+ # Ensure mask has the same shape as the input tensor
146
+ assert x.shape == mask.shape, "Input tensor and mask must have the same shape"
147
+
148
+ masked_x = torch.where(mask, x, torch.zeros_like(x))
149
+ # Create a ones kernel with the same number of channels as the input tensor
150
+ ones_kernel = torch.ones(x.size(1), 1, self.kernel_size, device=x.device)
151
+
152
+ # Perform sum pooling
153
+ sum_pooled = nn.functional.conv1d(
154
+ masked_x,
155
+ ones_kernel,
156
+ stride=self.stride,
157
+ padding=self.padding,
158
+ groups=x.size(1),
159
+ )
160
+
161
+ # Count the non-masked (valid) elements in each pooling window
162
+ valid_count = nn.functional.conv1d(
163
+ mask.float(),
164
+ ones_kernel,
165
+ stride=self.stride,
166
+ padding=self.padding,
167
+ groups=x.size(1),
168
+ )
169
+ valid_count = valid_count.clamp(min=1) # Avoid division by zero
170
+
171
+ # Perform masked average pooling
172
+ avg_pooled = sum_pooled / valid_count
173
+
174
+ # Fill zero values with NaNs
175
+ avg_pooled[avg_pooled == 0] = float("nan")
176
+
177
+ if ndim == 2:
178
+ return avg_pooled.squeeze(1)
179
+
180
+ return avg_pooled
181
+
182
+
183
+ class MaskedMedianPool1d(nn.Module):
184
+ def __init__(
185
+ self, kernel_size: int, stride: Optional[int] = None, padding: Optional[int] = 0
186
+ ):
187
+ """An implementation of median pooling that supports masked values.
188
+
189
+ This implementation is inspired by the median pooling implementation in
190
+ https://gist.github.com/rwightman/f2d3849281624be7c0f11c85c87c1598
191
+
192
+ Args:
193
+ kernel_size (int): The size of the median pooling window.
194
+ stride (int, optional): The stride of the median pooling window. Defaults to None.
195
+ padding (int, optional): The padding of the median pooling window. Defaults to 0.
196
+ """
197
+
198
+ super(MaskedMedianPool1d, self).__init__()
199
+ self.kernel_size = kernel_size
200
+ self.stride = stride or kernel_size
201
+ self.padding = padding
202
+
203
+ def forward(self, x, mask=None):
204
+ ndim = x.dim()
205
+ if ndim == 2:
206
+ x = x.unsqueeze(1)
207
+
208
+ assert (
209
+ x.dim() == 3
210
+ ), "Input tensor must have 2 or 3 dimensions (batch_size, channels, width)"
211
+
212
+ if mask is None:
213
+ mask = ~torch.isnan(x)
214
+
215
+ assert x.shape == mask.shape, "Input tensor and mask must have the same shape"
216
+
217
+ masked_x = torch.where(mask, x, torch.zeros_like(x))
218
+
219
+ x = F.pad(masked_x, (self.padding, self.padding), mode="reflect")
220
+ mask = F.pad(
221
+ mask.float(), (self.padding, self.padding), mode="constant", value=0
222
+ )
223
+
224
+ x = x.unfold(2, self.kernel_size, self.stride)
225
+ mask = mask.unfold(2, self.kernel_size, self.stride)
226
+
227
+ x = x.contiguous().view(x.size()[:3] + (-1,))
228
+ mask = mask.contiguous().view(mask.size()[:3] + (-1,)).to(x.device)
229
+
230
+ # Combine the mask with the input tensor
231
+ #x_masked = torch.where(mask.bool(), x, torch.fill_(torch.zeros_like(x),float("inf")))
232
+ x_masked = torch.where(mask.bool(), x, torch.FloatTensor([float("inf")]).to(x.device))
233
+
234
+ # Sort the masked tensor along the last dimension
235
+ x_sorted, _ = torch.sort(x_masked, dim=-1)
236
+
237
+ # Compute the count of non-masked (valid) values
238
+ valid_count = mask.sum(dim=-1)
239
+
240
+ # Calculate the index of the median value for each pooling window
241
+ median_idx = (torch.div((valid_count - 1), 2, rounding_mode='trunc')).clamp(min=0)
242
+
243
+ # Gather the median values using the calculated indices
244
+ median_pooled = x_sorted.gather(-1, median_idx.unsqueeze(-1).long()).squeeze(-1)
245
+
246
+ # Fill infinite values with NaNs
247
+ median_pooled[torch.isinf(median_pooled)] = float("nan")
248
+
249
+ if ndim == 2:
250
+ return median_pooled.squeeze(1)
251
+
252
+ return median_pooled
253
+
254
+
255
+ class CrepePitchExtractor(BasePitchExtractor):
256
+ def __init__(
257
+ self,
258
+ hop_length: int = 512,
259
+ f0_min: float = 50.0,
260
+ f0_max: float = 1100.0,
261
+ threshold: float = 0.05,
262
+ keep_zeros: bool = False,
263
+ device = None,
264
+ model: Literal["full", "tiny"] = "full",
265
+ use_fast_filters: bool = True,
266
+ decoder="viterbi"
267
+ ):
268
+ super().__init__(hop_length, f0_min, f0_max, keep_zeros)
269
+ if decoder == "viterbi":
270
+ self.decoder = torchcrepe.decode.viterbi
271
+ elif decoder == "argmax":
272
+ self.decoder = torchcrepe.decode.argmax
273
+ elif decoder == "weighted_argmax":
274
+ self.decoder = torchcrepe.decode.weighted_argmax
275
+ else:
276
+ raise "Unknown decoder"
277
+ self.threshold = threshold
278
+ self.model = model
279
+ self.use_fast_filters = use_fast_filters
280
+ self.hop_length = hop_length
281
+ if device is None:
282
+ self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
283
+ else:
284
+ self.dev = torch.device(device)
285
+ if self.use_fast_filters:
286
+ self.median_filter = MaskedMedianPool1d(3, 1, 1).to(device)
287
+ self.mean_filter = MaskedAvgPool1d(3, 1, 1).to(device)
288
+
289
+ def __call__(self, x, sampling_rate=44100, pad_to=None):
290
+ """Extract pitch using crepe.
291
+
292
+
293
+ Args:
294
+ x (torch.Tensor): Audio signal, shape (1, T).
295
+ sampling_rate (int, optional): Sampling rate. Defaults to 44100.
296
+ pad_to (int, optional): Pad to length. Defaults to None.
297
+
298
+ Returns:
299
+ torch.Tensor: Pitch, shape (T // hop_length,).
300
+ """
301
+
302
+ assert x.ndim == 2, f"Expected 2D tensor, got {x.ndim}D tensor."
303
+ assert x.shape[0] == 1, f"Expected 1 channel, got {x.shape[0]} channels."
304
+
305
+ x = x.to(self.dev)
306
+ f0, pd = torchcrepe.predict(
307
+ x,
308
+ sampling_rate,
309
+ self.hop_length,
310
+ self.f0_min,
311
+ self.f0_max,
312
+ pad=True,
313
+ model=self.model,
314
+ batch_size=1024,
315
+ device=x.device,
316
+ return_periodicity=True,
317
+ decoder=self.decoder
318
+ )
319
+
320
+ # Filter, remove silence, set uv threshold, refer to the original warehouse readme
321
+ if self.use_fast_filters:
322
+ pd = self.median_filter(pd)
323
+ else:
324
+ pd = torchcrepe.filter.median(pd, 3)
325
+
326
+ pd = torchcrepe.threshold.Silence(-60.0)(pd, x, sampling_rate, self.hop_length)
327
+ f0 = torchcrepe.threshold.At(self.threshold)(f0, pd)
328
+
329
+ if self.use_fast_filters:
330
+ f0 = self.mean_filter(f0)
331
+ else:
332
+ f0 = torchcrepe.filter.mean(f0, 3)
333
+
334
+ f0 = torch.where(torch.isnan(f0), torch.full_like(f0, 0), f0)[0]
335
+
336
+ if torch.all(f0 == 0):
337
+ rtn = f0.cpu().numpy() if pad_to is None else np.zeros(pad_to)
338
+ return rtn,rtn
339
+
340
+ return self.post_process(x, sampling_rate, f0, pad_to)
modules/F0Predictor/fcpe/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .model import FCPEInfer # noqa: F401
2
+ from .nvSTFT import STFT # noqa: F401
3
+ from .pcmer import PCmer # noqa: F401
modules/F0Predictor/fcpe/model.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn.utils import weight_norm
6
+ from torchaudio.transforms import Resample
7
+
8
+ from .nvSTFT import STFT
9
+ from .pcmer import PCmer
10
+
11
+
12
+ def l2_regularization(model, l2_alpha):
13
+ l2_loss = []
14
+ for module in model.modules():
15
+ if type(module) is nn.Conv2d:
16
+ l2_loss.append((module.weight ** 2).sum() / 2.0)
17
+ return l2_alpha * sum(l2_loss)
18
+
19
+
20
+ class FCPE(nn.Module):
21
+ def __init__(
22
+ self,
23
+ input_channel=128,
24
+ out_dims=360,
25
+ n_layers=12,
26
+ n_chans=512,
27
+ use_siren=False,
28
+ use_full=False,
29
+ loss_mse_scale=10,
30
+ loss_l2_regularization=False,
31
+ loss_l2_regularization_scale=1,
32
+ loss_grad1_mse=False,
33
+ loss_grad1_mse_scale=1,
34
+ f0_max=1975.5,
35
+ f0_min=32.70,
36
+ confidence=False,
37
+ threshold=0.05,
38
+ use_input_conv=True
39
+ ):
40
+ super().__init__()
41
+ if use_siren is True:
42
+ raise ValueError("Siren is not supported yet.")
43
+ if use_full is True:
44
+ raise ValueError("Full model is not supported yet.")
45
+
46
+ self.loss_mse_scale = loss_mse_scale if (loss_mse_scale is not None) else 10
47
+ self.loss_l2_regularization = loss_l2_regularization if (loss_l2_regularization is not None) else False
48
+ self.loss_l2_regularization_scale = loss_l2_regularization_scale if (loss_l2_regularization_scale
49
+ is not None) else 1
50
+ self.loss_grad1_mse = loss_grad1_mse if (loss_grad1_mse is not None) else False
51
+ self.loss_grad1_mse_scale = loss_grad1_mse_scale if (loss_grad1_mse_scale is not None) else 1
52
+ self.f0_max = f0_max if (f0_max is not None) else 1975.5
53
+ self.f0_min = f0_min if (f0_min is not None) else 32.70
54
+ self.confidence = confidence if (confidence is not None) else False
55
+ self.threshold = threshold if (threshold is not None) else 0.05
56
+ self.use_input_conv = use_input_conv if (use_input_conv is not None) else True
57
+
58
+ self.cent_table_b = torch.Tensor(
59
+ np.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0],
60
+ out_dims))
61
+ self.register_buffer("cent_table", self.cent_table_b)
62
+
63
+ # conv in stack
64
+ _leaky = nn.LeakyReLU()
65
+ self.stack = nn.Sequential(
66
+ nn.Conv1d(input_channel, n_chans, 3, 1, 1),
67
+ nn.GroupNorm(4, n_chans),
68
+ _leaky,
69
+ nn.Conv1d(n_chans, n_chans, 3, 1, 1))
70
+
71
+ # transformer
72
+ self.decoder = PCmer(
73
+ num_layers=n_layers,
74
+ num_heads=8,
75
+ dim_model=n_chans,
76
+ dim_keys=n_chans,
77
+ dim_values=n_chans,
78
+ residual_dropout=0.1,
79
+ attention_dropout=0.1)
80
+ self.norm = nn.LayerNorm(n_chans)
81
+
82
+ # out
83
+ self.n_out = out_dims
84
+ self.dense_out = weight_norm(
85
+ nn.Linear(n_chans, self.n_out))
86
+
87
+ def forward(self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder = "local_argmax"):
88
+ """
89
+ input:
90
+ B x n_frames x n_unit
91
+ return:
92
+ dict of B x n_frames x feat
93
+ """
94
+ if cdecoder == "argmax":
95
+ self.cdecoder = self.cents_decoder
96
+ elif cdecoder == "local_argmax":
97
+ self.cdecoder = self.cents_local_decoder
98
+ if self.use_input_conv:
99
+ x = self.stack(mel.transpose(1, 2)).transpose(1, 2)
100
+ else:
101
+ x = mel
102
+ x = self.decoder(x)
103
+ x = self.norm(x)
104
+ x = self.dense_out(x) # [B,N,D]
105
+ x = torch.sigmoid(x)
106
+ if not infer:
107
+ gt_cent_f0 = self.f0_to_cent(gt_f0) # mel f0 #[B,N,1]
108
+ gt_cent_f0 = self.gaussian_blurred_cent(gt_cent_f0) # #[B,N,out_dim]
109
+ loss_all = self.loss_mse_scale * F.binary_cross_entropy(x, gt_cent_f0) # bce loss
110
+ # l2 regularization
111
+ if self.loss_l2_regularization:
112
+ loss_all = loss_all + l2_regularization(model=self, l2_alpha=self.loss_l2_regularization_scale)
113
+ x = loss_all
114
+ if infer:
115
+ x = self.cdecoder(x)
116
+ x = self.cent_to_f0(x)
117
+ if not return_hz_f0:
118
+ x = (1 + x / 700).log()
119
+ return x
120
+
121
+ def cents_decoder(self, y, mask=True):
122
+ B, N, _ = y.size()
123
+ ci = self.cent_table[None, None, :].expand(B, N, -1)
124
+ rtn = torch.sum(ci * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True) # cents: [B,N,1]
125
+ if mask:
126
+ confident = torch.max(y, dim=-1, keepdim=True)[0]
127
+ confident_mask = torch.ones_like(confident)
128
+ confident_mask[confident <= self.threshold] = float("-INF")
129
+ rtn = rtn * confident_mask
130
+ if self.confidence:
131
+ return rtn, confident
132
+ else:
133
+ return rtn
134
+
135
+ def cents_local_decoder(self, y, mask=True):
136
+ B, N, _ = y.size()
137
+ ci = self.cent_table[None, None, :].expand(B, N, -1)
138
+ confident, max_index = torch.max(y, dim=-1, keepdim=True)
139
+ local_argmax_index = torch.arange(0,8).to(max_index.device) + (max_index - 4)
140
+ local_argmax_index[local_argmax_index<0] = 0
141
+ local_argmax_index[local_argmax_index>=self.n_out] = self.n_out - 1
142
+ ci_l = torch.gather(ci,-1,local_argmax_index)
143
+ y_l = torch.gather(y,-1,local_argmax_index)
144
+ rtn = torch.sum(ci_l * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True) # cents: [B,N,1]
145
+ if mask:
146
+ confident_mask = torch.ones_like(confident)
147
+ confident_mask[confident <= self.threshold] = float("-INF")
148
+ rtn = rtn * confident_mask
149
+ if self.confidence:
150
+ return rtn, confident
151
+ else:
152
+ return rtn
153
+
154
+ def cent_to_f0(self, cent):
155
+ return 10. * 2 ** (cent / 1200.)
156
+
157
+ def f0_to_cent(self, f0):
158
+ return 1200. * torch.log2(f0 / 10.)
159
+
160
+ def gaussian_blurred_cent(self, cents): # cents: [B,N,1]
161
+ mask = (cents > 0.1) & (cents < (1200. * np.log2(self.f0_max / 10.)))
162
+ B, N, _ = cents.size()
163
+ ci = self.cent_table[None, None, :].expand(B, N, -1)
164
+ return torch.exp(-torch.square(ci - cents) / 1250) * mask.float()
165
+
166
+
167
+ class FCPEInfer:
168
+ def __init__(self, model_path, device=None, dtype=torch.float32):
169
+ if device is None:
170
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
171
+ self.device = device
172
+ ckpt = torch.load(model_path, map_location=torch.device(self.device))
173
+ self.args = DotDict(ckpt["config"])
174
+ self.dtype = dtype
175
+ model = FCPE(
176
+ input_channel=self.args.model.input_channel,
177
+ out_dims=self.args.model.out_dims,
178
+ n_layers=self.args.model.n_layers,
179
+ n_chans=self.args.model.n_chans,
180
+ use_siren=self.args.model.use_siren,
181
+ use_full=self.args.model.use_full,
182
+ loss_mse_scale=self.args.loss.loss_mse_scale,
183
+ loss_l2_regularization=self.args.loss.loss_l2_regularization,
184
+ loss_l2_regularization_scale=self.args.loss.loss_l2_regularization_scale,
185
+ loss_grad1_mse=self.args.loss.loss_grad1_mse,
186
+ loss_grad1_mse_scale=self.args.loss.loss_grad1_mse_scale,
187
+ f0_max=self.args.model.f0_max,
188
+ f0_min=self.args.model.f0_min,
189
+ confidence=self.args.model.confidence,
190
+ )
191
+ model.to(self.device).to(self.dtype)
192
+ model.load_state_dict(ckpt['model'])
193
+ model.eval()
194
+ self.model = model
195
+ self.wav2mel = Wav2Mel(self.args, dtype=self.dtype, device=self.device)
196
+
197
+ @torch.no_grad()
198
+ def __call__(self, audio, sr, threshold=0.05):
199
+ self.model.threshold = threshold
200
+ audio = audio[None,:]
201
+ mel = self.wav2mel(audio=audio, sample_rate=sr).to(self.dtype)
202
+ f0 = self.model(mel=mel, infer=True, return_hz_f0=True)
203
+ return f0
204
+
205
+
206
+ class Wav2Mel:
207
+
208
+ def __init__(self, args, device=None, dtype=torch.float32):
209
+ # self.args = args
210
+ self.sampling_rate = args.mel.sampling_rate
211
+ self.hop_size = args.mel.hop_size
212
+ if device is None:
213
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
214
+ self.device = device
215
+ self.dtype = dtype
216
+ self.stft = STFT(
217
+ args.mel.sampling_rate,
218
+ args.mel.num_mels,
219
+ args.mel.n_fft,
220
+ args.mel.win_size,
221
+ args.mel.hop_size,
222
+ args.mel.fmin,
223
+ args.mel.fmax
224
+ )
225
+ self.resample_kernel = {}
226
+
227
+ def extract_nvstft(self, audio, keyshift=0, train=False):
228
+ mel = self.stft.get_mel(audio, keyshift=keyshift, train=train).transpose(1, 2) # B, n_frames, bins
229
+ return mel
230
+
231
+ def extract_mel(self, audio, sample_rate, keyshift=0, train=False):
232
+ audio = audio.to(self.dtype).to(self.device)
233
+ # resample
234
+ if sample_rate == self.sampling_rate:
235
+ audio_res = audio
236
+ else:
237
+ key_str = str(sample_rate)
238
+ if key_str not in self.resample_kernel:
239
+ self.resample_kernel[key_str] = Resample(sample_rate, self.sampling_rate, lowpass_filter_width=128)
240
+ self.resample_kernel[key_str] = self.resample_kernel[key_str].to(self.dtype).to(self.device)
241
+ audio_res = self.resample_kernel[key_str](audio)
242
+
243
+ # extract
244
+ mel = self.extract_nvstft(audio_res, keyshift=keyshift, train=train) # B, n_frames, bins
245
+ n_frames = int(audio.shape[1] // self.hop_size) + 1
246
+ if n_frames > int(mel.shape[1]):
247
+ mel = torch.cat((mel, mel[:, -1:, :]), 1)
248
+ if n_frames < int(mel.shape[1]):
249
+ mel = mel[:, :n_frames, :]
250
+ return mel
251
+
252
+ def __call__(self, audio, sample_rate, keyshift=0, train=False):
253
+ return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train)
254
+
255
+
256
+ class DotDict(dict):
257
+ def __getattr__(*args):
258
+ val = dict.get(*args)
259
+ return DotDict(val) if type(val) is dict else val
260
+
261
+ __setattr__ = dict.__setitem__
262
+ __delattr__ = dict.__delitem__
modules/F0Predictor/fcpe/nvSTFT.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import librosa
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.utils.data
9
+ from librosa.filters import mel as librosa_mel_fn
10
+
11
+ os.environ["LRU_CACHE_CAPACITY"] = "3"
12
+
13
+ def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
14
+ sampling_rate = None
15
+ try:
16
+ data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile.
17
+ except Exception as ex:
18
+ print(f"'{full_path}' failed to load.\nException:")
19
+ print(ex)
20
+ if return_empty_on_exception:
21
+ return [], sampling_rate or target_sr or 48000
22
+ else:
23
+ raise Exception(ex)
24
+
25
+ if len(data.shape) > 1:
26
+ data = data[:, 0]
27
+ assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension)
28
+
29
+ if np.issubdtype(data.dtype, np.integer): # if audio data is type int
30
+ max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX
31
+ else: # if audio data is type fp32
32
+ max_mag = max(np.amax(data), -np.amin(data))
33
+ max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32
34
+
35
+ data = torch.FloatTensor(data.astype(np.float32))/max_mag
36
+
37
+ if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except
38
+ return [], sampling_rate or target_sr or 48000
39
+ if target_sr is not None and sampling_rate != target_sr:
40
+ data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr))
41
+ sampling_rate = target_sr
42
+
43
+ return data, sampling_rate
44
+
45
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
46
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
47
+
48
+ def dynamic_range_decompression(x, C=1):
49
+ return np.exp(x) / C
50
+
51
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
52
+ return torch.log(torch.clamp(x, min=clip_val) * C)
53
+
54
+ def dynamic_range_decompression_torch(x, C=1):
55
+ return torch.exp(x) / C
56
+
57
+ class STFT():
58
+ def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
59
+ self.target_sr = sr
60
+
61
+ self.n_mels = n_mels
62
+ self.n_fft = n_fft
63
+ self.win_size = win_size
64
+ self.hop_length = hop_length
65
+ self.fmin = fmin
66
+ self.fmax = fmax
67
+ self.clip_val = clip_val
68
+ self.mel_basis = {}
69
+ self.hann_window = {}
70
+
71
+ def get_mel(self, y, keyshift=0, speed=1, center=False, train=False):
72
+ sampling_rate = self.target_sr
73
+ n_mels = self.n_mels
74
+ n_fft = self.n_fft
75
+ win_size = self.win_size
76
+ hop_length = self.hop_length
77
+ fmin = self.fmin
78
+ fmax = self.fmax
79
+ clip_val = self.clip_val
80
+
81
+ factor = 2 ** (keyshift / 12)
82
+ n_fft_new = int(np.round(n_fft * factor))
83
+ win_size_new = int(np.round(win_size * factor))
84
+ hop_length_new = int(np.round(hop_length * speed))
85
+ if not train:
86
+ mel_basis = self.mel_basis
87
+ hann_window = self.hann_window
88
+ else:
89
+ mel_basis = {}
90
+ hann_window = {}
91
+
92
+ if torch.min(y) < -1.:
93
+ print('min value is ', torch.min(y))
94
+ if torch.max(y) > 1.:
95
+ print('max value is ', torch.max(y))
96
+
97
+ mel_basis_key = str(fmax)+'_'+str(y.device)
98
+ if mel_basis_key not in mel_basis:
99
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
100
+ mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device)
101
+
102
+ keyshift_key = str(keyshift)+'_'+str(y.device)
103
+ if keyshift_key not in hann_window:
104
+ hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
105
+
106
+ pad_left = (win_size_new - hop_length_new) //2
107
+ pad_right = max((win_size_new- hop_length_new + 1) //2, win_size_new - y.size(-1) - pad_left)
108
+ if pad_right < y.size(-1):
109
+ mode = 'reflect'
110
+ else:
111
+ mode = 'constant'
112
+ y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode = mode)
113
+ y = y.squeeze(1)
114
+
115
+ spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=hann_window[keyshift_key],
116
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
117
+ spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9))
118
+ if keyshift != 0:
119
+ size = n_fft // 2 + 1
120
+ resize = spec.size(1)
121
+ if resize < size:
122
+ spec = F.pad(spec, (0, 0, 0, size-resize))
123
+ spec = spec[:, :size, :] * win_size / win_size_new
124
+ spec = torch.matmul(mel_basis[mel_basis_key], spec)
125
+ spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
126
+ return spec
127
+
128
+ def __call__(self, audiopath):
129
+ audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
130
+ spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
131
+ return spect
132
+
133
+ stft = STFT()
modules/F0Predictor/fcpe/pcmer.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from einops import rearrange, repeat
7
+ from local_attention import LocalAttention
8
+ from torch import nn
9
+
10
+ #import fast_transformers.causal_product.causal_product_cuda
11
+
12
+ def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device = None):
13
+ b, h, *_ = data.shape
14
+ # (batch size, head, length, model_dim)
15
+
16
+ # normalize model dim
17
+ data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
18
+
19
+ # what is ration?, projection_matrix.shape[0] --> 266
20
+
21
+ ratio = (projection_matrix.shape[0] ** -0.5)
22
+
23
+ projection = repeat(projection_matrix, 'j d -> b h j d', b = b, h = h)
24
+ projection = projection.type_as(data)
25
+
26
+ #data_dash = w^T x
27
+ data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)
28
+
29
+
30
+ # diag_data = D**2
31
+ diag_data = data ** 2
32
+ diag_data = torch.sum(diag_data, dim=-1)
33
+ diag_data = (diag_data / 2.0) * (data_normalizer ** 2)
34
+ diag_data = diag_data.unsqueeze(dim=-1)
35
+
36
+ #print ()
37
+ if is_query:
38
+ data_dash = ratio * (
39
+ torch.exp(data_dash - diag_data -
40
+ torch.max(data_dash, dim=-1, keepdim=True).values) + eps)
41
+ else:
42
+ data_dash = ratio * (
43
+ torch.exp(data_dash - diag_data + eps))#- torch.max(data_dash)) + eps)
44
+
45
+ return data_dash.type_as(data)
46
+
47
+ def orthogonal_matrix_chunk(cols, qr_uniform_q = False, device = None):
48
+ unstructured_block = torch.randn((cols, cols), device = device)
49
+ q, r = torch.linalg.qr(unstructured_block.cpu(), mode='reduced')
50
+ q, r = map(lambda t: t.to(device), (q, r))
51
+
52
+ # proposed by @Parskatt
53
+ # to make sure Q is uniform https://arxiv.org/pdf/math-ph/0609050.pdf
54
+ if qr_uniform_q:
55
+ d = torch.diag(r, 0)
56
+ q *= d.sign()
57
+ return q.t()
58
+ def exists(val):
59
+ return val is not None
60
+
61
+ def empty(tensor):
62
+ return tensor.numel() == 0
63
+
64
+ def default(val, d):
65
+ return val if exists(val) else d
66
+
67
+ def cast_tuple(val):
68
+ return (val,) if not isinstance(val, tuple) else val
69
+
70
+ class PCmer(nn.Module):
71
+ """The encoder that is used in the Transformer model."""
72
+
73
+ def __init__(self,
74
+ num_layers,
75
+ num_heads,
76
+ dim_model,
77
+ dim_keys,
78
+ dim_values,
79
+ residual_dropout,
80
+ attention_dropout):
81
+ super().__init__()
82
+ self.num_layers = num_layers
83
+ self.num_heads = num_heads
84
+ self.dim_model = dim_model
85
+ self.dim_values = dim_values
86
+ self.dim_keys = dim_keys
87
+ self.residual_dropout = residual_dropout
88
+ self.attention_dropout = attention_dropout
89
+
90
+ self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(num_layers)])
91
+
92
+ # METHODS ########################################################################################################
93
+
94
+ def forward(self, phone, mask=None):
95
+
96
+ # apply all layers to the input
97
+ for (i, layer) in enumerate(self._layers):
98
+ phone = layer(phone, mask)
99
+ # provide the final sequence
100
+ return phone
101
+
102
+
103
+ # ==================================================================================================================== #
104
+ # CLASS _ E N C O D E R L A Y E R #
105
+ # ==================================================================================================================== #
106
+
107
+
108
+ class _EncoderLayer(nn.Module):
109
+ """One layer of the encoder.
110
+
111
+ Attributes:
112
+ attn: (:class:`mha.MultiHeadAttention`): The attention mechanism that is used to read the input sequence.
113
+ feed_forward (:class:`ffl.FeedForwardLayer`): The feed-forward layer on top of the attention mechanism.
114
+ """
115
+
116
+ def __init__(self, parent: PCmer):
117
+ """Creates a new instance of ``_EncoderLayer``.
118
+
119
+ Args:
120
+ parent (Encoder): The encoder that the layers is created for.
121
+ """
122
+ super().__init__()
123
+
124
+
125
+ self.conformer = ConformerConvModule(parent.dim_model)
126
+ self.norm = nn.LayerNorm(parent.dim_model)
127
+ self.dropout = nn.Dropout(parent.residual_dropout)
128
+
129
+ # selfatt -> fastatt: performer!
130
+ self.attn = SelfAttention(dim = parent.dim_model,
131
+ heads = parent.num_heads,
132
+ causal = False)
133
+
134
+ # METHODS ########################################################################################################
135
+
136
+ def forward(self, phone, mask=None):
137
+
138
+ # compute attention sub-layer
139
+ phone = phone + (self.attn(self.norm(phone), mask=mask))
140
+
141
+ phone = phone + (self.conformer(phone))
142
+
143
+ return phone
144
+
145
+ def calc_same_padding(kernel_size):
146
+ pad = kernel_size // 2
147
+ return (pad, pad - (kernel_size + 1) % 2)
148
+
149
+ # helper classes
150
+
151
+ class Swish(nn.Module):
152
+ def forward(self, x):
153
+ return x * x.sigmoid()
154
+
155
+ class Transpose(nn.Module):
156
+ def __init__(self, dims):
157
+ super().__init__()
158
+ assert len(dims) == 2, 'dims must be a tuple of two dimensions'
159
+ self.dims = dims
160
+
161
+ def forward(self, x):
162
+ return x.transpose(*self.dims)
163
+
164
+ class GLU(nn.Module):
165
+ def __init__(self, dim):
166
+ super().__init__()
167
+ self.dim = dim
168
+
169
+ def forward(self, x):
170
+ out, gate = x.chunk(2, dim=self.dim)
171
+ return out * gate.sigmoid()
172
+
173
+ class DepthWiseConv1d(nn.Module):
174
+ def __init__(self, chan_in, chan_out, kernel_size, padding):
175
+ super().__init__()
176
+ self.padding = padding
177
+ self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)
178
+
179
+ def forward(self, x):
180
+ x = F.pad(x, self.padding)
181
+ return self.conv(x)
182
+
183
+ class ConformerConvModule(nn.Module):
184
+ def __init__(
185
+ self,
186
+ dim,
187
+ causal = False,
188
+ expansion_factor = 2,
189
+ kernel_size = 31,
190
+ dropout = 0.):
191
+ super().__init__()
192
+
193
+ inner_dim = dim * expansion_factor
194
+ padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
195
+
196
+ self.net = nn.Sequential(
197
+ nn.LayerNorm(dim),
198
+ Transpose((1, 2)),
199
+ nn.Conv1d(dim, inner_dim * 2, 1),
200
+ GLU(dim=1),
201
+ DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding),
202
+ #nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
203
+ Swish(),
204
+ nn.Conv1d(inner_dim, dim, 1),
205
+ Transpose((1, 2)),
206
+ nn.Dropout(dropout)
207
+ )
208
+
209
+ def forward(self, x):
210
+ return self.net(x)
211
+
212
+ def linear_attention(q, k, v):
213
+ if v is None:
214
+ #print (k.size(), q.size())
215
+ out = torch.einsum('...ed,...nd->...ne', k, q)
216
+ return out
217
+
218
+ else:
219
+ k_cumsum = k.sum(dim = -2)
220
+ #k_cumsum = k.sum(dim = -2)
221
+ D_inv = 1. / (torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q)) + 1e-8)
222
+
223
+ context = torch.einsum('...nd,...ne->...de', k, v)
224
+ #print ("TRUEEE: ", context.size(), q.size(), D_inv.size())
225
+ out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
226
+ return out
227
+
228
+ def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling = 0, qr_uniform_q = False, device = None):
229
+ nb_full_blocks = int(nb_rows / nb_columns)
230
+ #print (nb_full_blocks)
231
+ block_list = []
232
+
233
+ for _ in range(nb_full_blocks):
234
+ q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device)
235
+ block_list.append(q)
236
+ # block_list[n] is a orthogonal matrix ... (model_dim * model_dim)
237
+ #print (block_list[0].size(), torch.einsum('...nd,...nd->...n', block_list[0], torch.roll(block_list[0],1,1)))
238
+ #print (nb_rows, nb_full_blocks, nb_columns)
239
+ remaining_rows = nb_rows - nb_full_blocks * nb_columns
240
+ #print (remaining_rows)
241
+ if remaining_rows > 0:
242
+ q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q = qr_uniform_q, device = device)
243
+ #print (q[:remaining_rows].size())
244
+ block_list.append(q[:remaining_rows])
245
+
246
+ final_matrix = torch.cat(block_list)
247
+
248
+ if scaling == 0:
249
+ multiplier = torch.randn((nb_rows, nb_columns), device = device).norm(dim = 1)
250
+ elif scaling == 1:
251
+ multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device = device)
252
+ else:
253
+ raise ValueError(f'Invalid scaling {scaling}')
254
+
255
+ return torch.diag(multiplier) @ final_matrix
256
+
257
+ class FastAttention(nn.Module):
258
+ def __init__(self, dim_heads, nb_features = None, ortho_scaling = 0, causal = False, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, no_projection = False):
259
+ super().__init__()
260
+ nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
261
+
262
+ self.dim_heads = dim_heads
263
+ self.nb_features = nb_features
264
+ self.ortho_scaling = ortho_scaling
265
+
266
+ self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows = self.nb_features, nb_columns = dim_heads, scaling = ortho_scaling, qr_uniform_q = qr_uniform_q)
267
+ projection_matrix = self.create_projection()
268
+ self.register_buffer('projection_matrix', projection_matrix)
269
+
270
+ self.generalized_attention = generalized_attention
271
+ self.kernel_fn = kernel_fn
272
+
273
+ # if this is turned on, no projection will be used
274
+ # queries and keys will be softmax-ed as in the original efficient attention paper
275
+ self.no_projection = no_projection
276
+
277
+ self.causal = causal
278
+
279
+ @torch.no_grad()
280
+ def redraw_projection_matrix(self):
281
+ projections = self.create_projection()
282
+ self.projection_matrix.copy_(projections)
283
+ del projections
284
+
285
+ def forward(self, q, k, v):
286
+ device = q.device
287
+
288
+ if self.no_projection:
289
+ q = q.softmax(dim = -1)
290
+ k = torch.exp(k) if self.causal else k.softmax(dim = -2)
291
+ else:
292
+ create_kernel = partial(softmax_kernel, projection_matrix = self.projection_matrix, device = device)
293
+
294
+ q = create_kernel(q, is_query = True)
295
+ k = create_kernel(k, is_query = False)
296
+
297
+ attn_fn = linear_attention if not self.causal else self.causal_linear_fn
298
+ if v is None:
299
+ out = attn_fn(q, k, None)
300
+ return out
301
+ else:
302
+ out = attn_fn(q, k, v)
303
+ return out
304
+ class SelfAttention(nn.Module):
305
+ def __init__(self, dim, causal = False, heads = 8, dim_head = 64, local_heads = 0, local_window_size = 256, nb_features = None, feature_redraw_interval = 1000, generalized_attention = False, kernel_fn = nn.ReLU(), qr_uniform_q = False, dropout = 0., no_projection = False):
306
+ super().__init__()
307
+ assert dim % heads == 0, 'dimension must be divisible by number of heads'
308
+ dim_head = default(dim_head, dim // heads)
309
+ inner_dim = dim_head * heads
310
+ self.fast_attention = FastAttention(dim_head, nb_features, causal = causal, generalized_attention = generalized_attention, kernel_fn = kernel_fn, qr_uniform_q = qr_uniform_q, no_projection = no_projection)
311
+
312
+ self.heads = heads
313
+ self.global_heads = heads - local_heads
314
+ self.local_attn = LocalAttention(window_size = local_window_size, causal = causal, autopad = True, dropout = dropout, look_forward = int(not causal), rel_pos_emb_config = (dim_head, local_heads)) if local_heads > 0 else None
315
+
316
+ #print (heads, nb_features, dim_head)
317
+ #name_embedding = torch.zeros(110, heads, dim_head, dim_head)
318
+ #self.name_embedding = nn.Parameter(name_embedding, requires_grad=True)
319
+
320
+
321
+ self.to_q = nn.Linear(dim, inner_dim)
322
+ self.to_k = nn.Linear(dim, inner_dim)
323
+ self.to_v = nn.Linear(dim, inner_dim)
324
+ self.to_out = nn.Linear(inner_dim, dim)
325
+ self.dropout = nn.Dropout(dropout)
326
+
327
+ @torch.no_grad()
328
+ def redraw_projection_matrix(self):
329
+ self.fast_attention.redraw_projection_matrix()
330
+ #torch.nn.init.zeros_(self.name_embedding)
331
+ #print (torch.sum(self.name_embedding))
332
+ def forward(self, x, context = None, mask = None, context_mask = None, name=None, inference=False, **kwargs):
333
+ _, _, _, h, gh = *x.shape, self.heads, self.global_heads
334
+
335
+ cross_attend = exists(context)
336
+
337
+ context = default(context, x)
338
+ context_mask = default(context_mask, mask) if not cross_attend else context_mask
339
+ #print (torch.sum(self.name_embedding))
340
+ q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
341
+
342
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
343
+ (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
344
+
345
+ attn_outs = []
346
+ #print (name)
347
+ #print (self.name_embedding[name].size())
348
+ if not empty(q):
349
+ if exists(context_mask):
350
+ global_mask = context_mask[:, None, :, None]
351
+ v.masked_fill_(~global_mask, 0.)
352
+ if cross_attend:
353
+ pass
354
+ #print (torch.sum(self.name_embedding))
355
+ #out = self.fast_attention(q,self.name_embedding[name],None)
356
+ #print (torch.sum(self.name_embedding[...,-1:]))
357
+ else:
358
+ out = self.fast_attention(q, k, v)
359
+ attn_outs.append(out)
360
+
361
+ if not empty(lq):
362
+ assert not cross_attend, 'local attention is not compatible with cross attention'
363
+ out = self.local_attn(lq, lk, lv, input_mask = mask)
364
+ attn_outs.append(out)
365
+
366
+ out = torch.cat(attn_outs, dim = 1)
367
+ out = rearrange(out, 'b h n d -> b n (h d)')
368
+ out = self.to_out(out)
369
+ return self.dropout(out)
modules/F0Predictor/rmvpe/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .constants import * # noqa: F403
2
+ from .inference import RMVPE # noqa: F401
3
+ from .model import E2E, E2E0 # noqa: F401
4
+ from .spec import MelSpectrogram # noqa: F401
5
+ from .utils import ( # noqa: F401
6
+ cycle,
7
+ summary,
8
+ to_local_average_cents,
9
+ to_viterbi_cents,
10
+ )
modules/F0Predictor/rmvpe/constants.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ SAMPLE_RATE = 16000
2
+
3
+ N_CLASS = 360
4
+
5
+ N_MELS = 128
6
+ MEL_FMIN = 30
7
+ MEL_FMAX = SAMPLE_RATE // 2
8
+ WINDOW_LENGTH = 1024
9
+ CONST = 1997.3794084376191
modules/F0Predictor/rmvpe/deepunet.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .constants import N_MELS
5
+
6
+
7
+ class ConvBlockRes(nn.Module):
8
+ def __init__(self, in_channels, out_channels, momentum=0.01):
9
+ super(ConvBlockRes, self).__init__()
10
+ self.conv = nn.Sequential(
11
+ nn.Conv2d(in_channels=in_channels,
12
+ out_channels=out_channels,
13
+ kernel_size=(3, 3),
14
+ stride=(1, 1),
15
+ padding=(1, 1),
16
+ bias=False),
17
+ nn.BatchNorm2d(out_channels, momentum=momentum),
18
+ nn.ReLU(),
19
+
20
+ nn.Conv2d(in_channels=out_channels,
21
+ out_channels=out_channels,
22
+ kernel_size=(3, 3),
23
+ stride=(1, 1),
24
+ padding=(1, 1),
25
+ bias=False),
26
+ nn.BatchNorm2d(out_channels, momentum=momentum),
27
+ nn.ReLU(),
28
+ )
29
+ if in_channels != out_channels:
30
+ self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
31
+ self.is_shortcut = True
32
+ else:
33
+ self.is_shortcut = False
34
+
35
+ def forward(self, x):
36
+ if self.is_shortcut:
37
+ return self.conv(x) + self.shortcut(x)
38
+ else:
39
+ return self.conv(x) + x
40
+
41
+
42
+ class ResEncoderBlock(nn.Module):
43
+ def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
44
+ super(ResEncoderBlock, self).__init__()
45
+ self.n_blocks = n_blocks
46
+ self.conv = nn.ModuleList()
47
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
48
+ for i in range(n_blocks - 1):
49
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
50
+ self.kernel_size = kernel_size
51
+ if self.kernel_size is not None:
52
+ self.pool = nn.AvgPool2d(kernel_size=kernel_size)
53
+
54
+ def forward(self, x):
55
+ for i in range(self.n_blocks):
56
+ x = self.conv[i](x)
57
+ if self.kernel_size is not None:
58
+ return x, self.pool(x)
59
+ else:
60
+ return x
61
+
62
+
63
+ class ResDecoderBlock(nn.Module):
64
+ def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
65
+ super(ResDecoderBlock, self).__init__()
66
+ out_padding = (0, 1) if stride == (1, 2) else (1, 1)
67
+ self.n_blocks = n_blocks
68
+ self.conv1 = nn.Sequential(
69
+ nn.ConvTranspose2d(in_channels=in_channels,
70
+ out_channels=out_channels,
71
+ kernel_size=(3, 3),
72
+ stride=stride,
73
+ padding=(1, 1),
74
+ output_padding=out_padding,
75
+ bias=False),
76
+ nn.BatchNorm2d(out_channels, momentum=momentum),
77
+ nn.ReLU(),
78
+ )
79
+ self.conv2 = nn.ModuleList()
80
+ self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
81
+ for i in range(n_blocks-1):
82
+ self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
83
+
84
+ def forward(self, x, concat_tensor):
85
+ x = self.conv1(x)
86
+ x = torch.cat((x, concat_tensor), dim=1)
87
+ for i in range(self.n_blocks):
88
+ x = self.conv2[i](x)
89
+ return x
90
+
91
+
92
+ class Encoder(nn.Module):
93
+ def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
94
+ super(Encoder, self).__init__()
95
+ self.n_encoders = n_encoders
96
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
97
+ self.layers = nn.ModuleList()
98
+ self.latent_channels = []
99
+ for i in range(self.n_encoders):
100
+ self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum))
101
+ self.latent_channels.append([out_channels, in_size])
102
+ in_channels = out_channels
103
+ out_channels *= 2
104
+ in_size //= 2
105
+ self.out_size = in_size
106
+ self.out_channel = out_channels
107
+
108
+ def forward(self, x):
109
+ concat_tensors = []
110
+ x = self.bn(x)
111
+ for i in range(self.n_encoders):
112
+ _, x = self.layers[i](x)
113
+ concat_tensors.append(_)
114
+ return x, concat_tensors
115
+
116
+
117
+ class Intermediate(nn.Module):
118
+ def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
119
+ super(Intermediate, self).__init__()
120
+ self.n_inters = n_inters
121
+ self.layers = nn.ModuleList()
122
+ self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
123
+ for i in range(self.n_inters-1):
124
+ self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
125
+
126
+ def forward(self, x):
127
+ for i in range(self.n_inters):
128
+ x = self.layers[i](x)
129
+ return x
130
+
131
+
132
+ class Decoder(nn.Module):
133
+ def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
134
+ super(Decoder, self).__init__()
135
+ self.layers = nn.ModuleList()
136
+ self.n_decoders = n_decoders
137
+ for i in range(self.n_decoders):
138
+ out_channels = in_channels // 2
139
+ self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
140
+ in_channels = out_channels
141
+
142
+ def forward(self, x, concat_tensors):
143
+ for i in range(self.n_decoders):
144
+ x = self.layers[i](x, concat_tensors[-1-i])
145
+ return x
146
+
147
+
148
+ class TimbreFilter(nn.Module):
149
+ def __init__(self, latent_rep_channels):
150
+ super(TimbreFilter, self).__init__()
151
+ self.layers = nn.ModuleList()
152
+ for latent_rep in latent_rep_channels:
153
+ self.layers.append(ConvBlockRes(latent_rep[0], latent_rep[0]))
154
+
155
+ def forward(self, x_tensors):
156
+ out_tensors = []
157
+ for i, layer in enumerate(self.layers):
158
+ out_tensors.append(layer(x_tensors[i]))
159
+ return out_tensors
160
+
161
+
162
+ class DeepUnet(nn.Module):
163
+ def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
164
+ super(DeepUnet, self).__init__()
165
+ self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels)
166
+ self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
167
+ self.tf = TimbreFilter(self.encoder.latent_channels)
168
+ self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
169
+
170
+ def forward(self, x):
171
+ x, concat_tensors = self.encoder(x)
172
+ x = self.intermediate(x)
173
+ concat_tensors = self.tf(concat_tensors)
174
+ x = self.decoder(x, concat_tensors)
175
+ return x
176
+
177
+
178
+ class DeepUnet0(nn.Module):
179
+ def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
180
+ super(DeepUnet0, self).__init__()
181
+ self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels)
182
+ self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
183
+ self.tf = TimbreFilter(self.encoder.latent_channels)
184
+ self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
185
+
186
+ def forward(self, x):
187
+ x, concat_tensors = self.encoder(x)
188
+ x = self.intermediate(x)
189
+ x = self.decoder(x, concat_tensors)
190
+ return x
modules/F0Predictor/rmvpe/inference.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torchaudio.transforms import Resample
4
+
5
+ from .constants import * # noqa: F403
6
+ from .model import E2E0
7
+ from .spec import MelSpectrogram
8
+ from .utils import to_local_average_cents, to_viterbi_cents
9
+
10
+
11
+ class RMVPE:
12
+ def __init__(self, model_path, device=None, dtype = torch.float32, hop_length=160):
13
+ self.resample_kernel = {}
14
+ if device is None:
15
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+ else:
17
+ self.device = device
18
+ model = E2E0(4, 1, (2, 2))
19
+ ckpt = torch.load(model_path, map_location=torch.device(self.device))
20
+ model.load_state_dict(ckpt['model'])
21
+ model = model.to(dtype).to(self.device)
22
+ model.eval()
23
+ self.model = model
24
+ self.dtype = dtype
25
+ self.mel_extractor = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) # noqa: F405
26
+ self.resample_kernel = {}
27
+
28
+ def mel2hidden(self, mel):
29
+ with torch.no_grad():
30
+ n_frames = mel.shape[-1]
31
+ mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='constant')
32
+ hidden = self.model(mel)
33
+ return hidden[:, :n_frames]
34
+
35
+ def decode(self, hidden, thred=0.03, use_viterbi=False):
36
+ if use_viterbi:
37
+ cents_pred = to_viterbi_cents(hidden, thred=thred)
38
+ else:
39
+ cents_pred = to_local_average_cents(hidden, thred=thred)
40
+ f0 = torch.Tensor([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred]).to(self.device)
41
+ return f0
42
+
43
+ def infer_from_audio(self, audio, sample_rate=16000, thred=0.05, use_viterbi=False):
44
+ audio = audio.unsqueeze(0).to(self.dtype).to(self.device)
45
+ if sample_rate == 16000:
46
+ audio_res = audio
47
+ else:
48
+ key_str = str(sample_rate)
49
+ if key_str not in self.resample_kernel:
50
+ self.resample_kernel[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128)
51
+ self.resample_kernel[key_str] = self.resample_kernel[key_str].to(self.dtype).to(self.device)
52
+ audio_res = self.resample_kernel[key_str](audio)
53
+ mel_extractor = self.mel_extractor.to(self.device)
54
+ mel = mel_extractor(audio_res, center=True).to(self.dtype)
55
+ hidden = self.mel2hidden(mel)
56
+ f0 = self.decode(hidden.squeeze(0), thred=thred, use_viterbi=use_viterbi)
57
+ return f0
modules/F0Predictor/rmvpe/model.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ from .constants import * # noqa: F403
4
+ from .deepunet import DeepUnet, DeepUnet0
5
+ from .seq import BiGRU
6
+ from .spec import MelSpectrogram
7
+
8
+
9
+ class E2E(nn.Module):
10
+ def __init__(self, hop_length, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1,
11
+ en_out_channels=16):
12
+ super(E2E, self).__init__()
13
+ self.mel = MelSpectrogram(N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX) # noqa: F405
14
+ self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
15
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
16
+ if n_gru:
17
+ self.fc = nn.Sequential(
18
+ BiGRU(3 * N_MELS, 256, n_gru), # noqa: F405
19
+ nn.Linear(512, N_CLASS), # noqa: F405
20
+ nn.Dropout(0.25),
21
+ nn.Sigmoid()
22
+ )
23
+ else:
24
+ self.fc = nn.Sequential(
25
+ nn.Linear(3 * N_MELS, N_CLASS), # noqa: F405
26
+ nn.Dropout(0.25),
27
+ nn.Sigmoid()
28
+ )
29
+
30
+ def forward(self, x):
31
+ mel = self.mel(x.reshape(-1, x.shape[-1])).transpose(-1, -2).unsqueeze(1)
32
+ x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
33
+ # x = self.fc(x)
34
+ hidden_vec = 0
35
+ if len(self.fc) == 4:
36
+ for i in range(len(self.fc)):
37
+ x = self.fc[i](x)
38
+ if i == 0:
39
+ hidden_vec = x
40
+ return hidden_vec, x
41
+
42
+
43
+ class E2E0(nn.Module):
44
+ def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1,
45
+ en_out_channels=16):
46
+ super(E2E0, self).__init__()
47
+ self.unet = DeepUnet0(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
48
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
49
+ if n_gru:
50
+ self.fc = nn.Sequential(
51
+ BiGRU(3 * N_MELS, 256, n_gru), # noqa: F405
52
+ nn.Linear(512, N_CLASS), # noqa: F405
53
+ nn.Dropout(0.25),
54
+ nn.Sigmoid()
55
+ )
56
+ else:
57
+ self.fc = nn.Sequential(
58
+ nn.Linear(3 * N_MELS, N_CLASS), # noqa: F405
59
+ nn.Dropout(0.25),
60
+ nn.Sigmoid()
61
+ )
62
+
63
+ def forward(self, mel):
64
+ mel = mel.transpose(-1, -2).unsqueeze(1)
65
+ x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
66
+ x = self.fc(x)
67
+ return x
modules/F0Predictor/rmvpe/seq.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class BiGRU(nn.Module):
5
+ def __init__(self, input_features, hidden_features, num_layers):
6
+ super(BiGRU, self).__init__()
7
+ self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
8
+
9
+ def forward(self, x):
10
+ return self.gru(x)[0]
11
+
12
+
13
+ class BiLSTM(nn.Module):
14
+ def __init__(self, input_features, hidden_features, num_layers):
15
+ super(BiLSTM, self).__init__()
16
+ self.lstm = nn.LSTM(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
17
+
18
+ def forward(self, x):
19
+ return self.lstm(x)[0]
20
+
modules/F0Predictor/rmvpe/spec.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from librosa.filters import mel
5
+
6
+
7
+ class MelSpectrogram(torch.nn.Module):
8
+ def __init__(
9
+ self,
10
+ n_mel_channels,
11
+ sampling_rate,
12
+ win_length,
13
+ hop_length,
14
+ n_fft=None,
15
+ mel_fmin=0,
16
+ mel_fmax=None,
17
+ clamp = 1e-5
18
+ ):
19
+ super().__init__()
20
+ n_fft = win_length if n_fft is None else n_fft
21
+ self.hann_window = {}
22
+ mel_basis = mel(
23
+ sr=sampling_rate,
24
+ n_fft=n_fft,
25
+ n_mels=n_mel_channels,
26
+ fmin=mel_fmin,
27
+ fmax=mel_fmax,
28
+ htk=True)
29
+ mel_basis = torch.from_numpy(mel_basis).float()
30
+ self.register_buffer("mel_basis", mel_basis)
31
+ self.n_fft = win_length if n_fft is None else n_fft
32
+ self.hop_length = hop_length
33
+ self.win_length = win_length
34
+ self.sampling_rate = sampling_rate
35
+ self.n_mel_channels = n_mel_channels
36
+ self.clamp = clamp
37
+
38
+ def forward(self, audio, keyshift=0, speed=1, center=True):
39
+ factor = 2 ** (keyshift / 12)
40
+ n_fft_new = int(np.round(self.n_fft * factor))
41
+ win_length_new = int(np.round(self.win_length * factor))
42
+ hop_length_new = int(np.round(self.hop_length * speed))
43
+
44
+ keyshift_key = str(keyshift)+'_'+str(audio.device)
45
+ if keyshift_key not in self.hann_window:
46
+ self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
47
+
48
+ fft = torch.stft(
49
+ audio,
50
+ n_fft=n_fft_new,
51
+ hop_length=hop_length_new,
52
+ win_length=win_length_new,
53
+ window=self.hann_window[keyshift_key],
54
+ center=center,
55
+ return_complex=True)
56
+ magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
57
+
58
+ if keyshift != 0:
59
+ size = self.n_fft // 2 + 1
60
+ resize = magnitude.size(1)
61
+ if resize < size:
62
+ magnitude = F.pad(magnitude, (0, 0, 0, size-resize))
63
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
64
+
65
+ mel_output = torch.matmul(self.mel_basis, magnitude)
66
+ log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
67
+ return log_mel_spec
modules/F0Predictor/rmvpe/utils.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from functools import reduce
3
+
4
+ import librosa
5
+ import numpy as np
6
+ import torch
7
+ from torch.nn.modules.module import _addindent
8
+
9
+ from .constants import * # noqa: F403
10
+
11
+
12
+ def cycle(iterable):
13
+ while True:
14
+ for item in iterable:
15
+ yield item
16
+
17
+
18
+ def summary(model, file=sys.stdout):
19
+ def repr(model):
20
+ # We treat the extra repr like the sub-module, one item per line
21
+ extra_lines = []
22
+ extra_repr = model.extra_repr()
23
+ # empty string will be split into list ['']
24
+ if extra_repr:
25
+ extra_lines = extra_repr.split('\n')
26
+ child_lines = []
27
+ total_params = 0
28
+ for key, module in model._modules.items():
29
+ mod_str, num_params = repr(module)
30
+ mod_str = _addindent(mod_str, 2)
31
+ child_lines.append('(' + key + '): ' + mod_str)
32
+ total_params += num_params
33
+ lines = extra_lines + child_lines
34
+
35
+ for name, p in model._parameters.items():
36
+ if hasattr(p, 'shape'):
37
+ total_params += reduce(lambda x, y: x * y, p.shape)
38
+
39
+ main_str = model._get_name() + '('
40
+ if lines:
41
+ # simple one-liner info, which most builtin Modules will use
42
+ if len(extra_lines) == 1 and not child_lines:
43
+ main_str += extra_lines[0]
44
+ else:
45
+ main_str += '\n ' + '\n '.join(lines) + '\n'
46
+
47
+ main_str += ')'
48
+ if file is sys.stdout:
49
+ main_str += ', \033[92m{:,}\033[0m params'.format(total_params)
50
+ else:
51
+ main_str += ', {:,} params'.format(total_params)
52
+ return main_str, total_params
53
+
54
+ string, count = repr(model)
55
+ if file is not None:
56
+ if isinstance(file, str):
57
+ file = open(file, 'w')
58
+ print(string, file=file)
59
+ file.flush()
60
+
61
+ return count
62
+
63
+
64
+ def to_local_average_cents(salience, center=None, thred=0.05):
65
+ """
66
+ find the weighted average cents near the argmax bin
67
+ """
68
+
69
+ if not hasattr(to_local_average_cents, 'cents_mapping'):
70
+ # the bin number-to-cents mapping
71
+ to_local_average_cents.cents_mapping = (
72
+ 20 * torch.arange(N_CLASS) + CONST).to(salience.device) # noqa: F405
73
+
74
+ if salience.ndim == 1:
75
+ if center is None:
76
+ center = int(torch.argmax(salience))
77
+ start = max(0, center - 4)
78
+ end = min(len(salience), center + 5)
79
+ salience = salience[start:end]
80
+ product_sum = torch.sum(
81
+ salience * to_local_average_cents.cents_mapping[start:end])
82
+ weight_sum = torch.sum(salience)
83
+ return product_sum / weight_sum if torch.max(salience) > thred else 0
84
+ if salience.ndim == 2:
85
+ return torch.Tensor([to_local_average_cents(salience[i, :], None, thred) for i in
86
+ range(salience.shape[0])]).to(salience.device)
87
+
88
+ raise Exception("label should be either 1d or 2d ndarray")
89
+
90
+ def to_viterbi_cents(salience, thred=0.05):
91
+ # Create viterbi transition matrix
92
+ if not hasattr(to_viterbi_cents, 'transition'):
93
+ xx, yy = torch.meshgrid(range(N_CLASS), range(N_CLASS)) # noqa: F405
94
+ transition = torch.maximum(30 - abs(xx - yy), 0)
95
+ transition = transition / transition.sum(axis=1, keepdims=True)
96
+ to_viterbi_cents.transition = transition
97
+
98
+ # Convert to probability
99
+ prob = salience.T
100
+ prob = prob / prob.sum(axis=0)
101
+
102
+ # Perform viterbi decoding
103
+ path = librosa.sequence.viterbi(prob.detach().cpu().numpy(), to_viterbi_cents.transition).astype(np.int64)
104
+
105
+ return torch.Tensor([to_local_average_cents(salience[i, :], path[i], thred) for i in
106
+ range(len(path))]).to(salience.device)
107
+
modules/__init__.py ADDED
File without changes
modules/__pycache__/DSConv.cpython-38.pyc ADDED
Binary file (3.07 kB). View file
 
modules/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (157 Bytes). View file
 
modules/__pycache__/attentions.cpython-38.pyc ADDED
Binary file (10.6 kB). View file
 
modules/__pycache__/commons.cpython-38.pyc ADDED
Binary file (6.44 kB). View file
 
modules/__pycache__/losses.cpython-38.pyc ADDED
Binary file (1.53 kB). View file
 
modules/__pycache__/mel_processing.cpython-38.pyc ADDED
Binary file (3.45 kB). View file
 
modules/__pycache__/modules.cpython-38.pyc ADDED
Binary file (9.08 kB). View file
 
modules/__pycache__/slicer2.cpython-38.pyc ADDED
Binary file (4.99 kB). View file
 
modules/attentions.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ import modules.commons as commons
8
+ from modules.modules import LayerNorm
9
+
10
+
11
+ class FFT(nn.Module):
12
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel_size=1, p_dropout=0.,
13
+ proximal_bias=False, proximal_init=True, **kwargs):
14
+ super().__init__()
15
+ self.hidden_channels = hidden_channels
16
+ self.filter_channels = filter_channels
17
+ self.n_heads = n_heads
18
+ self.n_layers = n_layers
19
+ self.kernel_size = kernel_size
20
+ self.p_dropout = p_dropout
21
+ self.proximal_bias = proximal_bias
22
+ self.proximal_init = proximal_init
23
+
24
+ self.drop = nn.Dropout(p_dropout)
25
+ self.self_attn_layers = nn.ModuleList()
26
+ self.norm_layers_0 = nn.ModuleList()
27
+ self.ffn_layers = nn.ModuleList()
28
+ self.norm_layers_1 = nn.ModuleList()
29
+ for i in range(self.n_layers):
30
+ self.self_attn_layers.append(
31
+ MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias,
32
+ proximal_init=proximal_init))
33
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
34
+ self.ffn_layers.append(
35
+ FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
36
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
37
+
38
+ def forward(self, x, x_mask):
39
+ """
40
+ x: decoder input
41
+ h: encoder output
42
+ """
43
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
44
+ x = x * x_mask
45
+ for i in range(self.n_layers):
46
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
47
+ y = self.drop(y)
48
+ x = self.norm_layers_0[i](x + y)
49
+
50
+ y = self.ffn_layers[i](x, x_mask)
51
+ y = self.drop(y)
52
+ x = self.norm_layers_1[i](x + y)
53
+ x = x * x_mask
54
+ return x
55
+
56
+
57
+ class Encoder(nn.Module):
58
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs):
59
+ super().__init__()
60
+ self.hidden_channels = hidden_channels
61
+ self.filter_channels = filter_channels
62
+ self.n_heads = n_heads
63
+ self.n_layers = n_layers
64
+ self.kernel_size = kernel_size
65
+ self.p_dropout = p_dropout
66
+ self.window_size = window_size
67
+
68
+ self.drop = nn.Dropout(p_dropout)
69
+ self.attn_layers = nn.ModuleList()
70
+ self.norm_layers_1 = nn.ModuleList()
71
+ self.ffn_layers = nn.ModuleList()
72
+ self.norm_layers_2 = nn.ModuleList()
73
+ for i in range(self.n_layers):
74
+ self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
75
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
76
+ self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
77
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
78
+
79
+ def forward(self, x, x_mask):
80
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
81
+ x = x * x_mask
82
+ for i in range(self.n_layers):
83
+ y = self.attn_layers[i](x, x, attn_mask)
84
+ y = self.drop(y)
85
+ x = self.norm_layers_1[i](x + y)
86
+
87
+ y = self.ffn_layers[i](x, x_mask)
88
+ y = self.drop(y)
89
+ x = self.norm_layers_2[i](x + y)
90
+ x = x * x_mask
91
+ return x
92
+
93
+
94
+ class Decoder(nn.Module):
95
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
96
+ super().__init__()
97
+ self.hidden_channels = hidden_channels
98
+ self.filter_channels = filter_channels
99
+ self.n_heads = n_heads
100
+ self.n_layers = n_layers
101
+ self.kernel_size = kernel_size
102
+ self.p_dropout = p_dropout
103
+ self.proximal_bias = proximal_bias
104
+ self.proximal_init = proximal_init
105
+
106
+ self.drop = nn.Dropout(p_dropout)
107
+ self.self_attn_layers = nn.ModuleList()
108
+ self.norm_layers_0 = nn.ModuleList()
109
+ self.encdec_attn_layers = nn.ModuleList()
110
+ self.norm_layers_1 = nn.ModuleList()
111
+ self.ffn_layers = nn.ModuleList()
112
+ self.norm_layers_2 = nn.ModuleList()
113
+ for i in range(self.n_layers):
114
+ self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init))
115
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
116
+ self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
117
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
118
+ self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
119
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
120
+
121
+ def forward(self, x, x_mask, h, h_mask):
122
+ """
123
+ x: decoder input
124
+ h: encoder output
125
+ """
126
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
127
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
128
+ x = x * x_mask
129
+ for i in range(self.n_layers):
130
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
131
+ y = self.drop(y)
132
+ x = self.norm_layers_0[i](x + y)
133
+
134
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
135
+ y = self.drop(y)
136
+ x = self.norm_layers_1[i](x + y)
137
+
138
+ y = self.ffn_layers[i](x, x_mask)
139
+ y = self.drop(y)
140
+ x = self.norm_layers_2[i](x + y)
141
+ x = x * x_mask
142
+ return x
143
+
144
+
145
+ class MultiHeadAttention(nn.Module):
146
+ def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
147
+ super().__init__()
148
+ assert channels % n_heads == 0
149
+
150
+ self.channels = channels
151
+ self.out_channels = out_channels
152
+ self.n_heads = n_heads
153
+ self.p_dropout = p_dropout
154
+ self.window_size = window_size
155
+ self.heads_share = heads_share
156
+ self.block_length = block_length
157
+ self.proximal_bias = proximal_bias
158
+ self.proximal_init = proximal_init
159
+ self.attn = None
160
+
161
+ self.k_channels = channels // n_heads
162
+ self.conv_q = nn.Conv1d(channels, channels, 1)
163
+ self.conv_k = nn.Conv1d(channels, channels, 1)
164
+ self.conv_v = nn.Conv1d(channels, channels, 1)
165
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
166
+ self.drop = nn.Dropout(p_dropout)
167
+
168
+ if window_size is not None:
169
+ n_heads_rel = 1 if heads_share else n_heads
170
+ rel_stddev = self.k_channels**-0.5
171
+ self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
172
+ self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
173
+
174
+ nn.init.xavier_uniform_(self.conv_q.weight)
175
+ nn.init.xavier_uniform_(self.conv_k.weight)
176
+ nn.init.xavier_uniform_(self.conv_v.weight)
177
+ if proximal_init:
178
+ with torch.no_grad():
179
+ self.conv_k.weight.copy_(self.conv_q.weight)
180
+ self.conv_k.bias.copy_(self.conv_q.bias)
181
+
182
+ def forward(self, x, c, attn_mask=None):
183
+ q = self.conv_q(x)
184
+ k = self.conv_k(c)
185
+ v = self.conv_v(c)
186
+
187
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
188
+
189
+ x = self.conv_o(x)
190
+ return x
191
+
192
+ def attention(self, query, key, value, mask=None):
193
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
194
+ b, d, t_s, t_t = (*key.size(), query.size(2))
195
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
196
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
197
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
198
+
199
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
200
+ if self.window_size is not None:
201
+ assert t_s == t_t, "Relative attention is only available for self-attention."
202
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
203
+ rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings)
204
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
205
+ scores = scores + scores_local
206
+ if self.proximal_bias:
207
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
208
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
209
+ if mask is not None:
210
+ scores = scores.masked_fill(mask == 0, -1e4)
211
+ if self.block_length is not None:
212
+ assert t_s == t_t, "Local attention is only available for self-attention."
213
+ block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
214
+ scores = scores.masked_fill(block_mask == 0, -1e4)
215
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
216
+ p_attn = self.drop(p_attn)
217
+ output = torch.matmul(p_attn, value)
218
+ if self.window_size is not None:
219
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
220
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
221
+ output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
222
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
223
+ return output, p_attn
224
+
225
+ def _matmul_with_relative_values(self, x, y):
226
+ """
227
+ x: [b, h, l, m]
228
+ y: [h or 1, m, d]
229
+ ret: [b, h, l, d]
230
+ """
231
+ ret = torch.matmul(x, y.unsqueeze(0))
232
+ return ret
233
+
234
+ def _matmul_with_relative_keys(self, x, y):
235
+ """
236
+ x: [b, h, l, d]
237
+ y: [h or 1, m, d]
238
+ ret: [b, h, l, m]
239
+ """
240
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
241
+ return ret
242
+
243
+ def _get_relative_embeddings(self, relative_embeddings, length):
244
+ 2 * self.window_size + 1
245
+ # Pad first before slice to avoid using cond ops.
246
+ pad_length = max(length - (self.window_size + 1), 0)
247
+ slice_start_position = max((self.window_size + 1) - length, 0)
248
+ slice_end_position = slice_start_position + 2 * length - 1
249
+ if pad_length > 0:
250
+ padded_relative_embeddings = F.pad(
251
+ relative_embeddings,
252
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
253
+ else:
254
+ padded_relative_embeddings = relative_embeddings
255
+ used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
256
+ return used_relative_embeddings
257
+
258
+ def _relative_position_to_absolute_position(self, x):
259
+ """
260
+ x: [b, h, l, 2*l-1]
261
+ ret: [b, h, l, l]
262
+ """
263
+ batch, heads, length, _ = x.size()
264
+ # Concat columns of pad to shift from relative to absolute indexing.
265
+ x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
266
+
267
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
268
+ x_flat = x.view([batch, heads, length * 2 * length])
269
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]]))
270
+
271
+ # Reshape and slice out the padded elements.
272
+ x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
273
+ return x_final
274
+
275
+ def _absolute_position_to_relative_position(self, x):
276
+ """
277
+ x: [b, h, l, l]
278
+ ret: [b, h, l, 2*l-1]
279
+ """
280
+ batch, heads, length, _ = x.size()
281
+ # padd along column
282
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
283
+ x_flat = x.view([batch, heads, length**2 + length*(length -1)])
284
+ # add 0's in the beginning that will skew the elements after reshape
285
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
286
+ x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
287
+ return x_final
288
+
289
+ def _attention_bias_proximal(self, length):
290
+ """Bias for self-attention to encourage attention to close positions.
291
+ Args:
292
+ length: an integer scalar.
293
+ Returns:
294
+ a Tensor with shape [1, 1, length, length]
295
+ """
296
+ r = torch.arange(length, dtype=torch.float32)
297
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
298
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
299
+
300
+
301
+ class FFN(nn.Module):
302
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False):
303
+ super().__init__()
304
+ self.in_channels = in_channels
305
+ self.out_channels = out_channels
306
+ self.filter_channels = filter_channels
307
+ self.kernel_size = kernel_size
308
+ self.p_dropout = p_dropout
309
+ self.activation = activation
310
+ self.causal = causal
311
+
312
+ if causal:
313
+ self.padding = self._causal_padding
314
+ else:
315
+ self.padding = self._same_padding
316
+
317
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
318
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
319
+ self.drop = nn.Dropout(p_dropout)
320
+
321
+ def forward(self, x, x_mask):
322
+ x = self.conv_1(self.padding(x * x_mask))
323
+ if self.activation == "gelu":
324
+ x = x * torch.sigmoid(1.702 * x)
325
+ else:
326
+ x = torch.relu(x)
327
+ x = self.drop(x)
328
+ x = self.conv_2(self.padding(x * x_mask))
329
+ return x * x_mask
330
+
331
+ def _causal_padding(self, x):
332
+ if self.kernel_size == 1:
333
+ return x
334
+ pad_l = self.kernel_size - 1
335
+ pad_r = 0
336
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
337
+ x = F.pad(x, commons.convert_pad_shape(padding))
338
+ return x
339
+
340
+ def _same_padding(self, x):
341
+ if self.kernel_size == 1:
342
+ return x
343
+ pad_l = (self.kernel_size - 1) // 2
344
+ pad_r = self.kernel_size // 2
345
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
346
+ x = F.pad(x, commons.convert_pad_shape(padding))
347
+ return x
modules/commons.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ def slice_pitch_segments(x, ids_str, segment_size=4):
8
+ ret = torch.zeros_like(x[:, :segment_size])
9
+ for i in range(x.size(0)):
10
+ idx_str = ids_str[i]
11
+ idx_end = idx_str + segment_size
12
+ ret[i] = x[i, idx_str:idx_end]
13
+ return ret
14
+
15
+ def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4):
16
+ b, d, t = x.size()
17
+ if x_lengths is None:
18
+ x_lengths = t
19
+ ids_str_max = x_lengths - segment_size + 1
20
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
21
+ ret = slice_segments(x, ids_str, segment_size)
22
+ ret_pitch = slice_pitch_segments(pitch, ids_str, segment_size)
23
+ return ret, ret_pitch, ids_str
24
+
25
+ def init_weights(m, mean=0.0, std=0.01):
26
+ classname = m.__class__.__name__
27
+ if "Depthwise_Separable" in classname:
28
+ m.depth_conv.weight.data.normal_(mean, std)
29
+ m.point_conv.weight.data.normal_(mean, std)
30
+ elif classname.find("Conv") != -1:
31
+ m.weight.data.normal_(mean, std)
32
+
33
+ def get_padding(kernel_size, dilation=1):
34
+ return int((kernel_size*dilation - dilation)/2)
35
+
36
+
37
+ def convert_pad_shape(pad_shape):
38
+ l = pad_shape[::-1]
39
+ pad_shape = [item for sublist in l for item in sublist]
40
+ return pad_shape
41
+
42
+
43
+ def intersperse(lst, item):
44
+ result = [item] * (len(lst) * 2 + 1)
45
+ result[1::2] = lst
46
+ return result
47
+
48
+
49
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
50
+ """KL(P||Q)"""
51
+ kl = (logs_q - logs_p) - 0.5
52
+ kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
53
+ return kl
54
+
55
+
56
+ def rand_gumbel(shape):
57
+ """Sample from the Gumbel distribution, protect from overflows."""
58
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
59
+ return -torch.log(-torch.log(uniform_samples))
60
+
61
+
62
+ def rand_gumbel_like(x):
63
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
64
+ return g
65
+
66
+
67
+ def slice_segments(x, ids_str, segment_size=4):
68
+ ret = torch.zeros_like(x[:, :, :segment_size])
69
+ for i in range(x.size(0)):
70
+ idx_str = ids_str[i]
71
+ idx_end = idx_str + segment_size
72
+ ret[i] = x[i, :, idx_str:idx_end]
73
+ return ret
74
+
75
+
76
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
77
+ b, d, t = x.size()
78
+ if x_lengths is None:
79
+ x_lengths = t
80
+ ids_str_max = x_lengths - segment_size + 1
81
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
82
+ ret = slice_segments(x, ids_str, segment_size)
83
+ return ret, ids_str
84
+
85
+
86
+ def rand_spec_segments(x, x_lengths=None, segment_size=4):
87
+ b, d, t = x.size()
88
+ if x_lengths is None:
89
+ x_lengths = t
90
+ ids_str_max = x_lengths - segment_size
91
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
92
+ ret = slice_segments(x, ids_str, segment_size)
93
+ return ret, ids_str
94
+
95
+
96
+ def get_timing_signal_1d(
97
+ length, channels, min_timescale=1.0, max_timescale=1.0e4):
98
+ position = torch.arange(length, dtype=torch.float)
99
+ num_timescales = channels // 2
100
+ log_timescale_increment = (
101
+ math.log(float(max_timescale) / float(min_timescale)) /
102
+ (num_timescales - 1))
103
+ inv_timescales = min_timescale * torch.exp(
104
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
105
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
106
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
107
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
108
+ signal = signal.view(1, channels, length)
109
+ return signal
110
+
111
+
112
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
113
+ b, channels, length = x.size()
114
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
115
+ return x + signal.to(dtype=x.dtype, device=x.device)
116
+
117
+
118
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
119
+ b, channels, length = x.size()
120
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
121
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
122
+
123
+
124
+ def subsequent_mask(length):
125
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
126
+ return mask
127
+
128
+
129
+ @torch.jit.script
130
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
131
+ n_channels_int = n_channels[0]
132
+ in_act = input_a + input_b
133
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
134
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
135
+ acts = t_act * s_act
136
+ return acts
137
+
138
+
139
+ def shift_1d(x):
140
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
141
+ return x
142
+
143
+
144
+ def sequence_mask(length, max_length=None):
145
+ if max_length is None:
146
+ max_length = length.max()
147
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
148
+ return x.unsqueeze(0) < length.unsqueeze(1)
149
+
150
+
151
+ def generate_path(duration, mask):
152
+ """
153
+ duration: [b, 1, t_x]
154
+ mask: [b, 1, t_y, t_x]
155
+ """
156
+
157
+ b, _, t_y, t_x = mask.shape
158
+ cum_duration = torch.cumsum(duration, -1)
159
+
160
+ cum_duration_flat = cum_duration.view(b * t_x)
161
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
162
+ path = path.view(b, t_x, t_y)
163
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
164
+ path = path.unsqueeze(1).transpose(2,3) * mask
165
+ return path
166
+
167
+
168
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
169
+ if isinstance(parameters, torch.Tensor):
170
+ parameters = [parameters]
171
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
172
+ norm_type = float(norm_type)
173
+ if clip_value is not None:
174
+ clip_value = float(clip_value)
175
+
176
+ total_norm = 0
177
+ for p in parameters:
178
+ param_norm = p.grad.data.norm(norm_type)
179
+ total_norm += param_norm.item() ** norm_type
180
+ if clip_value is not None:
181
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
182
+ total_norm = total_norm ** (1. / norm_type)
183
+ return total_norm
modules/enhancer.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchaudio.transforms import Resample
5
+
6
+ from vdecoder.nsf_hifigan.models import load_model
7
+ from vdecoder.nsf_hifigan.nvSTFT import STFT
8
+
9
+
10
+ class Enhancer:
11
+ def __init__(self, enhancer_type, enhancer_ckpt, device=None):
12
+ if device is None:
13
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
+ self.device = device
15
+
16
+ if enhancer_type == 'nsf-hifigan':
17
+ self.enhancer = NsfHifiGAN(enhancer_ckpt, device=self.device)
18
+ else:
19
+ raise ValueError(f" [x] Unknown enhancer: {enhancer_type}")
20
+
21
+ self.resample_kernel = {}
22
+ self.enhancer_sample_rate = self.enhancer.sample_rate()
23
+ self.enhancer_hop_size = self.enhancer.hop_size()
24
+
25
+ def enhance(self,
26
+ audio, # 1, T
27
+ sample_rate,
28
+ f0, # 1, n_frames, 1
29
+ hop_size,
30
+ adaptive_key = 0,
31
+ silence_front = 0
32
+ ):
33
+ # enhancer start time
34
+ start_frame = int(silence_front * sample_rate / hop_size)
35
+ real_silence_front = start_frame * hop_size / sample_rate
36
+ audio = audio[:, int(np.round(real_silence_front * sample_rate)) : ]
37
+ f0 = f0[: , start_frame :, :]
38
+
39
+ # adaptive parameters
40
+ adaptive_factor = 2 ** ( -adaptive_key / 12)
41
+ adaptive_sample_rate = 100 * int(np.round(self.enhancer_sample_rate / adaptive_factor / 100))
42
+ real_factor = self.enhancer_sample_rate / adaptive_sample_rate
43
+
44
+ # resample the ddsp output
45
+ if sample_rate == adaptive_sample_rate:
46
+ audio_res = audio
47
+ else:
48
+ key_str = str(sample_rate) + str(adaptive_sample_rate)
49
+ if key_str not in self.resample_kernel:
50
+ self.resample_kernel[key_str] = Resample(sample_rate, adaptive_sample_rate, lowpass_filter_width = 128).to(self.device)
51
+ audio_res = self.resample_kernel[key_str](audio)
52
+
53
+ n_frames = int(audio_res.size(-1) // self.enhancer_hop_size + 1)
54
+
55
+ # resample f0
56
+ f0_np = f0.squeeze(0).squeeze(-1).cpu().numpy()
57
+ f0_np *= real_factor
58
+ time_org = (hop_size / sample_rate) * np.arange(len(f0_np)) / real_factor
59
+ time_frame = (self.enhancer_hop_size / self.enhancer_sample_rate) * np.arange(n_frames)
60
+ f0_res = np.interp(time_frame, time_org, f0_np, left=f0_np[0], right=f0_np[-1])
61
+ f0_res = torch.from_numpy(f0_res).unsqueeze(0).float().to(self.device) # 1, n_frames
62
+
63
+ # enhance
64
+ enhanced_audio, enhancer_sample_rate = self.enhancer(audio_res, f0_res)
65
+
66
+ # resample the enhanced output
67
+ if adaptive_factor != 0:
68
+ key_str = str(adaptive_sample_rate) + str(enhancer_sample_rate)
69
+ if key_str not in self.resample_kernel:
70
+ self.resample_kernel[key_str] = Resample(adaptive_sample_rate, enhancer_sample_rate, lowpass_filter_width = 128).to(self.device)
71
+ enhanced_audio = self.resample_kernel[key_str](enhanced_audio)
72
+
73
+ # pad the silence frames
74
+ if start_frame > 0:
75
+ enhanced_audio = F.pad(enhanced_audio, (int(np.round(enhancer_sample_rate * real_silence_front)), 0))
76
+
77
+ return enhanced_audio, enhancer_sample_rate
78
+
79
+
80
+ class NsfHifiGAN(torch.nn.Module):
81
+ def __init__(self, model_path, device=None):
82
+ super().__init__()
83
+ if device is None:
84
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
85
+ self.device = device
86
+ print('| Load HifiGAN: ', model_path)
87
+ self.model, self.h = load_model(model_path, device=self.device)
88
+
89
+ def sample_rate(self):
90
+ return self.h.sampling_rate
91
+
92
+ def hop_size(self):
93
+ return self.h.hop_size
94
+
95
+ def forward(self, audio, f0):
96
+ stft = STFT(
97
+ self.h.sampling_rate,
98
+ self.h.num_mels,
99
+ self.h.n_fft,
100
+ self.h.win_size,
101
+ self.h.hop_size,
102
+ self.h.fmin,
103
+ self.h.fmax)
104
+ with torch.no_grad():
105
+ mel = stft.get_mel(audio)
106
+ enhanced_audio = self.model(mel, f0[:,:mel.size(-1)]).view(-1)
107
+ return enhanced_audio, self.h.sampling_rate
modules/losses.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def feature_loss(fmap_r, fmap_g):
5
+ loss = 0
6
+ for dr, dg in zip(fmap_r, fmap_g):
7
+ for rl, gl in zip(dr, dg):
8
+ rl = rl.float().detach()
9
+ gl = gl.float()
10
+ loss += torch.mean(torch.abs(rl - gl))
11
+
12
+ return loss * 2
13
+
14
+
15
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
16
+ loss = 0
17
+ r_losses = []
18
+ g_losses = []
19
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
20
+ dr = dr.float()
21
+ dg = dg.float()
22
+ r_loss = torch.mean((1-dr)**2)
23
+ g_loss = torch.mean(dg**2)
24
+ loss += (r_loss + g_loss)
25
+ r_losses.append(r_loss.item())
26
+ g_losses.append(g_loss.item())
27
+
28
+ return loss, r_losses, g_losses
29
+
30
+
31
+ def generator_loss(disc_outputs):
32
+ loss = 0
33
+ gen_losses = []
34
+ for dg in disc_outputs:
35
+ dg = dg.float()
36
+ l = torch.mean((1-dg)**2)
37
+ gen_losses.append(l)
38
+ loss += l
39
+
40
+ return loss, gen_losses
41
+
42
+
43
+ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
44
+ """
45
+ z_p, logs_q: [b, h, t_t]
46
+ m_p, logs_p: [b, h, t_t]
47
+ """
48
+ z_p = z_p.float()
49
+ logs_q = logs_q.float()
50
+ m_p = m_p.float()
51
+ logs_p = logs_p.float()
52
+ z_mask = z_mask.float()
53
+ #print(logs_p)
54
+ kl = logs_p - logs_q - 0.5
55
+ kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
56
+ kl = torch.sum(kl * z_mask)
57
+ l = kl / torch.sum(z_mask)
58
+ return l
modules/mel_processing.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data
3
+ from librosa.filters import mel as librosa_mel_fn
4
+
5
+ MAX_WAV_VALUE = 32768.0
6
+
7
+
8
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
9
+ """
10
+ PARAMS
11
+ ------
12
+ C: compression factor
13
+ """
14
+ return torch.log(torch.clamp(x, min=clip_val) * C)
15
+
16
+
17
+ def dynamic_range_decompression_torch(x, C=1):
18
+ """
19
+ PARAMS
20
+ ------
21
+ C: compression factor used to compress
22
+ """
23
+ return torch.exp(x) / C
24
+
25
+
26
+ def spectral_normalize_torch(magnitudes):
27
+ output = dynamic_range_compression_torch(magnitudes)
28
+ return output
29
+
30
+
31
+ def spectral_de_normalize_torch(magnitudes):
32
+ output = dynamic_range_decompression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ mel_basis = {}
37
+ hann_window = {}
38
+
39
+
40
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
41
+ if torch.min(y) < -1.:
42
+ print('min value is ', torch.min(y))
43
+ if torch.max(y) > 1.:
44
+ print('max value is ', torch.max(y))
45
+
46
+ global hann_window
47
+ dtype_device = str(y.dtype) + '_' + str(y.device)
48
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
49
+ if wnsize_dtype_device not in hann_window:
50
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
51
+
52
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
53
+ y = y.squeeze(1)
54
+
55
+ y_dtype = y.dtype
56
+ if y.dtype == torch.bfloat16:
57
+ y = y.to(torch.float32)
58
+
59
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
60
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
61
+ spec = torch.view_as_real(spec).to(y_dtype)
62
+
63
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
64
+ return spec
65
+
66
+
67
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
68
+ global mel_basis
69
+ dtype_device = str(spec.dtype) + '_' + str(spec.device)
70
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
71
+ if fmax_dtype_device not in mel_basis:
72
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
73
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
74
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
75
+ spec = spectral_normalize_torch(spec)
76
+ return spec
77
+
78
+
79
+ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
80
+ spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center)
81
+ spec = spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax)
82
+
83
+ return spec
modules/modules.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ import modules.commons as commons
6
+ from modules.commons import get_padding, init_weights
7
+ from modules.DSConv import (
8
+ Depthwise_Separable_Conv1D,
9
+ remove_weight_norm_modules,
10
+ weight_norm_modules,
11
+ )
12
+
13
+ LRELU_SLOPE = 0.1
14
+
15
+ Conv1dModel = nn.Conv1d
16
+
17
+ def set_Conv1dModel(use_depthwise_conv):
18
+ global Conv1dModel
19
+ Conv1dModel = Depthwise_Separable_Conv1D if use_depthwise_conv else nn.Conv1d
20
+
21
+
22
+ class LayerNorm(nn.Module):
23
+ def __init__(self, channels, eps=1e-5):
24
+ super().__init__()
25
+ self.channels = channels
26
+ self.eps = eps
27
+
28
+ self.gamma = nn.Parameter(torch.ones(channels))
29
+ self.beta = nn.Parameter(torch.zeros(channels))
30
+
31
+ def forward(self, x):
32
+ x = x.transpose(1, -1)
33
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
34
+ return x.transpose(1, -1)
35
+
36
+
37
+ class ConvReluNorm(nn.Module):
38
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
39
+ super().__init__()
40
+ self.in_channels = in_channels
41
+ self.hidden_channels = hidden_channels
42
+ self.out_channels = out_channels
43
+ self.kernel_size = kernel_size
44
+ self.n_layers = n_layers
45
+ self.p_dropout = p_dropout
46
+ assert n_layers > 1, "Number of layers should be larger than 0."
47
+
48
+ self.conv_layers = nn.ModuleList()
49
+ self.norm_layers = nn.ModuleList()
50
+ self.conv_layers.append(Conv1dModel(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
51
+ self.norm_layers.append(LayerNorm(hidden_channels))
52
+ self.relu_drop = nn.Sequential(
53
+ nn.ReLU(),
54
+ nn.Dropout(p_dropout))
55
+ for _ in range(n_layers-1):
56
+ self.conv_layers.append(Conv1dModel(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
57
+ self.norm_layers.append(LayerNorm(hidden_channels))
58
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
59
+ self.proj.weight.data.zero_()
60
+ self.proj.bias.data.zero_()
61
+
62
+ def forward(self, x, x_mask):
63
+ x_org = x
64
+ for i in range(self.n_layers):
65
+ x = self.conv_layers[i](x * x_mask)
66
+ x = self.norm_layers[i](x)
67
+ x = self.relu_drop(x)
68
+ x = x_org + self.proj(x)
69
+ return x * x_mask
70
+
71
+
72
+ class WN(torch.nn.Module):
73
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
74
+ super(WN, self).__init__()
75
+ assert(kernel_size % 2 == 1)
76
+ self.hidden_channels =hidden_channels
77
+ self.kernel_size = kernel_size,
78
+ self.dilation_rate = dilation_rate
79
+ self.n_layers = n_layers
80
+ self.gin_channels = gin_channels
81
+ self.p_dropout = p_dropout
82
+
83
+ self.in_layers = torch.nn.ModuleList()
84
+ self.res_skip_layers = torch.nn.ModuleList()
85
+ self.drop = nn.Dropout(p_dropout)
86
+
87
+ if gin_channels != 0:
88
+ cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
89
+ self.cond_layer = weight_norm_modules(cond_layer, name='weight')
90
+
91
+ for i in range(n_layers):
92
+ dilation = dilation_rate ** i
93
+ padding = int((kernel_size * dilation - dilation) / 2)
94
+ in_layer = Conv1dModel(hidden_channels, 2*hidden_channels, kernel_size,
95
+ dilation=dilation, padding=padding)
96
+ in_layer = weight_norm_modules(in_layer, name='weight')
97
+ self.in_layers.append(in_layer)
98
+
99
+ # last one is not necessary
100
+ if i < n_layers - 1:
101
+ res_skip_channels = 2 * hidden_channels
102
+ else:
103
+ res_skip_channels = hidden_channels
104
+
105
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
106
+ res_skip_layer = weight_norm_modules(res_skip_layer, name='weight')
107
+ self.res_skip_layers.append(res_skip_layer)
108
+
109
+ def forward(self, x, x_mask, g=None, **kwargs):
110
+ output = torch.zeros_like(x)
111
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
112
+
113
+ if g is not None:
114
+ g = self.cond_layer(g)
115
+
116
+ for i in range(self.n_layers):
117
+ x_in = self.in_layers[i](x)
118
+ if g is not None:
119
+ cond_offset = i * 2 * self.hidden_channels
120
+ g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
121
+ else:
122
+ g_l = torch.zeros_like(x_in)
123
+
124
+ acts = commons.fused_add_tanh_sigmoid_multiply(
125
+ x_in,
126
+ g_l,
127
+ n_channels_tensor)
128
+ acts = self.drop(acts)
129
+
130
+ res_skip_acts = self.res_skip_layers[i](acts)
131
+ if i < self.n_layers - 1:
132
+ res_acts = res_skip_acts[:,:self.hidden_channels,:]
133
+ x = (x + res_acts) * x_mask
134
+ output = output + res_skip_acts[:,self.hidden_channels:,:]
135
+ else:
136
+ output = output + res_skip_acts
137
+ return output * x_mask
138
+
139
+ def remove_weight_norm(self):
140
+ if self.gin_channels != 0:
141
+ remove_weight_norm_modules(self.cond_layer)
142
+ for l in self.in_layers:
143
+ remove_weight_norm_modules(l)
144
+ for l in self.res_skip_layers:
145
+ remove_weight_norm_modules(l)
146
+
147
+
148
+ class ResBlock1(torch.nn.Module):
149
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
150
+ super(ResBlock1, self).__init__()
151
+ self.convs1 = nn.ModuleList([
152
+ weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0],
153
+ padding=get_padding(kernel_size, dilation[0]))),
154
+ weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1],
155
+ padding=get_padding(kernel_size, dilation[1]))),
156
+ weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[2],
157
+ padding=get_padding(kernel_size, dilation[2])))
158
+ ])
159
+ self.convs1.apply(init_weights)
160
+
161
+ self.convs2 = nn.ModuleList([
162
+ weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1,
163
+ padding=get_padding(kernel_size, 1))),
164
+ weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1,
165
+ padding=get_padding(kernel_size, 1))),
166
+ weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=1,
167
+ padding=get_padding(kernel_size, 1)))
168
+ ])
169
+ self.convs2.apply(init_weights)
170
+
171
+ def forward(self, x, x_mask=None):
172
+ for c1, c2 in zip(self.convs1, self.convs2):
173
+ xt = F.leaky_relu(x, LRELU_SLOPE)
174
+ if x_mask is not None:
175
+ xt = xt * x_mask
176
+ xt = c1(xt)
177
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
178
+ if x_mask is not None:
179
+ xt = xt * x_mask
180
+ xt = c2(xt)
181
+ x = xt + x
182
+ if x_mask is not None:
183
+ x = x * x_mask
184
+ return x
185
+
186
+ def remove_weight_norm(self):
187
+ for l in self.convs1:
188
+ remove_weight_norm_modules(l)
189
+ for l in self.convs2:
190
+ remove_weight_norm_modules(l)
191
+
192
+
193
+ class ResBlock2(torch.nn.Module):
194
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
195
+ super(ResBlock2, self).__init__()
196
+ self.convs = nn.ModuleList([
197
+ weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[0],
198
+ padding=get_padding(kernel_size, dilation[0]))),
199
+ weight_norm_modules(Conv1dModel(channels, channels, kernel_size, 1, dilation=dilation[1],
200
+ padding=get_padding(kernel_size, dilation[1])))
201
+ ])
202
+ self.convs.apply(init_weights)
203
+
204
+ def forward(self, x, x_mask=None):
205
+ for c in self.convs:
206
+ xt = F.leaky_relu(x, LRELU_SLOPE)
207
+ if x_mask is not None:
208
+ xt = xt * x_mask
209
+ xt = c(xt)
210
+ x = xt + x
211
+ if x_mask is not None:
212
+ x = x * x_mask
213
+ return x
214
+
215
+ def remove_weight_norm(self):
216
+ for l in self.convs:
217
+ remove_weight_norm_modules(l)
218
+
219
+
220
+ class Log(nn.Module):
221
+ def forward(self, x, x_mask, reverse=False, **kwargs):
222
+ if not reverse:
223
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
224
+ logdet = torch.sum(-y, [1, 2])
225
+ return y, logdet
226
+ else:
227
+ x = torch.exp(x) * x_mask
228
+ return x
229
+
230
+
231
+ class Flip(nn.Module):
232
+ def forward(self, x, *args, reverse=False, **kwargs):
233
+ x = torch.flip(x, [1])
234
+ if not reverse:
235
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
236
+ return x, logdet
237
+ else:
238
+ return x
239
+
240
+
241
+ class ElementwiseAffine(nn.Module):
242
+ def __init__(self, channels):
243
+ super().__init__()
244
+ self.channels = channels
245
+ self.m = nn.Parameter(torch.zeros(channels,1))
246
+ self.logs = nn.Parameter(torch.zeros(channels,1))
247
+
248
+ def forward(self, x, x_mask, reverse=False, **kwargs):
249
+ if not reverse:
250
+ y = self.m + torch.exp(self.logs) * x
251
+ y = y * x_mask
252
+ logdet = torch.sum(self.logs * x_mask, [1,2])
253
+ return y, logdet
254
+ else:
255
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
256
+ return x
257
+
258
+
259
+ class ResidualCouplingLayer(nn.Module):
260
+ def __init__(self,
261
+ channels,
262
+ hidden_channels,
263
+ kernel_size,
264
+ dilation_rate,
265
+ n_layers,
266
+ p_dropout=0,
267
+ gin_channels=0,
268
+ mean_only=False,
269
+ wn_sharing_parameter=None
270
+ ):
271
+ assert channels % 2 == 0, "channels should be divisible by 2"
272
+ super().__init__()
273
+ self.channels = channels
274
+ self.hidden_channels = hidden_channels
275
+ self.kernel_size = kernel_size
276
+ self.dilation_rate = dilation_rate
277
+ self.n_layers = n_layers
278
+ self.half_channels = channels // 2
279
+ self.mean_only = mean_only
280
+
281
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
282
+ self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) if wn_sharing_parameter is None else wn_sharing_parameter
283
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
284
+ self.post.weight.data.zero_()
285
+ self.post.bias.data.zero_()
286
+
287
+ def forward(self, x, x_mask, g=None, reverse=False):
288
+ x0, x1 = torch.split(x, [self.half_channels]*2, 1)
289
+ h = self.pre(x0) * x_mask
290
+ h = self.enc(h, x_mask, g=g)
291
+ stats = self.post(h) * x_mask
292
+ if not self.mean_only:
293
+ m, logs = torch.split(stats, [self.half_channels]*2, 1)
294
+ else:
295
+ m = stats
296
+ logs = torch.zeros_like(m)
297
+
298
+ if not reverse:
299
+ x1 = m + x1 * torch.exp(logs) * x_mask
300
+ x = torch.cat([x0, x1], 1)
301
+ logdet = torch.sum(logs, [1,2])
302
+ return x, logdet
303
+ else:
304
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
305
+ x = torch.cat([x0, x1], 1)
306
+ return x
modules/slicer2.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ # This function is obtained from librosa.
5
+ def get_rms(
6
+ y,
7
+ *,
8
+ frame_length=2048,
9
+ hop_length=512,
10
+ pad_mode="constant",
11
+ ):
12
+ padding = (int(frame_length // 2), int(frame_length // 2))
13
+ y = np.pad(y, padding, mode=pad_mode)
14
+
15
+ axis = -1
16
+ # put our new within-frame axis at the end for now
17
+ out_strides = y.strides + tuple([y.strides[axis]])
18
+ # Reduce the shape on the framing axis
19
+ x_shape_trimmed = list(y.shape)
20
+ x_shape_trimmed[axis] -= frame_length - 1
21
+ out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
22
+ xw = np.lib.stride_tricks.as_strided(
23
+ y, shape=out_shape, strides=out_strides
24
+ )
25
+ if axis < 0:
26
+ target_axis = axis - 1
27
+ else:
28
+ target_axis = axis + 1
29
+ xw = np.moveaxis(xw, -1, target_axis)
30
+ # Downsample along the target axis
31
+ slices = [slice(None)] * xw.ndim
32
+ slices[axis] = slice(0, None, hop_length)
33
+ x = xw[tuple(slices)]
34
+
35
+ # Calculate power
36
+ power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
37
+
38
+ return np.sqrt(power)
39
+
40
+
41
+ class Slicer:
42
+ def __init__(self,
43
+ sr: int,
44
+ threshold: float = -40.,
45
+ min_length: int = 5000,
46
+ min_interval: int = 300,
47
+ hop_size: int = 20,
48
+ max_sil_kept: int = 5000):
49
+ if not min_length >= min_interval >= hop_size:
50
+ raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size')
51
+ if not max_sil_kept >= hop_size:
52
+ raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size')
53
+ min_interval = sr * min_interval / 1000
54
+ self.threshold = 10 ** (threshold / 20.)
55
+ self.hop_size = round(sr * hop_size / 1000)
56
+ self.win_size = min(round(min_interval), 4 * self.hop_size)
57
+ self.min_length = round(sr * min_length / 1000 / self.hop_size)
58
+ self.min_interval = round(min_interval / self.hop_size)
59
+ self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
60
+
61
+ def _apply_slice(self, waveform, begin, end):
62
+ if len(waveform.shape) > 1:
63
+ return waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)]
64
+ else:
65
+ return waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)]
66
+
67
+ # @timeit
68
+ def slice(self, waveform):
69
+ if len(waveform.shape) > 1:
70
+ samples = waveform.mean(axis=0)
71
+ else:
72
+ samples = waveform
73
+ if samples.shape[0] <= self.min_length:
74
+ return [waveform]
75
+ rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
76
+ sil_tags = []
77
+ silence_start = None
78
+ clip_start = 0
79
+ for i, rms in enumerate(rms_list):
80
+ # Keep looping while frame is silent.
81
+ if rms < self.threshold:
82
+ # Record start of silent frames.
83
+ if silence_start is None:
84
+ silence_start = i
85
+ continue
86
+ # Keep looping while frame is not silent and silence start has not been recorded.
87
+ if silence_start is None:
88
+ continue
89
+ # Clear recorded silence start if interval is not enough or clip is too short
90
+ is_leading_silence = silence_start == 0 and i > self.max_sil_kept
91
+ need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
92
+ if not is_leading_silence and not need_slice_middle:
93
+ silence_start = None
94
+ continue
95
+ # Need slicing. Record the range of silent frames to be removed.
96
+ if i - silence_start <= self.max_sil_kept:
97
+ pos = rms_list[silence_start: i + 1].argmin() + silence_start
98
+ if silence_start == 0:
99
+ sil_tags.append((0, pos))
100
+ else:
101
+ sil_tags.append((pos, pos))
102
+ clip_start = pos
103
+ elif i - silence_start <= self.max_sil_kept * 2:
104
+ pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin()
105
+ pos += i - self.max_sil_kept
106
+ pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
107
+ pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
108
+ if silence_start == 0:
109
+ sil_tags.append((0, pos_r))
110
+ clip_start = pos_r
111
+ else:
112
+ sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
113
+ clip_start = max(pos_r, pos)
114
+ else:
115
+ pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
116
+ pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
117
+ if silence_start == 0:
118
+ sil_tags.append((0, pos_r))
119
+ else:
120
+ sil_tags.append((pos_l, pos_r))
121
+ clip_start = pos_r
122
+ silence_start = None
123
+ # Deal with trailing silence.
124
+ total_frames = rms_list.shape[0]
125
+ if silence_start is not None and total_frames - silence_start >= self.min_interval:
126
+ silence_end = min(total_frames, silence_start + self.max_sil_kept)
127
+ pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start
128
+ sil_tags.append((pos, total_frames + 1))
129
+ # Apply and return slices.
130
+ if len(sil_tags) == 0:
131
+ return [waveform]
132
+ else:
133
+ chunks = []
134
+ if sil_tags[0][0] > 0:
135
+ chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0]))
136
+ for i in range(len(sil_tags) - 1):
137
+ chunks.append(self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]))
138
+ if sil_tags[-1][1] < total_frames:
139
+ chunks.append(self._apply_slice(waveform, sil_tags[-1][1], total_frames))
140
+ return chunks
141
+
142
+
143
+ def main():
144
+ import os.path
145
+ from argparse import ArgumentParser
146
+
147
+ import librosa
148
+ import soundfile
149
+
150
+ parser = ArgumentParser()
151
+ parser.add_argument('audio', type=str, help='The audio to be sliced')
152
+ parser.add_argument('--out', type=str, help='Output directory of the sliced audio clips')
153
+ parser.add_argument('--db_thresh', type=float, required=False, default=-40,
154
+ help='The dB threshold for silence detection')
155
+ parser.add_argument('--min_length', type=int, required=False, default=5000,
156
+ help='The minimum milliseconds required for each sliced audio clip')
157
+ parser.add_argument('--min_interval', type=int, required=False, default=300,
158
+ help='The minimum milliseconds for a silence part to be sliced')
159
+ parser.add_argument('--hop_size', type=int, required=False, default=10,
160
+ help='Frame length in milliseconds')
161
+ parser.add_argument('--max_sil_kept', type=int, required=False, default=500,
162
+ help='The maximum silence length kept around the sliced clip, presented in milliseconds')
163
+ args = parser.parse_args()
164
+ out = args.out
165
+ if out is None:
166
+ out = os.path.dirname(os.path.abspath(args.audio))
167
+ audio, sr = librosa.load(args.audio, sr=None, mono=False)
168
+ slicer = Slicer(
169
+ sr=sr,
170
+ threshold=args.db_thresh,
171
+ min_length=args.min_length,
172
+ min_interval=args.min_interval,
173
+ hop_size=args.hop_size,
174
+ max_sil_kept=args.max_sil_kept
175
+ )
176
+ chunks = slicer.slice(audio)
177
+ if not os.path.exists(out):
178
+ os.makedirs(out)
179
+ for i, chunk in enumerate(chunks):
180
+ if len(chunk.shape) > 1:
181
+ chunk = chunk.T
182
+ soundfile.write(os.path.join(out, f'%s_%d.wav' % (os.path.basename(args.audio).rsplit('.', maxsplit=1)[0], i)), chunk, sr)
183
+
184
+
185
+ if __name__ == '__main__':
186
+ main()
onnxexport/__pycache__/model_onnx.cpython-38.pyc ADDED
Binary file (8.87 kB). View file
 
onnxexport/__pycache__/model_onnx_speaker_mix.cpython-38.pyc ADDED
Binary file (6.16 kB). View file
 
onnxexport/model_onnx.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import Conv1d, Conv2d
4
+ from torch.nn import functional as F
5
+ from torch.nn.utils import spectral_norm, weight_norm
6
+
7
+ import modules.attentions as attentions
8
+ import modules.commons as commons
9
+ import modules.modules as modules
10
+ import utils
11
+ from modules.commons import get_padding
12
+ from utils import f0_to_coarse
13
+ from vdecoder.hifigan.models import Generator
14
+
15
+
16
+ class ResidualCouplingBlock(nn.Module):
17
+ def __init__(self,
18
+ channels,
19
+ hidden_channels,
20
+ kernel_size,
21
+ dilation_rate,
22
+ n_layers,
23
+ n_flows=4,
24
+ gin_channels=0):
25
+ super().__init__()
26
+ self.channels = channels
27
+ self.hidden_channels = hidden_channels
28
+ self.kernel_size = kernel_size
29
+ self.dilation_rate = dilation_rate
30
+ self.n_layers = n_layers
31
+ self.n_flows = n_flows
32
+ self.gin_channels = gin_channels
33
+
34
+ self.flows = nn.ModuleList()
35
+ for i in range(n_flows):
36
+ self.flows.append(
37
+ modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers,
38
+ gin_channels=gin_channels, mean_only=True))
39
+ self.flows.append(modules.Flip())
40
+
41
+ def forward(self, x, x_mask, g=None, reverse=False):
42
+ if not reverse:
43
+ for flow in self.flows:
44
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
45
+ else:
46
+ for flow in reversed(self.flows):
47
+ x = flow(x, x_mask, g=g, reverse=reverse)
48
+ return x
49
+
50
+
51
+ class Encoder(nn.Module):
52
+ def __init__(self,
53
+ in_channels,
54
+ out_channels,
55
+ hidden_channels,
56
+ kernel_size,
57
+ dilation_rate,
58
+ n_layers,
59
+ gin_channels=0):
60
+ super().__init__()
61
+ self.in_channels = in_channels
62
+ self.out_channels = out_channels
63
+ self.hidden_channels = hidden_channels
64
+ self.kernel_size = kernel_size
65
+ self.dilation_rate = dilation_rate
66
+ self.n_layers = n_layers
67
+ self.gin_channels = gin_channels
68
+
69
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
70
+ self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
71
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
72
+
73
+ def forward(self, x, x_lengths, g=None):
74
+ # print(x.shape,x_lengths.shape)
75
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
76
+ x = self.pre(x) * x_mask
77
+ x = self.enc(x, x_mask, g=g)
78
+ stats = self.proj(x) * x_mask
79
+ m, logs = torch.split(stats, self.out_channels, dim=1)
80
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
81
+ return z, m, logs, x_mask
82
+
83
+
84
+ class TextEncoder(nn.Module):
85
+ def __init__(self,
86
+ out_channels,
87
+ hidden_channels,
88
+ kernel_size,
89
+ n_layers,
90
+ gin_channels=0,
91
+ filter_channels=None,
92
+ n_heads=None,
93
+ p_dropout=None):
94
+ super().__init__()
95
+ self.out_channels = out_channels
96
+ self.hidden_channels = hidden_channels
97
+ self.kernel_size = kernel_size
98
+ self.n_layers = n_layers
99
+ self.gin_channels = gin_channels
100
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
101
+ self.f0_emb = nn.Embedding(256, hidden_channels)
102
+
103
+ self.enc_ = attentions.Encoder(
104
+ hidden_channels,
105
+ filter_channels,
106
+ n_heads,
107
+ n_layers,
108
+ kernel_size,
109
+ p_dropout)
110
+
111
+ def forward(self, x, x_mask, f0=None, z=None):
112
+ x = x + self.f0_emb(f0).transpose(1, 2)
113
+ x = self.enc_(x * x_mask, x_mask)
114
+ stats = self.proj(x) * x_mask
115
+ m, logs = torch.split(stats, self.out_channels, dim=1)
116
+ z = (m + z * torch.exp(logs)) * x_mask
117
+ return z, m, logs, x_mask
118
+
119
+
120
+ class DiscriminatorP(torch.nn.Module):
121
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
122
+ super(DiscriminatorP, self).__init__()
123
+ self.period = period
124
+ self.use_spectral_norm = use_spectral_norm
125
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
126
+ self.convs = nn.ModuleList([
127
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
128
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
129
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
130
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
131
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
132
+ ])
133
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
134
+
135
+ def forward(self, x):
136
+ fmap = []
137
+
138
+ # 1d to 2d
139
+ b, c, t = x.shape
140
+ if t % self.period != 0: # pad first
141
+ n_pad = self.period - (t % self.period)
142
+ x = F.pad(x, (0, n_pad), "reflect")
143
+ t = t + n_pad
144
+ x = x.view(b, c, t // self.period, self.period)
145
+
146
+ for l in self.convs:
147
+ x = l(x)
148
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
149
+ fmap.append(x)
150
+ x = self.conv_post(x)
151
+ fmap.append(x)
152
+ x = torch.flatten(x, 1, -1)
153
+
154
+ return x, fmap
155
+
156
+
157
+ class DiscriminatorS(torch.nn.Module):
158
+ def __init__(self, use_spectral_norm=False):
159
+ super(DiscriminatorS, self).__init__()
160
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
161
+ self.convs = nn.ModuleList([
162
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
163
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
164
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
165
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
166
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
167
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
168
+ ])
169
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
170
+
171
+ def forward(self, x):
172
+ fmap = []
173
+
174
+ for l in self.convs:
175
+ x = l(x)
176
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
177
+ fmap.append(x)
178
+ x = self.conv_post(x)
179
+ fmap.append(x)
180
+ x = torch.flatten(x, 1, -1)
181
+
182
+ return x, fmap
183
+
184
+
185
+ class F0Decoder(nn.Module):
186
+ def __init__(self,
187
+ out_channels,
188
+ hidden_channels,
189
+ filter_channels,
190
+ n_heads,
191
+ n_layers,
192
+ kernel_size,
193
+ p_dropout,
194
+ spk_channels=0):
195
+ super().__init__()
196
+ self.out_channels = out_channels
197
+ self.hidden_channels = hidden_channels
198
+ self.filter_channels = filter_channels
199
+ self.n_heads = n_heads
200
+ self.n_layers = n_layers
201
+ self.kernel_size = kernel_size
202
+ self.p_dropout = p_dropout
203
+ self.spk_channels = spk_channels
204
+
205
+ self.prenet = nn.Conv1d(hidden_channels, hidden_channels, 3, padding=1)
206
+ self.decoder = attentions.FFT(
207
+ hidden_channels,
208
+ filter_channels,
209
+ n_heads,
210
+ n_layers,
211
+ kernel_size,
212
+ p_dropout)
213
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
214
+ self.f0_prenet = nn.Conv1d(1, hidden_channels, 3, padding=1)
215
+ self.cond = nn.Conv1d(spk_channels, hidden_channels, 1)
216
+
217
+ def forward(self, x, norm_f0, x_mask, spk_emb=None):
218
+ x = torch.detach(x)
219
+ if spk_emb is not None:
220
+ x = x + self.cond(spk_emb)
221
+ x += self.f0_prenet(norm_f0)
222
+ x = self.prenet(x) * x_mask
223
+ x = self.decoder(x * x_mask, x_mask)
224
+ x = self.proj(x) * x_mask
225
+ return x
226
+
227
+
228
+ class SynthesizerTrn(nn.Module):
229
+ """
230
+ Synthesizer for Training
231
+ """
232
+
233
+ def __init__(self,
234
+ spec_channels,
235
+ segment_size,
236
+ inter_channels,
237
+ hidden_channels,
238
+ filter_channels,
239
+ n_heads,
240
+ n_layers,
241
+ kernel_size,
242
+ p_dropout,
243
+ resblock,
244
+ resblock_kernel_sizes,
245
+ resblock_dilation_sizes,
246
+ upsample_rates,
247
+ upsample_initial_channel,
248
+ upsample_kernel_sizes,
249
+ gin_channels,
250
+ ssl_dim,
251
+ n_speakers,
252
+ sampling_rate=44100,
253
+ **kwargs):
254
+ super().__init__()
255
+ self.spec_channels = spec_channels
256
+ self.inter_channels = inter_channels
257
+ self.hidden_channels = hidden_channels
258
+ self.filter_channels = filter_channels
259
+ self.n_heads = n_heads
260
+ self.n_layers = n_layers
261
+ self.kernel_size = kernel_size
262
+ self.p_dropout = p_dropout
263
+ self.resblock = resblock
264
+ self.resblock_kernel_sizes = resblock_kernel_sizes
265
+ self.resblock_dilation_sizes = resblock_dilation_sizes
266
+ self.upsample_rates = upsample_rates
267
+ self.upsample_initial_channel = upsample_initial_channel
268
+ self.upsample_kernel_sizes = upsample_kernel_sizes
269
+ self.segment_size = segment_size
270
+ self.gin_channels = gin_channels
271
+ self.ssl_dim = ssl_dim
272
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
273
+
274
+ self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2)
275
+
276
+ self.enc_p = TextEncoder(
277
+ inter_channels,
278
+ hidden_channels,
279
+ filter_channels=filter_channels,
280
+ n_heads=n_heads,
281
+ n_layers=n_layers,
282
+ kernel_size=kernel_size,
283
+ p_dropout=p_dropout
284
+ )
285
+ hps = {
286
+ "sampling_rate": sampling_rate,
287
+ "inter_channels": inter_channels,
288
+ "resblock": resblock,
289
+ "resblock_kernel_sizes": resblock_kernel_sizes,
290
+ "resblock_dilation_sizes": resblock_dilation_sizes,
291
+ "upsample_rates": upsample_rates,
292
+ "upsample_initial_channel": upsample_initial_channel,
293
+ "upsample_kernel_sizes": upsample_kernel_sizes,
294
+ "gin_channels": gin_channels,
295
+ }
296
+ self.dec = Generator(h=hps)
297
+ self.enc_q = Encoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
298
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
299
+ self.f0_decoder = F0Decoder(
300
+ 1,
301
+ hidden_channels,
302
+ filter_channels,
303
+ n_heads,
304
+ n_layers,
305
+ kernel_size,
306
+ p_dropout,
307
+ spk_channels=gin_channels
308
+ )
309
+ self.emb_uv = nn.Embedding(2, hidden_channels)
310
+ self.predict_f0 = False
311
+
312
+ def forward(self, c, f0, mel2ph, uv, noise=None, g=None):
313
+
314
+ decoder_inp = F.pad(c, [0, 0, 1, 0])
315
+ mel2ph_ = mel2ph.unsqueeze(2).repeat([1, 1, c.shape[-1]])
316
+ c = torch.gather(decoder_inp, 1, mel2ph_).transpose(1, 2) # [B, T, H]
317
+
318
+ c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
319
+ g = g.unsqueeze(0)
320
+ g = self.emb_g(g).transpose(1, 2)
321
+ x_mask = torch.unsqueeze(commons.sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype)
322
+ x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2)
323
+
324
+ if self.predict_f0:
325
+ lf0 = 2595. * torch.log10(1. + f0.unsqueeze(1) / 700.) / 500
326
+ norm_lf0 = utils.normalize_f0(lf0, x_mask, uv, random_scale=False)
327
+ pred_lf0 = self.f0_decoder(x, norm_lf0, x_mask, spk_emb=g)
328
+ f0 = (700 * (torch.pow(10, pred_lf0 * 500 / 2595) - 1)).squeeze(1)
329
+
330
+ z_p, m_p, logs_p, c_mask = self.enc_p(x, x_mask, f0=f0_to_coarse(f0), z=noise)
331
+ z = self.flow(z_p, c_mask, g=g, reverse=True)
332
+ o = self.dec(z * c_mask, g=g, f0=f0)
333
+ return o
onnxexport/model_onnx_speaker_mix.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ import modules.attentions as attentions
6
+ import modules.modules as modules
7
+ from utils import f0_to_coarse
8
+
9
+
10
+ class ResidualCouplingBlock(nn.Module):
11
+ def __init__(self,
12
+ channels,
13
+ hidden_channels,
14
+ kernel_size,
15
+ dilation_rate,
16
+ n_layers,
17
+ n_flows=4,
18
+ gin_channels=0):
19
+ super().__init__()
20
+ self.channels = channels
21
+ self.hidden_channels = hidden_channels
22
+ self.kernel_size = kernel_size
23
+ self.dilation_rate = dilation_rate
24
+ self.n_layers = n_layers
25
+ self.n_flows = n_flows
26
+ self.gin_channels = gin_channels
27
+
28
+ self.flows = nn.ModuleList()
29
+ for i in range(n_flows):
30
+ self.flows.append(
31
+ modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers,
32
+ gin_channels=gin_channels, mean_only=True))
33
+ self.flows.append(modules.Flip())
34
+
35
+ def forward(self, x, x_mask, g=None, reverse=False):
36
+ if not reverse:
37
+ for flow in self.flows:
38
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
39
+ else:
40
+ for flow in reversed(self.flows):
41
+ x = flow(x, x_mask, g=g, reverse=reverse)
42
+ return x
43
+
44
+
45
+ class TextEncoder(nn.Module):
46
+ def __init__(self,
47
+ out_channels,
48
+ hidden_channels,
49
+ kernel_size,
50
+ n_layers,
51
+ gin_channels=0,
52
+ filter_channels=None,
53
+ n_heads=None,
54
+ p_dropout=None):
55
+ super().__init__()
56
+ self.out_channels = out_channels
57
+ self.hidden_channels = hidden_channels
58
+ self.kernel_size = kernel_size
59
+ self.n_layers = n_layers
60
+ self.gin_channels = gin_channels
61
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
62
+ self.f0_emb = nn.Embedding(256, hidden_channels)
63
+
64
+ self.enc_ = attentions.Encoder(
65
+ hidden_channels,
66
+ filter_channels,
67
+ n_heads,
68
+ n_layers,
69
+ kernel_size,
70
+ p_dropout)
71
+
72
+ def forward(self, x, x_mask, f0=None, z=None):
73
+ x = x + self.f0_emb(f0).transpose(1, 2)
74
+ x = self.enc_(x * x_mask, x_mask)
75
+ stats = self.proj(x) * x_mask
76
+ m, logs = torch.split(stats, self.out_channels, dim=1)
77
+ z = (m + z * torch.exp(logs)) * x_mask
78
+
79
+ return z, m, logs, x_mask
80
+
81
+
82
+ class F0Decoder(nn.Module):
83
+ def __init__(self,
84
+ out_channels,
85
+ hidden_channels,
86
+ filter_channels,
87
+ n_heads,
88
+ n_layers,
89
+ kernel_size,
90
+ p_dropout,
91
+ spk_channels=0):
92
+ super().__init__()
93
+ self.out_channels = out_channels
94
+ self.hidden_channels = hidden_channels
95
+ self.filter_channels = filter_channels
96
+ self.n_heads = n_heads
97
+ self.n_layers = n_layers
98
+ self.kernel_size = kernel_size
99
+ self.p_dropout = p_dropout
100
+ self.spk_channels = spk_channels
101
+
102
+ self.prenet = nn.Conv1d(hidden_channels, hidden_channels, 3, padding=1)
103
+ self.decoder = attentions.FFT(
104
+ hidden_channels,
105
+ filter_channels,
106
+ n_heads,
107
+ n_layers,
108
+ kernel_size,
109
+ p_dropout)
110
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
111
+ self.f0_prenet = nn.Conv1d(1, hidden_channels, 3, padding=1)
112
+ self.cond = nn.Conv1d(spk_channels, hidden_channels, 1)
113
+
114
+ def forward(self, x, norm_f0, x_mask, spk_emb=None):
115
+ x = torch.detach(x)
116
+ if (spk_emb is not None):
117
+ x = x + self.cond(spk_emb)
118
+ x += self.f0_prenet(norm_f0)
119
+ x = self.prenet(x) * x_mask
120
+ x = self.decoder(x * x_mask, x_mask)
121
+ x = self.proj(x) * x_mask
122
+ return x
123
+
124
+
125
+ class SynthesizerTrn(nn.Module):
126
+ """
127
+ Synthesizer for Training
128
+ """
129
+
130
+ def __init__(self,
131
+ spec_channels,
132
+ segment_size,
133
+ inter_channels,
134
+ hidden_channels,
135
+ filter_channels,
136
+ n_heads,
137
+ n_layers,
138
+ kernel_size,
139
+ p_dropout,
140
+ resblock,
141
+ resblock_kernel_sizes,
142
+ resblock_dilation_sizes,
143
+ upsample_rates,
144
+ upsample_initial_channel,
145
+ upsample_kernel_sizes,
146
+ gin_channels,
147
+ ssl_dim,
148
+ n_speakers,
149
+ sampling_rate=44100,
150
+ vol_embedding=False,
151
+ vocoder_name = "nsf-hifigan",
152
+ **kwargs):
153
+
154
+ super().__init__()
155
+ self.spec_channels = spec_channels
156
+ self.inter_channels = inter_channels
157
+ self.hidden_channels = hidden_channels
158
+ self.filter_channels = filter_channels
159
+ self.n_heads = n_heads
160
+ self.n_layers = n_layers
161
+ self.kernel_size = kernel_size
162
+ self.p_dropout = p_dropout
163
+ self.resblock = resblock
164
+ self.resblock_kernel_sizes = resblock_kernel_sizes
165
+ self.resblock_dilation_sizes = resblock_dilation_sizes
166
+ self.upsample_rates = upsample_rates
167
+ self.upsample_initial_channel = upsample_initial_channel
168
+ self.upsample_kernel_sizes = upsample_kernel_sizes
169
+ self.segment_size = segment_size
170
+ self.gin_channels = gin_channels
171
+ self.ssl_dim = ssl_dim
172
+ self.vol_embedding = vol_embedding
173
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
174
+ if vol_embedding:
175
+ self.emb_vol = nn.Linear(1, hidden_channels)
176
+
177
+ self.pre = nn.Conv1d(ssl_dim, hidden_channels, kernel_size=5, padding=2)
178
+
179
+ self.enc_p = TextEncoder(
180
+ inter_channels,
181
+ hidden_channels,
182
+ filter_channels=filter_channels,
183
+ n_heads=n_heads,
184
+ n_layers=n_layers,
185
+ kernel_size=kernel_size,
186
+ p_dropout=p_dropout
187
+ )
188
+ hps = {
189
+ "sampling_rate": sampling_rate,
190
+ "inter_channels": inter_channels,
191
+ "resblock": resblock,
192
+ "resblock_kernel_sizes": resblock_kernel_sizes,
193
+ "resblock_dilation_sizes": resblock_dilation_sizes,
194
+ "upsample_rates": upsample_rates,
195
+ "upsample_initial_channel": upsample_initial_channel,
196
+ "upsample_kernel_sizes": upsample_kernel_sizes,
197
+ "gin_channels": gin_channels,
198
+ }
199
+
200
+ if vocoder_name == "nsf-hifigan":
201
+ from vdecoder.hifigan.models import Generator
202
+ self.dec = Generator(h=hps)
203
+ elif vocoder_name == "nsf-snake-hifigan":
204
+ from vdecoder.hifiganwithsnake.models import Generator
205
+ self.dec = Generator(h=hps)
206
+ else:
207
+ print("[?] Unkown vocoder: use default(nsf-hifigan)")
208
+ from vdecoder.hifigan.models import Generator
209
+ self.dec = Generator(h=hps)
210
+
211
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
212
+ self.f0_decoder = F0Decoder(
213
+ 1,
214
+ hidden_channels,
215
+ filter_channels,
216
+ n_heads,
217
+ n_layers,
218
+ kernel_size,
219
+ p_dropout,
220
+ spk_channels=gin_channels
221
+ )
222
+ self.emb_uv = nn.Embedding(2, hidden_channels)
223
+ self.predict_f0 = False
224
+ self.speaker_map = []
225
+ self.export_mix = False
226
+
227
+ def export_chara_mix(self, speakers_mix):
228
+ self.speaker_map = torch.zeros((len(speakers_mix), 1, 1, self.gin_channels))
229
+ i = 0
230
+ for key in speakers_mix.keys():
231
+ spkidx = speakers_mix[key]
232
+ self.speaker_map[i] = self.emb_g(torch.LongTensor([[spkidx]]))
233
+ i = i + 1
234
+ self.speaker_map = self.speaker_map.unsqueeze(0)
235
+ self.export_mix = True
236
+
237
+ def forward(self, c, f0, mel2ph, uv, noise=None, g=None, vol = None):
238
+ decoder_inp = F.pad(c, [0, 0, 1, 0])
239
+ mel2ph_ = mel2ph.unsqueeze(2).repeat([1, 1, c.shape[-1]])
240
+ c = torch.gather(decoder_inp, 1, mel2ph_).transpose(1, 2) # [B, T, H]
241
+
242
+ if self.export_mix: # [N, S] * [S, B, 1, H]
243
+ g = g.reshape((g.shape[0], g.shape[1], 1, 1, 1)) # [N, S, B, 1, 1]
244
+ g = g * self.speaker_map # [N, S, B, 1, H]
245
+ g = torch.sum(g, dim=1) # [N, 1, B, 1, H]
246
+ g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N]
247
+ else:
248
+ if g.dim() == 1:
249
+ g = g.unsqueeze(0)
250
+ g = self.emb_g(g).transpose(1, 2)
251
+
252
+ x_mask = torch.unsqueeze(torch.ones_like(f0), 1).to(c.dtype)
253
+ # vol proj
254
+ vol = self.emb_vol(vol[:,:,None]).transpose(1,2) if vol is not None and self.vol_embedding else 0
255
+
256
+ x = self.pre(c) * x_mask + self.emb_uv(uv.long()).transpose(1, 2) + vol
257
+
258
+ z_p, m_p, logs_p, c_mask = self.enc_p(x, x_mask, f0=f0_to_coarse(f0), z=noise)
259
+ z = self.flow(z_p, c_mask, g=g, reverse=True)
260
+ o = self.dec(z * c_mask, g=g, f0=f0)
261
+ return o
262
+