ASLP-lab commited on
Commit
010341e
ยท
verified ยท
1 Parent(s): 9f7e23b
This view is limited to 50 files because it contains too many changes. ย  See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. app.py +533 -0
  3. bigvgan/__init__.py +0 -0
  4. bigvgan/activations.py +126 -0
  5. bigvgan/alias_free_activation/cuda/__init__.py +0 -0
  6. bigvgan/alias_free_activation/cuda/activation1d.py +77 -0
  7. bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp +23 -0
  8. bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
  9. bigvgan/alias_free_activation/cuda/compat.h +29 -0
  10. bigvgan/alias_free_activation/cuda/load.py +86 -0
  11. bigvgan/alias_free_activation/cuda/type_shim.h +92 -0
  12. bigvgan/alias_free_activation/torch/__init__.py +6 -0
  13. bigvgan/alias_free_activation/torch/act.py +30 -0
  14. bigvgan/alias_free_activation/torch/filter.py +101 -0
  15. bigvgan/alias_free_activation/torch/resample.py +58 -0
  16. bigvgan/env.py +18 -0
  17. bigvgan/model.py +545 -0
  18. bigvgan/utils.py +59 -0
  19. diffrhythm2/__init__.py +0 -0
  20. diffrhythm2/backbones/__init__.py +0 -0
  21. diffrhythm2/backbones/dit.py +222 -0
  22. diffrhythm2/backbones/flex_attention.py +237 -0
  23. diffrhythm2/backbones/llama_attention.py +451 -0
  24. diffrhythm2/backbones/llama_nar.py +140 -0
  25. diffrhythm2/cache_utils.py +154 -0
  26. diffrhythm2/cfm.py +425 -0
  27. g2p/__init__.py +0 -0
  28. g2p/g2p/__init__.py +87 -0
  29. g2p/g2p/chinese_model_g2p.py +213 -0
  30. g2p/g2p/cleaners.py +31 -0
  31. g2p/g2p/english.py +202 -0
  32. g2p/g2p/french.py +149 -0
  33. g2p/g2p/german.py +94 -0
  34. g2p/g2p/japanese.py +816 -0
  35. g2p/g2p/korean.py +81 -0
  36. g2p/g2p/mandarin.py +597 -0
  37. g2p/g2p/text_tokenizers.py +84 -0
  38. g2p/g2p/vocab.json +372 -0
  39. g2p/g2p_generation.py +134 -0
  40. g2p/language_segmentation/LangSegment.py +865 -0
  41. g2p/language_segmentation/__init__.py +9 -0
  42. g2p/language_segmentation/utils/__init__.py +0 -0
  43. g2p/language_segmentation/utils/num.py +327 -0
  44. g2p/sources/bpmf_2_pinyin.txt +41 -0
  45. g2p/sources/chinese_lexicon.txt +3 -0
  46. g2p/sources/g2p_chinese_model/config.json +819 -0
  47. g2p/sources/g2p_chinese_model/poly_bert_model.onnx +3 -0
  48. g2p/sources/g2p_chinese_model/polychar.txt +159 -0
  49. g2p/sources/g2p_chinese_model/polydict.json +393 -0
  50. g2p/sources/g2p_chinese_model/polydict_r.json +393 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ g2p/sources/chinese_lexicon.txt filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import json
4
+ import torch
5
+ import torchaudio
6
+ import json
7
+ import os
8
+ import random
9
+ import numpy as np
10
+ import io
11
+ import pydub
12
+ import base64
13
+ from muq import MuQMuLan
14
+ from diffrhythm2.cfm import CFM
15
+ from diffrhythm2.backbones.dit import DiT
16
+ from bigvgan.model import Generator
17
+ from huggingface_hub import hf_hub_download
18
+
19
+ STRUCT_INFO = {
20
+ "[start]": 500,
21
+ "[end]": 501,
22
+ "[intro]": 502,
23
+ "[verse]": 503,
24
+ "[chorus]": 504,
25
+ "[outro]": 505,
26
+ "[inst]": 506,
27
+ "[solo]": 507,
28
+ "[bridge]": 508,
29
+ "[hook]": 509,
30
+ "[break]": 510,
31
+ "[stop]": 511,
32
+ "[space]": 512
33
+ }
34
+
35
+ class CNENTokenizer():
36
+ def __init__(self):
37
+ curr_path = os.path.abspath(__file__)
38
+ vocab_path = os.path.join(os.path.dirname(curr_path), "g2p/g2p/vocab.json")
39
+ with open(vocab_path, 'r') as file:
40
+ self.phone2id:dict = json.load(file)['vocab']
41
+ self.id2phone = {v:k for (k, v) in self.phone2id.items()}
42
+ from g2p.g2p_generation import chn_eng_g2p
43
+ self.tokenizer = chn_eng_g2p
44
+ def encode(self, text):
45
+ phone, token = self.tokenizer(text)
46
+ token = [x+1 for x in token]
47
+ return token
48
+ def decode(self, token):
49
+ return "|".join([self.id2phone[x-1] for x in token])
50
+
51
+ def prepare_model(repo_id, device, dtype):
52
+ diffrhythm2_ckpt_path = hf_hub_download(
53
+ repo_id=repo_id,
54
+ filename="model.safetensors",
55
+ local_dir="./ckpt",
56
+ local_files_only=False,
57
+ )
58
+ diffrhythm2_config_path = hf_hub_download(
59
+ repo_id=repo_id,
60
+ filename="model.json",
61
+ local_dir="./ckpt",
62
+ local_files_only=False,
63
+ )
64
+ with open(diffrhythm2_config_path) as f:
65
+ model_config = json.load(f)
66
+
67
+ model_config['use_flex_attn'] = False
68
+ diffrhythm2 = CFM(
69
+ transformer=DiT(
70
+ **model_config
71
+ ),
72
+ num_channels=model_config['mel_dim'],
73
+ block_size=model_config['block_size'],
74
+ )
75
+
76
+ total_params = sum(p.numel() for p in diffrhythm2.parameters())
77
+
78
+ diffrhythm2 = diffrhythm2.to(device).to(dtype)
79
+ if diffrhythm2_ckpt_path.endswith('.safetensors'):
80
+ from safetensors.torch import load_file
81
+ ckpt = load_file(diffrhythm2_ckpt_path)
82
+ else:
83
+ ckpt = torch.load(diffrhythm2_ckpt_path, map_location='cpu')
84
+ diffrhythm2.load_state_dict(ckpt)
85
+ print(f"Total params: {total_params:,}")
86
+
87
+ # load Mulan
88
+ mulan = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large", cache_dir="./ckpt").to(device).to(dtype)
89
+
90
+ # load frontend
91
+ lrc_tokenizer = CNENTokenizer()
92
+
93
+ # load decoder
94
+ decoder_ckpt_path = hf_hub_download(
95
+ repo_id=repo_id,
96
+ filename="decoder.bin",
97
+ local_dir="./ckpt",
98
+ local_files_only=False,
99
+ )
100
+ decoder_config_path = hf_hub_download(
101
+ repo_id=repo_id,
102
+ filename="decoder.json",
103
+ local_dir="./ckpt",
104
+ local_files_only=False,
105
+ )
106
+ decoder = Generator(decoder_config_path, decoder_ckpt_path)
107
+ decoder = decoder.to(device).to(dtype)
108
+
109
+ return diffrhythm2, mulan, lrc_tokenizer, decoder
110
+
111
+ def parse_lyrics(lyrics: str):
112
+ lyrics_with_time = []
113
+ lyrics = lyrics.split("\n")
114
+ for line in lyrics:
115
+ struct_idx = STRUCT_INFO.get(line, None)
116
+ if struct_idx is not None:
117
+ lyrics_with_time.append([struct_idx, STRUCT_INFO['[stop]']])
118
+ else:
119
+ tokens = lrc_tokenizer.encode(line.strip())
120
+ tokens = tokens + [STRUCT_INFO['[stop]']]
121
+ lyrics_with_time.append(tokens)
122
+ return lyrics_with_time
123
+
124
+ def get_audio_prompt(model, audio_file, device, dtype):
125
+ prompt_wav, sr = torchaudio.load(audio_file)
126
+ prompt_wav = torchaudio.functional.resample(prompt_wav.to(device).to(dtype), sr, 24000)
127
+ if prompt_wav.shape[1] > 24000 * 10:
128
+ start = random.randint(0, prompt_wav.shape[1] - 24000 * 10)
129
+ prompt_wav = prompt_wav[:, start:start+24000*10]
130
+ prompt_wav = prompt_wav.mean(dim=0, keepdim=True)
131
+ with torch.no_grad():
132
+ style_prompt_embed = model(wavs = prompt_wav)
133
+ return style_prompt_embed.squeeze(0)
134
+
135
+ def get_text_prompt(model, text, device, dtype):
136
+ with torch.no_grad():
137
+ style_prompt_embed = model(texts = [text])
138
+ return style_prompt_embed.squeeze(0)
139
+
140
+ def make_fake_stereo(audio, sampling_rate):
141
+ left_channel = audio
142
+ right_channel = audio.clone()
143
+ right_channel = right_channel * 0.8
144
+ delay_samples = int(0.01 * sampling_rate)
145
+ right_channel = torch.roll(right_channel, delay_samples)
146
+ right_channel[:,:delay_samples] = 0
147
+ # stereo_audio = np.concatenate([left_channel, right_channel], axis=0)
148
+ stereo_audio = torch.cat([left_channel, right_channel], dim=0)
149
+
150
+ return stereo_audio
151
+
152
+ def inference(
153
+ model,
154
+ decoder,
155
+ text,
156
+ style_prompt,
157
+ duration,
158
+ cfg_strength=1.0,
159
+ sample_steps=32,
160
+ fake_stereo=True,
161
+ odeint_method='euler',
162
+ file_type="wav"
163
+ ):
164
+ with torch.inference_mode():
165
+ latent = model.sample_block_cache(
166
+ text=text.unsqueeze(0),
167
+ duration=int(duration * 5),
168
+ style_prompt=style_prompt.unsqueeze(0),
169
+ steps=sample_steps,
170
+ cfg_strength=cfg_strength,
171
+ odeint_method=odeint_method
172
+ )
173
+ latent = latent.transpose(1, 2)
174
+ audio = decoder.decode_audio(latent, overlap=5, chunk_size=20)
175
+
176
+ num_channels = 1
177
+ audio = audio.float().cpu().squeeze()[None, :]
178
+ if fake_stereo:
179
+ audio = make_fake_stereo(audio, decoder.h.sampling_rate)
180
+ num_channels = 2
181
+
182
+ if file_type == 'wav':
183
+ return (decoder.h.sampling_rate, audio.numpy().T) # [channel, time]
184
+ else:
185
+ buffer = io.BytesIO()
186
+ torchaudio.save(buffer, audio, decoder.h.sampling_rate, format=file_type)
187
+ return buffer.getvalue()
188
+
189
+ def inference_stream(
190
+ model,
191
+ decoder,
192
+ text,
193
+ style_prompt,
194
+ duration,
195
+ cfg_strength=1.0,
196
+ sample_steps=32,
197
+ fake_stereo=True,
198
+ odeint_method='euler',
199
+ file_type="wav"
200
+ ):
201
+ with torch.inference_mode():
202
+ for audio in model.sample_cache_stream(
203
+ decoder=decoder,
204
+ text=text.unsqueeze(0),
205
+ duration=int(duration * 5),
206
+ style_prompt=style_prompt.unsqueeze(0),
207
+ steps=sample_steps,
208
+ cfg_strength=cfg_strength,
209
+ chunk_size=20,
210
+ overlap=5,
211
+ odeint_method=odeint_method
212
+ ):
213
+ audio = audio.float().cpu().numpy().squeeze()[None, :]
214
+ if fake_stereo:
215
+ audio = make_fake_stereo(audio, decoder.h.sampling_rate)
216
+ # encoded_audio = io.BytesIO()
217
+ # torchaudio.save(encoded_audio, audio, decoder.h.sampling_rate, format='wav')
218
+ yield (decoder.h.sampling_rate, audio.T) # [channel, time]
219
+
220
+
221
+ lrc_tokenizer = None
222
+ MAX_SEED = np.iinfo(np.int32).max
223
+ device='cuda'
224
+ dtype=torch.float16
225
+ diffrhythm2, mulan, lrc_tokenizer, decoder = prepare_model("ASLP-Lab/DiffRhythm2", device, dtype)
226
+
227
+ # import spaces
228
+ # @spaces.GPU
229
+ def infer_music(
230
+ lrc,
231
+ current_prompt_type,
232
+ audio_prompt=None,
233
+ text_prompt=None,
234
+ seed=42,
235
+ randomize_seed=False,
236
+ steps=16,
237
+ cfg_strength=1.0,
238
+ file_type='wav',
239
+ odeint_method='euler',
240
+ device='cuda'
241
+ ):
242
+ if randomize_seed:
243
+ seed = random.randint(0, MAX_SEED)
244
+ torch.manual_seed(seed)
245
+ print(seed, current_prompt_type)
246
+ try:
247
+ lrc_prompt = parse_lyrics(lrc)
248
+ lrc_prompt = torch.tensor(sum(lrc_prompt, []), dtype=torch.long, device=device)
249
+ if current_prompt_type == "audio":
250
+ style_prompt = get_audio_prompt(mulan, audio_prompt, device, dtype)
251
+ else:
252
+ style_prompt = get_text_prompt(mulan, text_prompt, device, dtype)
253
+ except Exception as e:
254
+ raise gr.Error(f"Error: {str(e)}")
255
+ style_prompt = style_prompt.to(dtype)
256
+ generate_song = inference(
257
+ model=diffrhythm2,
258
+ decoder=decoder,
259
+ text=lrc_prompt,
260
+ style_prompt=style_prompt,
261
+ sample_steps=steps,
262
+ cfg_strength=cfg_strength,
263
+ odeint_method=odeint_method,
264
+ duration=240,
265
+ file_type=file_type
266
+ )
267
+ return generate_song
268
+ # for block in inference_stream(
269
+ # model=diffrhythm2,
270
+ # decoder=decoder,
271
+ # text=lrc_prompt,
272
+ # style_prompt=style_prompt,
273
+ # sample_steps=steps,
274
+ # cfg_strength=cfg_strength,
275
+ # odeint_method=odeint_method,
276
+ # duration=240,
277
+ # file_type=file_type
278
+ # ):
279
+ # yield block
280
+
281
+
282
+ css = """
283
+ /* ๅ›บๅฎšๆ–‡ๆœฌๅŸŸ้ซ˜ๅบฆๅนถๅผบๅˆถๆปšๅŠจๆก */
284
+ .lyrics-scroll-box textarea {
285
+ height: 405px !important; /* ๅ›บๅฎš้ซ˜ๅบฆ */
286
+ max-height: 500px !important; /* ๆœ€ๅคง้ซ˜ๅบฆ */
287
+ overflow-y: auto !important; /* ๅž‚็›ดๆปšๅŠจ */
288
+ white-space: pre-wrap; /* ไฟ็•™ๆข่กŒ */
289
+ line-height: 1.5; /* ่กŒ้ซ˜ไผ˜ๅŒ– */
290
+ }
291
+
292
+ .gr-examples {
293
+ background: transparent !important;
294
+ border: 1px solid #e0e0e0 !important;
295
+ border-radius: 8px;
296
+ margin: 1rem 0 !important;
297
+ padding: 1rem !important;
298
+ }
299
+
300
+ """
301
+ import base64
302
+
303
+ def image_to_base64(path):
304
+ with open(path, "rb") as image_file:
305
+ return base64.b64encode(image_file.read()).decode('utf-8')
306
+
307
+ with gr.Blocks(css=css) as demo:
308
+ gr.HTML(f"""
309
+ <div style="flex: 1; text-align: center;">
310
+ <div style="font-size: 2em; font-weight: bold; text-align: center; margin-bottom: 5px">
311
+ Diโ™ชโ™ชRhythm 2 (่ฐ›้Ÿต)
312
+ </div>
313
+ <div style="display:flex; justify-content: center; column-gap:4px;">
314
+ <a href="https://arxiv.org/pdf/2510.22950">
315
+ <img src='https://img.shields.io/badge/Arxiv-Paper-blue'>
316
+ </a>
317
+ <a href="https://github.com/ASLP-lab/DiffRhythm2">
318
+ <img src='https://img.shields.io/badge/GitHub-Repo-green'>
319
+ </a>
320
+ <a href="https://aslp-lab.github.io/DiffRhythm2.github.io/">
321
+ <img src='https://img.shields.io/badge/Project-Page-brown'>
322
+ </a>
323
+ </div>
324
+ </div>
325
+ """)
326
+
327
+ with gr.Tabs() as tabs:
328
+
329
+ # page 1
330
+ with gr.Tab("Music Generate", id=0):
331
+ with gr.Row():
332
+ with gr.Column():
333
+ lrc = gr.Textbox(
334
+ label="Lyrics",
335
+ placeholder="Input the full lyrics",
336
+ lines=12,
337
+ max_lines=50,
338
+ elem_classes="lyrics-scroll-box",
339
+ value="""[start]
340
+ [intro]
341
+ [verse]
342
+ Thought I heard your voice yesterday
343
+ When I turned around to say
344
+ That I loved you baby
345
+ I realize it was juss my mind
346
+ Played tricks on me
347
+ And it seems colder lately at night
348
+ And I try to sleep with the lights on
349
+ Every time the phone rings
350
+ I pray to God it's you
351
+ And I just can't believe
352
+ That we're through
353
+ [chorus]
354
+ I miss you
355
+ There's no other way to say it
356
+ And I can't deny it
357
+ I miss you
358
+ It's so easy to see
359
+ I miss you and me
360
+ [verse]
361
+ Is it turning over this time
362
+ Have we really changed our minds about each other's love
363
+ All the feelings that we used to share
364
+ I refuse to believe
365
+ That you don't care
366
+ [chorus]
367
+ I miss you
368
+ There's no other way to say it
369
+ And I and I can't deny it
370
+ I miss you
371
+ [verse]
372
+ It's so easy to see
373
+ I've got to gather myself as together
374
+ I've been through worst kinds of weather
375
+ If it's over now
376
+ [outro]"""
377
+ )
378
+ current_prompt_type = gr.State(value="text")
379
+ with gr.Tabs() as inside_tabs:
380
+ with gr.Tab("Text Prompt"):
381
+ text_prompt = gr.Textbox(
382
+ label="Text Prompt",
383
+ value="Pop, Piano, Bass, Drums, Happy",
384
+ placeholder="Enter the Text Prompt, eg: emotional piano pop",
385
+ )
386
+ with gr.Tab("Audio Prompt"):
387
+ audio_prompt = gr.Audio(label="Audio Prompt", type="filepath")
388
+
389
+ def update_prompt_type(evt: gr.SelectData):
390
+ return "text" if evt.index == 0 else "audio"
391
+
392
+ inside_tabs.select(
393
+ fn=update_prompt_type,
394
+ outputs=current_prompt_type
395
+ )
396
+
397
+
398
+ with gr.Column():
399
+
400
+ with gr.Accordion("Best Practices Guide", open=True):
401
+ gr.Markdown("""
402
+ 1. **Lyrics Format Requirements**
403
+ - Each line must follow: `Lyric content`
404
+ - Example of valid format:
405
+ ```
406
+ [intro]
407
+ [verse]
408
+ Thought I heard your voice yesterday
409
+ When I turned around to say
410
+ ```
411
+
412
+ 2. **Audio Prompt Requirements**
413
+ - Reference audio should be โ‰ฅ 1 second, Audio >10 seconds will be randomly clipped into 10 seconds
414
+ - For optimal results, the 10-second clips should be carefully selected
415
+ - Shorter clips may lead to incoherent generation
416
+
417
+ 3. **Supported Languages**
418
+ - Chinese and English
419
+ """)
420
+ lyrics_btn = gr.Button("Generate", variant="primary")
421
+ # audio_output = gr.Gallery(label="Audio Results")
422
+ audio_output = gr.Audio(label="Audio Result", elem_id="audio_output")
423
+ with gr.Accordion("Advanced Settings", open=False):
424
+ seed = gr.Slider(
425
+ label="Seed",
426
+ minimum=0,
427
+ maximum=MAX_SEED,
428
+ step=1,
429
+ value=0,
430
+ )
431
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
432
+
433
+ steps = gr.Slider(
434
+ minimum=10,
435
+ maximum=100,
436
+ value=16,
437
+ step=1,
438
+ label="Diffusion Steps",
439
+ interactive=True,
440
+ elem_id="step_slider"
441
+ )
442
+ cfg_strength = gr.Slider(
443
+ minimum=1,
444
+ maximum=10,
445
+ value=1.0,
446
+ step=0.5,
447
+ label="CFG Strength",
448
+ interactive=True,
449
+ elem_id="step_slider"
450
+ )
451
+
452
+ odeint_method = gr.Radio(["euler", "midpoint", "rk4","implicit_adams"], label="ODE Solver", value="euler")
453
+ file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="mp3")
454
+
455
+
456
+ # gr.Examples(
457
+ # examples=[
458
+ # ["src/prompt/classic_cn.wav"],
459
+ # ["src/prompt/classic_en.wav"],
460
+ # ["src/prompt/country_cn.wav"],
461
+ # ["src/prompt/country_en.wav"],
462
+ # ["src/prompt/jazz_cn.wav"],
463
+ # ["src/prompt/jazz_en.wav"],
464
+ # ["src/prompt/pop_cn.wav"],
465
+ # ["src/prompt/pop_en.wav"],
466
+ # ["src/prompt/rap_cn.wav"],
467
+ # ["src/prompt/rap_en.wav"],
468
+ # ["src/prompt/rock_cn.wav"],
469
+ # ["src/prompt/rock_en.wav"]
470
+ # ],
471
+ # inputs=[audio_prompt],
472
+ # label="Audio Examples",
473
+ # examples_per_page=12,
474
+ # elem_id="audio-examples-container"
475
+ # )
476
+
477
+ # gr.Examples(
478
+ # examples=[
479
+ # ["Pop Emotional Piano"],
480
+ # ["ๆต่กŒ ๆƒ…ๆ„Ÿ ้’ข็ด"],
481
+ # ["Indie folk ballad, coming-of-age themes, acoustic guitar picking with harmonica interludes"],
482
+ # ["็‹ฌ็ซ‹ๆฐ‘่ฐฃ, ๆˆ้•ฟไธป้ข˜, ๅŽŸๅฃฐๅ‰ไป–ๅผนๅฅไธŽๅฃ็ด้—ดๅฅ"]
483
+ # ],
484
+ # inputs=[text_prompt],
485
+ # label="Text Examples",
486
+ # examples_per_page=4,
487
+ # elem_id="text-examples-container"
488
+ # )
489
+
490
+ # gr.Examples(
491
+ # examples=[
492
+ # ["""[00:10.00]Moonlight spills through broken blinds\n[00:13.20]Your shadow dances on the dashboard shrine\n[00:16.85]Neon ghosts in gasoline rain\n[00:20.40]I hear your laughter down the midnight train\n[00:24.15]Static whispers through frayed wires\n[00:27.65]Guitar strings hum our cathedral choirs\n[00:31.30]Flicker screens show reruns of June\n[00:34.90]I'm drowning in this mercury lagoon\n[00:38.55]Electric veins pulse through concrete skies\n[00:42.10]Your name echoes in the hollow where my heartbeat lies\n[00:45.75]We're satellites trapped in parallel light\n[00:49.25]Burning through the atmosphere of endless night\n[01:00.00]Dusty vinyl spins reverse\n[01:03.45]Our polaroid timeline bleeds through the verse\n[01:07.10]Telescope aimed at dead stars\n[01:10.65]Still tracing constellations through prison bars\n[01:14.30]Electric veins pulse through concrete skies\n[01:17.85]Your name echoes in the hollow where my heartbeat lies\n[01:21.50]We're satellites trapped in parallel light\n[01:25.05]Burning through the atmosphere of endless night\n[02:10.00]Clockwork gears grind moonbeams to rust\n[02:13.50]Our fingerprint smudged by interstellar dust\n[02:17.15]Velvet thunder rolls through my veins\n[02:20.70]Chasing phantom trains through solar plane\n[02:24.35]Electric veins pulse through concrete skies\n[02:27.90]Your name echoes in the hollow where my heartbeat lies"""],
493
+ # ["""[00:05.00]Stardust whispers in your eyes\n[00:09.30]Moonlight paints our silhouettes\n[00:13.75]Tides bring secrets from the deep\n[00:18.20]Where forever's breath is kept\n[00:22.90]We dance through constellations' maze\n[00:27.15]Footprints melt in cosmic waves\n[00:31.65]Horizons hum our silent vow\n[00:36.10]Time unravels here and now\n[00:40.85]Eternal embers in the night oh oh oh\n[00:45.25]Healing scars with liquid light\n[00:49.70]Galaxies write our refrain\n[00:54.15]Love reborn in endless rain\n[01:15.30]Paper boats of memories\n[01:19.75]Float through veins of ancient trees\n[01:24.20]Your laughter spins aurora threads\n[01:28.65]Weaving dawn through featherbed"""],
494
+ # ["""[00:04.27]ๅชๅ› ไฝ ๅคช็พŽ baby\n[00:08.95]ๅชๅ› ไฝ ๅฎžๅœจๆ˜ฏๅคช็พŽ baby\n[00:13.99]ๅชๅ› ไฝ ๅคช็พŽ baby\n[00:18.89]่ฟŽ้ข่ตฐๆฅ็š„ไฝ ่ฎฉๆˆ‘ๅฆ‚ๆญค่ ข่ ขๆฌฒๅŠจ\n[00:20.88]่ฟ™็งๆ„Ÿ่ง‰ๆˆ‘ไปŽๆœชๆœ‰\n[00:21.79]Cause I got a crush on you who you\n[00:25.74]ไฝ ๆ˜ฏๆˆ‘็š„ๆˆ‘ๆ˜ฏไฝ ็š„่ฐ\n[00:28.09]ๅ†ๅคšไธ€็œผ็œ‹ไธ€็œผๅฐฑไผš็ˆ†็‚ธ\n[00:30.31]ๅ†่ฟ‘ไธ€็‚น้ ่ฟ‘็‚นๅฟซ่ขซ่žๅŒ–\n[00:32.49]ๆƒณ่ฆๆŠŠไฝ ๅ ไธบๅทฑๆœ‰ baby\n[00:34.60]ไธ็ฎก่ตฐๅˆฐๅ“ช้‡Œ\n[00:35.44]้ƒฝไผšๆƒณ่ตท็š„ไบบๆ˜ฏไฝ  you you\n[00:38.12]ๆˆ‘ๅบ”่ฏฅๆ‹ฟไฝ ๆ€Žๆ ท\n[00:39.61]Uh ๆ‰€ๆœ‰ไบบ้ƒฝๅœจ็œ‹็€ไฝ \n[00:42.36]ๆˆ‘็š„ๅฟƒๆ€ปๆ˜ฏไธๅฎ‰\n[00:44.18]Oh ๆˆ‘็Žฐๅœจๅทฒ็—…ๅ…ฅ่†่‚“\n[00:46.63]Eh oh\n[00:47.84]้šพ้“็œŸ็š„ๅ› ไฝ ่€Œ็–ฏ็‹‚ๅ—\n[00:51.57]ๆˆ‘ๆœฌๆฅไธๆ˜ฏ่ฟ™็งไบบ\n[00:53.59]ๅ› ไฝ ๅ˜ๆˆๅฅ‡ๆ€ช็š„ไบบ\n[00:55.77]็ฌฌไธ€ๆฌกๅ‘€ๅ˜ๆˆ่ฟ™ๆ ท็š„ๆˆ‘\n[01:01.23]ไธ็ฎกๆˆ‘ๆ€ŽไนˆๅŽปๅฆ่ฎค\n[01:03.21]ๅชๅ› ไฝ ๅคช็พŽ baby\n[01:11.46]ๅชๅ› ไฝ ๅฎžๅœจๆ˜ฏๅคช็พŽ baby\n[01:16.75]ๅชๅ› ไฝ ๅคช็พŽ baby\n[01:21.09]Oh eh oh\n[01:22.82]็Žฐๅœจ็กฎ่ฎคๅœฐๅ‘Š่ฏ‰ๆˆ‘\n[01:25.26]Oh eh oh\n[01:27.31]ไฝ ๅˆฐๅบ•ๅฑžไบŽ่ฐ\n[01:29.98]Oh eh oh\n[01:31.70]็Žฐๅœจ็กฎ่ฎคๅœฐๅ‘Š่ฏ‰ๆˆ‘\n[01:34.45]Oh eh oh\n[01:36.35]ไฝ ๅˆฐๅบ•ๅฑžไบŽ่ฐ\n[01:37.65]ๅฐฑๆ˜ฏ็Žฐๅœจๅ‘Š่ฏ‰ๆˆ‘\n[01:40.00]่ทŸ็€้‚ฃ่Š‚ๅฅ ็ผ“็ผ“ make wave\n"""],
495
+ # ["""[00:16.55]ๅ€ฆ้ธŸ่ฅฟๅฝ’ ็ซนๅฝฑไฝ™ๆ™–\n[00:23.58]็ฆ…ๆ„ๅฟƒๆ‰‰\n[00:27.32]ๅพ…ๆธ…้ฃŽ ๆ‹‚ๅผ€ไธ€ๆฑ ๆ˜ฅๆฐด\n[00:30.83]ไฝ ็š„ๆ‰‹็ป˜ ็މ่‰ฒ้šพ่คช\n[00:37.99]ๆˆ‘็ซฏ่ฏฆ้ฃ˜ๆ•ฃ็š„้Ÿตๅ‘ณ\n[00:40.65]่ฝๆฌพๅฃถๅบ•็š„ๅ่ฎณ\n[00:42.92]ๅฆ‚ๅป่ฅฟๆ–ฝ็š„ๅ˜ด\n[00:45.14]้ฃŽ้›…ๅ‡ ๅ›ž ๆ€ป็›ธ้š\n[00:52.32]็š†ๅ› ไฝ ็่ดต\n[00:57.85]ไธ‰ๅƒๅผฑๆฐด ็…ฎไธ€ๆฏ\n[01:02.21]ๆˆ‘ๅช้ฅฎไธ‹ไฝ ็š„็พŽ\n[01:04.92]ๅƒๅนดไฝ™ๅ‘ณ ็ดซ็ ‚ๅฃถไผดๆˆ‘้†‰\n[01:09.73]้…ฟไธ€ไธ–ๆ— ๆ‚”\n[01:12.09]ๆฒๅฃถๆ˜ฅๆฐด ็ฟ ็ƒŸ้ฃž\n[01:16.62]ๆŠŠ็›ไธๅฐฝไฝ ็š„้ฆ™ๅ‘ณ\n[01:20.06]้‚€ๆœˆ็›ธๅฏน ๆ„ฟไปŠ็”ŸๅŒๅฎฟๅŒๅฝ’\n[01:26.43]ๅช่ฎฉไฝ ้™ช\n[01:46.12]่Œ—้ฆ™่Šณ่ฒ ไธ–ไฟ—ๆ— ่ฟฝ\n"""]
496
+ # ],
497
+ # inputs=[lrc],
498
+ # label="Lrc Examples",
499
+ # examples_per_page=4,
500
+ # elem_id="lrc-examples-container",
501
+ # )
502
+
503
+ tabs.select(
504
+ lambda s: None,
505
+ None,
506
+ None
507
+ )
508
+
509
+ # TODO add max_frames parameter for infer_music
510
+ lyrics_btn.click(
511
+ fn=infer_music,
512
+ inputs=[
513
+ lrc,
514
+ current_prompt_type,
515
+ audio_prompt,
516
+ text_prompt,
517
+ seed,
518
+ randomize_seed,
519
+ steps,
520
+ cfg_strength,
521
+ file_type,
522
+ odeint_method,
523
+ ],
524
+ outputs=audio_output,
525
+ )
526
+
527
+
528
+ # demo.queue().launch(show_api=False, show_error=True)
529
+
530
+
531
+
532
+ if __name__ == "__main__":
533
+ demo.launch()
bigvgan/__init__.py ADDED
File without changes
bigvgan/activations.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ from torch import nn, sin, pow
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Snake(nn.Module):
10
+ """
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ """
25
+
26
+ def __init__(
27
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
28
+ ):
29
+ """
30
+ Initialization.
31
+ INPUT:
32
+ - in_features: shape of the input
33
+ - alpha: trainable parameter
34
+ alpha is initialized to 1 by default, higher values = higher-frequency.
35
+ alpha will be trained along with the rest of your model.
36
+ """
37
+ super(Snake, self).__init__()
38
+ self.in_features = in_features
39
+
40
+ # Initialize alpha
41
+ self.alpha_logscale = alpha_logscale
42
+ if self.alpha_logscale: # Log scale alphas initialized to zeros
43
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
44
+ else: # Linear scale alphas initialized to ones
45
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
46
+
47
+ self.alpha.requires_grad = alpha_trainable
48
+
49
+ self.no_div_by_zero = 0.000000001
50
+
51
+ def forward(self, x):
52
+ """
53
+ Forward pass of the function.
54
+ Applies the function to the input elementwise.
55
+ Snake โˆถ= x + 1/a * sin^2 (xa)
56
+ """
57
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
58
+ if self.alpha_logscale:
59
+ alpha = torch.exp(alpha)
60
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
61
+
62
+ return x
63
+
64
+
65
+ class SnakeBeta(nn.Module):
66
+ """
67
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
68
+ Shape:
69
+ - Input: (B, C, T)
70
+ - Output: (B, C, T), same shape as the input
71
+ Parameters:
72
+ - alpha - trainable parameter that controls frequency
73
+ - beta - trainable parameter that controls magnitude
74
+ References:
75
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
76
+ https://arxiv.org/abs/2006.08195
77
+ Examples:
78
+ >>> a1 = snakebeta(256)
79
+ >>> x = torch.randn(256)
80
+ >>> x = a1(x)
81
+ """
82
+
83
+ def __init__(
84
+ self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
85
+ ):
86
+ """
87
+ Initialization.
88
+ INPUT:
89
+ - in_features: shape of the input
90
+ - alpha - trainable parameter that controls frequency
91
+ - beta - trainable parameter that controls magnitude
92
+ alpha is initialized to 1 by default, higher values = higher-frequency.
93
+ beta is initialized to 1 by default, higher values = higher-magnitude.
94
+ alpha will be trained along with the rest of your model.
95
+ """
96
+ super(SnakeBeta, self).__init__()
97
+ self.in_features = in_features
98
+
99
+ # Initialize alpha
100
+ self.alpha_logscale = alpha_logscale
101
+ if self.alpha_logscale: # Log scale alphas initialized to zeros
102
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
103
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
104
+ else: # Linear scale alphas initialized to ones
105
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
106
+ self.beta = Parameter(torch.ones(in_features) * alpha)
107
+
108
+ self.alpha.requires_grad = alpha_trainable
109
+ self.beta.requires_grad = alpha_trainable
110
+
111
+ self.no_div_by_zero = 0.000000001
112
+
113
+ def forward(self, x):
114
+ """
115
+ Forward pass of the function.
116
+ Applies the function to the input elementwise.
117
+ SnakeBeta โˆถ= x + 1/b * sin^2 (xa)
118
+ """
119
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # Line up with x to [B, C, T]
120
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
121
+ if self.alpha_logscale:
122
+ alpha = torch.exp(alpha)
123
+ beta = torch.exp(beta)
124
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
125
+
126
+ return x
bigvgan/alias_free_activation/cuda/__init__.py ADDED
File without changes
bigvgan/alias_free_activation/cuda/activation1d.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from alias_free_activation.torch.resample import UpSample1d, DownSample1d
7
+
8
+ # load fused CUDA kernel: this enables importing anti_alias_activation_cuda
9
+ from alias_free_activation.cuda import load
10
+
11
+ anti_alias_activation_cuda = load.load()
12
+
13
+
14
+ class FusedAntiAliasActivation(torch.autograd.Function):
15
+ """
16
+ Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
17
+ The hyperparameters are hard-coded in the kernel to maximize speed.
18
+ NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
19
+ """
20
+
21
+ @staticmethod
22
+ def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
23
+ activation_results = anti_alias_activation_cuda.forward(
24
+ inputs, up_ftr, down_ftr, alpha, beta
25
+ )
26
+
27
+ return activation_results
28
+
29
+ @staticmethod
30
+ def backward(ctx, output_grads):
31
+ raise NotImplementedError
32
+ return output_grads, None, None
33
+
34
+
35
+ class Activation1d(nn.Module):
36
+ def __init__(
37
+ self,
38
+ activation,
39
+ up_ratio: int = 2,
40
+ down_ratio: int = 2,
41
+ up_kernel_size: int = 12,
42
+ down_kernel_size: int = 12,
43
+ fused: bool = True,
44
+ ):
45
+ super().__init__()
46
+ self.up_ratio = up_ratio
47
+ self.down_ratio = down_ratio
48
+ self.act = activation
49
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
50
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
51
+
52
+ self.fused = fused # Whether to use fused CUDA kernel or not
53
+
54
+ def forward(self, x):
55
+ if not self.fused:
56
+ x = self.upsample(x)
57
+ x = self.act(x)
58
+ x = self.downsample(x)
59
+ return x
60
+ else:
61
+ if self.act.__class__.__name__ == "Snake":
62
+ beta = self.act.alpha.data # Snake uses same params for alpha and beta
63
+ else:
64
+ beta = (
65
+ self.act.beta.data
66
+ ) # Snakebeta uses different params for alpha and beta
67
+ alpha = self.act.alpha.data
68
+ if (
69
+ not self.act.alpha_logscale
70
+ ): # Exp baked into cuda kernel, cancel it out with a log
71
+ alpha = torch.log(alpha)
72
+ beta = torch.log(beta)
73
+
74
+ x = FusedAntiAliasActivation.apply(
75
+ x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
76
+ )
77
+ return x
bigvgan/alias_free_activation/cuda/anti_alias_activation.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <torch/extension.h>
18
+
19
+ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
20
+
21
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22
+ m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
23
+ }
bigvgan/alias_free_activation/cuda/anti_alias_activation_cuda.cu ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include <cuda.h>
19
+ #include <cuda_runtime.h>
20
+ #include <cuda_fp16.h>
21
+ #include <cuda_profiler_api.h>
22
+ #include <ATen/cuda/CUDAContext.h>
23
+ #include <torch/extension.h>
24
+ #include "type_shim.h"
25
+ #include <assert.h>
26
+ #include <cfloat>
27
+ #include <limits>
28
+ #include <stdint.h>
29
+ #include <c10/macros/Macros.h>
30
+
31
+ namespace
32
+ {
33
+ // Hard-coded hyperparameters
34
+ // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
35
+ constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
36
+ constexpr int BUFFER_SIZE = 32;
37
+ constexpr int FILTER_SIZE = 12;
38
+ constexpr int HALF_FILTER_SIZE = 6;
39
+ constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
40
+ constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
41
+ constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
42
+
43
+ template <typename input_t, typename output_t, typename acc_t>
44
+ __global__ void anti_alias_activation_forward(
45
+ output_t *dst,
46
+ const input_t *src,
47
+ const input_t *up_ftr,
48
+ const input_t *down_ftr,
49
+ const input_t *alpha,
50
+ const input_t *beta,
51
+ int batch_size,
52
+ int channels,
53
+ int seq_len)
54
+ {
55
+ // Up and downsample filters
56
+ input_t up_filter[FILTER_SIZE];
57
+ input_t down_filter[FILTER_SIZE];
58
+
59
+ // Load data from global memory including extra indices reserved for replication paddings
60
+ input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
61
+ input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
62
+
63
+ // Output stores downsampled output before writing to dst
64
+ output_t output[BUFFER_SIZE];
65
+
66
+ // blockDim/threadIdx = (128, 1, 1)
67
+ // gridDim/blockIdx = (seq_blocks, channels, batches)
68
+ int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
69
+ int local_offset = threadIdx.x * BUFFER_SIZE;
70
+ int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
71
+
72
+ // intermediate have double the seq_len
73
+ int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
74
+ int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
75
+
76
+ // Get values needed for replication padding before moving pointer
77
+ const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
78
+ input_t seq_left_most_value = right_most_pntr[0];
79
+ input_t seq_right_most_value = right_most_pntr[seq_len - 1];
80
+
81
+ // Move src and dst pointers
82
+ src += block_offset + local_offset;
83
+ dst += block_offset + local_offset;
84
+
85
+ // Alpha and beta values for snake activatons. Applies exp by default
86
+ alpha = alpha + blockIdx.y;
87
+ input_t alpha_val = expf(alpha[0]);
88
+ beta = beta + blockIdx.y;
89
+ input_t beta_val = expf(beta[0]);
90
+
91
+ #pragma unroll
92
+ for (int it = 0; it < FILTER_SIZE; it += 1)
93
+ {
94
+ up_filter[it] = up_ftr[it];
95
+ down_filter[it] = down_ftr[it];
96
+ }
97
+
98
+ // Apply replication padding for upsampling, matching torch impl
99
+ #pragma unroll
100
+ for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
101
+ {
102
+ int element_index = seq_offset + it; // index for element
103
+ if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
104
+ {
105
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
106
+ }
107
+ if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
108
+ {
109
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
110
+ }
111
+ if ((element_index >= 0) && (element_index < seq_len))
112
+ {
113
+ elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
114
+ }
115
+ }
116
+
117
+ // Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
118
+ #pragma unroll
119
+ for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
120
+ {
121
+ input_t acc = 0.0;
122
+ int element_index = intermediate_seq_offset + it; // index for intermediate
123
+ #pragma unroll
124
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
125
+ {
126
+ if ((element_index + f_idx) >= 0)
127
+ {
128
+ acc += up_filter[f_idx] * elements[it + f_idx];
129
+ }
130
+ }
131
+ intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
132
+ }
133
+
134
+ // Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
135
+ double no_div_by_zero = 0.000000001;
136
+ #pragma unroll
137
+ for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
138
+ {
139
+ intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
140
+ }
141
+
142
+ // Apply replication padding before downsampling conv from intermediates
143
+ #pragma unroll
144
+ for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
145
+ {
146
+ intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
147
+ }
148
+ #pragma unroll
149
+ for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
150
+ {
151
+ intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
152
+ }
153
+
154
+ // Apply downsample strided convolution (assuming stride=2) from intermediates
155
+ #pragma unroll
156
+ for (int it = 0; it < BUFFER_SIZE; it += 1)
157
+ {
158
+ input_t acc = 0.0;
159
+ #pragma unroll
160
+ for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
161
+ {
162
+ // Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
163
+ acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
164
+ }
165
+ output[it] = acc;
166
+ }
167
+
168
+ // Write output to dst
169
+ #pragma unroll
170
+ for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
171
+ {
172
+ int element_index = seq_offset + it;
173
+ if (element_index < seq_len)
174
+ {
175
+ dst[it] = output[it];
176
+ }
177
+ }
178
+
179
+ }
180
+
181
+ template <typename input_t, typename output_t, typename acc_t>
182
+ void dispatch_anti_alias_activation_forward(
183
+ output_t *dst,
184
+ const input_t *src,
185
+ const input_t *up_ftr,
186
+ const input_t *down_ftr,
187
+ const input_t *alpha,
188
+ const input_t *beta,
189
+ int batch_size,
190
+ int channels,
191
+ int seq_len)
192
+ {
193
+ if (seq_len == 0)
194
+ {
195
+ return;
196
+ }
197
+ else
198
+ {
199
+ // Use 128 threads per block to maximimize gpu utilization
200
+ constexpr int threads_per_block = 128;
201
+ constexpr int seq_len_per_block = 4096;
202
+ int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
203
+ dim3 blocks(blocks_per_seq_len, channels, batch_size);
204
+ dim3 threads(threads_per_block, 1, 1);
205
+
206
+ anti_alias_activation_forward<input_t, output_t, acc_t>
207
+ <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
208
+ }
209
+ }
210
+ }
211
+
212
+ extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
213
+ {
214
+ // Input is a 3d tensor with dimensions [batches, channels, seq_len]
215
+ const int batches = input.size(0);
216
+ const int channels = input.size(1);
217
+ const int seq_len = input.size(2);
218
+
219
+ // Output
220
+ auto act_options = input.options().requires_grad(false);
221
+
222
+ torch::Tensor anti_alias_activation_results =
223
+ torch::empty({batches, channels, seq_len}, act_options);
224
+
225
+ void *input_ptr = static_cast<void *>(input.data_ptr());
226
+ void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
227
+ void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
228
+ void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
229
+ void *beta_ptr = static_cast<void *>(beta.data_ptr());
230
+ void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
231
+
232
+ DISPATCH_FLOAT_HALF_AND_BFLOAT(
233
+ input.scalar_type(),
234
+ "dispatch anti alias activation_forward",
235
+ dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
236
+ reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
237
+ reinterpret_cast<const scalar_t *>(input_ptr),
238
+ reinterpret_cast<const scalar_t *>(up_filter_ptr),
239
+ reinterpret_cast<const scalar_t *>(down_filter_ptr),
240
+ reinterpret_cast<const scalar_t *>(alpha_ptr),
241
+ reinterpret_cast<const scalar_t *>(beta_ptr),
242
+ batches,
243
+ channels,
244
+ seq_len););
245
+ return anti_alias_activation_results;
246
+ }
bigvgan/alias_free_activation/cuda/compat.h ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ /*This code is copied fron NVIDIA apex:
18
+ * https://github.com/NVIDIA/apex
19
+ * with minor changes. */
20
+
21
+ #ifndef TORCH_CHECK
22
+ #define TORCH_CHECK AT_CHECK
23
+ #endif
24
+
25
+ #ifdef VERSION_GE_1_3
26
+ #define DATA_PTR data_ptr
27
+ #else
28
+ #define DATA_PTR data
29
+ #endif
bigvgan/alias_free_activation/cuda/load.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ import os
5
+ import pathlib
6
+ import subprocess
7
+
8
+ from torch.utils import cpp_extension
9
+
10
+ """
11
+ Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
12
+ Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
13
+ """
14
+ os.environ["TORCH_CUDA_ARCH_LIST"] = ""
15
+
16
+
17
+ def load():
18
+ # Check if cuda 11 is installed for compute capability 8.0
19
+ cc_flag = []
20
+ _, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
21
+ if int(bare_metal_major) >= 11:
22
+ cc_flag.append("-gencode")
23
+ cc_flag.append("arch=compute_80,code=sm_80")
24
+
25
+ # Build path
26
+ srcpath = pathlib.Path(__file__).parent.absolute()
27
+ buildpath = srcpath / "build"
28
+ _create_build_dir(buildpath)
29
+
30
+ # Helper function to build the kernels.
31
+ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
32
+ return cpp_extension.load(
33
+ name=name,
34
+ sources=sources,
35
+ build_directory=buildpath,
36
+ extra_cflags=[
37
+ "-O3",
38
+ ],
39
+ extra_cuda_cflags=[
40
+ "-O3",
41
+ "-gencode",
42
+ "arch=compute_70,code=sm_70",
43
+ "--use_fast_math",
44
+ ]
45
+ + extra_cuda_flags
46
+ + cc_flag,
47
+ verbose=True,
48
+ )
49
+
50
+ extra_cuda_flags = [
51
+ "-U__CUDA_NO_HALF_OPERATORS__",
52
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
53
+ "--expt-relaxed-constexpr",
54
+ "--expt-extended-lambda",
55
+ ]
56
+
57
+ sources = [
58
+ srcpath / "anti_alias_activation.cpp",
59
+ srcpath / "anti_alias_activation_cuda.cu",
60
+ ]
61
+ anti_alias_activation_cuda = _cpp_extention_load_helper(
62
+ "anti_alias_activation_cuda", sources, extra_cuda_flags
63
+ )
64
+
65
+ return anti_alias_activation_cuda
66
+
67
+
68
+ def _get_cuda_bare_metal_version(cuda_dir):
69
+ raw_output = subprocess.check_output(
70
+ [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
71
+ )
72
+ output = raw_output.split()
73
+ release_idx = output.index("release") + 1
74
+ release = output[release_idx].split(".")
75
+ bare_metal_major = release[0]
76
+ bare_metal_minor = release[1][0]
77
+
78
+ return raw_output, bare_metal_major, bare_metal_minor
79
+
80
+
81
+ def _create_build_dir(buildpath):
82
+ try:
83
+ os.mkdir(buildpath)
84
+ except OSError:
85
+ if not os.path.isdir(buildpath):
86
+ print(f"Creation of the build directory {buildpath} failed")
bigvgan/alias_free_activation/cuda/type_shim.h ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* coding=utf-8
2
+ * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #include <ATen/ATen.h>
18
+ #include "compat.h"
19
+
20
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
21
+ switch (TYPE) \
22
+ { \
23
+ case at::ScalarType::Float: \
24
+ { \
25
+ using scalar_t = float; \
26
+ __VA_ARGS__; \
27
+ break; \
28
+ } \
29
+ case at::ScalarType::Half: \
30
+ { \
31
+ using scalar_t = at::Half; \
32
+ __VA_ARGS__; \
33
+ break; \
34
+ } \
35
+ case at::ScalarType::BFloat16: \
36
+ { \
37
+ using scalar_t = at::BFloat16; \
38
+ __VA_ARGS__; \
39
+ break; \
40
+ } \
41
+ default: \
42
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
43
+ }
44
+
45
+ #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
46
+ switch (TYPEIN) \
47
+ { \
48
+ case at::ScalarType::Float: \
49
+ { \
50
+ using scalar_t_in = float; \
51
+ switch (TYPEOUT) \
52
+ { \
53
+ case at::ScalarType::Float: \
54
+ { \
55
+ using scalar_t_out = float; \
56
+ __VA_ARGS__; \
57
+ break; \
58
+ } \
59
+ case at::ScalarType::Half: \
60
+ { \
61
+ using scalar_t_out = at::Half; \
62
+ __VA_ARGS__; \
63
+ break; \
64
+ } \
65
+ case at::ScalarType::BFloat16: \
66
+ { \
67
+ using scalar_t_out = at::BFloat16; \
68
+ __VA_ARGS__; \
69
+ break; \
70
+ } \
71
+ default: \
72
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
73
+ } \
74
+ break; \
75
+ } \
76
+ case at::ScalarType::Half: \
77
+ { \
78
+ using scalar_t_in = at::Half; \
79
+ using scalar_t_out = at::Half; \
80
+ __VA_ARGS__; \
81
+ break; \
82
+ } \
83
+ case at::ScalarType::BFloat16: \
84
+ { \
85
+ using scalar_t_in = at::BFloat16; \
86
+ using scalar_t_out = at::BFloat16; \
87
+ __VA_ARGS__; \
88
+ break; \
89
+ } \
90
+ default: \
91
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
92
+ }
bigvgan/alias_free_activation/torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
6
+ from .act import *
bigvgan/alias_free_activation/torch/act.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from .resample import UpSample1d, DownSample1d
6
+
7
+
8
+ class Activation1d(nn.Module):
9
+ def __init__(
10
+ self,
11
+ activation,
12
+ up_ratio: int = 2,
13
+ down_ratio: int = 2,
14
+ up_kernel_size: int = 12,
15
+ down_kernel_size: int = 12,
16
+ ):
17
+ super().__init__()
18
+ self.up_ratio = up_ratio
19
+ self.down_ratio = down_ratio
20
+ self.act = activation
21
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
22
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
23
+
24
+ # x: [B,C,T]
25
+ def forward(self, x):
26
+ x = self.upsample(x)
27
+ x = self.act(x)
28
+ x = self.downsample(x)
29
+
30
+ return x
bigvgan/alias_free_activation/torch/filter.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if "sinc" in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(
21
+ x == 0,
22
+ torch.tensor(1.0, device=x.device, dtype=x.dtype),
23
+ torch.sin(math.pi * x) / math.pi / x,
24
+ )
25
+
26
+
27
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
28
+ # https://adefossez.github.io/julius/julius/lowpass.html
29
+ # LICENSE is in incl_licenses directory.
30
+ def kaiser_sinc_filter1d(
31
+ cutoff, half_width, kernel_size
32
+ ): # return filter [1,1,kernel_size]
33
+ even = kernel_size % 2 == 0
34
+ half_size = kernel_size // 2
35
+
36
+ # For kaiser window
37
+ delta_f = 4 * half_width
38
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
39
+ if A > 50.0:
40
+ beta = 0.1102 * (A - 8.7)
41
+ elif A >= 21.0:
42
+ beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
43
+ else:
44
+ beta = 0.0
45
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
46
+
47
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
48
+ if even:
49
+ time = torch.arange(-half_size, half_size) + 0.5
50
+ else:
51
+ time = torch.arange(kernel_size) - half_size
52
+ if cutoff == 0:
53
+ filter_ = torch.zeros_like(time)
54
+ else:
55
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
56
+ """
57
+ Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
58
+ """
59
+ filter_ /= filter_.sum()
60
+ filter = filter_.view(1, 1, kernel_size)
61
+
62
+ return filter
63
+
64
+
65
+ class LowPassFilter1d(nn.Module):
66
+ def __init__(
67
+ self,
68
+ cutoff=0.5,
69
+ half_width=0.6,
70
+ stride: int = 1,
71
+ padding: bool = True,
72
+ padding_mode: str = "replicate",
73
+ kernel_size: int = 12,
74
+ ):
75
+ """
76
+ kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
77
+ """
78
+ super().__init__()
79
+ if cutoff < -0.0:
80
+ raise ValueError("Minimum cutoff must be larger than zero.")
81
+ if cutoff > 0.5:
82
+ raise ValueError("A cutoff above 0.5 does not make sense.")
83
+ self.kernel_size = kernel_size
84
+ self.even = kernel_size % 2 == 0
85
+ self.pad_left = kernel_size // 2 - int(self.even)
86
+ self.pad_right = kernel_size // 2
87
+ self.stride = stride
88
+ self.padding = padding
89
+ self.padding_mode = padding_mode
90
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
91
+ self.register_buffer("filter", filter)
92
+
93
+ # Input [B, C, T]
94
+ def forward(self, x):
95
+ _, C, _ = x.shape
96
+
97
+ if self.padding:
98
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
99
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
100
+
101
+ return out
bigvgan/alias_free_activation/torch/resample.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from .filter import LowPassFilter1d
7
+ from .filter import kaiser_sinc_filter1d
8
+
9
+
10
+ class UpSample1d(nn.Module):
11
+ def __init__(self, ratio=2, kernel_size=None):
12
+ super().__init__()
13
+ self.ratio = ratio
14
+ self.kernel_size = (
15
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
16
+ )
17
+ self.stride = ratio
18
+ self.pad = self.kernel_size // ratio - 1
19
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
20
+ self.pad_right = (
21
+ self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
22
+ )
23
+ filter = kaiser_sinc_filter1d(
24
+ cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
25
+ )
26
+ self.register_buffer("filter", filter)
27
+
28
+ # x: [B, C, T]
29
+ def forward(self, x):
30
+ _, C, _ = x.shape
31
+
32
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
33
+ x = self.ratio * F.conv_transpose1d(
34
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
35
+ )
36
+ x = x[..., self.pad_left : -self.pad_right]
37
+
38
+ return x
39
+
40
+
41
+ class DownSample1d(nn.Module):
42
+ def __init__(self, ratio=2, kernel_size=None):
43
+ super().__init__()
44
+ self.ratio = ratio
45
+ self.kernel_size = (
46
+ int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
47
+ )
48
+ self.lowpass = LowPassFilter1d(
49
+ cutoff=0.5 / ratio,
50
+ half_width=0.6 / ratio,
51
+ stride=ratio,
52
+ kernel_size=self.kernel_size,
53
+ )
54
+
55
+ def forward(self, x):
56
+ xx = self.lowpass(x)
57
+
58
+ return xx
bigvgan/env.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import os
5
+ import shutil
6
+
7
+
8
+ class AttrDict(dict):
9
+ def __init__(self, *args, **kwargs):
10
+ super(AttrDict, self).__init__(*args, **kwargs)
11
+ self.__dict__ = self
12
+
13
+
14
+ def build_env(config, config_name, path):
15
+ t_path = os.path.join(path, config_name)
16
+ if config != t_path:
17
+ os.makedirs(path, exist_ok=True)
18
+ shutil.copyfile(config, os.path.join(path, config_name))
bigvgan/model.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ import os
8
+ import json
9
+ from pathlib import Path
10
+ from typing import Optional, Union, Dict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.nn import Conv1d, ConvTranspose1d
15
+ from torch.nn.utils import weight_norm, remove_weight_norm
16
+ from safetensors.torch import load_file
17
+
18
+ from .activations import Snake, SnakeBeta
19
+ from .utils import init_weights, get_padding
20
+ from .alias_free_activation.torch.act import Activation1d as TorchActivation1d
21
+ from .env import AttrDict
22
+
23
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
24
+
25
+
26
+ def load_hparams_from_json(path) -> AttrDict:
27
+ with open(path) as f:
28
+ data = f.read()
29
+ return AttrDict(json.loads(data))
30
+
31
+
32
+ class AMPBlock1(torch.nn.Module):
33
+ """
34
+ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
35
+ AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
36
+
37
+ Args:
38
+ h (AttrDict): Hyperparameters.
39
+ channels (int): Number of convolution channels.
40
+ kernel_size (int): Size of the convolution kernel. Default is 3.
41
+ dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
42
+ activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ h: AttrDict,
48
+ channels: int,
49
+ kernel_size: int = 3,
50
+ dilation: tuple = (1, 3, 5),
51
+ activation: str = None,
52
+ ):
53
+ super().__init__()
54
+
55
+ self.h = h
56
+
57
+ self.convs1 = nn.ModuleList(
58
+ [
59
+ weight_norm(
60
+ Conv1d(
61
+ channels,
62
+ channels,
63
+ kernel_size,
64
+ stride=1,
65
+ dilation=d,
66
+ padding=get_padding(kernel_size, d),
67
+ )
68
+ )
69
+ for d in dilation
70
+ ]
71
+ )
72
+ self.convs1.apply(init_weights)
73
+
74
+ self.convs2 = nn.ModuleList(
75
+ [
76
+ weight_norm(
77
+ Conv1d(
78
+ channels,
79
+ channels,
80
+ kernel_size,
81
+ stride=1,
82
+ dilation=1,
83
+ padding=get_padding(kernel_size, 1),
84
+ )
85
+ )
86
+ for _ in range(len(dilation))
87
+ ]
88
+ )
89
+ self.convs2.apply(init_weights)
90
+
91
+ self.num_layers = len(self.convs1) + len(
92
+ self.convs2
93
+ ) # Total number of conv layers
94
+
95
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
96
+ if self.h.get("use_cuda_kernel", False):
97
+ from alias_free_activation.cuda.activation1d import (
98
+ Activation1d as CudaActivation1d,
99
+ )
100
+
101
+ Activation1d = CudaActivation1d
102
+ else:
103
+ Activation1d = TorchActivation1d
104
+
105
+ # Activation functions
106
+ if activation == "snake":
107
+ self.activations = nn.ModuleList(
108
+ [
109
+ Activation1d(
110
+ activation=Snake(
111
+ channels, alpha_logscale=h.snake_logscale
112
+ )
113
+ )
114
+ for _ in range(self.num_layers)
115
+ ]
116
+ )
117
+ elif activation == "snakebeta":
118
+ self.activations = nn.ModuleList(
119
+ [
120
+ Activation1d(
121
+ activation=SnakeBeta(
122
+ channels, alpha_logscale=h.snake_logscale
123
+ )
124
+ )
125
+ for _ in range(self.num_layers)
126
+ ]
127
+ )
128
+ else:
129
+ raise NotImplementedError(
130
+ "activation incorrectly specified. check the config file and look for 'activation'."
131
+ )
132
+
133
+ def forward(self, x):
134
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
135
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
136
+ xt = a1(x)
137
+ xt = c1(xt)
138
+ xt = a2(xt)
139
+ xt = c2(xt)
140
+ x = xt + x
141
+
142
+ return x
143
+
144
+ def remove_weight_norm(self):
145
+ for l in self.convs1:
146
+ remove_weight_norm(l)
147
+ for l in self.convs2:
148
+ remove_weight_norm(l)
149
+
150
+
151
+ class AMPBlock2(torch.nn.Module):
152
+ """
153
+ AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
154
+ Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
155
+
156
+ Args:
157
+ h (AttrDict): Hyperparameters.
158
+ channels (int): Number of convolution channels.
159
+ kernel_size (int): Size of the convolution kernel. Default is 3.
160
+ dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
161
+ activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ h: AttrDict,
167
+ channels: int,
168
+ kernel_size: int = 3,
169
+ dilation: tuple = (1, 3, 5),
170
+ activation: str = None,
171
+ ):
172
+ super().__init__()
173
+
174
+ self.h = h
175
+
176
+ self.convs = nn.ModuleList(
177
+ [
178
+ weight_norm(
179
+ Conv1d(
180
+ channels,
181
+ channels,
182
+ kernel_size,
183
+ stride=1,
184
+ dilation=d,
185
+ padding=get_padding(kernel_size, d),
186
+ )
187
+ )
188
+ for d in dilation
189
+ ]
190
+ )
191
+ self.convs.apply(init_weights)
192
+
193
+ self.num_layers = len(self.convs) # Total number of conv layers
194
+
195
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
196
+ if self.h.get("use_cuda_kernel", False):
197
+ from alias_free_activation.cuda.activation1d import (
198
+ Activation1d as CudaActivation1d,
199
+ )
200
+
201
+ Activation1d = CudaActivation1d
202
+ else:
203
+ Activation1d = TorchActivation1d
204
+
205
+ # Activation functions
206
+ if activation == "snake":
207
+ self.activations = nn.ModuleList(
208
+ [
209
+ Activation1d(
210
+ activation=Snake(
211
+ channels, alpha_logscale=h.snake_logscale
212
+ )
213
+ )
214
+ for _ in range(self.num_layers)
215
+ ]
216
+ )
217
+ elif activation == "snakebeta":
218
+ self.activations = nn.ModuleList(
219
+ [
220
+ Activation1d(
221
+ activation=SnakeBeta(
222
+ channels, alpha_logscale=h.snake_logscale
223
+ )
224
+ )
225
+ for _ in range(self.num_layers)
226
+ ]
227
+ )
228
+ else:
229
+ raise NotImplementedError(
230
+ "activation incorrectly specified. check the config file and look for 'activation'."
231
+ )
232
+
233
+ def forward(self, x):
234
+ for c, a in zip(self.convs, self.activations):
235
+ xt = a(x)
236
+ xt = c(xt)
237
+ x = xt + x
238
+ return x
239
+
240
+ def remove_weight_norm(self):
241
+ for l in self.convs:
242
+ remove_weight_norm(l)
243
+
244
+
245
+ class BigVGAN(
246
+ torch.nn.Module,
247
+ PyTorchModelHubMixin,
248
+ library_name="bigvgan",
249
+ repo_url="https://github.com/NVIDIA/BigVGAN",
250
+ docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
251
+ pipeline_tag="audio-to-audio",
252
+ license="mit",
253
+ tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
254
+ ):
255
+ def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
256
+ super().__init__()
257
+ self.h = h
258
+ self.h["use_cuda_kernel"] = use_cuda_kernel
259
+
260
+ # Select which Activation1d, lazy-load cuda version to ensure backward compatibility
261
+ if self.h.get("use_cuda_kernel", False):
262
+ from alias_free_activation.cuda.activation1d import (
263
+ Activation1d as CudaActivation1d,
264
+ )
265
+
266
+ Activation1d = CudaActivation1d
267
+ else:
268
+ Activation1d = TorchActivation1d
269
+
270
+ self.num_kernels = len(h.resblock_kernel_sizes)
271
+ self.num_upsamples = len(h.upsample_rates)
272
+
273
+ # Pre-conv
274
+ self.conv_pre = weight_norm(
275
+ Conv1d(h.in_channels, h.upsample_initial_channel, 7, 1, padding=3)
276
+ )
277
+
278
+ # Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
279
+ if h.resblock == "1":
280
+ resblock_class = AMPBlock1
281
+ elif h.resblock == "2":
282
+ resblock_class = AMPBlock2
283
+ else:
284
+ raise ValueError(
285
+ f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
286
+ )
287
+
288
+ # Transposed conv-based upsamplers. does not apply anti-aliasing
289
+ self.ups = nn.ModuleList()
290
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
291
+ self.ups.append(
292
+ nn.ModuleList(
293
+ [
294
+ weight_norm(
295
+ ConvTranspose1d(
296
+ h.upsample_initial_channel // (2**i),
297
+ h.upsample_initial_channel // (2 ** (i + 1)),
298
+ k,
299
+ u,
300
+ padding=(k - u) // 2,
301
+ )
302
+ )
303
+ ]
304
+ )
305
+ )
306
+
307
+ # Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
308
+ self.resblocks = nn.ModuleList()
309
+ for i in range(len(self.ups)):
310
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
311
+ for j, (k, d) in enumerate(
312
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
313
+ ):
314
+ self.resblocks.append(
315
+ resblock_class(h, ch, k, d, activation=h.activation)
316
+ )
317
+
318
+ # Post-conv
319
+ activation_post = (
320
+ Snake(ch, alpha_logscale=h.snake_logscale)
321
+ if h.activation == "snake"
322
+ else (
323
+ SnakeBeta(ch, alpha_logscale=h.snake_logscale)
324
+ if h.activation == "snakebeta"
325
+ else None
326
+ )
327
+ )
328
+ if activation_post is None:
329
+ raise NotImplementedError(
330
+ "activation incorrectly specified. check the config file and look for 'activation'."
331
+ )
332
+
333
+ self.activation_post = Activation1d(activation=activation_post)
334
+
335
+ # Whether to use bias for the final conv_post. Default to True for backward compatibility
336
+ self.use_bias_at_final = h.get("use_bias_at_final", True)
337
+ self.conv_post = weight_norm(
338
+ Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
339
+ )
340
+
341
+ # Weight initialization
342
+ for i in range(len(self.ups)):
343
+ self.ups[i].apply(init_weights)
344
+ self.conv_post.apply(init_weights)
345
+
346
+ # Final tanh activation. Defaults to True for backward compatibility
347
+ self.use_tanh_at_final = h.get("use_tanh_at_final", True)
348
+
349
+ def forward(self, x):
350
+ # Pre-conv
351
+ x = self.conv_pre(x)
352
+
353
+ for i in range(self.num_upsamples):
354
+ # Upsampling
355
+ for i_up in range(len(self.ups[i])):
356
+ x = self.ups[i][i_up](x)
357
+ # AMP blocks
358
+ xs = None
359
+ for j in range(self.num_kernels):
360
+ if xs is None:
361
+ xs = self.resblocks[i * self.num_kernels + j](x)
362
+ else:
363
+ xs += self.resblocks[i * self.num_kernels + j](x)
364
+ x = xs / self.num_kernels
365
+
366
+ # Post-conv
367
+ x = self.activation_post(x)
368
+ x = self.conv_post(x)
369
+ # Final tanh activation
370
+ if self.use_tanh_at_final:
371
+ x = torch.tanh(x)
372
+ else:
373
+ x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
374
+
375
+ return x
376
+
377
+ def remove_weight_norm(self):
378
+ try:
379
+ print("Removing weight norm...")
380
+ for l in self.ups:
381
+ for l_i in l:
382
+ remove_weight_norm(l_i)
383
+ for l in self.resblocks:
384
+ l.remove_weight_norm()
385
+ remove_weight_norm(self.conv_pre)
386
+ remove_weight_norm(self.conv_post)
387
+ except ValueError:
388
+ print("[INFO] Model already removed weight norm. Skipping!")
389
+ pass
390
+
391
+ # Additional methods for huggingface_hub support
392
+ def _save_pretrained(self, save_directory: Path) -> None:
393
+ """Save weights and config.json from a Pytorch model to a local directory."""
394
+
395
+ model_path = save_directory / "bigvgan_generator.pt"
396
+ torch.save({"generator": self.state_dict()}, model_path)
397
+
398
+ config_path = save_directory / "config.json"
399
+ with open(config_path, "w") as config_file:
400
+ json.dump(self.h, config_file, indent=4)
401
+
402
+ @classmethod
403
+ def _from_pretrained(
404
+ cls,
405
+ *,
406
+ model_id: str,
407
+ revision: str,
408
+ cache_dir: str,
409
+ force_download: bool,
410
+ proxies: Optional[Dict],
411
+ resume_download: bool,
412
+ local_files_only: bool,
413
+ token: Union[str, bool, None],
414
+ map_location: str = "cpu", # Additional argument
415
+ strict: bool = False, # Additional argument
416
+ use_cuda_kernel: bool = False,
417
+ **model_kwargs,
418
+ ):
419
+ """Load Pytorch pretrained weights and return the loaded model."""
420
+
421
+ # Download and load hyperparameters (h) used by BigVGAN
422
+ if os.path.isdir(model_id):
423
+ print("Loading config.json from local directory")
424
+ config_file = os.path.join(model_id, "config.json")
425
+ else:
426
+ config_file = hf_hub_download(
427
+ repo_id=model_id,
428
+ filename="config.json",
429
+ revision=revision,
430
+ cache_dir=cache_dir,
431
+ force_download=force_download,
432
+ proxies=proxies,
433
+ resume_download=resume_download,
434
+ token=token,
435
+ local_files_only=local_files_only,
436
+ )
437
+ h = load_hparams_from_json(config_file)
438
+
439
+ # instantiate BigVGAN using h
440
+ if use_cuda_kernel:
441
+ print(
442
+ f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
443
+ )
444
+ print(
445
+ f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
446
+ )
447
+ print(
448
+ f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
449
+ )
450
+ model = cls(h, use_cuda_kernel=use_cuda_kernel)
451
+
452
+ # Download and load pretrained generator weight
453
+ if os.path.isdir(model_id):
454
+ print("Loading weights from local directory")
455
+ model_file = os.path.join(model_id, "bigvgan_generator.pt")
456
+ else:
457
+ print(f"Loading weights from {model_id}")
458
+ model_file = hf_hub_download(
459
+ repo_id=model_id,
460
+ filename="bigvgan_generator.pt",
461
+ revision=revision,
462
+ cache_dir=cache_dir,
463
+ force_download=force_download,
464
+ proxies=proxies,
465
+ resume_download=resume_download,
466
+ token=token,
467
+ local_files_only=local_files_only,
468
+ )
469
+
470
+ checkpoint_dict = torch.load(model_file, map_location=map_location)
471
+
472
+ try:
473
+ model.load_state_dict(checkpoint_dict["generator"])
474
+ except RuntimeError:
475
+ print(
476
+ f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
477
+ )
478
+ model.remove_weight_norm()
479
+ model.load_state_dict(checkpoint_dict["generator"])
480
+
481
+ return model
482
+
483
+
484
+ class Generator(torch.nn.Module):
485
+ def __init__(self, config_file, ckpt_path):
486
+ super().__init__()
487
+ with open(config_file) as f:
488
+ json_config = json.load(f)
489
+ self.h = AttrDict(json_config)
490
+ self.decoder = BigVGAN(self.h)
491
+ if ckpt_path.endswith(".safetensors"):
492
+ checkpoint_dict = load_file(ckpt_path)
493
+ else:
494
+ checkpoint_dict = torch.load(ckpt_path, map_location='cpu')
495
+ self.decoder.load_state_dict(checkpoint_dict["generator"])
496
+ self.decoder.remove_weight_norm()
497
+ self.decoder.eval()
498
+
499
+ def decode_audio(self, latents, overlap=5, chunk_size=20):
500
+ # chunked decoding
501
+ hop_size = chunk_size - overlap
502
+ total_size = latents.shape[2]
503
+ batch_size = latents.shape[0]
504
+ chunks = []
505
+ for i in range(0, total_size - chunk_size + 1, hop_size):
506
+ chunk = latents[:,:,i:i+chunk_size]
507
+ chunks.append(chunk)
508
+ if i+chunk_size != total_size:
509
+ # Final chunk
510
+ chunk = latents[:,:,-chunk_size:]
511
+ chunks.append(chunk)
512
+ chunks = torch.stack(chunks)
513
+ num_chunks = chunks.shape[0]
514
+ # samples_per_latent is just the downsampling ratio
515
+ samples_per_latent = 9600
516
+ # Create an empty waveform, we will populate it with chunks as decode them
517
+ y_size = total_size * samples_per_latent
518
+ y_final = torch.zeros((batch_size,1,y_size)).to(latents.device)
519
+ for i in range(num_chunks):
520
+ x_chunk = chunks[i,:]
521
+ # decode the chunk
522
+ y_chunk = self.decoder(x_chunk)
523
+ # figure out where to put the audio along the time domain
524
+ if i == num_chunks-1:
525
+ # final chunk always goes at the end
526
+ t_end = y_size
527
+ t_start = t_end - y_chunk.shape[2]
528
+ else:
529
+ t_start = i * hop_size * samples_per_latent
530
+ t_end = t_start + chunk_size * samples_per_latent
531
+ # remove the edges of the overlaps
532
+ ol = (overlap//2) * samples_per_latent
533
+ chunk_start = 0
534
+ chunk_end = y_chunk.shape[2]
535
+ if i > 0:
536
+ # no overlap for the start of the first chunk
537
+ t_start += ol
538
+ chunk_start += ol
539
+ if i < num_chunks-1:
540
+ # no overlap for the end of the last chunk
541
+ t_end -= ol
542
+ chunk_end -= ol
543
+ # paste the chunked audio into our y_final output audio
544
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
545
+ return y_final
bigvgan/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import glob
5
+ import os
6
+ import torch
7
+ from torch.nn.utils import weight_norm
8
+
9
+
10
+
11
+ def init_weights(m, mean=0.0, std=0.01):
12
+ classname = m.__class__.__name__
13
+ if classname.find("Conv") != -1:
14
+ m.weight.data.normal_(mean, std)
15
+
16
+
17
+ def apply_weight_norm(m):
18
+ classname = m.__class__.__name__
19
+ if classname.find("Conv") != -1:
20
+ weight_norm(m)
21
+
22
+
23
+ def get_padding(kernel_size, dilation=1):
24
+ return int((kernel_size * dilation - dilation) / 2)
25
+
26
+
27
+ def load_checkpoint(filepath, device):
28
+ assert os.path.isfile(filepath)
29
+ print(f"Loading '{filepath}'")
30
+ checkpoint_dict = torch.load(filepath, map_location=device)
31
+ print("Complete.")
32
+ return checkpoint_dict
33
+
34
+
35
+ def save_checkpoint(filepath, obj):
36
+ print(f"Saving checkpoint to {filepath}")
37
+ torch.save(obj, filepath)
38
+ print("Complete.")
39
+
40
+
41
+ def scan_checkpoint(cp_dir, prefix, renamed_file=None):
42
+ # Fallback to original scanning logic first
43
+ pattern = os.path.join(cp_dir, prefix + "????????")
44
+ cp_list = glob.glob(pattern)
45
+
46
+ if len(cp_list) > 0:
47
+ last_checkpoint_path = sorted(cp_list)[-1]
48
+ print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'")
49
+ return last_checkpoint_path
50
+
51
+ # If no pattern-based checkpoints are found, check for renamed file
52
+ if renamed_file:
53
+ renamed_path = os.path.join(cp_dir, renamed_file)
54
+ if os.path.isfile(renamed_path):
55
+ print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'")
56
+ return renamed_path
57
+
58
+ return None
59
+
diffrhythm2/__init__.py ADDED
File without changes
diffrhythm2/backbones/__init__.py ADDED
File without changes
diffrhythm2/backbones/dit.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ASLP Lab and Xiaomi Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from __future__ import annotations
17
+
18
+ import torch
19
+ import math
20
+ from torch import nn
21
+
22
+ from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, LlamaConfig
23
+ from .llama_nar import LlamaNARDecoderLayer
24
+
25
+ class TextEmbedding(nn.Module):
26
+ def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
27
+ super().__init__()
28
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
29
+
30
+ def forward(self, text: int["b nt"]): # noqa: F722
31
+ text = self.text_embed(text) # b n -> b n d
32
+ return text
33
+
34
+
35
+ class InputEmbedding(nn.Module):
36
+ def __init__(self, cond_dim, out_dim):
37
+ super().__init__()
38
+ self.proj = nn.Linear(cond_dim, cond_dim)
39
+ self.proj_2 = nn.Linear(cond_dim, out_dim)
40
+
41
+ def forward(self, x, style_emb, time_emb): # noqa: F722
42
+ style_emb = style_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
43
+ x_orig = x
44
+ x = x + style_emb + time_emb
45
+ x = self.proj(x) + x_orig
46
+ x = self.proj_2(x)
47
+ return x
48
+
49
+
50
+ class AdaLayerNormZero_Final(nn.Module):
51
+ def __init__(self, dim, cond_dim):
52
+ super().__init__()
53
+
54
+ self.silu = nn.SiLU()
55
+ self.linear = nn.Linear(cond_dim, dim * 2)
56
+
57
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
58
+
59
+ def forward(self, x, emb):
60
+ emb = self.linear(self.silu(emb))
61
+ scale, shift = torch.chunk(emb, 2, dim=-1)
62
+
63
+ x = self.norm(x) * (1 + scale) + shift
64
+ return x
65
+
66
+
67
+ class SinusPositionEmbedding(nn.Module):
68
+ def __init__(self, dim):
69
+ super().__init__()
70
+ self.dim = dim
71
+
72
+ def forward(self, x, scale=1000):
73
+ device = x.device
74
+ half_dim = self.dim // 2
75
+ emb = math.log(10000) / (half_dim - 1)
76
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
77
+ emb = scale * x.unsqueeze(-1) * emb.unsqueeze(0).unsqueeze(0)
78
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
79
+ return emb
80
+
81
+ def numel(self):
82
+ return 0
83
+
84
+
85
+ class TimestepEmbedding(nn.Module):
86
+ def __init__(self, dim, freq_embed_dim=256):
87
+ super().__init__()
88
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
89
+ self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
90
+
91
+ def forward(self, timestep: float["b"]): # noqa: F821
92
+ time_hidden = self.time_embed(timestep)
93
+ time_hidden = time_hidden.to(timestep.dtype)
94
+ time = self.time_mlp(time_hidden) # b d
95
+ return time
96
+
97
+
98
+ class DiT(nn.Module):
99
+ def __init__(
100
+ self,
101
+ *,
102
+ dim,
103
+ depth=8,
104
+ heads=8,
105
+ ff_mult=4,
106
+ mel_dim=100,
107
+ text_num_embeds=256,
108
+ conv_layers=0,
109
+ long_skip_connection=False,
110
+ use_flex_attn=False,
111
+ repa_depth=-1,
112
+ repa_dims=[1024],
113
+ **kwargs
114
+ ):
115
+ super().__init__()
116
+
117
+ cond_dim = 512
118
+ self.time_embed = TimestepEmbedding(cond_dim)
119
+ self.text_embed = TextEmbedding(text_num_embeds, cond_dim, conv_layers=conv_layers)
120
+ self.input_embed = InputEmbedding(cond_dim, dim)
121
+
122
+ self.latent_embed = torch.nn.Sequential(
123
+ nn.Linear(mel_dim, cond_dim),
124
+ nn.Linear(cond_dim, cond_dim)
125
+ )
126
+
127
+ self.dim = dim
128
+ self.depth = depth
129
+ self.use_flex_attn = use_flex_attn
130
+
131
+ llama_config = LlamaConfig(
132
+ hidden_size=dim,
133
+ num_attention_heads=heads,
134
+ intermediate_size=dim * ff_mult,
135
+ hidden_act='silu',
136
+ max_position_embeddings=4096
137
+ )
138
+ self.rotary_embed = LlamaRotaryEmbedding(config=llama_config)
139
+ llama_config._attn_implementation = 'sdpa'
140
+ self.transformer_blocks = nn.ModuleList(
141
+ [LlamaNARDecoderLayer(llama_config, layer_idx=i, use_flex_attn=self.use_flex_attn) for i in range(depth)]
142
+ )
143
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
144
+
145
+
146
+ self.norm_out = AdaLayerNormZero_Final(dim, cond_dim) # final modulation
147
+ self.proj_out = nn.Linear(dim, mel_dim)
148
+
149
+ self.repa_depth = repa_depth
150
+ self.repa_dims = repa_dims
151
+ self.projectors = None
152
+ if self.repa_depth > 0:
153
+ self.projectors = nn.ModuleList([
154
+ nn.Sequential(
155
+ nn.Linear(self.dim, self.dim * 2),
156
+ nn.SiLU(),
157
+ nn.Linear(self.dim * 2, self.dim * 2),
158
+ nn.SiLU(),
159
+ nn.Linear(self.dim * 2, repa_dim),
160
+ ) for repa_dim in self.repa_dims
161
+ ])
162
+
163
+
164
+ def forward(
165
+ self,
166
+ x: torch.Tensor,
167
+ time: torch.Tensor,
168
+ position_ids: torch.Tensor,
169
+ style_prompt: torch.Tensor,
170
+ attn_mask: torch.Tensor,
171
+ output_attentions: bool = False,
172
+ use_cache: bool = False,
173
+ past_key_value = None,
174
+ ):
175
+ """
176
+ Args:
177
+ x: [b, n, d]
178
+ time: [b, n, 1]
179
+ position_ids: [b, n]
180
+ style_prompt: [b, 512]
181
+ attn_mask: [b, 1, n, n]
182
+ """
183
+ batch, seq_len = x.shape[0], x.shape[1]
184
+ t = self.time_embed(time)
185
+ c = t # [B, T, dim]
186
+
187
+ x = self.input_embed(x, style_prompt, c)
188
+
189
+ if self.long_skip_connection is not None:
190
+ residual = x
191
+
192
+ position_embeddings = self.rotary_embed(x, position_ids)
193
+
194
+ attn_weights = []
195
+ if not use_cache:
196
+ past_key_value = None
197
+
198
+ repa_res = None
199
+ for i, block in enumerate(self.transformer_blocks):
200
+ res = block(
201
+ x,
202
+ attention_mask=attn_mask,
203
+ position_embeddings=position_embeddings,
204
+ output_attentions=output_attentions,
205
+ past_key_value=past_key_value,
206
+ use_cache=use_cache
207
+ )
208
+ x = res.pop(0)
209
+ if output_attentions:
210
+ attn_weights.append(res.pop(0))
211
+ if use_cache:
212
+ past_key_value = res.pop(0)
213
+ if i == self.repa_depth - 1:
214
+ repa_res = x
215
+
216
+ if self.long_skip_connection is not None:
217
+ x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
218
+
219
+ x = self.norm_out(x, c)
220
+ output = self.proj_out(x)
221
+
222
+ return output, attn_weights, past_key_value
diffrhythm2/backbones/flex_attention.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import torch
19
+
20
+ from torch.nn.attention.flex_attention import BlockMask, flex_attention
21
+ from torch.nn.attention.flex_attention import (
22
+ create_block_mask as create_block_causal_mask_flex,
23
+ )
24
+
25
+ class WrappedFlexAttention:
26
+ """
27
+ We are doing a singleton class so that flex attention is compiled once when it's first called.
28
+ """
29
+
30
+ _instance = None
31
+ _is_flex_compiled = False
32
+ _compiled_flex_attention = None
33
+
34
+ def __new__(cls, *args, **kwargs):
35
+ if cls._instance is None:
36
+ # Create a new instance if one doesn't already exist
37
+ cls._instance = super().__new__(cls)
38
+ return cls._instance
39
+
40
+ @torch.compiler.disable(recursive=False)
41
+ def __init__(self, training):
42
+ """
43
+ Initialize or update the singleton instance.
44
+ """
45
+ if not self._is_flex_compiled or training != self.training:
46
+ # In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
47
+ # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
48
+ # see https://github.com/pytorch/pytorch/issues/146260 for training
49
+ self.training = training
50
+ if torch.__version__.split("+")[0] == "2.6.0" and training:
51
+ self._compiled_flex_attention = torch.compile(
52
+ flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
53
+ )
54
+ else:
55
+ self._compiled_flex_attention = torch.compile(flex_attention)
56
+ self._is_flex_compiled = True
57
+
58
+ def __call__(self):
59
+ return self._compiled_flex_attention
60
+
61
+
62
+ Offset = Union[torch.Tensor, int]
63
+
64
+
65
+ def make_flex_block_causal_mask(
66
+ attention_mask_2d: torch.Tensor,
67
+ attention_chunk_size: Optional[int] = None,
68
+ query_length=None,
69
+ key_length=None,
70
+ offsets: Optional[Tuple[Offset, Offset]] = None,
71
+ ) -> "BlockMask":
72
+ """
73
+ Create a block causal document mask for a batch of sequences, both packed and unpacked.
74
+ Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
75
+ The resultant BlockMask is a compressed representation of the full block causal
76
+ mask. BlockMask is essential for performant computation of flex attention.
77
+ See: https://pytorch.org/blog/flexattention/
78
+
79
+ Args:
80
+ attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences
81
+ of shape (batch_size, total_seq_len). e.g.
82
+
83
+ For unpacked sequence:
84
+ [[1, 1, 1, 1, 0, 0, 0],
85
+ [1, 1, 1, 1, 1, 0, 0]]
86
+
87
+ For packed sequence:
88
+ [[1, 1, 1, 2, 2, 2, 0],
89
+ [1, 1, 2, 2, 2, 3, 3]]
90
+
91
+ Returns:
92
+ BlockMask
93
+ """
94
+ batch_size, total_seq_len = attention_mask_2d.shape
95
+ if not key_length:
96
+ key_length = total_seq_len
97
+ if not query_length:
98
+ query_length = total_seq_len
99
+ attention_mask_2d = torch.nn.functional.pad(attention_mask_2d, value=0, pad=(0, key_length))
100
+ device = attention_mask_2d.device
101
+ document_ids = attention_mask_2d.clone()
102
+
103
+ if attention_chunk_size is not None:
104
+ # we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
105
+ document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // (attention_chunk_size)
106
+
107
+ # Instead of passing a tensor mask, flex attention requires a mask_mod function
108
+ # that determines which elements of QK^T should be included in the attention
109
+ # computation prior to the softmax. For sample packing, we need both the
110
+ # logic for both causal mask and document mask. See PyTorch's official
111
+ # blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
112
+ def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
113
+ """
114
+ Defines the logic of a block causal mask by combining both a standard causal mask
115
+ and a block diagonal document mask.
116
+
117
+ See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
118
+ for an illustration.
119
+ """
120
+ causal_mask = q_idx >= kv_idx # not valid when decoding
121
+ document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
122
+ padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
123
+ final_mask = causal_mask & padding_mask & document_mask
124
+ return final_mask
125
+
126
+ if offsets is not None:
127
+ q_offset = offsets[0]
128
+ kv_offset = offsets[1]
129
+
130
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
131
+ offset_q = q_idx + q_offset
132
+ offset_kv = kv_idx + kv_offset
133
+ return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv)
134
+ else:
135
+ mask_mod = causal_mask_mod
136
+ return create_block_causal_mask_flex(
137
+ mask_mod=mask_mod,
138
+ B=batch_size,
139
+ H=None, # attention head
140
+ Q_LEN=query_length,
141
+ KV_LEN=key_length,
142
+ device=device,
143
+ _compile=True,
144
+ )
145
+
146
+
147
+ @torch.compiler.disable(recursive=False)
148
+ def compile_friendly_flex_attention(
149
+ query: torch.Tensor,
150
+ key: torch.Tensor,
151
+ value: torch.Tensor,
152
+ training=False,
153
+ **kwargs,
154
+ ) -> torch.Tensor:
155
+ # First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention
156
+ flex_attention_compiled = WrappedFlexAttention(training)()
157
+ return flex_attention_compiled(
158
+ query,
159
+ key,
160
+ value,
161
+ **kwargs,
162
+ )
163
+
164
+
165
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
166
+ """
167
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
168
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
169
+ """
170
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
171
+ if n_rep == 1:
172
+ return hidden_states
173
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
174
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
175
+
176
+
177
+ def flex_attention_forward(
178
+ query: torch.Tensor,
179
+ key: torch.Tensor,
180
+ value: torch.Tensor,
181
+ attention_mask: Union[torch.Tensor, "BlockMask"],
182
+ training: bool = True,
183
+ scaling: Optional[float] = None,
184
+ softcap: Optional[float] = None,
185
+ head_mask: Optional[torch.Tensor] = None,
186
+ **kwargs,
187
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
188
+ block_mask = None
189
+ causal_mask = None
190
+
191
+ block_mask = attention_mask
192
+ # if isinstance(attention_mask, BlockMask):
193
+ # block_mask = attention_mask
194
+ # else:
195
+ # causal_mask = attention_mask
196
+
197
+ if causal_mask is not None:
198
+ causal_mask = causal_mask[:, :, :, : key.shape[-2]]
199
+
200
+ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
201
+ if softcap is not None:
202
+ score = softcap * torch.tanh(score / softcap)
203
+ if causal_mask is not None:
204
+ score = score + causal_mask[batch_idx][0][q_idx][kv_idx]
205
+ if head_mask is not None:
206
+ score = score + head_mask[batch_idx][head_idx][0][0]
207
+ return score
208
+
209
+ enable_gqa = True
210
+ num_local_query_heads = query.shape[1]
211
+
212
+ # When running TP this helps:
213
+ if not ((num_local_query_heads & (num_local_query_heads - 1)) == 0):
214
+ key = repeat_kv(key, query.shape[1] // key.shape[1])
215
+ value = repeat_kv(value, query.shape[1] // value.shape[1])
216
+ enable_gqa = False
217
+
218
+ kernel_options = kwargs.get("kernel_options", None)
219
+ attn_output, attention_weights = compile_friendly_flex_attention(
220
+ query,
221
+ key,
222
+ value,
223
+ score_mod=score_mod,
224
+ block_mask=block_mask,
225
+ enable_gqa=enable_gqa,
226
+ scale=scaling,
227
+ kernel_options=kernel_options,
228
+ # Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
229
+ # For simplification, we thus always return it as no additional computations are introduced.
230
+ return_lse=True,
231
+ training=training,
232
+ )
233
+ # lse is returned in float32
234
+ attention_weights = attention_weights.to(value.dtype)
235
+ attn_output = attn_output.transpose(1, 2).contiguous()
236
+
237
+ return attn_output, attention_weights
diffrhythm2/backbones/llama_attention.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ASLP Lab and Xiaomi Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import math
19
+
20
+ from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRotaryEmbedding, LlamaRMSNorm
21
+ from transformers.models.llama.modeling_llama import Cache, StaticCache, FlashAttentionKwargs, Unpack
22
+ from transformers.models.llama.modeling_llama import (
23
+ apply_rotary_pos_emb,
24
+ repeat_kv,
25
+ _flash_attention_forward,
26
+ is_flash_attn_greater_or_equal_2_10
27
+ )
28
+ from transformers.models.llama.modeling_llama import logger
29
+ from typing import Optional, Tuple
30
+
31
+ try:
32
+ from .flex_attention import flex_attention_forward
33
+ except:
34
+ pass
35
+
36
+ class LlamaAttention(nn.Module):
37
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
38
+
39
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
40
+ super().__init__()
41
+ self.config = config
42
+ self.layer_idx = layer_idx
43
+ if layer_idx is None:
44
+ logger.warning_once(
45
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
46
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
47
+ "when creating this class."
48
+ )
49
+
50
+ self.attention_dropout = config.attention_dropout
51
+ self.hidden_size = config.hidden_size
52
+ self.num_heads = config.num_attention_heads
53
+ self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
54
+ self.num_key_value_heads = config.num_key_value_heads
55
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
56
+ self.max_position_embeddings = config.max_position_embeddings
57
+ self.rope_theta = config.rope_theta
58
+ self.is_causal = False
59
+
60
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
61
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
62
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
63
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
64
+
65
+ # TODO (joao): remove in v4.46 (RoPE is computed in the model, not in the decoder layers)
66
+ self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
67
+ self.q_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
68
+ self.k_norm = LlamaRMSNorm(self.head_dim, eps=config.rms_norm_eps)
69
+
70
+ def forward(
71
+ self,
72
+ hidden_states: torch.Tensor,
73
+ attention_mask: Optional[torch.Tensor] = None,
74
+ position_ids: Optional[torch.LongTensor] = None,
75
+ past_key_value: Optional[Cache] = None,
76
+ output_attentions: bool = False,
77
+ use_cache: bool = False,
78
+ cache_position: Optional[torch.LongTensor] = None,
79
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
80
+ **kwargs,
81
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
82
+ bsz, q_len, _ = hidden_states.size()
83
+
84
+ query_states = self.q_proj(hidden_states)
85
+ key_states = self.k_proj(hidden_states)
86
+ value_states = self.v_proj(hidden_states)
87
+
88
+ # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
89
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
90
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
91
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
92
+
93
+ query_states = self.q_norm(query_states)
94
+ key_states = self.k_norm(key_states)
95
+
96
+ if position_embeddings is None:
97
+ logger.warning_once(
98
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
99
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
100
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
101
+ "removed and `position_embeddings` will be mandatory."
102
+ )
103
+ cos, sin = self.rotary_emb(value_states, position_ids)
104
+ else:
105
+ cos, sin = position_embeddings
106
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
107
+
108
+ if past_key_value is not None:
109
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
110
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
111
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
112
+
113
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
114
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
115
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
116
+
117
+ if attention_mask is not None: # no matter the length, we just slice it
118
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
119
+ if attention_mask.dtype != torch.bool:
120
+ attn_weights = attn_weights + causal_mask
121
+ else:
122
+ attn_weights = torch.masked_fill(attn_weights, ~causal_mask, float("-inf"))
123
+
124
+ # upcast attention to fp32
125
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
126
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
127
+ attn_output = torch.matmul(attn_weights, value_states)
128
+
129
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
130
+ raise ValueError(
131
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
132
+ f" {attn_output.size()}"
133
+ )
134
+
135
+ attn_output = attn_output.transpose(1, 2).contiguous()
136
+
137
+ attn_output = attn_output.reshape(bsz, q_len, -1)
138
+
139
+ attn_output = self.o_proj(attn_output)
140
+
141
+ if not output_attentions:
142
+ attn_weights = None
143
+
144
+ return attn_output, attn_weights, past_key_value
145
+
146
+
147
+ class LlamaFlashAttention2(LlamaAttention):
148
+ """
149
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
150
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
151
+ flash attention and deal with padding tokens in case the input contains any of them.
152
+ """
153
+
154
+ def __init__(self, *args, **kwargs):
155
+ super().__init__(*args, **kwargs)
156
+
157
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
158
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
159
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
160
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
161
+
162
+ def forward(
163
+ self,
164
+ hidden_states: torch.Tensor,
165
+ attention_mask: Optional[torch.LongTensor] = None,
166
+ position_ids: Optional[torch.LongTensor] = None,
167
+ past_key_value: Optional[Cache] = None,
168
+ output_attentions: bool = False,
169
+ use_cache: bool = False,
170
+ cache_position: Optional[torch.LongTensor] = None,
171
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
172
+ **kwargs: Unpack[FlashAttentionKwargs],
173
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
174
+ if isinstance(past_key_value, StaticCache):
175
+ raise ValueError(
176
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
177
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
178
+ )
179
+
180
+ output_attentions = False
181
+
182
+ bsz, q_len, _ = hidden_states.size()
183
+
184
+ query_states = self.q_proj(hidden_states)
185
+ key_states = self.k_proj(hidden_states)
186
+ value_states = self.v_proj(hidden_states)
187
+
188
+ # Flash attention requires the input to have the shape
189
+ # batch_size x seq_length x head_dim x hidden_dim
190
+ # therefore we just need to keep the original shape
191
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
192
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
193
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
194
+
195
+ query_states = self.q_norm(query_states)
196
+ key_states = self.k_norm(key_states)
197
+
198
+ if position_embeddings is None:
199
+ logger.warning_once(
200
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
201
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
202
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
203
+ "removed and `position_embeddings` will be mandatory."
204
+ )
205
+ cos, sin = self.rotary_emb(value_states, position_ids)
206
+ else:
207
+ cos, sin = position_embeddings
208
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
209
+
210
+ if past_key_value is not None:
211
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
212
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
213
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
214
+
215
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
216
+ # to be able to avoid many of these transpose/reshape/view.
217
+ query_states = query_states.transpose(1, 2)
218
+ key_states = key_states.transpose(1, 2)
219
+ value_states = value_states.transpose(1, 2)
220
+
221
+ dropout_rate = self.attention_dropout if self.training else 0.0
222
+
223
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
224
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
225
+ # cast them back in the correct dtype just to be sure everything works as expected.
226
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
227
+ # in fp32. (LlamaRMSNorm handles it correctly)
228
+
229
+ input_dtype = query_states.dtype
230
+ if input_dtype == torch.float32:
231
+ if torch.is_autocast_enabled():
232
+ target_dtype = torch.get_autocast_gpu_dtype()
233
+ # Handle the case where the model is quantized
234
+ elif hasattr(self.config, "_pre_quantization_dtype"):
235
+ target_dtype = self.config._pre_quantization_dtype
236
+ else:
237
+ target_dtype = self.q_proj.weight.dtype
238
+
239
+ logger.warning_once(
240
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
241
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
242
+ f" {target_dtype}."
243
+ )
244
+
245
+ query_states = query_states.to(target_dtype)
246
+ key_states = key_states.to(target_dtype)
247
+ value_states = value_states.to(target_dtype)
248
+
249
+ attn_output = _flash_attention_forward(
250
+ query_states,
251
+ key_states,
252
+ value_states,
253
+ attention_mask,
254
+ q_len,
255
+ position_ids=position_ids,
256
+ dropout=dropout_rate,
257
+ sliding_window=getattr(self, "sliding_window", None),
258
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
259
+ is_causal=self.is_causal,
260
+ **kwargs,
261
+ )
262
+
263
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
264
+ attn_output = self.o_proj(attn_output)
265
+
266
+ if not output_attentions:
267
+ attn_weights = None
268
+
269
+ return attn_output, attn_weights, past_key_value
270
+
271
+
272
+ class LlamaSdpaAttention(LlamaAttention):
273
+ """
274
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
275
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
276
+ SDPA API.
277
+ """
278
+
279
+ # Adapted from LlamaAttention.forward
280
+ def forward(
281
+ self,
282
+ hidden_states: torch.Tensor,
283
+ attention_mask: Optional[torch.Tensor] = None,
284
+ position_ids: Optional[torch.LongTensor] = None,
285
+ past_key_value: Optional[Cache] = None,
286
+ output_attentions: bool = False,
287
+ use_cache: bool = False,
288
+ cache_position: Optional[torch.LongTensor] = None,
289
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
290
+ **kwargs,
291
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
292
+ if output_attentions:
293
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
294
+ logger.warning_once(
295
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
296
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
297
+ )
298
+ return super().forward(
299
+ hidden_states=hidden_states,
300
+ attention_mask=attention_mask,
301
+ position_ids=position_ids,
302
+ past_key_value=past_key_value,
303
+ output_attentions=output_attentions,
304
+ use_cache=use_cache,
305
+ cache_position=cache_position,
306
+ position_embeddings=position_embeddings,
307
+ )
308
+
309
+ bsz, q_len, _ = hidden_states.size()
310
+
311
+ query_states = self.q_proj(hidden_states)
312
+ key_states = self.k_proj(hidden_states)
313
+ value_states = self.v_proj(hidden_states)
314
+
315
+ # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
316
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
317
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
318
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
319
+
320
+ query_states = self.q_norm(query_states)
321
+ key_states = self.k_norm(key_states)
322
+
323
+ if position_embeddings is None:
324
+ logger.warning_once(
325
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
326
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
327
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
328
+ "removed and `position_embeddings` will be mandatory."
329
+ )
330
+ cos, sin = self.rotary_emb(value_states, position_ids)
331
+ else:
332
+ cos, sin = position_embeddings
333
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
334
+
335
+ if past_key_value is not None:
336
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
337
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
338
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
339
+
340
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
341
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
342
+
343
+ causal_mask = attention_mask
344
+ if attention_mask is not None:
345
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
346
+
347
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
348
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
349
+ if query_states.device.type == "cuda" and causal_mask is not None:
350
+ query_states = query_states.contiguous()
351
+ key_states = key_states.contiguous()
352
+ value_states = value_states.contiguous()
353
+
354
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
355
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
356
+ is_causal = True if causal_mask is None and q_len > 1 else False
357
+
358
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
359
+ query_states,
360
+ key_states,
361
+ value_states,
362
+ attn_mask=causal_mask,
363
+ dropout_p=self.attention_dropout if self.training else 0.0,
364
+ is_causal=is_causal,
365
+ )
366
+
367
+ attn_output = attn_output.transpose(1, 2).contiguous()
368
+ attn_output = attn_output.view(bsz, q_len, -1)
369
+
370
+ attn_output = self.o_proj(attn_output)
371
+
372
+ return attn_output, None, past_key_value
373
+
374
+
375
+ class LlamaFlexAttention(LlamaAttention):
376
+ """
377
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
378
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
379
+ SDPA API.
380
+ """
381
+
382
+ # Adapted from LlamaAttention.forward
383
+ def forward(
384
+ self,
385
+ hidden_states: torch.Tensor,
386
+ attention_mask: Optional[torch.Tensor] = None,
387
+ position_ids: Optional[torch.LongTensor] = None,
388
+ past_key_value: Optional[Cache] = None,
389
+ output_attentions: bool = False,
390
+ use_cache: bool = False,
391
+ cache_position: Optional[torch.LongTensor] = None,
392
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
393
+ **kwargs,
394
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
395
+ bsz, q_len, _ = hidden_states.size()
396
+
397
+ query_states = self.q_proj(hidden_states)
398
+ key_states = self.k_proj(hidden_states)
399
+ value_states = self.v_proj(hidden_states)
400
+
401
+ # use -1 to infer num_heads and num_key_value_heads as they may vary if tensor parallel is used
402
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
403
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
404
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
405
+ dtype = query_states.dtype
406
+
407
+ query_states = self.q_norm(query_states).to(dtype)
408
+ key_states = self.k_norm(key_states).to(dtype)
409
+
410
+ if position_embeddings is None:
411
+ logger.warning_once(
412
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
413
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
414
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
415
+ "removed and `position_embeddings` will be mandatory."
416
+ )
417
+ cos, sin = self.rotary_emb(value_states, position_ids)
418
+ else:
419
+ cos, sin = position_embeddings
420
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
421
+
422
+ if past_key_value is not None:
423
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
424
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
425
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
426
+
427
+
428
+ attn_output, attn_weight = flex_attention_forward(
429
+ query_states,
430
+ key_states,
431
+ value_states,
432
+ attention_mask,
433
+ training=self.training,
434
+ )
435
+ # print(attn_output.shape)
436
+
437
+ attn_output = attn_output.view(bsz, q_len, -1)
438
+ #print(attn_output.shape)
439
+ #print(self.o_proj)
440
+
441
+ attn_output = self.o_proj(attn_output)
442
+
443
+ return attn_output, attn_weight, past_key_value
444
+
445
+
446
+ LLAMA_ATTENTION_CLASSES = {
447
+ "eager": LlamaAttention,
448
+ "flash_attention_2": LlamaFlashAttention2,
449
+ "flex_attention": LlamaFlexAttention,
450
+ "sdpa": LlamaSdpaAttention,
451
+ }
diffrhythm2/backbones/llama_nar.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ASLP Lab and Xiaomi Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from transformers import LlamaConfig
17
+ import torch
18
+
19
+ import torch.nn as nn
20
+ from typing import Optional, Tuple
21
+ import math
22
+
23
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
24
+ from .llama_attention import LLAMA_ATTENTION_CLASSES
25
+
26
+ # sinusoidal positional encoding
27
+ class SinusoidalPosEmb(nn.Module):
28
+ def __init__(self, dim):
29
+ super().__init__()
30
+ self.dim = dim
31
+
32
+ def forward(self, x):
33
+ device = x.device
34
+ half_dim = self.dim // 2
35
+ emb = math.log(10000) / (half_dim - 1)
36
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
37
+ emb = x[:, None] * emb[None, :] * 1.0
38
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
39
+ return emb
40
+
41
+
42
+ class LlamaAdaptiveRMSNorm(nn.Module):
43
+ def __init__(self, hidden_size=1024, eps=1e-6, dim_cond=1024):
44
+ super().__init__()
45
+ self.to_weight = nn.Linear(dim_cond, hidden_size)
46
+ nn.init.zeros_(self.to_weight.weight)
47
+ nn.init.ones_(self.to_weight.bias)
48
+ self.variance_epsilon = eps
49
+ self._is_hf_initialized = True # disable automatic init
50
+
51
+ def forward(self, hidden_states, cond_embedding):
52
+ input_dtype = hidden_states.dtype
53
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
54
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
55
+
56
+ weight = self.to_weight(cond_embedding)
57
+ if len(weight.shape) == 2:
58
+ weight = weight.unsqueeze(1)
59
+
60
+ return (weight * hidden_states).to(input_dtype)
61
+
62
+
63
+ class LlamaNARDecoderLayer(LlamaDecoderLayer):
64
+ def __init__(self, config: LlamaConfig, layer_idx: int, use_flex_attn: bool=False):
65
+ """Override to adaptive layer norm"""
66
+ super().__init__(config, layer_idx) # init attention, mlp, etc.
67
+ _attn_implementation = config._attn_implementation
68
+ if use_flex_attn:
69
+ _attn_implementation = "flex_attention"
70
+ # _attn_implementation = "flash_attention_2"
71
+ self.self_attn = LLAMA_ATTENTION_CLASSES[_attn_implementation](config=config, layer_idx=layer_idx)
72
+ # self.input_layernorm = LlamaAdaptiveRMSNorm(
73
+ # config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
74
+ # )
75
+ # self.post_attention_layernorm = LlamaAdaptiveRMSNorm(
76
+ # config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size
77
+ # )
78
+
79
+ # add `cond` in forward function
80
+ def forward(
81
+ self,
82
+ hidden_states: torch.Tensor,
83
+ attention_mask: Optional[torch.Tensor] = None,
84
+ position_embeddings: Optional[torch.LongTensor] = None,
85
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
86
+ output_attentions: Optional[bool] = False,
87
+ use_cache: Optional[bool] = False,
88
+ ) -> Tuple[
89
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
90
+ ]:
91
+ """
92
+ Args:
93
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
94
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
95
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
96
+ output_attentions (`bool`, *optional*):
97
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
98
+ returned tensors for more detail.
99
+ use_cache (`bool`, *optional*):
100
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
101
+ (see `past_key_values`).
102
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
103
+ """
104
+
105
+ residual = hidden_states
106
+ # print(-1, hidden_states.isnan().sum(), hidden_states.isinf().sum())
107
+ hidden_states = self.input_layernorm(
108
+ hidden_states
109
+ )
110
+ # print(0, hidden_states.isnan().sum(), hidden_states.isinf().sum())
111
+ # Self Attention
112
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
113
+ hidden_states=hidden_states,
114
+ attention_mask=attention_mask,
115
+ position_embeddings=position_embeddings,
116
+ past_key_value=past_key_value,
117
+ output_attentions=output_attentions,
118
+ use_cache=use_cache,
119
+ )
120
+ # print(1, hidden_states.isnan().sum(), hidden_states.isinf().sum())
121
+ hidden_states = residual + hidden_states
122
+ # print(2, hidden_states.isnan().sum(), hidden_states.isinf().sum())
123
+ # Fully Connected
124
+ residual = hidden_states
125
+ hidden_states = self.post_attention_layernorm(
126
+ hidden_states
127
+ )
128
+ # print(3, hidden_states.isnan().sum(), hidden_states.isinf().sum())
129
+ hidden_states = self.mlp(hidden_states)
130
+ hidden_states = residual + hidden_states
131
+ # print(4, hidden_states.isnan().sum(), hidden_states.isinf().sum())
132
+ outputs = [hidden_states,]
133
+
134
+ if output_attentions:
135
+ outputs += [self_attn_weights,]
136
+
137
+ if use_cache:
138
+ outputs += [present_key_value,]
139
+
140
+ return outputs
diffrhythm2/cache_utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ASLP Lab and Xiaomi Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import torch
17
+
18
+ from typing import Optional, List, Tuple, Dict, Any
19
+ from transformers.cache_utils import Cache
20
+ from contextlib import contextmanager
21
+
22
+ class BlockFlowMatchingCache(Cache):
23
+ def __init__(
24
+ self,
25
+ text_lengths: Optional[torch.Tensor] = None,
26
+ block_size: Optional[int] = None,
27
+ num_history_block: Optional[int] = None
28
+ ) -> None:
29
+ super().__init__()
30
+ self._seen_tokens = 0
31
+ self.text_key_cache: List[torch.Tensor] = []
32
+ self.text_value_cache: List[torch.Tensor] = []
33
+ self.key_cache: List[torch.Tensor] = []
34
+ self.value_cache: List[torch.Tensor] = []
35
+ self.text_lengths = text_lengths
36
+ self.block_size = block_size
37
+ self.num_history_block = num_history_block
38
+ self.is_cache_text = False
39
+ self.is_storage_cache = False
40
+ assert (
41
+ (
42
+ self.num_history_block is not None
43
+ and
44
+ self.block_size is not None
45
+ ) or self.num_history_block is None
46
+ ), "num_history_block and block_size must be set at the same time."
47
+
48
+ @contextmanager
49
+ def cache_text(self):
50
+ self.is_cache_text = True
51
+ try:
52
+ yield self
53
+ finally:
54
+ self.is_cache_text = False
55
+
56
+ @contextmanager
57
+ def cache_context(self):
58
+ self.is_storage_cache = True
59
+ try:
60
+ yield self
61
+ finally:
62
+ self.is_storage_cache = False
63
+
64
+ def update(
65
+ self,
66
+ key_states: torch.Tensor,
67
+ value_states: torch.Tensor,
68
+ layer_idx: int,
69
+ cache_kwargs: Optional[Dict[str, Any]] = None,
70
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
71
+ """
72
+ Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
73
+
74
+ Parameters:
75
+ key_states (`torch.Tensor`):
76
+ The new key states to cache.
77
+ value_states (`torch.Tensor`):
78
+ The new value states to cache.
79
+ layer_idx (`int`):
80
+ The index of the layer to cache the states for.
81
+ cache_kwargs (`Dict[str, Any]`, `optional`):
82
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
83
+
84
+ Return:
85
+ A tuple containing the updated key and value states.
86
+ """
87
+ # cache text
88
+ if self.is_cache_text:
89
+ if self.text_lengths is None:
90
+ self.text_lengths = torch.LongTensor([key_states.shape[-2]] * key_states.shape[0])
91
+ self.text_key_cache.append(key_states)
92
+ self.text_value_cache.append(value_states)
93
+ return self.text_key_cache[layer_idx], self.text_value_cache[layer_idx]
94
+
95
+ # Update the number of seen tokens
96
+ if layer_idx == 0:
97
+ self._seen_tokens += key_states.shape[-2]
98
+
99
+ # Update the cache
100
+ if key_states is not None:
101
+ if len(self.key_cache) <= layer_idx:
102
+ # There may be skipped layers, fill them with empty lists
103
+ for _ in range(len(self.key_cache), layer_idx + 1):
104
+ self.key_cache.append([])
105
+ self.value_cache.append([])
106
+ cached_key_state = self.key_cache[layer_idx]
107
+ cached_value_state = self.value_cache[layer_idx]
108
+ if len(cached_key_state) != 0:
109
+ key_states = torch.cat([cached_key_state, key_states], dim=-2)
110
+ value_states = torch.cat([cached_value_state, value_states], dim=-2)
111
+ if self.num_history_block is not None:
112
+ history_length = self.block_size * (self.num_history_block + 1)
113
+ key_states = key_states[:, :, -history_length:, :]
114
+ value_states = value_states[:, :, -history_length:, :]
115
+ if self.is_storage_cache:
116
+ self.key_cache[layer_idx] = key_states
117
+ self.value_cache[layer_idx] = value_states
118
+
119
+ k_s = []
120
+ v_s = []
121
+
122
+ text_key_cache = (
123
+ self.text_key_cache[layer_idx]
124
+ if len(self.text_key_cache) > layer_idx
125
+ else torch.zeros(key_states.shape[0], key_states.shape[1], 0, key_states.shape[3], device=key_states.device, dtype=key_states.dtype)
126
+ )
127
+ text_value_cache = (
128
+ self.text_value_cache[layer_idx]
129
+ if len(self.text_value_cache) > layer_idx
130
+ else torch.zeros(value_states.shape[0], value_states.shape[1], 0, value_states.shape[3], device=value_states.device, dtype=value_states.dtype)
131
+ )
132
+ for b in range(self.text_lengths.shape[0]):
133
+ k_s.append(torch.cat([text_key_cache[b][:, :self.text_lengths[b], :], key_states[b]], dim=-2))
134
+ v_s.append(torch.cat([text_value_cache[b][:, :self.text_lengths[b], :], value_states[b]], dim=-2))
135
+ k_s = torch.nn.utils.rnn.pad_sequence(k_s, batch_first=True)
136
+ v_s = torch.nn.utils.rnn.pad_sequence(v_s, batch_first=True)
137
+
138
+ return k_s, v_s
139
+
140
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
141
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
142
+ # TODO: deprecate this function in favor of `cache_position`
143
+ is_empty_layer = (
144
+ len(self.key_cache) == 0 # no cache in any layer
145
+ or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
146
+ or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
147
+ )
148
+ layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
149
+ return layer_seq_length
150
+
151
+ def get_max_cache_shape(self) -> Optional[int]:
152
+ """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length."""
153
+ return None
154
+
diffrhythm2/cfm.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 ASLP Lab and Xiaomi Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+ import torch
17
+ from torch import nn
18
+ from tqdm import tqdm
19
+
20
+ from torchdiffeq import odeint
21
+ from .backbones.dit import DiT
22
+ from .cache_utils import BlockFlowMatchingCache
23
+ from torch.nn.attention.flex_attention import create_block_mask
24
+
25
+ def all_mask(b, h, q_idx, kv_idx):
26
+ return q_idx == q_idx
27
+
28
+
29
+ class CFM(nn.Module):
30
+ def __init__(
31
+ self,
32
+ transformer: DiT,
33
+ sigma=0.0,
34
+ odeint_kwargs: dict = dict(
35
+ # atol = 1e-5,
36
+ # rtol = 1e-5,
37
+ method="euler" # 'midpoint'
38
+ # method="adaptive_heun"
39
+ ),
40
+ odeint_options: dict = dict(
41
+ min_step=0.05
42
+ ),
43
+ num_channels=None,
44
+ block_size=None,
45
+ num_history_block=None
46
+ ):
47
+ super().__init__()
48
+
49
+ self.num_channels = num_channels
50
+
51
+ # transformer
52
+ self.transformer = transformer
53
+ dim = transformer.dim
54
+ self.dim = dim
55
+
56
+ # conditional flow related
57
+ self.sigma = sigma
58
+
59
+ # sampling related
60
+ self.odeint_kwargs = odeint_kwargs
61
+ print(f"ODE SOLVER: {self.odeint_kwargs['method']}")
62
+
63
+ self.odeint_options = odeint_options
64
+ self.block_size = block_size
65
+ self.num_history_block = num_history_block
66
+ if self.num_history_block is not None and self.num_history_block <= 0:
67
+ self.num_history_block = None
68
+
69
+ print(f"block_size: {self.block_size}; num_history_block: {self.num_history_block}")
70
+
71
+ @property
72
+ def device(self):
73
+ return next(self.parameters()).device
74
+
75
+ @torch.no_grad()
76
+ def sample_block_cache(
77
+ self,
78
+ text,
79
+ duration, # noqa: F821
80
+ style_prompt,
81
+ steps=32,
82
+ cfg_strength=1.0,
83
+ odeint_method='euler'
84
+ ):
85
+ self.eval()
86
+
87
+ batch = text.shape[0]
88
+ device = self.device
89
+ num_blocks = duration // self.block_size + (duration % self.block_size > 0)
90
+
91
+ text_emb = self.transformer.text_embed(text)
92
+ cfg_text_emb = self.transformer.text_embed(torch.zeros_like(text))
93
+ text_lens = torch.LongTensor([text_emb.shape[1]]).to(device)
94
+ clean_emb_stream = torch.zeros(batch, 0, self.num_channels, device=device, dtype=text_emb.dtype)
95
+ noisy_lens = torch.LongTensor([self.block_size]).to(device)
96
+ block_iterator = range(num_blocks)
97
+
98
+ # create cache
99
+ kv_cache = BlockFlowMatchingCache(text_lengths=text_lens, num_history_block=self.num_history_block)
100
+ cfg_kv_cache = BlockFlowMatchingCache(text_lengths=text_lens, num_history_block=self.num_history_block)
101
+ cache_time = torch.tensor([1], device=device)[:, None].repeat(batch, self.block_size).to(style_prompt.dtype)
102
+
103
+ # generate text cache
104
+ text_time = torch.tensor([-1], device=device)[:, None].repeat(batch, text_emb.shape[1]).to(style_prompt.dtype)
105
+ text_position_ids = torch.arange(0, text_emb.shape[1], device=device)[None, :].repeat(batch, 1)
106
+ text_attn_mask = torch.ones(batch, 1, text_emb.shape[1], text_emb.shape[1], device=device).bool()
107
+ # text_attn_mask = create_block_mask(
108
+ # all_mask,
109
+ # B = batch,
110
+ # H = None,
111
+ # Q_LEN=text_emb.shape[1],
112
+ # KV_LEN=text_emb.shape[1]
113
+ # )
114
+
115
+ if text_emb.shape[1] != 0:
116
+ with kv_cache.cache_text():
117
+ _, _, kv_cache = self.transformer(
118
+ x = text_emb,
119
+ time=text_time,
120
+ attn_mask=text_attn_mask,
121
+ position_ids=text_position_ids,
122
+ style_prompt=style_prompt,
123
+ use_cache=True,
124
+ past_key_value = kv_cache
125
+ )
126
+ with cfg_kv_cache.cache_text():
127
+ _, _, cfg_kv_cache = self.transformer(
128
+ x = cfg_text_emb,
129
+ time=text_time,
130
+ attn_mask=text_attn_mask,
131
+ position_ids=text_position_ids,
132
+ style_prompt=torch.zeros_like(style_prompt),
133
+ use_cache=True,
134
+ past_key_value = cfg_kv_cache
135
+ )
136
+
137
+ end_pos = 0
138
+ for bid in block_iterator:
139
+ clean_lens = torch.LongTensor([clean_emb_stream.shape[1]]).to(device)
140
+ #print(text_lens, clean_lens, noisy_lens, clean_emb_stream.shape, flush=True)
141
+
142
+ # all one mask
143
+ attn_mask = torch.ones(batch, 1, noisy_lens.max(), (text_lens + clean_lens + noisy_lens).max(), device=device).bool() # [B, 1, Q, KV]
144
+ # attn_mask = create_block_mask(
145
+ # all_mask,
146
+ # B = batch,
147
+ # H = None,
148
+ # Q_LEN=noisy_lens.max(),
149
+ # KV_LEN=(text_lens + clean_lens + noisy_lens).max()
150
+ # )
151
+
152
+ # generate position id
153
+ position_ids = torch.arange(0, (clean_lens + noisy_lens).max(), device=device)[None, :].repeat(batch, 1)
154
+ position_ids = position_ids[:, -noisy_lens.max():]
155
+
156
+ # core sample fn
157
+ def fn(t, x):
158
+ noisy_embed = self.transformer.latent_embed(x)
159
+
160
+ if t.ndim == 0:
161
+ t = t.repeat(batch)
162
+ time = t[:, None].repeat(1, noisy_lens.max())
163
+
164
+ pred, *_ = self.transformer(
165
+ x=noisy_embed,
166
+ time=time,
167
+ attn_mask=attn_mask,
168
+ position_ids=position_ids,
169
+ style_prompt=style_prompt,
170
+ use_cache=True,
171
+ past_key_value = kv_cache
172
+ )
173
+ if cfg_strength < 1e-5:
174
+ return pred
175
+
176
+ null_pred, *_ = self.transformer(
177
+ x=noisy_embed,
178
+ time=time,
179
+ attn_mask=attn_mask,
180
+ position_ids=position_ids,
181
+ style_prompt=torch.zeros_like(style_prompt),
182
+ use_cache=True,
183
+ past_key_value = cfg_kv_cache
184
+ )
185
+
186
+ return pred + (pred - null_pred) * cfg_strength
187
+
188
+ # generate time
189
+ noisy_emb = torch.randn(batch, self.block_size, self.num_channels, device=device, dtype=style_prompt.dtype)
190
+ t_start = 0
191
+ t_set = torch.linspace(t_start, 1, steps, device=device, dtype=noisy_emb.dtype)
192
+
193
+ # sampling
194
+ outputs = odeint(fn, noisy_emb, t_set, method=odeint_method)
195
+ sampled = outputs[-1]
196
+
197
+ # generate next kv cache
198
+ cache_embed = self.transformer.latent_embed(sampled)
199
+ with kv_cache.cache_context():
200
+ _, _, kv_cache = self.transformer(
201
+ x = cache_embed,
202
+ time=cache_time,
203
+ attn_mask=attn_mask,
204
+ position_ids=position_ids,
205
+ style_prompt=style_prompt,
206
+ use_cache=True,
207
+ past_key_value = kv_cache
208
+ )
209
+ with cfg_kv_cache.cache_context():
210
+ _, _, cfg_kv_cache = self.transformer(
211
+ x = cache_embed,
212
+ time=cache_time,
213
+ attn_mask=attn_mask,
214
+ position_ids=position_ids,
215
+ style_prompt=torch.zeros_like(style_prompt),
216
+ use_cache=True,
217
+ past_key_value = cfg_kv_cache
218
+ )
219
+
220
+ # push new block
221
+ clean_emb_stream = torch.cat([clean_emb_stream, sampled], dim=1)
222
+
223
+ pos = -1
224
+ curr_frame = clean_emb_stream[:, pos, :]
225
+ eos = torch.ones_like(curr_frame)
226
+ last_kl = torch.nn.functional.mse_loss(
227
+ curr_frame,
228
+ eos
229
+ )
230
+ if last_kl.abs() <= 0.05:
231
+ while last_kl.abs() <= 0.05 and abs(pos) < clean_emb_stream.shape[1]:
232
+ pos -= 1
233
+ curr_frame = clean_emb_stream[:, pos, :]
234
+ last_kl = torch.nn.functional.mse_loss(
235
+ curr_frame,
236
+ eos
237
+ )
238
+ end_pos = clean_emb_stream.shape[1] + pos
239
+ break
240
+ else:
241
+ end_pos = clean_emb_stream.shape[1]
242
+
243
+ clean_emb_stream = clean_emb_stream[:, :end_pos, :]
244
+
245
+ return clean_emb_stream
246
+
247
+ def sample_cache_stream(
248
+ self,
249
+ decoder,
250
+ text,
251
+ duration, # noqa: F821
252
+ style_prompt,
253
+ steps=32,
254
+ cfg_strength=1.0,
255
+ seed: int | None = None,
256
+ chunk_size=10,
257
+ overlap=2,
258
+ odeint_method='euler'
259
+ ):
260
+ self.eval()
261
+
262
+ batch = text.shape[0]
263
+ device = self.device
264
+ num_blocks = duration // self.block_size + (duration % self.block_size > 0)
265
+
266
+ text_emb = self.transformer.text_embed(text)
267
+ cfg_text_emb = self.transformer.text_embed(torch.zeros_like(text))
268
+ text_lens = torch.LongTensor([text_emb.shape[1]]).to(device)
269
+ clean_emb_stream = torch.zeros(batch, 0, self.num_channels, device=device, dtype=text_emb.dtype)
270
+ noisy_lens = torch.LongTensor([self.block_size]).to(device)
271
+ block_iterator = range(num_blocks)
272
+ # create cache
273
+ kv_cache = BlockFlowMatchingCache(text_lengths=text_lens, num_history_block=self.num_history_block)
274
+ cfg_kv_cache = BlockFlowMatchingCache(text_lengths=text_lens, num_history_block=self.num_history_block)
275
+ cache_time = torch.tensor([1], device=device)[:, None].repeat(batch, self.block_size).to(style_prompt.dtype)
276
+
277
+ # generate text cache
278
+ text_time = torch.tensor([-1], device=device)[:, None].repeat(batch, text_emb.shape[1]).to(style_prompt.dtype)
279
+ text_position_ids = torch.arange(0, text_emb.shape[1], device=device)[None, :].repeat(batch, 1)
280
+ text_attn_mask = torch.ones(batch, 1, text_emb.shape[1], text_emb.shape[1], device=device).bool()
281
+
282
+ if text_emb.shape[1] != 0:
283
+ with kv_cache.cache_text():
284
+ _, _, kv_cache = self.transformer(
285
+ x = text_emb,
286
+ time=text_time,
287
+ attn_mask=text_attn_mask,
288
+ position_ids=text_position_ids,
289
+ style_prompt=style_prompt,
290
+ use_cache=True,
291
+ past_key_value = kv_cache
292
+ )
293
+ with cfg_kv_cache.cache_text():
294
+ _, _, cfg_kv_cache = self.transformer(
295
+ x = cfg_text_emb,
296
+ time=text_time,
297
+ attn_mask=text_attn_mask,
298
+ position_ids=text_position_ids,
299
+ style_prompt=torch.zeros_like(style_prompt),
300
+ use_cache=True,
301
+ past_key_value = cfg_kv_cache
302
+ )
303
+
304
+ end_pos = 0
305
+ last_decoder_pos = 0
306
+ decode_audio = []
307
+ for bid in block_iterator:
308
+ clean_lens = torch.LongTensor([clean_emb_stream.shape[1]]).to(device)
309
+ #print(text_lens, clean_lens, noisy_lens, clean_emb_stream.shape, flush=True)
310
+
311
+ # all one mask
312
+ attn_mask = torch.ones(batch, 1, noisy_lens.max(), (text_lens + clean_lens + noisy_lens).max(), device=device).bool() # [B, 1, Q, KV]
313
+
314
+ # generate position id
315
+ position_ids = torch.arange(0, (clean_lens + noisy_lens).max(), device=device)[None, :].repeat(batch, 1)
316
+ position_ids = position_ids[:, -noisy_lens.max():]
317
+
318
+ # core sample fn
319
+ def fn(t, x):
320
+ noisy_embed = self.transformer.latent_embed(x)
321
+
322
+ if t.ndim == 0:
323
+ t = t.repeat(batch)
324
+ time = t[:, None].repeat(1, noisy_lens.max())
325
+
326
+ pred, *_ = self.transformer(
327
+ x=noisy_embed,
328
+ time=time,
329
+ attn_mask=attn_mask,
330
+ position_ids=position_ids,
331
+ style_prompt=style_prompt,
332
+ use_cache=True,
333
+ past_key_value = kv_cache
334
+ )
335
+ if cfg_strength < 1e-5:
336
+ return pred
337
+
338
+ null_pred, *_ = self.transformer(
339
+ x=noisy_embed,
340
+ time=time,
341
+ attn_mask=attn_mask,
342
+ position_ids=position_ids,
343
+ style_prompt=torch.zeros_like(style_prompt),
344
+ use_cache=True,
345
+ past_key_value = cfg_kv_cache
346
+ )
347
+
348
+ return pred + (pred - null_pred) * cfg_strength
349
+
350
+ # generate time
351
+ noisy_emb = torch.randn(batch, self.block_size, self.num_channels, device=device, dtype=style_prompt.dtype)
352
+ t_start = 0
353
+ t_set = torch.linspace(t_start, 1, steps, device=device, dtype=noisy_emb.dtype)
354
+
355
+ # sampling
356
+ outputs = odeint(fn, noisy_emb, t_set, method=odeint_method)
357
+ sampled = outputs[-1]
358
+
359
+ # generate next kv cache
360
+ cache_embed = self.transformer.latent_embed(sampled)
361
+ with kv_cache.cache_context():
362
+ _, _, kv_cache = self.transformer(
363
+ x = cache_embed,
364
+ time=cache_time,
365
+ attn_mask=attn_mask,
366
+ position_ids=position_ids,
367
+ style_prompt=style_prompt,
368
+ use_cache=True,
369
+ past_key_value = kv_cache
370
+ )
371
+ with cfg_kv_cache.cache_context():
372
+ _, _, cfg_kv_cache = self.transformer(
373
+ x = cache_embed,
374
+ time=cache_time,
375
+ attn_mask=attn_mask,
376
+ position_ids=position_ids,
377
+ style_prompt=torch.zeros_like(style_prompt),
378
+ use_cache=True,
379
+ past_key_value = cfg_kv_cache
380
+ )
381
+
382
+ # push new block
383
+ clean_emb_stream = torch.cat([clean_emb_stream, sampled], dim=1)
384
+
385
+ pos = -1
386
+ curr_frame = clean_emb_stream[:, pos, :]
387
+ eos = torch.ones_like(curr_frame)
388
+ last_kl = torch.nn.functional.mse_loss(
389
+ curr_frame,
390
+ eos
391
+ )
392
+ if last_kl.abs() <= 0.05:
393
+ while last_kl.abs() <= 0.05 and abs(pos) < clean_emb_stream.shape[1]:
394
+ pos -= 1
395
+ curr_frame = clean_emb_stream[:, pos, :]
396
+ last_kl = torch.nn.functional.mse_loss(
397
+ curr_frame,
398
+ eos
399
+ )
400
+ end_pos = clean_emb_stream.shape[1] + pos
401
+ break
402
+ else:
403
+ end_pos = clean_emb_stream.shape[1]
404
+ if end_pos - last_decoder_pos >= chunk_size:
405
+ start = max(0, last_decoder_pos - overlap)
406
+ overlap_frame = max(0, last_decoder_pos - start)
407
+ latent = clean_emb_stream[:, start:end_pos, :]
408
+ audio = decoder.decoder(latent.transpose(1, 2)) # [B, C, T]
409
+ # print(last_decoder_pos, start, end_pos, latent.shape, audio.shape, clean_emb_stream.shape, chunk_size, overlap_frame, last_decoder_pos-overlap, last_decoder_pos-start)
410
+ audio = audio[:, :, overlap_frame * 9600:]
411
+ print(audio.shape)
412
+ yield audio
413
+ last_decoder_pos = end_pos
414
+
415
+ clean_emb_stream = clean_emb_stream[:, :end_pos, :]
416
+ start = max(0, last_decoder_pos - overlap)
417
+ overlap = max(0, last_decoder_pos - start)
418
+ latent = clean_emb_stream[:, start:end_pos, :]
419
+ audio = decoder.decoder(latent.transpose(1, 2)) # [B, C, T]
420
+ audio = audio[:, :, overlap * 9600:]
421
+ print("last", audio.shape)
422
+ audio = torch.cat([audio, torch.zeros(audio.shape[0], audio.shape[1], 5, device=audio.device, dtype=audio.dtype)], dim=-1)
423
+ print(audio.shape)
424
+ yield audio
425
+
g2p/__init__.py ADDED
File without changes
g2p/g2p/__init__.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from g2p.g2p import cleaners
7
+ from tokenizers import Tokenizer
8
+ from g2p.g2p.text_tokenizers import TextTokenizer
9
+ from g2p.language_segmentation import LangSegment as LS
10
+ import json
11
+ import re
12
+
13
+ LangSegment = LS()
14
+
15
+ class PhonemeBpeTokenizer:
16
+ def __init__(self, vacab_path="./f5_tts/g2p/g2p/vocab.json"):
17
+ self.lang2backend = {
18
+ "zh": "cmn",
19
+ "ja": "ja",
20
+ "en": "en-us",
21
+ "fr": "fr-fr",
22
+ "ko": "ko",
23
+ "de": "de",
24
+ }
25
+ self.text_tokenizers = {}
26
+ self.int_text_tokenizers()
27
+
28
+ with open(vacab_path, "r") as f:
29
+ json_data = f.read()
30
+ data = json.loads(json_data)
31
+ self.vocab = data["vocab"]
32
+ LangSegment.setfilters(["en", "zh", "ja", "ko", "fr", "de"])
33
+
34
+ def int_text_tokenizers(self):
35
+ for key, value in self.lang2backend.items():
36
+ self.text_tokenizers[key] = TextTokenizer(language=value)
37
+
38
+ def tokenize(self, text, sentence, language):
39
+
40
+ # 1. convert text to phoneme
41
+ phonemes = []
42
+ if language == "auto":
43
+ seglist = LangSegment.getTexts(text)
44
+ tmp_ph = []
45
+ for seg in seglist:
46
+ tmp_ph.append(
47
+ self._clean_text(
48
+ seg["text"], sentence, seg["lang"], ["cjekfd_cleaners"]
49
+ )
50
+ )
51
+ phonemes = "|_|".join(tmp_ph)
52
+ else:
53
+ phonemes = self._clean_text(text, sentence, language, ["cjekfd_cleaners"])
54
+ # print('clean text: ', phonemes)
55
+
56
+ # 2. tokenize phonemes
57
+ phoneme_tokens = self.phoneme2token(phonemes)
58
+ # print('encode: ', phoneme_tokens)
59
+
60
+ # # 3. decode tokens [optional]
61
+ # decoded_text = self.tokenizer.decode(phoneme_tokens)
62
+ # print('decoded: ', decoded_text)
63
+
64
+ return phonemes, phoneme_tokens
65
+
66
+ def _clean_text(self, text, sentence, language, cleaner_names):
67
+ for name in cleaner_names:
68
+ cleaner = getattr(cleaners, name)
69
+ if not cleaner:
70
+ raise Exception("Unknown cleaner: %s" % name)
71
+ text = cleaner(text, sentence, language, self.text_tokenizers)
72
+ return text
73
+
74
+ def phoneme2token(self, phonemes):
75
+ tokens = []
76
+ if isinstance(phonemes, list):
77
+ for phone in phonemes:
78
+ phone = phone.split("\t")[0]
79
+ phonemes_split = phone.split("|")
80
+ tokens.append(
81
+ [self.vocab[p] for p in phonemes_split if p in self.vocab]
82
+ )
83
+ else:
84
+ phonemes = phonemes.split("\t")[0]
85
+ phonemes_split = phonemes.split("|")
86
+ tokens = [self.vocab[p] for p in phonemes_split if p in self.vocab]
87
+ return tokens
g2p/g2p/chinese_model_g2p.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import DataLoader
10
+ import json
11
+ from transformers import BertTokenizer
12
+ from torch.utils.data import Dataset
13
+ from transformers.models.bert.modeling_bert import *
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from onnxruntime import InferenceSession, GraphOptimizationLevel, SessionOptions
17
+
18
+
19
+ class PolyDataset(Dataset):
20
+ def __init__(self, words, labels, word_pad_idx=0, label_pad_idx=-1):
21
+ self.dataset = self.preprocess(words, labels)
22
+ self.word_pad_idx = word_pad_idx
23
+ self.label_pad_idx = label_pad_idx
24
+
25
+ def preprocess(self, origin_sentences, origin_labels):
26
+ """
27
+ Maps tokens and tags to their indices and stores them in the dict data.
28
+ examples:
29
+ word:['[CLS]', 'ๆต™', 'ๅ•†', '้“ถ', '่กŒ', 'ไผ', 'ไธš', 'ไฟก', '่ดท', '้ƒจ']
30
+ sentence:([101, 3851, 1555, 7213, 6121, 821, 689, 928, 6587, 6956],
31
+ array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))
32
+ label:[3, 13, 13, 13, 0, 0, 0, 0, 0]
33
+ """
34
+ data = []
35
+ labels = []
36
+ sentences = []
37
+ # tokenize
38
+ for line in origin_sentences:
39
+ # replace each token by its index
40
+ # we can not use encode_plus because our sentences are aligned to labels in list type
41
+ words = []
42
+ word_lens = []
43
+ for token in line:
44
+ words.append(token)
45
+ word_lens.append(1)
46
+ token_start_idxs = 1 + np.cumsum([0] + word_lens[:-1])
47
+ sentences.append(((words, token_start_idxs), 0))
48
+ ###
49
+ for tag in origin_labels:
50
+ labels.append(tag)
51
+
52
+ for sentence, label in zip(sentences, labels):
53
+ data.append((sentence, label))
54
+ return data
55
+
56
+ def __getitem__(self, idx):
57
+ """sample data to get batch"""
58
+ word = self.dataset[idx][0]
59
+ label = self.dataset[idx][1]
60
+ return [word, label]
61
+
62
+ def __len__(self):
63
+ """get dataset size"""
64
+ return len(self.dataset)
65
+
66
+ def collate_fn(self, batch):
67
+
68
+ sentences = [x[0][0] for x in batch]
69
+ ori_sents = [x[0][1] for x in batch]
70
+ labels = [x[1] for x in batch]
71
+ batch_len = len(sentences)
72
+
73
+ # compute length of longest sentence in batch
74
+ max_len = max([len(s[0]) for s in sentences])
75
+ max_label_len = 0
76
+ batch_data = np.ones((batch_len, max_len))
77
+ batch_label_starts = []
78
+
79
+ # padding and aligning
80
+ for j in range(batch_len):
81
+ cur_len = len(sentences[j][0])
82
+ batch_data[j][:cur_len] = sentences[j][0]
83
+ label_start_idx = sentences[j][-1]
84
+ label_starts = np.zeros(max_len)
85
+ label_starts[[idx for idx in label_start_idx if idx < max_len]] = 1
86
+ batch_label_starts.append(label_starts)
87
+ max_label_len = max(int(sum(label_starts)), max_label_len)
88
+
89
+ # padding label
90
+ batch_labels = self.label_pad_idx * np.ones((batch_len, max_label_len))
91
+ batch_pmasks = self.label_pad_idx * np.ones((batch_len, max_label_len))
92
+ for j in range(batch_len):
93
+ cur_tags_len = len(labels[j])
94
+ batch_labels[j][:cur_tags_len] = labels[j]
95
+ batch_pmasks[j][:cur_tags_len] = [
96
+ 1 if item > 0 else 0 for item in labels[j]
97
+ ]
98
+
99
+ # convert data to torch LongTensors
100
+ batch_data = torch.tensor(batch_data, dtype=torch.long)
101
+ batch_label_starts = torch.tensor(batch_label_starts, dtype=torch.long)
102
+ batch_labels = torch.tensor(batch_labels, dtype=torch.long)
103
+ batch_pmasks = torch.tensor(batch_pmasks, dtype=torch.long)
104
+ return [batch_data, batch_label_starts, batch_labels, batch_pmasks, ori_sents]
105
+
106
+
107
+ class BertPolyPredict:
108
+ def __init__(self, bert_model, jsonr_file, json_file):
109
+ self.tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=True)
110
+ with open(jsonr_file, "r", encoding="utf8") as fp:
111
+ self.pron_dict = json.load(fp)
112
+ with open(json_file, "r", encoding="utf8") as fp:
113
+ self.pron_dict_id_2_pinyin = json.load(fp)
114
+ self.num_polyphone = len(self.pron_dict)
115
+ self.device = "cpu"
116
+ self.polydataset = PolyDataset
117
+ options = SessionOptions() # initialize session options
118
+ options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
119
+ print(os.path.join(bert_model, "poly_bert_model.onnx"))
120
+ self.session = InferenceSession(
121
+ os.path.join(bert_model, "poly_bert_model.onnx"),
122
+ sess_options=options,
123
+ providers=[
124
+ "CUDAExecutionProvider",
125
+ "CPUExecutionProvider",
126
+ ], # CPUExecutionProvider #CUDAExecutionProvider
127
+ )
128
+ # self.session.set_providers(['CUDAExecutionProvider', "CPUExecutionProvider"], [ {'device_id': 0}])
129
+
130
+ # disable session.run() fallback mechanism, it prevents for a reset of the execution provider
131
+ self.session.disable_fallback()
132
+
133
+ def predict_process(self, txt_list):
134
+ word_test, label_test, texts_test = self.get_examples_po(txt_list)
135
+ data = self.polydataset(word_test, label_test)
136
+ predict_loader = DataLoader(
137
+ data, batch_size=1, shuffle=False, collate_fn=data.collate_fn
138
+ )
139
+ pred_tags = self.predict_onnx(predict_loader)
140
+ return pred_tags
141
+
142
+ def predict_onnx(self, dev_loader):
143
+ pred_tags = []
144
+ with torch.no_grad():
145
+ for idx, batch_samples in enumerate(dev_loader):
146
+ # [batch_data, batch_label_starts, batch_labels, batch_pmasks, ori_sents]
147
+ batch_data, batch_label_starts, batch_labels, batch_pmasks, _ = (
148
+ batch_samples
149
+ )
150
+ # shift tensors to GPU if available
151
+ batch_data = batch_data.to(self.device)
152
+ batch_label_starts = batch_label_starts.to(self.device)
153
+ batch_labels = batch_labels.to(self.device)
154
+ batch_pmasks = batch_pmasks.to(self.device)
155
+ batch_data = np.asarray(batch_data, dtype=np.int32)
156
+ batch_pmasks = np.asarray(batch_pmasks, dtype=np.int32)
157
+ # batch_output = self.session.run(output_names=['outputs'], input_feed={"input_ids":batch_data, "input_pmasks": batch_pmasks})[0][0]
158
+ batch_output = self.session.run(
159
+ output_names=["outputs"], input_feed={"input_ids": batch_data}
160
+ )[0]
161
+ label_masks = batch_pmasks == 1
162
+ batch_labels = batch_labels.to("cpu").numpy()
163
+ for i, indices in enumerate(np.argmax(batch_output, axis=2)):
164
+ for j, idx in enumerate(indices):
165
+ if label_masks[i][j]:
166
+ # pred_tag.append(idx)
167
+ pred_tags.append(self.pron_dict_id_2_pinyin[str(idx + 1)])
168
+ return pred_tags
169
+
170
+ def get_examples_po(self, text_list):
171
+
172
+ word_list = []
173
+ label_list = []
174
+ sentence_list = []
175
+ id = 0
176
+ for line in [text_list]:
177
+ sentence = line[0]
178
+ words = []
179
+ tokens = line[0]
180
+ index = line[-1]
181
+ front = index
182
+ back = len(tokens) - index - 1
183
+ labels = [0] * front + [1] + [0] * back
184
+ words = ["[CLS]"] + [item for item in sentence]
185
+ words = self.tokenizer.convert_tokens_to_ids(words)
186
+ word_list.append(words)
187
+ label_list.append(labels)
188
+ sentence_list.append(sentence)
189
+
190
+ id += 1
191
+ # mask_list.append(masks)
192
+ assert len(labels) + 1 == len(words), print(
193
+ (
194
+ poly,
195
+ sentence,
196
+ words,
197
+ labels,
198
+ sentence,
199
+ len(sentence),
200
+ len(words),
201
+ len(labels),
202
+ )
203
+ )
204
+ assert len(labels) + 1 == len(
205
+ words
206
+ ), "Number of labels does not match number of words"
207
+ assert len(labels) == len(
208
+ sentence
209
+ ), "Number of labels does not match number of sentences"
210
+ assert len(word_list) == len(
211
+ label_list
212
+ ), "Number of label sentences does not match number of word sentences"
213
+ return word_list, label_list, text_list
g2p/g2p/cleaners.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+ from g2p.g2p.japanese import japanese_to_ipa
8
+ from g2p.g2p.mandarin import chinese_to_ipa
9
+ from g2p.g2p.english import english_to_ipa
10
+ from g2p.g2p.french import french_to_ipa
11
+ from g2p.g2p.korean import korean_to_ipa
12
+ from g2p.g2p.german import german_to_ipa
13
+
14
+
15
+ def cjekfd_cleaners(text, sentence, language, text_tokenizers):
16
+
17
+ if language == "zh":
18
+ return chinese_to_ipa(text, sentence, text_tokenizers["zh"])
19
+ elif language == "ja":
20
+ return japanese_to_ipa(text, text_tokenizers["ja"])
21
+ elif language == "en":
22
+ return english_to_ipa(text, text_tokenizers["en"])
23
+ elif language == "fr":
24
+ return french_to_ipa(text, text_tokenizers["fr"])
25
+ elif language == "ko":
26
+ return korean_to_ipa(text, text_tokenizers["ko"])
27
+ elif language == "de":
28
+ return german_to_ipa(text, text_tokenizers["de"])
29
+ else:
30
+ raise Exception("Unknown language: %s" % language)
31
+ return None
g2p/g2p/english.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+ from unidecode import unidecode
8
+ import inflect
9
+
10
+ """
11
+ Text clean time
12
+ """
13
+ _inflect = inflect.engine()
14
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
15
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
16
+ _percent_number_re = re.compile(r"([0-9\.\,]*[0-9]+%)")
17
+ _pounds_re = re.compile(r"ยฃ([0-9\,]*[0-9]+)")
18
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
19
+ _fraction_re = re.compile(r"([0-9]+)/([0-9]+)")
20
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
21
+ _number_re = re.compile(r"[0-9]+")
22
+
23
+ # List of (regular expression, replacement) pairs for abbreviations:
24
+ _abbreviations = [
25
+ (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
26
+ for x in [
27
+ ("mrs", "misess"),
28
+ ("mr", "mister"),
29
+ ("dr", "doctor"),
30
+ ("st", "saint"),
31
+ ("co", "company"),
32
+ ("jr", "junior"),
33
+ ("maj", "major"),
34
+ ("gen", "general"),
35
+ ("drs", "doctors"),
36
+ ("rev", "reverend"),
37
+ ("lt", "lieutenant"),
38
+ ("hon", "honorable"),
39
+ ("sgt", "sergeant"),
40
+ ("capt", "captain"),
41
+ ("esq", "esquire"),
42
+ ("ltd", "limited"),
43
+ ("col", "colonel"),
44
+ ("ft", "fort"),
45
+ ("etc", "et cetera"),
46
+ ("btw", "by the way"),
47
+ ]
48
+ ]
49
+
50
+ _special_map = [
51
+ ("t|ษน", "tษน"),
52
+ ("d|ษน", "dษน"),
53
+ ("t|s", "ts"),
54
+ ("d|z", "dz"),
55
+ ("ษช|ษน", "ษชษน"),
56
+ ("ษ", "ษš"),
57
+ ("แตป", "ษช"),
58
+ ("ษ™l", "l"),
59
+ ("x", "k"),
60
+ ("ษฌ", "l"),
61
+ ("ส”", "t"),
62
+ ("nฬฉ", "n"),
63
+ ("oห|ษน", "oหษน"),
64
+ ]
65
+
66
+
67
+ def expand_abbreviations(text):
68
+ for regex, replacement in _abbreviations:
69
+ text = re.sub(regex, replacement, text)
70
+ return text
71
+
72
+
73
+ def _remove_commas(m):
74
+ return m.group(1).replace(",", "")
75
+
76
+
77
+ def _expand_decimal_point(m):
78
+ return m.group(1).replace(".", " point ")
79
+
80
+
81
+ def _expand_percent(m):
82
+ return m.group(1).replace("%", " percent ")
83
+
84
+
85
+ def _expand_dollars(m):
86
+ match = m.group(1)
87
+ parts = match.split(".")
88
+ if len(parts) > 2:
89
+ return " " + match + " dollars " # Unexpected format
90
+ dollars = int(parts[0]) if parts[0] else 0
91
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
92
+ if dollars and cents:
93
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
94
+ cent_unit = "cent" if cents == 1 else "cents"
95
+ return " %s %s, %s %s " % (dollars, dollar_unit, cents, cent_unit)
96
+ elif dollars:
97
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
98
+ return " %s %s " % (dollars, dollar_unit)
99
+ elif cents:
100
+ cent_unit = "cent" if cents == 1 else "cents"
101
+ return " %s %s " % (cents, cent_unit)
102
+ else:
103
+ return " zero dollars "
104
+
105
+
106
+ def fraction_to_words(numerator, denominator):
107
+ if numerator == 1 and denominator == 2:
108
+ return " one half "
109
+ if numerator == 1 and denominator == 4:
110
+ return " one quarter "
111
+ if denominator == 2:
112
+ return " " + _inflect.number_to_words(numerator) + " halves "
113
+ if denominator == 4:
114
+ return " " + _inflect.number_to_words(numerator) + " quarters "
115
+ return (
116
+ " "
117
+ + _inflect.number_to_words(numerator)
118
+ + " "
119
+ + _inflect.ordinal(_inflect.number_to_words(denominator))
120
+ + " "
121
+ )
122
+
123
+
124
+ def _expand_fraction(m):
125
+ numerator = int(m.group(1))
126
+ denominator = int(m.group(2))
127
+ return fraction_to_words(numerator, denominator)
128
+
129
+
130
+ def _expand_ordinal(m):
131
+ return " " + _inflect.number_to_words(m.group(0)) + " "
132
+
133
+
134
+ def _expand_number(m):
135
+ num = int(m.group(0))
136
+ if num > 1000 and num < 3000:
137
+ if num == 2000:
138
+ return " two thousand "
139
+ elif num > 2000 and num < 2010:
140
+ return " two thousand " + _inflect.number_to_words(num % 100) + " "
141
+ elif num % 100 == 0:
142
+ return " " + _inflect.number_to_words(num // 100) + " hundred "
143
+ else:
144
+ return (
145
+ " "
146
+ + _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(
147
+ ", ", " "
148
+ )
149
+ + " "
150
+ )
151
+ else:
152
+ return " " + _inflect.number_to_words(num, andword="") + " "
153
+
154
+
155
+ # Normalize numbers pronunciation
156
+ def normalize_numbers(text):
157
+ text = re.sub(_comma_number_re, _remove_commas, text)
158
+ text = re.sub(_pounds_re, r"\1 pounds", text)
159
+ text = re.sub(_dollars_re, _expand_dollars, text)
160
+ text = re.sub(_fraction_re, _expand_fraction, text)
161
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
162
+ text = re.sub(_percent_number_re, _expand_percent, text)
163
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
164
+ text = re.sub(_number_re, _expand_number, text)
165
+ return text
166
+
167
+
168
+ def _english_to_ipa(text):
169
+ # text = unidecode(text).lower()
170
+ text = expand_abbreviations(text)
171
+ text = normalize_numbers(text)
172
+ return text
173
+
174
+
175
+ # special map
176
+ def special_map(text):
177
+ for regex, replacement in _special_map:
178
+ regex = regex.replace("|", "\|")
179
+ while re.search(r"(^|[_|]){}([_|]|$)".format(regex), text):
180
+ text = re.sub(
181
+ r"(^|[_|]){}([_|]|$)".format(regex), r"\1{}\2".format(replacement), text
182
+ )
183
+ # text = re.sub(r'([,.!?])', r'|\1', text)
184
+ return text
185
+
186
+
187
+ # Add some special operation
188
+ def english_to_ipa(text, text_tokenizer):
189
+ if type(text) == str:
190
+ text = _english_to_ipa(text)
191
+ else:
192
+ text = [_english_to_ipa(t) for t in text]
193
+ phonemes = text_tokenizer(text)
194
+ if phonemes[-1] in "pโผสฐmftnlkxสƒs`ษนaoษ™ษ›ษชeษ‘สŠล‹iuษฅwรฆjห":
195
+ phonemes += "|_"
196
+ if type(text) == str:
197
+ return special_map(phonemes)
198
+ else:
199
+ result_ph = []
200
+ for phone in phonemes:
201
+ result_ph.append(special_map(phone))
202
+ return result_ph
g2p/g2p/french.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+
8
+ """
9
+ Text clean time
10
+ """
11
+ # List of (regular expression, replacement) pairs for abbreviations in french:
12
+ _abbreviations = [
13
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
14
+ for x in [
15
+ ("M", "monsieur"),
16
+ ("Mlle", "mademoiselle"),
17
+ ("Mlles", "mesdemoiselles"),
18
+ ("Mme", "Madame"),
19
+ ("Mmes", "Mesdames"),
20
+ ("N.B", "nota bene"),
21
+ ("M", "monsieur"),
22
+ ("p.c.q", "parce que"),
23
+ ("Pr", "professeur"),
24
+ ("qqch", "quelque chose"),
25
+ ("rdv", "rendez-vous"),
26
+ ("max", "maximum"),
27
+ ("min", "minimum"),
28
+ ("no", "numรฉro"),
29
+ ("adr", "adresse"),
30
+ ("dr", "docteur"),
31
+ ("st", "saint"),
32
+ ("co", "companie"),
33
+ ("jr", "junior"),
34
+ ("sgt", "sergent"),
35
+ ("capt", "capitain"),
36
+ ("col", "colonel"),
37
+ ("av", "avenue"),
38
+ ("av. J.-C", "avant Jรฉsus-Christ"),
39
+ ("apr. J.-C", "aprรจs Jรฉsus-Christ"),
40
+ ("art", "article"),
41
+ ("boul", "boulevard"),
42
+ ("c.-ร -d", "cโ€™est-ร -dire"),
43
+ ("etc", "et cetera"),
44
+ ("ex", "exemple"),
45
+ ("excl", "exclusivement"),
46
+ ("boul", "boulevard"),
47
+ ]
48
+ ] + [
49
+ (re.compile("\\b%s" % x[0]), x[1])
50
+ for x in [
51
+ ("Mlle", "mademoiselle"),
52
+ ("Mlles", "mesdemoiselles"),
53
+ ("Mme", "Madame"),
54
+ ("Mmes", "Mesdames"),
55
+ ]
56
+ ]
57
+
58
+ rep_map = {
59
+ "๏ผš": ",",
60
+ "๏ผ›": ",",
61
+ "๏ผŒ": ",",
62
+ "ใ€‚": ".",
63
+ "๏ผ": "!",
64
+ "๏ผŸ": "?",
65
+ "\n": ".",
66
+ "ยท": ",",
67
+ "ใ€": ",",
68
+ "...": ".",
69
+ "โ€ฆ": ".",
70
+ "$": ".",
71
+ "โ€œ": "",
72
+ "โ€": "",
73
+ "โ€˜": "",
74
+ "โ€™": "",
75
+ "๏ผˆ": "",
76
+ "๏ผ‰": "",
77
+ "(": "",
78
+ ")": "",
79
+ "ใ€Š": "",
80
+ "ใ€‹": "",
81
+ "ใ€": "",
82
+ "ใ€‘": "",
83
+ "[": "",
84
+ "]": "",
85
+ "โ€”": "",
86
+ "๏ฝž": "-",
87
+ "~": "-",
88
+ "ใ€Œ": "",
89
+ "ใ€": "",
90
+ "ยฟ": "",
91
+ "ยก": "",
92
+ }
93
+
94
+
95
+ def collapse_whitespace(text):
96
+ # Regular expression matching whitespace:
97
+ _whitespace_re = re.compile(r"\s+")
98
+ return re.sub(_whitespace_re, " ", text).strip()
99
+
100
+
101
+ def remove_punctuation_at_begin(text):
102
+ return re.sub(r"^[,.!?]+", "", text)
103
+
104
+
105
+ def remove_aux_symbols(text):
106
+ text = re.sub(r"[\<\>\(\)\[\]\"\ยซ\ยป]+", "", text)
107
+ return text
108
+
109
+
110
+ def replace_symbols(text):
111
+ text = text.replace(";", ",")
112
+ text = text.replace("-", " ")
113
+ text = text.replace(":", ",")
114
+ text = text.replace("&", " et ")
115
+ return text
116
+
117
+
118
+ def expand_abbreviations(text):
119
+ for regex, replacement in _abbreviations:
120
+ text = re.sub(regex, replacement, text)
121
+ return text
122
+
123
+
124
+ def replace_punctuation(text):
125
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
126
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
127
+ return replaced_text
128
+
129
+
130
+ def text_normalize(text):
131
+ text = expand_abbreviations(text)
132
+ text = replace_punctuation(text)
133
+ text = replace_symbols(text)
134
+ text = remove_aux_symbols(text)
135
+ text = remove_punctuation_at_begin(text)
136
+ text = collapse_whitespace(text)
137
+ text = re.sub(r"([^\.,!\?\-โ€ฆ])$", r"\1", text)
138
+ return text
139
+
140
+
141
+ def french_to_ipa(text, text_tokenizer):
142
+ if type(text) == str:
143
+ text = text_normalize(text)
144
+ phonemes = text_tokenizer(text)
145
+ return phonemes
146
+ else:
147
+ for i, t in enumerate(text):
148
+ text[i] = text_normalize(t)
149
+ return text_tokenizer(text)
g2p/g2p/german.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+
8
+ """
9
+ Text clean time
10
+ """
11
+ rep_map = {
12
+ "๏ผš": ",",
13
+ "๏ผ›": ",",
14
+ "๏ผŒ": ",",
15
+ "ใ€‚": ".",
16
+ "๏ผ": "!",
17
+ "๏ผŸ": "?",
18
+ "\n": ".",
19
+ "ยท": ",",
20
+ "ใ€": ",",
21
+ "...": ".",
22
+ "โ€ฆ": ".",
23
+ "$": ".",
24
+ "โ€œ": "",
25
+ "โ€": "",
26
+ "โ€˜": "",
27
+ "โ€™": "",
28
+ "๏ผˆ": "",
29
+ "๏ผ‰": "",
30
+ "(": "",
31
+ ")": "",
32
+ "ใ€Š": "",
33
+ "ใ€‹": "",
34
+ "ใ€": "",
35
+ "ใ€‘": "",
36
+ "[": "",
37
+ "]": "",
38
+ "โ€”": "",
39
+ "๏ฝž": "-",
40
+ "~": "-",
41
+ "ใ€Œ": "",
42
+ "ใ€": "",
43
+ "ยฟ": "",
44
+ "ยก": "",
45
+ }
46
+
47
+
48
+ def collapse_whitespace(text):
49
+ # Regular expression matching whitespace:
50
+ _whitespace_re = re.compile(r"\s+")
51
+ return re.sub(_whitespace_re, " ", text).strip()
52
+
53
+
54
+ def remove_punctuation_at_begin(text):
55
+ return re.sub(r"^[,.!?]+", "", text)
56
+
57
+
58
+ def remove_aux_symbols(text):
59
+ text = re.sub(r"[\<\>\(\)\[\]\"\ยซ\ยป]+", "", text)
60
+ return text
61
+
62
+
63
+ def replace_symbols(text):
64
+ text = text.replace(";", ",")
65
+ text = text.replace("-", " ")
66
+ text = text.replace(":", ",")
67
+ return text
68
+
69
+
70
+ def replace_punctuation(text):
71
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
72
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
73
+ return replaced_text
74
+
75
+
76
+ def text_normalize(text):
77
+ text = replace_punctuation(text)
78
+ text = replace_symbols(text)
79
+ text = remove_aux_symbols(text)
80
+ text = remove_punctuation_at_begin(text)
81
+ text = collapse_whitespace(text)
82
+ text = re.sub(r"([^\.,!\?\-โ€ฆ])$", r"\1", text)
83
+ return text
84
+
85
+
86
+ def german_to_ipa(text, text_tokenizer):
87
+ if type(text) == str:
88
+ text = text_normalize(text)
89
+ phonemes = text_tokenizer(text)
90
+ return phonemes
91
+ else:
92
+ for i, t in enumerate(text):
93
+ text[i] = text_normalize(t)
94
+ return text_tokenizer(text)
g2p/g2p/japanese.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import io, re, os, sys, time, argparse, pdb, json
7
+ from io import StringIO
8
+ from typing import Optional
9
+ import numpy as np
10
+ import traceback
11
+ import pyopenjtalk
12
+ from pykakasi import kakasi
13
+
14
+ punctuation = [",", ".", "!", "?", ":", ";", "'", "โ€ฆ"]
15
+
16
+ jp_xphone2ipa = [
17
+ " a a",
18
+ " i i",
19
+ " u ษฏ",
20
+ " e e",
21
+ " o o",
22
+ " a: aห",
23
+ " i: iห",
24
+ " u: ษฏห",
25
+ " e: eห",
26
+ " o: oห",
27
+ " k k",
28
+ " s s",
29
+ " t t",
30
+ " n n",
31
+ " h cฬง",
32
+ " f ษธ",
33
+ " m m",
34
+ " y j",
35
+ " r ษพ",
36
+ " w ษฐแต",
37
+ " N ษด",
38
+ " g g",
39
+ " j d ส‘",
40
+ " z z",
41
+ " d d",
42
+ " b b",
43
+ " p p",
44
+ " q q",
45
+ " v v",
46
+ " : :",
47
+ " by b j",
48
+ " ch t ษ•",
49
+ " dy d e j",
50
+ " ty t e j",
51
+ " gy g j",
52
+ " gw g ษฏ",
53
+ " hy cฬง j",
54
+ " ky k j",
55
+ " kw k ษฏ",
56
+ " my m j",
57
+ " ny n j",
58
+ " py p j",
59
+ " ry ษพ j",
60
+ " sh ษ•",
61
+ " ts t s ษฏ",
62
+ ]
63
+
64
+ _mora_list_minimum: list[tuple[str, Optional[str], str]] = [
65
+ ("ใƒดใ‚ฉ", "v", "o"),
66
+ ("ใƒดใ‚ง", "v", "e"),
67
+ ("ใƒดใ‚ฃ", "v", "i"),
68
+ ("ใƒดใ‚ก", "v", "a"),
69
+ ("ใƒด", "v", "u"),
70
+ ("ใƒณ", None, "N"),
71
+ ("ใƒฏ", "w", "a"),
72
+ ("ใƒญ", "r", "o"),
73
+ ("ใƒฌ", "r", "e"),
74
+ ("ใƒซ", "r", "u"),
75
+ ("ใƒชใƒง", "ry", "o"),
76
+ ("ใƒชใƒฅ", "ry", "u"),
77
+ ("ใƒชใƒฃ", "ry", "a"),
78
+ ("ใƒชใ‚ง", "ry", "e"),
79
+ ("ใƒช", "r", "i"),
80
+ ("ใƒฉ", "r", "a"),
81
+ ("ใƒจ", "y", "o"),
82
+ ("ใƒฆ", "y", "u"),
83
+ ("ใƒค", "y", "a"),
84
+ ("ใƒข", "m", "o"),
85
+ ("ใƒก", "m", "e"),
86
+ ("ใƒ ", "m", "u"),
87
+ ("ใƒŸใƒง", "my", "o"),
88
+ ("ใƒŸใƒฅ", "my", "u"),
89
+ ("ใƒŸใƒฃ", "my", "a"),
90
+ ("ใƒŸใ‚ง", "my", "e"),
91
+ ("ใƒŸ", "m", "i"),
92
+ ("ใƒž", "m", "a"),
93
+ ("ใƒ", "p", "o"),
94
+ ("ใƒœ", "b", "o"),
95
+ ("ใƒ›", "h", "o"),
96
+ ("ใƒš", "p", "e"),
97
+ ("ใƒ™", "b", "e"),
98
+ ("ใƒ˜", "h", "e"),
99
+ ("ใƒ—", "p", "u"),
100
+ ("ใƒ–", "b", "u"),
101
+ ("ใƒ•ใ‚ฉ", "f", "o"),
102
+ ("ใƒ•ใ‚ง", "f", "e"),
103
+ ("ใƒ•ใ‚ฃ", "f", "i"),
104
+ ("ใƒ•ใ‚ก", "f", "a"),
105
+ ("ใƒ•", "f", "u"),
106
+ ("ใƒ”ใƒง", "py", "o"),
107
+ ("ใƒ”ใƒฅ", "py", "u"),
108
+ ("ใƒ”ใƒฃ", "py", "a"),
109
+ ("ใƒ”ใ‚ง", "py", "e"),
110
+ ("ใƒ”", "p", "i"),
111
+ ("ใƒ“ใƒง", "by", "o"),
112
+ ("ใƒ“ใƒฅ", "by", "u"),
113
+ ("ใƒ“ใƒฃ", "by", "a"),
114
+ ("ใƒ“ใ‚ง", "by", "e"),
115
+ ("ใƒ“", "b", "i"),
116
+ ("ใƒ’ใƒง", "hy", "o"),
117
+ ("ใƒ’ใƒฅ", "hy", "u"),
118
+ ("ใƒ’ใƒฃ", "hy", "a"),
119
+ ("ใƒ’ใ‚ง", "hy", "e"),
120
+ ("ใƒ’", "h", "i"),
121
+ ("ใƒ‘", "p", "a"),
122
+ ("ใƒ", "b", "a"),
123
+ ("ใƒ", "h", "a"),
124
+ ("ใƒŽ", "n", "o"),
125
+ ("ใƒ", "n", "e"),
126
+ ("ใƒŒ", "n", "u"),
127
+ ("ใƒ‹ใƒง", "ny", "o"),
128
+ ("ใƒ‹ใƒฅ", "ny", "u"),
129
+ ("ใƒ‹ใƒฃ", "ny", "a"),
130
+ ("ใƒ‹ใ‚ง", "ny", "e"),
131
+ ("ใƒ‹", "n", "i"),
132
+ ("ใƒŠ", "n", "a"),
133
+ ("ใƒ‰ใ‚ฅ", "d", "u"),
134
+ ("ใƒ‰", "d", "o"),
135
+ ("ใƒˆใ‚ฅ", "t", "u"),
136
+ ("ใƒˆ", "t", "o"),
137
+ ("ใƒ‡ใƒง", "dy", "o"),
138
+ ("ใƒ‡ใƒฅ", "dy", "u"),
139
+ ("ใƒ‡ใƒฃ", "dy", "a"),
140
+ # ("ใƒ‡ใ‚ง", "dy", "e"),
141
+ ("ใƒ‡ใ‚ฃ", "d", "i"),
142
+ ("ใƒ‡", "d", "e"),
143
+ ("ใƒ†ใƒง", "ty", "o"),
144
+ ("ใƒ†ใƒฅ", "ty", "u"),
145
+ ("ใƒ†ใƒฃ", "ty", "a"),
146
+ ("ใƒ†ใ‚ฃ", "t", "i"),
147
+ ("ใƒ†", "t", "e"),
148
+ ("ใƒ„ใ‚ฉ", "ts", "o"),
149
+ ("ใƒ„ใ‚ง", "ts", "e"),
150
+ ("ใƒ„ใ‚ฃ", "ts", "i"),
151
+ ("ใƒ„ใ‚ก", "ts", "a"),
152
+ ("ใƒ„", "ts", "u"),
153
+ ("ใƒƒ", None, "q"), # ใ€Œclใ€ใ‹ใ‚‰ใ€Œqใ€ใซๅค‰ๆ›ด
154
+ ("ใƒใƒง", "ch", "o"),
155
+ ("ใƒใƒฅ", "ch", "u"),
156
+ ("ใƒใƒฃ", "ch", "a"),
157
+ ("ใƒใ‚ง", "ch", "e"),
158
+ ("ใƒ", "ch", "i"),
159
+ ("ใƒ€", "d", "a"),
160
+ ("ใ‚ฟ", "t", "a"),
161
+ ("ใ‚พ", "z", "o"),
162
+ ("ใ‚ฝ", "s", "o"),
163
+ ("ใ‚ผ", "z", "e"),
164
+ ("ใ‚ป", "s", "e"),
165
+ ("ใ‚บใ‚ฃ", "z", "i"),
166
+ ("ใ‚บ", "z", "u"),
167
+ ("ใ‚นใ‚ฃ", "s", "i"),
168
+ ("ใ‚น", "s", "u"),
169
+ ("ใ‚ธใƒง", "j", "o"),
170
+ ("ใ‚ธใƒฅ", "j", "u"),
171
+ ("ใ‚ธใƒฃ", "j", "a"),
172
+ ("ใ‚ธใ‚ง", "j", "e"),
173
+ ("ใ‚ธ", "j", "i"),
174
+ ("ใ‚ทใƒง", "sh", "o"),
175
+ ("ใ‚ทใƒฅ", "sh", "u"),
176
+ ("ใ‚ทใƒฃ", "sh", "a"),
177
+ ("ใ‚ทใ‚ง", "sh", "e"),
178
+ ("ใ‚ท", "sh", "i"),
179
+ ("ใ‚ถ", "z", "a"),
180
+ ("ใ‚ต", "s", "a"),
181
+ ("ใ‚ด", "g", "o"),
182
+ ("ใ‚ณ", "k", "o"),
183
+ ("ใ‚ฒ", "g", "e"),
184
+ ("ใ‚ฑ", "k", "e"),
185
+ ("ใ‚ฐใƒฎ", "gw", "a"),
186
+ ("ใ‚ฐ", "g", "u"),
187
+ ("ใ‚ฏใƒฎ", "kw", "a"),
188
+ ("ใ‚ฏ", "k", "u"),
189
+ ("ใ‚ฎใƒง", "gy", "o"),
190
+ ("ใ‚ฎใƒฅ", "gy", "u"),
191
+ ("ใ‚ฎใƒฃ", "gy", "a"),
192
+ ("ใ‚ฎใ‚ง", "gy", "e"),
193
+ ("ใ‚ฎ", "g", "i"),
194
+ ("ใ‚ญใƒง", "ky", "o"),
195
+ ("ใ‚ญใƒฅ", "ky", "u"),
196
+ ("ใ‚ญใƒฃ", "ky", "a"),
197
+ ("ใ‚ญใ‚ง", "ky", "e"),
198
+ ("ใ‚ญ", "k", "i"),
199
+ ("ใ‚ฌ", "g", "a"),
200
+ ("ใ‚ซ", "k", "a"),
201
+ ("ใ‚ช", None, "o"),
202
+ ("ใ‚จ", None, "e"),
203
+ ("ใ‚ฆใ‚ฉ", "w", "o"),
204
+ ("ใ‚ฆใ‚ง", "w", "e"),
205
+ ("ใ‚ฆใ‚ฃ", "w", "i"),
206
+ ("ใ‚ฆ", None, "u"),
207
+ ("ใ‚คใ‚ง", "y", "e"),
208
+ ("ใ‚ค", None, "i"),
209
+ ("ใ‚ข", None, "a"),
210
+ ]
211
+
212
+ _mora_list_additional: list[tuple[str, Optional[str], str]] = [
213
+ ("ใƒดใƒง", "by", "o"),
214
+ ("ใƒดใƒฅ", "by", "u"),
215
+ ("ใƒดใƒฃ", "by", "a"),
216
+ ("ใƒฒ", None, "o"),
217
+ ("ใƒฑ", None, "e"),
218
+ ("ใƒฐ", None, "i"),
219
+ ("ใƒฎ", "w", "a"),
220
+ ("ใƒง", "y", "o"),
221
+ ("ใƒฅ", "y", "u"),
222
+ ("ใƒ…", "z", "u"),
223
+ ("ใƒ‚", "j", "i"),
224
+ ("ใƒถ", "k", "e"),
225
+ ("ใƒฃ", "y", "a"),
226
+ ("ใ‚ฉ", None, "o"),
227
+ ("ใ‚ง", None, "e"),
228
+ ("ใ‚ฅ", None, "u"),
229
+ ("ใ‚ฃ", None, "i"),
230
+ ("ใ‚ก", None, "a"),
231
+ ]
232
+
233
+ # ไพ‹: "vo" -> "ใƒดใ‚ฉ", "a" -> "ใ‚ข"
234
+ mora_phonemes_to_mora_kata: dict[str, str] = {
235
+ (consonant or "") + vowel: kana for [kana, consonant, vowel] in _mora_list_minimum
236
+ }
237
+
238
+ # ไพ‹: "ใƒดใ‚ฉ" -> ("v", "o"), "ใ‚ข" -> (None, "a")
239
+ mora_kata_to_mora_phonemes: dict[str, tuple[Optional[str], str]] = {
240
+ kana: (consonant, vowel)
241
+ for [kana, consonant, vowel] in _mora_list_minimum + _mora_list_additional
242
+ }
243
+
244
+
245
+ # ๆญฃ่ฆๅŒ–ใง่จ˜ๅทใ‚’ๅค‰ๆ›ใ™ใ‚‹ใŸใ‚ใฎ่พžๆ›ธ
246
+ rep_map = {
247
+ "๏ผš": ":",
248
+ "๏ผ›": ";",
249
+ "๏ผŒ": ",",
250
+ "ใ€‚": ".",
251
+ "๏ผ": "!",
252
+ "๏ผŸ": "?",
253
+ "\n": ".",
254
+ "๏ผŽ": ".",
255
+ "โ‹ฏ": "โ€ฆ",
256
+ "ยทยทยท": "โ€ฆ",
257
+ "ใƒปใƒปใƒป": "โ€ฆ",
258
+ "ยท": ",",
259
+ "ใƒป": ",",
260
+ "โ€ข": ",",
261
+ "ใ€": ",",
262
+ "$": ".",
263
+ # "โ€œ": "'",
264
+ # "โ€": "'",
265
+ # '"': "'",
266
+ "โ€˜": "'",
267
+ "โ€™": "'",
268
+ # "๏ผˆ": "'",
269
+ # "๏ผ‰": "'",
270
+ # "(": "'",
271
+ # ")": "'",
272
+ # "ใ€Š": "'",
273
+ # "ใ€‹": "'",
274
+ # "ใ€": "'",
275
+ # "ใ€‘": "'",
276
+ # "[": "'",
277
+ # "]": "'",
278
+ # "โ€”โ€”": "-",
279
+ # "โˆ’": "-",
280
+ # "-": "-",
281
+ # "ใ€Ž": "'",
282
+ # "ใ€": "'",
283
+ # "ใ€ˆ": "'",
284
+ # "ใ€‰": "'",
285
+ # "ยซ": "'",
286
+ # "ยป": "'",
287
+ # # "๏ฝž": "-", # ใ“ใ‚Œใฏ้•ท้Ÿณ่จ˜ๅทใ€Œใƒผใ€ใจใ—ใฆๆ‰ฑใ†ใ‚ˆใ†ๅค‰ๆ›ด
288
+ # # "~": "-", # ใ“ใ‚Œใฏ้•ท้Ÿณ่จ˜ๅทใ€Œใƒผใ€ใจใ—ใฆๆ‰ฑใ†ใ‚ˆใ†ๅค‰ๆ›ด
289
+ # "ใ€Œ": "'",
290
+ # "ใ€": "'",
291
+ }
292
+
293
+
294
+ def _numeric_feature_by_regex(regex, s):
295
+ match = re.search(regex, s)
296
+ if match is None:
297
+ return -50
298
+ return int(match.group(1))
299
+
300
+
301
+ def replace_punctuation(text: str) -> str:
302
+ """ๅฅ่ชญ็‚น็ญ‰ใ‚’ใ€Œ.ใ€ใ€Œ,ใ€ใ€Œ!ใ€ใ€Œ?ใ€ใ€Œ'ใ€ใ€Œ-ใ€ใซๆญฃ่ฆๅŒ–ใ—ใ€OpenJTalkใง่ชญใฟใŒๅ–ๅพ—ใงใใ‚‹ใ‚‚ใฎใฎใฟๆฎ‹ใ™๏ผš
303
+ ๆผขๅญ—ใƒปๅนณไปฎๅใƒปใ‚ซใ‚ฟใ‚ซใƒŠใ€ใ‚ขใƒซใƒ•ใ‚กใƒ™ใƒƒใƒˆใ€ใ‚ฎใƒชใ‚ทใƒฃๆ–‡ๅญ—
304
+ """
305
+ pattern = re.compile("|".join(re.escape(p) for p in rep_map.keys()))
306
+ # print("before: ", text)
307
+ # ๅฅ่ชญ็‚นใ‚’่พžๆ›ธใง็ฝฎๆ›
308
+ replaced_text = pattern.sub(lambda x: rep_map[x.group()], text)
309
+
310
+ replaced_text = re.sub(
311
+ # โ†“ ใฒใ‚‰ใŒใชใ€ใ‚ซใ‚ฟใ‚ซใƒŠใ€ๆผขๅญ—
312
+ r"[^\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FFF\u3400-\u4DBF\u3005"
313
+ # โ†“ ๅŠ่ง’ใ‚ขใƒซใƒ•ใ‚กใƒ™ใƒƒใƒˆ๏ผˆๅคงๆ–‡ๅญ—ใจๅฐๆ–‡ๅญ—๏ผ‰
314
+ + r"\u0041-\u005A\u0061-\u007A"
315
+ # โ†“ ๅ…จ่ง’ใ‚ขใƒซใƒ•ใ‚กใƒ™ใƒƒใƒˆ๏ผˆๅคงๆ–‡ๅญ—ใจๅฐๆ–‡ๅญ—๏ผ‰
316
+ + r"\uFF21-\uFF3A\uFF41-\uFF5A"
317
+ # โ†“ ใ‚ฎใƒชใ‚ทใƒฃๆ–‡ๅญ—
318
+ + r"\u0370-\u03FF\u1F00-\u1FFF"
319
+ # โ†“ "!", "?", "โ€ฆ", ",", ".", "'", "-", ไฝ†ใ—`โ€ฆ`ใฏใ™ใงใซ`...`ใซๅค‰ๆ›ใ•ใ‚Œใฆใ„ใ‚‹
320
+ + "".join(punctuation) + r"]+",
321
+ # ไธŠ่ฟฐไปฅๅค–ใฎๆ–‡ๅญ—ใ‚’ๅ‰Š้™ค
322
+ "",
323
+ replaced_text,
324
+ )
325
+ # print("after: ", replaced_text)
326
+ return replaced_text
327
+
328
+
329
+ def fix_phone_tone(phone_tone_list: list[tuple[str, int]]) -> list[tuple[str, int]]:
330
+ """
331
+ `phone_tone_list`ใฎtone๏ผˆใ‚ขใ‚ฏใ‚ปใƒณใƒˆใฎๅ€ค๏ผ‰ใ‚’0ใ‹1ใฎ็ฏ„ๅ›ฒใซไฟฎๆญฃใ™ใ‚‹ใ€‚
332
+ ไพ‹: [(a, 0), (i, -1), (u, -1)] โ†’ [(a, 1), (i, 0), (u, 0)]
333
+ """
334
+ tone_values = set(tone for _, tone in phone_tone_list)
335
+ if len(tone_values) == 1:
336
+ assert tone_values == {0}, tone_values
337
+ return phone_tone_list
338
+ elif len(tone_values) == 2:
339
+ if tone_values == {0, 1}:
340
+ return phone_tone_list
341
+ elif tone_values == {-1, 0}:
342
+ return [
343
+ (letter, 0 if tone == -1 else 1) for letter, tone in phone_tone_list
344
+ ]
345
+ else:
346
+ raise ValueError(f"Unexpected tone values: {tone_values}")
347
+ else:
348
+ raise ValueError(f"Unexpected tone values: {tone_values}")
349
+
350
+
351
+ def fix_phone_tone_wplen(phone_tone_list, word_phone_length_list):
352
+ phones = []
353
+ tones = []
354
+ w_p_len = []
355
+ p_len = len(phone_tone_list)
356
+ idx = 0
357
+ w_idx = 0
358
+ while idx < p_len:
359
+ offset = 0
360
+ if phone_tone_list[idx] == "โ–":
361
+ w_p_len.append(w_idx + 1)
362
+
363
+ curr_w_p_len = word_phone_length_list[w_idx]
364
+ for i in range(curr_w_p_len):
365
+ p, t = phone_tone_list[idx]
366
+ if p == ":" and len(phones) > 0:
367
+ if phones[-1][-1] != ":":
368
+ phones[-1] += ":"
369
+ offset -= 1
370
+ else:
371
+ phones.append(p)
372
+ tones.append(str(t))
373
+ idx += 1
374
+ if idx >= p_len:
375
+ break
376
+ w_p_len.append(curr_w_p_len + offset)
377
+ w_idx += 1
378
+ # print(w_p_len)
379
+ return phones, tones, w_p_len
380
+
381
+
382
+ def g2phone_tone_wo_punct(prosodies) -> list[tuple[str, int]]:
383
+ """
384
+ ใƒ†ใ‚ญใ‚นใƒˆใซๅฏพใ—ใฆใ€้Ÿณ็ด ใจใ‚ขใ‚ฏใ‚ปใƒณใƒˆ๏ผˆ0ใ‹1๏ผ‰ใฎใƒšใ‚ขใฎใƒชใ‚นใƒˆใ‚’่ฟ”ใ™ใ€‚
385
+ ใŸใ ใ—ใ€Œ!ใ€ใ€Œ.ใ€ใ€Œ?ใ€็ญ‰ใฎ้ž้Ÿณ็ด ่จ˜ๅท(punctuation)ใฏๅ…จใฆๆถˆใˆใ‚‹๏ผˆใƒใƒผใ‚บ่จ˜ๅทใ‚‚ๆฎ‹ใ•ใชใ„๏ผ‰ใ€‚
386
+ ้ž้Ÿณ็ด ่จ˜ๅทใ‚’ๅซใ‚ใ‚‹ๅ‡ฆ็†ใฏ`align_tones()`ใง่กŒใ‚ใ‚Œใ‚‹ใ€‚
387
+ ใพใŸใ€Œใฃใ€ใฏใ€Œclใ€ใงใชใใ€Œqใ€ใซๅค‰ๆ›ใ•ใ‚Œใ‚‹๏ผˆใ€Œใ‚“ใ€ใฏใ€ŒNใ€ใฎใพใพ๏ผ‰ใ€‚
388
+ ไพ‹: "ใ“ใ‚“ใซใกใฏใ€ไธ–็•Œใƒผใ€‚ใ€‚ๅ…ƒๆฐ—๏ผŸ๏ผ" โ†’
389
+ [('k', 0), ('o', 0), ('N', 1), ('n', 1), ('i', 1), ('ch', 1), ('i', 1), ('w', 1), ('a', 1), ('s', 1), ('e', 1), ('k', 0), ('a', 0), ('i', 0), ('i', 0), ('g', 1), ('e', 1), ('N', 0), ('k', 0), ('i', 0)]
390
+ """
391
+ result: list[tuple[str, int]] = []
392
+ current_phrase: list[tuple[str, int]] = []
393
+ current_tone = 0
394
+ last_accent = ""
395
+ for i, letter in enumerate(prosodies):
396
+ # ็‰นๆฎŠ่จ˜ๅทใฎๅ‡ฆ็†
397
+
398
+ # ๆ–‡้ ญ่จ˜ๅทใ€็„ก่ฆ–ใ™ใ‚‹
399
+ if letter == "^":
400
+ assert i == 0, "Unexpected ^"
401
+ # ใ‚ขใ‚ฏใ‚ปใƒณใƒˆๅฅใฎ็ต‚ใ‚ใ‚Šใซๆฅใ‚‹่จ˜ๅท
402
+ elif letter in ("$", "?", "_", "#"):
403
+ # ไฟๆŒใ—ใฆใ„ใ‚‹ใƒ•ใƒฌใƒผใ‚บใ‚’ใ€ใ‚ขใ‚ฏใ‚ปใƒณใƒˆๆ•ฐๅ€คใ‚’0-1ใซไฟฎๆญฃใ—็ตๆžœใซ่ฟฝๅŠ 
404
+ result.extend(fix_phone_tone(current_phrase))
405
+ # ๆœซๅฐพใซๆฅใ‚‹็ต‚ไบ†่จ˜ๅทใ€็„ก่ฆ–๏ผˆๆ–‡ไธญใฎ็–‘ๅ•ๆ–‡ใฏ`_`ใซใชใ‚‹๏ผ‰
406
+ if letter in ("$", "?"):
407
+ assert i == len(prosodies) - 1, f"Unexpected {letter}"
408
+ # ใ‚ใจใฏ"_"๏ผˆใƒใƒผใ‚บ๏ผ‰ใจ"#"๏ผˆใ‚ขใ‚ฏใ‚ปใƒณใƒˆๅฅใฎๅขƒ็•Œ๏ผ‰ใฎใฟ
409
+ # ใ“ใ‚Œใ‚‰ใฏๆฎ‹ใ•ใšใ€ๆฌกใฎใ‚ขใ‚ฏใ‚ปใƒณใƒˆๅฅใซๅ‚™ใˆใ‚‹ใ€‚
410
+
411
+ current_phrase = []
412
+ # 0ใ‚’ๅŸบๆบ–็‚นใซใ—ใฆใใ“ใ‹ใ‚‰ไธŠๆ˜‡ใƒปไธ‹้™ใ™ใ‚‹๏ผˆ่ฒ ใฎๅ ดๅˆใฏไธŠใฎ`fix_phone_tone`ใง็›ดใ‚‹๏ผ‰
413
+ current_tone = 0
414
+ last_accent = ""
415
+ # ใ‚ขใ‚ฏใ‚ปใƒณใƒˆไธŠๆ˜‡่จ˜ๅท
416
+ elif letter == "[":
417
+ if last_accent != letter:
418
+ current_tone = current_tone + 1
419
+ last_accent = letter
420
+ # ใ‚ขใ‚ฏใ‚ปใƒณใƒˆไธ‹้™่จ˜ๅท
421
+ elif letter == "]":
422
+ if last_accent != letter:
423
+ current_tone = current_tone - 1
424
+ last_accent = letter
425
+ # ใใ‚Œไปฅๅค–ใฏ้€šๅธธใฎ้Ÿณ็ด 
426
+ else:
427
+ if letter == "cl": # ใ€Œใฃใ€ใฎๅ‡ฆ็†
428
+ letter = "q"
429
+ current_phrase.append((letter, current_tone))
430
+ return result
431
+
432
+
433
+ def handle_long(sep_phonemes: list[list[str]]) -> list[list[str]]:
434
+ for i in range(len(sep_phonemes)):
435
+ if sep_phonemes[i][0] == "ใƒผ":
436
+ # sep_phonemes[i][0] = sep_phonemes[i - 1][-1]
437
+ sep_phonemes[i][0] = ":"
438
+ if "ใƒผ" in sep_phonemes[i]:
439
+ for j in range(len(sep_phonemes[i])):
440
+ if sep_phonemes[i][j] == "ใƒผ":
441
+ # sep_phonemes[i][j] = sep_phonemes[i][j - 1][-1]
442
+ sep_phonemes[i][j] = ":"
443
+ return sep_phonemes
444
+
445
+
446
+ def handle_long_word(sep_phonemes: list[list[str]]) -> list[list[str]]:
447
+ res = []
448
+ for i in range(len(sep_phonemes)):
449
+ if sep_phonemes[i][0] == "ใƒผ":
450
+ sep_phonemes[i][0] = sep_phonemes[i - 1][-1]
451
+ # sep_phonemes[i][0] = ':'
452
+ if "ใƒผ" in sep_phonemes[i]:
453
+ for j in range(len(sep_phonemes[i])):
454
+ if sep_phonemes[i][j] == "ใƒผ":
455
+ sep_phonemes[i][j] = sep_phonemes[i][j - 1][-1]
456
+ # sep_phonemes[i][j] = ':'
457
+ res.append(sep_phonemes[i])
458
+ res.append("โ–")
459
+ return res
460
+
461
+
462
+ def align_tones(
463
+ phones_with_punct: list[str], phone_tone_list: list[tuple[str, int]]
464
+ ) -> list[tuple[str, int]]:
465
+ """
466
+ ไพ‹:
467
+ โ€ฆ็งใฏใ€ใ€ใใ†ๆ€ใ†ใ€‚
468
+ phones_with_punct:
469
+ [".", ".", ".", "w", "a", "t", "a", "sh", "i", "w", "a", ",", ",", "s", "o", "o", "o", "m", "o", "u", "."]
470
+ phone_tone_list:
471
+ [("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0))]
472
+ Return:
473
+ [(".", 0), (".", 0), (".", 0), ("w", 0), ("a", 0), ("t", 1), ("a", 1), ("sh", 1), ("i", 1), ("w", 1), ("a", 1), (",", 0), (",", 0), ("s", 0), ("o", 0), ("o", 1), ("o", 1), ("m", 1), ("o", 1), ("u", 0), (".", 0)]
474
+ """
475
+ result: list[tuple[str, int]] = []
476
+ tone_index = 0
477
+ for phone in phones_with_punct:
478
+ if tone_index >= len(phone_tone_list):
479
+ # ไฝ™ใฃใŸpunctuationใŒใ‚ใ‚‹ๅ ดๅˆ โ†’ (punctuation, 0)ใ‚’่ฟฝๅŠ 
480
+ result.append((phone, 0))
481
+ elif phone == phone_tone_list[tone_index][0]:
482
+ # phone_tone_listใฎ็พๅœจใฎ้Ÿณ็ด ใจไธ€่‡ดใ™ใ‚‹ๅ ดๅˆ โ†’ toneใ‚’ใใ“ใ‹ใ‚‰ๅ–ๅพ—ใ€(phone, tone)ใ‚’่ฟฝๅŠ 
483
+ result.append((phone, phone_tone_list[tone_index][1]))
484
+ # ๆŽขใ™indexใ‚’1ใค้€ฒใ‚ใ‚‹
485
+ tone_index += 1
486
+ elif phone in punctuation or phone == "โ–":
487
+ # phoneใŒpunctuationใฎๅ ดๅˆ โ†’ (phone, 0)ใ‚’่ฟฝๅŠ 
488
+ result.append((phone, 0))
489
+ else:
490
+ print(f"phones: {phones_with_punct}")
491
+ print(f"phone_tone_list: {phone_tone_list}")
492
+ print(f"result: {result}")
493
+ print(f"tone_index: {tone_index}")
494
+ print(f"phone: {phone}")
495
+ raise ValueError(f"Unexpected phone: {phone}")
496
+ return result
497
+
498
+
499
+ def kata2phoneme_list(text: str) -> list[str]:
500
+ """
501
+ ๅŽŸๅ‰‡ใ‚ซใ‚ฟใ‚ซใƒŠใฎ`text`ใ‚’ๅ—ใ‘ๅ–ใ‚Šใ€ใใ‚Œใ‚’ใใฎใพใพใ„ใ˜ใ‚‰ใšใซ้Ÿณ็ด ่จ˜ๅทใฎใƒชใ‚นใƒˆใซๅค‰ๆ›ใ€‚
502
+ ๆณจๆ„็‚น๏ผš
503
+ - punctuationใŒๆฅใŸๅ ดๅˆ๏ผˆpunctuationใŒ1ๆ–‡ๅญ—ใฎๅ ดๅˆใŒใ‚ใ‚Šใ†ใ‚‹๏ผ‰ใ€ๅ‡ฆ็†ใ›ใš1ๆ–‡ๅญ—ใฎใƒชใ‚นใƒˆใ‚’่ฟ”ใ™
504
+ - ๅ†’้ ญใซ็ถšใใ€Œใƒผใ€ใฏใใฎใพใพใ€Œใƒผใ€ใฎใพใพใซใ™ใ‚‹๏ผˆ`handle_long()`ใงๅ‡ฆ็†ใ•ใ‚Œใ‚‹๏ผ‰
505
+ - ๆ–‡ไธญใฎใ€Œใƒผใ€ใฏๅ‰ใฎ้Ÿณ็ด ่จ˜ๅทใฎๆœ€ๅพŒใฎ้Ÿณ็ด ่จ˜ๅทใซๅค‰ๆ›ใ•ใ‚Œใ‚‹ใ€‚
506
+ ไพ‹๏ผš
507
+ `ใƒผใƒผใ‚ฝใƒผใƒŠใƒŽใ‚ซใƒผใƒผ` โ†’ ["ใƒผ", "ใƒผ", "s", "o", "o", "n", "a", "n", "o", "k", "a", "a", "a"]
508
+ `?` โ†’ ["?"]
509
+ """
510
+ if text in punctuation:
511
+ return [text]
512
+ # `text`ใŒใ‚ซใ‚ฟใ‚ซใƒŠ๏ผˆ`ใƒผ`ๅซใ‚€๏ผ‰ใฎใฟใ‹ใ‚‰ใชใ‚‹ใ‹ใฉใ†ใ‹ใ‚’ใƒใ‚งใƒƒใ‚ฏ
513
+ if re.fullmatch(r"[\u30A0-\u30FF]+", text) is None:
514
+ raise ValueError(f"Input must be katakana only: {text}")
515
+ sorted_keys = sorted(mora_kata_to_mora_phonemes.keys(), key=len, reverse=True)
516
+ pattern = "|".join(map(re.escape, sorted_keys))
517
+
518
+ def mora2phonemes(mora: str) -> str:
519
+ cosonant, vowel = mora_kata_to_mora_phonemes[mora]
520
+ if cosonant is None:
521
+ return f" {vowel}"
522
+ return f" {cosonant} {vowel}"
523
+
524
+ spaced_phonemes = re.sub(pattern, lambda m: mora2phonemes(m.group()), text)
525
+
526
+ # ้•ท้Ÿณ่จ˜ๅทใ€Œใƒผใ€ใฎๅ‡ฆ็†
527
+ long_pattern = r"(\w)(ใƒผ*)"
528
+ long_replacement = lambda m: m.group(1) + (" " + m.group(1)) * len(m.group(2))
529
+ spaced_phonemes = re.sub(long_pattern, long_replacement, spaced_phonemes)
530
+ # spaced_phonemes += ' โ–'
531
+ return spaced_phonemes.strip().split(" ")
532
+
533
+
534
+ def frontend2phoneme(labels, drop_unvoiced_vowels=False):
535
+ N = len(labels)
536
+
537
+ phones = []
538
+ for n in range(N):
539
+ lab_curr = labels[n]
540
+ # print(lab_curr)
541
+ # current phoneme
542
+ p3 = re.search(r"\-(.*?)\+", lab_curr).group(1)
543
+
544
+ # deal unvoiced vowels as normal vowels
545
+ if drop_unvoiced_vowels and p3 in "AEIOU":
546
+ p3 = p3.lower()
547
+
548
+ # deal with sil at the beginning and the end of text
549
+ if p3 == "sil":
550
+ # assert n == 0 or n == N - 1
551
+ # if n == 0:
552
+ # phones.append("^")
553
+ # elif n == N - 1:
554
+ # # check question form or not
555
+ # e3 = _numeric_feature_by_regex(r"!(\d+)_", lab_curr)
556
+ # if e3 == 0:
557
+ # phones.append("$")
558
+ # elif e3 == 1:
559
+ # phones.append("?")
560
+ continue
561
+ elif p3 == "pau":
562
+ phones.append("_")
563
+ continue
564
+ else:
565
+ phones.append(p3)
566
+
567
+ # accent type and position info (forward or backward)
568
+ a1 = _numeric_feature_by_regex(r"/A:([0-9\-]+)\+", lab_curr)
569
+ a2 = _numeric_feature_by_regex(r"\+(\d+)\+", lab_curr)
570
+ a3 = _numeric_feature_by_regex(r"\+(\d+)/", lab_curr)
571
+
572
+ # number of mora in accent phrase
573
+ f1 = _numeric_feature_by_regex(r"/F:(\d+)_", lab_curr)
574
+
575
+ a2_next = _numeric_feature_by_regex(r"\+(\d+)\+", labels[n + 1])
576
+ # accent phrase border
577
+ # print(p3, a1, a2, a3, f1, a2_next, lab_curr)
578
+ if a3 == 1 and a2_next == 1 and p3 in "aeiouAEIOUNcl":
579
+ phones.append("#")
580
+ # pitch falling
581
+ elif a1 == 0 and a2_next == a2 + 1 and a2 != f1:
582
+ phones.append("]")
583
+ # pitch rising
584
+ elif a2 == 1 and a2_next == 2:
585
+ phones.append("[")
586
+
587
+ # phones = ' '.join(phones)
588
+ return phones
589
+
590
+
591
+ class JapanesePhoneConverter(object):
592
+ def __init__(self, lexicon_path=None, ipa_dict_path=None):
593
+ # lexicon_lines = open(lexicon_path, 'r', encoding='utf-8').readlines()
594
+ # self.lexicon = {}
595
+ # self.single_dict = {}
596
+ # self.double_dict = {}
597
+ # for curr_line in lexicon_lines:
598
+ # k,v = curr_line.strip().split('+',1)
599
+ # self.lexicon[k] = v
600
+ # if len(k) == 2:
601
+ # self.double_dict[k] = v
602
+ # elif len(k) == 1:
603
+ # self.single_dict[k] = v
604
+ self.ipa_dict = {}
605
+ for curr_line in jp_xphone2ipa:
606
+ k, v = curr_line.strip().split(" ", 1)
607
+ self.ipa_dict[k] = re.sub("\s", "", v)
608
+ # kakasi1 = kakasi()
609
+ # kakasi1.setMode("H","K")
610
+ # kakasi1.setMode("J","K")
611
+ # kakasi1.setMode("r","Hepburn")
612
+ self.japan_JH2K = kakasi()
613
+ self.table = {ord(f): ord(t) for f, t in zip("67", "_ยฏ")}
614
+
615
+ def text2sep_kata(self, parsed) -> tuple[list[str], list[str]]:
616
+ """
617
+ `text_normalize`ใงๆญฃ่ฆๅŒ–ๆธˆใฟใฎ`norm_text`ใ‚’ๅ—ใ‘ๅ–ใ‚Šใ€ใใ‚Œใ‚’ๅ˜่ชžๅˆ†ๅ‰ฒใ—ใ€
618
+ ๅˆ†ๅ‰ฒใ•ใ‚ŒใŸๅ˜่ชžใƒชใ‚นใƒˆใจใใฎ่ชญใฟ๏ผˆใ‚ซใ‚ฟใ‚ซใƒŠor่จ˜ๅท1ๆ–‡ๅญ—๏ผ‰ใฎใƒชใ‚น๏ฟฝ๏ฟฝ๏ฟฝใฎใ‚ฟใƒ—ใƒซใ‚’่ฟ”ใ™ใ€‚
619
+ ๅ˜่ชžๅˆ†ๅ‰ฒ็ตๆžœใฏใ€`g2p()`ใฎ`word2ph`ใง1ๆ–‡ๅญ—ใ‚ใŸใ‚Šใซๅ‰ฒใ‚ŠๆŒฏใ‚‹้Ÿณ็ด ่จ˜ๅทใฎๆ•ฐใ‚’ๆฑบใ‚ใ‚‹ใŸใ‚ใซไฝฟใ†ใ€‚
620
+ ไพ‹:
621
+ `็งใฏใใ†ๆ€ใ†!ใฃใฆๆ„Ÿใ˜?` โ†’
622
+ ["็ง", "ใฏ", "ใใ†", "ๆ€ใ†", "!", "ใฃใฆ", "ๆ„Ÿใ˜", "?"], ["ใƒฏใ‚ฟใ‚ท", "ใƒฏ", "ใ‚ฝใƒผ", "ใ‚ชใƒขใ‚ฆ", "!", "ใƒƒใƒ†", "ใ‚ซใƒณใ‚ธ", "?"]
623
+ """
624
+ # parsed: OpenJTalkใฎ่งฃๆž็ตๆžœ
625
+ sep_text: list[str] = []
626
+ sep_kata: list[str] = []
627
+ fix_parsed = []
628
+ i = 0
629
+ while i <= len(parsed) - 1:
630
+ # word: ๅฎŸ้š›ใฎๅ˜่ชžใฎๆ–‡ๅญ—ๅˆ—
631
+ # yomi: ใใฎ่ชญใฟใ€ไฝ†ใ—็„กๅฃฐๅŒ–ใ‚ตใ‚คใƒณใฎ`โ€™`ใฏ้™คๅŽป
632
+ # print(parsed)
633
+ yomi = parsed[i]["pron"]
634
+ tmp_parsed = parsed[i]
635
+ if i != len(parsed) - 1 and parsed[i + 1]["string"] in [
636
+ "ใ€…",
637
+ "ใ‚",
638
+ "ใƒฝ",
639
+ "ใ‚ž",
640
+ "ใƒพ",
641
+ "ใ‚›",
642
+ ]:
643
+ word = parsed[i]["string"] + parsed[i + 1]["string"]
644
+ i += 1
645
+ else:
646
+ word = parsed[i]["string"]
647
+ word, yomi = replace_punctuation(word), yomi.replace("โ€™", "")
648
+ """
649
+ ใ“ใ“ใง`yomi`ใฎๅ–ใ‚Šใ†ใ‚‹ๅ€คใฏไปฅไธ‹ใฎ้€šใ‚Šใฎใฏใšใ€‚
650
+ - `word`ใŒ้€šๅธธๅ˜่ชž โ†’ ้€šๅธธใฎ่ชญใฟ๏ผˆใ‚ซใ‚ฟใ‚ซใƒŠ๏ผ‰
651
+ ๏ผˆใ‚ซใ‚ฟใ‚ซใƒŠใ‹ใ‚‰ใชใ‚Šใ€้•ท้Ÿณ่จ˜ๅทใ‚‚ๅซใฟใ†ใ‚‹ใ€`ใ‚ขใƒผ` ็ญ‰๏ผ‰
652
+ - `word`ใŒ`ใƒผ` ใ‹ใ‚‰ๅง‹ใพใ‚‹ โ†’ `ใƒผใƒฉใƒผ` ใ‚„ `ใƒผใƒผใƒผ` ใชใฉ
653
+ - `word`ใŒๅฅ่ชญ็‚นใ‚„็ฉบ็™ฝ็ญ‰ โ†’ `ใ€`
654
+ - `word`ใŒ`?` โ†’ `๏ผŸ`๏ผˆๅ…จ่ง’ใซใชใ‚‹๏ผ‰
655
+ ไป–ใซใ‚‚`word`ใŒ่ชญใ‚ใชใ„ใ‚ญใƒชใƒซๆ–‡ๅญ—ใ‚ขใƒฉใƒ“ใ‚ขๆ–‡ๅญ—็ญ‰ใŒๆฅใ‚‹ใจ`ใ€`ใซใชใ‚‹ใŒใ€ๆญฃ่ฆๅŒ–ใงใ“ใฎๅ ดๅˆใฏ่ตทใใชใ„ใฏใšใ€‚
656
+ ใพใŸๅ…ƒใฎใ‚ณใƒผใƒ‰ใงใฏ`yomi`ใŒ็ฉบ็™ฝใฎๅ ดๅˆใฎๅ‡ฆ็†ใŒใ‚ใฃใŸใŒใ€ใ“ใ‚Œใฏ่ตทใใชใ„ใฏใšใ€‚
657
+ ๅ‡ฆ็†ใ™ในใใฏ`yomi`ใŒ`ใ€`ใฎๅ ดๅˆใฎใฟใฎใฏใšใ€‚
658
+ """
659
+ assert yomi != "", f"Empty yomi: {word}"
660
+ if yomi == "ใ€":
661
+ # wordใฏๆญฃ่ฆๅŒ–ใ•ใ‚Œใฆใ„ใ‚‹ใฎใงใ€`.`, `,`, `!`, `'`, `-`ใฎใ„ใšใ‚Œใ‹
662
+ if word not in (
663
+ ".",
664
+ ",",
665
+ "!",
666
+ "'",
667
+ "-",
668
+ "?",
669
+ ":",
670
+ ";",
671
+ "โ€ฆ",
672
+ "",
673
+ ):
674
+ # ใ“ใ“ใฏpyopenjtalkใŒ่ชญใ‚ใชใ„ๆ–‡ๅญ—็ญ‰ใฎใจใใซ่ตทใ“ใ‚‹
675
+ #print(
676
+ # "{}Cannot read:{}, yomi:{}, new_word:{};".format(
677
+ # parsed, word, yomi, self.japan_JH2K.convert(word)[0]["kana"]
678
+ # )
679
+ #)
680
+ # raise ValueError(word)
681
+ word = self.japan_JH2K.convert(word)[0]["kana"]
682
+ # print(word, self.japan_JH2K.convert(word)[0]['kana'], kata2phoneme_list(self.japan_JH2K.convert(word)[0]['kana']))
683
+ tmp_parsed["pron"] = word
684
+ # yomi = "-"
685
+ # word = ','
686
+ # yomiใฏๅ…ƒใฎ่จ˜ๅทใฎใพใพใซๅค‰ๆ›ด
687
+ # else:
688
+ # parsed[i]['pron'] = parsed[i]["string"]
689
+ yomi = word
690
+ elif yomi == "๏ผŸ":
691
+ assert word == "?", f"yomi `๏ผŸ` comes from: {word}"
692
+ yomi = "?"
693
+ if word == "":
694
+ i += 1
695
+ continue
696
+ sep_text.append(word)
697
+ sep_kata.append(yomi)
698
+ # print(word, yomi, parts)
699
+ fix_parsed.append(tmp_parsed)
700
+ i += 1
701
+ # print(sep_text, sep_kata)
702
+ return sep_text, sep_kata, fix_parsed
703
+
704
+ def getSentencePhone(self, sentence, blank_mode=True, phoneme_mode=False):
705
+ # print("origin:", sentence)
706
+ words = []
707
+ words_phone_len = []
708
+ short_char_flag = False
709
+ output_duration_flag = []
710
+ output_before_sil_flag = []
711
+ normed_text = []
712
+ sentence = sentence.strip().strip("'")
713
+ sentence = re.sub(r"\s+", "", sentence)
714
+ output_res = []
715
+ failed_words = []
716
+ last_long_pause = 4
717
+ last_word = None
718
+ frontend_text = pyopenjtalk.run_frontend(sentence)
719
+ # print("frontend_text: ", frontend_text)
720
+ try:
721
+ frontend_text = pyopenjtalk.estimate_accent(frontend_text)
722
+ except:
723
+ pass
724
+ # print("estimate_accent: ", frontend_text)
725
+ # sep_text: ๅ˜่ชžๅ˜ไฝใฎๅ˜่ชžใฎใƒชใ‚นใƒˆ
726
+ # sep_kata: ๅ˜่ชžๅ˜ไฝใฎๅ˜่ชžใฎใ‚ซใ‚ฟใ‚ซใƒŠ่ชญใฟใฎใƒชใ‚นใƒˆ
727
+ sep_text, sep_kata, frontend_text = self.text2sep_kata(frontend_text)
728
+ # print("sep_text: ", sep_text)
729
+ # print("sep_kata: ", sep_kata)
730
+ # print("frontend_text: ", frontend_text)
731
+ # sep_phonemes: ๅ„ๅ˜่ชžใ”๏ฟฝ๏ฟฝ๏ฟฝใฎ้Ÿณ็ด ใฎใƒชใ‚นใƒˆใฎใƒชใ‚นใƒˆ
732
+ sep_phonemes = handle_long_word([kata2phoneme_list(i) for i in sep_kata])
733
+ # print("sep_phonemes: ", sep_phonemes)
734
+
735
+ pron_text = [x["pron"].strip().replace("โ€™", "") for x in frontend_text]
736
+ # pdb.set_trace()
737
+ prosodys = pyopenjtalk.make_label(frontend_text)
738
+ prosodys = frontend2phoneme(prosodys, drop_unvoiced_vowels=True)
739
+ # print("prosodys: ", ' '.join(prosodys))
740
+ # print("pron_text: ", pron_text)
741
+ normed_text = [x["string"].strip() for x in frontend_text]
742
+ # punctuationใŒใ™ในใฆๆถˆใˆใŸใ€้Ÿณ็ด ใจใ‚ขใ‚ฏใ‚ปใƒณใƒˆใฎใ‚ฟใƒ—ใƒซใฎใƒชใ‚นใƒˆ
743
+ phone_tone_list_wo_punct = g2phone_tone_wo_punct(prosodys)
744
+ # print("phone_tone_list_wo_punct: ", phone_tone_list_wo_punct)
745
+
746
+ # phone_w_punct: sep_phonemesใ‚’็ตๅˆใ—ใŸใ€punctuationใ‚’ๅ…ƒใฎใพใพไฟๆŒใ—ใŸ้Ÿณ็ด ๅˆ—
747
+ phone_w_punct: list[str] = []
748
+ w_p_len = []
749
+ for i in sep_phonemes:
750
+ phone_w_punct += i
751
+ w_p_len.append(len(i))
752
+ phone_w_punct = phone_w_punct[:-1]
753
+ # punctuation็„กใ—ใฎใ‚ขใ‚ฏใ‚ปใƒณใƒˆๆƒ…ๅ ฑใ‚’ไฝฟใฃใฆใ€punctuationใ‚’ๅซใ‚ใŸใ‚ขใ‚ฏใ‚ปใƒณใƒˆๆƒ…ๅ ฑใ‚’ไฝœใ‚‹
754
+ # print("phone_w_punct: ", phone_w_punct)
755
+ # print("phone_tone_list_wo_punct: ", phone_tone_list_wo_punct)
756
+ phone_tone_list = align_tones(phone_w_punct, phone_tone_list_wo_punct)
757
+
758
+ jp_item = {}
759
+ jp_p = ""
760
+ jp_t = ""
761
+ # mye rye pye bye nye
762
+ # je she
763
+ # print(phone_tone_list)
764
+ for p, t in phone_tone_list:
765
+ if p in self.ipa_dict:
766
+ curr_p = self.ipa_dict[p]
767
+ jp_p += curr_p
768
+ jp_t += str(t + 6) * len(curr_p)
769
+ elif p in punctuation:
770
+ jp_p += p
771
+ jp_t += "0"
772
+ elif p == "โ–":
773
+ jp_p += p
774
+ jp_t += " "
775
+ else:
776
+ print(p, t)
777
+ jp_p += "|"
778
+ jp_t += "0"
779
+ # return phones, tones, w_p_len
780
+ jp_p = jp_p.replace("โ–", " ")
781
+ jp_t = jp_t.translate(self.table)
782
+ jp_l = ""
783
+ for t in jp_t:
784
+ if t == " ":
785
+ jp_l += " "
786
+ else:
787
+ jp_l += "2"
788
+ # print(jp_p)
789
+ # print(jp_t)
790
+ # print(jp_l)
791
+ # print(len(jp_p_len), sum(w_p_len), len(jp_p), sum(jp_p_len))
792
+ assert len(jp_p) == len(jp_t) and len(jp_p) == len(jp_l)
793
+
794
+ jp_item["jp_p"] = jp_p.replace("| |", "|").rstrip("|")
795
+ jp_item["jp_t"] = jp_t
796
+ jp_item["jp_l"] = jp_l
797
+ jp_item["jp_normed_text"] = " ".join(normed_text)
798
+ jp_item["jp_pron_text"] = " ".join(pron_text)
799
+ # jp_item['jp_ruoma'] = sep_phonemes
800
+ # print(len(normed_text), len(sep_phonemes))
801
+ # print(normed_text)
802
+ return jp_item
803
+
804
+
805
+ jpc = JapanesePhoneConverter()
806
+
807
+
808
+ def japanese_to_ipa(text, text_tokenizer):
809
+ # phonemes = text_tokenizer(text)
810
+ if type(text) == str:
811
+ return jpc.getSentencePhone(text)["jp_p"]
812
+ else:
813
+ result_ph = []
814
+ for t in text:
815
+ result_ph.append(jpc.getSentencePhone(t)["jp_p"])
816
+ return result_ph
g2p/g2p/korean.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+
8
+ """
9
+ Text clean time
10
+ """
11
+ english_dictionary = {
12
+ "KOREA": "์ฝ”๋ฆฌ์•„",
13
+ "IDOL": "์•„์ด๋Œ",
14
+ "IT": "์•„์ดํ‹ฐ",
15
+ "IQ": "์•„์ดํ",
16
+ "UP": "์—…",
17
+ "DOWN": "๋‹ค์šด",
18
+ "PC": "ํ”ผ์”จ",
19
+ "CCTV": "์”จ์”จํ‹ฐ๋น„",
20
+ "SNS": "์—์Šค์—”์—์Šค",
21
+ "AI": "์—์ด์•„์ด",
22
+ "CEO": "์”จ์ด์˜ค",
23
+ "A": "์—์ด",
24
+ "B": "๋น„",
25
+ "C": "์”จ",
26
+ "D": "๋””",
27
+ "E": "์ด",
28
+ "F": "์—ํ”„",
29
+ "G": "์ง€",
30
+ "H": "์—์ด์น˜",
31
+ "I": "์•„์ด",
32
+ "J": "์ œ์ด",
33
+ "K": "์ผ€์ด",
34
+ "L": "์—˜",
35
+ "M": "์— ",
36
+ "N": "์—”",
37
+ "O": "์˜ค",
38
+ "P": "ํ”ผ",
39
+ "Q": "ํ",
40
+ "R": "์•Œ",
41
+ "S": "์—์Šค",
42
+ "T": "ํ‹ฐ",
43
+ "U": "์œ ",
44
+ "V": "๋ธŒ์ด",
45
+ "W": "๋”๋ธ”์œ ",
46
+ "X": "์—‘์Šค",
47
+ "Y": "์™€์ด",
48
+ "Z": "์ œํŠธ",
49
+ }
50
+
51
+
52
+ def normalize(text):
53
+ text = text.strip()
54
+ text = re.sub(
55
+ "[โบ€-โบ™โบ›-โปณโผ€-โฟ•ใ€…ใ€‡ใ€ก-ใ€ฉใ€ธ-ใ€บใ€ปใ€-ไถตไธ€-้ฟƒ่ฑˆ-้ถดไพฎ-้ ปไธฆ-้พŽ]", "", text
56
+ )
57
+ text = normalize_english(text)
58
+ text = text.lower()
59
+ return text
60
+
61
+
62
+ def normalize_english(text):
63
+ def fn(m):
64
+ word = m.group()
65
+ if word in english_dictionary:
66
+ return english_dictionary.get(word)
67
+ return word
68
+
69
+ text = re.sub("([A-Za-z]+)", fn, text)
70
+ return text
71
+
72
+
73
+ def korean_to_ipa(text, text_tokenizer):
74
+ if type(text) == str:
75
+ text = normalize(text)
76
+ phonemes = text_tokenizer(text)
77
+ return phonemes
78
+ else:
79
+ for i, t in enumerate(text):
80
+ text[i] = normalize(t)
81
+ return text_tokenizer(text)
g2p/g2p/mandarin.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+ import jieba
8
+ import cn2an
9
+ from pypinyin import lazy_pinyin, BOPOMOFO
10
+ from typing import List
11
+ from g2p.g2p.chinese_model_g2p import BertPolyPredict
12
+ from g2p.utils.front_utils import *
13
+ import os
14
+
15
+ # from g2pw import G2PWConverter
16
+
17
+
18
+ # set blank level, {0๏ผš"none",1:"char", 2:"word"}
19
+ BLANK_LEVEL = 0
20
+
21
+ # conv = G2PWConverter(style='pinyin', enable_non_tradional_chinese=True)
22
+ resource_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
23
+ poly_all_class_path = os.path.join(
24
+ resource_path, "sources", "g2p_chinese_model", "polychar.txt"
25
+ )
26
+ if not os.path.exists(poly_all_class_path):
27
+ print(
28
+ "Incorrect path for polyphonic character class dictionary: {}, please check...".format(
29
+ poly_all_class_path
30
+ )
31
+ )
32
+ exit()
33
+ poly_dict = generate_poly_lexicon(poly_all_class_path)
34
+
35
+ # Set up G2PW model parameters
36
+ g2pw_poly_model_path = os.path.join(resource_path, "sources", "g2p_chinese_model")
37
+ if not os.path.exists(g2pw_poly_model_path):
38
+ print(
39
+ "Incorrect path for g2pw polyphonic character model: {}, please check...".format(
40
+ g2pw_poly_model_path
41
+ )
42
+ )
43
+ exit()
44
+
45
+ json_file_path = os.path.join(
46
+ resource_path, "sources", "g2p_chinese_model", "polydict.json"
47
+ )
48
+ if not os.path.exists(json_file_path):
49
+ print(
50
+ "Incorrect path for g2pw id to pinyin dictionary: {}, please check...".format(
51
+ json_file_path
52
+ )
53
+ )
54
+ exit()
55
+
56
+ jsonr_file_path = os.path.join(
57
+ resource_path, "sources", "g2p_chinese_model", "polydict_r.json"
58
+ )
59
+ if not os.path.exists(jsonr_file_path):
60
+ print(
61
+ "Incorrect path for g2pw pinyin to id dictionary: {}, please check...".format(
62
+ jsonr_file_path
63
+ )
64
+ )
65
+ exit()
66
+
67
+ g2pw_poly_predict = BertPolyPredict(
68
+ g2pw_poly_model_path, jsonr_file_path, json_file_path
69
+ )
70
+
71
+
72
+ """
73
+ Text clean time
74
+ """
75
+ # List of (Latin alphabet, bopomofo) pairs:
76
+ _latin_to_bopomofo = [
77
+ (re.compile("%s" % x[0], re.IGNORECASE), x[1])
78
+ for x in [
79
+ ("a", "ใ„Ÿห‰"),
80
+ ("b", "ใ„…ใ„งห‹"),
81
+ ("c", "ใ„™ใ„งห‰"),
82
+ ("d", "ใ„‰ใ„งห‹"),
83
+ ("e", "ใ„งห‹"),
84
+ ("f", "ใ„หŠใ„ˆใ„จห‹"),
85
+ ("g", "ใ„ใ„งห‹"),
86
+ ("h", "ใ„ห‡ใ„‘ใ„ฉห‹"),
87
+ ("i", "ใ„žห‹"),
88
+ ("j", "ใ„ใ„Ÿห‹"),
89
+ ("k", "ใ„Žใ„Ÿห‹"),
90
+ ("l", "ใ„หŠใ„›ห‹"),
91
+ ("m", "ใ„หŠใ„‡ใ„จห‹"),
92
+ ("n", "ใ„ฃห‰"),
93
+ ("o", "ใ„กห‰"),
94
+ ("p", "ใ„†ใ„งห‰"),
95
+ ("q", "ใ„Žใ„งใ„กห‰"),
96
+ ("r", "ใ„šห‹"),
97
+ ("s", "ใ„หŠใ„™ห‹"),
98
+ ("t", "ใ„Šใ„งห‹"),
99
+ ("u", "ใ„งใ„กห‰"),
100
+ ("v", "ใ„จใ„งห‰"),
101
+ ("w", "ใ„‰ใ„šห‹ใ„…ใ„จห‹ใ„Œใ„งใ„กห‹"),
102
+ ("x", "ใ„ห‰ใ„Žใ„จห‹ใ„™ห‹"),
103
+ ("y", "ใ„จใ„žห‹"),
104
+ ("z", "ใ„—ใ„Ÿห‹"),
105
+ ]
106
+ ]
107
+
108
+ # List of (bopomofo, ipa) pairs:
109
+ _bopomofo_to_ipa = [
110
+ (re.compile("%s" % x[0]), x[1])
111
+ for x in [
112
+ ("ใ„…ใ„›", "pโผwo"),
113
+ ("ใ„†ใ„›", "pสฐwo"),
114
+ ("ใ„‡ใ„›", "mwo"),
115
+ ("ใ„ˆใ„›", "fwo"),
116
+ ("ใ„งใ„ข", "|jษ›n"),
117
+ ("ใ„ฉใ„ข", "|ษฅรฆn"),
118
+ ("ใ„งใ„ฃ", "|in"),
119
+ ("ใ„ฉใ„ฃ", "|ษฅn"),
120
+ ("ใ„งใ„ฅ", "|iล‹"),
121
+ ("ใ„จใ„ฅ", "|สŠล‹"),
122
+ ("ใ„ฉใ„ฅ", "|jสŠล‹"),
123
+ # Add
124
+ ("ใ„งใ„š", "|ia"),
125
+ ("ใ„งใ„", "|iษ›"),
126
+ ("ใ„งใ„ ", "|iษ‘สŠ"),
127
+ ("ใ„งใ„ก", "|ioสŠ"),
128
+ ("ใ„งใ„ค", "|iษ‘ล‹"),
129
+ ("ใ„จใ„š", "|ua"),
130
+ ("ใ„จใ„›", "|uo"),
131
+ ("ใ„จใ„ž", "|uaษช"),
132
+ ("ใ„จใ„Ÿ", "|ueษช"),
133
+ ("ใ„จใ„ข", "|uan"),
134
+ ("ใ„จใ„ฃ", "|uษ™n"),
135
+ ("ใ„จใ„ค", "|uษ‘ล‹"),
136
+ ("ใ„ฉใ„", "|ษฅษ›"),
137
+ # End
138
+ ("ใ„…", "pโผ"),
139
+ ("ใ„†", "pสฐ"),
140
+ ("ใ„‡", "m"),
141
+ ("ใ„ˆ", "f"),
142
+ ("ใ„‰", "tโผ"),
143
+ ("ใ„Š", "tสฐ"),
144
+ ("ใ„‹", "n"),
145
+ ("ใ„Œ", "l"),
146
+ ("ใ„", "kโผ"),
147
+ ("ใ„Ž", "kสฐ"),
148
+ ("ใ„", "x"),
149
+ ("ใ„", "tสƒโผ"),
150
+ ("ใ„‘", "tสƒสฐ"),
151
+ ("ใ„’", "สƒ"),
152
+ ("ใ„“", "ts`โผ"),
153
+ ("ใ„”", "ts`สฐ"),
154
+ ("ใ„•", "s`"),
155
+ ("ใ„–", "ษน`"),
156
+ ("ใ„—", "tsโผ"),
157
+ ("ใ„˜", "tsสฐ"),
158
+ ("ใ„™", "|s"),
159
+ ("ใ„š", "|a"),
160
+ ("ใ„›", "|o"),
161
+ ("ใ„œ", "|ษ™"),
162
+ ("ใ„", "|ษ›"),
163
+ ("ใ„ž", "|aษช"),
164
+ ("ใ„Ÿ", "|eษช"),
165
+ ("ใ„ ", "|ษ‘สŠ"),
166
+ ("ใ„ก", "|oสŠ"),
167
+ ("ใ„ข", "|an"),
168
+ ("ใ„ฃ", "|ษ™n"),
169
+ ("ใ„ค", "|ษ‘ล‹"),
170
+ ("ใ„ฅ", "|ษ™ล‹"),
171
+ ("ใ„ฆ", "ษ™ษน"),
172
+ ("ใ„ง", "|i"),
173
+ ("ใ„จ", "|u"),
174
+ ("ใ„ฉ", "|ษฅ"),
175
+ ("ห‰", "โ†’|"),
176
+ ("หŠ", "โ†‘|"),
177
+ ("ห‡", "โ†“โ†‘|"),
178
+ ("ห‹", "โ†“|"),
179
+ ("ห™", "|"),
180
+ ]
181
+ ]
182
+ must_not_er_words = {"ๅฅณๅ„ฟ", "่€ๅ„ฟ", "็”ทๅ„ฟ", "ๅฐ‘ๅ„ฟ", "ๅฐๅ„ฟ"}
183
+
184
+ word_pinyin_dict = {}
185
+ with open(
186
+ os.path.join(resource_path, "sources", "chinese_lexicon.txt"), "r", encoding="utf-8"
187
+ ) as fread:
188
+ txt_list = fread.readlines()
189
+ for txt in txt_list:
190
+ word, pinyin = txt.strip().split("\t")
191
+ word_pinyin_dict[word] = pinyin
192
+ fread.close()
193
+
194
+ pinyin_2_bopomofo_dict = {}
195
+ with open(
196
+ os.path.join(resource_path, "sources", "pinyin_2_bpmf.txt"), "r", encoding="utf-8"
197
+ ) as fread:
198
+ txt_list = fread.readlines()
199
+ for txt in txt_list:
200
+ pinyin, bopomofo = txt.strip().split("\t")
201
+ pinyin_2_bopomofo_dict[pinyin] = bopomofo
202
+ fread.close()
203
+
204
+ tone_dict = {
205
+ "0": "ห™",
206
+ "5": "ห™",
207
+ "1": "",
208
+ "2": "หŠ",
209
+ "3": "ห‡",
210
+ "4": "ห‹",
211
+ }
212
+
213
+ bopomofos2pinyin_dict = {}
214
+ with open(
215
+ os.path.join(resource_path, "sources", "bpmf_2_pinyin.txt"), "r", encoding="utf-8"
216
+ ) as fread:
217
+ txt_list = fread.readlines()
218
+ for txt in txt_list:
219
+ v, k = txt.strip().split("\t")
220
+ bopomofos2pinyin_dict[k] = v
221
+ fread.close()
222
+
223
+
224
+ def bpmf_to_pinyin(text):
225
+ bopomofo_list = text.split("|")
226
+ pinyin_list = []
227
+ for info in bopomofo_list:
228
+ pinyin = ""
229
+ for c in info:
230
+ if c in bopomofos2pinyin_dict:
231
+ pinyin += bopomofos2pinyin_dict[c]
232
+ if len(pinyin) == 0:
233
+ continue
234
+ if pinyin[-1] not in "01234":
235
+ pinyin += "1"
236
+ if pinyin[:-1] == "ve":
237
+ pinyin = "y" + pinyin
238
+ if pinyin[:-1] == "sh":
239
+ pinyin = pinyin[:-1] + "i" + pinyin[-1]
240
+ if pinyin == "sh":
241
+ pinyin = pinyin[:-1] + "i"
242
+ if pinyin[:-1] == "s":
243
+ pinyin = "si" + pinyin[-1]
244
+ if pinyin[:-1] == "c":
245
+ pinyin = "ci" + pinyin[-1]
246
+ if pinyin[:-1] == "i":
247
+ pinyin = "yi" + pinyin[-1]
248
+ if pinyin[:-1] == "iou":
249
+ pinyin = "you" + pinyin[-1]
250
+ if pinyin[:-1] == "ien":
251
+ pinyin = "yin" + pinyin[-1]
252
+ if "iou" in pinyin and pinyin[-4:-1] == "iou":
253
+ pinyin = pinyin[:-4] + "iu" + pinyin[-1]
254
+ if "uei" in pinyin:
255
+ if pinyin[:-1] == "uei":
256
+ pinyin = "wei" + pinyin[-1]
257
+ elif pinyin[-4:-1] == "uei":
258
+ pinyin = pinyin[:-4] + "ui" + pinyin[-1]
259
+ if "uen" in pinyin and pinyin[-4:-1] == "uen":
260
+ if pinyin[:-1] == "uen":
261
+ pinyin = "wen" + pinyin[-1]
262
+ elif pinyin[-4:-1] == "uei":
263
+ pinyin = pinyin[:-4] + "un" + pinyin[-1]
264
+ if "van" in pinyin and pinyin[-4:-1] == "van":
265
+ if pinyin[:-1] == "van":
266
+ pinyin = "yuan" + pinyin[-1]
267
+ elif pinyin[-4:-1] == "van":
268
+ pinyin = pinyin[:-4] + "uan" + pinyin[-1]
269
+ if "ueng" in pinyin and pinyin[-5:-1] == "ueng":
270
+ pinyin = pinyin[:-5] + "ong" + pinyin[-1]
271
+ if pinyin[:-1] == "veng":
272
+ pinyin = "yong" + pinyin[-1]
273
+ if "veng" in pinyin and pinyin[-5:-1] == "veng":
274
+ pinyin = pinyin[:-5] + "iong" + pinyin[-1]
275
+ if pinyin[:-1] == "ieng":
276
+ pinyin = "ying" + pinyin[-1]
277
+ if pinyin[:-1] == "u":
278
+ pinyin = "wu" + pinyin[-1]
279
+ if pinyin[:-1] == "v":
280
+ pinyin = "yv" + pinyin[-1]
281
+ if pinyin[:-1] == "ing":
282
+ pinyin = "ying" + pinyin[-1]
283
+ if pinyin[:-1] == "z":
284
+ pinyin = "zi" + pinyin[-1]
285
+ if pinyin[:-1] == "zh":
286
+ pinyin = "zhi" + pinyin[-1]
287
+ if pinyin[0] == "u":
288
+ pinyin = "w" + pinyin[1:]
289
+ if pinyin[0] == "i":
290
+ pinyin = "y" + pinyin[1:]
291
+ pinyin = pinyin.replace("ien", "in")
292
+
293
+ pinyin_list.append(pinyin)
294
+ return " ".join(pinyin_list)
295
+
296
+
297
+ # Convert numbers to Chinese pronunciation
298
+ def number_to_chinese(text):
299
+ # numbers = re.findall(r'\d+(?:\.?\d+)?', text)
300
+ # for number in numbers:
301
+ # text = text.replace(number, cn2an.an2cn(number), 1)
302
+ text = cn2an.transform(text, "an2cn")
303
+ return text
304
+
305
+
306
+ def normalization(text):
307
+ text = text.replace("๏ผŒ", ",")
308
+ text = text.replace("ใ€‚", ".")
309
+ text = text.replace("๏ผ", "!")
310
+ text = text.replace("๏ผŸ", "?")
311
+ text = text.replace("๏ผ›", ";")
312
+ text = text.replace("๏ผš", ":")
313
+ text = text.replace("ใ€", ",")
314
+ text = text.replace("โ€˜", "'")
315
+ text = text.replace("โ€™", "'")
316
+ text = text.replace("โ‹ฏ", "โ€ฆ")
317
+ text = text.replace("ยทยทยท", "โ€ฆ")
318
+ text = text.replace("ใƒปใƒปใƒป", "โ€ฆ")
319
+ text = text.replace("...", "โ€ฆ")
320
+ text = re.sub(r"\s+", "", text)
321
+ text = re.sub(r"[^\u4e00-\u9fff\s_,\.\?!;:\'โ€ฆ]", "", text)
322
+ text = re.sub(r"\s*([,\.\?!;:\'โ€ฆ])\s*", r"\1", text)
323
+ return text
324
+
325
+
326
+ def change_tone(bopomofo: str, tone: str) -> str:
327
+ if bopomofo[-1] not in "ห™หŠห‡ห‹":
328
+ bopomofo = bopomofo + tone
329
+ else:
330
+ bopomofo = bopomofo[:-1] + tone
331
+ return bopomofo
332
+
333
+
334
+ def er_sandhi(word: str, bopomofos: List[str]) -> List[str]:
335
+ if len(word) > 1 and word[-1] == "ๅ„ฟ" and word not in must_not_er_words:
336
+ bopomofos[-1] = change_tone(bopomofos[-1], "ห™")
337
+ return bopomofos
338
+
339
+
340
+ def bu_sandhi(word: str, bopomofos: List[str]) -> List[str]:
341
+ valid_char = set(word)
342
+ if len(valid_char) == 1 and "ไธ" in valid_char:
343
+ pass
344
+ elif word in ["ไธๅญ—"]:
345
+ pass
346
+ elif len(word) == 3 and word[1] == "ไธ" and bopomofos[1][:-1] == "ใ„…ใ„จ":
347
+ bopomofos[1] = bopomofos[1][:-1] + "ห™"
348
+ else:
349
+ for i, char in enumerate(word):
350
+ if (
351
+ i + 1 < len(bopomofos)
352
+ and char == "ไธ"
353
+ and i + 1 < len(word)
354
+ and 0 < len(bopomofos[i + 1])
355
+ and bopomofos[i + 1][-1] == "ห‹"
356
+ ):
357
+ bopomofos[i] = bopomofos[i][:-1] + "หŠ"
358
+ return bopomofos
359
+
360
+
361
+ def yi_sandhi(word: str, bopomofos: List[str]) -> List[str]:
362
+ punc = "๏ผš๏ผŒ๏ผ›ใ€‚๏ผŸ๏ผโ€œโ€โ€˜โ€™':,;.?!()๏ผˆ๏ผ‰{}ใ€ใ€‘[]-~`ใ€ "
363
+ if word.find("ไธ€") != -1 and any(
364
+ [item.isnumeric() for item in word if item != "ไธ€"]
365
+ ):
366
+ for i in range(len(word)):
367
+ if (
368
+ i == 0
369
+ and word[0] == "ไธ€"
370
+ and len(word) > 1
371
+ and word[1]
372
+ not in [
373
+ "้›ถ",
374
+ "ไธ€",
375
+ "ไบŒ",
376
+ "ไธ‰",
377
+ "ๅ››",
378
+ "ไบ”",
379
+ "ๅ…ญ",
380
+ "ไธƒ",
381
+ "ๅ…ซ",
382
+ "ไน",
383
+ "ๅ",
384
+ ]
385
+ ):
386
+ if len(bopomofos[0]) > 0 and bopomofos[1][-1] in ["ห‹", "ห™"]:
387
+ bopomofos[0] = change_tone(bopomofos[0], "หŠ")
388
+ else:
389
+ bopomofos[0] = change_tone(bopomofos[0], "ห‹")
390
+ elif word[i] == "ไธ€":
391
+ bopomofos[i] = change_tone(bopomofos[i], "")
392
+ return bopomofos
393
+ elif len(word) == 3 and word[1] == "ไธ€" and word[0] == word[-1]:
394
+ bopomofos[1] = change_tone(bopomofos[1], "ห™")
395
+ elif word.startswith("็ฌฌไธ€"):
396
+ bopomofos[1] = change_tone(bopomofos[1], "")
397
+ elif word.startswith("ไธ€ๆœˆ") or word.startswith("ไธ€ๆ—ฅ") or word.startswith("ไธ€ๅท"):
398
+ bopomofos[0] = change_tone(bopomofos[0], "")
399
+ else:
400
+ for i, char in enumerate(word):
401
+ if char == "ไธ€" and i + 1 < len(word):
402
+ if (
403
+ len(bopomofos) > i + 1
404
+ and len(bopomofos[i + 1]) > 0
405
+ and bopomofos[i + 1][-1] in {"ห‹"}
406
+ ):
407
+ bopomofos[i] = change_tone(bopomofos[i], "หŠ")
408
+ else:
409
+ if word[i + 1] not in punc:
410
+ bopomofos[i] = change_tone(bopomofos[i], "ห‹")
411
+ else:
412
+ pass
413
+ return bopomofos
414
+
415
+
416
+ def merge_bu(seg: List) -> List:
417
+ new_seg = []
418
+ last_word = ""
419
+ for word in seg:
420
+ if word != "ไธ":
421
+ if last_word == "ไธ":
422
+ word = last_word + word
423
+ new_seg.append(word)
424
+ last_word = word
425
+ return new_seg
426
+
427
+
428
+ def merge_er(seg: List) -> List:
429
+ new_seg = []
430
+ for i, word in enumerate(seg):
431
+ if i - 1 >= 0 and word == "ๅ„ฟ":
432
+ new_seg[-1] = new_seg[-1] + seg[i]
433
+ else:
434
+ new_seg.append(word)
435
+ return new_seg
436
+
437
+
438
+ def merge_yi(seg: List) -> List:
439
+ new_seg = []
440
+ # function 1
441
+ for i, word in enumerate(seg):
442
+ if (
443
+ i - 1 >= 0
444
+ and word == "ไธ€"
445
+ and i + 1 < len(seg)
446
+ and seg[i - 1] == seg[i + 1]
447
+ ):
448
+ if i - 1 < len(new_seg):
449
+ new_seg[i - 1] = new_seg[i - 1] + "ไธ€" + new_seg[i - 1]
450
+ else:
451
+ new_seg.append(word)
452
+ new_seg.append(seg[i + 1])
453
+ else:
454
+ if i - 2 >= 0 and seg[i - 1] == "ไธ€" and seg[i - 2] == word:
455
+ continue
456
+ else:
457
+ new_seg.append(word)
458
+ seg = new_seg
459
+ new_seg = []
460
+ isnumeric_flag = False
461
+ for i, word in enumerate(seg):
462
+ if all([item.isnumeric() for item in word]) and not isnumeric_flag:
463
+ isnumeric_flag = True
464
+ new_seg.append(word)
465
+ else:
466
+ new_seg.append(word)
467
+ seg = new_seg
468
+ new_seg = []
469
+ # function 2
470
+ for i, word in enumerate(seg):
471
+ if new_seg and new_seg[-1] == "ไธ€":
472
+ new_seg[-1] = new_seg[-1] + word
473
+ else:
474
+ new_seg.append(word)
475
+ return new_seg
476
+
477
+
478
+ # Word Segmentation, and convert Chinese pronunciation to pinyin (bopomofo)
479
+ def chinese_to_bopomofo(text_short, sentence):
480
+ # bopomofos = conv(text_short)
481
+ words = jieba.lcut(text_short, cut_all=False)
482
+ words = merge_yi(words)
483
+ words = merge_bu(words)
484
+ words = merge_er(words)
485
+ text = ""
486
+
487
+ char_index = 0
488
+ for word in words:
489
+ bopomofos = []
490
+ if word in word_pinyin_dict and word not in poly_dict:
491
+ pinyin = word_pinyin_dict[word]
492
+ for py in pinyin.split(" "):
493
+ if py[:-1] in pinyin_2_bopomofo_dict and py[-1] in tone_dict:
494
+ bopomofos.append(
495
+ pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
496
+ )
497
+ if BLANK_LEVEL == 1:
498
+ bopomofos.append("_")
499
+ else:
500
+ bopomofos_lazy = lazy_pinyin(word, BOPOMOFO)
501
+ bopomofos += bopomofos_lazy
502
+ if BLANK_LEVEL == 1:
503
+ bopomofos.append("_")
504
+ else:
505
+ for i in range(len(word)):
506
+ c = word[i]
507
+ if c in poly_dict:
508
+ poly_pinyin = g2pw_poly_predict.predict_process(
509
+ [text_short, char_index + i]
510
+ )[0]
511
+ py = poly_pinyin[2:-1]
512
+ bopomofos.append(
513
+ pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
514
+ )
515
+ if BLANK_LEVEL == 1:
516
+ bopomofos.append("_")
517
+ elif c in word_pinyin_dict:
518
+ py = word_pinyin_dict[c]
519
+ bopomofos.append(
520
+ pinyin_2_bopomofo_dict[py[:-1]] + tone_dict[py[-1]]
521
+ )
522
+ if BLANK_LEVEL == 1:
523
+ bopomofos.append("_")
524
+ else:
525
+ bopomofos.append(c)
526
+ if BLANK_LEVEL == 1:
527
+ bopomofos.append("_")
528
+ if BLANK_LEVEL == 2:
529
+ bopomofos.append("_")
530
+ char_index += len(word)
531
+
532
+ if (
533
+ len(word) == 3
534
+ and bopomofos[0][-1] == "ห‡"
535
+ and bopomofos[1][-1] == "ห‡"
536
+ and bopomofos[-1][-1] == "ห‡"
537
+ ):
538
+ bopomofos[0] = bopomofos[0] + "หŠ"
539
+ bopomofos[1] = bopomofos[1] + "หŠ"
540
+ if len(word) == 2 and bopomofos[0][-1] == "ห‡" and bopomofos[-1][-1] == "ห‡":
541
+ bopomofos[0] = bopomofos[0][:-1] + "หŠ"
542
+ bopomofos = bu_sandhi(word, bopomofos)
543
+ bopomofos = yi_sandhi(word, bopomofos)
544
+ bopomofos = er_sandhi(word, bopomofos)
545
+ if not re.search("[\u4e00-\u9fff]", word):
546
+ text += "|" + word
547
+ continue
548
+ for i in range(len(bopomofos)):
549
+ bopomofos[i] = re.sub(r"([\u3105-\u3129])$", r"\1ห‰", bopomofos[i])
550
+ if text != "":
551
+ text += "|"
552
+ text += "|".join(bopomofos)
553
+ return text
554
+
555
+
556
+ # Convert latin pronunciation to pinyin (bopomofo)
557
+ def latin_to_bopomofo(text):
558
+ for regex, replacement in _latin_to_bopomofo:
559
+ text = re.sub(regex, replacement, text)
560
+ return text
561
+
562
+
563
+ # Convert pinyin (bopomofo) to IPA
564
+ def bopomofo_to_ipa(text):
565
+ for regex, replacement in _bopomofo_to_ipa:
566
+ text = re.sub(regex, replacement, text)
567
+ return text
568
+
569
+
570
+ def _chinese_to_ipa(text, sentence):
571
+ text = re.sub(r"\s", "_", text)
572
+
573
+ text = number_to_chinese(text.strip())
574
+ text = normalization(text)
575
+ text = chinese_to_bopomofo(text, sentence)
576
+ # pinyin = bpmf_to_pinyin(text)
577
+ text = latin_to_bopomofo(text)
578
+ text = bopomofo_to_ipa(text)
579
+ text = re.sub("([sษน]`[โผสฐ]?)([โ†’โ†“โ†‘ ]+|$)", r"\1ษน\2", text)
580
+ text = re.sub("([s][โผสฐ]?)([โ†’โ†“โ†‘ ]+|$)", r"\1ษน\2", text)
581
+ text = re.sub(r"^\||[^\w\s_,\.\?!;:\'โ€ฆ\|โ†’โ†“โ†‘โผสฐ`]", "", text)
582
+ text = re.sub(r"([,\.\?!;:\'โ€ฆ])", r"|\1|", text)
583
+ text = re.sub(r"\|+", "|", text)
584
+ text = text.rstrip("|")
585
+ return text
586
+
587
+
588
+ # Convert Chinese to IPA
589
+ def chinese_to_ipa(text, sentence, text_tokenizer):
590
+ # phonemes = text_tokenizer(text.strip())
591
+ if type(text) == str:
592
+ return _chinese_to_ipa(text, sentence)
593
+ else:
594
+ result_ph = []
595
+ for t in text:
596
+ result_ph.append(_chinese_to_ipa(t, sentence))
597
+ return result_ph
g2p/g2p/text_tokenizers.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import re
7
+ import os
8
+ from typing import List, Pattern, Union
9
+ from phonemizer.utils import list2str, str2list
10
+ from phonemizer.backend import EspeakBackend
11
+ from phonemizer.backend.espeak.language_switch import LanguageSwitch
12
+ from phonemizer.backend.espeak.words_mismatch import WordMismatch
13
+ from phonemizer.punctuation import Punctuation
14
+ from phonemizer.separator import Separator
15
+
16
+
17
+ class TextTokenizer:
18
+ """Phonemize Text."""
19
+
20
+ def __init__(
21
+ self,
22
+ language="en-us",
23
+ backend="espeak",
24
+ separator=Separator(word="|_|", syllable="-", phone="|"),
25
+ preserve_punctuation=True,
26
+ with_stress: bool = False,
27
+ tie: Union[bool, str] = False,
28
+ language_switch: LanguageSwitch = "remove-flags",
29
+ words_mismatch: WordMismatch = "ignore",
30
+ ) -> None:
31
+ self.preserve_punctuation_marks = ",.?!;:'โ€ฆ"
32
+ self.backend = EspeakBackend(
33
+ language,
34
+ punctuation_marks=self.preserve_punctuation_marks,
35
+ preserve_punctuation=preserve_punctuation,
36
+ with_stress=with_stress,
37
+ tie=tie,
38
+ language_switch=language_switch,
39
+ words_mismatch=words_mismatch,
40
+ )
41
+
42
+ self.separator = separator
43
+
44
+ # convert chinese punctuation to english punctuation
45
+ def convert_chinese_punctuation(self, text: str) -> str:
46
+ text = text.replace("๏ผŒ", ",")
47
+ text = text.replace("ใ€‚", ".")
48
+ text = text.replace("๏ผ", "!")
49
+ text = text.replace("๏ผŸ", "?")
50
+ text = text.replace("๏ผ›", ";")
51
+ text = text.replace("๏ผš", ":")
52
+ text = text.replace("ใ€", ",")
53
+ text = text.replace("โ€˜", "'")
54
+ text = text.replace("โ€™", "'")
55
+ text = text.replace("โ‹ฏ", "โ€ฆ")
56
+ text = text.replace("ยทยทยท", "โ€ฆ")
57
+ text = text.replace("ใƒปใƒปใƒป", "โ€ฆ")
58
+ text = text.replace("...", "โ€ฆ")
59
+ return text
60
+
61
+ def __call__(self, text, strip=True) -> List[str]:
62
+
63
+ text_type = type(text)
64
+ normalized_text = []
65
+ for line in str2list(text):
66
+ line = self.convert_chinese_punctuation(line.strip())
67
+ line = re.sub(r"[^\w\s_,\.\?!;:\'โ€ฆ]", "", line)
68
+ line = re.sub(r"\s*([,\.\?!;:\'โ€ฆ])\s*", r"\1", line)
69
+ line = re.sub(r"\s+", " ", line)
70
+ normalized_text.append(line)
71
+ # print("Normalized test: ", normalized_text[0])
72
+ phonemized = self.backend.phonemize(
73
+ normalized_text, separator=self.separator, strip=strip, njobs=1
74
+ )
75
+ if text_type == str:
76
+ phonemized = re.sub(r"([,\.\?!;:\'โ€ฆ])", r"|\1|", list2str(phonemized))
77
+ phonemized = re.sub(r"\|+", "|", phonemized)
78
+ phonemized = phonemized.rstrip("|")
79
+ else:
80
+ for i in range(len(phonemized)):
81
+ phonemized[i] = re.sub(r"([,\.\?!;:\'โ€ฆ])", r"|\1|", phonemized[i])
82
+ phonemized[i] = re.sub(r"\|+", "|", phonemized[i])
83
+ phonemized[i] = phonemized[i].rstrip("|")
84
+ return phonemized
g2p/g2p/vocab.json ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab": {
3
+ ",": 0,
4
+ ".": 1,
5
+ "?": 2,
6
+ "!": 3,
7
+ "_": 4,
8
+ "iห": 5,
9
+ "ษช": 6,
10
+ "ษœห": 7,
11
+ "ษš": 8,
12
+ "oหษน": 9,
13
+ "ษ”ห": 10,
14
+ "ษ”หษน": 11,
15
+ "ษ‘ห": 12,
16
+ "uห": 13,
17
+ "สŠ": 14,
18
+ "ษ‘หษน": 15,
19
+ "สŒ": 16,
20
+ "ษ›": 17,
21
+ "รฆ": 18,
22
+ "eษช": 19,
23
+ "aษช": 20,
24
+ "ษ”ษช": 21,
25
+ "aสŠ": 22,
26
+ "oสŠ": 23,
27
+ "ษชษน": 24,
28
+ "ษ›ษน": 25,
29
+ "สŠษน": 26,
30
+ "p": 27,
31
+ "b": 28,
32
+ "t": 29,
33
+ "d": 30,
34
+ "k": 31,
35
+ "ษก": 32,
36
+ "f": 33,
37
+ "v": 34,
38
+ "ฮธ": 35,
39
+ "รฐ": 36,
40
+ "s": 37,
41
+ "z": 38,
42
+ "สƒ": 39,
43
+ "ส’": 40,
44
+ "h": 41,
45
+ "tสƒ": 42,
46
+ "dส’": 43,
47
+ "m": 44,
48
+ "n": 45,
49
+ "ล‹": 46,
50
+ "j": 47,
51
+ "w": 48,
52
+ "ษน": 49,
53
+ "l": 50,
54
+ "tษน": 51,
55
+ "dษน": 52,
56
+ "ts": 53,
57
+ "dz": 54,
58
+ "i": 55,
59
+ "ษ”": 56,
60
+ "ษ™": 57,
61
+ "ษพ": 58,
62
+ "iษ™": 59,
63
+ "r": 60,
64
+ "u": 61,
65
+ "oห": 62,
66
+ "ษ›ห": 63,
67
+ "ษชห": 64,
68
+ "aษชษ™": 65,
69
+ "aษชษš": 66,
70
+ "ษ‘ฬƒ": 67,
71
+ "รง": 68,
72
+ "ษ”ฬƒ": 69,
73
+ "รฆรฆ": 70,
74
+ "ษษ": 71,
75
+ "ษกสฒ": 72,
76
+ "nสฒ": 73,
77
+ "iหห": 74,
78
+
79
+ "pโผ": 75,
80
+ "pสฐ": 76,
81
+ "tโผ": 77,
82
+ "tสฐ": 78,
83
+ "kโผ": 79,
84
+ "kสฐ": 80,
85
+ "x": 81,
86
+ "tสƒโผ": 82,
87
+ "tสƒสฐ": 83,
88
+ "ts`โผ": 84,
89
+ "ts`สฐ": 85,
90
+ "s`": 86,
91
+ "ษน`": 87,
92
+ "tsโผ": 88,
93
+ "tsสฐ": 89,
94
+ "pโผwo": 90,
95
+ "pโผwoโ†’": 91,
96
+ "pโผwoโ†‘": 92,
97
+ "pโผwoโ†“โ†‘": 93,
98
+ "pโผwoโ†“": 94,
99
+ "pสฐwo": 95,
100
+ "pสฐwoโ†’": 96,
101
+ "pสฐwoโ†‘": 97,
102
+ "pสฐwoโ†“โ†‘": 98,
103
+ "pสฐwoโ†“": 99,
104
+ "mwo": 100,
105
+ "mwoโ†’": 101,
106
+ "mwoโ†‘": 102,
107
+ "mwoโ†“โ†‘": 103,
108
+ "mwoโ†“": 104,
109
+ "fwo": 105,
110
+ "fwoโ†’": 106,
111
+ "fwoโ†‘": 107,
112
+ "fwoโ†“โ†‘": 108,
113
+ "fwoโ†“": 109,
114
+ "jษ›n": 110,
115
+ "jษ›nโ†’": 111,
116
+ "jษ›nโ†‘": 112,
117
+ "jษ›nโ†“โ†‘": 113,
118
+ "jษ›nโ†“": 114,
119
+ "ษฅรฆn": 115,
120
+ "ษฅรฆnโ†’": 116,
121
+ "ษฅรฆnโ†‘": 117,
122
+ "ษฅรฆnโ†“โ†‘": 118,
123
+ "ษฅรฆnโ†“": 119,
124
+ "in": 120,
125
+ "inโ†’": 121,
126
+ "inโ†‘": 122,
127
+ "inโ†“โ†‘": 123,
128
+ "inโ†“": 124,
129
+ "ษฅn": 125,
130
+ "ษฅnโ†’": 126,
131
+ "ษฅnโ†‘": 127,
132
+ "ษฅnโ†“โ†‘": 128,
133
+ "ษฅnโ†“": 129,
134
+ "iล‹": 130,
135
+ "iล‹โ†’": 131,
136
+ "iล‹โ†‘": 132,
137
+ "iล‹โ†“โ†‘": 133,
138
+ "iล‹โ†“": 134,
139
+ "สŠล‹": 135,
140
+ "สŠล‹โ†’": 136,
141
+ "สŠล‹โ†‘": 137,
142
+ "สŠล‹โ†“โ†‘": 138,
143
+ "สŠล‹โ†“": 139,
144
+ "jสŠล‹": 140,
145
+ "jสŠล‹โ†’": 141,
146
+ "jสŠล‹โ†‘": 142,
147
+ "jสŠล‹โ†“โ†‘": 143,
148
+ "jสŠล‹โ†“": 144,
149
+ "ia": 145,
150
+ "iaโ†’": 146,
151
+ "iaโ†‘": 147,
152
+ "iaโ†“โ†‘": 148,
153
+ "iaโ†“": 149,
154
+ "iษ›": 150,
155
+ "iษ›โ†’": 151,
156
+ "iษ›โ†‘": 152,
157
+ "iษ›โ†“โ†‘": 153,
158
+ "iษ›โ†“": 154,
159
+ "iษ‘สŠ": 155,
160
+ "iษ‘สŠโ†’": 156,
161
+ "iษ‘สŠโ†‘": 157,
162
+ "iษ‘สŠโ†“โ†‘": 158,
163
+ "iษ‘สŠโ†“": 159,
164
+ "ioสŠ": 160,
165
+ "ioสŠโ†’": 161,
166
+ "ioสŠโ†‘": 162,
167
+ "ioสŠโ†“โ†‘": 163,
168
+ "ioสŠโ†“": 164,
169
+ "iษ‘ล‹": 165,
170
+ "iษ‘ล‹โ†’": 166,
171
+ "iษ‘ล‹โ†‘": 167,
172
+ "iษ‘ล‹โ†“โ†‘": 168,
173
+ "iษ‘ล‹โ†“": 169,
174
+ "ua": 170,
175
+ "uaโ†’": 171,
176
+ "uaโ†‘": 172,
177
+ "uaโ†“โ†‘": 173,
178
+ "uaโ†“": 174,
179
+ "uo": 175,
180
+ "uoโ†’": 176,
181
+ "uoโ†‘": 177,
182
+ "uoโ†“โ†‘": 178,
183
+ "uoโ†“": 179,
184
+ "uaษช": 180,
185
+ "uaษชโ†’": 181,
186
+ "uaษชโ†‘": 182,
187
+ "uaษชโ†“โ†‘": 183,
188
+ "uaษชโ†“": 184,
189
+ "ueษช": 185,
190
+ "ueษชโ†’": 186,
191
+ "ueษชโ†‘": 187,
192
+ "ueษชโ†“โ†‘": 188,
193
+ "ueษชโ†“": 189,
194
+ "uan": 190,
195
+ "uanโ†’": 191,
196
+ "uanโ†‘": 192,
197
+ "uanโ†“โ†‘": 193,
198
+ "uanโ†“": 194,
199
+ "uษ™n": 195,
200
+ "uษ™nโ†’": 196,
201
+ "uษ™nโ†‘": 197,
202
+ "uษ™nโ†“โ†‘": 198,
203
+ "uษ™nโ†“": 199,
204
+ "uษ‘ล‹": 200,
205
+ "uษ‘ล‹โ†’": 201,
206
+ "uษ‘ล‹โ†‘": 202,
207
+ "uษ‘ล‹โ†“โ†‘": 203,
208
+ "uษ‘ล‹โ†“": 204,
209
+ "ษฅษ›": 205,
210
+ "ษฅษ›โ†’": 206,
211
+ "ษฅษ›โ†‘": 207,
212
+ "ษฅษ›โ†“โ†‘": 208,
213
+ "ษฅษ›โ†“": 209,
214
+ "a": 210,
215
+ "aโ†’": 211,
216
+ "aโ†‘": 212,
217
+ "aโ†“โ†‘": 213,
218
+ "aโ†“": 214,
219
+ "o": 215,
220
+ "oโ†’": 216,
221
+ "oโ†‘": 217,
222
+ "oโ†“โ†‘": 218,
223
+ "oโ†“": 219,
224
+ "ษ™โ†’": 220,
225
+ "ษ™โ†‘": 221,
226
+ "ษ™โ†“โ†‘": 222,
227
+ "ษ™โ†“": 223,
228
+ "ษ›โ†’": 224,
229
+ "ษ›โ†‘": 225,
230
+ "ษ›โ†“โ†‘": 226,
231
+ "ษ›โ†“": 227,
232
+ "aษชโ†’": 228,
233
+ "aษชโ†‘": 229,
234
+ "aษชโ†“โ†‘": 230,
235
+ "aษชโ†“": 231,
236
+ "eษชโ†’": 232,
237
+ "eษชโ†‘": 233,
238
+ "eษชโ†“โ†‘": 234,
239
+ "eษชโ†“": 235,
240
+ "ษ‘สŠ": 236,
241
+ "ษ‘สŠโ†’": 237,
242
+ "ษ‘สŠโ†‘": 238,
243
+ "ษ‘สŠโ†“โ†‘": 239,
244
+ "ษ‘สŠโ†“": 240,
245
+ "oสŠโ†’": 241,
246
+ "oสŠโ†‘": 242,
247
+ "oสŠโ†“โ†‘": 243,
248
+ "oสŠโ†“": 244,
249
+ "an": 245,
250
+ "anโ†’": 246,
251
+ "anโ†‘": 247,
252
+ "anโ†“โ†‘": 248,
253
+ "anโ†“": 249,
254
+ "ษ™n": 250,
255
+ "ษ™nโ†’": 251,
256
+ "ษ™nโ†‘": 252,
257
+ "ษ™nโ†“โ†‘": 253,
258
+ "ษ™nโ†“": 254,
259
+ "ษ‘ล‹": 255,
260
+ "ษ‘ล‹โ†’": 256,
261
+ "ษ‘ล‹โ†‘": 257,
262
+ "ษ‘ล‹โ†“โ†‘": 258,
263
+ "ษ‘ล‹โ†“": 259,
264
+ "ษ™ล‹": 260,
265
+ "ษ™ล‹โ†’": 261,
266
+ "ษ™ล‹โ†‘": 262,
267
+ "ษ™ล‹โ†“โ†‘": 263,
268
+ "ษ™ล‹โ†“": 264,
269
+ "ษ™ษน": 265,
270
+ "ษ™ษนโ†’": 266,
271
+ "ษ™ษนโ†‘": 267,
272
+ "ษ™ษนโ†“โ†‘": 268,
273
+ "ษ™ษนโ†“": 269,
274
+ "iโ†’": 270,
275
+ "iโ†‘": 271,
276
+ "iโ†“โ†‘": 272,
277
+ "iโ†“": 273,
278
+ "uโ†’": 274,
279
+ "uโ†‘": 275,
280
+ "uโ†“โ†‘": 276,
281
+ "uโ†“": 277,
282
+ "ษฅ": 278,
283
+ "ษฅโ†’": 279,
284
+ "ษฅโ†‘": 280,
285
+ "ษฅโ†“โ†‘": 281,
286
+ "ษฅโ†“": 282,
287
+ "ts`โผษน": 283,
288
+ "ts`โผษนโ†’": 284,
289
+ "ts`โผษนโ†‘": 285,
290
+ "ts`โผษนโ†“โ†‘": 286,
291
+ "ts`โผษนโ†“": 287,
292
+ "ts`สฐษน": 288,
293
+ "ts`สฐษนโ†’": 289,
294
+ "ts`สฐษนโ†‘": 290,
295
+ "ts`สฐษนโ†“โ†‘": 291,
296
+ "ts`สฐษนโ†“": 292,
297
+ "s`ษน": 293,
298
+ "s`ษนโ†’": 294,
299
+ "s`ษนโ†‘": 295,
300
+ "s`ษนโ†“โ†‘": 296,
301
+ "s`ษน๏ฟฝ๏ฟฝ๏ฟฝ": 297,
302
+ "ษน`ษน": 298,
303
+ "ษน`ษนโ†’": 299,
304
+ "ษน`ษนโ†‘": 300,
305
+ "ษน`ษนโ†“โ†‘": 301,
306
+ "ษน`ษนโ†“": 302,
307
+ "tsโผษน": 303,
308
+ "tsโผษนโ†’": 304,
309
+ "tsโผษนโ†‘": 305,
310
+ "tsโผษนโ†“โ†‘": 306,
311
+ "tsโผษนโ†“": 307,
312
+ "tsสฐษน": 308,
313
+ "tsสฐษนโ†’": 309,
314
+ "tsสฐษนโ†‘": 310,
315
+ "tsสฐษนโ†“โ†‘": 311,
316
+ "tsสฐษนโ†“": 312,
317
+ "sษน": 313,
318
+ "sษนโ†’": 314,
319
+ "sษนโ†‘": 315,
320
+ "sษนโ†“โ†‘": 316,
321
+ "sษนโ†“": 317,
322
+
323
+ "ษฏ": 318,
324
+ "e": 319,
325
+ "aห": 320,
326
+ "ษฏห": 321,
327
+ "eห": 322,
328
+ "cฬง": 323,
329
+ "ษธ": 324,
330
+ "ษฐแต": 325,
331
+ "ษด": 326,
332
+ "g": 327,
333
+ "dส‘": 328,
334
+ "q": 329,
335
+ "ห": 330,
336
+ "bj": 331,
337
+ "tษ•": 332,
338
+ "dej": 333,
339
+ "tej": 334,
340
+ "gj": 335,
341
+ "gษฏ": 336,
342
+ "cฬงj": 337,
343
+ "kj": 338,
344
+ "kษฏ": 339,
345
+ "mj": 340,
346
+ "nj": 341,
347
+ "pj": 342,
348
+ "ษพj": 343,
349
+ "ษ•": 344,
350
+ "tsษฏ": 345,
351
+
352
+ "ษ": 346,
353
+ "ษ‘": 347,
354
+ "ษ’": 348,
355
+ "ษœ": 349,
356
+ "ษซ": 350,
357
+ "ส‘": 351,
358
+ "สฒ": 352,
359
+
360
+ "y": 353,
361
+ "รธ": 354,
362
+ "ล“": 355,
363
+ "ส": 356,
364
+ "ฬƒ": 357,
365
+ "ษฒ": 358,
366
+
367
+ ":": 359,
368
+ ";": 360,
369
+ "'": 361,
370
+ "โ€ฆ": 362
371
+ }
372
+ }
g2p/g2p_generation.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import sys
8
+
9
+ from g2p.g2p import PhonemeBpeTokenizer
10
+ from g2p.utils.g2p import phonemizer_g2p
11
+ import tqdm
12
+ from typing import List
13
+ import json
14
+ import os
15
+ import re
16
+
17
+
18
+ def ph_g2p(text, language):
19
+
20
+ return phonemizer_g2p(text=text, language=language)
21
+
22
+
23
+ def g2p(text, sentence, language):
24
+
25
+ return text_tokenizer.tokenize(text=text, sentence=sentence, language=language)
26
+
27
+
28
+ def is_chinese(char):
29
+ if char >= "\u4e00" and char <= "\u9fa5":
30
+ return True
31
+ else:
32
+ return False
33
+
34
+
35
+ def is_alphabet(char):
36
+ if (char >= "\u0041" and char <= "\u005a") or (
37
+ char >= "\u0061" and char <= "\u007a"
38
+ ):
39
+ return True
40
+ else:
41
+ return False
42
+
43
+
44
+ def is_other(char):
45
+ if not (is_chinese(char) or is_alphabet(char)):
46
+ return True
47
+ else:
48
+ return False
49
+
50
+
51
+ def get_segment(text: str) -> List[str]:
52
+ # sentence --> [ch_part, en_part, ch_part, ...]
53
+ segments = []
54
+ types = []
55
+ flag = 0
56
+ temp_seg = ""
57
+ temp_lang = ""
58
+
59
+ # Determine the type of each character. type: blank, chinese, alphabet, number, unk and point.
60
+ for i, ch in enumerate(text):
61
+ if is_chinese(ch):
62
+ types.append("zh")
63
+ elif is_alphabet(ch):
64
+ types.append("en")
65
+ else:
66
+ types.append("other")
67
+
68
+ assert len(types) == len(text)
69
+
70
+ for i in range(len(types)):
71
+ # find the first char of the seg
72
+ if flag == 0:
73
+ temp_seg += text[i]
74
+ temp_lang = types[i]
75
+ flag = 1
76
+ else:
77
+ if temp_lang == "other":
78
+ if types[i] == temp_lang:
79
+ temp_seg += text[i]
80
+ else:
81
+ temp_seg += text[i]
82
+ temp_lang = types[i]
83
+ else:
84
+ if types[i] == temp_lang:
85
+ temp_seg += text[i]
86
+ elif types[i] == "other":
87
+ temp_seg += text[i]
88
+ else:
89
+ segments.append((temp_seg, temp_lang))
90
+ temp_seg = text[i]
91
+ temp_lang = types[i]
92
+ flag = 1
93
+
94
+ segments.append((temp_seg, temp_lang))
95
+ return segments
96
+
97
+
98
+ def chn_eng_g2p(text: str):
99
+ # now only en and ch
100
+ segments = get_segment(text)
101
+ all_phoneme = ""
102
+ all_tokens = []
103
+
104
+ for index in range(len(segments)):
105
+ seg = segments[index]
106
+ phoneme, token = g2p(seg[0], text, seg[1])
107
+ all_phoneme += phoneme + "|"
108
+ all_tokens += token
109
+
110
+ if seg[1] == "en" and index == len(segments) - 1 and all_phoneme[-2] == "_":
111
+ all_phoneme = all_phoneme[:-2]
112
+ all_tokens = all_tokens[:-1]
113
+ return all_phoneme, all_tokens
114
+
115
+
116
+ vocab_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "g2p/vocab.json")
117
+ text_tokenizer = PhonemeBpeTokenizer(vacab_path=vocab_path)
118
+ with open(vocab_path, "r") as f:
119
+ json_data = f.read()
120
+ data = json.loads(json_data)
121
+ vocab = data["vocab"]
122
+
123
+ if __name__ == '__main__':
124
+ phone, token = chn_eng_g2p("ไฝ ๅฅฝ๏ผŒhello world")
125
+ phone, token = chn_eng_g2p("ไฝ ๅฅฝ๏ผŒhello world, Bonjour, ํ…Œ์ŠคํŠธ ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค, ไบ”ๆœˆ้›จ็ท‘")
126
+ print(phone)
127
+ print(token)
128
+
129
+ #phone, token = text_tokenizer.tokenize("ไฝ ๅฅฝ๏ผŒhello world, Bonjour, ํ…Œ์ŠคํŠธ ํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค, ไบ”ๆœˆ้›จ็ท‘", "", "auto")
130
+ phone, token = text_tokenizer.tokenize("็ท‘", "", "auto")
131
+ #phone, token = text_tokenizer.tokenize("เค†เค‡เค เค‡เคธเค•เคพ เคชเคฐเฅ€เค•เฅเคทเคฃ เค•เคฐเฅ‡เค‚", "", "auto")
132
+ #phone, token = text_tokenizer.tokenize("เค†เค‡เค เค‡เคธเค•เคพ เคชเคฐเฅ€เค•เฅเคทเคฃ เค•เคฐเฅ‡เค‚", "", "other")
133
+ print(phone)
134
+ print(token)
g2p/language_segmentation/LangSegment.py ADDED
@@ -0,0 +1,865 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file bundles language identification functions.
3
+
4
+ Modifications (fork): Copyright (c) 2021, Adrien Barbaresi.
5
+
6
+ Original code: Copyright (c) 2011 Marco Lui <saffsd@gmail.com>.
7
+ Based on research by Marco Lui and Tim Baldwin.
8
+
9
+ See LICENSE file for more info.
10
+ https://github.com/adbar/py3langid
11
+
12
+ Projects:
13
+ https://github.com/juntaosun/LangSegment
14
+ """
15
+
16
+ import os
17
+ import re
18
+ import sys
19
+ import numpy as np
20
+ from collections import Counter
21
+ from collections import defaultdict
22
+
23
+ # import langid
24
+ # import py3langid as langid
25
+ # pip install py3langid==0.2.2
26
+
27
+ # ๅฏ็”จ่ฏญ่จ€้ข„ๆต‹ๆฆ‚็އๅฝ’ไธ€ๅŒ–๏ผŒๆฆ‚็އ้ข„ๆต‹็š„ๅˆ†ๆ•ฐใ€‚ๅ› ๆญค๏ผŒๅฎž็Žฐ้‡ๆ–ฐ่ง„่ŒƒๅŒ– ไบง็”Ÿ 0-1 ่Œƒๅ›ดๅ†…็š„่พ“ๅ‡บใ€‚
28
+ # langid disables probability normalization by default. For command-line usages of , it can be enabled by passing the flag.
29
+ # For probability normalization in library use, the user must instantiate their own . An example of such usage is as follows:
30
+ from py3langid.langid import LanguageIdentifier, MODEL_FILE
31
+
32
+ # Digital processing
33
+ try:from .utils.num import num2str
34
+ except ImportError:
35
+ try:from utils.num import num2str
36
+ except ImportError as e:
37
+ raise e
38
+
39
+ # -----------------------------------
40
+ # ๆ›ดๆ–ฐๆ—ฅๅฟ—๏ผšๆ–ฐ็‰ˆๆœฌๅˆ†่ฏๆ›ดๅŠ ็ฒพๅ‡†ใ€‚
41
+ # Changelog: The new version of the word segmentation is more accurate.
42
+ # ใƒใ‚งใƒณใ‚ธใƒญใ‚ฐ:ๆ–ฐใ—ใ„ใƒใƒผใ‚ธใƒงใƒณใฎๅ˜่ชžใ‚ปใ‚ฐใƒกใƒณใƒ†ใƒผใ‚ทใƒงใƒณใฏใ‚ˆใ‚Šๆญฃ็ขบใงใ™ใ€‚
43
+ # Changelog: ๋ถ„ํ• ์ด๋ผ๋Š” ๋‹จ์–ด์˜ ์ƒˆ๋กœ์šด ๋ฒ„์ „์ด ๋” ์ •ํ™•ํ•ฉ๋‹ˆ๋‹ค.
44
+ # -----------------------------------
45
+
46
+
47
+ # Word segmentation function:
48
+ # automatically identify and split the words (Chinese/English/Japanese/Korean) in the article or sentence according to different languages,
49
+ # making it more suitable for TTS processing.
50
+ # This code is designed for front-end text multi-lingual mixed annotation distinction, multi-language mixed training and inference of various TTS projects.
51
+ # This processing result is mainly for (Chinese = zh, Japanese = ja, English = en, Korean = ko), and can actually support up to 97 different language mixing processing.
52
+
53
+ #===========================================================================================================
54
+ #ๅˆ†ใ‹ใกๆ›ธใๆฉŸ่ƒฝ:ๆ–‡็ซ ใ‚„ๆ–‡็ซ ใฎไธญใฎไพ‹ใˆใฐ๏ผˆไธญๅ›ฝ่ชž/่‹ฑ่ชž/ๆ—ฅๆœฌ่ชž/้Ÿ“ๅ›ฝ่ชž๏ผ‰ใ‚’ใ€็•ฐใชใ‚‹่จ€่ชžใง่‡ชๅ‹•็š„ใซ่ช่ญ˜ใ—ใฆๅˆ†ๅ‰ฒใ—ใ€TTSๅ‡ฆ็†ใซใ‚ˆใ‚Š้ฉใ—ใŸใ‚‚ใฎใซใ—ใพใ™ใ€‚
55
+ #ใ“ใฎใ‚ณใƒผใƒ‰ใฏใ€ใ•ใพใ–ใพใชTTSใƒ—ใƒญใ‚ธใ‚งใ‚ฏใƒˆใฎใƒ•ใƒญใƒณใƒˆใ‚จใƒณใƒ‰ใƒ†ใ‚ญใ‚นใƒˆใฎๅคš่จ€่ชžๆททๅˆๆณจ้‡ˆๅŒบๅˆฅใ€ๅคš่จ€่ชžๆททๅˆใƒˆใƒฌใƒผใƒ‹ใƒณใ‚ฐใ€ใŠใ‚ˆใณๆŽจ่ซ–ใฎใŸใ‚ใซ็‰นๅˆฅใซไฝœๆˆใ•ใ‚Œใฆใ„ใพใ™ใ€‚
56
+ #===========================================================================================================
57
+ #(1)่‡ชๅ‹•ๅˆ†่ฉž:ใ€Œ้Ÿ“ๅ›ฝ่ชžใงใฏไฝ•ใ‚’่ชญใ‚€ใฎใงใ™ใ‹ใ‚ใชใŸใฎไฝ“่‚ฒใฎๅ…ˆ็”Ÿใฏ่ชฐใงใ™ใ‹?ไปŠๅ›žใฎ็™บ่กจไผšใงใฏใ€iPhone 15ใ‚ทใƒชใƒผใ‚บใฎ4ๆฉŸ็จฎใŒ็™ปๅ ดใ—ใพใ—ใŸใ€
58
+ #๏ผˆ2๏ผ‰ๆ‰‹ๅŠจๅˆ†่ฏ:โ€œใ‚ใชใŸใฎๅๅ‰ใฏ<ja>ไฝใ€…ๆœจใงใ™ใ‹?<ja>ใงใ™ใ‹?โ€
59
+ #ใ“ใฎๅ‡ฆ็†็ตๆžœใฏไธปใซ๏ผˆไธญๅ›ฝ่ชž=jaใ€ๆ—ฅๆœฌ่ชž=jaใ€่‹ฑ่ชž=enใ€้Ÿ“ๅ›ฝ่ชž=ko๏ผ‰ใ‚’ๅฏพ่ฑกใจใ—ใฆใŠใ‚Šใ€ๅฎŸ้š›ใซใฏๆœ€ๅคง97ใฎ็•ฐใชใ‚‹่จ€่ชžใฎๆททๅˆๅ‡ฆ็†ใ‚’ใ‚ตใƒใƒผใƒˆใงใใพใ™ใ€‚
60
+ #===========================================================================================================
61
+
62
+ #===========================================================================================================
63
+ # ๋‹จ์–ด ๋ถ„ํ•  ๊ธฐ๋Šฅ: ๊ธฐ์‚ฌ ๋˜๋Š” ๋ฌธ์žฅ์—์„œ ๋‹จ์–ด(์ค‘๊ตญ์–ด/์˜์–ด/์ผ๋ณธ์–ด/ํ•œ๊ตญ์–ด)๋ฅผ ๋‹ค๋ฅธ ์–ธ์–ด์— ๋”ฐ๋ผ ์ž๋™์œผ๋กœ ์‹๋ณ„ํ•˜๊ณ  ๋ถ„ํ• ํ•˜์—ฌ TTS ์ฒ˜๋ฆฌ์— ๋” ์ ํ•ฉํ•ฉ๋‹ˆ๋‹ค.
64
+ # ์ด ์ฝ”๋“œ๋Š” ํ”„๋ŸฐํŠธ ์—”๋“œ ํ…์ŠคํŠธ ๋‹ค๊ตญ์–ด ํ˜ผํ•ฉ ์ฃผ์„ ๋ถ„ํ™”, ๋‹ค๊ตญ์–ด ํ˜ผํ•ฉ ๊ต์œก ๋ฐ ๋‹ค์–‘ํ•œ TTS ํ”„๋กœ์ ํŠธ์˜ ์ถ”๋ก ์„ ์œ„ํ•ด ์„ค๊ณ„๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
65
+ #===========================================================================================================
66
+ # (1) ์ž๋™ ๋‹จ์–ด ๋ถ„ํ• : "ํ•œ๊ตญ์–ด๋กœ ๋ฌด์—‡์„ ์ฝ์Šต๋‹ˆ๊นŒ? ์Šคํฌ์ธ  ์”จ? ์ด ์ปจํผ๋Ÿฐ์Šค๋Š” 4๊ฐœ์˜ iPhone 15 ์‹œ๋ฆฌ์ฆˆ ๋ชจ๋ธ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค."
67
+ # (2) ์ˆ˜๋™ ์ฐธ์—ฌ: "์ด๋ฆ„์ด <ja>Saki์ž…๋‹ˆ๊นŒ? <ja>?"
68
+ # ์ด ์ฒ˜๋ฆฌ ๊ฒฐ๊ณผ๋Š” ์ฃผ๋กœ (์ค‘๊ตญ์–ด = zh, ์ผ๋ณธ์–ด = ja, ์˜์–ด = en, ํ•œ๊ตญ์–ด = ko)๋ฅผ ์œ„ํ•œ ๊ฒƒ์ด๋ฉฐ ์‹ค์ œ๋กœ ํ˜ผํ•ฉ ์ฒ˜๋ฆฌ๋ฅผ ์œ„ํ•ด ์ตœ๋Œ€ 97๊ฐœ์˜ ์–ธ์–ด๋ฅผ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค.
69
+ #===========================================================================================================
70
+
71
+ # ===========================================================================================================
72
+ # ๅˆ†่ฏๅŠŸ่ƒฝ๏ผšๅฐ†ๆ–‡็ซ ๆˆ–ๅฅๅญ้‡Œ็š„ไพ‹ๅฆ‚๏ผˆไธญ/่‹ฑ/ๆ—ฅ/้Ÿฉ๏ผ‰๏ผŒๆŒ‰ไธๅŒ่ฏญ่จ€่‡ชๅŠจ่ฏ†ๅˆซๅนถๆ‹†ๅˆ†๏ผŒ่ฎฉๅฎƒๆ›ด้€‚ๅˆTTSๅค„็†ใ€‚
73
+ # ๆœฌไปฃ็ ไธ“ไธบๅ„็ง TTS ้กน็›ฎ็š„ๅ‰็ซฏๆ–‡ๆœฌๅคš่ฏญ็งๆททๅˆๆ ‡ๆณจๅŒบๅˆ†๏ผŒๅคš่ฏญ่จ€ๆททๅˆ่ฎญ็ปƒๅ’ŒๆŽจ็†่€Œ็ผ–ๅ†™ใ€‚
74
+ # ===========================================================================================================
75
+ # ๏ผˆ1๏ผ‰่‡ชๅŠจๅˆ†่ฏ๏ผšโ€œ้Ÿฉ่ฏญไธญ็š„์˜ค๋น ่ฏปไป€ไนˆๅ‘ข๏ผŸ๏ฟฝ๏ฟฝ๏ฟฝใชใŸใฎไฝ“่‚ฒใฎๅ…ˆ็”Ÿใฏ่ชฐใงใ™ใ‹? ๆญคๆฌกๅ‘ๅธƒไผšๅธฆๆฅไบ†ๅ››ๆฌพiPhone 15็ณปๅˆ—ๆœบๅž‹โ€
76
+ # ๏ผˆ2๏ผ‰ๆ‰‹ๅŠจๅˆ†่ฏ๏ผšโ€œไฝ ็š„ๅๅญ—ๅซ<ja>ไฝใ€…ๆœจ๏ผŸ<ja>ๅ—๏ผŸโ€
77
+ # ๆœฌๅค„็†็ป“ๆžœไธป่ฆ้’ˆๅฏน๏ผˆไธญๆ–‡=zh , ๆ—ฅๆ–‡=ja , ่‹ฑๆ–‡=en , ้Ÿฉ่ฏญ=ko๏ผ‰, ๅฎž้™…ไธŠๅฏๆ”ฏๆŒๅคš่พพ 97 ็งไธๅŒ็š„่ฏญ่จ€ๆททๅˆๅค„็†ใ€‚
78
+ # ===========================================================================================================
79
+
80
+
81
+ # ๆ‰‹ๅŠจๅˆ†่ฏๆ ‡็ญพ่ง„่Œƒ๏ผš<่ฏญ่จ€ๆ ‡็ญพ>ๆ–‡ๆœฌๅ†…ๅฎน</่ฏญ่จ€ๆ ‡็ญพ>
82
+ # ์ˆ˜๋™ ๋‹จ์–ด ๋ถ„ํ•  ํƒœ๊ทธ ์‚ฌ์–‘: <์–ธ์–ด ํƒœ๊ทธ> ํ…์ŠคํŠธ ๋‚ด์šฉ</์–ธ์–ด ํƒœ๊ทธ>
83
+ # Manual word segmentation tag specification: <language tags> text content </language tags>
84
+ # ๆ‰‹ๅ‹•ๅˆ†่ฉžใ‚ฟใ‚ฐไป•ๆง˜:<่จ€่ชžใ‚ฟใ‚ฐ>ใƒ†ใ‚ญใ‚นใƒˆๅ†…ๅฎน</่จ€่ชžใ‚ฟใ‚ฐ>
85
+ # ===========================================================================================================
86
+ # For manual word segmentation, labels need to appear in pairs, such as:
87
+ # ๅฆ‚้œ€ๆ‰‹ๅŠจๅˆ†่ฏ๏ผŒๆ ‡็ญพ้œ€่ฆๆˆๅฏนๅ‡บ็Žฐ๏ผŒไพ‹ๅฆ‚๏ผšโ€œ<ja>ไฝใ€…ๆœจ<ja>โ€ ๆˆ–่€… โ€œ<ja>ไฝใ€…ๆœจ</ja>โ€
88
+ # ้”™่ฏฏ็คบ่Œƒ๏ผšโ€œไฝ ็š„ๅๅญ—ๅซ<ja>ไฝใ€…ๆœจใ€‚โ€ ๆญคๅฅๅญไธญๅ‡บ็Žฐ็š„ๅ•ไธช<ja>ๆ ‡็ญพๅฐ†่ขซๅฟฝ็•ฅ๏ผŒไธไผšๅค„็†ใ€‚
89
+ # Error demonstration: "Your name is <ja>ไฝใ€…ๆœจใ€‚" Single <ja> tags that appear in this sentence will be ignored and will not be processed.
90
+ # ===========================================================================================================
91
+
92
+
93
+ # ===========================================================================================================
94
+ # ่ฏญ้Ÿณๅˆๆˆๆ ‡่ฎฐ่ฏญ่จ€ SSML , ่ฟ™้‡Œๅชๆ”ฏๆŒๅฎƒ็š„ๆ ‡็ญพ๏ผˆ้ž XML๏ผ‰Speech Synthesis Markup Language SSML, only its tags are supported here (not XML)
95
+ # ๆƒณๆ”ฏๆŒๆ›ดๅคš็š„ SSML ๆ ‡็ญพ๏ผŸๆฌข่ฟŽ PR๏ผ Want to support more SSML tags? PRs are welcome!
96
+ # ่ฏดๆ˜Ž๏ผš้™คไบ†ไธญๆ–‡ไปฅๅค–๏ผŒๅฎƒไนŸๅฏๆ”น้€ ๆˆๆ”ฏๆŒๅคš่ฏญ็ง SSML ๏ผŒไธไป…ไป…ๆ˜ฏไธญๆ–‡ใ€‚
97
+ # Note: In addition to Chinese, it can also be modified to support multi-language SSML, not just Chinese.
98
+ # ===========================================================================================================
99
+ # ไธญๆ–‡ๅฎž็Žฐ๏ผšChinese implementation:
100
+ # ใ€SSMLใ€‘<number>=ไธญๆ–‡ๅคงๅ†™ๆ•ฐๅญ—่ฏปๆณ•๏ผˆๅ•ๅญ—๏ผ‰
101
+ # ใ€SSMLใ€‘<telephone>=ๆ•ฐๅญ—่ฝฌๆˆไธญๆ–‡็”ต่ฏๅท็ ๅคงๅ†™ๆฑ‰ๅญ—๏ผˆๅ•ๅญ—๏ผ‰
102
+ # ใ€SSMLใ€‘<currency>=ๆŒ‰้‡‘้ขๅ‘้Ÿณใ€‚
103
+ # ใ€SSMLใ€‘<date>=ๆŒ‰ๆ—ฅๆœŸๅ‘้Ÿณใ€‚ๆ”ฏๆŒ 2024ๅนด08ๆœˆ24, 2024/8/24, 2024-08, 08-24, 24 ็ญ‰่พ“ๅ…ฅใ€‚
104
+ # ===========================================================================================================
105
+ class LangSSML:
106
+
107
+ def __init__(self):
108
+ # ็บฏๆ•ฐๅญ—
109
+ self._zh_numerals_number = {
110
+ '0': '้›ถ',
111
+ '1': 'ไธ€',
112
+ '2': 'ไบŒ',
113
+ '3': 'ไธ‰',
114
+ '4': 'ๅ››',
115
+ '5': 'ไบ”',
116
+ '6': 'ๅ…ญ',
117
+ '7': 'ไธƒ',
118
+ '8': 'ๅ…ซ',
119
+ '9': 'ไน'
120
+ }
121
+
122
+ # ๅฐ†2024/8/24, 2024-08, 08-24, 24 ๆ ‡ๅ‡†ๅŒ–โ€œๅนดๆœˆๆ—ฅโ€
123
+ # Standardize 2024/8/24, 2024-08, 08-24, 24 to "year-month-day"
124
+ def _format_chinese_data(self, date_str:str):
125
+ # ๅค„็†ๆ—ฅๆœŸๆ ผๅผ
126
+ input_date = date_str
127
+ if date_str is None or date_str.strip() == "":return ""
128
+ date_str = re.sub(r"[\/\._|ๅนด|ๆœˆ]","-",date_str)
129
+ date_str = re.sub(r"ๆ—ฅ",r"",date_str)
130
+ date_arrs = date_str.split(' ')
131
+ if len(date_arrs) == 1 and ":" in date_arrs[0]:
132
+ time_str = date_arrs[0]
133
+ date_arrs = []
134
+ else:
135
+ time_str = date_arrs[1] if len(date_arrs) >=2 else ""
136
+ def nonZero(num,cn,func=None):
137
+ if func is not None:num=func(num)
138
+ return f"{num}{cn}" if num is not None and num != "" and num != "0" else ""
139
+ f_number = self.to_chinese_number
140
+ f_currency = self.to_chinese_currency
141
+ # year, month, day
142
+ year_month_day = ""
143
+ if len(date_arrs) > 0:
144
+ year, month, day = "","",""
145
+ parts = date_arrs[0].split('-')
146
+ if len(parts) == 3: # ๆ ผๅผไธบ YYYY-MM-DD
147
+ year, month, day = parts
148
+ elif len(parts) == 2: # ๆ ผๅผไธบ MM-DD ๆˆ– YYYY-MM
149
+ if len(parts[0]) == 4: # ๅนด-ๆœˆ
150
+ year, month = parts
151
+ else:month, day = parts # ๆœˆ-ๆ—ฅ
152
+ elif len(parts[0]) > 0: # ไป…ๆœ‰ๆœˆ-ๆ—ฅๆˆ–ๅนด
153
+ if len(parts[0]) == 4:
154
+ year = parts[0]
155
+ else:day = parts[0]
156
+ year,month,day = nonZero(year,"ๅนด",f_number),nonZero(month,"ๆœˆ",f_currency),nonZero(day,"ๆ—ฅ",f_currency)
157
+ year_month_day = re.sub(r"([ๅนด|ๆœˆ|ๆ—ฅ])+",r"\1",f"{year}{month}{day}")
158
+ # hours, minutes, seconds
159
+ time_str = re.sub(r"[\/\.\-๏ผš_]",":",time_str)
160
+ time_arrs = time_str.split(":")
161
+ hours, minutes, seconds = "","",""
162
+ if len(time_arrs) == 3: # H/M/S
163
+ hours, minutes, seconds = time_arrs
164
+ elif len(time_arrs) == 2:# H/M
165
+ hours, minutes = time_arrs
166
+ elif len(time_arrs[0]) > 0:hours = f'{time_arrs[0]}็‚น' # H
167
+ if len(time_arrs) > 1:
168
+ hours, minutes, seconds = nonZero(hours,"็‚น",f_currency),nonZero(minutes,"ๅˆ†",f_currency),nonZero(seconds,"็ง’",f_currency)
169
+ hours_minutes_seconds = re.sub(r"([็‚น|ๅˆ†|็ง’])+",r"\1",f"{hours}{minutes}{seconds}")
170
+ output_date = f"{year_month_day}{hours_minutes_seconds}"
171
+ return output_date
172
+
173
+ # ใ€SSMLใ€‘number=ไธญๆ–‡ๅคงๅ†™ๆ•ฐๅญ—่ฏปๆณ•๏ผˆๅ•ๅญ—๏ผ‰
174
+ # Chinese Numbers(single word)
175
+ def to_chinese_number(self, num:str):
176
+ pattern = r'(\d+)'
177
+ zh_numerals = self._zh_numerals_number
178
+ arrs = re.split(pattern, num)
179
+ output = ""
180
+ for item in arrs:
181
+ if re.match(pattern,item):
182
+ output += ''.join(zh_numerals[digit] if digit in zh_numerals else "" for digit in str(item))
183
+ else:output += item
184
+ output = output.replace(".","็‚น")
185
+ return output
186
+
187
+ # ใ€SSMLใ€‘telephone=ๆ•ฐๅญ—่ฝฌๆˆไธญๆ–‡็”ต่ฏๅท็ ๅคงๅ†™ๆฑ‰ๅญ—๏ผˆๅ•ๅญ—๏ผ‰
188
+ # Convert numbers to Chinese phone numbers in uppercase Chinese characters(single word)
189
+ def to_chinese_telephone(self, num:str):
190
+ output = self.to_chinese_number(num.replace("+86","")) # zh +86
191
+ output = output.replace("ไธ€","ๅนบ")
192
+ return output
193
+
194
+ # ใ€SSMLใ€‘currency=ๆŒ‰้‡‘้ขๅ‘้Ÿณใ€‚
195
+ # Digital processing from GPT_SoVITS num.py ๏ผˆthanks๏ผ‰
196
+ def to_chinese_currency(self, num:str):
197
+ pattern = r'(\d+)'
198
+ arrs = re.split(pattern, num)
199
+ output = ""
200
+ for item in arrs:
201
+ if re.match(pattern,item):
202
+ output += num2str(item)
203
+ else:output += item
204
+ output = output.replace(".","็‚น")
205
+ return output
206
+
207
+ # ใ€SSMLใ€‘date=ๆŒ‰ๆ—ฅๆœŸๅ‘้Ÿณใ€‚ๆ”ฏๆŒ 2024ๅนด08ๆœˆ24, 2024/8/24, 2024-08, 08-24, 24 ็ญ‰่พ“ๅ…ฅใ€‚
208
+ def to_chinese_date(self, num:str):
209
+ chinese_date = self._format_chinese_data(num)
210
+ return chinese_date
211
+
212
+
213
+ class LangSegment:
214
+
215
+ def __init__(self):
216
+
217
+ self.langid = LanguageIdentifier.from_pickled_model(MODEL_FILE, norm_probs=True)
218
+
219
+ self._text_cache = None
220
+ self._text_lasts = None
221
+ self._text_langs = None
222
+ self._lang_count = None
223
+ self._lang_eos = None
224
+
225
+ # ๅฏ่‡ชๅฎšไน‰่ฏญ่จ€ๅŒน้…ๆ ‡็ญพ๏ผšใ‚ซใ‚นใ‚ฟใƒžใ‚คใ‚บๅฏ่ƒฝใช่จ€่ชžๅฏพๅฟœใ‚ฟใ‚ฐ:์‚ฌ์šฉ์ž ์ง€์ • ๊ฐ€๋Šฅํ•œ ์–ธ์–ด ์ผ์น˜ ํƒœ๊ทธ:
226
+ # Customizable language matching tags: These are supported๏ผŒ์ด ํ‘œํ˜„๋“ค์€ ๋ชจ๋‘ ์ง€์ง€ํ•ฉ๋‹ˆ๋‹ค
227
+ # <zh>ไฝ ๅฅฝ<zh> , <ja>ไฝใ€…ๆœจ</ja> , <en>OK<en> , <ko>์˜ค๋น </ko> ่ฟ™ไบ›ๅ†™ๆณ•ๅ‡ๆ”ฏๆŒ
228
+ self.SYMBOLS_PATTERN = r'(<([a-zA-Z|-]*)>(.*?)<\/*[a-zA-Z|-]*>)'
229
+
230
+ # ่ฏญ่จ€่ฟ‡ๆปค็ป„ๅŠŸ่ƒฝ, ๅฏไปฅๆŒ‡ๅฎšไฟ็•™่ฏญ่จ€ใ€‚ไธๅœจ่ฟ‡ๆปค็ป„ไธญ็š„่ฏญ่จ€ๅฐ†่ขซๆธ…้™คใ€‚ๆ‚จๅฏ้šๅฟƒๆญ้…TTS่ฏญ้Ÿณๅˆๆˆๆ‰€ๆ”ฏๆŒ็š„่ฏญ่จ€ใ€‚
231
+ # ์–ธ์–ด ํ•„ํ„ฐ ๊ทธ๋ฃน ๊ธฐ๋Šฅ์„ ์‚ฌ์šฉํ•˜๋ฉด ์˜ˆ์•ฝ๋œ ์–ธ์–ด๋ฅผ ์ง€์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ•„ํ„ฐ ๊ทธ๋ฃน์— ์—†๋Š” ์–ธ์–ด๋Š” ์ง€์›Œ์ง‘๋‹ˆ๋‹ค. TTS ํ…์ŠคํŠธ์—์„œ ์ง€์›ํ•˜๋Š” ์–ธ์–ด๋ฅผ ์›ํ•˜๋Š” ๋Œ€๋กœ ์ผ์น˜์‹œํ‚ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
232
+ # ่จ€่ชžใƒ•ใ‚ฃใƒซใ‚ฟใƒผใ‚ฐใƒซใƒผใƒ—ๆฉŸ่ƒฝใงใฏใ€ไบˆ็ด„่จ€่ชžใ‚’ๆŒ‡ๅฎšใงใใพใ™ใ€‚ใƒ•ใ‚ฃใƒซใ‚ฟใƒผใ‚ฐใƒซใƒผใƒ—ใซๅซใพใ‚Œใฆใ„ใชใ„่จ€่ชžใฏใ‚ฏใƒชใ‚ขใ•ใ‚Œใพใ™ใ€‚TTS้ŸณๅฃฐๅˆๆˆใŒใ‚ตใƒใƒผใƒˆใ™ใ‚‹่จ€่ชžใ‚’่‡ช็”ฑใซ็ต„ใฟๅˆใ‚ใ›ใ‚‹ใ“ใจใŒใงใใพใ™ใ€‚
233
+ # The language filter group function allows you to specify reserved languages.
234
+ # Languages not in the filter group will be cleared. You can match the languages supported by TTS Text To Speech as you like.
235
+ # ๆŽ’ๅ่ถŠๅ‰๏ผŒไผ˜ๅ…ˆ็บง่ถŠ้ซ˜๏ผŒThe higher the ranking, the higher the priority๏ผŒใƒฉใƒณใ‚ญใƒณใ‚ฐใŒไธŠไฝใซใชใ‚‹ใปใฉใ€ๅ„ชๅ…ˆๅบฆใŒ้ซ˜ใใชใ‚Šใพใ™ใ€‚
236
+
237
+ # ็ณป็ปŸ้ป˜่ฎค่ฟ‡ๆปคๅ™จใ€‚System default filterใ€‚(ISO 639-1 codes given)
238
+ # ----------------------------------------------------------------------------------------------------------------------------------
239
+ # "zh"ไธญๆ–‡=Chinese ,"en"่‹ฑ่ฏญ=English ,"ja"ๆ—ฅ่ฏญ=Japanese ,"ko"้Ÿฉ่ฏญ=Korean ,"fr"ๆณ•่ฏญ=French ,"vi"่ถŠๅ—่ฏญ=Vietnamese , "ru"ไฟ„่ฏญ=Russian
240
+ # "th"ๆณฐ่ฏญ=Thai
241
+ # ----------------------------------------------------------------------------------------------------------------------------------
242
+ self.DEFAULT_FILTERS = ["zh", "ja", "ko", "en"]
243
+
244
+ # ็”จๆˆทๅฏ่‡ชๅฎšไน‰่ฟ‡ๆปคๅ™จใ€‚User-defined filters
245
+ self.Langfilters = self.DEFAULT_FILTERS[:] # ๅˆ›ๅปบๅ‰ฏๆœฌ
246
+
247
+ # ๅˆๅนถๆ–‡ๆœฌ
248
+ self.isLangMerge = True
249
+
250
+ # ่ฏ•้ชŒๆ€งๆ”ฏๆŒ๏ผšๆ‚จๅฏ่‡ชๅฎšไน‰ๆทปๅŠ ๏ผš"fr"ๆณ•่ฏญ , "vi"่ถŠๅ—่ฏญใ€‚Experimental: You can customize to add: "fr" French, "vi" Vietnamese.
251
+ # ่ฏทไฝฟ็”จAPIๅฏ็”จ๏ผšself.setfilters(["zh", "en", "ja", "ko", "fr", "vi" , "ru" , "th"]) # ๆ‚จๅฏ่‡ชๅฎšไน‰ๆทปๅŠ ๏ผŒๅฆ‚๏ผš"fr"ๆณ•่ฏญ , "vi"่ถŠๅ—่ฏญใ€‚
252
+
253
+ # ้ข„่งˆ็‰ˆๅŠŸ่ƒฝ๏ผŒ่‡ชๅŠจๅฏ็”จๆˆ–็ฆ็”จ๏ผŒๆ— ้œ€่ฎพ็ฝฎ
254
+ # Preview feature, automatically enabled or disabled, no settings required
255
+ self.EnablePreview = False
256
+
257
+ # ้™คๆญคไปฅๅค–๏ผŒๅฎƒๆ”ฏๆŒ็ฎ€ๅ†™่ฟ‡ๆปคๅ™จ๏ผŒๅช้œ€ๆŒ‰ไธๅŒ่ฏญ็งไปปๆ„็ป„ๅˆๅณๅฏใ€‚
258
+ # In addition to that, it supports abbreviation filters, allowing for any combination of different languages.
259
+ # ็คบไพ‹๏ผšๆ‚จๅฏไปฅไปปๆ„ๆŒ‡ๅฎšๅคš็ง็ป„ๅˆ๏ผŒ่ฟ›่กŒ่ฟ‡ๆปค
260
+ # Example: You can specify any combination to filter
261
+
262
+ # ไธญ/ๆ—ฅ่ฏญ่จ€ไผ˜ๅ…ˆ็บง้˜€ๅ€ผ๏ผˆ่ฏ„ๅˆ†่Œƒๅ›ดไธบ 0 ~ 1๏ผ‰:่ฏ„ๅˆ†ไฝŽไบŽ่ฎพๅฎš้˜€ๅ€ผ <0.89 ๆ—ถ๏ผŒๅฏ็”จ filters ไธญ็š„ไผ˜ๅ…ˆ็บงใ€‚\n
263
+ # ์ค‘/์ผ๋ณธ์–ด ์šฐ์„  ์ˆœ์œ„ ์ž„๊ณ„๊ฐ’(์ ์ˆ˜ ๋ฒ”์œ„ 0-1): ์ ์ˆ˜๊ฐ€ ์„ค์ •๋œ ์ž„๊ณ„๊ฐ’ <0.89๋ณด๋‹ค ๋‚ฎ์„ ๋•Œ ํ•„ํ„ฐ์—์„œ ์šฐ์„  ์ˆœ์œ„๋ฅผ ํ™œ์„ฑํ™”ํ•ฉ๋‹ˆ๋‹ค.
264
+ # ไธญๅ›ฝ่ชž/ๆ—ฅๆœฌ่ชžใฎๅ„ชๅ…ˆๅบฆใ—ใใ„ๅ€ค๏ผˆใ‚นใ‚ณใ‚ข็ฏ„ๅ›ฒ0ใ€œ1๏ผ‰:ใ‚นใ‚ณใ‚ขใŒ่จญๅฎšใ•ใ‚ŒใŸใ—ใใ„ๅ€ค<0.89ๆœชๆบ€ใฎๅ ดๅˆใ€ใƒ•ใ‚ฃใƒซใ‚ฟใƒผใฎๅ„ชๅ…ˆๅบฆใŒๆœ‰ๅŠนใซใชใ‚Šใพใ™ใ€‚\n
265
+ # Chinese and Japanese language priority threshold (score range is 0 ~ 1): The default threshold is 0.89. \n
266
+ # Only the common characters between Chinese and Japanese are processed with confidence and priority. \n
267
+ self.LangPriorityThreshold = 0.89
268
+
269
+ # Langfilters = ["zh"] # ๆŒ‰ไธญๆ–‡่ฏ†ๅˆซ
270
+ # Langfilters = ["en"] # ๆŒ‰่‹ฑๆ–‡่ฏ†ๅˆซ
271
+ # Langfilters = ["ja"] # ๆŒ‰ๆ—ฅๆ–‡่ฏ†ๅˆซ
272
+ # Langfilters = ["ko"] # ๆŒ‰้Ÿฉๆ–‡่ฏ†ๅˆซ
273
+ # Langfilters = ["zh_ja"] # ไธญๆ—ฅๆททๅˆ่ฏ†ๅˆซ
274
+ # Langfilters = ["zh_en"] # ไธญ่‹ฑๆททๅˆ่ฏ†ๅˆซ
275
+ # Langfilters = ["ja_en"] # ๆ—ฅ่‹ฑๆททๅˆ่ฏ†ๅˆซ
276
+ # Langfilters = ["zh_ko"] # ไธญ้Ÿฉๆททๅˆ่ฏ†ๅˆซ
277
+ # Langfilters = ["ja_ko"] # ๆ—ฅ้Ÿฉๆททๅˆ่ฏ†ๅˆซ
278
+ # Langfilters = ["en_ko"] # ่‹ฑ้Ÿฉๆททๅˆ่ฏ†ๅˆซ
279
+ # Langfilters = ["zh_ja_en"] # ไธญๆ—ฅ่‹ฑๆททๅˆ่ฏ†ๅˆซ
280
+ # Langfilters = ["zh_ja_en_ko"] # ไธญๆ—ฅ่‹ฑ้Ÿฉๆททๅˆ่ฏ†ๅˆซ
281
+
282
+ # ๆ›ดๅคš่ฟ‡ๆปค็ป„ๅˆ๏ผŒ่ฏทๆ‚จ้šๆ„ใ€‚ใ€‚ใ€‚For more filter combinations, please feel free to......
283
+ # ใ‚ˆใ‚Šๅคšใใฎใƒ•ใ‚ฃใƒซใ‚ฟใƒผใฎ็ต„ใฟๅˆใ‚ใ›ใ€ใŠๆฐ—่ปฝใซใ€‚ใ€‚ใ€‚๋” ๋งŽ์€ ํ•„ํ„ฐ ์กฐํ•ฉ์„ ์›ํ•˜์‹œ๋ฉด ์ž์œ ๋กญ๊ฒŒ ํ•ด์ฃผ์„ธ์š”. .....
284
+
285
+ # ๅฏ้€‰ไฟ็•™๏ผšๆ”ฏๆŒไธญๆ–‡ๆ•ฐๅญ—ๆ‹ผ้Ÿณๆ ผๅผ๏ผŒๆ›ดๆ–นไพฟๅ‰็ซฏๅฎž็Žฐๆ‹ผ้Ÿณ้Ÿณ็ด ไฟฎๆ”นๅ’ŒๆŽจ็†๏ผŒ้ป˜่ฎคๅ…ณ้—ญ False ใ€‚
286
+ # ๅผ€ๅฏๅŽ True ๏ผŒๆ‹ฌๅทๅ†…็š„ๆ•ฐๅญ—ๆ‹ผ้Ÿณๆ ผๅผๅ‡ไฟ็•™๏ผŒๅนถ่ฏ†ๅˆซ่พ“ๅ‡บไธบ๏ผš"zh"ไธญๆ–‡ใ€‚
287
+ self.keepPinyin = False
288
+
289
+ # DEFINITION
290
+ self.PARSE_TAG = re.compile(r'(โ‘ฅ\$*\d+[\d]{6,}โ‘ฅ)')
291
+
292
+ self.LangSSML = LangSSML()
293
+
294
+ def _clears(self):
295
+ self._text_cache = None
296
+ self._text_lasts = None
297
+ self._text_langs = None
298
+ self._text_waits = None
299
+ self._lang_count = None
300
+ self._lang_eos = None
301
+
302
+ def _is_english_word(self, word):
303
+ return bool(re.match(r'^[a-zA-Z]+$', word))
304
+
305
+ def _is_chinese(self, word):
306
+ for char in word:
307
+ if '\u4e00' <= char <= '\u9fff':
308
+ return True
309
+ return False
310
+
311
+ def _is_japanese_kana(self, word):
312
+ pattern = re.compile(r'[\u3040-\u309F\u30A0-\u30FF]+')
313
+ matches = pattern.findall(word)
314
+ return len(matches) > 0
315
+
316
+ def _insert_english_uppercase(self, word):
317
+ modified_text = re.sub(r'(?<!\b)([A-Z])', r' \1', word)
318
+ modified_text = modified_text.strip('-')
319
+ return modified_text + " "
320
+
321
+ def _split_camel_case(self, word):
322
+ return re.sub(r'(?<!^)(?=[A-Z])', ' ', word)
323
+
324
+ def _statistics(self, language, text):
325
+ # Language word statistics:
326
+ # Chinese characters usually occupy double bytes
327
+ if self._lang_count is None or not isinstance(self._lang_count, defaultdict):
328
+ self._lang_count = defaultdict(int)
329
+ lang_count = self._lang_count
330
+ if not "|" in language:
331
+ lang_count[language] += int(len(text)*2) if language == "zh" else len(text)
332
+ self._lang_count = lang_count
333
+
334
+ def _clear_text_number(self, text):
335
+ if text == "\n":return text,False # Keep Line Breaks
336
+ clear_text = re.sub(r'([^\w\s]+)','',re.sub(r'\n+','',text)).strip()
337
+ is_number = len(re.sub(re.compile(r'(\d+)'),'',clear_text)) == 0
338
+ return clear_text,is_number
339
+
340
+ def _saveData(self, words,language:str,text:str,score:float,symbol=None):
341
+ # Pre-detection
342
+ clear_text , is_number = self._clear_text_number(text)
343
+ # Merge the same language and save the results
344
+ preData = words[-1] if len(words) > 0 else None
345
+ if symbol is not None:pass
346
+ elif preData is not None and preData["symbol"] is None:
347
+ if len(clear_text) == 0:language = preData["lang"]
348
+ elif is_number == True:language = preData["lang"]
349
+ _ , pre_is_number = self._clear_text_number(preData["text"])
350
+ if (preData["lang"] == language):
351
+ self._statistics(preData["lang"],text)
352
+ text = preData["text"] + text
353
+ preData["text"] = text
354
+ return preData
355
+ elif pre_is_number == True:
356
+ text = f'{preData["text"]}{text}'
357
+ words.pop()
358
+ elif is_number == True:
359
+ priority_language = self._get_filters_string()[:2]
360
+ if priority_language in "ja-zh-en-ko-fr-vi":language = priority_language
361
+ data = {"lang":language,"text": text,"score":score,"symbol":symbol}
362
+ filters = self.Langfilters
363
+ if filters is None or len(filters) == 0 or "?" in language or \
364
+ language in filters or language in filters[0] or \
365
+ filters[0] == "*" or filters[0] in "alls-mixs-autos":
366
+ words.append(data)
367
+ self._statistics(data["lang"],data["text"])
368
+ return data
369
+
370
+ def _addwords(self, words,language,text,score,symbol=None):
371
+ if text == "\n":pass # Keep Line Breaks
372
+ elif text is None or len(text.strip()) == 0:return True
373
+ if language is None:language = ""
374
+ language = language.lower()
375
+ if language == 'en':text = self._insert_english_uppercase(text)
376
+ # text = re.sub(r'[(๏ผˆ๏ผ‰)]', ',' , text) # Keep it.
377
+ text_waits = self._text_waits
378
+ ispre_waits = len(text_waits)>0
379
+ preResult = text_waits.pop() if ispre_waits else None
380
+ if preResult is None:preResult = words[-1] if len(words) > 0 else None
381
+ if preResult and ("|" in preResult["lang"]):
382
+ pre_lang = preResult["lang"]
383
+ if language in pre_lang:preResult["lang"] = language = language.split("|")[0]
384
+ else:preResult["lang"]=pre_lang.split("|")[0]
385
+ if ispre_waits:preResult = self._saveData(words,preResult["lang"],preResult["text"],preResult["score"],preResult["symbol"])
386
+ pre_lang = preResult["lang"] if preResult else None
387
+ if ("|" in language) and (pre_lang and not pre_lang in language and not "โ€ฆ" in language):language = language.split("|")[0]
388
+ if "|" in language:self._text_waits.append({"lang":language,"text": text,"score":score,"symbol":symbol})
389
+ else:self._saveData(words,language,text,score,symbol)
390
+ return False
391
+
392
+ def _get_prev_data(self, words):
393
+ data = words[-1] if words and len(words) > 0 else None
394
+ if data:return (data["lang"] , data["text"])
395
+ return (None,"")
396
+
397
+ def _match_ending(self, input , index):
398
+ if input is None or len(input) == 0:return False,None
399
+ input = re.sub(r'\s+', '', input)
400
+ if len(input) == 0 or abs(index) > len(input):return False,None
401
+ ending_pattern = re.compile(r'([ใ€Œใ€โ€œโ€โ€˜โ€™"\':๏ผšใ€‚.๏ผ!?๏ผŽ๏ผŸ])')
402
+ return ending_pattern.match(input[index]),input[index]
403
+
404
+ def _cleans_text(self, cleans_text):
405
+ cleans_text = re.sub(r'(.*?)([^\w]+)', r'\1 ', cleans_text)
406
+ cleans_text = re.sub(r'(.)\1+', r'\1', cleans_text)
407
+ return cleans_text.strip()
408
+
409
+ def _mean_processing(self, text:str):
410
+ if text is None or (text.strip()) == "":return None , 0.0
411
+ arrs = self._split_camel_case(text).split(" ")
412
+ langs = []
413
+ for t in arrs:
414
+ if len(t.strip()) <= 3:continue
415
+ language, score = self.langid.classify(t)
416
+ langs.append({"lang":language})
417
+ if len(langs) == 0:return None , 0.0
418
+ return Counter([item['lang'] for item in langs]).most_common(1)[0][0],1.0
419
+
420
+ def _lang_classify(self, cleans_text):
421
+ language, score = self.langid.classify(cleans_text)
422
+ # fix: Huggingface is np.float32
423
+ if score is not None and isinstance(score, np.generic) and hasattr(score,"item"):
424
+ score = score.item()
425
+ score = round(score , 3)
426
+ return language, score
427
+
428
+ def _get_filters_string(self):
429
+ filters = self.Langfilters
430
+ return "-".join(filters).lower().strip() if filters is not None else ""
431
+
432
+ def _parse_language(self, words , segment):
433
+ LANG_JA = "ja"
434
+ LANG_ZH = "zh"
435
+ LANG_ZH_JA = f'{LANG_ZH}|{LANG_JA}'
436
+ LANG_JA_ZH = f'{LANG_JA}|{LANG_ZH}'
437
+ language = LANG_ZH
438
+ regex_pattern = re.compile(r'([^\w\s]+)')
439
+ lines = regex_pattern.split(segment)
440
+ lines_max = len(lines)
441
+ LANG_EOS =self._lang_eos
442
+ for index, text in enumerate(lines):
443
+ if len(text) == 0:continue
444
+ EOS = index >= (lines_max - 1)
445
+ nextId = index + 1
446
+ nextText = lines[nextId] if not EOS else ""
447
+ nextPunc = len(re.sub(regex_pattern,'',re.sub(r'\n+','',nextText)).strip()) == 0
448
+ textPunc = len(re.sub(regex_pattern,'',re.sub(r'\n+','',text)).strip()) == 0
449
+ if not EOS and (textPunc == True or ( len(nextText.strip()) >= 0 and nextPunc == True)):
450
+ lines[nextId] = f'{text}{nextText}'
451
+ continue
452
+ number_tags = re.compile(r'(โ‘ฅ\d{6,}โ‘ฅ)')
453
+ cleans_text = re.sub(number_tags, '' ,text)
454
+ cleans_text = re.sub(r'\d+', '' ,cleans_text)
455
+ cleans_text = self._cleans_text(cleans_text)
456
+ # fix:Langid's recognition of short sentences is inaccurate, and it is spliced longer.
457
+ if not EOS and len(cleans_text) <= 2:
458
+ lines[nextId] = f'{text}{nextText}'
459
+ continue
460
+ language,score = self._lang_classify(cleans_text)
461
+ prev_language , prev_text = self._get_prev_data(words)
462
+ if language != LANG_ZH and all('\u4e00' <= c <= '\u9fff' for c in re.sub(r'\s','',cleans_text)):language,score = LANG_ZH,1
463
+ if len(cleans_text) <= 5 and self._is_chinese(cleans_text):
464
+ filters_string = self._get_filters_string()
465
+ if score < self.LangPriorityThreshold and len(filters_string) > 0:
466
+ index_ja , index_zh = filters_string.find(LANG_JA) , filters_string.find(LANG_ZH)
467
+ if index_ja != -1 and index_ja < index_zh:language = LANG_JA
468
+ elif index_zh != -1 and index_zh < index_ja:language = LANG_ZH
469
+ if self._is_japanese_kana(cleans_text):language = LANG_JA
470
+ elif len(cleans_text) > 2 and score > 0.90:pass
471
+ elif EOS and LANG_EOS:language = LANG_ZH if len(cleans_text) <= 1 else language
472
+ else:
473
+ LANG_UNKNOWN = LANG_ZH_JA if language == LANG_ZH or (len(cleans_text) <=2 and prev_language == LANG_ZH) else LANG_JA_ZH
474
+ match_end,match_char = self._match_ending(text, -1)
475
+ referen = prev_language in LANG_UNKNOWN or LANG_UNKNOWN in prev_language if prev_language else False
476
+ if match_char in "ใ€‚.": language = prev_language if referen and len(words) > 0 else language
477
+ else:language = f"{LANG_UNKNOWN}|โ€ฆ"
478
+ text,*_ = re.subn(number_tags , self._restore_number , text )
479
+ self._addwords(words,language,text,score)
480
+
481
+ # ----------------------------------------------------------
482
+ # ใ€SSMLใ€‘ไธญๆ–‡ๆ•ฐๅญ—ๅค„็†๏ผšChinese Number Processing (SSML support)
483
+ # ่ฟ™้‡Œ้ป˜่ฎค้ƒฝๆ˜ฏไธญๆ–‡๏ผŒ็”จไบŽๅค„็† SSML ไธญๆ–‡ๆ ‡็ญพใ€‚ๅฝ“็„ถๅฏไปฅๆ”ฏๆŒไปปๆ„่ฏญ่จ€๏ผŒไพ‹ๅฆ‚๏ผš
484
+ # The default here is Chinese, which is used to process SSML Chinese tags. Of course, any language can be supported, for example:
485
+ # ไธญๆ–‡็”ต่ฏๅท็ ๏ผš<telephone>1234567</telephone>
486
+ # ไธญๆ–‡ๆ•ฐๅญ—ๅท็ ๏ผš<number>1234567</number>
487
+ def _process_symbol_SSML(self, words,data):
488
+ tag , match = data
489
+ language = SSML = match[1]
490
+ text = match[2]
491
+ score = 1.0
492
+ if SSML == "telephone":
493
+ # ไธญๆ–‡-็”ต่ฏๅท็ 
494
+ language = "zh"
495
+ text = self.LangSSML.to_chinese_telephone(text)
496
+ elif SSML == "number":
497
+ # ไธญๆ–‡-ๆ•ฐๅญ—่ฏปๆณ•
498
+ language = "zh"
499
+ text = self.LangSSML.to_chinese_number(text)
500
+ elif SSML == "currency":
501
+ # ไธญๆ–‡-ๆŒ‰้‡‘้ขๅ‘้Ÿณ
502
+ language = "zh"
503
+ text = self.LangSSML.to_chinese_currency(text)
504
+ elif SSML == "date":
505
+ # ไธญๆ–‡-ๆŒ‰้‡‘้ขๅ‘้Ÿณ
506
+ language = "zh"
507
+ text = self.LangSSML.to_chinese_date(text)
508
+ self._addwords(words,language,text,score,SSML)
509
+
510
+ # ----------------------------------------------------------
511
+ def _restore_number(self, matche):
512
+ value = matche.group(0)
513
+ text_cache = self._text_cache
514
+ if value in text_cache:
515
+ process , data = text_cache[value]
516
+ tag , match = data
517
+ value = match
518
+ return value
519
+
520
+ def _pattern_symbols(self, item , text):
521
+ if text is None:return text
522
+ tag , pattern , process = item
523
+ matches = pattern.findall(text)
524
+ if len(matches) == 1 and "".join(matches[0]) == text:
525
+ return text
526
+ for i , match in enumerate(matches):
527
+ key = f"โ‘ฅ{tag}{i:06d}โ‘ฅ"
528
+ text = re.sub(pattern , key , text , count=1)
529
+ self._text_cache[key] = (process , (tag , match))
530
+ return text
531
+
532
+ def _process_symbol(self, words,data):
533
+ tag , match = data
534
+ language = match[1]
535
+ text = match[2]
536
+ score = 1.0
537
+ filters = self._get_filters_string()
538
+ if language not in filters:
539
+ self._process_symbol_SSML(words,data)
540
+ else:
541
+ self._addwords(words,language,text,score,True)
542
+
543
+ def _process_english(self, words,data):
544
+ tag , match = data
545
+ text = match[0]
546
+ filters = self._get_filters_string()
547
+ priority_language = filters[:2]
548
+ # Preview feature, other language segmentation processing
549
+ enablePreview = self.EnablePreview
550
+ if enablePreview == True:
551
+ # Experimental: Other language support
552
+ regex_pattern = re.compile(r'(.*?[ใ€‚.?๏ผŸ!๏ผ]+[\n]{,1})')
553
+ lines = regex_pattern.split(text)
554
+ for index , text in enumerate(lines):
555
+ if len(text.strip()) == 0:continue
556
+ cleans_text = self._cleans_text(text)
557
+ language,score = self._lang_classify(cleans_text)
558
+ if language not in filters:
559
+ language,score = self._mean_processing(cleans_text)
560
+ if language is None or score <= 0.0:continue
561
+ elif language in filters:pass # pass
562
+ elif score >= 0.95:continue # High score, but not in the filter, excluded.
563
+ elif score <= 0.15 and filters[:2] == "fr":language = priority_language
564
+ else:language = "en"
565
+ self._addwords(words,language,text,score)
566
+ else:
567
+ # Default is English
568
+ language, score = "en", 1.0
569
+ self._addwords(words,language,text,score)
570
+
571
+ def _process_Russian(self, words,data):
572
+ tag , match = data
573
+ text = match[0]
574
+ language = "ru"
575
+ score = 1.0
576
+ self._addwords(words,language,text,score)
577
+
578
+ def _process_Thai(self, words,data):
579
+ tag , match = data
580
+ text = match[0]
581
+ language = "th"
582
+ score = 1.0
583
+ self._addwords(words,language,text,score)
584
+
585
+ def _process_korean(self, words,data):
586
+ tag , match = data
587
+ text = match[0]
588
+ language = "ko"
589
+ score = 1.0
590
+ self._addwords(words,language,text,score)
591
+
592
+ def _process_quotes(self, words,data):
593
+ tag , match = data
594
+ text = "".join(match)
595
+ childs = self.PARSE_TAG.findall(text)
596
+ if len(childs) > 0:
597
+ self._process_tags(words , text , False)
598
+ else:
599
+ cleans_text = self._cleans_text(match[1])
600
+ if len(cleans_text) <= 5:
601
+ self._parse_language(words,text)
602
+ else:
603
+ language,score = self._lang_classify(cleans_text)
604
+ self._addwords(words,language,text,score)
605
+
606
+ def _process_pinyin(self, words,data):
607
+ tag , match = data
608
+ text = match
609
+ language = "zh"
610
+ score = 1.0
611
+ self._addwords(words,language,text,score)
612
+
613
+ def _process_number(self, words,data): # "$0" process only
614
+ """
615
+ Numbers alone cannot accurately identify language.
616
+ Because numbers are universal in all languages.
617
+ So it won't be executed here, just for testing.
618
+ """
619
+ tag , match = data
620
+ language = words[0]["lang"] if len(words) > 0 else "zh"
621
+ text = match
622
+ score = 0.0
623
+ self._addwords(words,language,text,score)
624
+
625
+ def _process_tags(self, words , text , root_tag):
626
+ text_cache = self._text_cache
627
+ segments = re.split(self.PARSE_TAG, text)
628
+ segments_len = len(segments) - 1
629
+ for index , text in enumerate(segments):
630
+ if root_tag:self._lang_eos = index >= segments_len
631
+ if self.PARSE_TAG.match(text):
632
+ process , data = text_cache[text]
633
+ if process:process(words , data)
634
+ else:
635
+ self._parse_language(words , text)
636
+ return words
637
+
638
+ def _merge_results(self, words):
639
+ new_word = []
640
+ for index , cur_data in enumerate(words):
641
+ if "symbol" in cur_data:del cur_data["symbol"]
642
+ if index == 0:new_word.append(cur_data)
643
+ else:
644
+ pre_data = new_word[-1]
645
+ if cur_data["lang"] == pre_data["lang"]:
646
+ pre_data["text"] = f'{pre_data["text"]}{cur_data["text"]}'
647
+ else:new_word.append(cur_data)
648
+ return new_word
649
+
650
+ def _parse_symbols(self, text):
651
+ TAG_NUM = "00" # "00" => default channels , "$0" => testing channel
652
+ TAG_S1,TAG_S2,TAG_P1,TAG_P2,TAG_EN,TAG_KO,TAG_RU,TAG_TH = "$1" ,"$2" ,"$3" ,"$4" ,"$5" ,"$6" ,"$7","$8"
653
+ TAG_BASE = re.compile(fr'(([ใ€ใ€Š๏ผˆ(โ€œโ€˜"\']*[LANGUAGE]+[\W\s]*)+)')
654
+ # Get custom language filter
655
+ filters = self.Langfilters
656
+ filters = filters if filters is not None else ""
657
+ # =======================================================================================================
658
+ # Experimental: Other language support.Thแปญ nghiแป‡m: Hแป— trแปฃ ngรดn ngแปฏ khรกc.Expรฉrimentalย : prise en charge dโ€™autres langues.
659
+ # ็›ธๅ…ณ่ฏญ่จ€ๅญ—็ฌฆๅฆ‚ๆœ‰็ผบๅคฑ๏ผŒ็†Ÿๆ‚‰็›ธๅ…ณ่ฏญ่จ€็š„ๆœ‹ๅ‹๏ผŒๅฏไปฅๆไบคๆŠŠ็ผบๅคฑ็š„ๅ‘้Ÿณ็ฌฆๅท่กฅๅ…จใ€‚
660
+ # If relevant language characters are missing, friends who are familiar with the relevant languages can submit a submission to complete the missing pronunciation symbols.
661
+ # S'il manque des caractรจres linguistiques pertinents, les amis qui connaissent les langues concernรฉes peuvent soumettre une soumission pour complรฉter les symboles de prononciation manquants.
662
+ # Nแบฟu thiแบฟu kรฝ tแปฑ ngรดn ngแปฏ liรชn quan, nhแปฏng ngฦฐแปi bแบกn quen thuแป™c vแป›i ngรดn ngแปฏ liรชn quan cรณ thแปƒ gแปญi bร i ฤ‘แปƒ hoร n thร nh cรกc kรฝ hiแป‡u phรกt รขm cรฒn thiแบฟu.
663
+ # -------------------------------------------------------------------------------------------------------
664
+ # Preview feature, other language support
665
+ enablePreview = self.EnablePreview
666
+ if "fr" in filters or \
667
+ "vi" in filters:enablePreview = True
668
+ self.EnablePreview = enablePreview
669
+ # ๅฎž้ชŒๆ€ง๏ผšๆณ•่ฏญๅญ—็ฌฆๆ”ฏๆŒใ€‚Prise en charge des caractรจres franรงais
670
+ RE_FR = "" if not enablePreview else "ร รกรขรฃรครฅรฆรงรจรฉรชรซรฌรญรฎรฏรฐรฑรฒรณรดรตรถรนรบรปรผรฝรพรฟ"
671
+ # ๅฎž้ชŒๆ€ง๏ผš่ถŠๅ—่ฏญๅญ—็ฌฆๆ”ฏๆŒใ€‚Hแป— trแปฃ kรฝ tแปฑ tiแบฟng Viแป‡t
672
+ RE_VI = "" if not enablePreview else "ฤ‘ฦกฦฐฤƒรกร แบฃรฃแบกแบฏแบฑแบณแบตแบทแบฅแบงแบฉแบซแบญรฉรจแบปแบฝแบนแบฟแปแปƒแป…แป‡รญรฌแป‰ฤฉแป‹รณรฒแปรตแปแป‘แป“แป•แป—แป™แป›แปแปŸแปกแปฃรบรนแปงลฉแปฅแปฉแปซแปญแปฏแปฑรดรขรชฦกฦฐแปทแปน"
673
+ # -------------------------------------------------------------------------------------------------------
674
+ # Basic options:
675
+ process_list = [
676
+ ( TAG_S1 , re.compile(self.SYMBOLS_PATTERN) , self._process_symbol ), # Symbol Tag
677
+ ( TAG_KO , re.compile(re.sub(r'LANGUAGE',f'\uac00-\ud7a3',TAG_BASE.pattern)) , self._process_korean ), # Korean words
678
+ ( TAG_TH , re.compile(re.sub(r'LANGUAGE',f'\u0E00-\u0E7F',TAG_BASE.pattern)) , self._process_Thai ), # Thai words support.
679
+ ( TAG_RU , re.compile(re.sub(r'LANGUAGE',f'ะ-ะฏะฐ-ัะั‘',TAG_BASE.pattern)) , self._process_Russian ), # Russian words support.
680
+ ( TAG_NUM , re.compile(r'(\W*\d+\W+\d*\W*\d*)') , self._process_number ), # Number words, Universal in all languages, Ignore it.
681
+ ( TAG_EN , re.compile(re.sub(r'LANGUAGE',f'a-zA-Z{RE_FR}{RE_VI}',TAG_BASE.pattern)) , self._process_english ), # English words + Other language support.
682
+ ( TAG_P1 , re.compile(r'(["\'])(.*?)(\1)') , self._process_quotes ), # Regular quotes
683
+ ( TAG_P2 , re.compile(r'([\n]*[ใ€ใ€Š๏ผˆ(โ€œโ€˜])([^ใ€ใ€Š๏ผˆ(โ€œโ€˜โ€™โ€)๏ผ‰ใ€‹ใ€‘]{3,})([โ€™โ€)๏ผ‰ใ€‹ใ€‘][\W\s]*[\n]{,1})') , self._process_quotes ), # Special quotes, There are left and right.
684
+ ]
685
+ # Extended options: Default False
686
+ if self.keepPinyin == True:process_list.insert(1 ,
687
+ ( TAG_S2 , re.compile(r'([\(๏ผˆ{](?:\s*\w*\d\w*\s*)+[}๏ผ‰\)])') , self._process_pinyin ), # Chinese Pinyin Tag.
688
+ )
689
+ # -------------------------------------------------------------------------------------------------------
690
+ words = []
691
+ lines = re.findall(r'.*\n*', re.sub(self.PARSE_TAG, '' ,text))
692
+ for index , text in enumerate(lines):
693
+ if len(text.strip()) == 0:continue
694
+ self._lang_eos = False
695
+ self._text_cache = {}
696
+ for item in process_list:
697
+ text = self._pattern_symbols(item , text)
698
+ cur_word = self._process_tags([] , text , True)
699
+ if len(cur_word) == 0:continue
700
+ cur_data = cur_word[0] if len(cur_word) > 0 else None
701
+ pre_data = words[-1] if len(words) > 0 else None
702
+ if cur_data and pre_data and cur_data["lang"] == pre_data["lang"] and cur_data["symbol"] == False and pre_data["symbol"] :
703
+ cur_data["text"] = f'{pre_data["text"]}{cur_data["text"]}'
704
+ words.pop()
705
+ words += cur_word
706
+ if self.isLangMerge == True:words = self._merge_results(words)
707
+ lang_count = self._lang_count
708
+ if lang_count and len(lang_count) > 0:
709
+ lang_count = dict(sorted(lang_count.items(), key=lambda x: x[1], reverse=True))
710
+ lang_count = list(lang_count.items())
711
+ self._lang_count = lang_count
712
+ return words
713
+
714
+ def setfilters(self, filters):
715
+ # ๅฝ“่ฟ‡ๆปคๅ™จๆ›ดๆ”นๆ—ถ๏ผŒๆธ…้™ค็ผ“ๅญ˜
716
+ # ํ•„ํ„ฐ๊ฐ€ ๋ณ€๊ฒฝ๋˜๋ฉด ์บ์‹œ๋ฅผ ์ง€์›๋‹ˆ๋‹ค.
717
+ # ใƒ•ใ‚ฃใƒซใ‚ฟใŒๅค‰ๆ›ดใ•ใ‚Œใ‚‹ใจใ€ใ‚ญใƒฃใƒƒใ‚ทใƒฅใŒใ‚ฏใƒชใ‚ขใ•ใ‚Œใพใ™
718
+ # When the filter changes, clear the cache
719
+ if self.Langfilters != filters:
720
+ self._clears()
721
+ self.Langfilters = filters
722
+
723
+ def getfilters(self):
724
+ return self.Langfilters
725
+
726
+ def setPriorityThreshold(self, threshold:float):
727
+ self.LangPriorityThreshold = threshold
728
+
729
+ def getPriorityThreshold(self):
730
+ return self.LangPriorityThreshold
731
+
732
+ def getCounts(self):
733
+ lang_count = self._lang_count
734
+ if lang_count is not None:return lang_count
735
+ text_langs = self._text_langs
736
+ if text_langs is None or len(text_langs) == 0:return [("zh",0)]
737
+ lang_counts = defaultdict(int)
738
+ for d in text_langs:lang_counts[d['lang']] += int(len(d['text'])*2) if d['lang'] == "zh" else len(d['text'])
739
+ lang_counts = dict(sorted(lang_counts.items(), key=lambda x: x[1], reverse=True))
740
+ lang_counts = list(lang_counts.items())
741
+ self._lang_count = lang_counts
742
+ return lang_counts
743
+
744
+ def getTexts(self, text:str):
745
+ if text is None or len(text.strip()) == 0:
746
+ self._clears()
747
+ return []
748
+ # lasts
749
+ text_langs = self._text_langs
750
+ if self._text_lasts == text and text_langs is not None:return text_langs
751
+ # parse
752
+ self._text_waits = []
753
+ self._lang_count = None
754
+ self._text_lasts = text
755
+ text = self._parse_symbols(text)
756
+ self._text_langs = text
757
+ return text
758
+
759
+ def classify(self, text:str):
760
+ return self.getTexts(text)
761
+
762
+ def printList(langlist):
763
+ """
764
+ ๅŠŸ่ƒฝ๏ผšๆ‰“ๅฐๆ•ฐ็ป„็ป“ๆžœ
765
+ ๊ธฐ๋Šฅ: ์–ด๋ ˆ์ด ๊ฒฐ๊ณผ ์ธ์‡„
766
+ ๆฉŸ่ƒฝ:้…ๅˆ—็ตๆžœใ‚’ๅฐๅˆท
767
+ Function: Print array results
768
+ """
769
+ print("\n===================ใ€ๆ‰“ๅฐ็ป“ๆžœใ€‘===================")
770
+ if langlist is None or len(langlist) == 0:
771
+ print("ๆ— ๅ†…ๅฎน็ป“ๆžœ,No content result")
772
+ return
773
+ for line in langlist:
774
+ print(line)
775
+ pass
776
+
777
+
778
+
779
+ def main():
780
+
781
+ # -----------------------------------
782
+ # ๆ›ดๆ–ฐๆ—ฅๅฟ—๏ผšๆ–ฐ็‰ˆๆœฌๅˆ†่ฏๆ›ดๅŠ ็ฒพๅ‡†ใ€‚
783
+ # Changelog: The new version of the word segmentation is more accurate.
784
+ # ใƒใ‚งใƒณใ‚ธใƒญใ‚ฐ:ๆ–ฐใ—ใ„ใƒใƒผใ‚ธใƒงใƒณใฎๅ˜่ชžใ‚ปใ‚ฐใƒกใƒณใƒ†ใƒผใ‚ทใƒงใƒณใฏใ‚ˆใ‚Šๆญฃ็ขบใงใ™ใ€‚
785
+ # Changelog: ๋ถ„ํ• ์ด๋ผ๋Š” ๋‹จ์–ด์˜ ์ƒˆ๋กœ์šด ๋ฒ„์ „์ด ๋” ์ •ํ™•ํ•ฉ๋‹ˆ๋‹ค.
786
+ # -----------------------------------
787
+
788
+ # ่พ“ๅ…ฅ็คบไพ‹1๏ผš๏ผˆๅŒ…ๅซๆ—ฅๆ–‡๏ผŒไธญๆ–‡๏ผ‰Input Example 1: (including Japanese, Chinese)
789
+ # text = "โ€œๆ˜จๆ—ฅใฏ้›จใŒ้™ใฃใŸ๏ผŒ้Ÿณๆฅฝใ€ๆ˜ ็”ปใ€‚ใ€‚ใ€‚โ€ไฝ ไปŠๅคฉๅญฆไน ๆ—ฅ่ฏญไบ†ๅ—๏ผŸๆ˜ฅใฏๆกœใฎๅญฃ็ฏ€ใงใ™ใ€‚่ฏญ็งๅˆ†่ฏๆ˜ฏ่ฏญ้Ÿณๅˆๆˆๅฟ…ไธๅฏๅฐ‘็š„็Žฏ่Š‚ใ€‚่จ€่ชžๅˆ†่ฉžใฏ้Ÿณๅฃฐๅˆๆˆใซๆฌ ใ‹ใ›ใชใ„็’ฐ็ฏ€ใงใ‚ใ‚‹๏ผ"
790
+
791
+ # ่พ“ๅ…ฅ็คบไพ‹2๏ผš๏ผˆๅŒ…ๅซๆ—ฅๆ–‡๏ผŒไธญๆ–‡๏ผ‰Input Example 1: (including Japanese, Chinese)
792
+ # text = "ๆฌข่ฟŽๆฅ็Žฉใ€‚ๆฑไบฌ๏ผŒใฏๆ—ฅๆœฌใฎ้ฆ–้ƒฝใงใ™ใ€‚ๆฌข่ฟŽๆฅ็Žฉ. ๅคชๅฅฝไบ†!"
793
+
794
+ # ่พ“ๅ…ฅ็คบไพ‹3๏ผš๏ผˆๅŒ…ๅซๆ—ฅๆ–‡๏ผŒไธญๆ–‡๏ผ‰Input Example 1: (including Japanese, Chinese)
795
+ # text = "ๆ˜Žๆ—ฅใ€็งใŸใกใฏๆตท่พบใซใƒใ‚ซใƒณใ‚นใซ่กŒใใพใ™ใ€‚ไฝ ไผš่ฏดๆ—ฅ่ฏญๅ—๏ผšโ€œไธญๅ›ฝ่ชžใ€่ฉฑใ›ใพใ™ใ‹โ€ ไฝ ็š„ๆ—ฅ่ฏญ็œŸๅฅฝๅ•Š๏ผ"
796
+
797
+
798
+ # ่พ“ๅ…ฅ็คบไพ‹4๏ผš๏ผˆๅŒ…ๅซๆ—ฅๆ–‡๏ผŒไธญๆ–‡๏ผŒ้Ÿฉ่ฏญ๏ผŒ่‹ฑๆ–‡๏ผ‰Input Example 4: (including Japanese, Chinese, Korean, English)
799
+ # text = "ไฝ ็š„ๅๅญ—ๅซ<ja>ไฝใ€…ๆœจ๏ผŸ<ja>ๅ—๏ผŸ้Ÿฉ่ฏญไธญ็š„์•ˆ๋…• ์˜ค๋น ่ฏปไป€ไนˆๅ‘ข๏ผŸใ‚ใชใŸใฎไฝ“่‚ฒใฎๅ…ˆ็”Ÿใฏ่ชฐใงใ™ใ‹? ๆญคๆฌกๅ‘ๅธƒไผšๅธฆๆฅไบ†ๅ››ๆฌพiPhone 15็ณปๅˆ—ๆœบๅž‹ๅ’Œไธ‰ๆฌพApple Watch็ญ‰ไธ€็ณปๅˆ—ๆ–ฐๅ“๏ผŒ่ฟ™ๆฌก็š„iPad Air้‡‡็”จไบ†LCDๅฑๅน•"
800
+
801
+
802
+ # ่ฏ•้ชŒๆ€งๆ”ฏๆŒ๏ผš"fr"ๆณ•่ฏญ , "vi"่ถŠๅ—่ฏญ , "ru"ไฟ„่ฏญ , "th"ๆณฐ่ฏญใ€‚Experimental: Other language support.
803
+ langsegment = LangSegment()
804
+ langsegment.setfilters(["fr", "vi" , "ja", "zh", "ko", "en" , "ru" , "th"])
805
+ text = """
806
+ ๆˆ‘ๅ–œๆฌขๅœจ้›จๅคฉ้‡Œๅฌ้Ÿณไนใ€‚
807
+ I enjoy listening to music on rainy days.
808
+ ้›จใฎๆ—ฅใซ้Ÿณๆฅฝใ‚’่ดใใฎใŒๅฅฝใใงใ™ใ€‚
809
+ ๋น„ ์˜ค๋Š” ๋‚ ์— ์Œ์•…์„ ๋“ฃ๋Š” ๊ฒƒ์„ ์ฆ๊น๋‹ˆ๋‹คใ€‚
810
+ J'aime รฉcouter de la musique les jours de pluie.
811
+ Tรดi thรญch nghe nhแบกc vร o nhแปฏng ngร y mฦฐa.
812
+ ะœะฝะต ะฝั€ะฐะฒะธั‚ัั ัะปัƒัˆะฐั‚ัŒ ะผัƒะทั‹ะบัƒ ะฒ ะดะพะถะดะปะธะฒัƒัŽ ะฟะพะณะพะดัƒ.
813
+ เธ‰เธฑเธ™เธŠเธญเธšเธŸเธฑเธ‡เน€เธžเธฅเธ‡เนƒเธ™เธงเธฑเธ™เธ—เธตเนˆเธเธ™เธ•เธ
814
+ """
815
+
816
+
817
+
818
+ # ่ฟ›่กŒๅˆ†่ฏ๏ผš๏ผˆๆŽฅๅ…ฅTTS้กน็›ฎไป…้œ€ไธ€่กŒไปฃ็ ่ฐƒ็”จ๏ผ‰Segmentation: (Only one line of code is required to access the TTS project)
819
+ langlist = langsegment.getTexts(text)
820
+ printList(langlist)
821
+
822
+
823
+ # ่ฏญ็ง็ปŸ่ฎก:Language statistics:
824
+ print("\n===================ใ€่ฏญ็ง็ปŸ่ฎกใ€‘===================")
825
+ # ่Žทๅ–ๆ‰€ๆœ‰่ฏญ็งๆ•ฐ็ป„็ป“ๆžœ๏ผŒๆ นๆฎๅ†…ๅฎนๅญ—ๆ•ฐ้™ๅบๆŽ’ๅˆ—
826
+ # Get the array results in all languages, sorted in descending order according to the number of content words
827
+ langCounts = langsegment.getCounts()
828
+ print(langCounts , "\n")
829
+
830
+ # ๆ นๆฎ็ป“ๆžœ่Žทๅ–ๅ†…ๅฎน็š„ไธป่ฆ่ฏญ็ง (่ฏญ่จ€๏ผŒๅญ—ๆ•ฐๅซๆ ‡็‚น)
831
+ # Get the main language of content based on the results (language, word count including punctuation)
832
+ lang , count = langCounts[0]
833
+ print(f"่พ“ๅ…ฅๅ†…ๅฎน็š„ไธป่ฆ่ฏญ่จ€ไธบ = {lang} ๏ผŒๅญ—ๆ•ฐ = {count}")
834
+ print("==================================================\n")
835
+
836
+
837
+ # ๅˆ†่ฏ่พ“ๅ‡บ๏ผšlang=่ฏญ่จ€๏ผŒtext=ๅ†…ๅฎนใ€‚Word output: lang = language, text = content
838
+ # ===================ใ€ๆ‰“ๅฐ็ป“ๆžœใ€‘===================
839
+ # {'lang': 'zh', 'text': 'ไฝ ็š„ๅๅญ—ๅซ'}
840
+ # {'lang': 'ja', 'text': 'ไฝใ€…ๆœจ๏ผŸ'}
841
+ # {'lang': 'zh', 'text': 'ๅ—๏ผŸ้Ÿฉ่ฏญไธญ็š„'}
842
+ # {'lang': 'ko', 'text': '์•ˆ๋…• ์˜ค๋น '}
843
+ # {'lang': 'zh', 'text': '่ฏปไป€ไนˆๅ‘ข๏ผŸ'}
844
+ # {'lang': 'ja', 'text': 'ใ‚ใชใŸใฎไฝ“่‚ฒใฎๅ…ˆ็”Ÿใฏ่ชฐใงใ™ใ‹?'}
845
+ # {'lang': 'zh', 'text': ' ๆญคๆฌกๅ‘ๅธƒไผšๅธฆๆฅไบ†ๅ››ๆฌพ'}
846
+ # {'lang': 'en', 'text': 'i Phone '}
847
+ # {'lang': 'zh', 'text': '15็ณปๅˆ—ๆœบๅž‹ๅ’Œไธ‰ๆฌพ'}
848
+ # {'lang': 'en', 'text': 'Apple Watch '}
849
+ # {'lang': 'zh', 'text': '็ญ‰ไธ€็ณปๅˆ—ๆ–ฐๅ“๏ผŒ่ฟ™ๆฌก็š„'}
850
+ # {'lang': 'en', 'text': 'i Pad Air '}
851
+ # {'lang': 'zh', 'text': '้‡‡็”จไบ†'}
852
+ # {'lang': 'en', 'text': 'L C D '}
853
+ # {'lang': 'zh', 'text': 'ๅฑๅน•'}
854
+ # ===================ใ€่ฏญ็ง็ปŸ่ฎกใ€‘===================
855
+
856
+ # ===================ใ€่ฏญ็ง็ปŸ่ฎกใ€‘===================
857
+ # [('zh', 51), ('ja', 19), ('en', 18), ('ko', 5)]
858
+
859
+ # ่พ“ๅ…ฅๅ†…ๅฎน็š„ไธป่ฆ่ฏญ่จ€ไธบ = zh ๏ผŒๅญ—ๆ•ฐ = 51
860
+ # ==================================================
861
+ # The main language of the input content is = zh, word count = 51
862
+
863
+
864
+ if __name__ == "__main__":
865
+ main()
g2p/language_segmentation/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .LangSegment import LangSegment
2
+
3
+
4
+ # release
5
+ __version__ = '0.3.5'
6
+
7
+
8
+ # develop
9
+ __develop__ = 'dev-0.0.1'
g2p/language_segmentation/utils/__init__.py ADDED
File without changes
g2p/language_segmentation/utils/num.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Digital processing from GPT_SoVITS num.py ๏ผˆthanks๏ผ‰
15
+ """
16
+ Rules to verbalize numbers into Chinese characters.
17
+ https://zh.wikipedia.org/wiki/ไธญๆ–‡ๆ•ฐๅญ—#็พไปฃไธญๆ–‡
18
+ """
19
+
20
+ import re
21
+ from collections import OrderedDict
22
+ from typing import List
23
+
24
+ DIGITS = {str(i): tran for i, tran in enumerate('้›ถไธ€ไบŒไธ‰ๅ››ไบ”ๅ…ญไธƒๅ…ซไน')}
25
+ UNITS = OrderedDict({
26
+ 1: 'ๅ',
27
+ 2: '็™พ',
28
+ 3: 'ๅƒ',
29
+ 4: 'ไธ‡',
30
+ 8: 'ไบฟ',
31
+ })
32
+
33
+ COM_QUANTIFIERS = '(ๅค„|ๅฐ|ๆžถ|ๆžš|่ถŸ|ๅน…|ๅนณ|ๆ–น|ๅ ต|้—ด|ๅบŠ|ๆ ช|ๆ‰น|้กน|ไพ‹|ๅˆ—|็ฏ‡|ๆ ‹|ๆณจ|ไบฉ|ๅฐ|่‰˜|ๆŠŠ|็›ฎ|ๅฅ—|ๆฎต|ไบบ|ๆ‰€|ๆœต|ๅŒน|ๅผ |ๅบง|ๅ›ž|ๅœบ|ๅฐพ|ๆก|ไธช|้ฆ–|้˜™|้˜ต|็ฝ‘|็‚ฎ|้กถ|ไธ˜|ๆฃต|ๅช|ๆ”ฏ|่ขญ|่พ†|ๆŒ‘|ๆ‹…|้ข—|ๅฃณ|็ช |ๆ›ฒ|ๅข™|็พค|่…”|็ ฃ|ๅบง|ๅฎข|่ดฏ|ๆ‰Ž|ๆ†|ๅˆ€|ไปค|ๆ‰“|ๆ‰‹|็ฝ—|ๅก|ๅฑฑ|ๅฒญ|ๆฑŸ|ๆบช|้’Ÿ|้˜Ÿ|ๅ•|ๅŒ|ๅฏน|ๅ‡บ|ๅฃ|ๅคด|่„š|ๆฟ|่ทณ|ๆž|ไปถ|่ดด|้’ˆ|็บฟ|็ฎก|ๅ|ไฝ|่บซ|ๅ ‚|่ฏพ|ๆœฌ|้กต|ๅฎถ|ๆˆท|ๅฑ‚|ไธ|ๆฏซ|ๅŽ˜|ๅˆ†|้’ฑ|ไธค|ๆ–ค|ๆ‹…|้“ข|็Ÿณ|้’ง|้”ฑ|ๅฟฝ|(ๅƒ|ๆฏซ|ๅพฎ)ๅ…‹|ๆฏซ|ๅŽ˜|(ๅ…ฌ)ๅˆ†|ๅˆ†|ๅฏธ|ๅฐบ|ไธˆ|้‡Œ|ๅฏป|ๅธธ|้“บ|็จ‹|(ๅƒ|ๅˆ†|ๅŽ˜|ๆฏซ|ๅพฎ)็ฑณ|็ฑณ|ๆ’ฎ|ๅ‹บ|ๅˆ|ๅ‡|ๆ–—|็Ÿณ|็›˜|็ข—|็ขŸ|ๅ |ๆกถ|็ฌผ|็›†|็›’|ๆฏ|้’Ÿ|ๆ–›|้”…|็ฐ‹|็ฏฎ|็›˜|ๆกถ|็ฝ|็“ถ|ๅฃถ|ๅฎ|็›|็ฎฉ|็ฎฑ|็…ฒ|ๅ•–|่ข‹|้’ต|ๅนด|ๆœˆ|ๆ—ฅ|ๅญฃ|ๅˆป|ๆ—ถ|ๅ‘จ|ๅคฉ|็ง’|ๅˆ†|ๅฐๆ—ถ|ๆ—ฌ|็บช|ๅฒ|ไธ–|ๆ›ด|ๅคœ|ๆ˜ฅ|ๅค|็ง‹|ๅ†ฌ|ไปฃ|ไผ|่พˆ|ไธธ|ๆณก|็ฒ’|้ข—|ๅนข|ๅ †|ๆก|ๆ น|ๆ”ฏ|้“|้ข|็‰‡|ๅผ |้ข—|ๅ—|ๅ…ƒ|(ไบฟ|ๅƒไธ‡|็™พไธ‡|ไธ‡|ๅƒ|็™พ)|(ไบฟ|ๅƒไธ‡|็™พไธ‡|ไธ‡|ๅƒ|็™พ|็พŽ|)ๅ…ƒ|(ไบฟ|ๅƒไธ‡|็™พไธ‡|ไธ‡|ๅƒ|็™พ|ๅ|)ๅจ|(ไบฟ|ๅƒไธ‡|็™พไธ‡|ไธ‡|ๅƒ|็™พ|)ๅ—|่ง’|ๆฏ›|ๅˆ†)'
34
+
35
+ # ๅˆ†ๆ•ฐ่กจ่พพๅผ
36
+ RE_FRAC = re.compile(r'(-?)(\d+)/(\d+)')
37
+
38
+
39
+ def replace_frac(match) -> str:
40
+ """
41
+ Args:
42
+ match (re.Match)
43
+ Returns:
44
+ str
45
+ """
46
+ sign = match.group(1)
47
+ nominator = match.group(2)
48
+ denominator = match.group(3)
49
+ sign: str = "่ดŸ" if sign else ""
50
+ nominator: str = num2str(nominator)
51
+ denominator: str = num2str(denominator)
52
+ result = f"{sign}{denominator}ๅˆ†ไน‹{nominator}"
53
+ return result
54
+
55
+
56
+ # ็™พๅˆ†ๆ•ฐ่กจ่พพๅผ
57
+ RE_PERCENTAGE = re.compile(r'(-?)(\d+(\.\d+)?)%')
58
+
59
+
60
+ def replace_percentage(match) -> str:
61
+ """
62
+ Args:
63
+ match (re.Match)
64
+ Returns:
65
+ str
66
+ """
67
+ sign = match.group(1)
68
+ percent = match.group(2)
69
+ sign: str = "่ดŸ" if sign else ""
70
+ percent: str = num2str(percent)
71
+ result = f"{sign}็™พๅˆ†ไน‹{percent}"
72
+ return result
73
+
74
+
75
+ # ๆ•ดๆ•ฐ่กจ่พพๅผ
76
+ # ๅธฆ่ดŸๅท็š„ๆ•ดๆ•ฐ -10
77
+ RE_INTEGER = re.compile(r'(-)' r'(\d+)')
78
+
79
+
80
+ def replace_negative_num(match) -> str:
81
+ """
82
+ Args:
83
+ match (re.Match)
84
+ Returns:
85
+ str
86
+ """
87
+ sign = match.group(1)
88
+ number = match.group(2)
89
+ sign: str = "่ดŸ" if sign else ""
90
+ number: str = num2str(number)
91
+ result = f"{sign}{number}"
92
+ return result
93
+
94
+
95
+ # ็ผ–ๅท-ๆ— ็ฌฆๅทๆ•ดๅฝข
96
+ # 00078
97
+ RE_DEFAULT_NUM = re.compile(r'\d{3}\d*')
98
+
99
+
100
+ def replace_default_num(match):
101
+ """
102
+ Args:
103
+ match (re.Match)
104
+ Returns:
105
+ str
106
+ """
107
+ number = match.group(0)
108
+ return verbalize_digit(number, alt_one=True)
109
+
110
+
111
+ # ๅŠ ๅ‡ไน˜้™ค
112
+ # RE_ASMD = re.compile(
113
+ # r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))([\+\-\ร—รท=])((-?)((\d+)(\.\d+)?)|(\.(\d+)))')
114
+ RE_ASMD = re.compile(
115
+ r'((-?)((\d+)(\.\d+)?[โฐยนยฒยณโดโตโถโทโธโนหฃสธโฟ]*)|(\.\d+[โฐยนยฒยณโดโตโถโทโธโนหฃสธโฟ]*)|([A-Za-z][โฐยนยฒยณโดโตโถโทโธโนหฃสธโฟ]*))([\+\-\ร—รท=])((-?)((\d+)(\.\d+)?[โฐยนยฒยณโดโตโถโทโธโนหฃสธโฟ]*)|(\.\d+[โฐยนยฒยณโดโตโถโทโธโนหฃสธโฟ]*)|([A-Za-z][โฐยนยฒยณโดโตโถโทโธโนหฃสธโฟ]*))')
116
+
117
+ asmd_map = {
118
+ '+': 'ๅŠ ',
119
+ '-': 'ๅ‡',
120
+ 'ร—': 'ไน˜',
121
+ 'รท': '้™ค',
122
+ '=': '็ญ‰ไบŽ'
123
+ }
124
+
125
+ def replace_asmd(match) -> str:
126
+ """
127
+ Args:
128
+ match (re.Match)
129
+ Returns:
130
+ str
131
+ """
132
+ result = match.group(1) + asmd_map[match.group(8)] + match.group(9)
133
+ return result
134
+
135
+
136
+ # ๆฌกๆ–นไธ“้กน
137
+ RE_POWER = re.compile(r'[โฐยนยฒยณโดโตโถโทโธโนหฃสธโฟ]+')
138
+
139
+ power_map = {
140
+ 'โฐ': '0',
141
+ 'ยน': '1',
142
+ 'ยฒ': '2',
143
+ 'ยณ': '3',
144
+ 'โด': '4',
145
+ 'โต': '5',
146
+ 'โถ': '6',
147
+ 'โท': '7',
148
+ 'โธ': '8',
149
+ 'โน': '9',
150
+ 'หฃ': 'x',
151
+ 'สธ': 'y',
152
+ 'โฟ': 'n'
153
+ }
154
+
155
+ def replace_power(match) -> str:
156
+ """
157
+ Args:
158
+ match (re.Match)
159
+ Returns:
160
+ str
161
+ """
162
+ power_num = ""
163
+ for m in match.group(0):
164
+ power_num += power_map[m]
165
+ result = "็š„" + power_num + "ๆฌกๆ–น"
166
+ return result
167
+
168
+
169
+ # ๆ•ฐๅญ—่กจ่พพๅผ
170
+ # ็บฏๅฐๆ•ฐ
171
+ RE_DECIMAL_NUM = re.compile(r'(-?)((\d+)(\.\d+))' r'|(\.(\d+))')
172
+ # ๆญฃๆ•ดๆ•ฐ + ้‡่ฏ
173
+ RE_POSITIVE_QUANTIFIERS = re.compile(r"(\d+)([ๅคšไฝ™ๅ‡ \+])?" + COM_QUANTIFIERS)
174
+ RE_NUMBER = re.compile(r'(-?)((\d+)(\.\d+)?)' r'|(\.(\d+))')
175
+
176
+
177
+ def replace_positive_quantifier(match) -> str:
178
+ """
179
+ Args:
180
+ match (re.Match)
181
+ Returns:
182
+ str
183
+ """
184
+ number = match.group(1)
185
+ match_2 = match.group(2)
186
+ if match_2 == "+":
187
+ match_2 = "ๅคš"
188
+ match_2: str = match_2 if match_2 else ""
189
+ quantifiers: str = match.group(3)
190
+ number: str = num2str(number)
191
+ result = f"{number}{match_2}{quantifiers}"
192
+ return result
193
+
194
+
195
+ def replace_number(match) -> str:
196
+ """
197
+ Args:
198
+ match (re.Match)
199
+ Returns:
200
+ str
201
+ """
202
+ sign = match.group(1)
203
+ number = match.group(2)
204
+ pure_decimal = match.group(5)
205
+ if pure_decimal:
206
+ result = num2str(pure_decimal)
207
+ else:
208
+ sign: str = "่ดŸ" if sign else ""
209
+ number: str = num2str(number)
210
+ result = f"{sign}{number}"
211
+ return result
212
+
213
+
214
+ # ่Œƒๅ›ด่กจ่พพๅผ
215
+ # match.group(1) and match.group(8) are copy from RE_NUMBER
216
+
217
+ RE_RANGE = re.compile(
218
+ r"""
219
+ (?<![\d\+\-\ร—รท=]) # ไฝฟ็”จๅๅ‘ๅ‰็žปไปฅ็กฎไฟๆ•ฐๅญ—่Œƒๅ›ดไน‹ๅ‰ๆฒกๆœ‰ๅ…ถไป–ๆ•ฐๅญ—ๅ’Œๆ“ไฝœ็ฌฆ
220
+ ((-?)((\d+)(\.\d+)?)) # ๅŒน้…่Œƒๅ›ด่ตทๅง‹็š„่ดŸๆ•ฐๆˆ–ๆญฃๆ•ฐ๏ผˆๆ•ดๆ•ฐๆˆ–ๅฐๆ•ฐ๏ผ‰
221
+ [-~] # ๅŒน้…่Œƒๅ›ดๅˆ†้š”็ฌฆ
222
+ ((-?)((\d+)(\.\d+)?)) # ๅŒน้…่Œƒๅ›ด็ป“ๆŸ็š„่ดŸๆ•ฐๆˆ–ๆญฃๆ•ฐ๏ผˆๆ•ดๆ•ฐๆˆ–ๅฐๆ•ฐ๏ผ‰
223
+ (?![\d\+\-\ร—รท=]) # ไฝฟ็”จๆญฃๅ‘ๅ‰็žปไปฅ็กฎไฟๆ•ฐๅญ—่Œƒๅ›ดไน‹ๅŽๆฒกๆœ‰ๅ…ถไป–ๆ•ฐๅญ—ๅ’Œๆ“ไฝœ็ฌฆ
224
+ """, re.VERBOSE)
225
+
226
+
227
+ def replace_range(match) -> str:
228
+ """
229
+ Args:
230
+ match (re.Match)
231
+ Returns:
232
+ str
233
+ """
234
+ first, second = match.group(1), match.group(6)
235
+ first = RE_NUMBER.sub(replace_number, first)
236
+ second = RE_NUMBER.sub(replace_number, second)
237
+ result = f"{first}ๅˆฐ{second}"
238
+ return result
239
+
240
+
241
+ # ~่‡ณ่กจ่พพๅผ
242
+ RE_TO_RANGE = re.compile(
243
+ r'((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|ยฐC|โ„ƒ|ๅบฆ|ๆ‘„ๆฐๅบฆ|cm2|cmยฒ|cm3|cmยณ|cm|db|ds|kg|km|m2|mยฒ|mยณ|m3|ml|m|mm|s)[~]((-?)((\d+)(\.\d+)?)|(\.(\d+)))(%|ยฐC|โ„ƒ|ๅบฆ|ๆ‘„ๆฐๅบฆ|cm2|cmยฒ|cm3|cmยณ|cm|db|ds|kg|km|m2|mยฒ|mยณ|m3|ml|m|mm|s)')
244
+
245
+ def replace_to_range(match) -> str:
246
+ """
247
+ Args:
248
+ match (re.Match)
249
+ Returns:
250
+ str
251
+ """
252
+ result = match.group(0).replace('~', '่‡ณ')
253
+ return result
254
+
255
+
256
+ def _get_value(value_string: str, use_zero: bool=True) -> List[str]:
257
+ stripped = value_string.lstrip('0')
258
+ if len(stripped) == 0:
259
+ return []
260
+ elif len(stripped) == 1:
261
+ if use_zero and len(stripped) < len(value_string):
262
+ return [DIGITS['0'], DIGITS[stripped]]
263
+ else:
264
+ return [DIGITS[stripped]]
265
+ else:
266
+ largest_unit = next(
267
+ power for power in reversed(UNITS.keys()) if power < len(stripped))
268
+ first_part = value_string[:-largest_unit]
269
+ second_part = value_string[-largest_unit:]
270
+ return _get_value(first_part) + [UNITS[largest_unit]] + _get_value(
271
+ second_part)
272
+
273
+
274
+ def verbalize_cardinal(value_string: str) -> str:
275
+ if not value_string:
276
+ return ''
277
+
278
+ # 000 -> '้›ถ' , 0 -> '้›ถ'
279
+ value_string = value_string.lstrip('0')
280
+ if len(value_string) == 0:
281
+ return DIGITS['0']
282
+
283
+ result_symbols = _get_value(value_string)
284
+ # verbalized number starting with 'ไธ€ๅ*' is abbreviated as `ๅ*`
285
+ if len(result_symbols) >= 2 and result_symbols[0] == DIGITS[
286
+ '1'] and result_symbols[1] == UNITS[1]:
287
+ result_symbols = result_symbols[1:]
288
+ return ''.join(result_symbols)
289
+
290
+
291
+ def verbalize_digit(value_string: str, alt_one=False) -> str:
292
+ result_symbols = [DIGITS[digit] for digit in value_string]
293
+ result = ''.join(result_symbols)
294
+ if alt_one:
295
+ result = result.replace("ไธ€", "ๅนบ")
296
+ return result
297
+
298
+
299
+ def num2str(value_string: str) -> str:
300
+ integer_decimal = value_string.split('.')
301
+ if len(integer_decimal) == 1:
302
+ integer = integer_decimal[0]
303
+ decimal = ''
304
+ elif len(integer_decimal) == 2:
305
+ integer, decimal = integer_decimal
306
+ else:
307
+ raise ValueError(
308
+ f"The value string: '${value_string}' has more than one point in it."
309
+ )
310
+
311
+ result = verbalize_cardinal(integer)
312
+
313
+ decimal = decimal.rstrip('0')
314
+ if decimal:
315
+ # '.22' is verbalized as '้›ถ็‚นไบŒไบŒ'
316
+ # '3.20' is verbalized as 'ไธ‰็‚นไบŒ
317
+ result = result if result else "้›ถ"
318
+ result += '็‚น' + verbalize_digit(decimal)
319
+ return result
320
+
321
+
322
+ if __name__ == "__main__":
323
+
324
+ text = ""
325
+ text = num2str(text)
326
+ print(text)
327
+ pass
g2p/sources/bpmf_2_pinyin.txt ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ b ใ„…
2
+ p ใ„†
3
+ m ใ„‡
4
+ f ใ„ˆ
5
+ d ใ„‰
6
+ t ใ„Š
7
+ n ใ„‹
8
+ l ใ„Œ
9
+ g ใ„
10
+ k ใ„Ž
11
+ h ใ„
12
+ j ใ„
13
+ q ใ„‘
14
+ x ใ„’
15
+ zh ใ„“
16
+ ch ใ„”
17
+ sh ใ„•
18
+ r ใ„–
19
+ z ใ„—
20
+ c ใ„˜
21
+ s ใ„™
22
+ i ใ„ง
23
+ u ใ„จ
24
+ v ใ„ฉ
25
+ a ใ„š
26
+ o ใ„›
27
+ e ใ„œ
28
+ e ใ„
29
+ ai ใ„ž
30
+ ei ใ„Ÿ
31
+ ao ใ„ 
32
+ ou ใ„ก
33
+ an ใ„ข
34
+ en ใ„ฃ
35
+ ang ใ„ค
36
+ eng ใ„ฅ
37
+ er ใ„ฆ
38
+ 2 หŠ
39
+ 3 ห‡
40
+ 4 ห‹
41
+ 0 ห™
g2p/sources/chinese_lexicon.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3a7685d1c3e68eb2fa304bfc63e90c90c3c1a1948839a5b1b507b2131b3e2fb
3
+ size 14779443
g2p/sources/g2p_chinese_model/config.json ADDED
@@ -0,0 +1,819 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/BERT-POLY-v2/pretrained_models/mini_bert",
3
+ "architectures": [
4
+ "BertPoly"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "directionality": "bidi",
9
+ "gradient_checkpointing": false,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 384,
13
+ "id2label": {
14
+ "0": "LABEL_0",
15
+ "1": "LABEL_1",
16
+ "2": "LABEL_2",
17
+ "3": "LABEL_3",
18
+ "4": "LABEL_4",
19
+ "5": "LABEL_5",
20
+ "6": "LABEL_6",
21
+ "7": "LABEL_7",
22
+ "8": "LABEL_8",
23
+ "9": "LABEL_9",
24
+ "10": "LABEL_10",
25
+ "11": "LABEL_11",
26
+ "12": "LABEL_12",
27
+ "13": "LABEL_13",
28
+ "14": "LABEL_14",
29
+ "15": "LABEL_15",
30
+ "16": "LABEL_16",
31
+ "17": "LABEL_17",
32
+ "18": "LABEL_18",
33
+ "19": "LABEL_19",
34
+ "20": "LABEL_20",
35
+ "21": "LABEL_21",
36
+ "22": "LABEL_22",
37
+ "23": "LABEL_23",
38
+ "24": "LABEL_24",
39
+ "25": "LABEL_25",
40
+ "26": "LABEL_26",
41
+ "27": "LABEL_27",
42
+ "28": "LABEL_28",
43
+ "29": "LABEL_29",
44
+ "30": "LABEL_30",
45
+ "31": "LABEL_31",
46
+ "32": "LABEL_32",
47
+ "33": "LABEL_33",
48
+ "34": "LABEL_34",
49
+ "35": "LABEL_35",
50
+ "36": "LABEL_36",
51
+ "37": "LABEL_37",
52
+ "38": "LABEL_38",
53
+ "39": "LABEL_39",
54
+ "40": "LABEL_40",
55
+ "41": "LABEL_41",
56
+ "42": "LABEL_42",
57
+ "43": "LABEL_43",
58
+ "44": "LABEL_44",
59
+ "45": "LABEL_45",
60
+ "46": "LABEL_46",
61
+ "47": "LABEL_47",
62
+ "48": "LABEL_48",
63
+ "49": "LABEL_49",
64
+ "50": "LABEL_50",
65
+ "51": "LABEL_51",
66
+ "52": "LABEL_52",
67
+ "53": "LABEL_53",
68
+ "54": "LABEL_54",
69
+ "55": "LABEL_55",
70
+ "56": "LABEL_56",
71
+ "57": "LABEL_57",
72
+ "58": "LABEL_58",
73
+ "59": "LABEL_59",
74
+ "60": "LABEL_60",
75
+ "61": "LABEL_61",
76
+ "62": "LABEL_62",
77
+ "63": "LABEL_63",
78
+ "64": "LABEL_64",
79
+ "65": "LABEL_65",
80
+ "66": "LABEL_66",
81
+ "67": "LABEL_67",
82
+ "68": "LABEL_68",
83
+ "69": "LABEL_69",
84
+ "70": "LABEL_70",
85
+ "71": "LABEL_71",
86
+ "72": "LABEL_72",
87
+ "73": "LABEL_73",
88
+ "74": "LABEL_74",
89
+ "75": "LABEL_75",
90
+ "76": "LABEL_76",
91
+ "77": "LABEL_77",
92
+ "78": "LABEL_78",
93
+ "79": "LABEL_79",
94
+ "80": "LABEL_80",
95
+ "81": "LABEL_81",
96
+ "82": "LABEL_82",
97
+ "83": "LABEL_83",
98
+ "84": "LABEL_84",
99
+ "85": "LABEL_85",
100
+ "86": "LABEL_86",
101
+ "87": "LABEL_87",
102
+ "88": "LABEL_88",
103
+ "89": "LABEL_89",
104
+ "90": "LABEL_90",
105
+ "91": "LABEL_91",
106
+ "92": "LABEL_92",
107
+ "93": "LABEL_93",
108
+ "94": "LABEL_94",
109
+ "95": "LABEL_95",
110
+ "96": "LABEL_96",
111
+ "97": "LABEL_97",
112
+ "98": "LABEL_98",
113
+ "99": "LABEL_99",
114
+ "100": "LABEL_100",
115
+ "101": "LABEL_101",
116
+ "102": "LABEL_102",
117
+ "103": "LABEL_103",
118
+ "104": "LABEL_104",
119
+ "105": "LABEL_105",
120
+ "106": "LABEL_106",
121
+ "107": "LABEL_107",
122
+ "108": "LABEL_108",
123
+ "109": "LABEL_109",
124
+ "110": "LABEL_110",
125
+ "111": "LABEL_111",
126
+ "112": "LABEL_112",
127
+ "113": "LABEL_113",
128
+ "114": "LABEL_114",
129
+ "115": "LABEL_115",
130
+ "116": "LABEL_116",
131
+ "117": "LABEL_117",
132
+ "118": "LABEL_118",
133
+ "119": "LABEL_119",
134
+ "120": "LABEL_120",
135
+ "121": "LABEL_121",
136
+ "122": "LABEL_122",
137
+ "123": "LABEL_123",
138
+ "124": "LABEL_124",
139
+ "125": "LABEL_125",
140
+ "126": "LABEL_126",
141
+ "127": "LABEL_127",
142
+ "128": "LABEL_128",
143
+ "129": "LABEL_129",
144
+ "130": "LABEL_130",
145
+ "131": "LABEL_131",
146
+ "132": "LABEL_132",
147
+ "133": "LABEL_133",
148
+ "134": "LABEL_134",
149
+ "135": "LABEL_135",
150
+ "136": "LABEL_136",
151
+ "137": "LABEL_137",
152
+ "138": "LABEL_138",
153
+ "139": "LABEL_139",
154
+ "140": "LABEL_140",
155
+ "141": "LABEL_141",
156
+ "142": "LABEL_142",
157
+ "143": "LABEL_143",
158
+ "144": "LABEL_144",
159
+ "145": "LABEL_145",
160
+ "146": "LABEL_146",
161
+ "147": "LABEL_147",
162
+ "148": "LABEL_148",
163
+ "149": "LABEL_149",
164
+ "150": "LABEL_150",
165
+ "151": "LABEL_151",
166
+ "152": "LABEL_152",
167
+ "153": "LABEL_153",
168
+ "154": "LABEL_154",
169
+ "155": "LABEL_155",
170
+ "156": "LABEL_156",
171
+ "157": "LABEL_157",
172
+ "158": "LABEL_158",
173
+ "159": "LABEL_159",
174
+ "160": "LABEL_160",
175
+ "161": "LABEL_161",
176
+ "162": "LABEL_162",
177
+ "163": "LABEL_163",
178
+ "164": "LABEL_164",
179
+ "165": "LABEL_165",
180
+ "166": "LABEL_166",
181
+ "167": "LABEL_167",
182
+ "168": "LABEL_168",
183
+ "169": "LABEL_169",
184
+ "170": "LABEL_170",
185
+ "171": "LABEL_171",
186
+ "172": "LABEL_172",
187
+ "173": "LABEL_173",
188
+ "174": "LABEL_174",
189
+ "175": "LABEL_175",
190
+ "176": "LABEL_176",
191
+ "177": "LABEL_177",
192
+ "178": "LABEL_178",
193
+ "179": "LABEL_179",
194
+ "180": "LABEL_180",
195
+ "181": "LABEL_181",
196
+ "182": "LABEL_182",
197
+ "183": "LABEL_183",
198
+ "184": "LABEL_184",
199
+ "185": "LABEL_185",
200
+ "186": "LABEL_186",
201
+ "187": "LABEL_187",
202
+ "188": "LABEL_188",
203
+ "189": "LABEL_189",
204
+ "190": "LABEL_190",
205
+ "191": "LABEL_191",
206
+ "192": "LABEL_192",
207
+ "193": "LABEL_193",
208
+ "194": "LABEL_194",
209
+ "195": "LABEL_195",
210
+ "196": "LABEL_196",
211
+ "197": "LABEL_197",
212
+ "198": "LABEL_198",
213
+ "199": "LABEL_199",
214
+ "200": "LABEL_200",
215
+ "201": "LABEL_201",
216
+ "202": "LABEL_202",
217
+ "203": "LABEL_203",
218
+ "204": "LABEL_204",
219
+ "205": "LABEL_205",
220
+ "206": "LABEL_206",
221
+ "207": "LABEL_207",
222
+ "208": "LABEL_208",
223
+ "209": "LABEL_209",
224
+ "210": "LABEL_210",
225
+ "211": "LABEL_211",
226
+ "212": "LABEL_212",
227
+ "213": "LABEL_213",
228
+ "214": "LABEL_214",
229
+ "215": "LABEL_215",
230
+ "216": "LABEL_216",
231
+ "217": "LABEL_217",
232
+ "218": "LABEL_218",
233
+ "219": "LABEL_219",
234
+ "220": "LABEL_220",
235
+ "221": "LABEL_221",
236
+ "222": "LABEL_222",
237
+ "223": "LABEL_223",
238
+ "224": "LABEL_224",
239
+ "225": "LABEL_225",
240
+ "226": "LABEL_226",
241
+ "227": "LABEL_227",
242
+ "228": "LABEL_228",
243
+ "229": "LABEL_229",
244
+ "230": "LABEL_230",
245
+ "231": "LABEL_231",
246
+ "232": "LABEL_232",
247
+ "233": "LABEL_233",
248
+ "234": "LABEL_234",
249
+ "235": "LABEL_235",
250
+ "236": "LABEL_236",
251
+ "237": "LABEL_237",
252
+ "238": "LABEL_238",
253
+ "239": "LABEL_239",
254
+ "240": "LABEL_240",
255
+ "241": "LABEL_241",
256
+ "242": "LABEL_242",
257
+ "243": "LABEL_243",
258
+ "244": "LABEL_244",
259
+ "245": "LABEL_245",
260
+ "246": "LABEL_246",
261
+ "247": "LABEL_247",
262
+ "248": "LABEL_248",
263
+ "249": "LABEL_249",
264
+ "250": "LABEL_250",
265
+ "251": "LABEL_251",
266
+ "252": "LABEL_252",
267
+ "253": "LABEL_253",
268
+ "254": "LABEL_254",
269
+ "255": "LABEL_255",
270
+ "256": "LABEL_256",
271
+ "257": "LABEL_257",
272
+ "258": "LABEL_258",
273
+ "259": "LABEL_259",
274
+ "260": "LABEL_260",
275
+ "261": "LABEL_261",
276
+ "262": "LABEL_262",
277
+ "263": "LABEL_263",
278
+ "264": "LABEL_264",
279
+ "265": "LABEL_265",
280
+ "266": "LABEL_266",
281
+ "267": "LABEL_267",
282
+ "268": "LABEL_268",
283
+ "269": "LABEL_269",
284
+ "270": "LABEL_270",
285
+ "271": "LABEL_271",
286
+ "272": "LABEL_272",
287
+ "273": "LABEL_273",
288
+ "274": "LABEL_274",
289
+ "275": "LABEL_275",
290
+ "276": "LABEL_276",
291
+ "277": "LABEL_277",
292
+ "278": "LABEL_278",
293
+ "279": "LABEL_279",
294
+ "280": "LABEL_280",
295
+ "281": "LABEL_281",
296
+ "282": "LABEL_282",
297
+ "283": "LABEL_283",
298
+ "284": "LABEL_284",
299
+ "285": "LABEL_285",
300
+ "286": "LABEL_286",
301
+ "287": "LABEL_287",
302
+ "288": "LABEL_288",
303
+ "289": "LABEL_289",
304
+ "290": "LABEL_290",
305
+ "291": "LABEL_291",
306
+ "292": "LABEL_292",
307
+ "293": "LABEL_293",
308
+ "294": "LABEL_294",
309
+ "295": "LABEL_295",
310
+ "296": "LABEL_296",
311
+ "297": "LABEL_297",
312
+ "298": "LABEL_298",
313
+ "299": "LABEL_299",
314
+ "300": "LABEL_300",
315
+ "301": "LABEL_301",
316
+ "302": "LABEL_302",
317
+ "303": "LABEL_303",
318
+ "304": "LABEL_304",
319
+ "305": "LABEL_305",
320
+ "306": "LABEL_306",
321
+ "307": "LABEL_307",
322
+ "308": "LABEL_308",
323
+ "309": "LABEL_309",
324
+ "310": "LABEL_310",
325
+ "311": "LABEL_311",
326
+ "312": "LABEL_312",
327
+ "313": "LABEL_313",
328
+ "314": "LABEL_314",
329
+ "315": "LABEL_315",
330
+ "316": "LABEL_316",
331
+ "317": "LABEL_317",
332
+ "318": "LABEL_318",
333
+ "319": "LABEL_319",
334
+ "320": "LABEL_320",
335
+ "321": "LABEL_321",
336
+ "322": "LABEL_322",
337
+ "323": "LABEL_323",
338
+ "324": "LABEL_324",
339
+ "325": "LABEL_325",
340
+ "326": "LABEL_326",
341
+ "327": "LABEL_327",
342
+ "328": "LABEL_328",
343
+ "329": "LABEL_329",
344
+ "330": "LABEL_330",
345
+ "331": "LABEL_331",
346
+ "332": "LABEL_332",
347
+ "333": "LABEL_333",
348
+ "334": "LABEL_334",
349
+ "335": "LABEL_335",
350
+ "336": "LABEL_336",
351
+ "337": "LABEL_337",
352
+ "338": "LABEL_338",
353
+ "339": "LABEL_339",
354
+ "340": "LABEL_340",
355
+ "341": "LABEL_341",
356
+ "342": "LABEL_342",
357
+ "343": "LABEL_343",
358
+ "344": "LABEL_344",
359
+ "345": "LABEL_345",
360
+ "346": "LABEL_346",
361
+ "347": "LABEL_347",
362
+ "348": "LABEL_348",
363
+ "349": "LABEL_349",
364
+ "350": "LABEL_350",
365
+ "351": "LABEL_351",
366
+ "352": "LABEL_352",
367
+ "353": "LABEL_353",
368
+ "354": "LABEL_354",
369
+ "355": "LABEL_355",
370
+ "356": "LABEL_356",
371
+ "357": "LABEL_357",
372
+ "358": "LABEL_358",
373
+ "359": "LABEL_359",
374
+ "360": "LABEL_360",
375
+ "361": "LABEL_361",
376
+ "362": "LABEL_362",
377
+ "363": "LABEL_363",
378
+ "364": "LABEL_364",
379
+ "365": "LABEL_365",
380
+ "366": "LABEL_366",
381
+ "367": "LABEL_367",
382
+ "368": "LABEL_368",
383
+ "369": "LABEL_369",
384
+ "370": "LABEL_370",
385
+ "371": "LABEL_371",
386
+ "372": "LABEL_372",
387
+ "373": "LABEL_373",
388
+ "374": "LABEL_374",
389
+ "375": "LABEL_375",
390
+ "376": "LABEL_376",
391
+ "377": "LABEL_377",
392
+ "378": "LABEL_378",
393
+ "379": "LABEL_379",
394
+ "380": "LABEL_380",
395
+ "381": "LABEL_381",
396
+ "382": "LABEL_382",
397
+ "383": "LABEL_383",
398
+ "384": "LABEL_384",
399
+ "385": "LABEL_385",
400
+ "386": "LABEL_386",
401
+ "387": "LABEL_387",
402
+ "388": "LABEL_388",
403
+ "389": "LABEL_389",
404
+ "390": "LABEL_390"
405
+ },
406
+ "initializer_range": 0.02,
407
+ "intermediate_size": 1536,
408
+ "label2id": {
409
+ "LABEL_0": 0,
410
+ "LABEL_1": 1,
411
+ "LABEL_10": 10,
412
+ "LABEL_100": 100,
413
+ "LABEL_101": 101,
414
+ "LABEL_102": 102,
415
+ "LABEL_103": 103,
416
+ "LABEL_104": 104,
417
+ "LABEL_105": 105,
418
+ "LABEL_106": 106,
419
+ "LABEL_107": 107,
420
+ "LABEL_108": 108,
421
+ "LABEL_109": 109,
422
+ "LABEL_11": 11,
423
+ "LABEL_110": 110,
424
+ "LABEL_111": 111,
425
+ "LABEL_112": 112,
426
+ "LABEL_113": 113,
427
+ "LABEL_114": 114,
428
+ "LABEL_115": 115,
429
+ "LABEL_116": 116,
430
+ "LABEL_117": 117,
431
+ "LABEL_118": 118,
432
+ "LABEL_119": 119,
433
+ "LABEL_12": 12,
434
+ "LABEL_120": 120,
435
+ "LABEL_121": 121,
436
+ "LABEL_122": 122,
437
+ "LABEL_123": 123,
438
+ "LABEL_124": 124,
439
+ "LABEL_125": 125,
440
+ "LABEL_126": 126,
441
+ "LABEL_127": 127,
442
+ "LABEL_128": 128,
443
+ "LABEL_129": 129,
444
+ "LABEL_13": 13,
445
+ "LABEL_130": 130,
446
+ "LABEL_131": 131,
447
+ "LABEL_132": 132,
448
+ "LABEL_133": 133,
449
+ "LABEL_134": 134,
450
+ "LABEL_135": 135,
451
+ "LABEL_136": 136,
452
+ "LABEL_137": 137,
453
+ "LABEL_138": 138,
454
+ "LABEL_139": 139,
455
+ "LABEL_14": 14,
456
+ "LABEL_140": 140,
457
+ "LABEL_141": 141,
458
+ "LABEL_142": 142,
459
+ "LABEL_143": 143,
460
+ "LABEL_144": 144,
461
+ "LABEL_145": 145,
462
+ "LABEL_146": 146,
463
+ "LABEL_147": 147,
464
+ "LABEL_148": 148,
465
+ "LABEL_149": 149,
466
+ "LABEL_15": 15,
467
+ "LABEL_150": 150,
468
+ "LABEL_151": 151,
469
+ "LABEL_152": 152,
470
+ "LABEL_153": 153,
471
+ "LABEL_154": 154,
472
+ "LABEL_155": 155,
473
+ "LABEL_156": 156,
474
+ "LABEL_157": 157,
475
+ "LABEL_158": 158,
476
+ "LABEL_159": 159,
477
+ "LABEL_16": 16,
478
+ "LABEL_160": 160,
479
+ "LABEL_161": 161,
480
+ "LABEL_162": 162,
481
+ "LABEL_163": 163,
482
+ "LABEL_164": 164,
483
+ "LABEL_165": 165,
484
+ "LABEL_166": 166,
485
+ "LABEL_167": 167,
486
+ "LABEL_168": 168,
487
+ "LABEL_169": 169,
488
+ "LABEL_17": 17,
489
+ "LABEL_170": 170,
490
+ "LABEL_171": 171,
491
+ "LABEL_172": 172,
492
+ "LABEL_173": 173,
493
+ "LABEL_174": 174,
494
+ "LABEL_175": 175,
495
+ "LABEL_176": 176,
496
+ "LABEL_177": 177,
497
+ "LABEL_178": 178,
498
+ "LABEL_179": 179,
499
+ "LABEL_18": 18,
500
+ "LABEL_180": 180,
501
+ "LABEL_181": 181,
502
+ "LABEL_182": 182,
503
+ "LABEL_183": 183,
504
+ "LABEL_184": 184,
505
+ "LABEL_185": 185,
506
+ "LABEL_186": 186,
507
+ "LABEL_187": 187,
508
+ "LABEL_188": 188,
509
+ "LABEL_189": 189,
510
+ "LABEL_19": 19,
511
+ "LABEL_190": 190,
512
+ "LABEL_191": 191,
513
+ "LABEL_192": 192,
514
+ "LABEL_193": 193,
515
+ "LABEL_194": 194,
516
+ "LABEL_195": 195,
517
+ "LABEL_196": 196,
518
+ "LABEL_197": 197,
519
+ "LABEL_198": 198,
520
+ "LABEL_199": 199,
521
+ "LABEL_2": 2,
522
+ "LABEL_20": 20,
523
+ "LABEL_200": 200,
524
+ "LABEL_201": 201,
525
+ "LABEL_202": 202,
526
+ "LABEL_203": 203,
527
+ "LABEL_204": 204,
528
+ "LABEL_205": 205,
529
+ "LABEL_206": 206,
530
+ "LABEL_207": 207,
531
+ "LABEL_208": 208,
532
+ "LABEL_209": 209,
533
+ "LABEL_21": 21,
534
+ "LABEL_210": 210,
535
+ "LABEL_211": 211,
536
+ "LABEL_212": 212,
537
+ "LABEL_213": 213,
538
+ "LABEL_214": 214,
539
+ "LABEL_215": 215,
540
+ "LABEL_216": 216,
541
+ "LABEL_217": 217,
542
+ "LABEL_218": 218,
543
+ "LABEL_219": 219,
544
+ "LABEL_22": 22,
545
+ "LABEL_220": 220,
546
+ "LABEL_221": 221,
547
+ "LABEL_222": 222,
548
+ "LABEL_223": 223,
549
+ "LABEL_224": 224,
550
+ "LABEL_225": 225,
551
+ "LABEL_226": 226,
552
+ "LABEL_227": 227,
553
+ "LABEL_228": 228,
554
+ "LABEL_229": 229,
555
+ "LABEL_23": 23,
556
+ "LABEL_230": 230,
557
+ "LABEL_231": 231,
558
+ "LABEL_232": 232,
559
+ "LABEL_233": 233,
560
+ "LABEL_234": 234,
561
+ "LABEL_235": 235,
562
+ "LABEL_236": 236,
563
+ "LABEL_237": 237,
564
+ "LABEL_238": 238,
565
+ "LABEL_239": 239,
566
+ "LABEL_24": 24,
567
+ "LABEL_240": 240,
568
+ "LABEL_241": 241,
569
+ "LABEL_242": 242,
570
+ "LABEL_243": 243,
571
+ "LABEL_244": 244,
572
+ "LABEL_245": 245,
573
+ "LABEL_246": 246,
574
+ "LABEL_247": 247,
575
+ "LABEL_248": 248,
576
+ "LABEL_249": 249,
577
+ "LABEL_25": 25,
578
+ "LABEL_250": 250,
579
+ "LABEL_251": 251,
580
+ "LABEL_252": 252,
581
+ "LABEL_253": 253,
582
+ "LABEL_254": 254,
583
+ "LABEL_255": 255,
584
+ "LABEL_256": 256,
585
+ "LABEL_257": 257,
586
+ "LABEL_258": 258,
587
+ "LABEL_259": 259,
588
+ "LABEL_26": 26,
589
+ "LABEL_260": 260,
590
+ "LABEL_261": 261,
591
+ "LABEL_262": 262,
592
+ "LABEL_263": 263,
593
+ "LABEL_264": 264,
594
+ "LABEL_265": 265,
595
+ "LABEL_266": 266,
596
+ "LABEL_267": 267,
597
+ "LABEL_268": 268,
598
+ "LABEL_269": 269,
599
+ "LABEL_27": 27,
600
+ "LABEL_270": 270,
601
+ "LABEL_271": 271,
602
+ "LABEL_272": 272,
603
+ "LABEL_273": 273,
604
+ "LABEL_274": 274,
605
+ "LABEL_275": 275,
606
+ "LABEL_276": 276,
607
+ "LABEL_277": 277,
608
+ "LABEL_278": 278,
609
+ "LABEL_279": 279,
610
+ "LABEL_28": 28,
611
+ "LABEL_280": 280,
612
+ "LABEL_281": 281,
613
+ "LABEL_282": 282,
614
+ "LABEL_283": 283,
615
+ "LABEL_284": 284,
616
+ "LABEL_285": 285,
617
+ "LABEL_286": 286,
618
+ "LABEL_287": 287,
619
+ "LABEL_288": 288,
620
+ "LABEL_289": 289,
621
+ "LABEL_29": 29,
622
+ "LABEL_290": 290,
623
+ "LABEL_291": 291,
624
+ "LABEL_292": 292,
625
+ "LABEL_293": 293,
626
+ "LABEL_294": 294,
627
+ "LABEL_295": 295,
628
+ "LABEL_296": 296,
629
+ "LABEL_297": 297,
630
+ "LABEL_298": 298,
631
+ "LABEL_299": 299,
632
+ "LABEL_3": 3,
633
+ "LABEL_30": 30,
634
+ "LABEL_300": 300,
635
+ "LABEL_301": 301,
636
+ "LABEL_302": 302,
637
+ "LABEL_303": 303,
638
+ "LABEL_304": 304,
639
+ "LABEL_305": 305,
640
+ "LABEL_306": 306,
641
+ "LABEL_307": 307,
642
+ "LABEL_308": 308,
643
+ "LABEL_309": 309,
644
+ "LABEL_31": 31,
645
+ "LABEL_310": 310,
646
+ "LABEL_311": 311,
647
+ "LABEL_312": 312,
648
+ "LABEL_313": 313,
649
+ "LABEL_314": 314,
650
+ "LABEL_315": 315,
651
+ "LABEL_316": 316,
652
+ "LABEL_317": 317,
653
+ "LABEL_318": 318,
654
+ "LABEL_319": 319,
655
+ "LABEL_32": 32,
656
+ "LABEL_320": 320,
657
+ "LABEL_321": 321,
658
+ "LABEL_322": 322,
659
+ "LABEL_323": 323,
660
+ "LABEL_324": 324,
661
+ "LABEL_325": 325,
662
+ "LABEL_326": 326,
663
+ "LABEL_327": 327,
664
+ "LABEL_328": 328,
665
+ "LABEL_329": 329,
666
+ "LABEL_33": 33,
667
+ "LABEL_330": 330,
668
+ "LABEL_331": 331,
669
+ "LABEL_332": 332,
670
+ "LABEL_333": 333,
671
+ "LABEL_334": 334,
672
+ "LABEL_335": 335,
673
+ "LABEL_336": 336,
674
+ "LABEL_337": 337,
675
+ "LABEL_338": 338,
676
+ "LABEL_339": 339,
677
+ "LABEL_34": 34,
678
+ "LABEL_340": 340,
679
+ "LABEL_341": 341,
680
+ "LABEL_342": 342,
681
+ "LABEL_343": 343,
682
+ "LABEL_344": 344,
683
+ "LABEL_345": 345,
684
+ "LABEL_346": 346,
685
+ "LABEL_347": 347,
686
+ "LABEL_348": 348,
687
+ "LABEL_349": 349,
688
+ "LABEL_35": 35,
689
+ "LABEL_350": 350,
690
+ "LABEL_351": 351,
691
+ "LABEL_352": 352,
692
+ "LABEL_353": 353,
693
+ "LABEL_354": 354,
694
+ "LABEL_355": 355,
695
+ "LABEL_356": 356,
696
+ "LABEL_357": 357,
697
+ "LABEL_358": 358,
698
+ "LABEL_359": 359,
699
+ "LABEL_36": 36,
700
+ "LABEL_360": 360,
701
+ "LABEL_361": 361,
702
+ "LABEL_362": 362,
703
+ "LABEL_363": 363,
704
+ "LABEL_364": 364,
705
+ "LABEL_365": 365,
706
+ "LABEL_366": 366,
707
+ "LABEL_367": 367,
708
+ "LABEL_368": 368,
709
+ "LABEL_369": 369,
710
+ "LABEL_37": 37,
711
+ "LABEL_370": 370,
712
+ "LABEL_371": 371,
713
+ "LABEL_372": 372,
714
+ "LABEL_373": 373,
715
+ "LABEL_374": 374,
716
+ "LABEL_375": 375,
717
+ "LABEL_376": 376,
718
+ "LABEL_377": 377,
719
+ "LABEL_378": 378,
720
+ "LABEL_379": 379,
721
+ "LABEL_38": 38,
722
+ "LABEL_380": 380,
723
+ "LABEL_381": 381,
724
+ "LABEL_382": 382,
725
+ "LABEL_383": 383,
726
+ "LABEL_384": 384,
727
+ "LABEL_385": 385,
728
+ "LABEL_386": 386,
729
+ "LABEL_387": 387,
730
+ "LABEL_388": 388,
731
+ "LABEL_389": 389,
732
+ "LABEL_39": 39,
733
+ "LABEL_390": 390,
734
+ "LABEL_4": 4,
735
+ "LABEL_40": 40,
736
+ "LABEL_41": 41,
737
+ "LABEL_42": 42,
738
+ "LABEL_43": 43,
739
+ "LABEL_44": 44,
740
+ "LABEL_45": 45,
741
+ "LABEL_46": 46,
742
+ "LABEL_47": 47,
743
+ "LABEL_48": 48,
744
+ "LABEL_49": 49,
745
+ "LABEL_5": 5,
746
+ "LABEL_50": 50,
747
+ "LABEL_51": 51,
748
+ "LABEL_52": 52,
749
+ "LABEL_53": 53,
750
+ "LABEL_54": 54,
751
+ "LABEL_55": 55,
752
+ "LABEL_56": 56,
753
+ "LABEL_57": 57,
754
+ "LABEL_58": 58,
755
+ "LABEL_59": 59,
756
+ "LABEL_6": 6,
757
+ "LABEL_60": 60,
758
+ "LABEL_61": 61,
759
+ "LABEL_62": 62,
760
+ "LABEL_63": 63,
761
+ "LABEL_64": 64,
762
+ "LABEL_65": 65,
763
+ "LABEL_66": 66,
764
+ "LABEL_67": 67,
765
+ "LABEL_68": 68,
766
+ "LABEL_69": 69,
767
+ "LABEL_7": 7,
768
+ "LABEL_70": 70,
769
+ "LABEL_71": 71,
770
+ "LABEL_72": 72,
771
+ "LABEL_73": 73,
772
+ "LABEL_74": 74,
773
+ "LABEL_75": 75,
774
+ "LABEL_76": 76,
775
+ "LABEL_77": 77,
776
+ "LABEL_78": 78,
777
+ "LABEL_79": 79,
778
+ "LABEL_8": 8,
779
+ "LABEL_80": 80,
780
+ "LABEL_81": 81,
781
+ "LABEL_82": 82,
782
+ "LABEL_83": 83,
783
+ "LABEL_84": 84,
784
+ "LABEL_85": 85,
785
+ "LABEL_86": 86,
786
+ "LABEL_87": 87,
787
+ "LABEL_88": 88,
788
+ "LABEL_89": 89,
789
+ "LABEL_9": 9,
790
+ "LABEL_90": 90,
791
+ "LABEL_91": 91,
792
+ "LABEL_92": 92,
793
+ "LABEL_93": 93,
794
+ "LABEL_94": 94,
795
+ "LABEL_95": 95,
796
+ "LABEL_96": 96,
797
+ "LABEL_97": 97,
798
+ "LABEL_98": 98,
799
+ "LABEL_99": 99
800
+ },
801
+ "layer_norm_eps": 1e-12,
802
+ "max_position_embeddings": 512,
803
+ "model_type": "bert",
804
+ "num_attention_heads": 12,
805
+ "num_hidden_layers": 6,
806
+ "num_relation_heads": 32,
807
+ "pad_token_id": 0,
808
+ "pooler_fc_size": 768,
809
+ "pooler_num_attention_heads": 12,
810
+ "pooler_num_fc_layers": 3,
811
+ "pooler_size_per_head": 128,
812
+ "pooler_type": "first_token_transform",
813
+ "position_embedding_type": "absolute",
814
+ "torch_dtype": "float32",
815
+ "transformers_version": "4.44.1",
816
+ "type_vocab_size": 2,
817
+ "use_cache": true,
818
+ "vocab_size": 21128
819
+ }
g2p/sources/g2p_chinese_model/poly_bert_model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8765d835ffdf9811c832d4dc7b6a552757aa8615c01d1184db716a50c20aebbc
3
+ size 76583333
g2p/sources/g2p_chinese_model/polychar.txt ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ไธง
2
+ ไธญ
3
+ ไธบ
4
+ ไนŒ
5
+ ไน
6
+ ไบ†
7
+ ไป€
8
+ ไป”
9
+ ไปค
10
+ ไปป
11
+ ไผš
12
+ ไผ 
13
+ ไฝ›
14
+ ไพ›
15
+ ไพฟ
16
+ ๅ€’
17
+ ๅ‡
18
+ ๅ…ด
19
+ ๅ† 
20
+ ๅ†ฒ
21
+ ๅ‡ 
22
+ ๅˆ†
23
+ ๅˆ‡
24
+ ๅˆ’
25
+ ๅˆ›
26
+ ๅ‰ฅ
27
+ ๅ‹’
28
+ ๅŒบ
29
+ ๅŽ
30
+ ๅ•
31
+ ๅœ
32
+ ๅ 
33
+ ๅก
34
+ ๅท
35
+ ๅŽฆ
36
+ ๅ‚
37
+ ๅ‘
38
+ ๅช
39
+ ๅท
40
+ ๅŒ
41
+ ๅ
42
+ ๅ’Œ
43
+ ๅ–
44
+ ๅœˆ
45
+ ๅœฐ
46
+ ๅกž
47
+ ๅฃณ
48
+ ๅค„
49
+ ๅฅ‡
50
+ ๅฅ”
51
+ ๅฅฝ
52
+ ๅฎ
53
+ ๅฎฟ
54
+ ๅฐ†
55
+ ๅฐ‘
56
+ ๅฐฝ
57
+ ๅฒ—
58
+ ๅทฎ
59
+ ๅทท
60
+ ๅธ–
61
+ ๅนฒ
62
+ ๅบ”
63
+ ๅบฆ
64
+ ๅผน
65
+ ๅผบ
66
+ ๅฝ“
67
+ ๅพ…
68
+ ๅพ—
69
+ ๆถ
70
+ ๆ‰
71
+ ๆ‰‡
72
+ ๆ‰Ž
73
+ ๆ‰ซ
74
+ ๆ‹…
75
+ ๆŒ‘
76
+ ๆฎ
77
+ ๆ’’
78
+ ๆ•™
79
+ ๆ•ฃ
80
+ ๆ•ฐ
81
+ ๆ–—
82
+ ๆ™ƒ
83
+ ๆ›
84
+ ๆ›ฒ
85
+ ๆ›ด
86
+ ๆ›พ
87
+ ๆœ
88
+ ๆœด
89
+ ๆ†
90
+ ๆŸฅ
91
+ ๆ ก
92
+ ๆจก
93
+ ๆจช
94
+ ๆฒก
95
+ ๆณก
96
+ ๆตŽ
97
+ ๆทท
98
+ ๆผ‚
99
+ ็‚ธ
100
+ ็†Ÿ
101
+ ็‡•
102
+ ็‰‡
103
+ ็އ
104
+ ็•œ
105
+ ็š„
106
+ ็››
107
+ ็›ธ
108
+ ็œ
109
+ ็œ‹
110
+ ็€
111
+ ็Ÿซ
112
+ ็ฆ
113
+ ็ง
114
+ ็งฐ
115
+ ็ฉบ
116
+ ็ญ”
117
+ ็ฒ˜
118
+ ็ณŠ
119
+ ็ณป
120
+ ็ดฏ
121
+ ็บค
122
+ ็ป“
123
+ ็ป™
124
+ ็ผ
125
+ ่‚–
126
+ ่ƒŒ
127
+ ่„
128
+ ่ˆ
129
+ ่‰ฒ
130
+ ่ฝ
131
+ ่’™
132
+ ่–„
133
+ ่—
134
+ ่ก€
135
+ ่กŒ
136
+ ่ฆ
137
+ ่ง‚
138
+ ่ง‰
139
+ ่ง’
140
+ ่งฃ
141
+ ่ฏด
142
+ ่ฐƒ
143
+ ่ธ
144
+ ่ฝฆ
145
+ ่ฝฌ
146
+ ่ฝฝ
147
+ ่ฟ˜
148
+ ้‚
149
+ ้ƒฝ
150
+ ้‡
151
+ ้‡
152
+ ้’ป
153
+ ้“บ
154
+ ้•ฟ
155
+ ้—ด
156
+ ้™
157
+ ้šพ
158
+ ้œฒ
159
+ ้ฒœ
g2p/sources/g2p_chinese_model/polydict.json ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "1": "ไธง{sang1}",
3
+ "2": "ไธง{sang4}",
4
+ "3": "ไธญ{zhong1}",
5
+ "4": "ไธญ{zhong4}",
6
+ "5": "ไธบ{wei2}",
7
+ "6": "ไธบ{wei4}",
8
+ "7": "ไนŒ{wu1}",
9
+ "8": "ไนŒ{wu4}",
10
+ "9": "ไน{lao4}",
11
+ "10": "ไน{le4}",
12
+ "11": "ไน{le5}",
13
+ "12": "ไน{yao4}",
14
+ "13": "ไน{yve4}",
15
+ "14": "ไบ†{le5}",
16
+ "15": "ไบ†{liao3}",
17
+ "16": "ไบ†{liao5}",
18
+ "17": "ไป€{shen2}",
19
+ "18": "ไป€{shi2}",
20
+ "19": "ไป”{zai3}",
21
+ "20": "ไป”{zai5}",
22
+ "21": "ไป”{zi3}",
23
+ "22": "ไป”{zi5}",
24
+ "23": "ไปค{ling2}",
25
+ "24": "ไปค{ling4}",
26
+ "25": "ไปป{ren2}",
27
+ "26": "ไปป{ren4}",
28
+ "27": "ไผš{hui4}",
29
+ "28": "ไผš{hui5}",
30
+ "29": "ไผš{kuai4}",
31
+ "30": "ไผ {chuan2}",
32
+ "31": "ไผ {zhuan4}",
33
+ "32": "ไฝ›{fo2}",
34
+ "33": "ไฝ›{fu2}",
35
+ "34": "ไพ›{gong1}",
36
+ "35": "ไพ›{gong4}",
37
+ "36": "ไพฟ{bian4}",
38
+ "37": "ไพฟ{pian2}",
39
+ "38": "ๅ€’{dao3}",
40
+ "39": "ๅ€’{dao4}",
41
+ "40": "ๅ‡{jia3}",
42
+ "41": "ๅ‡{jia4}",
43
+ "42": "ๅ…ด{xing1}",
44
+ "43": "ๅ…ด{xing4}",
45
+ "44": "ๅ† {guan1}",
46
+ "45": "ๅ† {guan4}",
47
+ "46": "ๅ†ฒ{chong1}",
48
+ "47": "ๅ†ฒ{chong4}",
49
+ "48": "ๅ‡ {ji1}",
50
+ "49": "ๅ‡ {ji2}",
51
+ "50": "ๅ‡ {ji3}",
52
+ "51": "ๅˆ†{fen1}",
53
+ "52": "ๅˆ†{fen4}",
54
+ "53": "ๅˆ†{fen5}",
55
+ "54": "ๅˆ‡{qie1}",
56
+ "55": "ๅˆ‡{qie4}",
57
+ "56": "ๅˆ’{hua2}",
58
+ "57": "ๅˆ’{hua4}",
59
+ "58": "ๅˆ’{hua5}",
60
+ "59": "ๅˆ›{chuang1}",
61
+ "60": "ๅˆ›{chuang4}",
62
+ "61": "ๅ‰ฅ{bao1}",
63
+ "62": "ๅ‰ฅ{bo1}",
64
+ "63": "ๅ‹’{le4}",
65
+ "64": "ๅ‹’{le5}",
66
+ "65": "ๅ‹’{lei1}",
67
+ "66": "ๅŒบ{ou1}",
68
+ "67": "ๅŒบ{qu1}",
69
+ "68": "ๅŽ{hua2}",
70
+ "69": "ๅŽ{hua4}",
71
+ "70": "ๅ•{chan2}",
72
+ "71": "ๅ•{dan1}",
73
+ "72": "ๅ•{shan4}",
74
+ "73": "ๅœ{bo5}",
75
+ "74": "ๅœ{bu3}",
76
+ "75": "ๅ {zhan1}",
77
+ "76": "ๅ {zhan4}",
78
+ "77": "ๅก{ka2}",
79
+ "78": "ๅก{ka3}",
80
+ "79": "ๅก{qia3}",
81
+ "80": "ๅท{jvan3}",
82
+ "81": "ๅท{jvan4}",
83
+ "82": "ๅŽฆ{sha4}",
84
+ "83": "ๅŽฆ{xia4}",
85
+ "84": "ๅ‚{can1}",
86
+ "85": "ๅ‚{cen1}",
87
+ "86": "ๅ‚{shen1}",
88
+ "87": "ๅ‘{fa1}",
89
+ "88": "ๅ‘{fa4}",
90
+ "89": "ๅ‘{fa5}",
91
+ "90": "ๅช{zhi1}",
92
+ "91": "ๅช{zhi3}",
93
+ "92": "ๅท{hao2}",
94
+ "93": "ๅท{hao4}",
95
+ "94": "ๅท{hao5}",
96
+ "95": "ๅŒ{tong2}",
97
+ "96": "ๅŒ{tong4}",
98
+ "97": "ๅŒ{tong5}",
99
+ "98": "ๅ{tu2}",
100
+ "99": "ๅ{tu3}",
101
+ "100": "ๅ{tu4}",
102
+ "101": "ๅ’Œ{he2}",
103
+ "102": "ๅ’Œ{he4}",
104
+ "103": "ๅ’Œ{he5}",
105
+ "104": "ๅ’Œ{huo2}",
106
+ "105": "ๅ’Œ{huo4}",
107
+ "106": "ๅ’Œ{huo5}",
108
+ "107": "ๅ–{he1}",
109
+ "108": "ๅ–{he4}",
110
+ "109": "ๅœˆ{jvan4}",
111
+ "110": "ๅœˆ{qvan1}",
112
+ "111": "ๅœˆ{qvan5}",
113
+ "112": "ๅœฐ{de5}",
114
+ "113": "ๅœฐ{di4}",
115
+ "114": "ๅœฐ{di5}",
116
+ "115": "ๅกž{sai1}",
117
+ "116": "ๅกž{sai2}",
118
+ "117": "ๅกž{sai4}",
119
+ "118": "ๅกž{se4}",
120
+ "119": "ๅฃณ{ke2}",
121
+ "120": "ๅฃณ{qiao4}",
122
+ "121": "ๅค„{chu3}",
123
+ "122": "ๅค„{chu4}",
124
+ "123": "ๅฅ‡{ji1}",
125
+ "124": "ๅฅ‡{qi2}",
126
+ "125": "ๅฅ”{ben1}",
127
+ "126": "ๅฅ”{ben4}",
128
+ "127": "ๅฅฝ{hao3}",
129
+ "128": "ๅฅฝ{hao4}",
130
+ "129": "ๅฅฝ{hao5}",
131
+ "130": "ๅฎ{ning2}",
132
+ "131": "ๅฎ{ning4}",
133
+ "132": "ๅฎ{ning5}",
134
+ "133": "ๅฎฟ{su4}",
135
+ "134": "ๅฎฟ{xiu3}",
136
+ "135": "ๅฎฟ{xiu4}",
137
+ "136": "ๅฐ†{jiang1}",
138
+ "137": "ๅฐ†{jiang4}",
139
+ "138": "ๅฐ‘{shao3}",
140
+ "139": "ๅฐ‘{shao4}",
141
+ "140": "ๅฐฝ{jin3}",
142
+ "141": "ๅฐฝ{jin4}",
143
+ "142": "ๅฒ—{gang1}",
144
+ "143": "ๅฒ—{gang3}",
145
+ "144": "ๅทฎ{cha1}",
146
+ "145": "ๅทฎ{cha4}",
147
+ "146": "ๅทฎ{chai1}",
148
+ "147": "ๅทฎ{ci1}",
149
+ "148": "ๅทท{hang4}",
150
+ "149": "ๅทท{xiang4}",
151
+ "150": "ๅธ–{tie1}",
152
+ "151": "ๅธ–{tie3}",
153
+ "152": "ๅธ–{tie4}",
154
+ "153": "ๅนฒ{gan1}",
155
+ "154": "ๅนฒ{gan4}",
156
+ "155": "ๅบ”{ying1}",
157
+ "156": "ๅบ”{ying4}",
158
+ "157": "ๅบ”{ying5}",
159
+ "158": "ๅบฆ{du4}",
160
+ "159": "ๅบฆ{du5}",
161
+ "160": "ๅบฆ{duo2}",
162
+ "161": "ๅผน{dan4}",
163
+ "162": "ๅผน{tan2}",
164
+ "163": "ๅผน{tan5}",
165
+ "164": "ๅผบ{jiang4}",
166
+ "165": "ๅผบ{qiang2}",
167
+ "166": "ๅผบ{qiang3}",
168
+ "167": "ๅฝ“{dang1}",
169
+ "168": "ๅฝ“{dang4}",
170
+ "169": "ๅฝ“{dang5}",
171
+ "170": "ๅพ…{dai1}",
172
+ "171": "ๅพ…{dai4}",
173
+ "172": "ๅพ—{de2}",
174
+ "173": "ๅพ—{de5}",
175
+ "174": "ๅพ—{dei3}",
176
+ "175": "ๅพ—{dei5}",
177
+ "176": "ๆถ{e3}",
178
+ "177": "ๆถ{e4}",
179
+ "178": "ๆถ{wu4}",
180
+ "179": "ๆ‰{bian3}",
181
+ "180": "ๆ‰{pian1}",
182
+ "181": "ๆ‰‡{shan1}",
183
+ "182": "ๆ‰‡{shan4}",
184
+ "183": "ๆ‰Ž{za1}",
185
+ "184": "ๆ‰Ž{zha1}",
186
+ "185": "ๆ‰Ž{zha2}",
187
+ "186": "ๆ‰ซ{sao3}",
188
+ "187": "ๆ‰ซ{sao4}",
189
+ "188": "ๆ‹…{dan1}",
190
+ "189": "ๆ‹…{dan4}",
191
+ "190": "ๆ‹…{dan5}",
192
+ "191": "ๆŒ‘{tiao1}",
193
+ "192": "ๆŒ‘{tiao3}",
194
+ "193": "ๆฎ{jv1}",
195
+ "194": "ๆฎ{jv4}",
196
+ "195": "ๆ’’{sa1}",
197
+ "196": "ๆ’’{sa3}",
198
+ "197": "ๆ’’{sa5}",
199
+ "198": "ๆ•™{jiao1}",
200
+ "199": "ๆ•™{jiao4}",
201
+ "200": "ๆ•ฃ{san3}",
202
+ "201": "ๆ•ฃ{san4}",
203
+ "202": "ๆ•ฃ{san5}",
204
+ "203": "ๆ•ฐ{shu3}",
205
+ "204": "ๆ•ฐ{shu4}",
206
+ "205": "ๆ•ฐ{shu5}",
207
+ "206": "ๆ–—{dou3}",
208
+ "207": "ๆ–—{dou4}",
209
+ "208": "ๆ™ƒ{huang3}",
210
+ "209": "ๆ›{bao4}",
211
+ "210": "ๆ›ฒ{qu1}",
212
+ "211": "ๆ›ฒ{qu3}",
213
+ "212": "ๆ›ด{geng1}",
214
+ "213": "ๆ›ด{geng4}",
215
+ "214": "ๆ›พ{ceng1}",
216
+ "215": "ๆ›พ{ceng2}",
217
+ "216": "ๆ›พ{zeng1}",
218
+ "217": "ๆœ{chao2}",
219
+ "218": "ๆœ{zhao1}",
220
+ "219": "ๆœด{piao2}",
221
+ "220": "ๆœด{pu2}",
222
+ "221": "ๆœด{pu3}",
223
+ "222": "ๆ†{gan1}",
224
+ "223": "ๆ†{gan3}",
225
+ "224": "ๆŸฅ{cha2}",
226
+ "225": "ๆŸฅ{zha1}",
227
+ "226": "ๆ ก{jiao4}",
228
+ "227": "ๆ ก{xiao4}",
229
+ "228": "ๆจก{mo2}",
230
+ "229": "ๆจก{mu2}",
231
+ "230": "ๆจช{heng2}",
232
+ "231": "ๆจช{heng4}",
233
+ "232": "ๆฒก{mei2}",
234
+ "233": "ๆฒก{mo4}",
235
+ "234": "ๆณก{pao1}",
236
+ "235": "ๆณก{pao4}",
237
+ "236": "ๆณก{pao5}",
238
+ "237": "ๆตŽ{ji3}",
239
+ "238": "ๆตŽ{ji4}",
240
+ "239": "ๆทท{hun2}",
241
+ "240": "ๆทท{hun3}",
242
+ "241": "ๆทท{hun4}",
243
+ "242": "ๆทท{hun5}",
244
+ "243": "ๆผ‚{piao1}",
245
+ "244": "ๆผ‚{piao3}",
246
+ "245": "ๆผ‚{piao4}",
247
+ "246": "็‚ธ{zha2}",
248
+ "247": "็‚ธ{zha4}",
249
+ "248": "็†Ÿ{shou2}",
250
+ "249": "็†Ÿ{shu2}",
251
+ "250": "็‡•{yan1}",
252
+ "251": "็‡•{yan4}",
253
+ "252": "็‰‡{pian1}",
254
+ "253": "็‰‡{pian4}",
255
+ "254": "็އ{lv4}",
256
+ "255": "็އ{shuai4}",
257
+ "256": "็•œ{chu4}",
258
+ "257": "็•œ{xu4}",
259
+ "258": "็š„{de5}",
260
+ "259": "็š„{di1}",
261
+ "260": "็š„{di2}",
262
+ "261": "็š„{di4}",
263
+ "262": "็š„{di5}",
264
+ "263": "็››{cheng2}",
265
+ "264": "็››{sheng4}",
266
+ "265": "็›ธ{xiang1}",
267
+ "266": "็›ธ{xiang4}",
268
+ "267": "็›ธ{xiang5}",
269
+ "268": "็œ{sheng3}",
270
+ "269": "็œ{xing3}",
271
+ "270": "็œ‹{kan1}",
272
+ "271": "็œ‹{kan4}",
273
+ "272": "็œ‹{kan5}",
274
+ "273": "็€{zhao1}",
275
+ "274": "็€{zhao2}",
276
+ "275": "็€{zhao5}",
277
+ "276": "็€{zhe5}",
278
+ "277": "็€{zhuo2}",
279
+ "278": "็€{zhuo5}",
280
+ "279": "็Ÿซ{jiao3}",
281
+ "280": "็ฆ{jin1}",
282
+ "281": "็ฆ{jin4}",
283
+ "282": "็ง{zhong3}",
284
+ "283": "็ง{zhong4}",
285
+ "284": "็งฐ{chen4}",
286
+ "285": "็งฐ{cheng1}",
287
+ "286": "็ฉบ{kong1}",
288
+ "287": "็ฉบ{kong4}",
289
+ "288": "็ญ”{da1}",
290
+ "289": "็ญ”{da2}",
291
+ "290": "็ฒ˜{nian2}",
292
+ "291": "็ฒ˜{zhan1}",
293
+ "292": "็ณŠ{hu2}",
294
+ "293": "็ณŠ{hu5}",
295
+ "294": "็ณป{ji4}",
296
+ "295": "็ณป{xi4}",
297
+ "296": "็ณป{xi5}",
298
+ "297": "็ดฏ{lei2}",
299
+ "298": "็ดฏ{lei3}",
300
+ "299": "็ดฏ{lei4}",
301
+ "300": "็ดฏ{lei5}",
302
+ "301": "็บค{qian4}",
303
+ "302": "็บค{xian1}",
304
+ "303": "็ป“{jie1}",
305
+ "304": "็ป“{jie2}",
306
+ "305": "็ป“{jie5}",
307
+ "306": "็ป™{gei3}",
308
+ "307": "็ป™{gei5}",
309
+ "308": "็ป™{ji3}",
310
+ "309": "็ผ{feng2}",
311
+ "310": "็ผ{feng4}",
312
+ "311": "็ผ{feng5}",
313
+ "312": "่‚–{xiao1}",
314
+ "313": "่‚–{xiao4}",
315
+ "314": "่ƒŒ{bei1}",
316
+ "315": "่ƒŒ{bei4}",
317
+ "316": "่„{zang1}",
318
+ "317": "่„{zang4}",
319
+ "318": "่ˆ{she3}",
320
+ "319": "่ˆ{she4}",
321
+ "320": "่‰ฒ{se4}",
322
+ "321": "่‰ฒ{shai3}",
323
+ "322": "่ฝ{lao4}",
324
+ "323": "่ฝ{luo4}",
325
+ "324": "่’™{meng1}",
326
+ "325": "่’™{meng2}",
327
+ "326": "่’™{meng3}",
328
+ "327": "่–„{bao2}",
329
+ "328": "่–„{bo2}",
330
+ "329": "่–„{bo4}",
331
+ "330": "่—{cang2}",
332
+ "331": "่—{zang4}",
333
+ "332": "่ก€{xie3}",
334
+ "333": "่ก€{xue4}",
335
+ "334": "่กŒ{hang2}",
336
+ "335": "่กŒ{hang5}",
337
+ "336": "่กŒ{heng5}",
338
+ "337": "่กŒ{xing2}",
339
+ "338": "่กŒ{xing4}",
340
+ "339": "่ฆ{yao1}",
341
+ "340": "่ฆ{yao4}",
342
+ "341": "่ง‚{guan1}",
343
+ "342": "่ง‚{guan4}",
344
+ "343": "่ง‰{jiao4}",
345
+ "344": "่ง‰{jiao5}",
346
+ "345": "่ง‰{jve2}",
347
+ "346": "่ง’{jiao3}",
348
+ "347": "่ง’{jve2}",
349
+ "348": "่งฃ{jie3}",
350
+ "349": "่งฃ{jie4}",
351
+ "350": "่งฃ{xie4}",
352
+ "351": "่ฏด{shui4}",
353
+ "352": "่ฏด{shuo1}",
354
+ "353": "่ฐƒ{diao4}",
355
+ "354": "่ฐƒ{tiao2}",
356
+ "355": "่ธ{ta1}",
357
+ "356": "่ธ{ta4}",
358
+ "357": "่ฝฆ{che1}",
359
+ "358": "่ฝฆ{jv1}",
360
+ "359": "่ฝฌ{zhuan3}",
361
+ "360": "่ฝฌ{zhuan4}",
362
+ "361": "่ฝฝ{zai3}",
363
+ "362": "่ฝฝ{zai4}",
364
+ "363": "่ฟ˜{hai2}",
365
+ "364": "่ฟ˜{huan2}",
366
+ "365": "้‚{sui2}",
367
+ "366": "้‚{sui4}",
368
+ "367": "้ƒฝ{dou1}",
369
+ "368": "้ƒฝ{du1}",
370
+ "369": "้‡{chong2}",
371
+ "370": "้‡{zhong4}",
372
+ "371": "้‡{liang2}",
373
+ "372": "้‡{liang4}",
374
+ "373": "้‡{liang5}",
375
+ "374": "้’ป{zuan1}",
376
+ "375": "้’ป{zuan4}",
377
+ "376": "้“บ{pu1}",
378
+ "377": "้“บ{pu4}",
379
+ "378": "้•ฟ{chang2}",
380
+ "379": "้•ฟ{chang3}",
381
+ "380": "้•ฟ{zhang3}",
382
+ "381": "้—ด{jian1}",
383
+ "382": "้—ด{jian4}",
384
+ "383": "้™{jiang4}",
385
+ "384": "้™{xiang2}",
386
+ "385": "้šพ{nan2}",
387
+ "386": "้šพ{nan4}",
388
+ "387": "้šพ{nan5}",
389
+ "388": "้œฒ{lou4}",
390
+ "389": "้œฒ{lu4}",
391
+ "390": "้ฒœ{xian1}",
392
+ "391": "้ฒœ{xian3}"
393
+ }
g2p/sources/g2p_chinese_model/polydict_r.json ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ไธง{sang1}": 1,
3
+ "ไธง{sang4}": 2,
4
+ "ไธญ{zhong1}": 3,
5
+ "ไธญ{zhong4}": 4,
6
+ "ไธบ{wei2}": 5,
7
+ "ไธบ{wei4}": 6,
8
+ "ไนŒ{wu1}": 7,
9
+ "ไนŒ{wu4}": 8,
10
+ "ไน{lao4}": 9,
11
+ "ไน{le4}": 10,
12
+ "ไน{le5}": 11,
13
+ "ไน{yao4}": 12,
14
+ "ไน{yve4}": 13,
15
+ "ไบ†{le5}": 14,
16
+ "ไบ†{liao3}": 15,
17
+ "ไบ†{liao5}": 16,
18
+ "ไป€{shen2}": 17,
19
+ "ไป€{shi2}": 18,
20
+ "ไป”{zai3}": 19,
21
+ "ไป”{zai5}": 20,
22
+ "ไป”{zi3}": 21,
23
+ "ไป”{zi5}": 22,
24
+ "ไปค{ling2}": 23,
25
+ "ไปค{ling4}": 24,
26
+ "ไปป{ren2}": 25,
27
+ "ไปป{ren4}": 26,
28
+ "ไผš{hui4}": 27,
29
+ "ไผš{hui5}": 28,
30
+ "ไผš{kuai4}": 29,
31
+ "ไผ {chuan2}": 30,
32
+ "ไผ {zhuan4}": 31,
33
+ "ไฝ›{fo2}": 32,
34
+ "ไฝ›{fu2}": 33,
35
+ "ไพ›{gong1}": 34,
36
+ "ไพ›{gong4}": 35,
37
+ "ไพฟ{bian4}": 36,
38
+ "ไพฟ{pian2}": 37,
39
+ "ๅ€’{dao3}": 38,
40
+ "ๅ€’{dao4}": 39,
41
+ "ๅ‡{jia3}": 40,
42
+ "ๅ‡{jia4}": 41,
43
+ "ๅ…ด{xing1}": 42,
44
+ "ๅ…ด{xing4}": 43,
45
+ "ๅ† {guan1}": 44,
46
+ "ๅ† {guan4}": 45,
47
+ "ๅ†ฒ{chong1}": 46,
48
+ "ๅ†ฒ{chong4}": 47,
49
+ "ๅ‡ {ji1}": 48,
50
+ "ๅ‡ {ji2}": 49,
51
+ "ๅ‡ {ji3}": 50,
52
+ "ๅˆ†{fen1}": 51,
53
+ "ๅˆ†{fen4}": 52,
54
+ "ๅˆ†{fen5}": 53,
55
+ "ๅˆ‡{qie1}": 54,
56
+ "ๅˆ‡{qie4}": 55,
57
+ "ๅˆ’{hua2}": 56,
58
+ "ๅˆ’{hua4}": 57,
59
+ "ๅˆ’{hua5}": 58,
60
+ "ๅˆ›{chuang1}": 59,
61
+ "ๅˆ›{chuang4}": 60,
62
+ "ๅ‰ฅ{bao1}": 61,
63
+ "ๅ‰ฅ{bo1}": 62,
64
+ "ๅ‹’{le4}": 63,
65
+ "ๅ‹’{le5}": 64,
66
+ "ๅ‹’{lei1}": 65,
67
+ "ๅŒบ{ou1}": 66,
68
+ "ๅŒบ{qu1}": 67,
69
+ "ๅŽ{hua2}": 68,
70
+ "ๅŽ{hua4}": 69,
71
+ "ๅ•{chan2}": 70,
72
+ "ๅ•{dan1}": 71,
73
+ "ๅ•{shan4}": 72,
74
+ "ๅœ{bo5}": 73,
75
+ "ๅœ{bu3}": 74,
76
+ "ๅ {zhan1}": 75,
77
+ "ๅ {zhan4}": 76,
78
+ "ๅก{ka2}": 77,
79
+ "ๅก{ka3}": 78,
80
+ "ๅก{qia3}": 79,
81
+ "ๅท{jvan3}": 80,
82
+ "ๅท{jvan4}": 81,
83
+ "ๅŽฆ{sha4}": 82,
84
+ "ๅŽฆ{xia4}": 83,
85
+ "ๅ‚{can1}": 84,
86
+ "ๅ‚{cen1}": 85,
87
+ "ๅ‚{shen1}": 86,
88
+ "ๅ‘{fa1}": 87,
89
+ "ๅ‘{fa4}": 88,
90
+ "ๅ‘{fa5}": 89,
91
+ "ๅช{zhi1}": 90,
92
+ "ๅช{zhi3}": 91,
93
+ "ๅท{hao2}": 92,
94
+ "ๅท{hao4}": 93,
95
+ "ๅท{hao5}": 94,
96
+ "ๅŒ{tong2}": 95,
97
+ "ๅŒ{tong4}": 96,
98
+ "ๅŒ{tong5}": 97,
99
+ "ๅ{tu2}": 98,
100
+ "ๅ{tu3}": 99,
101
+ "ๅ{tu4}": 100,
102
+ "ๅ’Œ{he2}": 101,
103
+ "ๅ’Œ{he4}": 102,
104
+ "ๅ’Œ{he5}": 103,
105
+ "ๅ’Œ{huo2}": 104,
106
+ "ๅ’Œ{huo4}": 105,
107
+ "ๅ’Œ{huo5}": 106,
108
+ "ๅ–{he1}": 107,
109
+ "ๅ–{he4}": 108,
110
+ "ๅœˆ{jvan4}": 109,
111
+ "ๅœˆ{qvan1}": 110,
112
+ "ๅœˆ{qvan5}": 111,
113
+ "ๅœฐ{de5}": 112,
114
+ "ๅœฐ{di4}": 113,
115
+ "ๅœฐ{di5}": 114,
116
+ "ๅกž{sai1}": 115,
117
+ "ๅกž{sai2}": 116,
118
+ "ๅกž{sai4}": 117,
119
+ "ๅกž{se4}": 118,
120
+ "ๅฃณ{ke2}": 119,
121
+ "ๅฃณ{qiao4}": 120,
122
+ "ๅค„{chu3}": 121,
123
+ "ๅค„{chu4}": 122,
124
+ "ๅฅ‡{ji1}": 123,
125
+ "ๅฅ‡{qi2}": 124,
126
+ "ๅฅ”{ben1}": 125,
127
+ "ๅฅ”{ben4}": 126,
128
+ "ๅฅฝ{hao3}": 127,
129
+ "ๅฅฝ{hao4}": 128,
130
+ "ๅฅฝ{hao5}": 129,
131
+ "ๅฎ{ning2}": 130,
132
+ "ๅฎ{ning4}": 131,
133
+ "ๅฎ{ning5}": 132,
134
+ "ๅฎฟ{su4}": 133,
135
+ "ๅฎฟ{xiu3}": 134,
136
+ "ๅฎฟ{xiu4}": 135,
137
+ "ๅฐ†{jiang1}": 136,
138
+ "ๅฐ†{jiang4}": 137,
139
+ "ๅฐ‘{shao3}": 138,
140
+ "ๅฐ‘{shao4}": 139,
141
+ "ๅฐฝ{jin3}": 140,
142
+ "ๅฐฝ{jin4}": 141,
143
+ "ๅฒ—{gang1}": 142,
144
+ "ๅฒ—{gang3}": 143,
145
+ "ๅทฎ{cha1}": 144,
146
+ "ๅทฎ{cha4}": 145,
147
+ "ๅทฎ{chai1}": 146,
148
+ "ๅทฎ{ci1}": 147,
149
+ "ๅทท{hang4}": 148,
150
+ "ๅทท{xiang4}": 149,
151
+ "ๅธ–{tie1}": 150,
152
+ "ๅธ–{tie3}": 151,
153
+ "ๅธ–{tie4}": 152,
154
+ "ๅนฒ{gan1}": 153,
155
+ "ๅนฒ{gan4}": 154,
156
+ "ๅบ”{ying1}": 155,
157
+ "ๅบ”{ying4}": 156,
158
+ "ๅบ”{ying5}": 157,
159
+ "ๅบฆ{du4}": 158,
160
+ "ๅบฆ{du5}": 159,
161
+ "ๅบฆ{duo2}": 160,
162
+ "ๅผน{dan4}": 161,
163
+ "ๅผน{tan2}": 162,
164
+ "ๅผน{tan5}": 163,
165
+ "ๅผบ{jiang4}": 164,
166
+ "ๅผบ{qiang2}": 165,
167
+ "ๅผบ{qiang3}": 166,
168
+ "ๅฝ“{dang1}": 167,
169
+ "ๅฝ“{dang4}": 168,
170
+ "ๅฝ“{dang5}": 169,
171
+ "ๅพ…{dai1}": 170,
172
+ "ๅพ…{dai4}": 171,
173
+ "ๅพ—{de2}": 172,
174
+ "ๅพ—{de5}": 173,
175
+ "ๅพ—{dei3}": 174,
176
+ "ๅพ—{dei5}": 175,
177
+ "ๆถ{e3}": 176,
178
+ "ๆถ{e4}": 177,
179
+ "ๆถ{wu4}": 178,
180
+ "ๆ‰{bian3}": 179,
181
+ "ๆ‰{pian1}": 180,
182
+ "ๆ‰‡{shan1}": 181,
183
+ "ๆ‰‡{shan4}": 182,
184
+ "ๆ‰Ž{za1}": 183,
185
+ "ๆ‰Ž{zha1}": 184,
186
+ "ๆ‰Ž{zha2}": 185,
187
+ "ๆ‰ซ{sao3}": 186,
188
+ "ๆ‰ซ{sao4}": 187,
189
+ "ๆ‹…{dan1}": 188,
190
+ "ๆ‹…{dan4}": 189,
191
+ "ๆ‹…{dan5}": 190,
192
+ "ๆŒ‘{tiao1}": 191,
193
+ "ๆŒ‘{tiao3}": 192,
194
+ "ๆฎ{jv1}": 193,
195
+ "ๆฎ{jv4}": 194,
196
+ "ๆ’’{sa1}": 195,
197
+ "ๆ’’{sa3}": 196,
198
+ "ๆ’’{sa5}": 197,
199
+ "ๆ•™{jiao1}": 198,
200
+ "ๆ•™{jiao4}": 199,
201
+ "ๆ•ฃ{san3}": 200,
202
+ "ๆ•ฃ{san4}": 201,
203
+ "ๆ•ฃ{san5}": 202,
204
+ "ๆ•ฐ{shu3}": 203,
205
+ "ๆ•ฐ{shu4}": 204,
206
+ "ๆ•ฐ{shu5}": 205,
207
+ "ๆ–—{dou3}": 206,
208
+ "ๆ–—{dou4}": 207,
209
+ "ๆ™ƒ{huang3}": 208,
210
+ "ๆ›{bao4}": 209,
211
+ "ๆ›ฒ{qu1}": 210,
212
+ "ๆ›ฒ{qu3}": 211,
213
+ "ๆ›ด{geng1}": 212,
214
+ "ๆ›ด{geng4}": 213,
215
+ "ๆ›พ{ceng1}": 214,
216
+ "ๆ›พ{ceng2}": 215,
217
+ "ๆ›พ{zeng1}": 216,
218
+ "ๆœ{chao2}": 217,
219
+ "ๆœ{zhao1}": 218,
220
+ "ๆœด{piao2}": 219,
221
+ "ๆœด{pu2}": 220,
222
+ "ๆœด{pu3}": 221,
223
+ "ๆ†{gan1}": 222,
224
+ "ๆ†{gan3}": 223,
225
+ "ๆŸฅ{cha2}": 224,
226
+ "ๆŸฅ{zha1}": 225,
227
+ "ๆ ก{jiao4}": 226,
228
+ "ๆ ก{xiao4}": 227,
229
+ "ๆจก{mo2}": 228,
230
+ "ๆจก{mu2}": 229,
231
+ "ๆจช{heng2}": 230,
232
+ "ๆจช{heng4}": 231,
233
+ "ๆฒก{mei2}": 232,
234
+ "ๆฒก{mo4}": 233,
235
+ "ๆณก{pao1}": 234,
236
+ "ๆณก{pao4}": 235,
237
+ "ๆณก{pao5}": 236,
238
+ "ๆตŽ{ji3}": 237,
239
+ "ๆตŽ{ji4}": 238,
240
+ "ๆทท{hun2}": 239,
241
+ "ๆทท{hun3}": 240,
242
+ "ๆทท{hun4}": 241,
243
+ "ๆทท{hun5}": 242,
244
+ "ๆผ‚{piao1}": 243,
245
+ "ๆผ‚{piao3}": 244,
246
+ "ๆผ‚{piao4}": 245,
247
+ "็‚ธ{zha2}": 246,
248
+ "็‚ธ{zha4}": 247,
249
+ "็†Ÿ{shou2}": 248,
250
+ "็†Ÿ{shu2}": 249,
251
+ "็‡•{yan1}": 250,
252
+ "็‡•{yan4}": 251,
253
+ "็‰‡{pian1}": 252,
254
+ "็‰‡{pian4}": 253,
255
+ "็އ{lv4}": 254,
256
+ "็އ{shuai4}": 255,
257
+ "็•œ{chu4}": 256,
258
+ "็•œ{xu4}": 257,
259
+ "็š„{de5}": 258,
260
+ "็š„{di1}": 259,
261
+ "็š„{di2}": 260,
262
+ "็š„{di4}": 261,
263
+ "็š„{di5}": 262,
264
+ "็››{cheng2}": 263,
265
+ "็››{sheng4}": 264,
266
+ "็›ธ{xiang1}": 265,
267
+ "็›ธ{xiang4}": 266,
268
+ "็›ธ{xiang5}": 267,
269
+ "็œ{sheng3}": 268,
270
+ "็œ{xing3}": 269,
271
+ "็œ‹{kan1}": 270,
272
+ "็œ‹{kan4}": 271,
273
+ "็œ‹{kan5}": 272,
274
+ "็€{zhao1}": 273,
275
+ "็€{zhao2}": 274,
276
+ "็€{zhao5}": 275,
277
+ "็€{zhe5}": 276,
278
+ "็€{zhuo2}": 277,
279
+ "็€{zhuo5}": 278,
280
+ "็Ÿซ{jiao3}": 279,
281
+ "็ฆ{jin1}": 280,
282
+ "็ฆ{jin4}": 281,
283
+ "็ง{zhong3}": 282,
284
+ "็ง{zhong4}": 283,
285
+ "็งฐ{chen4}": 284,
286
+ "็งฐ{cheng1}": 285,
287
+ "็ฉบ{kong1}": 286,
288
+ "็ฉบ{kong4}": 287,
289
+ "็ญ”{da1}": 288,
290
+ "็ญ”{da2}": 289,
291
+ "็ฒ˜{nian2}": 290,
292
+ "็ฒ˜{zhan1}": 291,
293
+ "็ณŠ{hu2}": 292,
294
+ "็ณŠ{hu5}": 293,
295
+ "็ณป{ji4}": 294,
296
+ "็ณป{xi4}": 295,
297
+ "็ณป{xi5}": 296,
298
+ "็ดฏ{lei2}": 297,
299
+ "็ดฏ{lei3}": 298,
300
+ "็ดฏ{lei4}": 299,
301
+ "็ดฏ{lei5}": 300,
302
+ "็บค{qian4}": 301,
303
+ "็บค{xian1}": 302,
304
+ "็ป“{jie1}": 303,
305
+ "็ป“{jie2}": 304,
306
+ "็ป“{jie5}": 305,
307
+ "็ป™{gei3}": 306,
308
+ "็ป™{gei5}": 307,
309
+ "็ป™{ji3}": 308,
310
+ "็ผ{feng2}": 309,
311
+ "็ผ{feng4}": 310,
312
+ "็ผ{feng5}": 311,
313
+ "่‚–{xiao1}": 312,
314
+ "่‚–{xiao4}": 313,
315
+ "่ƒŒ{bei1}": 314,
316
+ "่ƒŒ{bei4}": 315,
317
+ "่„{zang1}": 316,
318
+ "่„{zang4}": 317,
319
+ "่ˆ{she3}": 318,
320
+ "่ˆ{she4}": 319,
321
+ "่‰ฒ{se4}": 320,
322
+ "่‰ฒ{shai3}": 321,
323
+ "่ฝ{lao4}": 322,
324
+ "่ฝ{luo4}": 323,
325
+ "่’™{meng1}": 324,
326
+ "่’™{meng2}": 325,
327
+ "่’™{meng3}": 326,
328
+ "่–„{bao2}": 327,
329
+ "่–„{bo2}": 328,
330
+ "่–„{bo4}": 329,
331
+ "่—{cang2}": 330,
332
+ "่—{zang4}": 331,
333
+ "่ก€{xie3}": 332,
334
+ "่ก€{xue4}": 333,
335
+ "่กŒ{hang2}": 334,
336
+ "่กŒ{hang5}": 335,
337
+ "่กŒ{heng5}": 336,
338
+ "่กŒ{xing2}": 337,
339
+ "่กŒ{xing4}": 338,
340
+ "่ฆ{yao1}": 339,
341
+ "่ฆ{yao4}": 340,
342
+ "่ง‚{guan1}": 341,
343
+ "่ง‚{guan4}": 342,
344
+ "่ง‰{jiao4}": 343,
345
+ "่ง‰{jiao5}": 344,
346
+ "่ง‰{jve2}": 345,
347
+ "่ง’{jiao3}": 346,
348
+ "่ง’{jve2}": 347,
349
+ "่งฃ{jie3}": 348,
350
+ "่งฃ{jie4}": 349,
351
+ "่งฃ{xie4}": 350,
352
+ "่ฏด{shui4}": 351,
353
+ "่ฏด{shuo1}": 352,
354
+ "่ฐƒ{diao4}": 353,
355
+ "่ฐƒ{tiao2}": 354,
356
+ "่ธ{ta1}": 355,
357
+ "่ธ{ta4}": 356,
358
+ "่ฝฆ{che1}": 357,
359
+ "่ฝฆ{jv1}": 358,
360
+ "่ฝฌ{zhuan3}": 359,
361
+ "่ฝฌ{zhuan4}": 360,
362
+ "่ฝฝ{zai3}": 361,
363
+ "่ฝฝ{zai4}": 362,
364
+ "่ฟ˜{hai2}": 363,
365
+ "่ฟ˜{huan2}": 364,
366
+ "้‚{sui2}": 365,
367
+ "้‚{sui4}": 366,
368
+ "้ƒฝ{dou1}": 367,
369
+ "้ƒฝ{du1}": 368,
370
+ "้‡{chong2}": 369,
371
+ "้‡{zhong4}": 370,
372
+ "้‡{liang2}": 371,
373
+ "้‡{liang4}": 372,
374
+ "้‡{liang5}": 373,
375
+ "้’ป{zuan1}": 374,
376
+ "้’ป{zuan4}": 375,
377
+ "้“บ{pu1}": 376,
378
+ "้“บ{pu4}": 377,
379
+ "้•ฟ{chang2}": 378,
380
+ "้•ฟ{chang3}": 379,
381
+ "้•ฟ{zhang3}": 380,
382
+ "้—ด{jian1}": 381,
383
+ "้—ด{jian4}": 382,
384
+ "้™{jiang4}": 383,
385
+ "้™{xiang2}": 384,
386
+ "้šพ{nan2}": 385,
387
+ "้šพ{nan4}": 386,
388
+ "้šพ{nan5}": 387,
389
+ "้œฒ{lou4}": 388,
390
+ "้œฒ{lu4}": 389,
391
+ "้ฒœ{xian1}": 390,
392
+ "้ฒœ{xian3}": 391
393
+ }