hexgrad commited on
Commit
e9b69d2
1 Parent(s): fc11260

Upload 10 files

Browse files
Files changed (10) hide show
  1. app.py +200 -0
  2. en.txt +0 -0
  3. istftnet.py +523 -0
  4. ja.txt +0 -0
  5. katsu.py +430 -0
  6. models.py +571 -0
  7. num2kana.py +317 -0
  8. packages.txt +1 -0
  9. plbert.py +15 -0
  10. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+ from katsu import Katsu
3
+ from models import build_model
4
+ import gradio as gr
5
+ import noisereduce as nr
6
+ import numpy as np
7
+ import os
8
+ import phonemizer
9
+ import random
10
+ import torch
11
+ import yaml
12
+
13
+ random_texts = {}
14
+ for lang in ['en', 'ja']:
15
+ with open(f'{lang}.txt', 'r') as r:
16
+ random_texts[lang] = [line.strip() for line in r]
17
+
18
+ def get_random_text(voice):
19
+ if voice[0] == 'j':
20
+ lang = 'ja'
21
+ else:
22
+ lang = 'en'
23
+ return random.choice(random_texts[lang])
24
+
25
+ def parens_to_angles(s):
26
+ return s.replace('(', '«').replace(')', '»')
27
+
28
+ def normalize(text):
29
+ # TODO: Custom text normalization rules?
30
+ text = text.replace('Dr.', 'Doctor')
31
+ text = text.replace('Mr.', 'Mister')
32
+ text = text.replace('Ms.', 'Miss')
33
+ text = text.replace('Mrs.', 'Mrs')
34
+ return parens_to_angles(text)
35
+
36
+ phonemizers = dict(
37
+ a=phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True),
38
+ b=phonemizer.backend.EspeakBackend(language='en-gb', preserve_punctuation=True, with_stress=True),
39
+ j=Katsu()
40
+ )
41
+
42
+ def phonemize(text, voice):
43
+ lang = voice[0]
44
+ text = normalize(text)
45
+ ps = phonemizers[lang].phonemize([text])
46
+ ps = ps[0] if ps else ''
47
+ # TODO: Custom phonemization rules?
48
+ ps = parens_to_angles(ps)
49
+ # https://en.wiktionary.org/wiki/kokoro#English
50
+ ps = ps.replace('kəkˈoːɹoʊ', 'kˈoʊkəɹoʊ').replace('kəkˈɔːɹəʊ', 'kˈəʊkəɹəʊ')
51
+ ps = ''.join(filter(lambda p: p in VOCAB, ps))
52
+ return ps.strip()
53
+
54
+ def length_to_mask(lengths):
55
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
56
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
57
+ return mask
58
+
59
+ def get_vocab():
60
+ _pad = "$"
61
+ _punctuation = ';:,.!?¡¿—…"«»“” '
62
+ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
63
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
64
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
65
+ dicts = {}
66
+ for i in range(len((symbols))):
67
+ dicts[symbols[i]] = i
68
+ return dicts
69
+
70
+ VOCAB = get_vocab()
71
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
72
+
73
+ snapshot = snapshot_download(repo_id='hexgrad/kokoro', allow_patterns=['*.pt', '*.pth', '*.yml'], use_auth_token=os.environ['TOKEN'])
74
+ config = yaml.safe_load(open(os.path.join(snapshot, 'config.yml')))
75
+ model = build_model(config['model_params'])
76
+ _ = [model[key].eval() for key in model]
77
+ _ = [model[key].to(device) for key in model]
78
+ for key, state_dict in torch.load(os.path.join(snapshot, 'net.pth'), map_location='cpu', weights_only=True)['net'].items():
79
+ assert key in model, key
80
+ try:
81
+ model[key].load_state_dict(state_dict)
82
+ except:
83
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
84
+ model[key].load_state_dict(state_dict, strict=False)
85
+
86
+ CHOICES = {
87
+ '🇺🇸 🚺 American Female 0': 'af0',
88
+ '🇺🇸 🚺 Bella': 'af1',
89
+ '🇺🇸 🚺 Nicole': 'af2',
90
+ '🇺🇸 🚹 Michael': 'am0',
91
+ '🇺🇸 🚹 Adam': 'am1',
92
+ '🇬🇧 🚺 British Female 0': 'bf0',
93
+ '🇬🇧 🚺 British Female 1': 'bf1',
94
+ '🇬🇧 🚺 British Female 2': 'bf2',
95
+ '🇬🇧 🚹 British Male 0': 'bm0',
96
+ '🇬🇧 🚹 British Male 1': 'bm1',
97
+ '🇬🇧 🚹 British Male 2': 'bm2',
98
+ '🇬🇧 🚹 British Male 3': 'bm3',
99
+ '🇯🇵 🚺 Japanese Female 0': 'jf0',
100
+ }
101
+ VOICES = {k: torch.load(os.path.join(snapshot, 'voices', f'{k}.pt'), weights_only=True).to(device) for k in CHOICES.values()}
102
+
103
+ np_log_99 = np.log(99)
104
+ def s_curve(p):
105
+ if p <= 0:
106
+ return 0
107
+ elif p >= 1:
108
+ return 1
109
+ s = 1 / (1 + np.exp((1-p*2)*np_log_99))
110
+ s = (s-0.01) * 50/49
111
+ return s
112
+
113
+ SAMPLE_RATE = 24000
114
+
115
+ @torch.no_grad()
116
+ def forward(text, voice, ps=None, speed=1.0, reduce_noise=0.5, opening_cut=5000, closing_cut=0, ease_in=3000, ease_out=0):
117
+ ps = ps or phonemize(text, voice)
118
+ tokens = [i for i in map(VOCAB.get, ps) if i is not None]
119
+ if not tokens:
120
+ return (None, '')
121
+ elif len(tokens) > 510:
122
+ tokens = tokens[:510]
123
+ ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
124
+ tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
125
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
126
+ text_mask = length_to_mask(input_lengths).to(device)
127
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
128
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
129
+ ref_s = VOICES[voice]
130
+ s = ref_s[:, 128:]
131
+ d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
132
+ x, _ = model.predictor.lstm(d)
133
+ duration = model.predictor.duration_proj(x)
134
+ duration = torch.sigmoid(duration).sum(axis=-1) / speed
135
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
136
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
137
+ c_frame = 0
138
+ for i in range(pred_aln_trg.size(0)):
139
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
140
+ c_frame += int(pred_dur[i].data)
141
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
142
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
143
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
144
+ asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
145
+ out = model.decoder(asr, F0_pred, N_pred, ref_s[:, :128])
146
+ out = out.squeeze().cpu().numpy()
147
+ if reduce_noise > 0:
148
+ out = nr.reduce_noise(y=out, sr=SAMPLE_RATE, prop_decrease=reduce_noise, n_fft=512)
149
+ opening_cut = max(0, int(opening_cut / speed))
150
+ if opening_cut > 0:
151
+ out[:opening_cut] = 0
152
+ closing_cut = max(0, int(closing_cut / speed))
153
+ if closing_cut > 0:
154
+ out = out[-closing_cut:] = 0
155
+ ease_in = min(int(ease_in / speed), len(out)//2 - opening_cut)
156
+ for i in range(ease_in):
157
+ out[i+opening_cut] *= s_curve(i / ease_in)
158
+ ease_out = min(int(ease_out / speed), len(out)//2 - closing_cut)
159
+ for i in range(ease_out):
160
+ out[-i-1-closing_cut] *= s_curve(i / ease_out)
161
+ return ((SAMPLE_RATE, out), ps)
162
+
163
+ with gr.Blocks() as demo:
164
+ with gr.Row():
165
+ with gr.Column():
166
+ text = gr.Textbox(label='Input Text')
167
+ voice = gr.Dropdown(list(CHOICES.items()), label='Voice')
168
+ with gr.Row():
169
+ random_btn = gr.Button('Random Text', variant='secondary')
170
+ generate_btn = gr.Button('Generate', variant='primary')
171
+ random_btn.click(get_random_text, inputs=[voice], outputs=[text])
172
+ with gr.Accordion('Input Phonemes', open=False):
173
+ in_ps = gr.Textbox(show_label=False, info='Override the input text with custom pronunciation. Leave this blank to use the input text instead.')
174
+ with gr.Row():
175
+ clear_btn = gr.ClearButton(in_ps)
176
+ phonemize_btn = gr.Button('Phonemize Input Text', variant='primary')
177
+ phonemize_btn.click(phonemize, inputs=[text, voice], outputs=[in_ps])
178
+ with gr.Column():
179
+ audio = gr.Audio(interactive=False, label='Output Audio')
180
+ with gr.Accordion('Tokens', open=True):
181
+ out_ps = gr.Textbox(interactive=False, show_label=False, info='Tokens used to generate the audio. Same as input phonemes if supplied, excluding unknown characters and truncated to 510 tokens.')
182
+ with gr.Accordion('Advanced Settings', open=False):
183
+ with gr.Row():
184
+ reduce_noise = gr.Slider(minimum=0, maximum=1, value=0.5, label='Reduce Noise', info='👻 Fix it in post: non-stationary noise reduction via spectral gating.')
185
+ with gr.Row():
186
+ speed = gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.1, label='Speed', info='⚡️ Adjust the speed of the audio. The trim settings below are also auto-scaled by speed.')
187
+ with gr.Row():
188
+ with gr.Column():
189
+ opening_cut = gr.Slider(minimum=0, maximum=24000, value=5000, step=1000, label='Opening Cut', info='✂️ Zero out this many samples at the start.')
190
+ with gr.Column():
191
+ closing_cut = gr.Slider(minimum=0, maximum=24000, value=0, step=1000, label='Closing Cut', info='✂️ Zero out this many samples at the end.')
192
+ with gr.Row():
193
+ with gr.Column():
194
+ ease_in = gr.Slider(minimum=0, maximum=24000, value=3000, step=1000, label='Ease In', info='🚀 Ease in for this many samples, after opening cut.')
195
+ with gr.Column():
196
+ ease_out = gr.Slider(minimum=0, maximum=24000, value=0, step=1000, label='Ease Out', info='📐 Ease out for this many samples, before closing cut.')
197
+ generate_btn.click(forward, inputs=[text, voice, in_ps, speed, reduce_noise, opening_cut, closing_cut, ease_in, ease_out], outputs=[audio, out_ps])
198
+
199
+ if __name__ == '__main__':
200
+ demo.launch()
en.txt ADDED
The diff for this file is too large to render. See raw diff
 
istftnet.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
2
+ from scipy.signal import get_window
3
+ from torch.nn import Conv1d, ConvTranspose1d
4
+ from torch.nn.utils import weight_norm, remove_weight_norm
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ # https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
11
+ def init_weights(m, mean=0.0, std=0.01):
12
+ classname = m.__class__.__name__
13
+ if classname.find("Conv") != -1:
14
+ m.weight.data.normal_(mean, std)
15
+
16
+ def get_padding(kernel_size, dilation=1):
17
+ return int((kernel_size*dilation - dilation)/2)
18
+
19
+ LRELU_SLOPE = 0.1
20
+
21
+ class AdaIN1d(nn.Module):
22
+ def __init__(self, style_dim, num_features):
23
+ super().__init__()
24
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
25
+ self.fc = nn.Linear(style_dim, num_features*2)
26
+
27
+ def forward(self, x, s):
28
+ h = self.fc(s)
29
+ h = h.view(h.size(0), h.size(1), 1)
30
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
31
+ return (1 + gamma) * self.norm(x) + beta
32
+
33
+ class AdaINResBlock1(torch.nn.Module):
34
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
35
+ super(AdaINResBlock1, self).__init__()
36
+ self.convs1 = nn.ModuleList([
37
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
38
+ padding=get_padding(kernel_size, dilation[0]))),
39
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
40
+ padding=get_padding(kernel_size, dilation[1]))),
41
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
42
+ padding=get_padding(kernel_size, dilation[2])))
43
+ ])
44
+ self.convs1.apply(init_weights)
45
+
46
+ self.convs2 = nn.ModuleList([
47
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
48
+ padding=get_padding(kernel_size, 1))),
49
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
50
+ padding=get_padding(kernel_size, 1))),
51
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
52
+ padding=get_padding(kernel_size, 1)))
53
+ ])
54
+ self.convs2.apply(init_weights)
55
+
56
+ self.adain1 = nn.ModuleList([
57
+ AdaIN1d(style_dim, channels),
58
+ AdaIN1d(style_dim, channels),
59
+ AdaIN1d(style_dim, channels),
60
+ ])
61
+
62
+ self.adain2 = nn.ModuleList([
63
+ AdaIN1d(style_dim, channels),
64
+ AdaIN1d(style_dim, channels),
65
+ AdaIN1d(style_dim, channels),
66
+ ])
67
+
68
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
69
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
70
+
71
+
72
+ def forward(self, x, s):
73
+ for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
74
+ xt = n1(x, s)
75
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
76
+ xt = c1(xt)
77
+ xt = n2(xt, s)
78
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
79
+ xt = c2(xt)
80
+ x = xt + x
81
+ return x
82
+
83
+ def remove_weight_norm(self):
84
+ for l in self.convs1:
85
+ remove_weight_norm(l)
86
+ for l in self.convs2:
87
+ remove_weight_norm(l)
88
+
89
+ class TorchSTFT(torch.nn.Module):
90
+ def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
91
+ super().__init__()
92
+ self.filter_length = filter_length
93
+ self.hop_length = hop_length
94
+ self.win_length = win_length
95
+ self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
96
+
97
+ def transform(self, input_data):
98
+ forward_transform = torch.stft(
99
+ input_data,
100
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
101
+ return_complex=True)
102
+
103
+ return torch.abs(forward_transform), torch.angle(forward_transform)
104
+
105
+ def inverse(self, magnitude, phase):
106
+ inverse_transform = torch.istft(
107
+ magnitude * torch.exp(phase * 1j),
108
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
109
+
110
+ return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
111
+
112
+ def forward(self, input_data):
113
+ self.magnitude, self.phase = self.transform(input_data)
114
+ reconstruction = self.inverse(self.magnitude, self.phase)
115
+ return reconstruction
116
+
117
+ class SineGen(torch.nn.Module):
118
+ """ Definition of sine generator
119
+ SineGen(samp_rate, harmonic_num = 0,
120
+ sine_amp = 0.1, noise_std = 0.003,
121
+ voiced_threshold = 0,
122
+ flag_for_pulse=False)
123
+ samp_rate: sampling rate in Hz
124
+ harmonic_num: number of harmonic overtones (default 0)
125
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
126
+ noise_std: std of Gaussian noise (default 0.003)
127
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
128
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
129
+ Note: when flag_for_pulse is True, the first time step of a voiced
130
+ segment is always sin(np.pi) or cos(0)
131
+ """
132
+
133
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
134
+ sine_amp=0.1, noise_std=0.003,
135
+ voiced_threshold=0,
136
+ flag_for_pulse=False):
137
+ super(SineGen, self).__init__()
138
+ self.sine_amp = sine_amp
139
+ self.noise_std = noise_std
140
+ self.harmonic_num = harmonic_num
141
+ self.dim = self.harmonic_num + 1
142
+ self.sampling_rate = samp_rate
143
+ self.voiced_threshold = voiced_threshold
144
+ self.flag_for_pulse = flag_for_pulse
145
+ self.upsample_scale = upsample_scale
146
+
147
+ def _f02uv(self, f0):
148
+ # generate uv signal
149
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
150
+ return uv
151
+
152
+ def _f02sine(self, f0_values):
153
+ """ f0_values: (batchsize, length, dim)
154
+ where dim indicates fundamental tone and overtones
155
+ """
156
+ # convert to F0 in rad. The interger part n can be ignored
157
+ # because 2 * np.pi * n doesn't affect phase
158
+ rad_values = (f0_values / self.sampling_rate) % 1
159
+
160
+ # initial phase noise (no noise for fundamental component)
161
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
162
+ device=f0_values.device)
163
+ rand_ini[:, 0] = 0
164
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
165
+
166
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
167
+ if not self.flag_for_pulse:
168
+ # # for normal case
169
+
170
+ # # To prevent torch.cumsum numerical overflow,
171
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
172
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
173
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
174
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
175
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
176
+ # cumsum_shift = torch.zeros_like(rad_values)
177
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
178
+
179
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
180
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
181
+ scale_factor=1/self.upsample_scale,
182
+ mode="linear").transpose(1, 2)
183
+
184
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
185
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
186
+ # cumsum_shift = torch.zeros_like(rad_values)
187
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
188
+
189
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
190
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
191
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
192
+ sines = torch.sin(phase)
193
+
194
+ else:
195
+ # If necessary, make sure that the first time step of every
196
+ # voiced segments is sin(pi) or cos(0)
197
+ # This is used for pulse-train generation
198
+
199
+ # identify the last time step in unvoiced segments
200
+ uv = self._f02uv(f0_values)
201
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
202
+ uv_1[:, -1, :] = 1
203
+ u_loc = (uv < 1) * (uv_1 > 0)
204
+
205
+ # get the instantanouse phase
206
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
207
+ # different batch needs to be processed differently
208
+ for idx in range(f0_values.shape[0]):
209
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
210
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
211
+ # stores the accumulation of i.phase within
212
+ # each voiced segments
213
+ tmp_cumsum[idx, :, :] = 0
214
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
215
+
216
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
217
+ # within the previous voiced segment.
218
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
219
+
220
+ # get the sines
221
+ sines = torch.cos(i_phase * 2 * np.pi)
222
+ return sines
223
+
224
+ def forward(self, f0):
225
+ """ sine_tensor, uv = forward(f0)
226
+ input F0: tensor(batchsize=1, length, dim=1)
227
+ f0 for unvoiced steps should be 0
228
+ output sine_tensor: tensor(batchsize=1, length, dim)
229
+ output uv: tensor(batchsize=1, length, 1)
230
+ """
231
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
232
+ device=f0.device)
233
+ # fundamental component
234
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
235
+
236
+ # generate sine waveforms
237
+ sine_waves = self._f02sine(fn) * self.sine_amp
238
+
239
+ # generate uv signal
240
+ # uv = torch.ones(f0.shape)
241
+ # uv = uv * (f0 > self.voiced_threshold)
242
+ uv = self._f02uv(f0)
243
+
244
+ # noise: for unvoiced should be similar to sine_amp
245
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
246
+ # . for voiced regions is self.noise_std
247
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
248
+ noise = noise_amp * torch.randn_like(sine_waves)
249
+
250
+ # first: set the unvoiced part to 0 by uv
251
+ # then: additive noise
252
+ sine_waves = sine_waves * uv + noise
253
+ return sine_waves, uv, noise
254
+
255
+
256
+ class SourceModuleHnNSF(torch.nn.Module):
257
+ """ SourceModule for hn-nsf
258
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
259
+ add_noise_std=0.003, voiced_threshod=0)
260
+ sampling_rate: sampling_rate in Hz
261
+ harmonic_num: number of harmonic above F0 (default: 0)
262
+ sine_amp: amplitude of sine source signal (default: 0.1)
263
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
264
+ note that amplitude of noise in unvoiced is decided
265
+ by sine_amp
266
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
267
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
268
+ F0_sampled (batchsize, length, 1)
269
+ Sine_source (batchsize, length, 1)
270
+ noise_source (batchsize, length 1)
271
+ uv (batchsize, length, 1)
272
+ """
273
+
274
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
275
+ add_noise_std=0.003, voiced_threshod=0):
276
+ super(SourceModuleHnNSF, self).__init__()
277
+
278
+ self.sine_amp = sine_amp
279
+ self.noise_std = add_noise_std
280
+
281
+ # to produce sine waveforms
282
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
283
+ sine_amp, add_noise_std, voiced_threshod)
284
+
285
+ # to merge source harmonics into a single excitation
286
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
287
+ self.l_tanh = torch.nn.Tanh()
288
+
289
+ def forward(self, x):
290
+ """
291
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
292
+ F0_sampled (batchsize, length, 1)
293
+ Sine_source (batchsize, length, 1)
294
+ noise_source (batchsize, length 1)
295
+ """
296
+ # source for harmonic branch
297
+ with torch.no_grad():
298
+ sine_wavs, uv, _ = self.l_sin_gen(x)
299
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
300
+
301
+ # source for noise branch, in the same shape as uv
302
+ noise = torch.randn_like(uv) * self.sine_amp / 3
303
+ return sine_merge, noise, uv
304
+ def padDiff(x):
305
+ return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
306
+
307
+
308
+ class Generator(torch.nn.Module):
309
+ def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
310
+ super(Generator, self).__init__()
311
+
312
+ self.num_kernels = len(resblock_kernel_sizes)
313
+ self.num_upsamples = len(upsample_rates)
314
+ resblock = AdaINResBlock1
315
+
316
+ self.m_source = SourceModuleHnNSF(
317
+ sampling_rate=24000,
318
+ upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
319
+ harmonic_num=8, voiced_threshod=10)
320
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
321
+ self.noise_convs = nn.ModuleList()
322
+ self.noise_res = nn.ModuleList()
323
+
324
+ self.ups = nn.ModuleList()
325
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
326
+ self.ups.append(weight_norm(
327
+ ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
328
+ k, u, padding=(k-u)//2)))
329
+
330
+ self.resblocks = nn.ModuleList()
331
+ for i in range(len(self.ups)):
332
+ ch = upsample_initial_channel//(2**(i+1))
333
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
334
+ self.resblocks.append(resblock(ch, k, d, style_dim))
335
+
336
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
337
+
338
+ if i + 1 < len(upsample_rates): #
339
+ stride_f0 = np.prod(upsample_rates[i + 1:])
340
+ self.noise_convs.append(Conv1d(
341
+ gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
342
+ self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
343
+ else:
344
+ self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
345
+ self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
346
+
347
+
348
+ self.post_n_fft = gen_istft_n_fft
349
+ self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
350
+ self.ups.apply(init_weights)
351
+ self.conv_post.apply(init_weights)
352
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
353
+ self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
354
+
355
+
356
+ def forward(self, x, s, f0):
357
+ with torch.no_grad():
358
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
359
+
360
+ har_source, noi_source, uv = self.m_source(f0)
361
+ har_source = har_source.transpose(1, 2).squeeze(1)
362
+ har_spec, har_phase = self.stft.transform(har_source)
363
+ har = torch.cat([har_spec, har_phase], dim=1)
364
+
365
+ for i in range(self.num_upsamples):
366
+ x = F.leaky_relu(x, LRELU_SLOPE)
367
+ x_source = self.noise_convs[i](har)
368
+ x_source = self.noise_res[i](x_source, s)
369
+
370
+ x = self.ups[i](x)
371
+ if i == self.num_upsamples - 1:
372
+ x = self.reflection_pad(x)
373
+
374
+ x = x + x_source
375
+ xs = None
376
+ for j in range(self.num_kernels):
377
+ if xs is None:
378
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
379
+ else:
380
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
381
+ x = xs / self.num_kernels
382
+ x = F.leaky_relu(x)
383
+ x = self.conv_post(x)
384
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
385
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
386
+ return self.stft.inverse(spec, phase)
387
+
388
+ def fw_phase(self, x, s):
389
+ for i in range(self.num_upsamples):
390
+ x = F.leaky_relu(x, LRELU_SLOPE)
391
+ x = self.ups[i](x)
392
+ xs = None
393
+ for j in range(self.num_kernels):
394
+ if xs is None:
395
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
396
+ else:
397
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
398
+ x = xs / self.num_kernels
399
+ x = F.leaky_relu(x)
400
+ x = self.reflection_pad(x)
401
+ x = self.conv_post(x)
402
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
403
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
404
+ return spec, phase
405
+
406
+ def remove_weight_norm(self):
407
+ print('Removing weight norm...')
408
+ for l in self.ups:
409
+ remove_weight_norm(l)
410
+ for l in self.resblocks:
411
+ l.remove_weight_norm()
412
+ remove_weight_norm(self.conv_pre)
413
+ remove_weight_norm(self.conv_post)
414
+
415
+
416
+ class AdainResBlk1d(nn.Module):
417
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
418
+ upsample='none', dropout_p=0.0):
419
+ super().__init__()
420
+ self.actv = actv
421
+ self.upsample_type = upsample
422
+ self.upsample = UpSample1d(upsample)
423
+ self.learned_sc = dim_in != dim_out
424
+ self._build_weights(dim_in, dim_out, style_dim)
425
+ self.dropout = nn.Dropout(dropout_p)
426
+
427
+ if upsample == 'none':
428
+ self.pool = nn.Identity()
429
+ else:
430
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
431
+
432
+
433
+ def _build_weights(self, dim_in, dim_out, style_dim):
434
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
435
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
436
+ self.norm1 = AdaIN1d(style_dim, dim_in)
437
+ self.norm2 = AdaIN1d(style_dim, dim_out)
438
+ if self.learned_sc:
439
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
440
+
441
+ def _shortcut(self, x):
442
+ x = self.upsample(x)
443
+ if self.learned_sc:
444
+ x = self.conv1x1(x)
445
+ return x
446
+
447
+ def _residual(self, x, s):
448
+ x = self.norm1(x, s)
449
+ x = self.actv(x)
450
+ x = self.pool(x)
451
+ x = self.conv1(self.dropout(x))
452
+ x = self.norm2(x, s)
453
+ x = self.actv(x)
454
+ x = self.conv2(self.dropout(x))
455
+ return x
456
+
457
+ def forward(self, x, s):
458
+ out = self._residual(x, s)
459
+ out = (out + self._shortcut(x)) / np.sqrt(2)
460
+ return out
461
+
462
+ class UpSample1d(nn.Module):
463
+ def __init__(self, layer_type):
464
+ super().__init__()
465
+ self.layer_type = layer_type
466
+
467
+ def forward(self, x):
468
+ if self.layer_type == 'none':
469
+ return x
470
+ else:
471
+ return F.interpolate(x, scale_factor=2, mode='nearest')
472
+
473
+ class Decoder(nn.Module):
474
+ def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
475
+ resblock_kernel_sizes = [3,7,11],
476
+ upsample_rates = [10, 6],
477
+ upsample_initial_channel=512,
478
+ resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
479
+ upsample_kernel_sizes=[20, 12],
480
+ gen_istft_n_fft=20, gen_istft_hop_size=5):
481
+ super().__init__()
482
+
483
+ self.decode = nn.ModuleList()
484
+
485
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
486
+
487
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
488
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
489
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
490
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
491
+
492
+ self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
493
+
494
+ self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
495
+
496
+ self.asr_res = nn.Sequential(
497
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
498
+ )
499
+
500
+
501
+ self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
502
+ upsample_initial_channel, resblock_dilation_sizes,
503
+ upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
504
+
505
+ def forward(self, asr, F0_curve, N, s):
506
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
507
+ N = self.N_conv(N.unsqueeze(1))
508
+
509
+ x = torch.cat([asr, F0, N], axis=1)
510
+ x = self.encode(x, s)
511
+
512
+ asr_res = self.asr_res(asr)
513
+
514
+ res = True
515
+ for block in self.decode:
516
+ if res:
517
+ x = torch.cat([x, asr_res, F0, N], axis=1)
518
+ x = block(x, s)
519
+ if block.upsample_type != "none":
520
+ res = False
521
+
522
+ x = self.generator(x, s, F0_curve)
523
+ return x
ja.txt ADDED
The diff for this file is too large to render. See raw diff
 
katsu.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/polm/cutlet/blob/master/cutlet/cutlet.py
2
+ from dataclasses import dataclass
3
+ from fugashi import Tagger
4
+ from num2kana import Convert
5
+ import mojimoji
6
+ import re
7
+ import unicodedata
8
+
9
+ HEPBURN = {
10
+ chr(12449):'a', #ァ
11
+ chr(12450):'a', #ア
12
+ chr(12451):'i', #ィ
13
+ chr(12452):'i', #イ
14
+ chr(12453):'ɯ', #ゥ
15
+ chr(12454):'ɯ', #ウ
16
+ chr(12455):'e', #ェ
17
+ chr(12456):'e', #エ
18
+ chr(12457):'o', #ォ
19
+ chr(12458):'o', #オ
20
+ chr(12459):'ka', #カ
21
+ chr(12460):'ɡa', #ガ
22
+ chr(12461):'ki', #キ
23
+ chr(12462):'ɡi', #ギ
24
+ chr(12463):'kɯ', #ク
25
+ chr(12464):'ɡɯ', #グ
26
+ chr(12465):'ke', #ケ
27
+ chr(12466):'ɡe', #ゲ
28
+ chr(12467):'ko', #コ
29
+ chr(12468):'ɡo', #ゴ
30
+ chr(12469):'sa', #サ
31
+ chr(12470):'za', #ザ
32
+ chr(12471):'ɕi', #シ
33
+ chr(12472):'dʑi', #ジ
34
+ chr(12473):'sɨ', #ス
35
+ chr(12474):'zɨ', #ズ
36
+ chr(12475):'se', #セ
37
+ chr(12476):'ze', #ゼ
38
+ chr(12477):'so', #ソ
39
+ chr(12478):'zo', #ゾ
40
+ chr(12479):'ta', #タ
41
+ chr(12480):'da', #ダ
42
+ chr(12481):'tɕi', #チ
43
+ chr(12482):'dʑi', #ヂ
44
+ # chr(12483) #ッ
45
+ chr(12484):'tsɨ', #ツ
46
+ chr(12485):'zɨ', #ヅ
47
+ chr(12486):'te', #テ
48
+ chr(12487):'de', #デ
49
+ chr(12488):'to', #ト
50
+ chr(12489):'do', #ド
51
+ chr(12490):'na', #ナ
52
+ chr(12491):'ɲi', #ニ
53
+ chr(12492):'nɯ', #ヌ
54
+ chr(12493):'ne', #ネ
55
+ chr(12494):'no', #ノ
56
+ chr(12495):'ha', #ハ
57
+ chr(12496):'ba', #バ
58
+ chr(12497):'pa', #パ
59
+ chr(12498):'çi', #ヒ
60
+ chr(12499):'bi', #ビ
61
+ chr(12500):'pi', #ピ
62
+ chr(12501):'ɸɯ', #フ
63
+ chr(12502):'bɯ', #ブ
64
+ chr(12503):'pɯ', #プ
65
+ chr(12504):'he', #ヘ
66
+ chr(12505):'be', #ベ
67
+ chr(12506):'pe', #ペ
68
+ chr(12507):'ho', #ホ
69
+ chr(12508):'bo', #ボ
70
+ chr(12509):'po', #ポ
71
+ chr(12510):'ma', #マ
72
+ chr(12511):'mi', #ミ
73
+ chr(12512):'mɯ', #ム
74
+ chr(12513):'me', #メ
75
+ chr(12514):'mo', #モ
76
+ chr(12515):'ja', #ャ
77
+ chr(12516):'ja', #ヤ
78
+ chr(12517):'jɯ', #ュ
79
+ chr(12518):'jɯ', #ユ
80
+ chr(12519):'jo', #ョ
81
+ chr(12520):'jo', #ヨ
82
+ chr(12521):'ra', #ラ
83
+ chr(12522):'ri', #リ
84
+ chr(12523):'rɯ', #ル
85
+ chr(12524):'re', #レ
86
+ chr(12525):'ro', #ロ
87
+ chr(12526):'wa', #ヮ
88
+ chr(12527):'wa', #ワ
89
+ chr(12528):'i', #ヰ
90
+ chr(12529):'e', #ヱ
91
+ chr(12530):'o', #ヲ
92
+ # chr(12531) #ン
93
+ chr(12532):'vɯ', #ヴ
94
+ chr(12533):'ka', #ヵ
95
+ chr(12534):'ke', #ヶ
96
+ }
97
+ assert len(HEPBURN) == 84 and all(i in {12483, 12531} or chr(i) in HEPBURN for i in range(12449, 12535))
98
+
99
+ for k, v in list(HEPBURN.items()):
100
+ HEPBURN[chr(ord(k)-96)] = v
101
+ assert len(HEPBURN) == 84*2
102
+
103
+ HEPBURN.update({
104
+ chr(12535):'va', #ヷ
105
+ chr(12536):'vi', #ヸ
106
+ chr(12537):'ve', #ヹ
107
+ chr(12538):'vo', #ヺ
108
+ })
109
+ assert len(HEPBURN) == 84*2+4 and all(chr(i) in HEPBURN for i in range(12535, 12539))
110
+
111
+ HEPBURN.update({
112
+ chr(12784):'kɯ', #ㇰ
113
+ chr(12785):'ɕi', #ㇱ
114
+ chr(12786):'sɨ', #ㇲ
115
+ chr(12787):'to', #ㇳ
116
+ chr(12788):'nɯ', #ㇴ
117
+ chr(12789):'ha', #ㇵ
118
+ chr(12790):'çi', #ㇶ
119
+ chr(12791):'ɸɯ', #ㇷ
120
+ chr(12792):'he', #ㇸ
121
+ chr(12793):'ho', #ㇹ
122
+ chr(12794):'mɯ', #ㇺ
123
+ chr(12795):'ra', #ㇻ
124
+ chr(12796):'ri', #ㇼ
125
+ chr(12797):'rɯ', #ㇽ
126
+ chr(12798):'re', #ㇾ
127
+ chr(12799):'ro', #ㇿ
128
+ })
129
+ assert len(HEPBURN) == 84*2+4+16 and all(chr(i) in HEPBURN for i in range(12784, 12800))
130
+
131
+ HEPBURN.update({
132
+ chr(12452)+chr(12455):'je', #イェ
133
+ chr(12454)+chr(12451):'wi', #ウィ
134
+ chr(12454)+chr(12455):'we', #ウェ
135
+ chr(12454)+chr(12457):'wo', #ウォ
136
+ chr(12461)+chr(12455):'kʲe', #キェ
137
+ chr(12461)+chr(12515):'kʲa', #キャ
138
+ chr(12461)+chr(12517):'kʲɨ', #キュ
139
+ chr(12461)+chr(12519):'kʲo', #キョ
140
+ chr(12462)+chr(12515):'ɡʲa', #ギャ
141
+ chr(12462)+chr(12517):'ɡʲɨ', #ギュ
142
+ chr(12462)+chr(12519):'ɡʲo', #ギョ
143
+ chr(12463)+chr(12449):'kʷa', #クァ
144
+ chr(12463)+chr(12451):'kʷi', #クィ
145
+ chr(12463)+chr(12455):'kʷe', #クェ
146
+ chr(12463)+chr(12457):'kʷo', #クォ
147
+ chr(12464)+chr(12449):'ɡʷa', #グァ
148
+ chr(12464)+chr(12451):'ɡʷi', #グィ
149
+ chr(12464)+chr(12455):'ɡʷe', #グェ
150
+ chr(12464)+chr(12457):'ɡʷo', #グォ
151
+ chr(12471)+chr(12455):'ɕe', #シェ
152
+ chr(12471)+chr(12515):'ɕa', #シャ
153
+ chr(12471)+chr(12517):'ɕɨ', #シュ
154
+ chr(12471)+chr(12519):'ɕo', #ショ
155
+ chr(12472)+chr(12455):'dʑe', #ジェ
156
+ chr(12472)+chr(12515):'dʑa', #ジャ
157
+ chr(12472)+chr(12517):'dʑɨ', #ジュ
158
+ chr(12472)+chr(12519):'dʑo', #ジョ
159
+ chr(12481)+chr(12455):'tɕe', #チェ
160
+ chr(12481)+chr(12515):'tɕa', #チャ
161
+ chr(12481)+chr(12517):'tɕɨ', #チュ
162
+ chr(12481)+chr(12519):'tɕo', #チョ
163
+ chr(12482)+chr(12515):'dʑa', #ヂャ
164
+ chr(12482)+chr(12517):'dʑɨ', #ヂュ
165
+ chr(12482)+chr(12519):'dʑo', #ヂョ
166
+ chr(12484)+chr(12449):'tsa', #ツァ
167
+ chr(12484)+chr(12451):'tsi', #ツィ
168
+ chr(12484)+chr(12455):'tse', #ツェ
169
+ chr(12484)+chr(12457):'tso', #ツォ
170
+ chr(12486)+chr(12451):'ti', #ティ
171
+ chr(12486)+chr(12517):'tʲɨ', #テュ
172
+ chr(12487)+chr(12451):'di', #ディ
173
+ chr(12487)+chr(12517):'dʲɨ', #デュ
174
+ chr(12488)+chr(12453):'tɯ', #トゥ
175
+ chr(12489)+chr(12453):'dɯ', #ドゥ
176
+ chr(12491)+chr(12455):'ɲe', #ニェ
177
+ chr(12491)+chr(12515):'ɲa', #ニャ
178
+ chr(12491)+chr(12517):'ɲɨ', #ニュ
179
+ chr(12491)+chr(12519):'ɲo', #ニョ
180
+ chr(12498)+chr(12455):'çe', #ヒェ
181
+ chr(12498)+chr(12515):'ça', #ヒャ
182
+ chr(12498)+chr(12517):'çɨ', #ヒュ
183
+ chr(12498)+chr(12519):'ço', #ヒョ
184
+ chr(12499)+chr(12515):'bʲa', #ビャ
185
+ chr(12499)+chr(12517):'bʲɨ', #ビュ
186
+ chr(12499)+chr(12519):'bʲo', #ビョ
187
+ chr(12500)+chr(12515):'pʲa', #ピャ
188
+ chr(12500)+chr(12517):'pʲɨ', #ピュ
189
+ chr(12500)+chr(12519):'pʲo', #ピョ
190
+ chr(12501)+chr(12449):'ɸa', #ファ
191
+ chr(12501)+chr(12451):'ɸi', #フィ
192
+ chr(12501)+chr(12455):'ɸe', #フェ
193
+ chr(12501)+chr(12457):'ɸo', #フォ
194
+ chr(12501)+chr(12517):'ɸʲɨ', #フュ
195
+ chr(12501)+chr(12519):'ɸʲo', #フョ
196
+ chr(12511)+chr(12515):'mʲa', #ミャ
197
+ chr(12511)+chr(12517):'mʲɨ', #ミュ
198
+ chr(12511)+chr(12519):'mʲo', #ミョ
199
+ chr(12522)+chr(12515):'rʲa', #リャ
200
+ chr(12522)+chr(12517):'rʲɨ', #リュ
201
+ chr(12522)+chr(12519):'rʲo', #リョ
202
+ chr(12532)+chr(12449):'va', #ヴァ
203
+ chr(12532)+chr(12451):'vi', #ヴィ
204
+ chr(12532)+chr(12455):'ve', #ヴェ
205
+ chr(12532)+chr(12457):'vo', #ヴォ
206
+ chr(12532)+chr(12517):'vʲɨ', #ヴュ
207
+ chr(12532)+chr(12519):'vʲo', #ヴョ
208
+ })
209
+ assert len(HEPBURN) == 84*2+4+16+76
210
+
211
+ for k, v in list(HEPBURN.items()):
212
+ if len(k) != 2:
213
+ continue
214
+ a, b = k
215
+ assert a in HEPBURN and b in HEPBURN, (a, b)
216
+ a = chr(ord(a)-96)
217
+ b = chr(ord(b)-96)
218
+ assert a in HEPBURN and b in HEPBURN, (a, b)
219
+ HEPBURN[a+b] = v
220
+ assert len(HEPBURN) == 84*2+4+16+76*2
221
+
222
+ HEPBURN.update({
223
+ # symbols
224
+ # 'ー': '-', # 長音符, only used when repeated
225
+ '。': '.',
226
+ '、': ',',
227
+ '?': '?',
228
+ '!': '!',
229
+ '「': '"',
230
+ '」': '"',
231
+ '『': '"',
232
+ '』': '"',
233
+ ':': ':',
234
+ '(': '(',
235
+ ')': ')',
236
+ '《': '(',
237
+ '》': ')',
238
+ '【': '[',
239
+ '】': ']',
240
+ '・': ' ',#'/',
241
+ ',': ',',
242
+ '~': '—',
243
+ '〜': '—',
244
+ '—': '—',
245
+ '«': '«',
246
+ '»': '»',
247
+
248
+ # other
249
+ '゚': '', # combining handakuten by itself, just discard
250
+ '゙': '', # combining dakuten by itself
251
+ })
252
+
253
+ def add_dakuten(kk):
254
+ """Given a kana (single-character string), add a dakuten."""
255
+ try:
256
+ # ii = 'かきくけこさしすせそたちつてとはひふへほ'.index(kk)
257
+ ii = 'カキクケコサシスセソタチツテトハヒフヘホ'.index(kk)
258
+ return 'ガギグゲゴザジズゼゾダヂヅデドバビブベボ'[ii]
259
+ # return 'がぎぐげござじずぜぞだぢづでどばびぶべぼ'[ii]
260
+ except ValueError:
261
+ # this is normal if the input is nonsense
262
+ return None
263
+
264
+ SUTEGANA = 'ャュョァィゥェォ' #'ゃゅょぁぃぅぇぉ'
265
+ PUNCT = '\'".!?(),;:-'
266
+ ODORI = '々〃ゝゞヽゞ'
267
+
268
+ @dataclass
269
+ class Token:
270
+ surface: str
271
+ space: bool # if a space should follow
272
+ def __str__(self):
273
+ sp = " " if self.space else ""
274
+ return f"{self.surface}{sp}"
275
+
276
+ class Katsu:
277
+ def __init__(self):
278
+ """Create a Katsu object, which holds configuration as well as
279
+ tokenizer state.
280
+
281
+ Typical usage:
282
+
283
+ ```python
284
+ katsu = Katsu()
285
+ roma = katsu.romaji("カツカレーを食べた")
286
+ # "Cutlet curry wo tabeta"
287
+ ```
288
+ """
289
+ self.tagger = Tagger()
290
+ self.table = dict(HEPBURN) # make a copy so we can modify it
291
+ self.exceptions = {}
292
+
293
+ def romaji(self, text):
294
+ """Build a complete string from input text."""
295
+ if not text:
296
+ return ''
297
+ text = self._normalize_text(text)
298
+ words = self.tagger(text)
299
+ tokens = self._romaji_tokens(words)
300
+ out = ''.join([str(tok) for tok in tokens])
301
+ return re.sub(r'\s+', ' ', out.strip())
302
+
303
+ def phonemize(self, texts):
304
+ # espeak-ng API
305
+ return [self.romaji(text) for text in texts]
306
+
307
+ def _normalize_text(self, text):
308
+ """Given text, normalize variations in Japanese.
309
+
310
+ This specifically removes variations that are meaningless for romaji
311
+ conversion using the following steps:
312
+
313
+ - Unicode NFKC normalization
314
+ - Full-width Latin to half-width
315
+ - Half-width katakana to full-width
316
+ """
317
+ # perform unicode normalization
318
+ text = re.sub(r'[〜~](?=\d)', 'から', text) # wave dash range
319
+ text = unicodedata.normalize('NFKC', text)
320
+ # convert all full-width alphanum to half-width, since it can go out as-is
321
+ text = mojimoji.zen_to_han(text, kana=False)
322
+ # replace half-width katakana with full-width
323
+ text = mojimoji.han_to_zen(text, digit=False, ascii=False)
324
+ return ''.join([(' '+Convert(t)) if t.isdigit() else t for t in re.findall(r'\d+|\D+', text)])
325
+
326
+ def _romaji_tokens(self, words):
327
+ """Build a list of tokens from input nodes."""
328
+ out = []
329
+ for wi, word in enumerate(words):
330
+ po = out[-1] if out else None
331
+ pw = words[wi - 1] if wi > 0 else None
332
+ nw = words[wi + 1] if wi < len(words) - 1 else None
333
+ roma = self._romaji_word(word)
334
+ tok = Token(roma, False)
335
+ # handle punctuation with atypical spacing
336
+ surface = word.surface#['orig']
337
+ if surface in '「『' or roma in '([':
338
+ if po:
339
+ po.space = True
340
+ elif surface in '」』' or roma in ']).,?!:':
341
+ if po:
342
+ po.space = False
343
+ tok.space = True
344
+ elif roma == ' ':
345
+ tok.space = False
346
+ else:
347
+ tok.space = True
348
+ out.append(tok)
349
+ # remove any leftover sokuon
350
+ for tok in out:
351
+ tok.surface = tok.surface.replace(chr(12483), '')
352
+ return out
353
+
354
+ def _romaji_word(self, word):
355
+ """Return the romaji for a single word (node)."""
356
+ surface = word.surface#['orig']
357
+ if surface in self.exceptions:
358
+ return self.exceptions[surface]
359
+ assert not surface.isdigit(), surface
360
+ if surface.isascii():
361
+ return surface
362
+ kana = word.feature.pron or word.feature.kana or surface
363
+ if word.is_unk:
364
+ if word.char_type == 7: # katakana
365
+ pass
366
+ elif word.char_type == 3: # symbol
367
+ return ''.join(map(lambda c: self.table.get(c, c), surface))
368
+ else:
369
+ return '' # TODO: silently fail
370
+ out = ''
371
+ for ki, char in enumerate(kana):
372
+ nk = kana[ki + 1] if ki < len(kana) - 1 else None
373
+ pk = kana[ki - 1] if ki > 0 else None
374
+ out += self._get_single_mapping(pk, char, nk)
375
+ return out
376
+
377
+ def _get_single_mapping(self, pk, kk, nk):
378
+ """Given a single kana and its neighbors, return the mapped romaji."""
379
+ # handle odoriji
380
+ # NOTE: This is very rarely useful at present because odoriji are not
381
+ # left in readings for dictionary words, and we can't follow kana
382
+ # across word boundaries.
383
+ if kk in ODORI:
384
+ if kk in 'ゝヽ':
385
+ if pk: return pk
386
+ else: return '' # invalid but be nice
387
+ if kk in 'ゞヾ': # repeat with voicing
388
+ if not pk: return ''
389
+ vv = add_dakuten(pk)
390
+ if vv: return self.table[vv]
391
+ else: return ''
392
+ # remaining are 々 for kanji and 〃 for symbols, but we can't
393
+ # infer their span reliably (or handle rendaku)
394
+ return ''
395
+ # handle digraphs
396
+ if pk and (pk + kk) in self.table:
397
+ return self.table[pk + kk]
398
+ if nk and (kk + nk) in self.table:
399
+ return ''
400
+ if nk and nk in SUTEGANA:
401
+ if kk == 'ッ': return '' # never valid, just ignore
402
+ return self.table[kk][:-1] + self.table[nk]
403
+ if kk in SUTEGANA:
404
+ return ''
405
+ if kk == 'ー': # 長音符
406
+ return 'ː'
407
+ if ord(kk) in {12387, 12483}: # っ or ッ
408
+ tnk = self.table.get(nk)
409
+ if tnk and tnk[0] in 'bdɸɡhçijkmnɲoprstɯvwz':
410
+ return tnk[0]
411
+ return kk
412
+ if ord(kk) in {12435, 12531}: # ん or ン
413
+ # https://en.wikipedia.org/wiki/N_(kana)
414
+ # m before m,p,b
415
+ # ŋ before k,g
416
+ # ɲ before ɲ,tɕ,dʑ
417
+ # n before n,t,d,r,z
418
+ # ɴ otherwise
419
+ tnk = self.table.get(nk)
420
+ if tnk:
421
+ if tnk[0] in 'mpb':
422
+ return 'm'
423
+ elif tnk[0] in 'kɡ':
424
+ return 'ŋ'
425
+ elif any(tnk.startswith(p) for p in ('ɲ','tɕ','dʑ')):
426
+ return 'ɲ'
427
+ elif tnk[0] in 'ntdrz':
428
+ return 'n'
429
+ return 'ɴ'
430
+ return self.table.get(kk, '')
models.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/yl4579/StyleTTS2/blob/main/models.py
2
+ from istftnet import Decoder
3
+ from munch import Munch
4
+ from plbert import load_plbert
5
+ from torch.nn.utils import weight_norm, spectral_norm
6
+ import numpy as np
7
+ import os
8
+ import os.path as osp
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ class LearnedDownSample(nn.Module):
14
+ def __init__(self, layer_type, dim_in):
15
+ super().__init__()
16
+ self.layer_type = layer_type
17
+
18
+ if self.layer_type == 'none':
19
+ self.conv = nn.Identity()
20
+ elif self.layer_type == 'timepreserve':
21
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
22
+ elif self.layer_type == 'half':
23
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
24
+ else:
25
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
26
+
27
+ def forward(self, x):
28
+ return self.conv(x)
29
+
30
+ class LearnedUpSample(nn.Module):
31
+ def __init__(self, layer_type, dim_in):
32
+ super().__init__()
33
+ self.layer_type = layer_type
34
+
35
+ if self.layer_type == 'none':
36
+ self.conv = nn.Identity()
37
+ elif self.layer_type == 'timepreserve':
38
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
39
+ elif self.layer_type == 'half':
40
+ self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
41
+ else:
42
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
43
+
44
+
45
+ def forward(self, x):
46
+ return self.conv(x)
47
+
48
+ class DownSample(nn.Module):
49
+ def __init__(self, layer_type):
50
+ super().__init__()
51
+ self.layer_type = layer_type
52
+
53
+ def forward(self, x):
54
+ if self.layer_type == 'none':
55
+ return x
56
+ elif self.layer_type == 'timepreserve':
57
+ return F.avg_pool2d(x, (2, 1))
58
+ elif self.layer_type == 'half':
59
+ if x.shape[-1] % 2 != 0:
60
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
61
+ return F.avg_pool2d(x, 2)
62
+ else:
63
+ raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
64
+
65
+
66
+ class UpSample(nn.Module):
67
+ def __init__(self, layer_type):
68
+ super().__init__()
69
+ self.layer_type = layer_type
70
+
71
+ def forward(self, x):
72
+ if self.layer_type == 'none':
73
+ return x
74
+ elif self.layer_type == 'timepreserve':
75
+ return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
76
+ elif self.layer_type == 'half':
77
+ return F.interpolate(x, scale_factor=2, mode='nearest')
78
+ else:
79
+ raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
80
+
81
+
82
+ class ResBlk(nn.Module):
83
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
84
+ normalize=False, downsample='none'):
85
+ super().__init__()
86
+ self.actv = actv
87
+ self.normalize = normalize
88
+ self.downsample = DownSample(downsample)
89
+ self.downsample_res = LearnedDownSample(downsample, dim_in)
90
+ self.learned_sc = dim_in != dim_out
91
+ self._build_weights(dim_in, dim_out)
92
+
93
+ def _build_weights(self, dim_in, dim_out):
94
+ self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
95
+ self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
96
+ if self.normalize:
97
+ self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
98
+ self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
99
+ if self.learned_sc:
100
+ self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
101
+
102
+ def _shortcut(self, x):
103
+ if self.learned_sc:
104
+ x = self.conv1x1(x)
105
+ if self.downsample:
106
+ x = self.downsample(x)
107
+ return x
108
+
109
+ def _residual(self, x):
110
+ if self.normalize:
111
+ x = self.norm1(x)
112
+ x = self.actv(x)
113
+ x = self.conv1(x)
114
+ x = self.downsample_res(x)
115
+ if self.normalize:
116
+ x = self.norm2(x)
117
+ x = self.actv(x)
118
+ x = self.conv2(x)
119
+ return x
120
+
121
+ def forward(self, x):
122
+ x = self._shortcut(x) + self._residual(x)
123
+ return x / np.sqrt(2) # unit variance
124
+
125
+ class LinearNorm(torch.nn.Module):
126
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
127
+ super(LinearNorm, self).__init__()
128
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
129
+
130
+ torch.nn.init.xavier_uniform_(
131
+ self.linear_layer.weight,
132
+ gain=torch.nn.init.calculate_gain(w_init_gain))
133
+
134
+ def forward(self, x):
135
+ return self.linear_layer(x)
136
+
137
+ class Discriminator2d(nn.Module):
138
+ def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
139
+ super().__init__()
140
+ blocks = []
141
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
142
+
143
+ for lid in range(repeat_num):
144
+ dim_out = min(dim_in*2, max_conv_dim)
145
+ blocks += [ResBlk(dim_in, dim_out, downsample='half')]
146
+ dim_in = dim_out
147
+
148
+ blocks += [nn.LeakyReLU(0.2)]
149
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
150
+ blocks += [nn.LeakyReLU(0.2)]
151
+ blocks += [nn.AdaptiveAvgPool2d(1)]
152
+ blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
153
+ self.main = nn.Sequential(*blocks)
154
+
155
+ def get_feature(self, x):
156
+ features = []
157
+ for l in self.main:
158
+ x = l(x)
159
+ features.append(x)
160
+ out = features[-1]
161
+ out = out.view(out.size(0), -1) # (batch, num_domains)
162
+ return out, features
163
+
164
+ def forward(self, x):
165
+ out, features = self.get_feature(x)
166
+ out = out.squeeze() # (batch)
167
+ return out, features
168
+
169
+ class ResBlk1d(nn.Module):
170
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
171
+ normalize=False, downsample='none', dropout_p=0.2):
172
+ super().__init__()
173
+ self.actv = actv
174
+ self.normalize = normalize
175
+ self.downsample_type = downsample
176
+ self.learned_sc = dim_in != dim_out
177
+ self._build_weights(dim_in, dim_out)
178
+ self.dropout_p = dropout_p
179
+
180
+ if self.downsample_type == 'none':
181
+ self.pool = nn.Identity()
182
+ else:
183
+ self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
184
+
185
+ def _build_weights(self, dim_in, dim_out):
186
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
187
+ self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
188
+ if self.normalize:
189
+ self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
190
+ self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
191
+ if self.learned_sc:
192
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
193
+
194
+ def downsample(self, x):
195
+ if self.downsample_type == 'none':
196
+ return x
197
+ else:
198
+ if x.shape[-1] % 2 != 0:
199
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
200
+ return F.avg_pool1d(x, 2)
201
+
202
+ def _shortcut(self, x):
203
+ if self.learned_sc:
204
+ x = self.conv1x1(x)
205
+ x = self.downsample(x)
206
+ return x
207
+
208
+ def _residual(self, x):
209
+ if self.normalize:
210
+ x = self.norm1(x)
211
+ x = self.actv(x)
212
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
213
+
214
+ x = self.conv1(x)
215
+ x = self.pool(x)
216
+ if self.normalize:
217
+ x = self.norm2(x)
218
+
219
+ x = self.actv(x)
220
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
221
+
222
+ x = self.conv2(x)
223
+ return x
224
+
225
+ def forward(self, x):
226
+ x = self._shortcut(x) + self._residual(x)
227
+ return x / np.sqrt(2) # unit variance
228
+
229
+ class LayerNorm(nn.Module):
230
+ def __init__(self, channels, eps=1e-5):
231
+ super().__init__()
232
+ self.channels = channels
233
+ self.eps = eps
234
+
235
+ self.gamma = nn.Parameter(torch.ones(channels))
236
+ self.beta = nn.Parameter(torch.zeros(channels))
237
+
238
+ def forward(self, x):
239
+ x = x.transpose(1, -1)
240
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
241
+ return x.transpose(1, -1)
242
+
243
+ class TextEncoder(nn.Module):
244
+ def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
245
+ super().__init__()
246
+ self.embedding = nn.Embedding(n_symbols, channels)
247
+
248
+ padding = (kernel_size - 1) // 2
249
+ self.cnn = nn.ModuleList()
250
+ for _ in range(depth):
251
+ self.cnn.append(nn.Sequential(
252
+ weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
253
+ LayerNorm(channels),
254
+ actv,
255
+ nn.Dropout(0.2),
256
+ ))
257
+ # self.cnn = nn.Sequential(*self.cnn)
258
+
259
+ self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
260
+
261
+ def forward(self, x, input_lengths, m):
262
+ x = self.embedding(x) # [B, T, emb]
263
+ x = x.transpose(1, 2) # [B, emb, T]
264
+ m = m.to(input_lengths.device).unsqueeze(1)
265
+ x.masked_fill_(m, 0.0)
266
+
267
+ for c in self.cnn:
268
+ x = c(x)
269
+ x.masked_fill_(m, 0.0)
270
+
271
+ x = x.transpose(1, 2) # [B, T, chn]
272
+
273
+ input_lengths = input_lengths.cpu().numpy()
274
+ x = nn.utils.rnn.pack_padded_sequence(
275
+ x, input_lengths, batch_first=True, enforce_sorted=False)
276
+
277
+ self.lstm.flatten_parameters()
278
+ x, _ = self.lstm(x)
279
+ x, _ = nn.utils.rnn.pad_packed_sequence(
280
+ x, batch_first=True)
281
+
282
+ x = x.transpose(-1, -2)
283
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
284
+
285
+ x_pad[:, :, :x.shape[-1]] = x
286
+ x = x_pad.to(x.device)
287
+
288
+ x.masked_fill_(m, 0.0)
289
+
290
+ return x
291
+
292
+ def inference(self, x):
293
+ x = self.embedding(x)
294
+ x = x.transpose(1, 2)
295
+ x = self.cnn(x)
296
+ x = x.transpose(1, 2)
297
+ self.lstm.flatten_parameters()
298
+ x, _ = self.lstm(x)
299
+ return x
300
+
301
+ def length_to_mask(self, lengths):
302
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
303
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
304
+ return mask
305
+
306
+
307
+
308
+ class AdaIN1d(nn.Module):
309
+ def __init__(self, style_dim, num_features):
310
+ super().__init__()
311
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
312
+ self.fc = nn.Linear(style_dim, num_features*2)
313
+
314
+ def forward(self, x, s):
315
+ h = self.fc(s)
316
+ h = h.view(h.size(0), h.size(1), 1)
317
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
318
+ return (1 + gamma) * self.norm(x) + beta
319
+
320
+ class UpSample1d(nn.Module):
321
+ def __init__(self, layer_type):
322
+ super().__init__()
323
+ self.layer_type = layer_type
324
+
325
+ def forward(self, x):
326
+ if self.layer_type == 'none':
327
+ return x
328
+ else:
329
+ return F.interpolate(x, scale_factor=2, mode='nearest')
330
+
331
+ class AdainResBlk1d(nn.Module):
332
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
333
+ upsample='none', dropout_p=0.0):
334
+ super().__init__()
335
+ self.actv = actv
336
+ self.upsample_type = upsample
337
+ self.upsample = UpSample1d(upsample)
338
+ self.learned_sc = dim_in != dim_out
339
+ self._build_weights(dim_in, dim_out, style_dim)
340
+ self.dropout = nn.Dropout(dropout_p)
341
+
342
+ if upsample == 'none':
343
+ self.pool = nn.Identity()
344
+ else:
345
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
346
+
347
+
348
+ def _build_weights(self, dim_in, dim_out, style_dim):
349
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
350
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
351
+ self.norm1 = AdaIN1d(style_dim, dim_in)
352
+ self.norm2 = AdaIN1d(style_dim, dim_out)
353
+ if self.learned_sc:
354
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
355
+
356
+ def _shortcut(self, x):
357
+ x = self.upsample(x)
358
+ if self.learned_sc:
359
+ x = self.conv1x1(x)
360
+ return x
361
+
362
+ def _residual(self, x, s):
363
+ x = self.norm1(x, s)
364
+ x = self.actv(x)
365
+ x = self.pool(x)
366
+ x = self.conv1(self.dropout(x))
367
+ x = self.norm2(x, s)
368
+ x = self.actv(x)
369
+ x = self.conv2(self.dropout(x))
370
+ return x
371
+
372
+ def forward(self, x, s):
373
+ out = self._residual(x, s)
374
+ out = (out + self._shortcut(x)) / np.sqrt(2)
375
+ return out
376
+
377
+ class AdaLayerNorm(nn.Module):
378
+ def __init__(self, style_dim, channels, eps=1e-5):
379
+ super().__init__()
380
+ self.channels = channels
381
+ self.eps = eps
382
+
383
+ self.fc = nn.Linear(style_dim, channels*2)
384
+
385
+ def forward(self, x, s):
386
+ x = x.transpose(-1, -2)
387
+ x = x.transpose(1, -1)
388
+
389
+ h = self.fc(s)
390
+ h = h.view(h.size(0), h.size(1), 1)
391
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
392
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
393
+
394
+
395
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
396
+ x = (1 + gamma) * x + beta
397
+ return x.transpose(1, -1).transpose(-1, -2)
398
+
399
+ class ProsodyPredictor(nn.Module):
400
+
401
+ def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
402
+ super().__init__()
403
+
404
+ self.text_encoder = DurationEncoder(sty_dim=style_dim,
405
+ d_model=d_hid,
406
+ nlayers=nlayers,
407
+ dropout=dropout)
408
+
409
+ self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
410
+ self.duration_proj = LinearNorm(d_hid, max_dur)
411
+
412
+ self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
413
+ self.F0 = nn.ModuleList()
414
+ self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
415
+ self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
416
+ self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
417
+
418
+ self.N = nn.ModuleList()
419
+ self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
420
+ self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
421
+ self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
422
+
423
+ self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
424
+ self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
425
+
426
+
427
+ def forward(self, texts, style, text_lengths, alignment, m):
428
+ d = self.text_encoder(texts, style, text_lengths, m)
429
+
430
+ batch_size = d.shape[0]
431
+ text_size = d.shape[1]
432
+
433
+ # predict duration
434
+ input_lengths = text_lengths.cpu().numpy()
435
+ x = nn.utils.rnn.pack_padded_sequence(
436
+ d, input_lengths, batch_first=True, enforce_sorted=False)
437
+
438
+ m = m.to(text_lengths.device).unsqueeze(1)
439
+
440
+ self.lstm.flatten_parameters()
441
+ x, _ = self.lstm(x)
442
+ x, _ = nn.utils.rnn.pad_packed_sequence(
443
+ x, batch_first=True)
444
+
445
+ x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
446
+
447
+ x_pad[:, :x.shape[1], :] = x
448
+ x = x_pad.to(x.device)
449
+
450
+ duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
451
+
452
+ en = (d.transpose(-1, -2) @ alignment)
453
+
454
+ return duration.squeeze(-1), en
455
+
456
+ def F0Ntrain(self, x, s):
457
+ x, _ = self.shared(x.transpose(-1, -2))
458
+
459
+ F0 = x.transpose(-1, -2)
460
+ for block in self.F0:
461
+ F0 = block(F0, s)
462
+ F0 = self.F0_proj(F0)
463
+
464
+ N = x.transpose(-1, -2)
465
+ for block in self.N:
466
+ N = block(N, s)
467
+ N = self.N_proj(N)
468
+
469
+ return F0.squeeze(1), N.squeeze(1)
470
+
471
+ def length_to_mask(self, lengths):
472
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
473
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
474
+ return mask
475
+
476
+ class DurationEncoder(nn.Module):
477
+
478
+ def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
479
+ super().__init__()
480
+ self.lstms = nn.ModuleList()
481
+ for _ in range(nlayers):
482
+ self.lstms.append(nn.LSTM(d_model + sty_dim,
483
+ d_model // 2,
484
+ num_layers=1,
485
+ batch_first=True,
486
+ bidirectional=True,
487
+ dropout=dropout))
488
+ self.lstms.append(AdaLayerNorm(sty_dim, d_model))
489
+
490
+
491
+ self.dropout = dropout
492
+ self.d_model = d_model
493
+ self.sty_dim = sty_dim
494
+
495
+ def forward(self, x, style, text_lengths, m):
496
+ masks = m.to(text_lengths.device)
497
+
498
+ x = x.permute(2, 0, 1)
499
+ s = style.expand(x.shape[0], x.shape[1], -1)
500
+ x = torch.cat([x, s], axis=-1)
501
+ x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
502
+
503
+ x = x.transpose(0, 1)
504
+ input_lengths = text_lengths.cpu().numpy()
505
+ x = x.transpose(-1, -2)
506
+
507
+ for block in self.lstms:
508
+ if isinstance(block, AdaLayerNorm):
509
+ x = block(x.transpose(-1, -2), style).transpose(-1, -2)
510
+ x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
511
+ x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
512
+ else:
513
+ x = x.transpose(-1, -2)
514
+ x = nn.utils.rnn.pack_padded_sequence(
515
+ x, input_lengths, batch_first=True, enforce_sorted=False)
516
+ block.flatten_parameters()
517
+ x, _ = block(x)
518
+ x, _ = nn.utils.rnn.pad_packed_sequence(
519
+ x, batch_first=True)
520
+ x = F.dropout(x, p=self.dropout, training=self.training)
521
+ x = x.transpose(-1, -2)
522
+
523
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
524
+
525
+ x_pad[:, :, :x.shape[-1]] = x
526
+ x = x_pad.to(x.device)
527
+
528
+ return x.transpose(-1, -2)
529
+
530
+ def inference(self, x, style):
531
+ x = self.embedding(x.transpose(-1, -2)) * np.sqrt(self.d_model)
532
+ style = style.expand(x.shape[0], x.shape[1], -1)
533
+ x = torch.cat([x, style], axis=-1)
534
+ src = self.pos_encoder(x)
535
+ output = self.transformer_encoder(src).transpose(0, 1)
536
+ return output
537
+
538
+ def length_to_mask(self, lengths):
539
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
540
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
541
+ return mask
542
+
543
+ # https://github.com/yl4579/StyleTTS2/blob/main/utils.py
544
+ def recursive_munch(d):
545
+ if isinstance(d, dict):
546
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
547
+ elif isinstance(d, list):
548
+ return [recursive_munch(v) for v in d]
549
+ else:
550
+ return d
551
+
552
+ def build_model(args):
553
+ args = recursive_munch(args)
554
+ assert args.decoder.type == 'istftnet', 'Decoder type unknown'
555
+ decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
556
+ resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
557
+ upsample_rates = args.decoder.upsample_rates,
558
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
559
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
560
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
561
+ gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
562
+ text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
563
+ predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
564
+ bert = load_plbert()
565
+ return Munch(
566
+ bert=bert,
567
+ bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
568
+ predictor=predictor,
569
+ decoder=decoder,
570
+ text_encoder=text_encoder,
571
+ )
num2kana.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/Greatdane/Convert-Numbers-to-Japanese/blob/master/Convert-Numbers-to-Japanese.py
2
+ # Japanese Number Converter
3
+ # - Currently using length functions - possible to use Recursive Functions? Difficult with the many exceptions
4
+ # - Works up to 9 figures (999999999)
5
+
6
+ romaji_dict = {".": "ten", "0": "zero", "1": "ichi", "2": "ni", "3": "san", "4": "yon", "5": "go", "6": "roku", "7": "nana",
7
+ "8": "hachi", "9": "kyuu", "10": "juu", "100": "hyaku", "1000": "sen", "10000": "man", "100000000": "oku",
8
+ "300": "sanbyaku", "600": "roppyaku", "800": "happyaku", "3000": "sanzen", "8000":"hassen", "01000": "issen"}
9
+
10
+ kanji_dict = {".": "点", "0": "零", "1": "一", "2": "二", "3": "三", "4": "四", "5": "五", "6": "六", "7": "七",
11
+ "8": "八", "9": "九", "10": "十", "100": "百", "1000": "千", "10000": "万", "100000000": "億",
12
+ "300": "三百", "600": "六百", "800": "八百", "3000": "三千", "8000":"八千", "01000": "一千"}
13
+
14
+ hiragana_dict = {".": "てん", "0": "ゼロ", "1": "いち", "2": "に", "3": "さん", "4": "よん", "5": "ご", "6": "ろく", "7": "なな",
15
+ "8": "はち", "9": "きゅう", "10": "じゅう", "100": "ひゃく", "1000": "せん", "10000": "まん", "100000000": "おく",
16
+ "300": "さんびゃく", "600": "ろっぴゃく", "800": "はっぴゃく", "3000": "さんぜん", "8000":"はっせん", "01000": "いっせん" }
17
+
18
+ key_dict = {"kanji" : kanji_dict, "hiragana" : hiragana_dict, "romaji": romaji_dict}
19
+
20
+ def len_one(convert_num,requested_dict):
21
+ # Returns single digit conversion, 0-9
22
+ return requested_dict[convert_num]
23
+
24
+ def len_two(convert_num,requested_dict):
25
+ # Returns the conversion, when number is of length two (10-99)
26
+ if convert_num[0] == "0": #if 0 is first, return len_one
27
+ return len_one(convert_num[1],requested_dict)
28
+ if convert_num == "10":
29
+ return requested_dict["10"] # Exception, if number is 10, simple return 10
30
+ if convert_num[0] == "1": # When first number is 1, use ten plus second number
31
+ return requested_dict["10"] + " " + len_one(convert_num[1],requested_dict)
32
+ elif convert_num[1] == "0": # If ending number is zero, give first number plus 10
33
+ return len_one(convert_num[0],requested_dict) + " " + requested_dict["10"]
34
+ else:
35
+ num_list = []
36
+ for x in convert_num:
37
+ num_list.append(requested_dict[x])
38
+ num_list.insert(1, requested_dict["10"])
39
+ # Convert to a string (from a list)
40
+ output = ""
41
+ for y in num_list:
42
+ output += y + " "
43
+ output = output[:len(output) - 1] # take off the space
44
+ return output
45
+
46
+ def len_three(convert_num,requested_dict):
47
+ # Returns the conversion, when number is of length three (100-999)
48
+ num_list = []
49
+ if convert_num[0] == "1":
50
+ num_list.append(requested_dict["100"])
51
+ elif convert_num[0] == "3":
52
+ num_list.append(requested_dict["300"])
53
+ elif convert_num[0] == "6":
54
+ num_list.append(requested_dict["600"])
55
+ elif convert_num[0] == "8":
56
+ num_list.append(requested_dict["800"])
57
+ else:
58
+ num_list.append(requested_dict[convert_num[0]])
59
+ num_list.append(requested_dict["100"])
60
+ if convert_num[1:] == "00" and len(convert_num) == 3:
61
+ pass
62
+ else:
63
+ if convert_num[1] == "0":
64
+ num_list.append(requested_dict[convert_num[2]])
65
+ else:
66
+ num_list.append(len_two(convert_num[1:], requested_dict))
67
+ output = ""
68
+ for y in num_list:
69
+ output += y + " "
70
+ output = output[:len(output) - 1]
71
+ return output
72
+
73
+ def len_four(convert_num,requested_dict, stand_alone):
74
+ # Returns the conversion, when number is of length four (1000-9999)
75
+ num_list = []
76
+ # First, check for zeros (and get deal with them)
77
+ if convert_num == "0000":
78
+ return ""
79
+ while convert_num[0] == "0":
80
+ convert_num = convert_num[1:]
81
+ if len(convert_num) == 1:
82
+ return len_one(convert_num,requested_dict)
83
+ elif len(convert_num) == 2:
84
+ return len_two(convert_num,requested_dict)
85
+ elif len(convert_num) == 3:
86
+ return len_three(convert_num,requested_dict)
87
+ # If no zeros, do the calculation
88
+ else:
89
+ # Have to handle 1000, depending on if its a standalone 1000-9999 or included in a larger number
90
+ if convert_num[0] == "1" and stand_alone:
91
+ num_list.append(requested_dict["1000"])
92
+ elif convert_num[0] == "1":
93
+ num_list.append(requested_dict["01000"])
94
+ elif convert_num[0] == "3":
95
+ num_list.append(requested_dict["3000"])
96
+ elif convert_num[0] == "8":
97
+ num_list.append(requested_dict["8000"])
98
+ else:
99
+ num_list.append(requested_dict[convert_num[0]])
100
+ num_list.append(requested_dict["1000"])
101
+ if convert_num[1:] == "000" and len(convert_num) == 4:
102
+ pass
103
+ else:
104
+ if convert_num[1] == "0":
105
+ num_list.append(len_two(convert_num[2:],requested_dict))
106
+ else:
107
+ num_list.append(len_three(convert_num[1:],requested_dict))
108
+ output = ""
109
+ for y in num_list:
110
+ output += y + " "
111
+ output = output[:len(output) - 1]
112
+ return output
113
+
114
+
115
+ def len_x(convert_num,requested_dict):
116
+ #Returns everything else.. (up to 9 digits)
117
+ num_list = []
118
+ if len(convert_num[0:-4]) == 1:
119
+ num_list.append(requested_dict[convert_num[0:-4]])
120
+ num_list.append(requested_dict["10000"])
121
+ elif len(convert_num[0:-4]) == 2:
122
+ num_list.append(len_two(convert_num[0:2],requested_dict))
123
+ num_list.append(requested_dict["10000"])
124
+ elif len(convert_num[0:-4]) == 3:
125
+ num_list.append(len_three(convert_num[0:3],requested_dict))
126
+ num_list.append(requested_dict["10000"])
127
+ elif len(convert_num[0:-4]) == 4:
128
+ num_list.append(len_four(convert_num[0:4],requested_dict, False))
129
+ num_list.append(requested_dict["10000"])
130
+ elif len(convert_num[0:-4]) == 5:
131
+ num_list.append(requested_dict[convert_num[0]])
132
+ num_list.append(requested_dict["100000000"])
133
+ num_list.append(len_four(convert_num[1:5],requested_dict, False))
134
+ if convert_num[1:5] == "0000":
135
+ pass
136
+ else:
137
+ num_list.append(requested_dict["10000"])
138
+ else:
139
+ assert False, "Not yet implemented, please choose a lower number."
140
+ num_list.append(len_four(convert_num[-4:],requested_dict, False))
141
+ output = ""
142
+ for y in num_list:
143
+ output += y + " "
144
+ output = output[:len(output) - 1]
145
+ return output
146
+
147
+ def remove_spaces(convert_result):
148
+ # Remove spaces in Hirigana and Kanji results
149
+ correction = ""
150
+ for x in convert_result:
151
+ if x == " ":
152
+ pass
153
+ else:
154
+ correction += x
155
+ return correction
156
+
157
+ def do_convert(convert_num,requested_dict):
158
+ #Check lengths and convert accordingly
159
+ if len(convert_num) == 1:
160
+ return(len_one(convert_num,requested_dict))
161
+ elif len(convert_num) == 2:
162
+ return(len_two(convert_num,requested_dict))
163
+ elif len(convert_num) == 3:
164
+ return(len_three(convert_num,requested_dict))
165
+ elif len(convert_num) == 4:
166
+ return(len_four(convert_num,requested_dict, True))
167
+ else:
168
+ return(len_x(convert_num,requested_dict))
169
+
170
+ def split_Point(split_num,dict_choice):
171
+ # Used if a decmial point is in the string.
172
+ split_num = split_num.split(".")
173
+ split_num_a = split_num[0]
174
+ split_num_b = split_num[1]
175
+ split_num_b_end = " "
176
+ for x in split_num_b:
177
+ split_num_b_end += len_one(x,key_dict[dict_choice]) + " "
178
+ # To account for small exception of small tsu when ending in jyuu in hiragana/romaji
179
+ if split_num_a[-1] == "0" and split_num_a[-2] != "0" and dict_choice == "hiragana":
180
+ small_Tsu = Convert(split_num_a,dict_choice)
181
+ small_Tsu = small_Tsu[0:-1] + "っ"
182
+ return small_Tsu + key_dict[dict_choice]["."] + split_num_b_end
183
+ if split_num_a[-1] == "0" and split_num_a[-2] != "0" and dict_choice == "romaji":
184
+ small_Tsu = Convert(split_num_a,dict_choice)
185
+ small_Tsu = small_Tsu[0:-1] + "t"
186
+ return small_Tsu + key_dict[dict_choice]["."] + split_num_b_end
187
+
188
+ return Convert(split_num_a,dict_choice) + " " + key_dict[dict_choice]["."] + split_num_b_end
189
+
190
+
191
+ def do_kanji_convert(convert_num):
192
+ # Converts kanji to arabic number
193
+
194
+ if convert_num == "零":
195
+ return 0
196
+
197
+ # First, needs to check for MAN 万 and OKU 億 kanji, as need to handle differently, splitting up the numbers at these intervals.
198
+ # key tells us whether we need to add or multiply the numbers, then we create a list of numbers in an order we need to add/multiply
199
+ key = []
200
+ numberList = []
201
+ y = ""
202
+ for x in convert_num:
203
+ if x == "万" or x == "億":
204
+ numberList.append(y)
205
+ key.append("times")
206
+ numberList.append(x)
207
+ key.append("plus")
208
+ y = ""
209
+ else:
210
+ y += x
211
+ if y != "":
212
+ numberList.append(y)
213
+
214
+ numberListConverted = []
215
+ baseNumber = ["一", "二", "三", "四", "五", "六", "七", "八", "九"]
216
+ linkNumber = ["十", "百", "千", "万", "億"]
217
+
218
+ # Converts the kanji number list to arabic numbers, using the 'base number' and 'link number' list above. For a link number, we would need to
219
+ # link with a base number
220
+ for noX in numberList:
221
+ count = len(noX)
222
+ result = 0
223
+ skip = 1
224
+ for x in reversed(noX):
225
+ addTo = 0
226
+ skip -= 1
227
+ count = count - 1
228
+ if skip == 1:
229
+ continue
230
+ if x in baseNumber:
231
+ for y, z in kanji_dict.items():
232
+ if z == x:
233
+ result += int(y)
234
+ elif x in linkNumber:
235
+ if noX[count - 1] in baseNumber and count > 0:
236
+ for y, z in kanji_dict.items():
237
+ if z == noX[count - 1]:
238
+ tempNo = int(y)
239
+ for y, z in kanji_dict.items():
240
+ if z == x:
241
+ addTo += tempNo * int(y)
242
+ result += addTo
243
+ skip = 2
244
+ else:
245
+ for y, z in kanji_dict.items():
246
+ if z == x:
247
+ result += int(y)
248
+ numberListConverted.append(int(result))
249
+
250
+ result = numberListConverted[0]
251
+ y = 0
252
+
253
+ # Iterate over the converted list, and either multiply/add as instructed in key list
254
+ for x in range(1,len(numberListConverted)):
255
+ if key[y] == "plus":
256
+ try:
257
+ if key[y+1] == "times":
258
+ result = result + numberListConverted[x] * numberListConverted[x+1]
259
+ y += 1
260
+ else:
261
+ result += numberListConverted[x]
262
+ except IndexError:
263
+ result += numberListConverted[-1]
264
+ break
265
+ else:
266
+ result = result * numberListConverted[x]
267
+ y += 1
268
+
269
+ return result
270
+
271
+ def Convert(convert_num, dict_choice='hiragana'):
272
+ # Input formatting
273
+ convert_num = str(convert_num)
274
+ convert_num = convert_num.replace(',','')
275
+ dict_choice = dict_choice.lower()
276
+
277
+ # If all is selected as dict_choice, return as a list
278
+ if dict_choice == "all":
279
+ result_list = []
280
+ for x in "kanji", "hiragana", "romaji":
281
+ result_list.append(Convert(convert_num,x))
282
+ return result_list
283
+
284
+ dictionary = key_dict[dict_choice]
285
+
286
+ # Exit if length is greater than current limit
287
+ if len(convert_num) > 9:
288
+ return("Number length too long, choose less than 10 digits")
289
+
290
+ # Remove any leading zeroes
291
+ while convert_num[0] == "0" and len(convert_num) > 1:
292
+ convert_num = convert_num[1:]
293
+
294
+ # Check for decimal places
295
+ if "." in convert_num:
296
+ result = split_Point(convert_num,dict_choice)
297
+ else:
298
+ result = do_convert(convert_num, dictionary)
299
+
300
+ # Remove spaces and return result
301
+ if key_dict[dict_choice] == romaji_dict:
302
+ pass
303
+ else:
304
+ result = remove_spaces(result)
305
+ return result
306
+
307
+ def ConvertKanji(convert_num):
308
+ if convert_num[0] in kanji_dict.values():
309
+ # Check to see if 点 (point) is in the input, and handle by splitting at 点, before and after is handled separately
310
+ if "点" in convert_num:
311
+ point = convert_num.find("点")
312
+ endNumber = ""
313
+ for x in convert_num[point+1:]:
314
+ endNumber += list(kanji_dict.keys())[list(kanji_dict.values()).index(x)]
315
+ return(str(do_kanji_convert(convert_num[0:point])) + "." + endNumber)
316
+ else:
317
+ return(str(do_kanji_convert(convert_num)))
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ espeak-ng
plbert.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
2
+ from transformers import AlbertConfig, AlbertModel
3
+
4
+ class CustomAlbert(AlbertModel):
5
+ def forward(self, *args, **kwargs):
6
+ # Call the original forward method
7
+ outputs = super().forward(*args, **kwargs)
8
+ # Only return the last_hidden_state
9
+ return outputs.last_hidden_state
10
+
11
+ def load_plbert():
12
+ plbert_config = {'vocab_size': 178, 'hidden_size': 768, 'num_attention_heads': 12, 'intermediate_size': 2048, 'max_position_embeddings': 512, 'num_hidden_layers': 12, 'dropout': 0.1}
13
+ albert_base_configuration = AlbertConfig(**plbert_config)
14
+ bert = CustomAlbert(albert_base_configuration)
15
+ return bert
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fugashi
2
+ gradio
3
+ mojimoji
4
+ munch
5
+ noisereduce
6
+ phonemizer
7
+ scipy
8
+ torch
9
+ transformers
10
+ unicodedata2
11
+ unidic-lite