PushkarA07 CjangCjengh commited on
Commit
257783b
0 Parent(s):

Duplicate from CjangCjengh/Sanskrit-TTS

Browse files

Co-authored-by: CjangCjengh <CjangCjengh@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zst filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Sanskrit TTS
3
+ emoji: 👀
4
+ colorFrom: blue
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.3.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: gpl-3.0
11
+ duplicated_from: CjangCjengh/Sanskrit-TTS
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__pycache__/attentions.cpython-37.pyc ADDED
Binary file (9.62 kB). View file
 
__pycache__/commons.cpython-37.pyc ADDED
Binary file (3.28 kB). View file
 
__pycache__/mel_processing.cpython-37.pyc ADDED
Binary file (3.07 kB). View file
 
__pycache__/models.cpython-37.pyc ADDED
Binary file (15.4 kB). View file
 
__pycache__/modules.cpython-37.pyc ADDED
Binary file (11.5 kB). View file
 
__pycache__/transforms.cpython-37.pyc ADDED
Binary file (3.85 kB). View file
 
__pycache__/utils.cpython-37.pyc ADDED
Binary file (2.74 kB). View file
 
app.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import librosa
3
+ import commons
4
+ import utils
5
+ from models import SynthesizerTrn
6
+ from text import text_to_sequence
7
+ import numpy as np
8
+ from mel_processing import spectrogram_torch
9
+ import gradio as gr
10
+ from indic_transliteration import sanscript
11
+
12
+
13
+ SCRIPT_DICT={
14
+ 'Devanagari':sanscript.DEVANAGARI,
15
+ 'IAST':sanscript.IAST,
16
+ 'SLP1':sanscript.SLP1,
17
+ 'HK':sanscript.HK
18
+ }
19
+
20
+ DEFAULT_TEXT='संस्कृतम् जगतः एकतमा अतिप्राचीना समृद्धा शास्त्रीया च भाषासु वर्तते । संस्कृतं भारतस्य जगत: वा भाषासु एकतमा‌ प्राचीनतमा ।'
21
+
22
+
23
+ def get_text(text, hps, cleaned=False):
24
+ if cleaned:
25
+ text_norm = text_to_sequence(text, hps.symbols, [])
26
+ else:
27
+ text_norm = text_to_sequence(text, hps.symbols, hps.data.text_cleaners)
28
+ if hps.data.add_blank:
29
+ text_norm = commons.intersperse(text_norm, 0)
30
+ text_norm = torch.LongTensor(text_norm)
31
+ return text_norm
32
+
33
+
34
+ def default_text(script):
35
+ if script=='Devanagari':
36
+ return DEFAULT_TEXT
37
+ else:
38
+ return sanscript.transliterate(DEFAULT_TEXT,sanscript.DEVANAGARI,SCRIPT_DICT[script])
39
+
40
+
41
+ def speech_synthesize(text,script, speaker_id, length_scale):
42
+ text=text.replace('\n','')
43
+ if script!='Devanagari':
44
+ text=sanscript.transliterate(text,SCRIPT_DICT[script],sanscript.DEVANAGARI)
45
+ print(text)
46
+ stn_tst = get_text(text, hps_ms)
47
+ with torch.no_grad():
48
+ x_tst = stn_tst.unsqueeze(0)
49
+ x_tst_lengths = torch.LongTensor([stn_tst.size(0)])
50
+ sid = torch.LongTensor([speaker_id])
51
+ audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.8, length_scale=length_scale)[0][0,0].data.cpu().float().numpy()
52
+ return (hps_ms.data.sampling_rate, audio)
53
+
54
+
55
+ def voice_convert(audio,origin_id,target_id):
56
+ sampling_rate, audio = audio
57
+ audio = (audio / np.iinfo(audio.dtype).max).astype(np.float32)
58
+ if len(audio.shape) > 1:
59
+ audio = librosa.to_mono(audio.transpose(1, 0))
60
+ if sampling_rate != hps_ms.data.sampling_rate:
61
+ audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=hps_ms.data.sampling_rate)
62
+
63
+ with torch.no_grad():
64
+ y = torch.FloatTensor(audio).unsqueeze(0)
65
+ spec = spectrogram_torch(y, hps_ms.data.filter_length,
66
+ hps_ms.data.sampling_rate, hps_ms.data.hop_length, hps_ms.data.win_length,
67
+ center=False)
68
+ spec_lengths = torch.LongTensor([spec.size(-1)])
69
+ sid_src = torch.LongTensor([origin_id])
70
+ sid_tgt = torch.LongTensor([target_id])
71
+ audio = net_g_ms.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt)[0][0,0].data.cpu().float().numpy()
72
+ return (hps_ms.data.sampling_rate, audio)
73
+
74
+
75
+ if __name__=='__main__':
76
+ hps_ms = utils.get_hparams_from_file('model/config.json')
77
+ n_speakers = hps_ms.data.n_speakers
78
+ n_symbols = len(hps_ms.symbols)
79
+ speakers = hps_ms.speakers
80
+
81
+ net_g_ms = SynthesizerTrn(
82
+ n_symbols,
83
+ hps_ms.data.filter_length // 2 + 1,
84
+ hps_ms.train.segment_size // hps_ms.data.hop_length,
85
+ n_speakers=n_speakers,
86
+ **hps_ms.model)
87
+ _ = net_g_ms.eval()
88
+ utils.load_checkpoint('model/model.pth', net_g_ms)
89
+
90
+ with gr.Blocks() as app:
91
+ gr.Markdown('# Sanskrit Text to Speech\n'
92
+ '![visitor badge](https://visitor-badge.glitch.me/badge?page_id=cjangcjengh.sanskrit-tts)')
93
+ with gr.Tab('Text to Speech'):
94
+ text_script=gr.Radio(['Devanagari','IAST','SLP1','HK'],label='Script',interactive=True,value='Devanagari')
95
+ text_input = gr.TextArea(label='Text', placeholder='Type your text here',value=DEFAULT_TEXT)
96
+ speaker_id=gr.Dropdown(speakers,label='Speaker',type='index',interactive=True,value=speakers[0])
97
+ length_scale=gr.Slider(0.5,2,1,step=0.1,label='Speaking Speed',interactive=True)
98
+ tts_button = gr.Button('Synthesize')
99
+ audio_output = gr.Audio(label='Speech Synthesized')
100
+ text_script.change(default_text,[text_script],[text_input])
101
+ tts_button.click(speech_synthesize,[text_input,text_script,speaker_id,length_scale],[audio_output])
102
+ with gr.Tab('Voice Conversion'):
103
+ audio_input = gr.Audio(label='Audio',interactive=True)
104
+ speaker_input = gr.Dropdown(speakers, label='Original Speaker',type='index',interactive=True, value=speakers[0])
105
+ speaker_output = gr.Dropdown(speakers, label='Target Speaker',type='index',interactive=True, value=speakers[0])
106
+ vc_button = gr.Button('Convert')
107
+ audio_output_vc = gr.Audio(label='Voice Converted')
108
+ vc_button.click(voice_convert,[audio_input,speaker_input,speaker_output],[audio_output_vc])
109
+ gr.Markdown('## Based on\n'
110
+ '- [VITS](https://github.com/jaywalnut310/vits)\n\n'
111
+ '## Dataset\n'
112
+ '- [Vāksañcayaḥ](https://www.cse.iitb.ac.in/~asr/)')
113
+
114
+ app.launch()
attentions.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ from modules import LayerNorm
8
+
9
+
10
+ class Encoder(nn.Module):
11
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs):
12
+ super().__init__()
13
+ self.hidden_channels = hidden_channels
14
+ self.filter_channels = filter_channels
15
+ self.n_heads = n_heads
16
+ self.n_layers = n_layers
17
+ self.kernel_size = kernel_size
18
+ self.p_dropout = p_dropout
19
+ self.window_size = window_size
20
+
21
+ self.drop = nn.Dropout(p_dropout)
22
+ self.attn_layers = nn.ModuleList()
23
+ self.norm_layers_1 = nn.ModuleList()
24
+ self.ffn_layers = nn.ModuleList()
25
+ self.norm_layers_2 = nn.ModuleList()
26
+ for i in range(self.n_layers):
27
+ self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
28
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
29
+ self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
30
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
31
+
32
+ def forward(self, x, x_mask):
33
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
34
+ x = x * x_mask
35
+ for i in range(self.n_layers):
36
+ y = self.attn_layers[i](x, x, attn_mask)
37
+ y = self.drop(y)
38
+ x = self.norm_layers_1[i](x + y)
39
+
40
+ y = self.ffn_layers[i](x, x_mask)
41
+ y = self.drop(y)
42
+ x = self.norm_layers_2[i](x + y)
43
+ x = x * x_mask
44
+ return x
45
+
46
+
47
+ class Decoder(nn.Module):
48
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
49
+ super().__init__()
50
+ self.hidden_channels = hidden_channels
51
+ self.filter_channels = filter_channels
52
+ self.n_heads = n_heads
53
+ self.n_layers = n_layers
54
+ self.kernel_size = kernel_size
55
+ self.p_dropout = p_dropout
56
+ self.proximal_bias = proximal_bias
57
+ self.proximal_init = proximal_init
58
+
59
+ self.drop = nn.Dropout(p_dropout)
60
+ self.self_attn_layers = nn.ModuleList()
61
+ self.norm_layers_0 = nn.ModuleList()
62
+ self.encdec_attn_layers = nn.ModuleList()
63
+ self.norm_layers_1 = nn.ModuleList()
64
+ self.ffn_layers = nn.ModuleList()
65
+ self.norm_layers_2 = nn.ModuleList()
66
+ for i in range(self.n_layers):
67
+ self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init))
68
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
69
+ self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
70
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
71
+ self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
72
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
73
+
74
+ def forward(self, x, x_mask, h, h_mask):
75
+ """
76
+ x: decoder input
77
+ h: encoder output
78
+ """
79
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
80
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
81
+ x = x * x_mask
82
+ for i in range(self.n_layers):
83
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
84
+ y = self.drop(y)
85
+ x = self.norm_layers_0[i](x + y)
86
+
87
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
88
+ y = self.drop(y)
89
+ x = self.norm_layers_1[i](x + y)
90
+
91
+ y = self.ffn_layers[i](x, x_mask)
92
+ y = self.drop(y)
93
+ x = self.norm_layers_2[i](x + y)
94
+ x = x * x_mask
95
+ return x
96
+
97
+
98
+ class MultiHeadAttention(nn.Module):
99
+ def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
100
+ super().__init__()
101
+ assert channels % n_heads == 0
102
+
103
+ self.channels = channels
104
+ self.out_channels = out_channels
105
+ self.n_heads = n_heads
106
+ self.p_dropout = p_dropout
107
+ self.window_size = window_size
108
+ self.heads_share = heads_share
109
+ self.block_length = block_length
110
+ self.proximal_bias = proximal_bias
111
+ self.proximal_init = proximal_init
112
+ self.attn = None
113
+
114
+ self.k_channels = channels // n_heads
115
+ self.conv_q = nn.Conv1d(channels, channels, 1)
116
+ self.conv_k = nn.Conv1d(channels, channels, 1)
117
+ self.conv_v = nn.Conv1d(channels, channels, 1)
118
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
119
+ self.drop = nn.Dropout(p_dropout)
120
+
121
+ if window_size is not None:
122
+ n_heads_rel = 1 if heads_share else n_heads
123
+ rel_stddev = self.k_channels**-0.5
124
+ self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
125
+ self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
126
+
127
+ nn.init.xavier_uniform_(self.conv_q.weight)
128
+ nn.init.xavier_uniform_(self.conv_k.weight)
129
+ nn.init.xavier_uniform_(self.conv_v.weight)
130
+ if proximal_init:
131
+ with torch.no_grad():
132
+ self.conv_k.weight.copy_(self.conv_q.weight)
133
+ self.conv_k.bias.copy_(self.conv_q.bias)
134
+
135
+ def forward(self, x, c, attn_mask=None):
136
+ q = self.conv_q(x)
137
+ k = self.conv_k(c)
138
+ v = self.conv_v(c)
139
+
140
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
141
+
142
+ x = self.conv_o(x)
143
+ return x
144
+
145
+ def attention(self, query, key, value, mask=None):
146
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
147
+ b, d, t_s, t_t = (*key.size(), query.size(2))
148
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
149
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
150
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
151
+
152
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
153
+ if self.window_size is not None:
154
+ assert t_s == t_t, "Relative attention is only available for self-attention."
155
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
156
+ rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings)
157
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
158
+ scores = scores + scores_local
159
+ if self.proximal_bias:
160
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
161
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
162
+ if mask is not None:
163
+ scores = scores.masked_fill(mask == 0, -1e4)
164
+ if self.block_length is not None:
165
+ assert t_s == t_t, "Local attention is only available for self-attention."
166
+ block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
167
+ scores = scores.masked_fill(block_mask == 0, -1e4)
168
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
169
+ p_attn = self.drop(p_attn)
170
+ output = torch.matmul(p_attn, value)
171
+ if self.window_size is not None:
172
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
173
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
174
+ output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
175
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
176
+ return output, p_attn
177
+
178
+ def _matmul_with_relative_values(self, x, y):
179
+ """
180
+ x: [b, h, l, m]
181
+ y: [h or 1, m, d]
182
+ ret: [b, h, l, d]
183
+ """
184
+ ret = torch.matmul(x, y.unsqueeze(0))
185
+ return ret
186
+
187
+ def _matmul_with_relative_keys(self, x, y):
188
+ """
189
+ x: [b, h, l, d]
190
+ y: [h or 1, m, d]
191
+ ret: [b, h, l, m]
192
+ """
193
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
194
+ return ret
195
+
196
+ def _get_relative_embeddings(self, relative_embeddings, length):
197
+ max_relative_position = 2 * self.window_size + 1
198
+ # Pad first before slice to avoid using cond ops.
199
+ pad_length = max(length - (self.window_size + 1), 0)
200
+ slice_start_position = max((self.window_size + 1) - length, 0)
201
+ slice_end_position = slice_start_position + 2 * length - 1
202
+ if pad_length > 0:
203
+ padded_relative_embeddings = F.pad(
204
+ relative_embeddings,
205
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
206
+ else:
207
+ padded_relative_embeddings = relative_embeddings
208
+ used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
209
+ return used_relative_embeddings
210
+
211
+ def _relative_position_to_absolute_position(self, x):
212
+ """
213
+ x: [b, h, l, 2*l-1]
214
+ ret: [b, h, l, l]
215
+ """
216
+ batch, heads, length, _ = x.size()
217
+ # Concat columns of pad to shift from relative to absolute indexing.
218
+ x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
219
+
220
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
221
+ x_flat = x.view([batch, heads, length * 2 * length])
222
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]]))
223
+
224
+ # Reshape and slice out the padded elements.
225
+ x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
226
+ return x_final
227
+
228
+ def _absolute_position_to_relative_position(self, x):
229
+ """
230
+ x: [b, h, l, l]
231
+ ret: [b, h, l, 2*l-1]
232
+ """
233
+ batch, heads, length, _ = x.size()
234
+ # padd along column
235
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
236
+ x_flat = x.view([batch, heads, length**2 + length*(length -1)])
237
+ # add 0's in the beginning that will skew the elements after reshape
238
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
239
+ x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
240
+ return x_final
241
+
242
+ def _attention_bias_proximal(self, length):
243
+ """Bias for self-attention to encourage attention to close positions.
244
+ Args:
245
+ length: an integer scalar.
246
+ Returns:
247
+ a Tensor with shape [1, 1, length, length]
248
+ """
249
+ r = torch.arange(length, dtype=torch.float32)
250
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
251
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
252
+
253
+
254
+ class FFN(nn.Module):
255
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False):
256
+ super().__init__()
257
+ self.in_channels = in_channels
258
+ self.out_channels = out_channels
259
+ self.filter_channels = filter_channels
260
+ self.kernel_size = kernel_size
261
+ self.p_dropout = p_dropout
262
+ self.activation = activation
263
+ self.causal = causal
264
+
265
+ if causal:
266
+ self.padding = self._causal_padding
267
+ else:
268
+ self.padding = self._same_padding
269
+
270
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
271
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
272
+ self.drop = nn.Dropout(p_dropout)
273
+
274
+ def forward(self, x, x_mask):
275
+ x = self.conv_1(self.padding(x * x_mask))
276
+ if self.activation == "gelu":
277
+ x = x * torch.sigmoid(1.702 * x)
278
+ else:
279
+ x = torch.relu(x)
280
+ x = self.drop(x)
281
+ x = self.conv_2(self.padding(x * x_mask))
282
+ return x * x_mask
283
+
284
+ def _causal_padding(self, x):
285
+ if self.kernel_size == 1:
286
+ return x
287
+ pad_l = self.kernel_size - 1
288
+ pad_r = 0
289
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
290
+ x = F.pad(x, commons.convert_pad_shape(padding))
291
+ return x
292
+
293
+ def _same_padding(self, x):
294
+ if self.kernel_size == 1:
295
+ return x
296
+ pad_l = (self.kernel_size - 1) // 2
297
+ pad_r = self.kernel_size // 2
298
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
299
+ x = F.pad(x, commons.convert_pad_shape(padding))
300
+ return x
commons.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn import functional as F
4
+ import torch.jit
5
+
6
+
7
+ def script_method(fn, _rcb=None):
8
+ return fn
9
+
10
+
11
+ def script(obj, optimize=True, _frames_up=0, _rcb=None):
12
+ return obj
13
+
14
+
15
+ torch.jit.script_method = script_method
16
+ torch.jit.script = script
17
+
18
+
19
+ def init_weights(m, mean=0.0, std=0.01):
20
+ classname = m.__class__.__name__
21
+ if classname.find("Conv") != -1:
22
+ m.weight.data.normal_(mean, std)
23
+
24
+
25
+ def get_padding(kernel_size, dilation=1):
26
+ return int((kernel_size*dilation - dilation)/2)
27
+
28
+
29
+ def intersperse(lst, item):
30
+ result = [item] * (len(lst) * 2 + 1)
31
+ result[1::2] = lst
32
+ return result
33
+
34
+
35
+ def slice_segments(x, ids_str, segment_size=4):
36
+ ret = torch.zeros_like(x[:, :, :segment_size])
37
+ for i in range(x.size(0)):
38
+ idx_str = ids_str[i]
39
+ idx_end = idx_str + segment_size
40
+ ret[i] = x[i, :, idx_str:idx_end]
41
+ return ret
42
+
43
+
44
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
45
+ b, d, t = x.size()
46
+ if x_lengths is None:
47
+ x_lengths = t
48
+ ids_str_max = x_lengths - segment_size + 1
49
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
50
+ ret = slice_segments(x, ids_str, segment_size)
51
+ return ret, ids_str
52
+
53
+
54
+ def subsequent_mask(length):
55
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
56
+ return mask
57
+
58
+
59
+ @torch.jit.script
60
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
61
+ n_channels_int = n_channels[0]
62
+ in_act = input_a + input_b
63
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
64
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
65
+ acts = t_act * s_act
66
+ return acts
67
+
68
+
69
+ def convert_pad_shape(pad_shape):
70
+ l = pad_shape[::-1]
71
+ pad_shape = [item for sublist in l for item in sublist]
72
+ return pad_shape
73
+
74
+
75
+ def sequence_mask(length, max_length=None):
76
+ if max_length is None:
77
+ max_length = length.max()
78
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
79
+ return x.unsqueeze(0) < length.unsqueeze(1)
80
+
81
+
82
+ def generate_path(duration, mask):
83
+ """
84
+ duration: [b, 1, t_x]
85
+ mask: [b, 1, t_y, t_x]
86
+ """
87
+ device = duration.device
88
+
89
+ b, _, t_y, t_x = mask.shape
90
+ cum_duration = torch.cumsum(duration, -1)
91
+
92
+ cum_duration_flat = cum_duration.view(b * t_x)
93
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
94
+ path = path.view(b, t_x, t_y)
95
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
96
+ path = path.unsqueeze(1).transpose(2,3) * mask
97
+ return path
mel_processing.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data
3
+ from librosa.filters import mel as librosa_mel_fn
4
+
5
+ MAX_WAV_VALUE = 32768.0
6
+
7
+
8
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
9
+ """
10
+ PARAMS
11
+ ------
12
+ C: compression factor
13
+ """
14
+ return torch.log(torch.clamp(x, min=clip_val) * C)
15
+
16
+
17
+ def dynamic_range_decompression_torch(x, C=1):
18
+ """
19
+ PARAMS
20
+ ------
21
+ C: compression factor used to compress
22
+ """
23
+ return torch.exp(x) / C
24
+
25
+
26
+ def spectral_normalize_torch(magnitudes):
27
+ output = dynamic_range_compression_torch(magnitudes)
28
+ return output
29
+
30
+
31
+ def spectral_de_normalize_torch(magnitudes):
32
+ output = dynamic_range_decompression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ mel_basis = {}
37
+ hann_window = {}
38
+
39
+
40
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
41
+ if torch.min(y) < -1.:
42
+ print('min value is ', torch.min(y))
43
+ if torch.max(y) > 1.:
44
+ print('max value is ', torch.max(y))
45
+
46
+ global hann_window
47
+ dtype_device = str(y.dtype) + '_' + str(y.device)
48
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
49
+ if wnsize_dtype_device not in hann_window:
50
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
51
+
52
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
53
+ y = y.squeeze(1)
54
+
55
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
56
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
57
+
58
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
59
+ return spec
60
+
61
+
62
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
63
+ global mel_basis
64
+ dtype_device = str(spec.dtype) + '_' + str(spec.device)
65
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
66
+ if fmax_dtype_device not in mel_basis:
67
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
68
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
69
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
70
+ spec = spectral_normalize_torch(spec)
71
+ return spec
72
+
73
+
74
+ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
75
+ if torch.min(y) < -1.:
76
+ print('min value is ', torch.min(y))
77
+ if torch.max(y) > 1.:
78
+ print('max value is ', torch.max(y))
79
+
80
+ global mel_basis, hann_window
81
+ dtype_device = str(y.dtype) + '_' + str(y.device)
82
+ fmax_dtype_device = str(fmax) + '_' + dtype_device
83
+ wnsize_dtype_device = str(win_size) + '_' + dtype_device
84
+ if fmax_dtype_device not in mel_basis:
85
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
86
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
87
+ if wnsize_dtype_device not in hann_window:
88
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
89
+
90
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
91
+ y = y.squeeze(1)
92
+
93
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
94
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
95
+
96
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
97
+
98
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
99
+ spec = spectral_normalize_torch(spec)
100
+
101
+ return spec
model/config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "segment_size": 8192
4
+ },
5
+ "data": {
6
+ "text_cleaners":["sanskrit_cleaners"],
7
+ "max_wav_value": 32768.0,
8
+ "sampling_rate": 22050,
9
+ "filter_length": 1024,
10
+ "hop_length": 256,
11
+ "win_length": 1024,
12
+ "add_blank": true,
13
+ "n_speakers": 27
14
+ },
15
+ "model": {
16
+ "inter_channels": 192,
17
+ "hidden_channels": 192,
18
+ "filter_channels": 768,
19
+ "n_heads": 2,
20
+ "n_layers": 6,
21
+ "kernel_size": 3,
22
+ "p_dropout": 0.1,
23
+ "resblock": "1",
24
+ "resblock_kernel_sizes": [3,7,11],
25
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
26
+ "upsample_rates": [8,8,2,2],
27
+ "upsample_initial_channel": 512,
28
+ "upsample_kernel_sizes": [16,16,4,4],
29
+ "n_layers_q": 3,
30
+ "use_spectral_norm": false,
31
+ "gin_channels": 256
32
+ },
33
+ "speakers": ["Male 1", "Male 2", "Male 3", "Male 4 (Malayalam)", "Male 5", "Male 6", "Male 7", "Male 8 (Kannada)", "Female 1 (Tamil)", "Male 9 (Kannada)", "Female 2 (Marathi)", "Female 3 (Marathi)", "Female 4 (Marathi)", "Female 5 (Telugu)", "Female 6 (Telugu)", "Male 10 (Kannada)", "Male 11 (Kannada)", "Male 12", "Male 13", "Male 14", "Male 15", "Female 7", "Male 16 (Malayalam)", "Male 17 (Tamil)", "Male 18 (Hindi)", "Male 19 (Telugu)", "Male 20 (Hindi)"],
34
+ "symbols": ["_", "\u0964", "\u0901", "\u0902", "\u0903", "\u0905", "\u0906", "\u0907", "\u0908", "\u0909", "\u090a", "\u090b", "\u090f", "\u0910", "\u0913", "\u0914", "\u0915", "\u0916", "\u0917", "\u0918", "\u0919", "\u091a", "\u091b", "\u091c", "\u091d", "\u091e", "\u091f", "\u0920", "\u0921", "\u0922", "\u0923", "\u0924", "\u0925", "\u0926", "\u0927", "\u0928", "\u092a", "\u092b", "\u092c", "\u092d", "\u092e", "\u092f", "\u0930", "\u0932", "\u0933", "\u0935", "\u0936", "\u0937", "\u0938", "\u0939", "\u093d", "\u093e", "\u093f", "\u0940", "\u0941", "\u0942", "\u0943", "\u0944", "\u0947", "\u0948", "\u094b", "\u094c", "\u094d", "\u0960", "\u0962", " "]
35
+ }
model/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:badfd4dd674b5f0f686ce3eeeeff10424791b688e3b3a1f4e274e4dd5d2ea8f6
3
+ size 158921293
models.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ import modules
8
+ import attentions
9
+ import monotonic_align
10
+
11
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
12
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
+ from commons import init_weights, get_padding
14
+
15
+
16
+ class StochasticDurationPredictor(nn.Module):
17
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
18
+ super().__init__()
19
+ filter_channels = in_channels # it needs to be removed from future version.
20
+ self.in_channels = in_channels
21
+ self.filter_channels = filter_channels
22
+ self.kernel_size = kernel_size
23
+ self.p_dropout = p_dropout
24
+ self.n_flows = n_flows
25
+ self.gin_channels = gin_channels
26
+
27
+ self.log_flow = modules.Log()
28
+ self.flows = nn.ModuleList()
29
+ self.flows.append(modules.ElementwiseAffine(2))
30
+ for i in range(n_flows):
31
+ self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
32
+ self.flows.append(modules.Flip())
33
+
34
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
35
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
36
+ self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
37
+ self.post_flows = nn.ModuleList()
38
+ self.post_flows.append(modules.ElementwiseAffine(2))
39
+ for i in range(4):
40
+ self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
41
+ self.post_flows.append(modules.Flip())
42
+
43
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
44
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
45
+ self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
46
+ if gin_channels != 0:
47
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
48
+
49
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
50
+ x = torch.detach(x)
51
+ x = self.pre(x)
52
+ if g is not None:
53
+ g = torch.detach(g)
54
+ x = x + self.cond(g)
55
+ x = self.convs(x, x_mask)
56
+ x = self.proj(x) * x_mask
57
+
58
+ if not reverse:
59
+ flows = self.flows
60
+ assert w is not None
61
+
62
+ logdet_tot_q = 0
63
+ h_w = self.post_pre(w)
64
+ h_w = self.post_convs(h_w, x_mask)
65
+ h_w = self.post_proj(h_w) * x_mask
66
+ e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
67
+ z_q = e_q
68
+ for flow in self.post_flows:
69
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
70
+ logdet_tot_q += logdet_q
71
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
72
+ u = torch.sigmoid(z_u) * x_mask
73
+ z0 = (w - u) * x_mask
74
+ logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
75
+ logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
76
+
77
+ logdet_tot = 0
78
+ z0, logdet = self.log_flow(z0, x_mask)
79
+ logdet_tot += logdet
80
+ z = torch.cat([z0, z1], 1)
81
+ for flow in flows:
82
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
83
+ logdet_tot = logdet_tot + logdet
84
+ nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
85
+ return nll + logq # [b]
86
+ else:
87
+ flows = list(reversed(self.flows))
88
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
89
+ z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
90
+ for flow in flows:
91
+ z = flow(z, x_mask, g=x, reverse=reverse)
92
+ z0, z1 = torch.split(z, [1, 1], 1)
93
+ logw = z0
94
+ return logw
95
+
96
+
97
+ class DurationPredictor(nn.Module):
98
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
99
+ super().__init__()
100
+
101
+ self.in_channels = in_channels
102
+ self.filter_channels = filter_channels
103
+ self.kernel_size = kernel_size
104
+ self.p_dropout = p_dropout
105
+ self.gin_channels = gin_channels
106
+
107
+ self.drop = nn.Dropout(p_dropout)
108
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
109
+ self.norm_1 = modules.LayerNorm(filter_channels)
110
+ self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
111
+ self.norm_2 = modules.LayerNorm(filter_channels)
112
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
113
+
114
+ if gin_channels != 0:
115
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
116
+
117
+ def forward(self, x, x_mask, g=None):
118
+ x = torch.detach(x)
119
+ if g is not None:
120
+ g = torch.detach(g)
121
+ x = x + self.cond(g)
122
+ x = self.conv_1(x * x_mask)
123
+ x = torch.relu(x)
124
+ x = self.norm_1(x)
125
+ x = self.drop(x)
126
+ x = self.conv_2(x * x_mask)
127
+ x = torch.relu(x)
128
+ x = self.norm_2(x)
129
+ x = self.drop(x)
130
+ x = self.proj(x * x_mask)
131
+ return x * x_mask
132
+
133
+
134
+ class TextEncoder(nn.Module):
135
+ def __init__(self,
136
+ n_vocab,
137
+ out_channels,
138
+ hidden_channels,
139
+ filter_channels,
140
+ n_heads,
141
+ n_layers,
142
+ kernel_size,
143
+ p_dropout):
144
+ super().__init__()
145
+ self.n_vocab = n_vocab
146
+ self.out_channels = out_channels
147
+ self.hidden_channels = hidden_channels
148
+ self.filter_channels = filter_channels
149
+ self.n_heads = n_heads
150
+ self.n_layers = n_layers
151
+ self.kernel_size = kernel_size
152
+ self.p_dropout = p_dropout
153
+
154
+ if self.n_vocab!=0:
155
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
156
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
157
+
158
+ self.encoder = attentions.Encoder(
159
+ hidden_channels,
160
+ filter_channels,
161
+ n_heads,
162
+ n_layers,
163
+ kernel_size,
164
+ p_dropout)
165
+ self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
166
+
167
+ def forward(self, x, x_lengths):
168
+ if self.n_vocab!=0:
169
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
170
+ x = torch.transpose(x, 1, -1) # [b, h, t]
171
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
172
+
173
+ x = self.encoder(x * x_mask, x_mask)
174
+ stats = self.proj(x) * x_mask
175
+
176
+ m, logs = torch.split(stats, self.out_channels, dim=1)
177
+ return x, m, logs, x_mask
178
+
179
+
180
+ class ResidualCouplingBlock(nn.Module):
181
+ def __init__(self,
182
+ channels,
183
+ hidden_channels,
184
+ kernel_size,
185
+ dilation_rate,
186
+ n_layers,
187
+ n_flows=4,
188
+ gin_channels=0):
189
+ super().__init__()
190
+ self.channels = channels
191
+ self.hidden_channels = hidden_channels
192
+ self.kernel_size = kernel_size
193
+ self.dilation_rate = dilation_rate
194
+ self.n_layers = n_layers
195
+ self.n_flows = n_flows
196
+ self.gin_channels = gin_channels
197
+
198
+ self.flows = nn.ModuleList()
199
+ for i in range(n_flows):
200
+ self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
201
+ self.flows.append(modules.Flip())
202
+
203
+ def forward(self, x, x_mask, g=None, reverse=False):
204
+ if not reverse:
205
+ for flow in self.flows:
206
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
207
+ else:
208
+ for flow in reversed(self.flows):
209
+ x = flow(x, x_mask, g=g, reverse=reverse)
210
+ return x
211
+
212
+
213
+ class PosteriorEncoder(nn.Module):
214
+ def __init__(self,
215
+ in_channels,
216
+ out_channels,
217
+ hidden_channels,
218
+ kernel_size,
219
+ dilation_rate,
220
+ n_layers,
221
+ gin_channels=0):
222
+ super().__init__()
223
+ self.in_channels = in_channels
224
+ self.out_channels = out_channels
225
+ self.hidden_channels = hidden_channels
226
+ self.kernel_size = kernel_size
227
+ self.dilation_rate = dilation_rate
228
+ self.n_layers = n_layers
229
+ self.gin_channels = gin_channels
230
+
231
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
232
+ self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
233
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
234
+
235
+ def forward(self, x, x_lengths, g=None):
236
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
237
+ x = self.pre(x) * x_mask
238
+ x = self.enc(x, x_mask, g=g)
239
+ stats = self.proj(x) * x_mask
240
+ m, logs = torch.split(stats, self.out_channels, dim=1)
241
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
242
+ return z, m, logs, x_mask
243
+
244
+
245
+ class Generator(torch.nn.Module):
246
+ def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
247
+ super(Generator, self).__init__()
248
+ self.num_kernels = len(resblock_kernel_sizes)
249
+ self.num_upsamples = len(upsample_rates)
250
+ self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
251
+ resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
252
+
253
+ self.ups = nn.ModuleList()
254
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
255
+ self.ups.append(weight_norm(
256
+ ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
257
+ k, u, padding=(k-u)//2)))
258
+
259
+ self.resblocks = nn.ModuleList()
260
+ for i in range(len(self.ups)):
261
+ ch = upsample_initial_channel//(2**(i+1))
262
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
263
+ self.resblocks.append(resblock(ch, k, d))
264
+
265
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
266
+ self.ups.apply(init_weights)
267
+
268
+ if gin_channels != 0:
269
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
270
+
271
+ def forward(self, x, g=None):
272
+ x = self.conv_pre(x)
273
+ if g is not None:
274
+ x = x + self.cond(g)
275
+
276
+ for i in range(self.num_upsamples):
277
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
278
+ x = self.ups[i](x)
279
+ xs = None
280
+ for j in range(self.num_kernels):
281
+ if xs is None:
282
+ xs = self.resblocks[i*self.num_kernels+j](x)
283
+ else:
284
+ xs += self.resblocks[i*self.num_kernels+j](x)
285
+ x = xs / self.num_kernels
286
+ x = F.leaky_relu(x)
287
+ x = self.conv_post(x)
288
+ x = torch.tanh(x)
289
+
290
+ return x
291
+
292
+ def remove_weight_norm(self):
293
+ print('Removing weight norm...')
294
+ for l in self.ups:
295
+ remove_weight_norm(l)
296
+ for l in self.resblocks:
297
+ l.remove_weight_norm()
298
+
299
+
300
+ class DiscriminatorP(torch.nn.Module):
301
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
302
+ super(DiscriminatorP, self).__init__()
303
+ self.period = period
304
+ self.use_spectral_norm = use_spectral_norm
305
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
306
+ self.convs = nn.ModuleList([
307
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
308
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
309
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
310
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
311
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
312
+ ])
313
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
314
+
315
+ def forward(self, x):
316
+ fmap = []
317
+
318
+ # 1d to 2d
319
+ b, c, t = x.shape
320
+ if t % self.period != 0: # pad first
321
+ n_pad = self.period - (t % self.period)
322
+ x = F.pad(x, (0, n_pad), "reflect")
323
+ t = t + n_pad
324
+ x = x.view(b, c, t // self.period, self.period)
325
+
326
+ for l in self.convs:
327
+ x = l(x)
328
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
329
+ fmap.append(x)
330
+ x = self.conv_post(x)
331
+ fmap.append(x)
332
+ x = torch.flatten(x, 1, -1)
333
+
334
+ return x, fmap
335
+
336
+
337
+ class DiscriminatorS(torch.nn.Module):
338
+ def __init__(self, use_spectral_norm=False):
339
+ super(DiscriminatorS, self).__init__()
340
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
341
+ self.convs = nn.ModuleList([
342
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
343
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
344
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
345
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
346
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
347
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
348
+ ])
349
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
350
+
351
+ def forward(self, x):
352
+ fmap = []
353
+
354
+ for l in self.convs:
355
+ x = l(x)
356
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
357
+ fmap.append(x)
358
+ x = self.conv_post(x)
359
+ fmap.append(x)
360
+ x = torch.flatten(x, 1, -1)
361
+
362
+ return x, fmap
363
+
364
+
365
+ class MultiPeriodDiscriminator(torch.nn.Module):
366
+ def __init__(self, use_spectral_norm=False):
367
+ super(MultiPeriodDiscriminator, self).__init__()
368
+ periods = [2,3,5,7,11]
369
+
370
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
371
+ discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
372
+ self.discriminators = nn.ModuleList(discs)
373
+
374
+ def forward(self, y, y_hat):
375
+ y_d_rs = []
376
+ y_d_gs = []
377
+ fmap_rs = []
378
+ fmap_gs = []
379
+ for i, d in enumerate(self.discriminators):
380
+ y_d_r, fmap_r = d(y)
381
+ y_d_g, fmap_g = d(y_hat)
382
+ y_d_rs.append(y_d_r)
383
+ y_d_gs.append(y_d_g)
384
+ fmap_rs.append(fmap_r)
385
+ fmap_gs.append(fmap_g)
386
+
387
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
388
+
389
+
390
+
391
+ class SynthesizerTrn(nn.Module):
392
+ """
393
+ Synthesizer for Training
394
+ """
395
+
396
+ def __init__(self,
397
+ n_vocab,
398
+ spec_channels,
399
+ segment_size,
400
+ inter_channels,
401
+ hidden_channels,
402
+ filter_channels,
403
+ n_heads,
404
+ n_layers,
405
+ kernel_size,
406
+ p_dropout,
407
+ resblock,
408
+ resblock_kernel_sizes,
409
+ resblock_dilation_sizes,
410
+ upsample_rates,
411
+ upsample_initial_channel,
412
+ upsample_kernel_sizes,
413
+ n_speakers=0,
414
+ gin_channels=0,
415
+ use_sdp=True,
416
+ **kwargs):
417
+
418
+ super().__init__()
419
+ self.n_vocab = n_vocab
420
+ self.spec_channels = spec_channels
421
+ self.inter_channels = inter_channels
422
+ self.hidden_channels = hidden_channels
423
+ self.filter_channels = filter_channels
424
+ self.n_heads = n_heads
425
+ self.n_layers = n_layers
426
+ self.kernel_size = kernel_size
427
+ self.p_dropout = p_dropout
428
+ self.resblock = resblock
429
+ self.resblock_kernel_sizes = resblock_kernel_sizes
430
+ self.resblock_dilation_sizes = resblock_dilation_sizes
431
+ self.upsample_rates = upsample_rates
432
+ self.upsample_initial_channel = upsample_initial_channel
433
+ self.upsample_kernel_sizes = upsample_kernel_sizes
434
+ self.segment_size = segment_size
435
+ self.n_speakers = n_speakers
436
+ self.gin_channels = gin_channels
437
+
438
+ self.use_sdp = use_sdp
439
+
440
+ self.enc_p = TextEncoder(n_vocab,
441
+ inter_channels,
442
+ hidden_channels,
443
+ filter_channels,
444
+ n_heads,
445
+ n_layers,
446
+ kernel_size,
447
+ p_dropout)
448
+ self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
449
+ self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
450
+ self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
451
+
452
+ if use_sdp:
453
+ self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
454
+ else:
455
+ self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
456
+
457
+ if n_speakers > 1:
458
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
459
+
460
+ def forward(self, x, x_lengths, y, y_lengths, sid=None):
461
+
462
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
463
+ if self.n_speakers > 0:
464
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
465
+ else:
466
+ g = None
467
+
468
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
469
+ z_p = self.flow(z, y_mask, g=g)
470
+
471
+ with torch.no_grad():
472
+ # negative cross-entropy
473
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
474
+ neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s]
475
+ neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
476
+ neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
477
+ neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s]
478
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
479
+
480
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
481
+ attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
482
+
483
+ w = attn.sum(2)
484
+ if self.use_sdp:
485
+ l_length = self.dp(x, x_mask, w, g=g)
486
+ l_length = l_length / torch.sum(x_mask)
487
+ else:
488
+ logw_ = torch.log(w + 1e-6) * x_mask
489
+ logw = self.dp(x, x_mask, g=g)
490
+ l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging
491
+
492
+ # expand prior
493
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
494
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
495
+
496
+ z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
497
+ o = self.dec(z_slice, g=g)
498
+ return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
499
+
500
+ def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
501
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
502
+ if self.n_speakers > 0:
503
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
504
+ else:
505
+ g = None
506
+
507
+ if self.use_sdp:
508
+ logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
509
+ else:
510
+ logw = self.dp(x, x_mask, g=g)
511
+ w = torch.exp(logw) * x_mask * length_scale
512
+ w_ceil = torch.ceil(w)
513
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
514
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
515
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
516
+ attn = commons.generate_path(w_ceil, attn_mask)
517
+
518
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
519
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
520
+
521
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
522
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
523
+ o = self.dec((z * y_mask)[:,:,:max_len], g=g)
524
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
525
+
526
+ def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
527
+ assert self.n_speakers > 0, "n_speakers have to be larger than 0."
528
+ g_src = self.emb_g(sid_src).unsqueeze(-1)
529
+ g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
530
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
531
+ z_p = self.flow(z, y_mask, g=g_src)
532
+ z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
533
+ o_hat = self.dec(z_hat * y_mask, g=g_tgt)
534
+ return o_hat, y_mask, (z, z_p, z_hat)
535
+
modules.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from torch.nn import Conv1d
7
+ from torch.nn.utils import weight_norm, remove_weight_norm
8
+
9
+ import commons
10
+ from commons import init_weights, get_padding
11
+ from transforms import piecewise_rational_quadratic_transform
12
+
13
+
14
+ LRELU_SLOPE = 0.1
15
+
16
+
17
+ class LayerNorm(nn.Module):
18
+ def __init__(self, channels, eps=1e-5):
19
+ super().__init__()
20
+ self.channels = channels
21
+ self.eps = eps
22
+
23
+ self.gamma = nn.Parameter(torch.ones(channels))
24
+ self.beta = nn.Parameter(torch.zeros(channels))
25
+
26
+ def forward(self, x):
27
+ x = x.transpose(1, -1)
28
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
29
+ return x.transpose(1, -1)
30
+
31
+
32
+ class ConvReluNorm(nn.Module):
33
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
34
+ super().__init__()
35
+ self.in_channels = in_channels
36
+ self.hidden_channels = hidden_channels
37
+ self.out_channels = out_channels
38
+ self.kernel_size = kernel_size
39
+ self.n_layers = n_layers
40
+ self.p_dropout = p_dropout
41
+ assert n_layers > 1, "Number of layers should be larger than 0."
42
+
43
+ self.conv_layers = nn.ModuleList()
44
+ self.norm_layers = nn.ModuleList()
45
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
46
+ self.norm_layers.append(LayerNorm(hidden_channels))
47
+ self.relu_drop = nn.Sequential(
48
+ nn.ReLU(),
49
+ nn.Dropout(p_dropout))
50
+ for _ in range(n_layers-1):
51
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
52
+ self.norm_layers.append(LayerNorm(hidden_channels))
53
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
54
+ self.proj.weight.data.zero_()
55
+ self.proj.bias.data.zero_()
56
+
57
+ def forward(self, x, x_mask):
58
+ x_org = x
59
+ for i in range(self.n_layers):
60
+ x = self.conv_layers[i](x * x_mask)
61
+ x = self.norm_layers[i](x)
62
+ x = self.relu_drop(x)
63
+ x = x_org + self.proj(x)
64
+ return x * x_mask
65
+
66
+
67
+ class DDSConv(nn.Module):
68
+ """
69
+ Dialted and Depth-Separable Convolution
70
+ """
71
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
72
+ super().__init__()
73
+ self.channels = channels
74
+ self.kernel_size = kernel_size
75
+ self.n_layers = n_layers
76
+ self.p_dropout = p_dropout
77
+
78
+ self.drop = nn.Dropout(p_dropout)
79
+ self.convs_sep = nn.ModuleList()
80
+ self.convs_1x1 = nn.ModuleList()
81
+ self.norms_1 = nn.ModuleList()
82
+ self.norms_2 = nn.ModuleList()
83
+ for i in range(n_layers):
84
+ dilation = kernel_size ** i
85
+ padding = (kernel_size * dilation - dilation) // 2
86
+ self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
87
+ groups=channels, dilation=dilation, padding=padding
88
+ ))
89
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
90
+ self.norms_1.append(LayerNorm(channels))
91
+ self.norms_2.append(LayerNorm(channels))
92
+
93
+ def forward(self, x, x_mask, g=None):
94
+ if g is not None:
95
+ x = x + g
96
+ for i in range(self.n_layers):
97
+ y = self.convs_sep[i](x * x_mask)
98
+ y = self.norms_1[i](y)
99
+ y = F.gelu(y)
100
+ y = self.convs_1x1[i](y)
101
+ y = self.norms_2[i](y)
102
+ y = F.gelu(y)
103
+ y = self.drop(y)
104
+ x = x + y
105
+ return x * x_mask
106
+
107
+
108
+ class WN(torch.nn.Module):
109
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
110
+ super(WN, self).__init__()
111
+ assert(kernel_size % 2 == 1)
112
+ self.hidden_channels =hidden_channels
113
+ self.kernel_size = kernel_size,
114
+ self.dilation_rate = dilation_rate
115
+ self.n_layers = n_layers
116
+ self.gin_channels = gin_channels
117
+ self.p_dropout = p_dropout
118
+
119
+ self.in_layers = torch.nn.ModuleList()
120
+ self.res_skip_layers = torch.nn.ModuleList()
121
+ self.drop = nn.Dropout(p_dropout)
122
+
123
+ if gin_channels != 0:
124
+ cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
125
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
126
+
127
+ for i in range(n_layers):
128
+ dilation = dilation_rate ** i
129
+ padding = int((kernel_size * dilation - dilation) / 2)
130
+ in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
131
+ dilation=dilation, padding=padding)
132
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
133
+ self.in_layers.append(in_layer)
134
+
135
+ # last one is not necessary
136
+ if i < n_layers - 1:
137
+ res_skip_channels = 2 * hidden_channels
138
+ else:
139
+ res_skip_channels = hidden_channels
140
+
141
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
142
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
143
+ self.res_skip_layers.append(res_skip_layer)
144
+
145
+ def forward(self, x, x_mask, g=None, **kwargs):
146
+ output = torch.zeros_like(x)
147
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
148
+
149
+ if g is not None:
150
+ g = self.cond_layer(g)
151
+
152
+ for i in range(self.n_layers):
153
+ x_in = self.in_layers[i](x)
154
+ if g is not None:
155
+ cond_offset = i * 2 * self.hidden_channels
156
+ g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
157
+ else:
158
+ g_l = torch.zeros_like(x_in)
159
+
160
+ acts = commons.fused_add_tanh_sigmoid_multiply(
161
+ x_in,
162
+ g_l,
163
+ n_channels_tensor)
164
+ acts = self.drop(acts)
165
+
166
+ res_skip_acts = self.res_skip_layers[i](acts)
167
+ if i < self.n_layers - 1:
168
+ res_acts = res_skip_acts[:,:self.hidden_channels,:]
169
+ x = (x + res_acts) * x_mask
170
+ output = output + res_skip_acts[:,self.hidden_channels:,:]
171
+ else:
172
+ output = output + res_skip_acts
173
+ return output * x_mask
174
+
175
+ def remove_weight_norm(self):
176
+ if self.gin_channels != 0:
177
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
178
+ for l in self.in_layers:
179
+ torch.nn.utils.remove_weight_norm(l)
180
+ for l in self.res_skip_layers:
181
+ torch.nn.utils.remove_weight_norm(l)
182
+
183
+
184
+ class ResBlock1(torch.nn.Module):
185
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
186
+ super(ResBlock1, self).__init__()
187
+ self.convs1 = nn.ModuleList([
188
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
189
+ padding=get_padding(kernel_size, dilation[0]))),
190
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
191
+ padding=get_padding(kernel_size, dilation[1]))),
192
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
193
+ padding=get_padding(kernel_size, dilation[2])))
194
+ ])
195
+ self.convs1.apply(init_weights)
196
+
197
+ self.convs2 = nn.ModuleList([
198
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
199
+ padding=get_padding(kernel_size, 1))),
200
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
201
+ padding=get_padding(kernel_size, 1))),
202
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
203
+ padding=get_padding(kernel_size, 1)))
204
+ ])
205
+ self.convs2.apply(init_weights)
206
+
207
+ def forward(self, x, x_mask=None):
208
+ for c1, c2 in zip(self.convs1, self.convs2):
209
+ xt = F.leaky_relu(x, LRELU_SLOPE)
210
+ if x_mask is not None:
211
+ xt = xt * x_mask
212
+ xt = c1(xt)
213
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
214
+ if x_mask is not None:
215
+ xt = xt * x_mask
216
+ xt = c2(xt)
217
+ x = xt + x
218
+ if x_mask is not None:
219
+ x = x * x_mask
220
+ return x
221
+
222
+ def remove_weight_norm(self):
223
+ for l in self.convs1:
224
+ remove_weight_norm(l)
225
+ for l in self.convs2:
226
+ remove_weight_norm(l)
227
+
228
+
229
+ class ResBlock2(torch.nn.Module):
230
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
231
+ super(ResBlock2, self).__init__()
232
+ self.convs = nn.ModuleList([
233
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
234
+ padding=get_padding(kernel_size, dilation[0]))),
235
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
236
+ padding=get_padding(kernel_size, dilation[1])))
237
+ ])
238
+ self.convs.apply(init_weights)
239
+
240
+ def forward(self, x, x_mask=None):
241
+ for c in self.convs:
242
+ xt = F.leaky_relu(x, LRELU_SLOPE)
243
+ if x_mask is not None:
244
+ xt = xt * x_mask
245
+ xt = c(xt)
246
+ x = xt + x
247
+ if x_mask is not None:
248
+ x = x * x_mask
249
+ return x
250
+
251
+ def remove_weight_norm(self):
252
+ for l in self.convs:
253
+ remove_weight_norm(l)
254
+
255
+
256
+ class Log(nn.Module):
257
+ def forward(self, x, x_mask, reverse=False, **kwargs):
258
+ if not reverse:
259
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
260
+ logdet = torch.sum(-y, [1, 2])
261
+ return y, logdet
262
+ else:
263
+ x = torch.exp(x) * x_mask
264
+ return x
265
+
266
+
267
+ class Flip(nn.Module):
268
+ def forward(self, x, *args, reverse=False, **kwargs):
269
+ x = torch.flip(x, [1])
270
+ if not reverse:
271
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
272
+ return x, logdet
273
+ else:
274
+ return x
275
+
276
+
277
+ class ElementwiseAffine(nn.Module):
278
+ def __init__(self, channels):
279
+ super().__init__()
280
+ self.channels = channels
281
+ self.m = nn.Parameter(torch.zeros(channels,1))
282
+ self.logs = nn.Parameter(torch.zeros(channels,1))
283
+
284
+ def forward(self, x, x_mask, reverse=False, **kwargs):
285
+ if not reverse:
286
+ y = self.m + torch.exp(self.logs) * x
287
+ y = y * x_mask
288
+ logdet = torch.sum(self.logs * x_mask, [1,2])
289
+ return y, logdet
290
+ else:
291
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
292
+ return x
293
+
294
+
295
+ class ResidualCouplingLayer(nn.Module):
296
+ def __init__(self,
297
+ channels,
298
+ hidden_channels,
299
+ kernel_size,
300
+ dilation_rate,
301
+ n_layers,
302
+ p_dropout=0,
303
+ gin_channels=0,
304
+ mean_only=False):
305
+ assert channels % 2 == 0, "channels should be divisible by 2"
306
+ super().__init__()
307
+ self.channels = channels
308
+ self.hidden_channels = hidden_channels
309
+ self.kernel_size = kernel_size
310
+ self.dilation_rate = dilation_rate
311
+ self.n_layers = n_layers
312
+ self.half_channels = channels // 2
313
+ self.mean_only = mean_only
314
+
315
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
316
+ self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
317
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
318
+ self.post.weight.data.zero_()
319
+ self.post.bias.data.zero_()
320
+
321
+ def forward(self, x, x_mask, g=None, reverse=False):
322
+ x0, x1 = torch.split(x, [self.half_channels]*2, 1)
323
+ h = self.pre(x0) * x_mask
324
+ h = self.enc(h, x_mask, g=g)
325
+ stats = self.post(h) * x_mask
326
+ if not self.mean_only:
327
+ m, logs = torch.split(stats, [self.half_channels]*2, 1)
328
+ else:
329
+ m = stats
330
+ logs = torch.zeros_like(m)
331
+
332
+ if not reverse:
333
+ x1 = m + x1 * torch.exp(logs) * x_mask
334
+ x = torch.cat([x0, x1], 1)
335
+ logdet = torch.sum(logs, [1,2])
336
+ return x, logdet
337
+ else:
338
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
339
+ x = torch.cat([x0, x1], 1)
340
+ return x
341
+
342
+
343
+ class ConvFlow(nn.Module):
344
+ def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
345
+ super().__init__()
346
+ self.in_channels = in_channels
347
+ self.filter_channels = filter_channels
348
+ self.kernel_size = kernel_size
349
+ self.n_layers = n_layers
350
+ self.num_bins = num_bins
351
+ self.tail_bound = tail_bound
352
+ self.half_channels = in_channels // 2
353
+
354
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
355
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
356
+ self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
357
+ self.proj.weight.data.zero_()
358
+ self.proj.bias.data.zero_()
359
+
360
+ def forward(self, x, x_mask, g=None, reverse=False):
361
+ x0, x1 = torch.split(x, [self.half_channels]*2, 1)
362
+ h = self.pre(x0)
363
+ h = self.convs(h, x_mask, g=g)
364
+ h = self.proj(h) * x_mask
365
+
366
+ b, c, t = x0.shape
367
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
368
+
369
+ unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
370
+ unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels)
371
+ unnormalized_derivatives = h[..., 2 * self.num_bins:]
372
+
373
+ x1, logabsdet = piecewise_rational_quadratic_transform(x1,
374
+ unnormalized_widths,
375
+ unnormalized_heights,
376
+ unnormalized_derivatives,
377
+ inverse=reverse,
378
+ tails='linear',
379
+ tail_bound=self.tail_bound
380
+ )
381
+
382
+ x = torch.cat([x0, x1], 1) * x_mask
383
+ logdet = torch.sum(logabsdet * x_mask, [1,2])
384
+ if not reverse:
385
+ return x, logdet
386
+ else:
387
+ return x
monotonic_align/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numpy import zeros, int32, float32
2
+ from torch import from_numpy
3
+
4
+ from .core import maximum_path_jit
5
+
6
+ def maximum_path(neg_cent, mask):
7
+ """ numba optimized version.
8
+ neg_cent: [b, t_t, t_s]
9
+ mask: [b, t_t, t_s]
10
+ """
11
+ device = neg_cent.device
12
+ dtype = neg_cent.dtype
13
+ neg_cent = neg_cent.data.cpu().numpy().astype(float32)
14
+ path = zeros(neg_cent.shape, dtype=int32)
15
+
16
+ t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
17
+ t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
18
+ maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
19
+ return from_numpy(path).to(device=device, dtype=dtype)
monotonic_align/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (823 Bytes). View file
 
