Spaces:
Running
on
Zero
Running
on
Zero
Upload 10 files
Browse files- app.py +200 -0
- en.txt +0 -0
- istftnet.py +523 -0
- ja.txt +0 -0
- katsu.py +430 -0
- models.py +571 -0
- num2kana.py +317 -0
- packages.txt +1 -0
- plbert.py +15 -0
- requirements.txt +11 -0
app.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import snapshot_download
|
2 |
+
from katsu import Katsu
|
3 |
+
from models import build_model
|
4 |
+
import gradio as gr
|
5 |
+
import noisereduce as nr
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
import phonemizer
|
9 |
+
import random
|
10 |
+
import torch
|
11 |
+
import yaml
|
12 |
+
|
13 |
+
random_texts = {}
|
14 |
+
for lang in ['en', 'ja']:
|
15 |
+
with open(f'{lang}.txt', 'r') as r:
|
16 |
+
random_texts[lang] = [line.strip() for line in r]
|
17 |
+
|
18 |
+
def get_random_text(voice):
|
19 |
+
if voice[0] == 'j':
|
20 |
+
lang = 'ja'
|
21 |
+
else:
|
22 |
+
lang = 'en'
|
23 |
+
return random.choice(random_texts[lang])
|
24 |
+
|
25 |
+
def parens_to_angles(s):
|
26 |
+
return s.replace('(', '«').replace(')', '»')
|
27 |
+
|
28 |
+
def normalize(text):
|
29 |
+
# TODO: Custom text normalization rules?
|
30 |
+
text = text.replace('Dr.', 'Doctor')
|
31 |
+
text = text.replace('Mr.', 'Mister')
|
32 |
+
text = text.replace('Ms.', 'Miss')
|
33 |
+
text = text.replace('Mrs.', 'Mrs')
|
34 |
+
return parens_to_angles(text)
|
35 |
+
|
36 |
+
phonemizers = dict(
|
37 |
+
a=phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True),
|
38 |
+
b=phonemizer.backend.EspeakBackend(language='en-gb', preserve_punctuation=True, with_stress=True),
|
39 |
+
j=Katsu()
|
40 |
+
)
|
41 |
+
|
42 |
+
def phonemize(text, voice):
|
43 |
+
lang = voice[0]
|
44 |
+
text = normalize(text)
|
45 |
+
ps = phonemizers[lang].phonemize([text])
|
46 |
+
ps = ps[0] if ps else ''
|
47 |
+
# TODO: Custom phonemization rules?
|
48 |
+
ps = parens_to_angles(ps)
|
49 |
+
# https://en.wiktionary.org/wiki/kokoro#English
|
50 |
+
ps = ps.replace('kəkˈoːɹoʊ', 'kˈoʊkəɹoʊ').replace('kəkˈɔːɹəʊ', 'kˈəʊkəɹəʊ')
|
51 |
+
ps = ''.join(filter(lambda p: p in VOCAB, ps))
|
52 |
+
return ps.strip()
|
53 |
+
|
54 |
+
def length_to_mask(lengths):
|
55 |
+
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
56 |
+
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
57 |
+
return mask
|
58 |
+
|
59 |
+
def get_vocab():
|
60 |
+
_pad = "$"
|
61 |
+
_punctuation = ';:,.!?¡¿—…"«»“” '
|
62 |
+
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
63 |
+
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
64 |
+
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
65 |
+
dicts = {}
|
66 |
+
for i in range(len((symbols))):
|
67 |
+
dicts[symbols[i]] = i
|
68 |
+
return dicts
|
69 |
+
|
70 |
+
VOCAB = get_vocab()
|
71 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
72 |
+
|
73 |
+
snapshot = snapshot_download(repo_id='hexgrad/kokoro', allow_patterns=['*.pt', '*.pth', '*.yml'], use_auth_token=os.environ['TOKEN'])
|
74 |
+
config = yaml.safe_load(open(os.path.join(snapshot, 'config.yml')))
|
75 |
+
model = build_model(config['model_params'])
|
76 |
+
_ = [model[key].eval() for key in model]
|
77 |
+
_ = [model[key].to(device) for key in model]
|
78 |
+
for key, state_dict in torch.load(os.path.join(snapshot, 'net.pth'), map_location='cpu', weights_only=True)['net'].items():
|
79 |
+
assert key in model, key
|
80 |
+
try:
|
81 |
+
model[key].load_state_dict(state_dict)
|
82 |
+
except:
|
83 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
84 |
+
model[key].load_state_dict(state_dict, strict=False)
|
85 |
+
|
86 |
+
CHOICES = {
|
87 |
+
'🇺🇸 🚺 American Female 0': 'af0',
|
88 |
+
'🇺🇸 🚺 Bella': 'af1',
|
89 |
+
'🇺🇸 🚺 Nicole': 'af2',
|
90 |
+
'🇺🇸 🚹 Michael': 'am0',
|
91 |
+
'🇺🇸 🚹 Adam': 'am1',
|
92 |
+
'🇬🇧 🚺 British Female 0': 'bf0',
|
93 |
+
'🇬🇧 🚺 British Female 1': 'bf1',
|
94 |
+
'🇬🇧 🚺 British Female 2': 'bf2',
|
95 |
+
'🇬🇧 🚹 British Male 0': 'bm0',
|
96 |
+
'🇬🇧 🚹 British Male 1': 'bm1',
|
97 |
+
'🇬🇧 🚹 British Male 2': 'bm2',
|
98 |
+
'🇬🇧 🚹 British Male 3': 'bm3',
|
99 |
+
'🇯🇵 🚺 Japanese Female 0': 'jf0',
|
100 |
+
}
|
101 |
+
VOICES = {k: torch.load(os.path.join(snapshot, 'voices', f'{k}.pt'), weights_only=True).to(device) for k in CHOICES.values()}
|
102 |
+
|
103 |
+
np_log_99 = np.log(99)
|
104 |
+
def s_curve(p):
|
105 |
+
if p <= 0:
|
106 |
+
return 0
|
107 |
+
elif p >= 1:
|
108 |
+
return 1
|
109 |
+
s = 1 / (1 + np.exp((1-p*2)*np_log_99))
|
110 |
+
s = (s-0.01) * 50/49
|
111 |
+
return s
|
112 |
+
|
113 |
+
SAMPLE_RATE = 24000
|
114 |
+
|
115 |
+
@torch.no_grad()
|
116 |
+
def forward(text, voice, ps=None, speed=1.0, reduce_noise=0.5, opening_cut=5000, closing_cut=0, ease_in=3000, ease_out=0):
|
117 |
+
ps = ps or phonemize(text, voice)
|
118 |
+
tokens = [i for i in map(VOCAB.get, ps) if i is not None]
|
119 |
+
if not tokens:
|
120 |
+
return (None, '')
|
121 |
+
elif len(tokens) > 510:
|
122 |
+
tokens = tokens[:510]
|
123 |
+
ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
|
124 |
+
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
125 |
+
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
126 |
+
text_mask = length_to_mask(input_lengths).to(device)
|
127 |
+
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
128 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
129 |
+
ref_s = VOICES[voice]
|
130 |
+
s = ref_s[:, 128:]
|
131 |
+
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
132 |
+
x, _ = model.predictor.lstm(d)
|
133 |
+
duration = model.predictor.duration_proj(x)
|
134 |
+
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
135 |
+
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
136 |
+
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
137 |
+
c_frame = 0
|
138 |
+
for i in range(pred_aln_trg.size(0)):
|
139 |
+
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
140 |
+
c_frame += int(pred_dur[i].data)
|
141 |
+
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
142 |
+
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
143 |
+
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
144 |
+
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
|
145 |
+
out = model.decoder(asr, F0_pred, N_pred, ref_s[:, :128])
|
146 |
+
out = out.squeeze().cpu().numpy()
|
147 |
+
if reduce_noise > 0:
|
148 |
+
out = nr.reduce_noise(y=out, sr=SAMPLE_RATE, prop_decrease=reduce_noise, n_fft=512)
|
149 |
+
opening_cut = max(0, int(opening_cut / speed))
|
150 |
+
if opening_cut > 0:
|
151 |
+
out[:opening_cut] = 0
|
152 |
+
closing_cut = max(0, int(closing_cut / speed))
|
153 |
+
if closing_cut > 0:
|
154 |
+
out = out[-closing_cut:] = 0
|
155 |
+
ease_in = min(int(ease_in / speed), len(out)//2 - opening_cut)
|
156 |
+
for i in range(ease_in):
|
157 |
+
out[i+opening_cut] *= s_curve(i / ease_in)
|
158 |
+
ease_out = min(int(ease_out / speed), len(out)//2 - closing_cut)
|
159 |
+
for i in range(ease_out):
|
160 |
+
out[-i-1-closing_cut] *= s_curve(i / ease_out)
|
161 |
+
return ((SAMPLE_RATE, out), ps)
|
162 |
+
|
163 |
+
with gr.Blocks() as demo:
|
164 |
+
with gr.Row():
|
165 |
+
with gr.Column():
|
166 |
+
text = gr.Textbox(label='Input Text')
|
167 |
+
voice = gr.Dropdown(list(CHOICES.items()), label='Voice')
|
168 |
+
with gr.Row():
|
169 |
+
random_btn = gr.Button('Random Text', variant='secondary')
|
170 |
+
generate_btn = gr.Button('Generate', variant='primary')
|
171 |
+
random_btn.click(get_random_text, inputs=[voice], outputs=[text])
|
172 |
+
with gr.Accordion('Input Phonemes', open=False):
|
173 |
+
in_ps = gr.Textbox(show_label=False, info='Override the input text with custom pronunciation. Leave this blank to use the input text instead.')
|
174 |
+
with gr.Row():
|
175 |
+
clear_btn = gr.ClearButton(in_ps)
|
176 |
+
phonemize_btn = gr.Button('Phonemize Input Text', variant='primary')
|
177 |
+
phonemize_btn.click(phonemize, inputs=[text, voice], outputs=[in_ps])
|
178 |
+
with gr.Column():
|
179 |
+
audio = gr.Audio(interactive=False, label='Output Audio')
|
180 |
+
with gr.Accordion('Tokens', open=True):
|
181 |
+
out_ps = gr.Textbox(interactive=False, show_label=False, info='Tokens used to generate the audio. Same as input phonemes if supplied, excluding unknown characters and truncated to 510 tokens.')
|
182 |
+
with gr.Accordion('Advanced Settings', open=False):
|
183 |
+
with gr.Row():
|
184 |
+
reduce_noise = gr.Slider(minimum=0, maximum=1, value=0.5, label='Reduce Noise', info='👻 Fix it in post: non-stationary noise reduction via spectral gating.')
|
185 |
+
with gr.Row():
|
186 |
+
speed = gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.1, label='Speed', info='⚡️ Adjust the speed of the audio. The trim settings below are also auto-scaled by speed.')
|
187 |
+
with gr.Row():
|
188 |
+
with gr.Column():
|
189 |
+
opening_cut = gr.Slider(minimum=0, maximum=24000, value=5000, step=1000, label='Opening Cut', info='✂️ Zero out this many samples at the start.')
|
190 |
+
with gr.Column():
|
191 |
+
closing_cut = gr.Slider(minimum=0, maximum=24000, value=0, step=1000, label='Closing Cut', info='✂️ Zero out this many samples at the end.')
|
192 |
+
with gr.Row():
|
193 |
+
with gr.Column():
|
194 |
+
ease_in = gr.Slider(minimum=0, maximum=24000, value=3000, step=1000, label='Ease In', info='🚀 Ease in for this many samples, after opening cut.')
|
195 |
+
with gr.Column():
|
196 |
+
ease_out = gr.Slider(minimum=0, maximum=24000, value=0, step=1000, label='Ease Out', info='📐 Ease out for this many samples, before closing cut.')
|
197 |
+
generate_btn.click(forward, inputs=[text, voice, in_ps, speed, reduce_noise, opening_cut, closing_cut, ease_in, ease_out], outputs=[audio, out_ps])
|
198 |
+
|
199 |
+
if __name__ == '__main__':
|
200 |
+
demo.launch()
|
en.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
istftnet.py
ADDED
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
|
2 |
+
from scipy.signal import get_window
|
3 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
4 |
+
from torch.nn.utils import weight_norm, remove_weight_norm
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
|
11 |
+
def init_weights(m, mean=0.0, std=0.01):
|
12 |
+
classname = m.__class__.__name__
|
13 |
+
if classname.find("Conv") != -1:
|
14 |
+
m.weight.data.normal_(mean, std)
|
15 |
+
|
16 |
+
def get_padding(kernel_size, dilation=1):
|
17 |
+
return int((kernel_size*dilation - dilation)/2)
|
18 |
+
|
19 |
+
LRELU_SLOPE = 0.1
|
20 |
+
|
21 |
+
class AdaIN1d(nn.Module):
|
22 |
+
def __init__(self, style_dim, num_features):
|
23 |
+
super().__init__()
|
24 |
+
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
25 |
+
self.fc = nn.Linear(style_dim, num_features*2)
|
26 |
+
|
27 |
+
def forward(self, x, s):
|
28 |
+
h = self.fc(s)
|
29 |
+
h = h.view(h.size(0), h.size(1), 1)
|
30 |
+
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
31 |
+
return (1 + gamma) * self.norm(x) + beta
|
32 |
+
|
33 |
+
class AdaINResBlock1(torch.nn.Module):
|
34 |
+
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
|
35 |
+
super(AdaINResBlock1, self).__init__()
|
36 |
+
self.convs1 = nn.ModuleList([
|
37 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
38 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
39 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
40 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
41 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
42 |
+
padding=get_padding(kernel_size, dilation[2])))
|
43 |
+
])
|
44 |
+
self.convs1.apply(init_weights)
|
45 |
+
|
46 |
+
self.convs2 = nn.ModuleList([
|
47 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
48 |
+
padding=get_padding(kernel_size, 1))),
|
49 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
50 |
+
padding=get_padding(kernel_size, 1))),
|
51 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
52 |
+
padding=get_padding(kernel_size, 1)))
|
53 |
+
])
|
54 |
+
self.convs2.apply(init_weights)
|
55 |
+
|
56 |
+
self.adain1 = nn.ModuleList([
|
57 |
+
AdaIN1d(style_dim, channels),
|
58 |
+
AdaIN1d(style_dim, channels),
|
59 |
+
AdaIN1d(style_dim, channels),
|
60 |
+
])
|
61 |
+
|
62 |
+
self.adain2 = nn.ModuleList([
|
63 |
+
AdaIN1d(style_dim, channels),
|
64 |
+
AdaIN1d(style_dim, channels),
|
65 |
+
AdaIN1d(style_dim, channels),
|
66 |
+
])
|
67 |
+
|
68 |
+
self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
|
69 |
+
self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
|
70 |
+
|
71 |
+
|
72 |
+
def forward(self, x, s):
|
73 |
+
for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
|
74 |
+
xt = n1(x, s)
|
75 |
+
xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
|
76 |
+
xt = c1(xt)
|
77 |
+
xt = n2(xt, s)
|
78 |
+
xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
|
79 |
+
xt = c2(xt)
|
80 |
+
x = xt + x
|
81 |
+
return x
|
82 |
+
|
83 |
+
def remove_weight_norm(self):
|
84 |
+
for l in self.convs1:
|
85 |
+
remove_weight_norm(l)
|
86 |
+
for l in self.convs2:
|
87 |
+
remove_weight_norm(l)
|
88 |
+
|
89 |
+
class TorchSTFT(torch.nn.Module):
|
90 |
+
def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
|
91 |
+
super().__init__()
|
92 |
+
self.filter_length = filter_length
|
93 |
+
self.hop_length = hop_length
|
94 |
+
self.win_length = win_length
|
95 |
+
self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
|
96 |
+
|
97 |
+
def transform(self, input_data):
|
98 |
+
forward_transform = torch.stft(
|
99 |
+
input_data,
|
100 |
+
self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
|
101 |
+
return_complex=True)
|
102 |
+
|
103 |
+
return torch.abs(forward_transform), torch.angle(forward_transform)
|
104 |
+
|
105 |
+
def inverse(self, magnitude, phase):
|
106 |
+
inverse_transform = torch.istft(
|
107 |
+
magnitude * torch.exp(phase * 1j),
|
108 |
+
self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
|
109 |
+
|
110 |
+
return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
|
111 |
+
|
112 |
+
def forward(self, input_data):
|
113 |
+
self.magnitude, self.phase = self.transform(input_data)
|
114 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
115 |
+
return reconstruction
|
116 |
+
|
117 |
+
class SineGen(torch.nn.Module):
|
118 |
+
""" Definition of sine generator
|
119 |
+
SineGen(samp_rate, harmonic_num = 0,
|
120 |
+
sine_amp = 0.1, noise_std = 0.003,
|
121 |
+
voiced_threshold = 0,
|
122 |
+
flag_for_pulse=False)
|
123 |
+
samp_rate: sampling rate in Hz
|
124 |
+
harmonic_num: number of harmonic overtones (default 0)
|
125 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
126 |
+
noise_std: std of Gaussian noise (default 0.003)
|
127 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
128 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
129 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
130 |
+
segment is always sin(np.pi) or cos(0)
|
131 |
+
"""
|
132 |
+
|
133 |
+
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
134 |
+
sine_amp=0.1, noise_std=0.003,
|
135 |
+
voiced_threshold=0,
|
136 |
+
flag_for_pulse=False):
|
137 |
+
super(SineGen, self).__init__()
|
138 |
+
self.sine_amp = sine_amp
|
139 |
+
self.noise_std = noise_std
|
140 |
+
self.harmonic_num = harmonic_num
|
141 |
+
self.dim = self.harmonic_num + 1
|
142 |
+
self.sampling_rate = samp_rate
|
143 |
+
self.voiced_threshold = voiced_threshold
|
144 |
+
self.flag_for_pulse = flag_for_pulse
|
145 |
+
self.upsample_scale = upsample_scale
|
146 |
+
|
147 |
+
def _f02uv(self, f0):
|
148 |
+
# generate uv signal
|
149 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
150 |
+
return uv
|
151 |
+
|
152 |
+
def _f02sine(self, f0_values):
|
153 |
+
""" f0_values: (batchsize, length, dim)
|
154 |
+
where dim indicates fundamental tone and overtones
|
155 |
+
"""
|
156 |
+
# convert to F0 in rad. The interger part n can be ignored
|
157 |
+
# because 2 * np.pi * n doesn't affect phase
|
158 |
+
rad_values = (f0_values / self.sampling_rate) % 1
|
159 |
+
|
160 |
+
# initial phase noise (no noise for fundamental component)
|
161 |
+
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
|
162 |
+
device=f0_values.device)
|
163 |
+
rand_ini[:, 0] = 0
|
164 |
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
165 |
+
|
166 |
+
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
167 |
+
if not self.flag_for_pulse:
|
168 |
+
# # for normal case
|
169 |
+
|
170 |
+
# # To prevent torch.cumsum numerical overflow,
|
171 |
+
# # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
|
172 |
+
# # Buffer tmp_over_one_idx indicates the time step to add -1.
|
173 |
+
# # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
|
174 |
+
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
175 |
+
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
176 |
+
# cumsum_shift = torch.zeros_like(rad_values)
|
177 |
+
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
178 |
+
|
179 |
+
# phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
180 |
+
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
|
181 |
+
scale_factor=1/self.upsample_scale,
|
182 |
+
mode="linear").transpose(1, 2)
|
183 |
+
|
184 |
+
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
185 |
+
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
186 |
+
# cumsum_shift = torch.zeros_like(rad_values)
|
187 |
+
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
188 |
+
|
189 |
+
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
190 |
+
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
191 |
+
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
|
192 |
+
sines = torch.sin(phase)
|
193 |
+
|
194 |
+
else:
|
195 |
+
# If necessary, make sure that the first time step of every
|
196 |
+
# voiced segments is sin(pi) or cos(0)
|
197 |
+
# This is used for pulse-train generation
|
198 |
+
|
199 |
+
# identify the last time step in unvoiced segments
|
200 |
+
uv = self._f02uv(f0_values)
|
201 |
+
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
202 |
+
uv_1[:, -1, :] = 1
|
203 |
+
u_loc = (uv < 1) * (uv_1 > 0)
|
204 |
+
|
205 |
+
# get the instantanouse phase
|
206 |
+
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
207 |
+
# different batch needs to be processed differently
|
208 |
+
for idx in range(f0_values.shape[0]):
|
209 |
+
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
210 |
+
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
211 |
+
# stores the accumulation of i.phase within
|
212 |
+
# each voiced segments
|
213 |
+
tmp_cumsum[idx, :, :] = 0
|
214 |
+
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
215 |
+
|
216 |
+
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
217 |
+
# within the previous voiced segment.
|
218 |
+
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
219 |
+
|
220 |
+
# get the sines
|
221 |
+
sines = torch.cos(i_phase * 2 * np.pi)
|
222 |
+
return sines
|
223 |
+
|
224 |
+
def forward(self, f0):
|
225 |
+
""" sine_tensor, uv = forward(f0)
|
226 |
+
input F0: tensor(batchsize=1, length, dim=1)
|
227 |
+
f0 for unvoiced steps should be 0
|
228 |
+
output sine_tensor: tensor(batchsize=1, length, dim)
|
229 |
+
output uv: tensor(batchsize=1, length, 1)
|
230 |
+
"""
|
231 |
+
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
|
232 |
+
device=f0.device)
|
233 |
+
# fundamental component
|
234 |
+
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
235 |
+
|
236 |
+
# generate sine waveforms
|
237 |
+
sine_waves = self._f02sine(fn) * self.sine_amp
|
238 |
+
|
239 |
+
# generate uv signal
|
240 |
+
# uv = torch.ones(f0.shape)
|
241 |
+
# uv = uv * (f0 > self.voiced_threshold)
|
242 |
+
uv = self._f02uv(f0)
|
243 |
+
|
244 |
+
# noise: for unvoiced should be similar to sine_amp
|
245 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
246 |
+
# . for voiced regions is self.noise_std
|
247 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
248 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
249 |
+
|
250 |
+
# first: set the unvoiced part to 0 by uv
|
251 |
+
# then: additive noise
|
252 |
+
sine_waves = sine_waves * uv + noise
|
253 |
+
return sine_waves, uv, noise
|
254 |
+
|
255 |
+
|
256 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
257 |
+
""" SourceModule for hn-nsf
|
258 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
259 |
+
add_noise_std=0.003, voiced_threshod=0)
|
260 |
+
sampling_rate: sampling_rate in Hz
|
261 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
262 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
263 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
264 |
+
note that amplitude of noise in unvoiced is decided
|
265 |
+
by sine_amp
|
266 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
267 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
268 |
+
F0_sampled (batchsize, length, 1)
|
269 |
+
Sine_source (batchsize, length, 1)
|
270 |
+
noise_source (batchsize, length 1)
|
271 |
+
uv (batchsize, length, 1)
|
272 |
+
"""
|
273 |
+
|
274 |
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
275 |
+
add_noise_std=0.003, voiced_threshod=0):
|
276 |
+
super(SourceModuleHnNSF, self).__init__()
|
277 |
+
|
278 |
+
self.sine_amp = sine_amp
|
279 |
+
self.noise_std = add_noise_std
|
280 |
+
|
281 |
+
# to produce sine waveforms
|
282 |
+
self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
|
283 |
+
sine_amp, add_noise_std, voiced_threshod)
|
284 |
+
|
285 |
+
# to merge source harmonics into a single excitation
|
286 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
287 |
+
self.l_tanh = torch.nn.Tanh()
|
288 |
+
|
289 |
+
def forward(self, x):
|
290 |
+
"""
|
291 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
292 |
+
F0_sampled (batchsize, length, 1)
|
293 |
+
Sine_source (batchsize, length, 1)
|
294 |
+
noise_source (batchsize, length 1)
|
295 |
+
"""
|
296 |
+
# source for harmonic branch
|
297 |
+
with torch.no_grad():
|
298 |
+
sine_wavs, uv, _ = self.l_sin_gen(x)
|
299 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
300 |
+
|
301 |
+
# source for noise branch, in the same shape as uv
|
302 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
303 |
+
return sine_merge, noise, uv
|
304 |
+
def padDiff(x):
|
305 |
+
return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
|
306 |
+
|
307 |
+
|
308 |
+
class Generator(torch.nn.Module):
|
309 |
+
def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
|
310 |
+
super(Generator, self).__init__()
|
311 |
+
|
312 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
313 |
+
self.num_upsamples = len(upsample_rates)
|
314 |
+
resblock = AdaINResBlock1
|
315 |
+
|
316 |
+
self.m_source = SourceModuleHnNSF(
|
317 |
+
sampling_rate=24000,
|
318 |
+
upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
|
319 |
+
harmonic_num=8, voiced_threshod=10)
|
320 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
|
321 |
+
self.noise_convs = nn.ModuleList()
|
322 |
+
self.noise_res = nn.ModuleList()
|
323 |
+
|
324 |
+
self.ups = nn.ModuleList()
|
325 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
326 |
+
self.ups.append(weight_norm(
|
327 |
+
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
|
328 |
+
k, u, padding=(k-u)//2)))
|
329 |
+
|
330 |
+
self.resblocks = nn.ModuleList()
|
331 |
+
for i in range(len(self.ups)):
|
332 |
+
ch = upsample_initial_channel//(2**(i+1))
|
333 |
+
for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
|
334 |
+
self.resblocks.append(resblock(ch, k, d, style_dim))
|
335 |
+
|
336 |
+
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
337 |
+
|
338 |
+
if i + 1 < len(upsample_rates): #
|
339 |
+
stride_f0 = np.prod(upsample_rates[i + 1:])
|
340 |
+
self.noise_convs.append(Conv1d(
|
341 |
+
gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
|
342 |
+
self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
|
343 |
+
else:
|
344 |
+
self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
|
345 |
+
self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
|
346 |
+
|
347 |
+
|
348 |
+
self.post_n_fft = gen_istft_n_fft
|
349 |
+
self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
|
350 |
+
self.ups.apply(init_weights)
|
351 |
+
self.conv_post.apply(init_weights)
|
352 |
+
self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
|
353 |
+
self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
|
354 |
+
|
355 |
+
|
356 |
+
def forward(self, x, s, f0):
|
357 |
+
with torch.no_grad():
|
358 |
+
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
359 |
+
|
360 |
+
har_source, noi_source, uv = self.m_source(f0)
|
361 |
+
har_source = har_source.transpose(1, 2).squeeze(1)
|
362 |
+
har_spec, har_phase = self.stft.transform(har_source)
|
363 |
+
har = torch.cat([har_spec, har_phase], dim=1)
|
364 |
+
|
365 |
+
for i in range(self.num_upsamples):
|
366 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
367 |
+
x_source = self.noise_convs[i](har)
|
368 |
+
x_source = self.noise_res[i](x_source, s)
|
369 |
+
|
370 |
+
x = self.ups[i](x)
|
371 |
+
if i == self.num_upsamples - 1:
|
372 |
+
x = self.reflection_pad(x)
|
373 |
+
|
374 |
+
x = x + x_source
|
375 |
+
xs = None
|
376 |
+
for j in range(self.num_kernels):
|
377 |
+
if xs is None:
|
378 |
+
xs = self.resblocks[i*self.num_kernels+j](x, s)
|
379 |
+
else:
|
380 |
+
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
381 |
+
x = xs / self.num_kernels
|
382 |
+
x = F.leaky_relu(x)
|
383 |
+
x = self.conv_post(x)
|
384 |
+
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
385 |
+
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
386 |
+
return self.stft.inverse(spec, phase)
|
387 |
+
|
388 |
+
def fw_phase(self, x, s):
|
389 |
+
for i in range(self.num_upsamples):
|
390 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
391 |
+
x = self.ups[i](x)
|
392 |
+
xs = None
|
393 |
+
for j in range(self.num_kernels):
|
394 |
+
if xs is None:
|
395 |
+
xs = self.resblocks[i*self.num_kernels+j](x, s)
|
396 |
+
else:
|
397 |
+
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
398 |
+
x = xs / self.num_kernels
|
399 |
+
x = F.leaky_relu(x)
|
400 |
+
x = self.reflection_pad(x)
|
401 |
+
x = self.conv_post(x)
|
402 |
+
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
403 |
+
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
404 |
+
return spec, phase
|
405 |
+
|
406 |
+
def remove_weight_norm(self):
|
407 |
+
print('Removing weight norm...')
|
408 |
+
for l in self.ups:
|
409 |
+
remove_weight_norm(l)
|
410 |
+
for l in self.resblocks:
|
411 |
+
l.remove_weight_norm()
|
412 |
+
remove_weight_norm(self.conv_pre)
|
413 |
+
remove_weight_norm(self.conv_post)
|
414 |
+
|
415 |
+
|
416 |
+
class AdainResBlk1d(nn.Module):
|
417 |
+
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
|
418 |
+
upsample='none', dropout_p=0.0):
|
419 |
+
super().__init__()
|
420 |
+
self.actv = actv
|
421 |
+
self.upsample_type = upsample
|
422 |
+
self.upsample = UpSample1d(upsample)
|
423 |
+
self.learned_sc = dim_in != dim_out
|
424 |
+
self._build_weights(dim_in, dim_out, style_dim)
|
425 |
+
self.dropout = nn.Dropout(dropout_p)
|
426 |
+
|
427 |
+
if upsample == 'none':
|
428 |
+
self.pool = nn.Identity()
|
429 |
+
else:
|
430 |
+
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
|
431 |
+
|
432 |
+
|
433 |
+
def _build_weights(self, dim_in, dim_out, style_dim):
|
434 |
+
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
435 |
+
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
436 |
+
self.norm1 = AdaIN1d(style_dim, dim_in)
|
437 |
+
self.norm2 = AdaIN1d(style_dim, dim_out)
|
438 |
+
if self.learned_sc:
|
439 |
+
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
440 |
+
|
441 |
+
def _shortcut(self, x):
|
442 |
+
x = self.upsample(x)
|
443 |
+
if self.learned_sc:
|
444 |
+
x = self.conv1x1(x)
|
445 |
+
return x
|
446 |
+
|
447 |
+
def _residual(self, x, s):
|
448 |
+
x = self.norm1(x, s)
|
449 |
+
x = self.actv(x)
|
450 |
+
x = self.pool(x)
|
451 |
+
x = self.conv1(self.dropout(x))
|
452 |
+
x = self.norm2(x, s)
|
453 |
+
x = self.actv(x)
|
454 |
+
x = self.conv2(self.dropout(x))
|
455 |
+
return x
|
456 |
+
|
457 |
+
def forward(self, x, s):
|
458 |
+
out = self._residual(x, s)
|
459 |
+
out = (out + self._shortcut(x)) / np.sqrt(2)
|
460 |
+
return out
|
461 |
+
|
462 |
+
class UpSample1d(nn.Module):
|
463 |
+
def __init__(self, layer_type):
|
464 |
+
super().__init__()
|
465 |
+
self.layer_type = layer_type
|
466 |
+
|
467 |
+
def forward(self, x):
|
468 |
+
if self.layer_type == 'none':
|
469 |
+
return x
|
470 |
+
else:
|
471 |
+
return F.interpolate(x, scale_factor=2, mode='nearest')
|
472 |
+
|
473 |
+
class Decoder(nn.Module):
|
474 |
+
def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
|
475 |
+
resblock_kernel_sizes = [3,7,11],
|
476 |
+
upsample_rates = [10, 6],
|
477 |
+
upsample_initial_channel=512,
|
478 |
+
resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
|
479 |
+
upsample_kernel_sizes=[20, 12],
|
480 |
+
gen_istft_n_fft=20, gen_istft_hop_size=5):
|
481 |
+
super().__init__()
|
482 |
+
|
483 |
+
self.decode = nn.ModuleList()
|
484 |
+
|
485 |
+
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
|
486 |
+
|
487 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
488 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
489 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
490 |
+
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
|
491 |
+
|
492 |
+
self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
493 |
+
|
494 |
+
self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
495 |
+
|
496 |
+
self.asr_res = nn.Sequential(
|
497 |
+
weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
|
498 |
+
)
|
499 |
+
|
500 |
+
|
501 |
+
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
|
502 |
+
upsample_initial_channel, resblock_dilation_sizes,
|
503 |
+
upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
|
504 |
+
|
505 |
+
def forward(self, asr, F0_curve, N, s):
|
506 |
+
F0 = self.F0_conv(F0_curve.unsqueeze(1))
|
507 |
+
N = self.N_conv(N.unsqueeze(1))
|
508 |
+
|
509 |
+
x = torch.cat([asr, F0, N], axis=1)
|
510 |
+
x = self.encode(x, s)
|
511 |
+
|
512 |
+
asr_res = self.asr_res(asr)
|
513 |
+
|
514 |
+
res = True
|
515 |
+
for block in self.decode:
|
516 |
+
if res:
|
517 |
+
x = torch.cat([x, asr_res, F0, N], axis=1)
|
518 |
+
x = block(x, s)
|
519 |
+
if block.upsample_type != "none":
|
520 |
+
res = False
|
521 |
+
|
522 |
+
x = self.generator(x, s, F0_curve)
|
523 |
+
return x
|
ja.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
katsu.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/polm/cutlet/blob/master/cutlet/cutlet.py
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from fugashi import Tagger
|
4 |
+
from num2kana import Convert
|
5 |
+
import mojimoji
|
6 |
+
import re
|
7 |
+
import unicodedata
|
8 |
+
|
9 |
+
HEPBURN = {
|
10 |
+
chr(12449):'a', #ァ
|
11 |
+
chr(12450):'a', #ア
|
12 |
+
chr(12451):'i', #ィ
|
13 |
+
chr(12452):'i', #イ
|
14 |
+
chr(12453):'ɯ', #ゥ
|
15 |
+
chr(12454):'ɯ', #ウ
|
16 |
+
chr(12455):'e', #ェ
|
17 |
+
chr(12456):'e', #エ
|
18 |
+
chr(12457):'o', #ォ
|
19 |
+
chr(12458):'o', #オ
|
20 |
+
chr(12459):'ka', #カ
|
21 |
+
chr(12460):'ɡa', #ガ
|
22 |
+
chr(12461):'ki', #キ
|
23 |
+
chr(12462):'ɡi', #ギ
|
24 |
+
chr(12463):'kɯ', #ク
|
25 |
+
chr(12464):'ɡɯ', #グ
|
26 |
+
chr(12465):'ke', #ケ
|
27 |
+
chr(12466):'ɡe', #ゲ
|
28 |
+
chr(12467):'ko', #コ
|
29 |
+
chr(12468):'ɡo', #ゴ
|
30 |
+
chr(12469):'sa', #サ
|
31 |
+
chr(12470):'za', #ザ
|
32 |
+
chr(12471):'ɕi', #シ
|
33 |
+
chr(12472):'dʑi', #ジ
|
34 |
+
chr(12473):'sɨ', #ス
|
35 |
+
chr(12474):'zɨ', #ズ
|
36 |
+
chr(12475):'se', #セ
|
37 |
+
chr(12476):'ze', #ゼ
|
38 |
+
chr(12477):'so', #ソ
|
39 |
+
chr(12478):'zo', #ゾ
|
40 |
+
chr(12479):'ta', #タ
|
41 |
+
chr(12480):'da', #ダ
|
42 |
+
chr(12481):'tɕi', #チ
|
43 |
+
chr(12482):'dʑi', #ヂ
|
44 |
+
# chr(12483) #ッ
|
45 |
+
chr(12484):'tsɨ', #ツ
|
46 |
+
chr(12485):'zɨ', #ヅ
|
47 |
+
chr(12486):'te', #テ
|
48 |
+
chr(12487):'de', #デ
|
49 |
+
chr(12488):'to', #ト
|
50 |
+
chr(12489):'do', #ド
|
51 |
+
chr(12490):'na', #ナ
|
52 |
+
chr(12491):'ɲi', #ニ
|
53 |
+
chr(12492):'nɯ', #ヌ
|
54 |
+
chr(12493):'ne', #ネ
|
55 |
+
chr(12494):'no', #ノ
|
56 |
+
chr(12495):'ha', #ハ
|
57 |
+
chr(12496):'ba', #バ
|
58 |
+
chr(12497):'pa', #パ
|
59 |
+
chr(12498):'çi', #ヒ
|
60 |
+
chr(12499):'bi', #ビ
|
61 |
+
chr(12500):'pi', #ピ
|
62 |
+
chr(12501):'ɸɯ', #フ
|
63 |
+
chr(12502):'bɯ', #ブ
|
64 |
+
chr(12503):'pɯ', #プ
|
65 |
+
chr(12504):'he', #ヘ
|
66 |
+
chr(12505):'be', #ベ
|
67 |
+
chr(12506):'pe', #ペ
|
68 |
+
chr(12507):'ho', #ホ
|
69 |
+
chr(12508):'bo', #ボ
|
70 |
+
chr(12509):'po', #ポ
|
71 |
+
chr(12510):'ma', #マ
|
72 |
+
chr(12511):'mi', #ミ
|
73 |
+
chr(12512):'mɯ', #ム
|
74 |
+
chr(12513):'me', #メ
|
75 |
+
chr(12514):'mo', #モ
|
76 |
+
chr(12515):'ja', #ャ
|
77 |
+
chr(12516):'ja', #ヤ
|
78 |
+
chr(12517):'jɯ', #ュ
|
79 |
+
chr(12518):'jɯ', #ユ
|
80 |
+
chr(12519):'jo', #ョ
|
81 |
+
chr(12520):'jo', #ヨ
|
82 |
+
chr(12521):'ra', #ラ
|
83 |
+
chr(12522):'ri', #リ
|
84 |
+
chr(12523):'rɯ', #ル
|
85 |
+
chr(12524):'re', #レ
|
86 |
+
chr(12525):'ro', #ロ
|
87 |
+
chr(12526):'wa', #ヮ
|
88 |
+
chr(12527):'wa', #ワ
|
89 |
+
chr(12528):'i', #ヰ
|
90 |
+
chr(12529):'e', #ヱ
|
91 |
+
chr(12530):'o', #ヲ
|
92 |
+
# chr(12531) #ン
|
93 |
+
chr(12532):'vɯ', #ヴ
|
94 |
+
chr(12533):'ka', #ヵ
|
95 |
+
chr(12534):'ke', #ヶ
|
96 |
+
}
|
97 |
+
assert len(HEPBURN) == 84 and all(i in {12483, 12531} or chr(i) in HEPBURN for i in range(12449, 12535))
|
98 |
+
|
99 |
+
for k, v in list(HEPBURN.items()):
|
100 |
+
HEPBURN[chr(ord(k)-96)] = v
|
101 |
+
assert len(HEPBURN) == 84*2
|
102 |
+
|
103 |
+
HEPBURN.update({
|
104 |
+
chr(12535):'va', #ヷ
|
105 |
+
chr(12536):'vi', #ヸ
|
106 |
+
chr(12537):'ve', #ヹ
|
107 |
+
chr(12538):'vo', #ヺ
|
108 |
+
})
|
109 |
+
assert len(HEPBURN) == 84*2+4 and all(chr(i) in HEPBURN for i in range(12535, 12539))
|
110 |
+
|
111 |
+
HEPBURN.update({
|
112 |
+
chr(12784):'kɯ', #ㇰ
|
113 |
+
chr(12785):'ɕi', #ㇱ
|
114 |
+
chr(12786):'sɨ', #ㇲ
|
115 |
+
chr(12787):'to', #ㇳ
|
116 |
+
chr(12788):'nɯ', #ㇴ
|
117 |
+
chr(12789):'ha', #ㇵ
|
118 |
+
chr(12790):'çi', #ㇶ
|
119 |
+
chr(12791):'ɸɯ', #ㇷ
|
120 |
+
chr(12792):'he', #ㇸ
|
121 |
+
chr(12793):'ho', #ㇹ
|
122 |
+
chr(12794):'mɯ', #ㇺ
|
123 |
+
chr(12795):'ra', #ㇻ
|
124 |
+
chr(12796):'ri', #ㇼ
|
125 |
+
chr(12797):'rɯ', #ㇽ
|
126 |
+
chr(12798):'re', #ㇾ
|
127 |
+
chr(12799):'ro', #ㇿ
|
128 |
+
})
|
129 |
+
assert len(HEPBURN) == 84*2+4+16 and all(chr(i) in HEPBURN for i in range(12784, 12800))
|
130 |
+
|
131 |
+
HEPBURN.update({
|
132 |
+
chr(12452)+chr(12455):'je', #イェ
|
133 |
+
chr(12454)+chr(12451):'wi', #ウィ
|
134 |
+
chr(12454)+chr(12455):'we', #ウェ
|
135 |
+
chr(12454)+chr(12457):'wo', #ウォ
|
136 |
+
chr(12461)+chr(12455):'kʲe', #キェ
|
137 |
+
chr(12461)+chr(12515):'kʲa', #キャ
|
138 |
+
chr(12461)+chr(12517):'kʲɨ', #キュ
|
139 |
+
chr(12461)+chr(12519):'kʲo', #キョ
|
140 |
+
chr(12462)+chr(12515):'ɡʲa', #ギャ
|
141 |
+
chr(12462)+chr(12517):'ɡʲɨ', #ギュ
|
142 |
+
chr(12462)+chr(12519):'ɡʲo', #ギョ
|
143 |
+
chr(12463)+chr(12449):'kʷa', #クァ
|
144 |
+
chr(12463)+chr(12451):'kʷi', #クィ
|
145 |
+
chr(12463)+chr(12455):'kʷe', #クェ
|
146 |
+
chr(12463)+chr(12457):'kʷo', #クォ
|
147 |
+
chr(12464)+chr(12449):'ɡʷa', #グァ
|
148 |
+
chr(12464)+chr(12451):'ɡʷi', #グィ
|
149 |
+
chr(12464)+chr(12455):'ɡʷe', #グェ
|
150 |
+
chr(12464)+chr(12457):'ɡʷo', #グォ
|
151 |
+
chr(12471)+chr(12455):'ɕe', #シェ
|
152 |
+
chr(12471)+chr(12515):'ɕa', #シャ
|
153 |
+
chr(12471)+chr(12517):'ɕɨ', #シュ
|
154 |
+
chr(12471)+chr(12519):'ɕo', #ショ
|
155 |
+
chr(12472)+chr(12455):'dʑe', #ジェ
|
156 |
+
chr(12472)+chr(12515):'dʑa', #ジャ
|
157 |
+
chr(12472)+chr(12517):'dʑɨ', #ジュ
|
158 |
+
chr(12472)+chr(12519):'dʑo', #ジョ
|
159 |
+
chr(12481)+chr(12455):'tɕe', #チェ
|
160 |
+
chr(12481)+chr(12515):'tɕa', #チャ
|
161 |
+
chr(12481)+chr(12517):'tɕɨ', #チュ
|
162 |
+
chr(12481)+chr(12519):'tɕo', #チョ
|
163 |
+
chr(12482)+chr(12515):'dʑa', #ヂャ
|
164 |
+
chr(12482)+chr(12517):'dʑɨ', #ヂュ
|
165 |
+
chr(12482)+chr(12519):'dʑo', #ヂョ
|
166 |
+
chr(12484)+chr(12449):'tsa', #ツァ
|
167 |
+
chr(12484)+chr(12451):'tsi', #ツィ
|
168 |
+
chr(12484)+chr(12455):'tse', #ツェ
|
169 |
+
chr(12484)+chr(12457):'tso', #ツォ
|
170 |
+
chr(12486)+chr(12451):'ti', #ティ
|
171 |
+
chr(12486)+chr(12517):'tʲɨ', #テュ
|
172 |
+
chr(12487)+chr(12451):'di', #ディ
|
173 |
+
chr(12487)+chr(12517):'dʲɨ', #デュ
|
174 |
+
chr(12488)+chr(12453):'tɯ', #トゥ
|
175 |
+
chr(12489)+chr(12453):'dɯ', #ドゥ
|
176 |
+
chr(12491)+chr(12455):'ɲe', #ニェ
|
177 |
+
chr(12491)+chr(12515):'ɲa', #ニャ
|
178 |
+
chr(12491)+chr(12517):'ɲɨ', #ニュ
|
179 |
+
chr(12491)+chr(12519):'ɲo', #ニョ
|
180 |
+
chr(12498)+chr(12455):'çe', #ヒェ
|
181 |
+
chr(12498)+chr(12515):'ça', #ヒャ
|
182 |
+
chr(12498)+chr(12517):'çɨ', #ヒュ
|
183 |
+
chr(12498)+chr(12519):'ço', #ヒョ
|
184 |
+
chr(12499)+chr(12515):'bʲa', #ビャ
|
185 |
+
chr(12499)+chr(12517):'bʲɨ', #ビュ
|
186 |
+
chr(12499)+chr(12519):'bʲo', #ビョ
|
187 |
+
chr(12500)+chr(12515):'pʲa', #ピャ
|
188 |
+
chr(12500)+chr(12517):'pʲɨ', #ピュ
|
189 |
+
chr(12500)+chr(12519):'pʲo', #ピョ
|
190 |
+
chr(12501)+chr(12449):'ɸa', #ファ
|
191 |
+
chr(12501)+chr(12451):'ɸi', #フィ
|
192 |
+
chr(12501)+chr(12455):'ɸe', #フェ
|
193 |
+
chr(12501)+chr(12457):'ɸo', #フォ
|
194 |
+
chr(12501)+chr(12517):'ɸʲɨ', #フュ
|
195 |
+
chr(12501)+chr(12519):'ɸʲo', #フョ
|
196 |
+
chr(12511)+chr(12515):'mʲa', #ミャ
|
197 |
+
chr(12511)+chr(12517):'mʲɨ', #ミュ
|
198 |
+
chr(12511)+chr(12519):'mʲo', #ミョ
|
199 |
+
chr(12522)+chr(12515):'rʲa', #リャ
|
200 |
+
chr(12522)+chr(12517):'rʲɨ', #リュ
|
201 |
+
chr(12522)+chr(12519):'rʲo', #リョ
|
202 |
+
chr(12532)+chr(12449):'va', #ヴァ
|
203 |
+
chr(12532)+chr(12451):'vi', #ヴィ
|
204 |
+
chr(12532)+chr(12455):'ve', #ヴェ
|
205 |
+
chr(12532)+chr(12457):'vo', #ヴォ
|
206 |
+
chr(12532)+chr(12517):'vʲɨ', #ヴュ
|
207 |
+
chr(12532)+chr(12519):'vʲo', #ヴョ
|
208 |
+
})
|
209 |
+
assert len(HEPBURN) == 84*2+4+16+76
|
210 |
+
|
211 |
+
for k, v in list(HEPBURN.items()):
|
212 |
+
if len(k) != 2:
|
213 |
+
continue
|
214 |
+
a, b = k
|
215 |
+
assert a in HEPBURN and b in HEPBURN, (a, b)
|
216 |
+
a = chr(ord(a)-96)
|
217 |
+
b = chr(ord(b)-96)
|
218 |
+
assert a in HEPBURN and b in HEPBURN, (a, b)
|
219 |
+
HEPBURN[a+b] = v
|
220 |
+
assert len(HEPBURN) == 84*2+4+16+76*2
|
221 |
+
|
222 |
+
HEPBURN.update({
|
223 |
+
# symbols
|
224 |
+
# 'ー': '-', # 長音符, only used when repeated
|
225 |
+
'。': '.',
|
226 |
+
'、': ',',
|
227 |
+
'?': '?',
|
228 |
+
'!': '!',
|
229 |
+
'「': '"',
|
230 |
+
'」': '"',
|
231 |
+
'『': '"',
|
232 |
+
'』': '"',
|
233 |
+
':': ':',
|
234 |
+
'(': '(',
|
235 |
+
')': ')',
|
236 |
+
'《': '(',
|
237 |
+
'》': ')',
|
238 |
+
'【': '[',
|
239 |
+
'】': ']',
|
240 |
+
'・': ' ',#'/',
|
241 |
+
',': ',',
|
242 |
+
'~': '—',
|
243 |
+
'〜': '—',
|
244 |
+
'—': '—',
|
245 |
+
'«': '«',
|
246 |
+
'»': '»',
|
247 |
+
|
248 |
+
# other
|
249 |
+
'゚': '', # combining handakuten by itself, just discard
|
250 |
+
'゙': '', # combining dakuten by itself
|
251 |
+
})
|
252 |
+
|
253 |
+
def add_dakuten(kk):
|
254 |
+
"""Given a kana (single-character string), add a dakuten."""
|
255 |
+
try:
|
256 |
+
# ii = 'かきくけこさしすせそたちつてとはひふへほ'.index(kk)
|
257 |
+
ii = 'カキクケコサシスセソタチツテトハヒフヘホ'.index(kk)
|
258 |
+
return 'ガギグゲゴザジズゼゾダヂヅデドバビブベボ'[ii]
|
259 |
+
# return 'がぎぐげござじずぜぞだぢづでどばびぶべぼ'[ii]
|
260 |
+
except ValueError:
|
261 |
+
# this is normal if the input is nonsense
|
262 |
+
return None
|
263 |
+
|
264 |
+
SUTEGANA = 'ャュョァィゥェォ' #'ゃゅょぁぃぅぇぉ'
|
265 |
+
PUNCT = '\'".!?(),;:-'
|
266 |
+
ODORI = '々〃ゝゞヽゞ'
|
267 |
+
|
268 |
+
@dataclass
|
269 |
+
class Token:
|
270 |
+
surface: str
|
271 |
+
space: bool # if a space should follow
|
272 |
+
def __str__(self):
|
273 |
+
sp = " " if self.space else ""
|
274 |
+
return f"{self.surface}{sp}"
|
275 |
+
|
276 |
+
class Katsu:
|
277 |
+
def __init__(self):
|
278 |
+
"""Create a Katsu object, which holds configuration as well as
|
279 |
+
tokenizer state.
|
280 |
+
|
281 |
+
Typical usage:
|
282 |
+
|
283 |
+
```python
|
284 |
+
katsu = Katsu()
|
285 |
+
roma = katsu.romaji("カツカレーを食べた")
|
286 |
+
# "Cutlet curry wo tabeta"
|
287 |
+
```
|
288 |
+
"""
|
289 |
+
self.tagger = Tagger()
|
290 |
+
self.table = dict(HEPBURN) # make a copy so we can modify it
|
291 |
+
self.exceptions = {}
|
292 |
+
|
293 |
+
def romaji(self, text):
|
294 |
+
"""Build a complete string from input text."""
|
295 |
+
if not text:
|
296 |
+
return ''
|
297 |
+
text = self._normalize_text(text)
|
298 |
+
words = self.tagger(text)
|
299 |
+
tokens = self._romaji_tokens(words)
|
300 |
+
out = ''.join([str(tok) for tok in tokens])
|
301 |
+
return re.sub(r'\s+', ' ', out.strip())
|
302 |
+
|
303 |
+
def phonemize(self, texts):
|
304 |
+
# espeak-ng API
|
305 |
+
return [self.romaji(text) for text in texts]
|
306 |
+
|
307 |
+
def _normalize_text(self, text):
|
308 |
+
"""Given text, normalize variations in Japanese.
|
309 |
+
|
310 |
+
This specifically removes variations that are meaningless for romaji
|
311 |
+
conversion using the following steps:
|
312 |
+
|
313 |
+
- Unicode NFKC normalization
|
314 |
+
- Full-width Latin to half-width
|
315 |
+
- Half-width katakana to full-width
|
316 |
+
"""
|
317 |
+
# perform unicode normalization
|
318 |
+
text = re.sub(r'[〜~](?=\d)', 'から', text) # wave dash range
|
319 |
+
text = unicodedata.normalize('NFKC', text)
|
320 |
+
# convert all full-width alphanum to half-width, since it can go out as-is
|
321 |
+
text = mojimoji.zen_to_han(text, kana=False)
|
322 |
+
# replace half-width katakana with full-width
|
323 |
+
text = mojimoji.han_to_zen(text, digit=False, ascii=False)
|
324 |
+
return ''.join([(' '+Convert(t)) if t.isdigit() else t for t in re.findall(r'\d+|\D+', text)])
|
325 |
+
|
326 |
+
def _romaji_tokens(self, words):
|
327 |
+
"""Build a list of tokens from input nodes."""
|
328 |
+
out = []
|
329 |
+
for wi, word in enumerate(words):
|
330 |
+
po = out[-1] if out else None
|
331 |
+
pw = words[wi - 1] if wi > 0 else None
|
332 |
+
nw = words[wi + 1] if wi < len(words) - 1 else None
|
333 |
+
roma = self._romaji_word(word)
|
334 |
+
tok = Token(roma, False)
|
335 |
+
# handle punctuation with atypical spacing
|
336 |
+
surface = word.surface#['orig']
|
337 |
+
if surface in '「『' or roma in '([':
|
338 |
+
if po:
|
339 |
+
po.space = True
|
340 |
+
elif surface in '」』' or roma in ']).,?!:':
|
341 |
+
if po:
|
342 |
+
po.space = False
|
343 |
+
tok.space = True
|
344 |
+
elif roma == ' ':
|
345 |
+
tok.space = False
|
346 |
+
else:
|
347 |
+
tok.space = True
|
348 |
+
out.append(tok)
|
349 |
+
# remove any leftover sokuon
|
350 |
+
for tok in out:
|
351 |
+
tok.surface = tok.surface.replace(chr(12483), '')
|
352 |
+
return out
|
353 |
+
|
354 |
+
def _romaji_word(self, word):
|
355 |
+
"""Return the romaji for a single word (node)."""
|
356 |
+
surface = word.surface#['orig']
|
357 |
+
if surface in self.exceptions:
|
358 |
+
return self.exceptions[surface]
|
359 |
+
assert not surface.isdigit(), surface
|
360 |
+
if surface.isascii():
|
361 |
+
return surface
|
362 |
+
kana = word.feature.pron or word.feature.kana or surface
|
363 |
+
if word.is_unk:
|
364 |
+
if word.char_type == 7: # katakana
|
365 |
+
pass
|
366 |
+
elif word.char_type == 3: # symbol
|
367 |
+
return ''.join(map(lambda c: self.table.get(c, c), surface))
|
368 |
+
else:
|
369 |
+
return '' # TODO: silently fail
|
370 |
+
out = ''
|
371 |
+
for ki, char in enumerate(kana):
|
372 |
+
nk = kana[ki + 1] if ki < len(kana) - 1 else None
|
373 |
+
pk = kana[ki - 1] if ki > 0 else None
|
374 |
+
out += self._get_single_mapping(pk, char, nk)
|
375 |
+
return out
|
376 |
+
|
377 |
+
def _get_single_mapping(self, pk, kk, nk):
|
378 |
+
"""Given a single kana and its neighbors, return the mapped romaji."""
|
379 |
+
# handle odoriji
|
380 |
+
# NOTE: This is very rarely useful at present because odoriji are not
|
381 |
+
# left in readings for dictionary words, and we can't follow kana
|
382 |
+
# across word boundaries.
|
383 |
+
if kk in ODORI:
|
384 |
+
if kk in 'ゝヽ':
|
385 |
+
if pk: return pk
|
386 |
+
else: return '' # invalid but be nice
|
387 |
+
if kk in 'ゞヾ': # repeat with voicing
|
388 |
+
if not pk: return ''
|
389 |
+
vv = add_dakuten(pk)
|
390 |
+
if vv: return self.table[vv]
|
391 |
+
else: return ''
|
392 |
+
# remaining are 々 for kanji and 〃 for symbols, but we can't
|
393 |
+
# infer their span reliably (or handle rendaku)
|
394 |
+
return ''
|
395 |
+
# handle digraphs
|
396 |
+
if pk and (pk + kk) in self.table:
|
397 |
+
return self.table[pk + kk]
|
398 |
+
if nk and (kk + nk) in self.table:
|
399 |
+
return ''
|
400 |
+
if nk and nk in SUTEGANA:
|
401 |
+
if kk == 'ッ': return '' # never valid, just ignore
|
402 |
+
return self.table[kk][:-1] + self.table[nk]
|
403 |
+
if kk in SUTEGANA:
|
404 |
+
return ''
|
405 |
+
if kk == 'ー': # 長音符
|
406 |
+
return 'ː'
|
407 |
+
if ord(kk) in {12387, 12483}: # っ or ッ
|
408 |
+
tnk = self.table.get(nk)
|
409 |
+
if tnk and tnk[0] in 'bdɸɡhçijkmnɲoprstɯvwz':
|
410 |
+
return tnk[0]
|
411 |
+
return kk
|
412 |
+
if ord(kk) in {12435, 12531}: # ん or ン
|
413 |
+
# https://en.wikipedia.org/wiki/N_(kana)
|
414 |
+
# m before m,p,b
|
415 |
+
# ŋ before k,g
|
416 |
+
# ɲ before ɲ,tɕ,dʑ
|
417 |
+
# n before n,t,d,r,z
|
418 |
+
# ɴ otherwise
|
419 |
+
tnk = self.table.get(nk)
|
420 |
+
if tnk:
|
421 |
+
if tnk[0] in 'mpb':
|
422 |
+
return 'm'
|
423 |
+
elif tnk[0] in 'kɡ':
|
424 |
+
return 'ŋ'
|
425 |
+
elif any(tnk.startswith(p) for p in ('ɲ','tɕ','dʑ')):
|
426 |
+
return 'ɲ'
|
427 |
+
elif tnk[0] in 'ntdrz':
|
428 |
+
return 'n'
|
429 |
+
return 'ɴ'
|
430 |
+
return self.table.get(kk, '')
|
models.py
ADDED
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
|
2 |
+
from istftnet import Decoder
|
3 |
+
from munch import Munch
|
4 |
+
from plbert import load_plbert
|
5 |
+
from torch.nn.utils import weight_norm, spectral_norm
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
import os.path as osp
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
class LearnedDownSample(nn.Module):
|
14 |
+
def __init__(self, layer_type, dim_in):
|
15 |
+
super().__init__()
|
16 |
+
self.layer_type = layer_type
|
17 |
+
|
18 |
+
if self.layer_type == 'none':
|
19 |
+
self.conv = nn.Identity()
|
20 |
+
elif self.layer_type == 'timepreserve':
|
21 |
+
self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
|
22 |
+
elif self.layer_type == 'half':
|
23 |
+
self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
|
24 |
+
else:
|
25 |
+
raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
return self.conv(x)
|
29 |
+
|
30 |
+
class LearnedUpSample(nn.Module):
|
31 |
+
def __init__(self, layer_type, dim_in):
|
32 |
+
super().__init__()
|
33 |
+
self.layer_type = layer_type
|
34 |
+
|
35 |
+
if self.layer_type == 'none':
|
36 |
+
self.conv = nn.Identity()
|
37 |
+
elif self.layer_type == 'timepreserve':
|
38 |
+
self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, output_padding=(1, 0), padding=(1, 0))
|
39 |
+
elif self.layer_type == 'half':
|
40 |
+
self.conv = nn.ConvTranspose2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, output_padding=1, padding=1)
|
41 |
+
else:
|
42 |
+
raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
|
43 |
+
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
return self.conv(x)
|
47 |
+
|
48 |
+
class DownSample(nn.Module):
|
49 |
+
def __init__(self, layer_type):
|
50 |
+
super().__init__()
|
51 |
+
self.layer_type = layer_type
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
if self.layer_type == 'none':
|
55 |
+
return x
|
56 |
+
elif self.layer_type == 'timepreserve':
|
57 |
+
return F.avg_pool2d(x, (2, 1))
|
58 |
+
elif self.layer_type == 'half':
|
59 |
+
if x.shape[-1] % 2 != 0:
|
60 |
+
x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
|
61 |
+
return F.avg_pool2d(x, 2)
|
62 |
+
else:
|
63 |
+
raise RuntimeError('Got unexpected donwsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
|
64 |
+
|
65 |
+
|
66 |
+
class UpSample(nn.Module):
|
67 |
+
def __init__(self, layer_type):
|
68 |
+
super().__init__()
|
69 |
+
self.layer_type = layer_type
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
if self.layer_type == 'none':
|
73 |
+
return x
|
74 |
+
elif self.layer_type == 'timepreserve':
|
75 |
+
return F.interpolate(x, scale_factor=(2, 1), mode='nearest')
|
76 |
+
elif self.layer_type == 'half':
|
77 |
+
return F.interpolate(x, scale_factor=2, mode='nearest')
|
78 |
+
else:
|
79 |
+
raise RuntimeError('Got unexpected upsampletype %s, expected is [none, timepreserve, half]' % self.layer_type)
|
80 |
+
|
81 |
+
|
82 |
+
class ResBlk(nn.Module):
|
83 |
+
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
84 |
+
normalize=False, downsample='none'):
|
85 |
+
super().__init__()
|
86 |
+
self.actv = actv
|
87 |
+
self.normalize = normalize
|
88 |
+
self.downsample = DownSample(downsample)
|
89 |
+
self.downsample_res = LearnedDownSample(downsample, dim_in)
|
90 |
+
self.learned_sc = dim_in != dim_out
|
91 |
+
self._build_weights(dim_in, dim_out)
|
92 |
+
|
93 |
+
def _build_weights(self, dim_in, dim_out):
|
94 |
+
self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
|
95 |
+
self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
|
96 |
+
if self.normalize:
|
97 |
+
self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
|
98 |
+
self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
|
99 |
+
if self.learned_sc:
|
100 |
+
self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
|
101 |
+
|
102 |
+
def _shortcut(self, x):
|
103 |
+
if self.learned_sc:
|
104 |
+
x = self.conv1x1(x)
|
105 |
+
if self.downsample:
|
106 |
+
x = self.downsample(x)
|
107 |
+
return x
|
108 |
+
|
109 |
+
def _residual(self, x):
|
110 |
+
if self.normalize:
|
111 |
+
x = self.norm1(x)
|
112 |
+
x = self.actv(x)
|
113 |
+
x = self.conv1(x)
|
114 |
+
x = self.downsample_res(x)
|
115 |
+
if self.normalize:
|
116 |
+
x = self.norm2(x)
|
117 |
+
x = self.actv(x)
|
118 |
+
x = self.conv2(x)
|
119 |
+
return x
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
x = self._shortcut(x) + self._residual(x)
|
123 |
+
return x / np.sqrt(2) # unit variance
|
124 |
+
|
125 |
+
class LinearNorm(torch.nn.Module):
|
126 |
+
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
127 |
+
super(LinearNorm, self).__init__()
|
128 |
+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
129 |
+
|
130 |
+
torch.nn.init.xavier_uniform_(
|
131 |
+
self.linear_layer.weight,
|
132 |
+
gain=torch.nn.init.calculate_gain(w_init_gain))
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
return self.linear_layer(x)
|
136 |
+
|
137 |
+
class Discriminator2d(nn.Module):
|
138 |
+
def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
|
139 |
+
super().__init__()
|
140 |
+
blocks = []
|
141 |
+
blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
|
142 |
+
|
143 |
+
for lid in range(repeat_num):
|
144 |
+
dim_out = min(dim_in*2, max_conv_dim)
|
145 |
+
blocks += [ResBlk(dim_in, dim_out, downsample='half')]
|
146 |
+
dim_in = dim_out
|
147 |
+
|
148 |
+
blocks += [nn.LeakyReLU(0.2)]
|
149 |
+
blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
|
150 |
+
blocks += [nn.LeakyReLU(0.2)]
|
151 |
+
blocks += [nn.AdaptiveAvgPool2d(1)]
|
152 |
+
blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
|
153 |
+
self.main = nn.Sequential(*blocks)
|
154 |
+
|
155 |
+
def get_feature(self, x):
|
156 |
+
features = []
|
157 |
+
for l in self.main:
|
158 |
+
x = l(x)
|
159 |
+
features.append(x)
|
160 |
+
out = features[-1]
|
161 |
+
out = out.view(out.size(0), -1) # (batch, num_domains)
|
162 |
+
return out, features
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
out, features = self.get_feature(x)
|
166 |
+
out = out.squeeze() # (batch)
|
167 |
+
return out, features
|
168 |
+
|
169 |
+
class ResBlk1d(nn.Module):
|
170 |
+
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
|
171 |
+
normalize=False, downsample='none', dropout_p=0.2):
|
172 |
+
super().__init__()
|
173 |
+
self.actv = actv
|
174 |
+
self.normalize = normalize
|
175 |
+
self.downsample_type = downsample
|
176 |
+
self.learned_sc = dim_in != dim_out
|
177 |
+
self._build_weights(dim_in, dim_out)
|
178 |
+
self.dropout_p = dropout_p
|
179 |
+
|
180 |
+
if self.downsample_type == 'none':
|
181 |
+
self.pool = nn.Identity()
|
182 |
+
else:
|
183 |
+
self.pool = weight_norm(nn.Conv1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1))
|
184 |
+
|
185 |
+
def _build_weights(self, dim_in, dim_out):
|
186 |
+
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
|
187 |
+
self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
188 |
+
if self.normalize:
|
189 |
+
self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
|
190 |
+
self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
|
191 |
+
if self.learned_sc:
|
192 |
+
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
193 |
+
|
194 |
+
def downsample(self, x):
|
195 |
+
if self.downsample_type == 'none':
|
196 |
+
return x
|
197 |
+
else:
|
198 |
+
if x.shape[-1] % 2 != 0:
|
199 |
+
x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
|
200 |
+
return F.avg_pool1d(x, 2)
|
201 |
+
|
202 |
+
def _shortcut(self, x):
|
203 |
+
if self.learned_sc:
|
204 |
+
x = self.conv1x1(x)
|
205 |
+
x = self.downsample(x)
|
206 |
+
return x
|
207 |
+
|
208 |
+
def _residual(self, x):
|
209 |
+
if self.normalize:
|
210 |
+
x = self.norm1(x)
|
211 |
+
x = self.actv(x)
|
212 |
+
x = F.dropout(x, p=self.dropout_p, training=self.training)
|
213 |
+
|
214 |
+
x = self.conv1(x)
|
215 |
+
x = self.pool(x)
|
216 |
+
if self.normalize:
|
217 |
+
x = self.norm2(x)
|
218 |
+
|
219 |
+
x = self.actv(x)
|
220 |
+
x = F.dropout(x, p=self.dropout_p, training=self.training)
|
221 |
+
|
222 |
+
x = self.conv2(x)
|
223 |
+
return x
|
224 |
+
|
225 |
+
def forward(self, x):
|
226 |
+
x = self._shortcut(x) + self._residual(x)
|
227 |
+
return x / np.sqrt(2) # unit variance
|
228 |
+
|
229 |
+
class LayerNorm(nn.Module):
|
230 |
+
def __init__(self, channels, eps=1e-5):
|
231 |
+
super().__init__()
|
232 |
+
self.channels = channels
|
233 |
+
self.eps = eps
|
234 |
+
|
235 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
236 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
237 |
+
|
238 |
+
def forward(self, x):
|
239 |
+
x = x.transpose(1, -1)
|
240 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
241 |
+
return x.transpose(1, -1)
|
242 |
+
|
243 |
+
class TextEncoder(nn.Module):
|
244 |
+
def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
|
245 |
+
super().__init__()
|
246 |
+
self.embedding = nn.Embedding(n_symbols, channels)
|
247 |
+
|
248 |
+
padding = (kernel_size - 1) // 2
|
249 |
+
self.cnn = nn.ModuleList()
|
250 |
+
for _ in range(depth):
|
251 |
+
self.cnn.append(nn.Sequential(
|
252 |
+
weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
|
253 |
+
LayerNorm(channels),
|
254 |
+
actv,
|
255 |
+
nn.Dropout(0.2),
|
256 |
+
))
|
257 |
+
# self.cnn = nn.Sequential(*self.cnn)
|
258 |
+
|
259 |
+
self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
|
260 |
+
|
261 |
+
def forward(self, x, input_lengths, m):
|
262 |
+
x = self.embedding(x) # [B, T, emb]
|
263 |
+
x = x.transpose(1, 2) # [B, emb, T]
|
264 |
+
m = m.to(input_lengths.device).unsqueeze(1)
|
265 |
+
x.masked_fill_(m, 0.0)
|
266 |
+
|
267 |
+
for c in self.cnn:
|
268 |
+
x = c(x)
|
269 |
+
x.masked_fill_(m, 0.0)
|
270 |
+
|
271 |
+
x = x.transpose(1, 2) # [B, T, chn]
|
272 |
+
|
273 |
+
input_lengths = input_lengths.cpu().numpy()
|
274 |
+
x = nn.utils.rnn.pack_padded_sequence(
|
275 |
+
x, input_lengths, batch_first=True, enforce_sorted=False)
|
276 |
+
|
277 |
+
self.lstm.flatten_parameters()
|
278 |
+
x, _ = self.lstm(x)
|
279 |
+
x, _ = nn.utils.rnn.pad_packed_sequence(
|
280 |
+
x, batch_first=True)
|
281 |
+
|
282 |
+
x = x.transpose(-1, -2)
|
283 |
+
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
284 |
+
|
285 |
+
x_pad[:, :, :x.shape[-1]] = x
|
286 |
+
x = x_pad.to(x.device)
|
287 |
+
|
288 |
+
x.masked_fill_(m, 0.0)
|
289 |
+
|
290 |
+
return x
|
291 |
+
|
292 |
+
def inference(self, x):
|
293 |
+
x = self.embedding(x)
|
294 |
+
x = x.transpose(1, 2)
|
295 |
+
x = self.cnn(x)
|
296 |
+
x = x.transpose(1, 2)
|
297 |
+
self.lstm.flatten_parameters()
|
298 |
+
x, _ = self.lstm(x)
|
299 |
+
return x
|
300 |
+
|
301 |
+
def length_to_mask(self, lengths):
|
302 |
+
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
303 |
+
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
304 |
+
return mask
|
305 |
+
|
306 |
+
|
307 |
+
|
308 |
+
class AdaIN1d(nn.Module):
|
309 |
+
def __init__(self, style_dim, num_features):
|
310 |
+
super().__init__()
|
311 |
+
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
312 |
+
self.fc = nn.Linear(style_dim, num_features*2)
|
313 |
+
|
314 |
+
def forward(self, x, s):
|
315 |
+
h = self.fc(s)
|
316 |
+
h = h.view(h.size(0), h.size(1), 1)
|
317 |
+
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
318 |
+
return (1 + gamma) * self.norm(x) + beta
|
319 |
+
|
320 |
+
class UpSample1d(nn.Module):
|
321 |
+
def __init__(self, layer_type):
|
322 |
+
super().__init__()
|
323 |
+
self.layer_type = layer_type
|
324 |
+
|
325 |
+
def forward(self, x):
|
326 |
+
if self.layer_type == 'none':
|
327 |
+
return x
|
328 |
+
else:
|
329 |
+
return F.interpolate(x, scale_factor=2, mode='nearest')
|
330 |
+
|
331 |
+
class AdainResBlk1d(nn.Module):
|
332 |
+
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
|
333 |
+
upsample='none', dropout_p=0.0):
|
334 |
+
super().__init__()
|
335 |
+
self.actv = actv
|
336 |
+
self.upsample_type = upsample
|
337 |
+
self.upsample = UpSample1d(upsample)
|
338 |
+
self.learned_sc = dim_in != dim_out
|
339 |
+
self._build_weights(dim_in, dim_out, style_dim)
|
340 |
+
self.dropout = nn.Dropout(dropout_p)
|
341 |
+
|
342 |
+
if upsample == 'none':
|
343 |
+
self.pool = nn.Identity()
|
344 |
+
else:
|
345 |
+
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
|
346 |
+
|
347 |
+
|
348 |
+
def _build_weights(self, dim_in, dim_out, style_dim):
|
349 |
+
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
350 |
+
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
351 |
+
self.norm1 = AdaIN1d(style_dim, dim_in)
|
352 |
+
self.norm2 = AdaIN1d(style_dim, dim_out)
|
353 |
+
if self.learned_sc:
|
354 |
+
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
355 |
+
|
356 |
+
def _shortcut(self, x):
|
357 |
+
x = self.upsample(x)
|
358 |
+
if self.learned_sc:
|
359 |
+
x = self.conv1x1(x)
|
360 |
+
return x
|
361 |
+
|
362 |
+
def _residual(self, x, s):
|
363 |
+
x = self.norm1(x, s)
|
364 |
+
x = self.actv(x)
|
365 |
+
x = self.pool(x)
|
366 |
+
x = self.conv1(self.dropout(x))
|
367 |
+
x = self.norm2(x, s)
|
368 |
+
x = self.actv(x)
|
369 |
+
x = self.conv2(self.dropout(x))
|
370 |
+
return x
|
371 |
+
|
372 |
+
def forward(self, x, s):
|
373 |
+
out = self._residual(x, s)
|
374 |
+
out = (out + self._shortcut(x)) / np.sqrt(2)
|
375 |
+
return out
|
376 |
+
|
377 |
+
class AdaLayerNorm(nn.Module):
|
378 |
+
def __init__(self, style_dim, channels, eps=1e-5):
|
379 |
+
super().__init__()
|
380 |
+
self.channels = channels
|
381 |
+
self.eps = eps
|
382 |
+
|
383 |
+
self.fc = nn.Linear(style_dim, channels*2)
|
384 |
+
|
385 |
+
def forward(self, x, s):
|
386 |
+
x = x.transpose(-1, -2)
|
387 |
+
x = x.transpose(1, -1)
|
388 |
+
|
389 |
+
h = self.fc(s)
|
390 |
+
h = h.view(h.size(0), h.size(1), 1)
|
391 |
+
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
392 |
+
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
|
393 |
+
|
394 |
+
|
395 |
+
x = F.layer_norm(x, (self.channels,), eps=self.eps)
|
396 |
+
x = (1 + gamma) * x + beta
|
397 |
+
return x.transpose(1, -1).transpose(-1, -2)
|
398 |
+
|
399 |
+
class ProsodyPredictor(nn.Module):
|
400 |
+
|
401 |
+
def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
|
402 |
+
super().__init__()
|
403 |
+
|
404 |
+
self.text_encoder = DurationEncoder(sty_dim=style_dim,
|
405 |
+
d_model=d_hid,
|
406 |
+
nlayers=nlayers,
|
407 |
+
dropout=dropout)
|
408 |
+
|
409 |
+
self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
410 |
+
self.duration_proj = LinearNorm(d_hid, max_dur)
|
411 |
+
|
412 |
+
self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
413 |
+
self.F0 = nn.ModuleList()
|
414 |
+
self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
415 |
+
self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
416 |
+
self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
417 |
+
|
418 |
+
self.N = nn.ModuleList()
|
419 |
+
self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
420 |
+
self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
421 |
+
self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
422 |
+
|
423 |
+
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
424 |
+
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
425 |
+
|
426 |
+
|
427 |
+
def forward(self, texts, style, text_lengths, alignment, m):
|
428 |
+
d = self.text_encoder(texts, style, text_lengths, m)
|
429 |
+
|
430 |
+
batch_size = d.shape[0]
|
431 |
+
text_size = d.shape[1]
|
432 |
+
|
433 |
+
# predict duration
|
434 |
+
input_lengths = text_lengths.cpu().numpy()
|
435 |
+
x = nn.utils.rnn.pack_padded_sequence(
|
436 |
+
d, input_lengths, batch_first=True, enforce_sorted=False)
|
437 |
+
|
438 |
+
m = m.to(text_lengths.device).unsqueeze(1)
|
439 |
+
|
440 |
+
self.lstm.flatten_parameters()
|
441 |
+
x, _ = self.lstm(x)
|
442 |
+
x, _ = nn.utils.rnn.pad_packed_sequence(
|
443 |
+
x, batch_first=True)
|
444 |
+
|
445 |
+
x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
|
446 |
+
|
447 |
+
x_pad[:, :x.shape[1], :] = x
|
448 |
+
x = x_pad.to(x.device)
|
449 |
+
|
450 |
+
duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
|
451 |
+
|
452 |
+
en = (d.transpose(-1, -2) @ alignment)
|
453 |
+
|
454 |
+
return duration.squeeze(-1), en
|
455 |
+
|
456 |
+
def F0Ntrain(self, x, s):
|
457 |
+
x, _ = self.shared(x.transpose(-1, -2))
|
458 |
+
|
459 |
+
F0 = x.transpose(-1, -2)
|
460 |
+
for block in self.F0:
|
461 |
+
F0 = block(F0, s)
|
462 |
+
F0 = self.F0_proj(F0)
|
463 |
+
|
464 |
+
N = x.transpose(-1, -2)
|
465 |
+
for block in self.N:
|
466 |
+
N = block(N, s)
|
467 |
+
N = self.N_proj(N)
|
468 |
+
|
469 |
+
return F0.squeeze(1), N.squeeze(1)
|
470 |
+
|
471 |
+
def length_to_mask(self, lengths):
|
472 |
+
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
473 |
+
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
474 |
+
return mask
|
475 |
+
|
476 |
+
class DurationEncoder(nn.Module):
|
477 |
+
|
478 |
+
def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
|
479 |
+
super().__init__()
|
480 |
+
self.lstms = nn.ModuleList()
|
481 |
+
for _ in range(nlayers):
|
482 |
+
self.lstms.append(nn.LSTM(d_model + sty_dim,
|
483 |
+
d_model // 2,
|
484 |
+
num_layers=1,
|
485 |
+
batch_first=True,
|
486 |
+
bidirectional=True,
|
487 |
+
dropout=dropout))
|
488 |
+
self.lstms.append(AdaLayerNorm(sty_dim, d_model))
|
489 |
+
|
490 |
+
|
491 |
+
self.dropout = dropout
|
492 |
+
self.d_model = d_model
|
493 |
+
self.sty_dim = sty_dim
|
494 |
+
|
495 |
+
def forward(self, x, style, text_lengths, m):
|
496 |
+
masks = m.to(text_lengths.device)
|
497 |
+
|
498 |
+
x = x.permute(2, 0, 1)
|
499 |
+
s = style.expand(x.shape[0], x.shape[1], -1)
|
500 |
+
x = torch.cat([x, s], axis=-1)
|
501 |
+
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
|
502 |
+
|
503 |
+
x = x.transpose(0, 1)
|
504 |
+
input_lengths = text_lengths.cpu().numpy()
|
505 |
+
x = x.transpose(-1, -2)
|
506 |
+
|
507 |
+
for block in self.lstms:
|
508 |
+
if isinstance(block, AdaLayerNorm):
|
509 |
+
x = block(x.transpose(-1, -2), style).transpose(-1, -2)
|
510 |
+
x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
|
511 |
+
x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
|
512 |
+
else:
|
513 |
+
x = x.transpose(-1, -2)
|
514 |
+
x = nn.utils.rnn.pack_padded_sequence(
|
515 |
+
x, input_lengths, batch_first=True, enforce_sorted=False)
|
516 |
+
block.flatten_parameters()
|
517 |
+
x, _ = block(x)
|
518 |
+
x, _ = nn.utils.rnn.pad_packed_sequence(
|
519 |
+
x, batch_first=True)
|
520 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
521 |
+
x = x.transpose(-1, -2)
|
522 |
+
|
523 |
+
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
524 |
+
|
525 |
+
x_pad[:, :, :x.shape[-1]] = x
|
526 |
+
x = x_pad.to(x.device)
|
527 |
+
|
528 |
+
return x.transpose(-1, -2)
|
529 |
+
|
530 |
+
def inference(self, x, style):
|
531 |
+
x = self.embedding(x.transpose(-1, -2)) * np.sqrt(self.d_model)
|
532 |
+
style = style.expand(x.shape[0], x.shape[1], -1)
|
533 |
+
x = torch.cat([x, style], axis=-1)
|
534 |
+
src = self.pos_encoder(x)
|
535 |
+
output = self.transformer_encoder(src).transpose(0, 1)
|
536 |
+
return output
|
537 |
+
|
538 |
+
def length_to_mask(self, lengths):
|
539 |
+
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
540 |
+
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
541 |
+
return mask
|
542 |
+
|
543 |
+
# https://github.com/yl4579/StyleTTS2/blob/main/utils.py
|
544 |
+
def recursive_munch(d):
|
545 |
+
if isinstance(d, dict):
|
546 |
+
return Munch((k, recursive_munch(v)) for k, v in d.items())
|
547 |
+
elif isinstance(d, list):
|
548 |
+
return [recursive_munch(v) for v in d]
|
549 |
+
else:
|
550 |
+
return d
|
551 |
+
|
552 |
+
def build_model(args):
|
553 |
+
args = recursive_munch(args)
|
554 |
+
assert args.decoder.type == 'istftnet', 'Decoder type unknown'
|
555 |
+
decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
|
556 |
+
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
|
557 |
+
upsample_rates = args.decoder.upsample_rates,
|
558 |
+
upsample_initial_channel=args.decoder.upsample_initial_channel,
|
559 |
+
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
|
560 |
+
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
|
561 |
+
gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
|
562 |
+
text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
|
563 |
+
predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
|
564 |
+
bert = load_plbert()
|
565 |
+
return Munch(
|
566 |
+
bert=bert,
|
567 |
+
bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
|
568 |
+
predictor=predictor,
|
569 |
+
decoder=decoder,
|
570 |
+
text_encoder=text_encoder,
|
571 |
+
)
|
num2kana.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/Greatdane/Convert-Numbers-to-Japanese/blob/master/Convert-Numbers-to-Japanese.py
|
2 |
+
# Japanese Number Converter
|
3 |
+
# - Currently using length functions - possible to use Recursive Functions? Difficult with the many exceptions
|
4 |
+
# - Works up to 9 figures (999999999)
|
5 |
+
|
6 |
+
romaji_dict = {".": "ten", "0": "zero", "1": "ichi", "2": "ni", "3": "san", "4": "yon", "5": "go", "6": "roku", "7": "nana",
|
7 |
+
"8": "hachi", "9": "kyuu", "10": "juu", "100": "hyaku", "1000": "sen", "10000": "man", "100000000": "oku",
|
8 |
+
"300": "sanbyaku", "600": "roppyaku", "800": "happyaku", "3000": "sanzen", "8000":"hassen", "01000": "issen"}
|
9 |
+
|
10 |
+
kanji_dict = {".": "点", "0": "零", "1": "一", "2": "二", "3": "三", "4": "四", "5": "五", "6": "六", "7": "七",
|
11 |
+
"8": "八", "9": "九", "10": "十", "100": "百", "1000": "千", "10000": "万", "100000000": "億",
|
12 |
+
"300": "三百", "600": "六百", "800": "八百", "3000": "三千", "8000":"八千", "01000": "一千"}
|
13 |
+
|
14 |
+
hiragana_dict = {".": "てん", "0": "ゼロ", "1": "いち", "2": "に", "3": "さん", "4": "よん", "5": "ご", "6": "ろく", "7": "なな",
|
15 |
+
"8": "はち", "9": "きゅう", "10": "じゅう", "100": "ひゃく", "1000": "せん", "10000": "まん", "100000000": "おく",
|
16 |
+
"300": "さんびゃく", "600": "ろっぴゃく", "800": "はっぴゃく", "3000": "さんぜん", "8000":"はっせん", "01000": "いっせん" }
|
17 |
+
|
18 |
+
key_dict = {"kanji" : kanji_dict, "hiragana" : hiragana_dict, "romaji": romaji_dict}
|
19 |
+
|
20 |
+
def len_one(convert_num,requested_dict):
|
21 |
+
# Returns single digit conversion, 0-9
|
22 |
+
return requested_dict[convert_num]
|
23 |
+
|
24 |
+
def len_two(convert_num,requested_dict):
|
25 |
+
# Returns the conversion, when number is of length two (10-99)
|
26 |
+
if convert_num[0] == "0": #if 0 is first, return len_one
|
27 |
+
return len_one(convert_num[1],requested_dict)
|
28 |
+
if convert_num == "10":
|
29 |
+
return requested_dict["10"] # Exception, if number is 10, simple return 10
|
30 |
+
if convert_num[0] == "1": # When first number is 1, use ten plus second number
|
31 |
+
return requested_dict["10"] + " " + len_one(convert_num[1],requested_dict)
|
32 |
+
elif convert_num[1] == "0": # If ending number is zero, give first number plus 10
|
33 |
+
return len_one(convert_num[0],requested_dict) + " " + requested_dict["10"]
|
34 |
+
else:
|
35 |
+
num_list = []
|
36 |
+
for x in convert_num:
|
37 |
+
num_list.append(requested_dict[x])
|
38 |
+
num_list.insert(1, requested_dict["10"])
|
39 |
+
# Convert to a string (from a list)
|
40 |
+
output = ""
|
41 |
+
for y in num_list:
|
42 |
+
output += y + " "
|
43 |
+
output = output[:len(output) - 1] # take off the space
|
44 |
+
return output
|
45 |
+
|
46 |
+
def len_three(convert_num,requested_dict):
|
47 |
+
# Returns the conversion, when number is of length three (100-999)
|
48 |
+
num_list = []
|
49 |
+
if convert_num[0] == "1":
|
50 |
+
num_list.append(requested_dict["100"])
|
51 |
+
elif convert_num[0] == "3":
|
52 |
+
num_list.append(requested_dict["300"])
|
53 |
+
elif convert_num[0] == "6":
|
54 |
+
num_list.append(requested_dict["600"])
|
55 |
+
elif convert_num[0] == "8":
|
56 |
+
num_list.append(requested_dict["800"])
|
57 |
+
else:
|
58 |
+
num_list.append(requested_dict[convert_num[0]])
|
59 |
+
num_list.append(requested_dict["100"])
|
60 |
+
if convert_num[1:] == "00" and len(convert_num) == 3:
|
61 |
+
pass
|
62 |
+
else:
|
63 |
+
if convert_num[1] == "0":
|
64 |
+
num_list.append(requested_dict[convert_num[2]])
|
65 |
+
else:
|
66 |
+
num_list.append(len_two(convert_num[1:], requested_dict))
|
67 |
+
output = ""
|
68 |
+
for y in num_list:
|
69 |
+
output += y + " "
|
70 |
+
output = output[:len(output) - 1]
|
71 |
+
return output
|
72 |
+
|
73 |
+
def len_four(convert_num,requested_dict, stand_alone):
|
74 |
+
# Returns the conversion, when number is of length four (1000-9999)
|
75 |
+
num_list = []
|
76 |
+
# First, check for zeros (and get deal with them)
|
77 |
+
if convert_num == "0000":
|
78 |
+
return ""
|
79 |
+
while convert_num[0] == "0":
|
80 |
+
convert_num = convert_num[1:]
|
81 |
+
if len(convert_num) == 1:
|
82 |
+
return len_one(convert_num,requested_dict)
|
83 |
+
elif len(convert_num) == 2:
|
84 |
+
return len_two(convert_num,requested_dict)
|
85 |
+
elif len(convert_num) == 3:
|
86 |
+
return len_three(convert_num,requested_dict)
|
87 |
+
# If no zeros, do the calculation
|
88 |
+
else:
|
89 |
+
# Have to handle 1000, depending on if its a standalone 1000-9999 or included in a larger number
|
90 |
+
if convert_num[0] == "1" and stand_alone:
|
91 |
+
num_list.append(requested_dict["1000"])
|
92 |
+
elif convert_num[0] == "1":
|
93 |
+
num_list.append(requested_dict["01000"])
|
94 |
+
elif convert_num[0] == "3":
|
95 |
+
num_list.append(requested_dict["3000"])
|
96 |
+
elif convert_num[0] == "8":
|
97 |
+
num_list.append(requested_dict["8000"])
|
98 |
+
else:
|
99 |
+
num_list.append(requested_dict[convert_num[0]])
|
100 |
+
num_list.append(requested_dict["1000"])
|
101 |
+
if convert_num[1:] == "000" and len(convert_num) == 4:
|
102 |
+
pass
|
103 |
+
else:
|
104 |
+
if convert_num[1] == "0":
|
105 |
+
num_list.append(len_two(convert_num[2:],requested_dict))
|
106 |
+
else:
|
107 |
+
num_list.append(len_three(convert_num[1:],requested_dict))
|
108 |
+
output = ""
|
109 |
+
for y in num_list:
|
110 |
+
output += y + " "
|
111 |
+
output = output[:len(output) - 1]
|
112 |
+
return output
|
113 |
+
|
114 |
+
|
115 |
+
def len_x(convert_num,requested_dict):
|
116 |
+
#Returns everything else.. (up to 9 digits)
|
117 |
+
num_list = []
|
118 |
+
if len(convert_num[0:-4]) == 1:
|
119 |
+
num_list.append(requested_dict[convert_num[0:-4]])
|
120 |
+
num_list.append(requested_dict["10000"])
|
121 |
+
elif len(convert_num[0:-4]) == 2:
|
122 |
+
num_list.append(len_two(convert_num[0:2],requested_dict))
|
123 |
+
num_list.append(requested_dict["10000"])
|
124 |
+
elif len(convert_num[0:-4]) == 3:
|
125 |
+
num_list.append(len_three(convert_num[0:3],requested_dict))
|
126 |
+
num_list.append(requested_dict["10000"])
|
127 |
+
elif len(convert_num[0:-4]) == 4:
|
128 |
+
num_list.append(len_four(convert_num[0:4],requested_dict, False))
|
129 |
+
num_list.append(requested_dict["10000"])
|
130 |
+
elif len(convert_num[0:-4]) == 5:
|
131 |
+
num_list.append(requested_dict[convert_num[0]])
|
132 |
+
num_list.append(requested_dict["100000000"])
|
133 |
+
num_list.append(len_four(convert_num[1:5],requested_dict, False))
|
134 |
+
if convert_num[1:5] == "0000":
|
135 |
+
pass
|
136 |
+
else:
|
137 |
+
num_list.append(requested_dict["10000"])
|
138 |
+
else:
|
139 |
+
assert False, "Not yet implemented, please choose a lower number."
|
140 |
+
num_list.append(len_four(convert_num[-4:],requested_dict, False))
|
141 |
+
output = ""
|
142 |
+
for y in num_list:
|
143 |
+
output += y + " "
|
144 |
+
output = output[:len(output) - 1]
|
145 |
+
return output
|
146 |
+
|
147 |
+
def remove_spaces(convert_result):
|
148 |
+
# Remove spaces in Hirigana and Kanji results
|
149 |
+
correction = ""
|
150 |
+
for x in convert_result:
|
151 |
+
if x == " ":
|
152 |
+
pass
|
153 |
+
else:
|
154 |
+
correction += x
|
155 |
+
return correction
|
156 |
+
|
157 |
+
def do_convert(convert_num,requested_dict):
|
158 |
+
#Check lengths and convert accordingly
|
159 |
+
if len(convert_num) == 1:
|
160 |
+
return(len_one(convert_num,requested_dict))
|
161 |
+
elif len(convert_num) == 2:
|
162 |
+
return(len_two(convert_num,requested_dict))
|
163 |
+
elif len(convert_num) == 3:
|
164 |
+
return(len_three(convert_num,requested_dict))
|
165 |
+
elif len(convert_num) == 4:
|
166 |
+
return(len_four(convert_num,requested_dict, True))
|
167 |
+
else:
|
168 |
+
return(len_x(convert_num,requested_dict))
|
169 |
+
|
170 |
+
def split_Point(split_num,dict_choice):
|
171 |
+
# Used if a decmial point is in the string.
|
172 |
+
split_num = split_num.split(".")
|
173 |
+
split_num_a = split_num[0]
|
174 |
+
split_num_b = split_num[1]
|
175 |
+
split_num_b_end = " "
|
176 |
+
for x in split_num_b:
|
177 |
+
split_num_b_end += len_one(x,key_dict[dict_choice]) + " "
|
178 |
+
# To account for small exception of small tsu when ending in jyuu in hiragana/romaji
|
179 |
+
if split_num_a[-1] == "0" and split_num_a[-2] != "0" and dict_choice == "hiragana":
|
180 |
+
small_Tsu = Convert(split_num_a,dict_choice)
|
181 |
+
small_Tsu = small_Tsu[0:-1] + "っ"
|
182 |
+
return small_Tsu + key_dict[dict_choice]["."] + split_num_b_end
|
183 |
+
if split_num_a[-1] == "0" and split_num_a[-2] != "0" and dict_choice == "romaji":
|
184 |
+
small_Tsu = Convert(split_num_a,dict_choice)
|
185 |
+
small_Tsu = small_Tsu[0:-1] + "t"
|
186 |
+
return small_Tsu + key_dict[dict_choice]["."] + split_num_b_end
|
187 |
+
|
188 |
+
return Convert(split_num_a,dict_choice) + " " + key_dict[dict_choice]["."] + split_num_b_end
|
189 |
+
|
190 |
+
|
191 |
+
def do_kanji_convert(convert_num):
|
192 |
+
# Converts kanji to arabic number
|
193 |
+
|
194 |
+
if convert_num == "零":
|
195 |
+
return 0
|
196 |
+
|
197 |
+
# First, needs to check for MAN 万 and OKU 億 kanji, as need to handle differently, splitting up the numbers at these intervals.
|
198 |
+
# key tells us whether we need to add or multiply the numbers, then we create a list of numbers in an order we need to add/multiply
|
199 |
+
key = []
|
200 |
+
numberList = []
|
201 |
+
y = ""
|
202 |
+
for x in convert_num:
|
203 |
+
if x == "万" or x == "億":
|
204 |
+
numberList.append(y)
|
205 |
+
key.append("times")
|
206 |
+
numberList.append(x)
|
207 |
+
key.append("plus")
|
208 |
+
y = ""
|
209 |
+
else:
|
210 |
+
y += x
|
211 |
+
if y != "":
|
212 |
+
numberList.append(y)
|
213 |
+
|
214 |
+
numberListConverted = []
|
215 |
+
baseNumber = ["一", "二", "三", "四", "五", "六", "七", "八", "九"]
|
216 |
+
linkNumber = ["十", "百", "千", "万", "億"]
|
217 |
+
|
218 |
+
# Converts the kanji number list to arabic numbers, using the 'base number' and 'link number' list above. For a link number, we would need to
|
219 |
+
# link with a base number
|
220 |
+
for noX in numberList:
|
221 |
+
count = len(noX)
|
222 |
+
result = 0
|
223 |
+
skip = 1
|
224 |
+
for x in reversed(noX):
|
225 |
+
addTo = 0
|
226 |
+
skip -= 1
|
227 |
+
count = count - 1
|
228 |
+
if skip == 1:
|
229 |
+
continue
|
230 |
+
if x in baseNumber:
|
231 |
+
for y, z in kanji_dict.items():
|
232 |
+
if z == x:
|
233 |
+
result += int(y)
|
234 |
+
elif x in linkNumber:
|
235 |
+
if noX[count - 1] in baseNumber and count > 0:
|
236 |
+
for y, z in kanji_dict.items():
|
237 |
+
if z == noX[count - 1]:
|
238 |
+
tempNo = int(y)
|
239 |
+
for y, z in kanji_dict.items():
|
240 |
+
if z == x:
|
241 |
+
addTo += tempNo * int(y)
|
242 |
+
result += addTo
|
243 |
+
skip = 2
|
244 |
+
else:
|
245 |
+
for y, z in kanji_dict.items():
|
246 |
+
if z == x:
|
247 |
+
result += int(y)
|
248 |
+
numberListConverted.append(int(result))
|
249 |
+
|
250 |
+
result = numberListConverted[0]
|
251 |
+
y = 0
|
252 |
+
|
253 |
+
# Iterate over the converted list, and either multiply/add as instructed in key list
|
254 |
+
for x in range(1,len(numberListConverted)):
|
255 |
+
if key[y] == "plus":
|
256 |
+
try:
|
257 |
+
if key[y+1] == "times":
|
258 |
+
result = result + numberListConverted[x] * numberListConverted[x+1]
|
259 |
+
y += 1
|
260 |
+
else:
|
261 |
+
result += numberListConverted[x]
|
262 |
+
except IndexError:
|
263 |
+
result += numberListConverted[-1]
|
264 |
+
break
|
265 |
+
else:
|
266 |
+
result = result * numberListConverted[x]
|
267 |
+
y += 1
|
268 |
+
|
269 |
+
return result
|
270 |
+
|
271 |
+
def Convert(convert_num, dict_choice='hiragana'):
|
272 |
+
# Input formatting
|
273 |
+
convert_num = str(convert_num)
|
274 |
+
convert_num = convert_num.replace(',','')
|
275 |
+
dict_choice = dict_choice.lower()
|
276 |
+
|
277 |
+
# If all is selected as dict_choice, return as a list
|
278 |
+
if dict_choice == "all":
|
279 |
+
result_list = []
|
280 |
+
for x in "kanji", "hiragana", "romaji":
|
281 |
+
result_list.append(Convert(convert_num,x))
|
282 |
+
return result_list
|
283 |
+
|
284 |
+
dictionary = key_dict[dict_choice]
|
285 |
+
|
286 |
+
# Exit if length is greater than current limit
|
287 |
+
if len(convert_num) > 9:
|
288 |
+
return("Number length too long, choose less than 10 digits")
|
289 |
+
|
290 |
+
# Remove any leading zeroes
|
291 |
+
while convert_num[0] == "0" and len(convert_num) > 1:
|
292 |
+
convert_num = convert_num[1:]
|
293 |
+
|
294 |
+
# Check for decimal places
|
295 |
+
if "." in convert_num:
|
296 |
+
result = split_Point(convert_num,dict_choice)
|
297 |
+
else:
|
298 |
+
result = do_convert(convert_num, dictionary)
|
299 |
+
|
300 |
+
# Remove spaces and return result
|
301 |
+
if key_dict[dict_choice] == romaji_dict:
|
302 |
+
pass
|
303 |
+
else:
|
304 |
+
result = remove_spaces(result)
|
305 |
+
return result
|
306 |
+
|
307 |
+
def ConvertKanji(convert_num):
|
308 |
+
if convert_num[0] in kanji_dict.values():
|
309 |
+
# Check to see if 点 (point) is in the input, and handle by splitting at 点, before and after is handled separately
|
310 |
+
if "点" in convert_num:
|
311 |
+
point = convert_num.find("点")
|
312 |
+
endNumber = ""
|
313 |
+
for x in convert_num[point+1:]:
|
314 |
+
endNumber += list(kanji_dict.keys())[list(kanji_dict.values()).index(x)]
|
315 |
+
return(str(do_kanji_convert(convert_num[0:point])) + "." + endNumber)
|
316 |
+
else:
|
317 |
+
return(str(do_kanji_convert(convert_num)))
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
espeak-ng
|
plbert.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
|
2 |
+
from transformers import AlbertConfig, AlbertModel
|
3 |
+
|
4 |
+
class CustomAlbert(AlbertModel):
|
5 |
+
def forward(self, *args, **kwargs):
|
6 |
+
# Call the original forward method
|
7 |
+
outputs = super().forward(*args, **kwargs)
|
8 |
+
# Only return the last_hidden_state
|
9 |
+
return outputs.last_hidden_state
|
10 |
+
|
11 |
+
def load_plbert():
|
12 |
+
plbert_config = {'vocab_size': 178, 'hidden_size': 768, 'num_attention_heads': 12, 'intermediate_size': 2048, 'max_position_embeddings': 512, 'num_hidden_layers': 12, 'dropout': 0.1}
|
13 |
+
albert_base_configuration = AlbertConfig(**plbert_config)
|
14 |
+
bert = CustomAlbert(albert_base_configuration)
|
15 |
+
return bert
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fugashi
|
2 |
+
gradio
|
3 |
+
mojimoji
|
4 |
+
munch
|
5 |
+
noisereduce
|
6 |
+
phonemizer
|
7 |
+
scipy
|
8 |
+
torch
|
9 |
+
transformers
|
10 |
+
unicodedata2
|
11 |
+
unidic-lite
|