monotonic_align/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (795 Bytes). View file
 
monotonic_align/__pycache__/core.cpython-37.pyc ADDED
Binary file (968 Bytes). View file
 
monotonic_align/build/temp.win-amd64-3.7/Release/core.cp37-win_amd64.exp ADDED
Binary file (740 Bytes). View file
 
monotonic_align/build/temp.win-amd64-3.7/Release/core.cp37-win_amd64.lib ADDED
Binary file (1.94 kB). View file
 
monotonic_align/build/temp.win-amd64-3.7/Release/core.obj ADDED
Binary file (864 kB). View file
 
monotonic_align/core.c ADDED
The diff for this file is too large to render. See raw diff
 
monotonic_align/core.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numba
2
+
3
+
4
+ @numba.jit(numba.void(numba.int32[:,:,::1], numba.float32[:,:,::1], numba.int32[::1], numba.int32[::1]), nopython=True, nogil=True)
5
+ def maximum_path_jit(paths, values, t_ys, t_xs):
6
+ b = paths.shape[0]
7
+ max_neg_val=-1e9
8
+ for i in range(int(b)):
9
+ path = paths[i]
10
+ value = values[i]
11
+ t_y = t_ys[i]
12
+ t_x = t_xs[i]
13
+
14
+ v_prev = v_cur = 0.0
15
+ index = t_x - 1
16
+
17
+ for y in range(t_y):
18
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
19
+ if x == y:
20
+ v_cur = max_neg_val
21
+ else:
22
+ v_cur = value[y-1, x]
23
+ if x == 0:
24
+ if y == 0:
25
+ v_prev = 0.
26
+ else:
27
+ v_prev = max_neg_val
28
+ else:
29
+ v_prev = value[y-1, x-1]
30
+ value[y, x] += max(v_prev, v_cur)
31
+
32
+ for y in range(t_y - 1, -1, -1):
33
+ path[y, index] = 1
34
+ if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]):
35
+ index = index - 1
monotonic_align/monotonic_align/core.cp37-win_amd64.pyd ADDED
Binary file (151 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ librosa
2
+ matplotlib
3
+ numpy
4
+ scipy
5
+ tensorboard
6
+ torch
7
+ torchvision
8
+ indic_transliteration
text/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+ from text import cleaners
3
+
4
+
5
+ def text_to_sequence(text, symbols, cleaner_names):
6
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
7
+ Args:
8
+ text: string to convert to a sequence
9
+ cleaner_names: names of the cleaner functions to run the text through
10
+ Returns:
11
+ List of integers corresponding to the symbols in the text
12
+ '''
13
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
14
+
15
+ sequence = []
16
+
17
+ clean_text = _clean_text(text, cleaner_names)
18
+ for symbol in clean_text:
19
+ if symbol not in _symbol_to_id.keys():
20
+ continue
21
+ symbol_id = _symbol_to_id[symbol]
22
+ sequence += [symbol_id]
23
+ return sequence
24
+
25
+
26
+ def _clean_text(text, cleaner_names):
27
+ for name in cleaner_names:
28
+ cleaner = getattr(cleaners, name)
29
+ if not cleaner:
30
+ raise Exception('Unknown cleaner: %s' % name)
31
+ text = cleaner(text)
32
+ return text
text/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (1.2 kB). View file
 
text/__pycache__/cleaners.cpython-37.pyc ADDED
Binary file (420 Bytes). View file
 
text/cleaners.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ def sanskrit_cleaners(text):
2
+ text = text.replace('॥', '।').replace('ॐ', 'ओम्')
3
+ if len(text)==0 or text[-1] != '।':
4
+ text += ' ।'
5
+ return text
transforms.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import numpy as np
5
+
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def piecewise_rational_quadratic_transform(inputs,
13
+ unnormalized_widths,
14
+ unnormalized_heights,
15
+ unnormalized_derivatives,
16
+ inverse=False,
17
+ tails=None,
18
+ tail_bound=1.,
19
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
20
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
21
+ min_derivative=DEFAULT_MIN_DERIVATIVE):
22
+
23
+ if tails is None:
24
+ spline_fn = rational_quadratic_spline
25
+ spline_kwargs = {}
26
+ else:
27
+ spline_fn = unconstrained_rational_quadratic_spline
28
+ spline_kwargs = {
29
+ 'tails': tails,
30
+ 'tail_bound': tail_bound
31
+ }
32
+
33
+ outputs, logabsdet = spline_fn(
34
+ inputs=inputs,
35
+ unnormalized_widths=unnormalized_widths,
36
+ unnormalized_heights=unnormalized_heights,
37
+ unnormalized_derivatives=unnormalized_derivatives,
38
+ inverse=inverse,
39
+ min_bin_width=min_bin_width,
40
+ min_bin_height=min_bin_height,
41
+ min_derivative=min_derivative,
42
+ **spline_kwargs
43
+ )
44
+ return outputs, logabsdet
45
+
46
+
47
+ def searchsorted(bin_locations, inputs, eps=1e-6):
48
+ bin_locations[..., -1] += eps
49
+ return torch.sum(
50
+ inputs[..., None] >= bin_locations,
51
+ dim=-1
52
+ ) - 1
53
+
54
+
55
+ def unconstrained_rational_quadratic_spline(inputs,
56
+ unnormalized_widths,
57
+ unnormalized_heights,
58
+ unnormalized_derivatives,
59
+ inverse=False,
60
+ tails='linear',
61
+ tail_bound=1.,
62
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
63
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
64
+ min_derivative=DEFAULT_MIN_DERIVATIVE):
65
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
66
+ outside_interval_mask = ~inside_interval_mask
67
+
68
+ outputs = torch.zeros_like(inputs)
69
+ logabsdet = torch.zeros_like(inputs)
70
+
71
+ if tails == 'linear':
72
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
73
+ constant = np.log(np.exp(1 - min_derivative) - 1)
74
+ unnormalized_derivatives[..., 0] = constant
75
+ unnormalized_derivatives[..., -1] = constant
76
+
77
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
78
+ logabsdet[outside_interval_mask] = 0
79
+ else:
80
+ raise RuntimeError('{} tails are not implemented.'.format(tails))
81
+
82
+ outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
83
+ inputs=inputs[inside_interval_mask],
84
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
+ inverse=inverse,
88
+ left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
89
+ min_bin_width=min_bin_width,
90
+ min_bin_height=min_bin_height,
91
+ min_derivative=min_derivative
92
+ )
93
+
94
+ return outputs, logabsdet
95
+
96
+ def rational_quadratic_spline(inputs,
97
+ unnormalized_widths,
98
+ unnormalized_heights,
99
+ unnormalized_derivatives,
100
+ inverse=False,
101
+ left=0., right=1., bottom=0., top=1.,
102
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
103
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
104
+ min_derivative=DEFAULT_MIN_DERIVATIVE):
105
+ if torch.min(inputs) < left or torch.max(inputs) > right:
106
+ raise ValueError('Input to a transform is not within its domain')
107
+
108
+ num_bins = unnormalized_widths.shape[-1]
109
+
110
+ if min_bin_width * num_bins > 1.0:
111
+ raise ValueError('Minimal bin width too large for the number of bins')
112
+ if min_bin_height * num_bins > 1.0:
113
+ raise ValueError('Minimal bin height too large for the number of bins')
114
+
115
+ widths = F.softmax(unnormalized_widths, dim=-1)
116
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
117
+ cumwidths = torch.cumsum(widths, dim=-1)
118
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
119
+ cumwidths = (right - left) * cumwidths + left
120
+ cumwidths[..., 0] = left
121
+ cumwidths[..., -1] = right
122
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
123
+
124
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
125
+
126
+ heights = F.softmax(unnormalized_heights, dim=-1)
127
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
128
+ cumheights = torch.cumsum(heights, dim=-1)
129
+ cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
130
+ cumheights = (top - bottom) * cumheights + bottom
131
+ cumheights[..., 0] = bottom
132
+ cumheights[..., -1] = top
133
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
134
+
135
+ if inverse:
136
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
137
+ else:
138
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
139
+
140
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
141
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
142
+
143
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
144
+ delta = heights / widths
145
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
146
+
147
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
148
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
149
+
150
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
151
+
152
+ if inverse:
153
+ a = (((inputs - input_cumheights) * (input_derivatives
154
+ + input_derivatives_plus_one
155
+ - 2 * input_delta)
156
+ + input_heights * (input_delta - input_derivatives)))
157
+ b = (input_heights * input_derivatives
158
+ - (inputs - input_cumheights) * (input_derivatives
159
+ + input_derivatives_plus_one
160
+ - 2 * input_delta))
161
+ c = - input_delta * (inputs - input_cumheights)
162
+
163
+ discriminant = b.pow(2) - 4 * a * c
164
+ assert (discriminant >= 0).all()
165
+
166
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
167
+ outputs = root * input_bin_widths + input_cumwidths
168
+
169
+ theta_one_minus_theta = root * (1 - root)
170
+ denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
171
+ * theta_one_minus_theta)
172
+ derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
173
+ + 2 * input_delta * theta_one_minus_theta
174
+ + input_derivatives * (1 - root).pow(2))
175
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
176
+
177
+ return outputs, -logabsdet
178
+ else:
179
+ theta = (inputs - input_cumwidths) / input_bin_widths
180
+ theta_one_minus_theta = theta * (1 - theta)
181
+
182
+ numerator = input_heights * (input_delta * theta.pow(2)
183
+ + input_derivatives * theta_one_minus_theta)
184
+ denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
185
+ * theta_one_minus_theta)
186
+ outputs = input_cumheights + numerator / denominator
187
+
188
+ derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
189
+ + 2 * input_delta * theta_one_minus_theta
190
+ + input_derivatives * (1 - theta).pow(2))
191
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
192
+
193
+ return outputs, logabsdet
utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from json import loads
3
+ from torch import load, FloatTensor
4
+ from numpy import float32
5
+ import librosa
6
+
7
+
8
+ class HParams():
9
+ def __init__(self, **kwargs):
10
+ for k, v in kwargs.items():
11
+ if type(v) == dict:
12
+ v = HParams(**v)
13
+ self[k] = v
14
+
15
+ def keys(self):
16
+ return self.__dict__.keys()
17
+
18
+ def items(self):
19
+ return self.__dict__.items()
20
+
21
+ def values(self):
22
+ return self.__dict__.values()
23
+
24
+ def __len__(self):
25
+ return len(self.__dict__)
26
+
27
+ def __getitem__(self, key):
28
+ return getattr(self, key)
29
+
30
+ def __setitem__(self, key, value):
31
+ return setattr(self, key, value)
32
+
33
+ def __contains__(self, key):
34
+ return key in self.__dict__
35
+
36
+ def __repr__(self):
37
+ return self.__dict__.__repr__()
38
+
39
+
40
+ def load_checkpoint(checkpoint_path, model):
41
+ checkpoint_dict = load(checkpoint_path, map_location='cpu')
42
+ iteration = checkpoint_dict['iteration']
43
+ saved_state_dict = checkpoint_dict['model']
44
+ if hasattr(model, 'module'):
45
+ state_dict = model.module.state_dict()
46
+ else:
47
+ state_dict = model.state_dict()
48
+ new_state_dict= {}
49
+ for k, v in state_dict.items():
50
+ try:
51
+ new_state_dict[k] = saved_state_dict[k]
52
+ except:
53
+ logging.info("%s is not in the checkpoint" % k)
54
+ new_state_dict[k] = v
55
+ if hasattr(model, 'module'):
56
+ model.module.load_state_dict(new_state_dict)
57
+ else:
58
+ model.load_state_dict(new_state_dict)
59
+ logging.info("Loaded checkpoint '{}' (iteration {})" .format(
60
+ checkpoint_path, iteration))
61
+ return
62
+
63
+
64
+ def get_hparams_from_file(config_path):
65
+ with open(config_path, "r") as f:
66
+ data = f.read()
67
+ config = loads(data)
68
+
69
+ hparams = HParams(**config)
70
+ return hparams
71
+
72
+
73
+ def load_audio_to_torch(full_path, target_sampling_rate):
74
+ audio, sampling_rate = librosa.load(full_path, sr=target_sampling_rate, mono=True)
75
+ return FloatTensor(audio.astype(float32))