artificialguybr commited on
Commit
ca19c59
1 Parent(s): 8ae6d76

Updated TTS to latest version

Browse files
TTS/.github/workflows/zoo_tests_tortoise.yml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: zoo-tests-tortoise
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ pull_request:
8
+ types: [opened, synchronize, reopened]
9
+ jobs:
10
+ check_skip:
11
+ runs-on: ubuntu-latest
12
+ if: "! contains(github.event.head_commit.message, '[ci skip]')"
13
+ steps:
14
+ - run: echo "${{ github.event.head_commit.message }}"
15
+
16
+ test:
17
+ runs-on: ubuntu-latest
18
+ strategy:
19
+ fail-fast: false
20
+ matrix:
21
+ python-version: [3.9, "3.10", "3.11"]
22
+ experimental: [false]
23
+ steps:
24
+ - uses: actions/checkout@v3
25
+ - name: Set up Python ${{ matrix.python-version }}
26
+ uses: actions/setup-python@v4
27
+ with:
28
+ python-version: ${{ matrix.python-version }}
29
+ architecture: x64
30
+ cache: 'pip'
31
+ cache-dependency-path: 'requirements*'
32
+ - name: check OS
33
+ run: cat /etc/os-release
34
+ - name: set ENV
35
+ run: export TRAINER_TELEMETRY=0
36
+ - name: Install dependencies
37
+ run: |
38
+ sudo apt-get update
39
+ sudo apt-get install -y git make gcc
40
+ sudo apt-get install espeak espeak-ng
41
+ make system-deps
42
+ - name: Install/upgrade Python setup deps
43
+ run: python3 -m pip install --upgrade pip setuptools wheel
44
+ - name: Replace scarf urls
45
+ run: |
46
+ sed -i 's/https:\/\/coqui.gateway.scarf.sh\//https:\/\/github.com\/coqui-ai\/TTS\/releases\/download\//g' TTS/.models.json
47
+ - name: Install TTS
48
+ run: |
49
+ python3 -m pip install .[all]
50
+ python3 setup.py egg_info
51
+ - name: Unit tests
52
+ run: nose2 -F -v -B --with-coverage --coverage TTS tests.zoo_tests.test_models.test_tortoise
TTS/README.md CHANGED
@@ -146,7 +146,7 @@ Underlined "TTS*" and "Judy*" are **internal** 🐸TTS models that are not relea
146
  You can also help us implement more models.
147
 
148
  ## Installation
149
- 🐸TTS is tested on Ubuntu 18.04 with **python >= 3.7, < 3.11.**.
150
 
151
  If you are only interested in [synthesizing speech](https://tts.readthedocs.io/en/latest/inference.html) with the released 🐸TTS models, installing from PyPI is the easiest option.
152
 
@@ -198,17 +198,18 @@ from TTS.api import TTS
198
  # Get device
199
  device = "cuda" if torch.cuda.is_available() else "cpu"
200
 
201
- # List available 🐸TTS models and choose the first one
202
- model_name = TTS().list_models()[0]
 
203
  # Init TTS
204
- tts = TTS(model_name).to(device)
205
 
206
  # Run TTS
207
- # ❗ Since this model is multi-speaker and multi-lingual, we must set the target speaker and the language
208
- # Text to speech with a numpy output
209
- wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0], language=tts.languages[0])
210
  # Text to speech to a file
211
- tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav")
212
  ```
213
 
214
  #### Running a single speaker model
@@ -347,6 +348,18 @@ If you don't specify any models, then it uses LJSpeech based English model.
347
  $ tts --text "Text for TTS" --out_path output/path/speech.wav
348
  ```
349
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  - Run a TTS model with its default vocoder model:
351
 
352
  ```
 
146
  You can also help us implement more models.
147
 
148
  ## Installation
149
+ 🐸TTS is tested on Ubuntu 18.04 with **python >= 3.9, < 3.12.**.
150
 
151
  If you are only interested in [synthesizing speech](https://tts.readthedocs.io/en/latest/inference.html) with the released 🐸TTS models, installing from PyPI is the easiest option.
152
 
 
198
  # Get device
199
  device = "cuda" if torch.cuda.is_available() else "cpu"
200
 
201
+ # List available 🐸TTS models
202
+ print(TTS().list_models())
203
+
204
  # Init TTS
205
+ tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1").to(device)
206
 
207
  # Run TTS
208
+ # ❗ Since this model is multi-lingual voice cloning model, we must set the target speaker_wav and language
209
+ # Text to speech list of amplitude values as output
210
+ wav = tts.tts(text="Hello world!", speaker_wav="my/cloning/audio.wav", language="en")
211
  # Text to speech to a file
212
+ tts.tts_to_file(text="Hello world!", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav")
213
  ```
214
 
215
  #### Running a single speaker model
 
348
  $ tts --text "Text for TTS" --out_path output/path/speech.wav
349
  ```
350
 
351
+ - Run TTS and pipe out the generated TTS wav file data:
352
+
353
+ ```
354
+ $ tts --text "Text for TTS" --pipe_out --out_path output/path/speech.wav | aplay
355
+ ```
356
+
357
+ - Run TTS and define speed factor to use for 🐸Coqui Studio models, between 0.0 and 2.0:
358
+
359
+ ```
360
+ $ tts --text "Text for TTS" --model_name "coqui_studio/<language>/<dataset>/<model_name>" --speed 1.2 --out_path output/path/speech.wav
361
+ ```
362
+
363
  - Run a TTS model with its default vocoder model:
364
 
365
  ```
TTS/TTS/.models.json CHANGED
@@ -5,12 +5,27 @@
5
  "xtts_v1": {
6
  "description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.",
7
  "hf_url": [
8
- "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/model.pth",
9
- "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/config.json",
10
- "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/vocab.json"
11
  ],
12
  "default_vocoder": null,
13
- "commit": "e9a1953e",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  "license": "CPML",
15
  "contact": "info@coqui.ai",
16
  "tos_required": true
@@ -917,4 +932,4 @@
917
  }
918
  }
919
  }
920
- }
 
5
  "xtts_v1": {
6
  "description": "XTTS-v1 by Coqui with 13 languages and cross-language voice cloning.",
7
  "hf_url": [
8
+ "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/model.pth",
9
+ "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/config.json",
10
+ "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/hifigan/vocab.json"
11
  ],
12
  "default_vocoder": null,
13
+ "commit": "e5140314",
14
+ "license": "CPML",
15
+ "contact": "info@coqui.ai",
16
+ "tos_required": true
17
+ },
18
+ "xtts_v1.1": {
19
+ "description": "XTTS-v1.1 by Coqui with 14 languages, cross-language voice cloning and reference leak fixed.",
20
+ "hf_url": [
21
+ "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/model.pth",
22
+ "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/config.json",
23
+ "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/vocab.json",
24
+ "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v1/v1.1.1/hash.md5"
25
+ ],
26
+ "model_hash": "ae9e4b39e095fd5728fe7f7931ec66ad",
27
+ "default_vocoder": null,
28
+ "commit": "82910a63",
29
  "license": "CPML",
30
  "contact": "info@coqui.ai",
31
  "tos_required": true
 
932
  }
933
  }
934
  }
935
+ }
TTS/TTS/VERSION CHANGED
@@ -1 +1 @@
1
- 0.17.5
 
1
+ 0.18.2
TTS/TTS/api.py CHANGED
@@ -17,7 +17,7 @@ class TTS(nn.Module):
17
 
18
  def __init__(
19
  self,
20
- model_name: str = None,
21
  model_path: str = None,
22
  config_path: str = None,
23
  vocoder_path: str = None,
@@ -105,8 +105,8 @@ class TTS(nn.Module):
105
 
106
  @property
107
  def is_multi_lingual(self):
108
- # TODO: fix this
109
- if "xtts" in self.model_name:
110
  return True
111
  if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager:
112
  return self.synthesizer.tts_model.language_manager.num_languages > 1
@@ -264,6 +264,7 @@ class TTS(nn.Module):
264
  language: str = None,
265
  emotion: str = None,
266
  speed: float = 1.0,
 
267
  file_path: str = None,
268
  ) -> Union[np.ndarray, str]:
269
  """Convert text to speech using Coqui Studio models. Use `CS_API` class if you are only interested in the API.
@@ -280,6 +281,8 @@ class TTS(nn.Module):
280
  with "V1" model. Defaults to None.
281
  speed (float, optional):
282
  Speed of the speech. Defaults to 1.0.
 
 
283
  file_path (str, optional):
284
  Path to save the output file. When None it returns the `np.ndarray` of waveform. Defaults to None.
285
 
@@ -293,6 +296,7 @@ class TTS(nn.Module):
293
  speaker_name=speaker_name,
294
  language=language,
295
  speed=speed,
 
296
  emotion=emotion,
297
  file_path=file_path,
298
  )[0]
@@ -355,6 +359,7 @@ class TTS(nn.Module):
355
  speaker_wav: str = None,
356
  emotion: str = None,
357
  speed: float = 1.0,
 
358
  file_path: str = "output.wav",
359
  **kwargs,
360
  ):
@@ -376,6 +381,8 @@ class TTS(nn.Module):
376
  Emotion to use for 🐸Coqui Studio models. Defaults to "Neutral".
377
  speed (float, optional):
378
  Speed factor to use for 🐸Coqui Studio models, between 0.0 and 2.0. Defaults to None.
 
 
379
  file_path (str, optional):
380
  Output file path. Defaults to "output.wav".
381
  kwargs (dict, optional):
@@ -385,10 +392,16 @@ class TTS(nn.Module):
385
 
386
  if self.csapi is not None:
387
  return self.tts_coqui_studio(
388
- text=text, speaker_name=speaker, language=language, emotion=emotion, speed=speed, file_path=file_path
 
 
 
 
 
 
389
  )
390
  wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)
391
- self.synthesizer.save_wav(wav=wav, path=file_path)
392
  return file_path
393
 
394
  def voice_conversion(
 
17
 
18
  def __init__(
19
  self,
20
+ model_name: str = "",
21
  model_path: str = None,
22
  config_path: str = None,
23
  vocoder_path: str = None,
 
105
 
106
  @property
107
  def is_multi_lingual(self):
108
+ # Not sure what sets this to None, but applied a fix to prevent crashing.
109
+ if isinstance(self.model_name, str) and "xtts" in self.model_name:
110
  return True
111
  if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager:
112
  return self.synthesizer.tts_model.language_manager.num_languages > 1
 
264
  language: str = None,
265
  emotion: str = None,
266
  speed: float = 1.0,
267
+ pipe_out = None,
268
  file_path: str = None,
269
  ) -> Union[np.ndarray, str]:
270
  """Convert text to speech using Coqui Studio models. Use `CS_API` class if you are only interested in the API.
 
281
  with "V1" model. Defaults to None.
282
  speed (float, optional):
283
  Speed of the speech. Defaults to 1.0.
284
+ pipe_out (BytesIO, optional):
285
+ Flag to stdout the generated TTS wav file for shell pipe.
286
  file_path (str, optional):
287
  Path to save the output file. When None it returns the `np.ndarray` of waveform. Defaults to None.
288
 
 
296
  speaker_name=speaker_name,
297
  language=language,
298
  speed=speed,
299
+ pipe_out=pipe_out,
300
  emotion=emotion,
301
  file_path=file_path,
302
  )[0]
 
359
  speaker_wav: str = None,
360
  emotion: str = None,
361
  speed: float = 1.0,
362
+ pipe_out = None,
363
  file_path: str = "output.wav",
364
  **kwargs,
365
  ):
 
381
  Emotion to use for 🐸Coqui Studio models. Defaults to "Neutral".
382
  speed (float, optional):
383
  Speed factor to use for 🐸Coqui Studio models, between 0.0 and 2.0. Defaults to None.
384
+ pipe_out (BytesIO, optional):
385
+ Flag to stdout the generated TTS wav file for shell pipe.
386
  file_path (str, optional):
387
  Output file path. Defaults to "output.wav".
388
  kwargs (dict, optional):
 
392
 
393
  if self.csapi is not None:
394
  return self.tts_coqui_studio(
395
+ text=text,
396
+ speaker_name=speaker,
397
+ language=language,
398
+ emotion=emotion,
399
+ speed=speed,
400
+ file_path=file_path,
401
+ pipe_out=pipe_out,
402
  )
403
  wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)
404
+ self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out)
405
  return file_path
406
 
407
  def voice_conversion(
TTS/TTS/bin/synthesize.py CHANGED
@@ -2,6 +2,7 @@
2
  # -*- coding: utf-8 -*-
3
 
4
  import argparse
 
5
  import sys
6
  from argparse import RawTextHelpFormatter
7
 
@@ -59,6 +60,18 @@ If you don't specify any models, then it uses LJSpeech based English model.
59
  $ tts --text "Text for TTS" --out_path output/path/speech.wav
60
  ```
61
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  - Run a TTS model with its default vocoder model:
63
 
64
  ```
@@ -228,6 +241,20 @@ def main():
228
  help="Language to condition the model with. Only available for 🐸Coqui Studio `XTTS-multilingual` model.",
229
  default=None,
230
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  # args for multi-speaker synthesis
233
  parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
@@ -335,167 +362,177 @@ def main():
335
  if not any(check_args):
336
  parser.parse_args(["-h"])
337
 
338
- # Late-import to make things load faster
339
- from TTS.api import TTS
340
- from TTS.utils.manage import ModelManager
341
- from TTS.utils.synthesizer import Synthesizer
342
-
343
- # load model manager
344
- path = Path(__file__).parent / "../.models.json"
345
- manager = ModelManager(path, progress_bar=args.progress_bar)
346
- api = TTS()
347
-
348
- tts_path = None
349
- tts_config_path = None
350
- speakers_file_path = None
351
- language_ids_file_path = None
352
- vocoder_path = None
353
- vocoder_config_path = None
354
- encoder_path = None
355
- encoder_config_path = None
356
- vc_path = None
357
- vc_config_path = None
358
- model_dir = None
359
-
360
- # CASE1 #list : list pre-trained TTS models
361
- if args.list_models:
362
- manager.add_cs_api_models(api.list_models())
363
- manager.list_models()
364
- sys.exit()
365
-
366
- # CASE2 #info : model info for pre-trained TTS models
367
- if args.model_info_by_idx:
368
- model_query = args.model_info_by_idx
369
- manager.model_info_by_idx(model_query)
370
- sys.exit()
371
-
372
- if args.model_info_by_name:
373
- model_query_full_name = args.model_info_by_name
374
- manager.model_info_by_full_name(model_query_full_name)
375
- sys.exit()
376
-
377
- # CASE3: TTS with coqui studio models
378
- if "coqui_studio" in args.model_name:
379
- print(" > Using 🐸Coqui Studio model: ", args.model_name)
380
- api = TTS(model_name=args.model_name, cs_api_model=args.cs_model)
381
- api.tts_to_file(text=args.text, emotion=args.emotion, file_path=args.out_path, language=args.language)
382
- print(" > Saving output to ", args.out_path)
383
- return
384
-
385
- # CASE4: load pre-trained model paths
386
- if args.model_name is not None and not args.model_path:
387
- model_path, config_path, model_item = manager.download_model(args.model_name)
388
- # tts model
389
- if model_item["model_type"] == "tts_models":
390
- tts_path = model_path
391
- tts_config_path = config_path
392
- if "default_vocoder" in model_item:
393
- args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
394
-
395
- # voice conversion model
396
- if model_item["model_type"] == "voice_conversion_models":
397
- vc_path = model_path
398
- vc_config_path = config_path
399
-
400
- # tts model with multiple files to be loaded from the directory path
401
- if model_item.get("author", None) == "fairseq" or isinstance(model_item["model_url"], list):
402
- model_dir = model_path
403
- tts_path = None
404
- tts_config_path = None
405
- args.vocoder_name = None
406
-
407
- # load vocoder
408
- if args.vocoder_name is not None and not args.vocoder_path:
409
- vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
410
-
411
- # CASE5: set custom model paths
412
- if args.model_path is not None:
413
- tts_path = args.model_path
414
- tts_config_path = args.config_path
415
- speakers_file_path = args.speakers_file_path
416
- language_ids_file_path = args.language_ids_file_path
417
-
418
- if args.vocoder_path is not None:
419
- vocoder_path = args.vocoder_path
420
- vocoder_config_path = args.vocoder_config_path
421
-
422
- if args.encoder_path is not None:
423
- encoder_path = args.encoder_path
424
- encoder_config_path = args.encoder_config_path
425
-
426
- device = args.device
427
- if args.use_cuda:
428
- device = "cuda"
429
-
430
- # load models
431
- synthesizer = Synthesizer(
432
- tts_path,
433
- tts_config_path,
434
- speakers_file_path,
435
- language_ids_file_path,
436
- vocoder_path,
437
- vocoder_config_path,
438
- encoder_path,
439
- encoder_config_path,
440
- vc_path,
441
- vc_config_path,
442
- model_dir,
443
- args.voice_dir,
444
- ).to(device)
445
-
446
- # query speaker ids of a multi-speaker model.
447
- if args.list_speaker_idxs:
448
- print(
449
- " > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
450
- )
451
- print(synthesizer.tts_model.speaker_manager.name_to_id)
452
- return
453
-
454
- # query langauge ids of a multi-lingual model.
455
- if args.list_language_idxs:
456
- print(
457
- " > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
458
- )
459
- print(synthesizer.tts_model.language_manager.name_to_id)
460
- return
461
-
462
- # check the arguments against a multi-speaker model.
463
- if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
464
- print(
465
- " [!] Looks like you use a multi-speaker model. Define `--speaker_idx` to "
466
- "select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`."
467
- )
468
- return
469
-
470
- # RUN THE SYNTHESIS
471
- if args.text:
472
- print(" > Text: {}".format(args.text))
473
-
474
- # kick it
475
- if tts_path is not None:
476
- wav = synthesizer.tts(
477
- args.text,
478
- speaker_name=args.speaker_idx,
479
- language_name=args.language_idx,
480
- speaker_wav=args.speaker_wav,
481
- reference_wav=args.reference_wav,
482
- style_wav=args.capacitron_style_wav,
483
- style_text=args.capacitron_style_text,
484
- reference_speaker_name=args.reference_speaker_idx,
485
- )
486
- elif vc_path is not None:
487
- wav = synthesizer.voice_conversion(
488
- source_wav=args.source_wav,
489
- target_wav=args.target_wav,
490
- )
491
- elif model_dir is not None:
492
- wav = synthesizer.tts(
493
- args.text, speaker_name=args.speaker_idx, language_name=args.language_idx, speaker_wav=args.speaker_wav
494
- )
495
-
496
- # save the results
497
- print(" > Saving output to {}".format(args.out_path))
498
- synthesizer.save_wav(wav, args.out_path)
 
 
 
 
 
 
 
 
 
 
499
 
500
 
501
  if __name__ == "__main__":
 
2
  # -*- coding: utf-8 -*-
3
 
4
  import argparse
5
+ import contextlib
6
  import sys
7
  from argparse import RawTextHelpFormatter
8
 
 
60
  $ tts --text "Text for TTS" --out_path output/path/speech.wav
61
  ```
62
 
63
+ - Run TTS and pipe out the generated TTS wav file data:
64
+
65
+ ```
66
+ $ tts --text "Text for TTS" --pipe_out --out_path output/path/speech.wav | aplay
67
+ ```
68
+
69
+ - Run TTS and define speed factor to use for 🐸Coqui Studio models, between 0.0 and 2.0:
70
+
71
+ ```
72
+ $ tts --text "Text for TTS" --model_name "coqui_studio/<language>/<dataset>/<model_name>" --speed 1.2 --out_path output/path/speech.wav
73
+ ```
74
+
75
  - Run a TTS model with its default vocoder model:
76
 
77
  ```
 
241
  help="Language to condition the model with. Only available for 🐸Coqui Studio `XTTS-multilingual` model.",
242
  default=None,
243
  )
244
+ parser.add_argument(
245
+ "--pipe_out",
246
+ help="stdout the generated TTS wav file for shell pipe.",
247
+ type=str2bool,
248
+ nargs="?",
249
+ const=True,
250
+ default=False,
251
+ )
252
+ parser.add_argument(
253
+ "--speed",
254
+ type=float,
255
+ help="Speed factor to use for 🐸Coqui Studio models, between 0.0 and 2.0.",
256
+ default=None,
257
+ )
258
 
259
  # args for multi-speaker synthesis
260
  parser.add_argument("--speakers_file_path", type=str, help="JSON file for multi-speaker model.", default=None)
 
362
  if not any(check_args):
363
  parser.parse_args(["-h"])
364
 
365
+ pipe_out = sys.stdout if args.pipe_out else None
366
+
367
+ with contextlib.redirect_stdout(None if args.pipe_out else sys.stdout):
368
+ # Late-import to make things load faster
369
+ from TTS.api import TTS
370
+ from TTS.utils.manage import ModelManager
371
+ from TTS.utils.synthesizer import Synthesizer
372
+
373
+ # load model manager
374
+ path = Path(__file__).parent / "../.models.json"
375
+ manager = ModelManager(path, progress_bar=args.progress_bar)
376
+ api = TTS()
377
+
378
+ tts_path = None
379
+ tts_config_path = None
380
+ speakers_file_path = None
381
+ language_ids_file_path = None
382
+ vocoder_path = None
383
+ vocoder_config_path = None
384
+ encoder_path = None
385
+ encoder_config_path = None
386
+ vc_path = None
387
+ vc_config_path = None
388
+ model_dir = None
389
+
390
+ # CASE1 #list : list pre-trained TTS models
391
+ if args.list_models:
392
+ manager.add_cs_api_models(api.list_models())
393
+ manager.list_models()
394
+ sys.exit()
395
+
396
+ # CASE2 #info : model info for pre-trained TTS models
397
+ if args.model_info_by_idx:
398
+ model_query = args.model_info_by_idx
399
+ manager.model_info_by_idx(model_query)
400
+ sys.exit()
401
+
402
+ if args.model_info_by_name:
403
+ model_query_full_name = args.model_info_by_name
404
+ manager.model_info_by_full_name(model_query_full_name)
405
+ sys.exit()
406
+
407
+ # CASE3: TTS with coqui studio models
408
+ if "coqui_studio" in args.model_name:
409
+ print(" > Using 🐸Coqui Studio model: ", args.model_name)
410
+ api = TTS(model_name=args.model_name, cs_api_model=args.cs_model)
411
+ api.tts_to_file(
412
+ text=args.text,
413
+ emotion=args.emotion,
414
+ file_path=args.out_path,
415
+ language=args.language,
416
+ speed=args.speed,
417
+ pipe_out=pipe_out,
418
+ )
419
+ print(" > Saving output to ", args.out_path)
420
+ return
421
+
422
+ # CASE4: load pre-trained model paths
423
+ if args.model_name is not None and not args.model_path:
424
+ model_path, config_path, model_item = manager.download_model(args.model_name)
425
+ # tts model
426
+ if model_item["model_type"] == "tts_models":
427
+ tts_path = model_path
428
+ tts_config_path = config_path
429
+ if "default_vocoder" in model_item:
430
+ args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
431
+
432
+ # voice conversion model
433
+ if model_item["model_type"] == "voice_conversion_models":
434
+ vc_path = model_path
435
+ vc_config_path = config_path
436
+
437
+ # tts model with multiple files to be loaded from the directory path
438
+ if model_item.get("author", None) == "fairseq" or isinstance(model_item["model_url"], list):
439
+ model_dir = model_path
440
+ tts_path = None
441
+ tts_config_path = None
442
+ args.vocoder_name = None
443
+
444
+ # load vocoder
445
+ if args.vocoder_name is not None and not args.vocoder_path:
446
+ vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
447
+
448
+ # CASE5: set custom model paths
449
+ if args.model_path is not None:
450
+ tts_path = args.model_path
451
+ tts_config_path = args.config_path
452
+ speakers_file_path = args.speakers_file_path
453
+ language_ids_file_path = args.language_ids_file_path
454
+
455
+ if args.vocoder_path is not None:
456
+ vocoder_path = args.vocoder_path
457
+ vocoder_config_path = args.vocoder_config_path
458
+
459
+ if args.encoder_path is not None:
460
+ encoder_path = args.encoder_path
461
+ encoder_config_path = args.encoder_config_path
462
+
463
+ device = args.device
464
+ if args.use_cuda:
465
+ device = "cuda"
466
+
467
+ # load models
468
+ synthesizer = Synthesizer(
469
+ tts_path,
470
+ tts_config_path,
471
+ speakers_file_path,
472
+ language_ids_file_path,
473
+ vocoder_path,
474
+ vocoder_config_path,
475
+ encoder_path,
476
+ encoder_config_path,
477
+ vc_path,
478
+ vc_config_path,
479
+ model_dir,
480
+ args.voice_dir,
481
+ ).to(device)
482
+
483
+ # query speaker ids of a multi-speaker model.
484
+ if args.list_speaker_idxs:
485
+ print(
486
+ " > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
487
+ )
488
+ print(synthesizer.tts_model.speaker_manager.name_to_id)
489
+ return
490
+
491
+ # query langauge ids of a multi-lingual model.
492
+ if args.list_language_idxs:
493
+ print(
494
+ " > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
495
+ )
496
+ print(synthesizer.tts_model.language_manager.name_to_id)
497
+ return
498
+
499
+ # check the arguments against a multi-speaker model.
500
+ if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
501
+ print(
502
+ " [!] Looks like you use a multi-speaker model. Define `--speaker_idx` to "
503
+ "select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`."
504
+ )
505
+ return
506
+
507
+ # RUN THE SYNTHESIS
508
+ if args.text:
509
+ print(" > Text: {}".format(args.text))
510
+
511
+ # kick it
512
+ if tts_path is not None:
513
+ wav = synthesizer.tts(
514
+ args.text,
515
+ speaker_name=args.speaker_idx,
516
+ language_name=args.language_idx,
517
+ speaker_wav=args.speaker_wav,
518
+ reference_wav=args.reference_wav,
519
+ style_wav=args.capacitron_style_wav,
520
+ style_text=args.capacitron_style_text,
521
+ reference_speaker_name=args.reference_speaker_idx,
522
+ )
523
+ elif vc_path is not None:
524
+ wav = synthesizer.voice_conversion(
525
+ source_wav=args.source_wav,
526
+ target_wav=args.target_wav,
527
+ )
528
+ elif model_dir is not None:
529
+ wav = synthesizer.tts(
530
+ args.text, speaker_name=args.speaker_idx, language_name=args.language_idx, speaker_wav=args.speaker_wav
531
+ )
532
+
533
+ # save the results
534
+ print(" > Saving output to {}".format(args.out_path))
535
+ synthesizer.save_wav(wav, args.out_path, pipe_out=pipe_out)
536
 
537
 
538
  if __name__ == "__main__":
TTS/TTS/cs_api.py CHANGED
@@ -9,6 +9,8 @@ import numpy as np
9
  import requests
10
  from scipy.io import wavfile
11
 
 
 
12
 
13
  class Speaker(object):
14
  """Convert dict to object."""
@@ -288,6 +290,7 @@ class CS_API:
288
  speaker_id=None,
289
  emotion=None,
290
  speed=1.0,
 
291
  language=None,
292
  file_path: str = None,
293
  ) -> str:
@@ -300,6 +303,7 @@ class CS_API:
300
  speaker_id (str): Speaker ID. If None, the speaker name is used.
301
  emotion (str): Emotion of the speaker. One of "Neutral", "Happy", "Sad", "Angry", "Dull".
302
  speed (float): Speed of the speech. 1.0 is normal speed.
 
303
  language (str): Language of the text. If None, the default language of the speaker is used. Language is only
304
  supported by `XTTS-multilang` model. Currently supports en, de, es, fr, it, pt, pl. Defaults to "en".
305
  file_path (str): Path to save the file. If None, a temporary file is created.
@@ -307,7 +311,7 @@ class CS_API:
307
  if file_path is None:
308
  file_path = tempfile.mktemp(".wav")
309
  wav, sr = self.tts(text, speaker_name, speaker_id, emotion, speed, language)
310
- wavfile.write(file_path, sr, wav)
311
  return file_path
312
 
313
 
 
9
  import requests
10
  from scipy.io import wavfile
11
 
12
+ from TTS.utils.audio.numpy_transforms import save_wav
13
+
14
 
15
  class Speaker(object):
16
  """Convert dict to object."""
 
290
  speaker_id=None,
291
  emotion=None,
292
  speed=1.0,
293
+ pipe_out=None,
294
  language=None,
295
  file_path: str = None,
296
  ) -> str:
 
303
  speaker_id (str): Speaker ID. If None, the speaker name is used.
304
  emotion (str): Emotion of the speaker. One of "Neutral", "Happy", "Sad", "Angry", "Dull".
305
  speed (float): Speed of the speech. 1.0 is normal speed.
306
+ pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
307
  language (str): Language of the text. If None, the default language of the speaker is used. Language is only
308
  supported by `XTTS-multilang` model. Currently supports en, de, es, fr, it, pt, pl. Defaults to "en".
309
  file_path (str): Path to save the file. If None, a temporary file is created.
 
311
  if file_path is None:
312
  file_path = tempfile.mktemp(".wav")
313
  wav, sr = self.tts(text, speaker_name, speaker_id, emotion, speed, language)
314
+ save_wav(wav=wav, path=file_path, sample_rate=sr, pipe_out=pipe_out)
315
  return file_path
316
 
317
 
TTS/TTS/tts/configs/xtts_config.py CHANGED
@@ -78,13 +78,13 @@ class XttsConfig(BaseTTSConfig):
78
  )
79
 
80
  # inference params
81
- temperature: float = 0.2
82
  length_penalty: float = 1.0
83
  repetition_penalty: float = 2.0
84
  top_k: int = 50
85
- top_p: float = 0.8
86
  cond_free_k: float = 2.0
87
  diffusion_temperature: float = 1.0
88
- num_gpt_outputs: int = 16
89
  decoder_iterations: int = 30
90
  decoder_sampler: str = "ddim"
 
78
  )
79
 
80
  # inference params
81
+ temperature: float = 0.85
82
  length_penalty: float = 1.0
83
  repetition_penalty: float = 2.0
84
  top_k: int = 50
85
+ top_p: float = 0.85
86
  cond_free_k: float = 2.0
87
  diffusion_temperature: float = 1.0
88
+ num_gpt_outputs: int = 1
89
  decoder_iterations: int = 30
90
  decoder_sampler: str = "ddim"
TTS/TTS/tts/layers/tortoise/tokenizer.py CHANGED
@@ -5,9 +5,13 @@ from tokenizers import Tokenizer
5
 
6
  from TTS.tts.utils.text.cleaners import english_cleaners
7
 
 
 
 
 
8
 
9
  class VoiceBpeTokenizer:
10
- def __init__(self, vocab_file=None, vocab_str=None):
11
  self.tokenizer = None
12
  if vocab_file is not None:
13
  self.tokenizer = Tokenizer.from_file(vocab_file)
 
5
 
6
  from TTS.tts.utils.text.cleaners import english_cleaners
7
 
8
+ DEFAULT_VOCAB_FILE = os.path.join(
9
+ os.path.dirname(os.path.realpath(__file__)), "../../utils/assets/tortoise/tokenizer.json"
10
+ )
11
+
12
 
13
  class VoiceBpeTokenizer:
14
+ def __init__(self, vocab_file=DEFAULT_VOCAB_FILE, vocab_str=None):
15
  self.tokenizer = None
16
  if vocab_file is not None:
17
  self.tokenizer = Tokenizer.from_file(vocab_file)
TTS/TTS/tts/layers/xtts/gpt.py CHANGED
@@ -172,7 +172,7 @@ class GPT(nn.Module):
172
  "heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
173
  }
174
 
175
- def init_gpt_for_inference(self, kv_cache=True):
176
  seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1
177
  gpt_config = GPT2Config(
178
  vocab_size=self.max_mel_tokens,
@@ -195,6 +195,17 @@ class GPT(nn.Module):
195
  )
196
  self.gpt.wte = self.mel_embedding
197
 
 
 
 
 
 
 
 
 
 
 
 
198
  def set_inputs_and_targets(self, input, start_token, stop_token):
199
  inp = F.pad(input, (1, 0), value=start_token)
200
  tar = F.pad(input, (0, 1), value=stop_token)
@@ -543,3 +554,14 @@ class GPT(nn.Module):
543
  if "return_dict_in_generate" in hf_generate_kwargs:
544
  return gen.sequences[:, gpt_inputs.shape[1] :], gen
545
  return gen[:, gpt_inputs.shape[1] :]
 
 
 
 
 
 
 
 
 
 
 
 
172
  "heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
173
  }
174
 
175
+ def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False):
176
  seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1
177
  gpt_config = GPT2Config(
178
  vocab_size=self.max_mel_tokens,
 
195
  )
196
  self.gpt.wte = self.mel_embedding
197
 
198
+ if use_deepspeed:
199
+ import deepspeed
200
+ self.ds_engine = deepspeed.init_inference(
201
+ model=self.gpt_inference.half(), # Transformers models
202
+ mp_size=1, # Number of GPU
203
+ dtype=torch.float32, # desired data type of output
204
+ replace_method="auto", # Lets DS autmatically identify the layer to replace
205
+ replace_with_kernel_inject=True, # replace the model with the kernel injector
206
+ )
207
+ self.gpt_inference = self.ds_engine.module.eval()
208
+
209
  def set_inputs_and_targets(self, input, start_token, stop_token):
210
  inp = F.pad(input, (1, 0), value=start_token)
211
  tar = F.pad(input, (0, 1), value=stop_token)
 
554
  if "return_dict_in_generate" in hf_generate_kwargs:
555
  return gen.sequences[:, gpt_inputs.shape[1] :], gen
556
  return gen[:, gpt_inputs.shape[1] :]
557
+
558
+ def get_generator(self, fake_inputs, **hf_generate_kwargs):
559
+ return self.gpt_inference.generate_stream(
560
+ fake_inputs,
561
+ bos_token_id=self.start_audio_token,
562
+ pad_token_id=self.stop_audio_token,
563
+ eos_token_id=self.stop_audio_token,
564
+ max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
565
+ do_stream=True,
566
+ **hf_generate_kwargs,
567
+ )
TTS/TTS/tts/layers/xtts/gpt_encoder_eren.py DELETED
@@ -1,658 +0,0 @@
1
- import functools
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- from transformers import GPT2Config, GPT2Model, GPT2PreTrainedModel
7
- from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
-
9
-
10
- def null_position_embeddings(range, dim):
11
- return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
12
-
13
-
14
- class GPT2InferenceModel(GPT2PreTrainedModel):
15
- """Override GPT2LMHeadModel to allow for prefix conditioning."""
16
-
17
- def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
18
- super().__init__(config)
19
- self.transformer = gpt
20
- self.pos_embedding = pos_emb
21
- self.embeddings = embeddings
22
- self.final_norm = norm
23
- self.lm_head = nn.Sequential(norm, linear)
24
- self.kv_cache = kv_cache
25
-
26
- def store_prefix_emb(self, prefix_emb):
27
- self.cached_prefix_emb = prefix_emb
28
-
29
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
30
- token_type_ids = kwargs.get("token_type_ids", None) # usually None
31
- if not self.kv_cache:
32
- past_key_values = None
33
-
34
- # only last token for inputs_ids if past is defined in kwargs
35
- if past_key_values is not None:
36
- input_ids = input_ids[:, -1].unsqueeze(-1)
37
- if token_type_ids is not None:
38
- token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
39
-
40
- attention_mask = kwargs.get("attention_mask", None)
41
- position_ids = kwargs.get("position_ids", None)
42
-
43
- if attention_mask is not None and position_ids is None:
44
- # create position_ids on the fly for batch generation
45
- position_ids = attention_mask.long().cumsum(-1) - 1
46
- position_ids.masked_fill_(attention_mask == 0, 1)
47
- if past_key_values is not None:
48
- position_ids = position_ids[:, -1].unsqueeze(-1)
49
- else:
50
- position_ids = None
51
- return {
52
- "input_ids": input_ids,
53
- "past_key_values": past_key_values,
54
- "use_cache": kwargs.get("use_cache"),
55
- "position_ids": position_ids,
56
- "attention_mask": attention_mask,
57
- "token_type_ids": token_type_ids,
58
- }
59
-
60
- def forward(
61
- self,
62
- input_ids=None,
63
- past_key_values=None,
64
- attention_mask=None,
65
- token_type_ids=None,
66
- position_ids=None,
67
- head_mask=None,
68
- inputs_embeds=None,
69
- encoder_hidden_states=None,
70
- encoder_attention_mask=None,
71
- labels=None,
72
- use_cache=None,
73
- output_attentions=None,
74
- output_hidden_states=None,
75
- return_dict=None,
76
- ):
77
- assert self.cached_prefix_emb is not None
78
- assert inputs_embeds is None # Not supported by this inference model.
79
- assert labels is None # Training not supported by this inference model.
80
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
81
-
82
- # assert len(past_key_values) + len(input_ids) == attention_mask.shape[1]
83
-
84
- # Create embedding
85
- prefix_len = self.cached_prefix_emb.shape[1]
86
- if input_ids.shape[1] != 1:
87
- gen_inputs = input_ids[:, prefix_len:]
88
- gen_emb = self.embeddings(gen_inputs)
89
- gen_emb = gen_emb + self.pos_embedding(gen_emb)
90
- if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]:
91
- prefix_emb = self.cached_prefix_emb.repeat_interleave(
92
- gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0
93
- )
94
- else:
95
- prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype)
96
- emb = torch.cat([prefix_emb, gen_emb], dim=1)
97
- else:
98
- emb = self.embeddings(input_ids)
99
- emb = emb + self.pos_embedding.get_fixed_embedding(
100
- attention_mask.shape[1] - (prefix_len + 1), attention_mask.device
101
- )
102
- transformer_outputs = self.transformer(
103
- inputs_embeds=emb,
104
- past_key_values=past_key_values,
105
- attention_mask=attention_mask,
106
- token_type_ids=token_type_ids,
107
- position_ids=position_ids,
108
- head_mask=head_mask,
109
- encoder_hidden_states=encoder_hidden_states,
110
- encoder_attention_mask=encoder_attention_mask,
111
- use_cache=use_cache,
112
- output_attentions=output_attentions,
113
- output_hidden_states=output_hidden_states,
114
- return_dict=return_dict,
115
- )
116
- hidden_states = transformer_outputs[0]
117
- lm_logits = self.lm_head(hidden_states)
118
-
119
- if not return_dict:
120
- return (lm_logits,) + transformer_outputs[1:]
121
-
122
- return CausalLMOutputWithCrossAttentions(
123
- loss=None,
124
- logits=lm_logits,
125
- past_key_values=transformer_outputs.past_key_values,
126
- hidden_states=transformer_outputs.hidden_states,
127
- attentions=transformer_outputs.attentions,
128
- cross_attentions=transformer_outputs.cross_attentions,
129
- )
130
-
131
- @staticmethod
132
- def _reorder_cache(past, beam_idx):
133
- """
134
- This function is used to re-order the :obj:`past_key_values` cache if
135
- :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
136
- called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
137
- """
138
- return tuple(
139
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
140
- for layer_past in past
141
- )
142
-
143
-
144
- class LearnedPositionEmbeddings(nn.Module):
145
- def __init__(self, seq_len, model_channels, init_std=0.02, relative=False):
146
- super().__init__()
147
- self.emb = nn.Embedding(seq_len, model_channels)
148
- nn.init.normal_(self.emb.weight, mean=0.0, std=init_std)
149
- self.relative = relative
150
-
151
- def forward(self, x):
152
- seq_len = x.shape[1]
153
- if self.relative:
154
- start = torch.randint(seq_len, (1,), device=x.device).item()
155
- positions = torch.arange(start, start + seq_len, device=x.device)
156
- else:
157
- positions = torch.arange(seq_len, device=x.device)
158
- return self.emb(positions)
159
-
160
- def get_fixed_embedding(self, ind, dev):
161
- return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
162
-
163
-
164
- def init_gpt(layers, model_channels, heads, max_mel_seq_len, max_text_seq_len, max_prompt_len, checkpointing):
165
- """
166
- Initializes a GPT-2 model and its position embeddings for a text-to-speech system.
167
-
168
- Args:
169
- layers (int): Number of layers in the GPT-2 model.
170
- model_channels (int): Dimension of the GPT-2 model.
171
- heads (int): Number of heads in the GPT-2 model.
172
- max_mel_seq_len (int): Maximum sequence length for the mel spectrogram.
173
- max_text_seq_len (int): Maximum sequence length for the text.
174
- max_prompt_len (int): Maximum length of the prompt.
175
- checkpointing (bool): Whether to use gradient checkpointing.
176
-
177
- Returns:
178
- gpt (GPT2Model): GPT-2 model.
179
- mel_pos_emb (LearnedPositionEmbeddings): Position embeddings for the mel spectrogram.
180
- text_pos_emb (LearnedPositionEmbeddings): Position embeddings for the text.
181
- """
182
- gpt_config = GPT2Config(
183
- vocab_size=123,
184
- n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
185
- n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
186
- n_embd=model_channels,
187
- n_layer=layers,
188
- n_head=heads,
189
- gradient_checkpointing=checkpointing,
190
- use_cache=not checkpointing,
191
- )
192
- gpt = GPT2Model(gpt_config)
193
-
194
- del gpt.wpe
195
- del gpt.wte
196
-
197
- gpt.wpe = functools.partial(null_position_embeddings, dim=model_channels)
198
-
199
- audio_pos_emb = (
200
- LearnedPositionEmbeddings(max_mel_seq_len, model_channels)
201
- if max_mel_seq_len != -1
202
- else functools.partial(null_position_embeddings, dim=model_channels)
203
- )
204
- text_pos_emb = (
205
- LearnedPositionEmbeddings(max_text_seq_len, model_channels)
206
- if max_mel_seq_len != -1
207
- else functools.partial(null_position_embeddings, dim=model_channels)
208
- )
209
-
210
- return gpt, audio_pos_emb, text_pos_emb
211
-
212
-
213
- class XTTSGPTEncoder(nn.Module):
214
- """XTTS GPT Encoder model implementation.
215
- Args:
216
- start_text_token (int): Index of the start token in the text vocabulary.
217
- stop_text_token (int): Index of the stop token in the text vocabulary.
218
- n_layers (int): Number of layers in the GPT-2 model.
219
- n_model_channels (int): Dimension of the GPT-2 model.
220
- n_heads (int): Number of heads in the GPT-2 model.
221
- max_text_tokens (int): Maximum number of text tokens.
222
- max_audio_tokens (int): Maximum number of audio tokens.
223
- max_prompt_tokens (int): Maximum number of prompt tokens.
224
- audio_len_compression (int): Compression factor for the audio length.
225
- number_text_tokens (int): Number of text tokens.
226
- number_audio_codes (int): Number of audio codes.
227
- start_mel_token (int): Index of the start token in the mel code vocabulary.
228
- stop_mel_token (int): Index of the stop token in the mel code vocabulary.
229
- checkpointing (bool): Whether or not to use gradient checkpointing at training.
230
- """
231
-
232
- _inference_flag = False
233
-
234
- def __init__(
235
- self,
236
- start_text_token=261,
237
- stop_text_token=0,
238
- n_layers=8,
239
- n_model_channels=512,
240
- n_heads=8,
241
- max_text_tokens=120,
242
- max_audio_tokens=250,
243
- max_prompt_tokens=70,
244
- audio_len_compression=1024,
245
- number_text_tokens=256,
246
- number_audio_codes=8194,
247
- start_mel_token=8192,
248
- stop_mel_token=8193,
249
- checkpointing=True,
250
- label_smoothing=0.0,
251
- ):
252
- super().__init__()
253
-
254
- self.label_smoothing = label_smoothing
255
- self.number_text_tokens = number_text_tokens
256
- self.start_text_token = start_text_token
257
- self.stop_text_token = stop_text_token
258
- self.number_audio_codes = number_audio_codes
259
- self.start_mel_token = start_mel_token
260
- self.stop_mel_token = stop_mel_token
261
- self.start_prompt_token = start_mel_token
262
- self.stop_prompt_token = stop_mel_token
263
- self.n_layers = n_layers
264
- self.n_heads = n_heads
265
- self.n_model_channels = n_model_channels
266
- self.max_audio_tokens = -1 if max_audio_tokens == -1 else max_audio_tokens + 2 + self.max_conditioning_inputs
267
- self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2
268
- self.max_prompt_tokens = max_prompt_tokens
269
- self.audio_len_compression = audio_len_compression
270
-
271
- # embedding layers
272
- self.text_embedding = nn.Embedding(self.number_text_tokens, n_model_channels)
273
- self.audio_embedding = nn.Embedding(self.number_audio_codes, n_model_channels)
274
- self.prompt_embedding = nn.Embedding(self.number_audio_codes, n_model_channels)
275
- self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, n_model_channels)
276
-
277
- # initialize the GPT-2 model
278
- (
279
- self.gpt,
280
- self.audio_pos_embedding,
281
- self.text_pos_embedding,
282
- ) = init_gpt(
283
- n_layers,
284
- n_model_channels,
285
- n_heads,
286
- self.max_audio_tokens,
287
- self.max_text_tokens,
288
- self.max_prompt_tokens,
289
- checkpointing,
290
- )
291
-
292
- # output layers
293
- self.final_norm = nn.LayerNorm(n_model_channels)
294
- self.text_head = nn.Linear(n_model_channels, self.number_text_tokens)
295
- self.mel_head = nn.Linear(n_model_channels, self.number_audio_codes)
296
-
297
- def get_grad_norm_parameter_groups(self):
298
- return {
299
- "conditioning_encoder": list(self.conditioning_encoder.parameters()),
300
- "gpt": list(self.gpt.parameters()),
301
- "heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
302
- }
303
-
304
- def init_model_for_inference(self, kv_cache=True, use_deepspeed=False, use_deepspeed_f16=False):
305
- self._inference_flag = True
306
- seq_length = self.max_prompt_tokens + self.max_audio_tokens + self.max_text_tokens
307
- gpt_config = GPT2Config(
308
- vocab_size=self.max_audio_tokens,
309
- n_positions=seq_length,
310
- n_ctx=seq_length,
311
- n_embd=self.n_model_channels,
312
- n_layer=self.n_layers,
313
- n_head=self.n_heads,
314
- gradient_checkpointing=False,
315
- use_cache=True,
316
- )
317
- self.inference_model = GPT2InferenceModel(
318
- gpt_config,
319
- self.gpt,
320
- self.audio_pos_embedding,
321
- self.audio_embedding,
322
- self.final_norm,
323
- self.mel_head,
324
- kv_cache=kv_cache,
325
- )
326
- self.gpt.wte = self.audio_embedding
327
-
328
- def set_inputs_and_targets(self, input, start_token, stop_token):
329
- inp = F.pad(input, (1, 0), value=start_token)
330
- tar = F.pad(input, (0, 1), value=stop_token)
331
- return inp, tar
332
-
333
- def set_audio_tokens_padding(self, audio_tokens, audio_token_lens):
334
- # Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
335
- for b in range(len(audio_token_lens)):
336
- actual_end = audio_token_lens[b]
337
- if actual_end < audio_tokens.shape[-1]:
338
- audio_tokens[b, actual_end:] = self.stop_mel_token
339
- return audio_tokens
340
-
341
- def get_logits(
342
- self,
343
- speech_conditioning_inputs,
344
- first_inputs,
345
- first_head,
346
- second_inputs=None,
347
- second_head=None,
348
- prompt=None,
349
- get_attns=False,
350
- return_latent=False,
351
- attn_mask_text=None,
352
- attn_mask_mel=None,
353
- ):
354
- if prompt is not None and speech_conditioning_inputs is not None:
355
- offset = speech_conditioning_inputs.shape[1] + prompt.shape[1]
356
- if second_inputs is not None:
357
- emb = torch.cat(
358
- [speech_conditioning_inputs, prompt, first_inputs, second_inputs],
359
- dim=1,
360
- )
361
- else:
362
- emb = torch.cat([speech_conditioning_inputs, prompt, first_inputs], dim=1)
363
- elif speech_conditioning_inputs is not None:
364
- offset = speech_conditioning_inputs.shape[1]
365
- if second_inputs is not None:
366
- emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
367
- else:
368
- emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
369
- elif prompt is not None:
370
- offset = prompt.shape[1]
371
- if second_inputs is not None:
372
- emb = torch.cat([prompt, first_inputs, second_inputs], dim=1)
373
- else:
374
- emb = torch.cat([prompt, first_inputs], dim=1)
375
-
376
- # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
377
- attn_mask = None
378
- if attn_mask_text is not None:
379
- attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
380
- if prompt is not None:
381
- attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
382
- attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1)
383
-
384
- gpt_out = self.gpt(
385
- inputs_embeds=emb,
386
- return_dict=True,
387
- output_attentions=get_attns,
388
- attention_mask=attn_mask,
389
- )
390
-
391
- if get_attns:
392
- return gpt_out.attentions
393
-
394
- enc = gpt_out.last_hidden_state[:, offset:]
395
- enc = self.final_norm(enc)
396
-
397
- if return_latent:
398
- return enc[:, : first_inputs.shape[1]], enc[:, -second_inputs.shape[1] :]
399
-
400
- first_logits = enc[:, : first_inputs.shape[1]]
401
- first_logits = first_head(first_logits)
402
- first_logits = first_logits.permute(0, 2, 1)
403
- if second_inputs is not None:
404
- second_logits = enc[:, -second_inputs.shape[1] :]
405
- second_logits = second_head(second_logits)
406
- second_logits = second_logits.permute(0, 2, 1)
407
- return first_logits, second_logits
408
- else:
409
- return first_logits
410
-
411
- def get_conditioning(self, speech_conditioning_input):
412
- speech_conditioning_input = (
413
- speech_conditioning_input.unsqueeze(1)
414
- if len(speech_conditioning_input.shape) == 3
415
- else speech_conditioning_input
416
- )
417
- conds = []
418
- for j in range(speech_conditioning_input.shape[1]):
419
- conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
420
- conds = torch.stack(conds, dim=1)
421
- conds = conds.mean(dim=1)
422
- return conds
423
-
424
- def get_prompts(self, prompt_codes):
425
- prompt = F.pad(prompt_codes, (1, 0), value=self.start_prompt_token)
426
- prompt = F.pad(prompt_codes, (0, 1), value=self.stop_prompt_token)
427
- return prompt
428
-
429
- def forward(
430
- self,
431
- text_inputs,
432
- text_lengths,
433
- audio_codes,
434
- wav_lengths,
435
- prompt_codes,
436
- return_attentions=False,
437
- return_latent=False,
438
- ):
439
- max_text_len = text_lengths.max()
440
-
441
- # Due to the convolution in DVAE, codes do not end with silence at the right place. Rather it predicts some intermediate values
442
- # Like [..., 186, 45, 45, 83] where actually it should end with 186.
443
- # We take last 3 codes to prevent abrupt ending of the audio.
444
- # TODO: This is might need some testing.
445
- mel_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 3
446
-
447
- # If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
448
- max_mel_len = mel_lengths.max()
449
-
450
- if max_mel_len > audio_codes.shape[-1]:
451
- audio_codes = F.pad(audio_codes, (0, max_mel_len - audio_codes.shape[-1]))
452
-
453
- # silence aware lengths, skip the silence tokens at the end of the mel codes.
454
- silence = True
455
- for idx, l in enumerate(mel_lengths):
456
- length = l.item()
457
- while silence:
458
- if audio_codes[idx, length - 1] != 83:
459
- break
460
- length -= 1
461
- mel_lengths[idx] = length
462
-
463
- # Lovely assertions
464
- assert (
465
- max_mel_len <= audio_codes.shape[-1]
466
- ), f" ❗ max_mel_len ({max_mel_len}) > audio_codes.shape[-1] ({audio_codes.shape[-1]})"
467
- assert (
468
- max_text_len <= text_inputs.shape[-1]
469
- ), f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})"
470
-
471
- # Append stop token to text inputs
472
- text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token)
473
-
474
- # Append silence token to mel codes
475
- audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_mel_token)
476
-
477
- # Pad mel codes with STOP_MEL_TOKEN
478
- audio_codes = self.set_mel_padding(audio_codes, mel_lengths)
479
-
480
- # Compute speech conditioning input
481
- conds = None
482
- if speech_conditioning_input is not None:
483
- if not return_latent:
484
- # Compute speech conditioning input
485
- speech_conditioning_input = (
486
- speech_conditioning_input.unsqueeze(1)
487
- if len(speech_conditioning_input.shape) == 3
488
- else speech_conditioning_input
489
- )
490
-
491
- conds = []
492
- for j in range(speech_conditioning_input.shape[1]):
493
- conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
494
- conds = torch.stack(conds, dim=1)
495
- if self.average_conditioning_embeddings:
496
- conds = conds.mean(dim=1).unsqueeze(1)
497
- else:
498
- # already computed
499
- conds = speech_conditioning_input.unsqueeze(1)
500
-
501
- # Build input and target tensors
502
- # Prepend start token to inputs and append stop token to targets
503
- text_inputs, _ = self.set_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
504
- audio_codes, _ = self.set_inputs_and_targets(audio_codes, self.start_mel_token, self.stop_mel_token)
505
-
506
- # Set attn_mask
507
- attn_mask_text = None
508
- attn_mask_mel = None
509
- if not return_latent:
510
- attn_mask_text = torch.ones(
511
- text_inputs.shape[0],
512
- text_inputs.shape[1],
513
- dtype=torch.bool,
514
- device=text_inputs.device,
515
- )
516
- attn_mask_mel = torch.ones(
517
- audio_codes.shape[0],
518
- audio_codes.shape[1],
519
- dtype=torch.bool,
520
- device=audio_codes.device,
521
- )
522
-
523
- for idx, l in enumerate(text_lengths):
524
- attn_mask_text[idx, l + 1 :] = 0.0
525
-
526
- for idx, l in enumerate(mel_lengths):
527
- attn_mask_mel[idx, l + 1 :] = 0.0
528
-
529
- # Compute text embeddings + positional embeddings
530
- # print(" > text input latent:", text_inputs)
531
- text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
532
-
533
- # Compute mel embeddings + positional embeddings
534
- audio_emb = self.audio_embedding(audio_codes) + self.audio_embedding(audio_codes)
535
-
536
- # Compute prompt embeddings + positional embeddings
537
- prompt = self.get_prompts(prompt_codes)
538
-
539
- # prompt_emb = self.audio_embedding(prompt).detach() + self.mel_pos_embedding(prompt).detach()
540
- prompt_emb = self.prompt_embedding(prompt) + self.prompt_pos_embedding(prompt)
541
-
542
- # dropout prompt embeddings
543
- prompt_emb = F.dropout(prompt_emb, p=0.1, training=self.training)
544
-
545
- # Get logits
546
- sub = -4 # don't ask me why 😄
547
- if self.training:
548
- sub = -1
549
- _, audio_logits = self.get_logits(
550
- conds,
551
- text_emb,
552
- self.text_head,
553
- audio_emb,
554
- self.mel_head,
555
- prompt=prompt_emb,
556
- get_attns=return_attentions,
557
- return_latent=return_latent,
558
- attn_mask_text=attn_mask_text,
559
- attn_mask_mel=attn_mask_mel,
560
- )
561
- return audio_logits[:, :sub] # sub to prevent bla.
562
-
563
- def compute_embeddings(
564
- self,
565
- speech_conditioning_latent,
566
- text_inputs,
567
- input_tokens=None,
568
- prompt_codes=None,
569
- pad_input_text=False,
570
- ):
571
- """Compute all the embeddings needed for inference."""
572
- if pad_input_text and text_inputs.shape[1] < 250:
573
- text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token)
574
- else:
575
- text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
576
- text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
577
-
578
- emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
579
-
580
- print(" > Text inputs:", text_inputs)
581
- if prompt_codes is not None:
582
- prompt_codes = self.get_prompts(prompt_codes)
583
- # prompt_emb = self.audio_embedding(prompt_codes) + self.mel_pos_embedding(prompt_codes)
584
- prompt_emb = self.prompt_embedding(prompt_codes) + self.prompt_pos_embedding(prompt_codes)
585
-
586
- print(" > Prompt inputs:", prompt_codes)
587
- print(" > Prompt inputs shape:", prompt_codes.shape)
588
- emb = torch.cat([prompt_emb, emb], dim=1)
589
-
590
- if speech_conditioning_latent is not None:
591
- conds = speech_conditioning_latent.unsqueeze(1)
592
- emb = torch.cat([conds, emb], dim=1)
593
-
594
- self.inference_model.store_prefix_emb(emb)
595
-
596
- fake_inputs = torch.full(
597
- (
598
- emb.shape[0],
599
- emb.shape[1] + 1, # +1 for the start_mel_token
600
- ),
601
- fill_value=1,
602
- dtype=torch.long,
603
- device=text_inputs.device,
604
- )
605
- fake_inputs[:, -1] = self.start_mel_token
606
-
607
- if input_tokens is not None:
608
- fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1)
609
- return fake_inputs
610
-
611
- def inference(
612
- self,
613
- text_inputs,
614
- input_tokens=None,
615
- prompt_codes=None,
616
- pad_input_text=False,
617
- **hf_generate_kwargs,
618
- ):
619
- if pad_input_text and text_inputs.shape[1] < 250:
620
- text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token)
621
- else:
622
- text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
623
- text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
624
-
625
- emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
626
-
627
- if prompt_codes is not None:
628
- prompt_codes = self.get_prompts(prompt_codes)
629
- prompt_emb = self.prompt_embedding(prompt_codes) + self.prompt_pos_embedding(prompt_codes)
630
- emb = torch.cat([prompt_emb, emb], dim=1)
631
-
632
- self.inference_model.store_prefix_emb(emb)
633
-
634
- fake_inputs = torch.full(
635
- (
636
- emb.shape[0],
637
- emb.shape[1] + 1, # +1 for the start_mel_token
638
- ),
639
- fill_value=1,
640
- dtype=torch.long,
641
- device=text_inputs.device,
642
- )
643
- fake_inputs[:, -1] = self.start_mel_token
644
-
645
- if input_tokens is not None:
646
- fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1)
647
-
648
- gen = self.inference_model.generate(
649
- fake_inputs,
650
- bos_token_id=self.start_mel_token,
651
- pad_token_id=self.stop_mel_token,
652
- eos_token_id=self.stop_mel_token,
653
- max_length=self.max_audio_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
654
- **hf_generate_kwargs,
655
- )
656
- if "return_dict_in_generate" in hf_generate_kwargs:
657
- return gen.sequences[:, fake_inputs.shape[1] :], gen
658
- return gen[:, fake_inputs.shape[1] :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
TTS/TTS/tts/layers/xtts/gpt_encoder_old.py DELETED
@@ -1,1057 +0,0 @@
1
- import functools
2
- import math
3
- import random
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
-
9
- try:
10
- import deepspeed
11
- from deepspeed.ops.transformer.inference import DeepSpeedTransformerInferenceKernel
12
- except ImportError:
13
- pass
14
-
15
- import dlas.codes.torch_intermediary as ml
16
- from dlas.codes.models.arch_util import AttentionBlock
17
- from dlas.codes.trainer.networks import register_model
18
- from dlas.codes.utils.transformers.stream_generator import init_stream_support
19
- from dlas.codes.utils.util import opt_get
20
- from transformers import GPT2Config, GPT2PreTrainedModel
21
- from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
22
-
23
- init_stream_support()
24
-
25
-
26
- def null_position_embeddings(range, dim):
27
- return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
28
-
29
-
30
- class ResBlock(nn.Module):
31
- """
32
- Basic residual convolutional block that uses GroupNorm.
33
- """
34
-
35
- def __init__(self, chan):
36
- super().__init__()
37
- self.net = nn.Sequential(
38
- nn.Conv1d(chan, chan, kernel_size=3, padding=1),
39
- nn.GroupNorm(chan // 8, chan),
40
- nn.ReLU(),
41
- nn.Conv1d(chan, chan, kernel_size=3, padding=1),
42
- nn.GroupNorm(chan // 8, chan),
43
- )
44
-
45
- def forward(self, x):
46
- return F.relu(self.net(x) + x)
47
-
48
-
49
- class GPT2InferenceModel(GPT2PreTrainedModel):
50
- """Override GPT2LMHeadModel to allow for prefix conditioning."""
51
-
52
- def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
53
- super().__init__(config)
54
- self.transformer = gpt
55
- self.pos_embedding = pos_emb
56
- self.embeddings = embeddings
57
- self.final_norm = norm
58
- self.lm_head = nn.Sequential(norm, linear)
59
- self.kv_cache = kv_cache
60
-
61
- def store_prefix_emb(self, prefix_emb):
62
- self.cached_prefix_emb = prefix_emb
63
-
64
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
65
- token_type_ids = kwargs.get("token_type_ids", None) # usually None
66
- if not self.kv_cache:
67
- past_key_values = None
68
-
69
- # only last token for inputs_ids if past is defined in kwargs
70
- if past_key_values is not None:
71
- input_ids = input_ids[:, -1].unsqueeze(-1)
72
- if token_type_ids is not None:
73
- token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
74
-
75
- attention_mask = kwargs.get("attention_mask", None)
76
- position_ids = kwargs.get("position_ids", None)
77
-
78
- if attention_mask is not None and position_ids is None:
79
- # create position_ids on the fly for batch generation
80
- position_ids = attention_mask.long().cumsum(-1) - 1
81
- position_ids.masked_fill_(attention_mask == 0, 1)
82
- if past_key_values is not None:
83
- position_ids = position_ids[:, -1].unsqueeze(-1)
84
- else:
85
- position_ids = None
86
- return {
87
- "input_ids": input_ids,
88
- "past_key_values": past_key_values,
89
- "use_cache": kwargs.get("use_cache"),
90
- "position_ids": position_ids,
91
- "attention_mask": attention_mask,
92
- "token_type_ids": token_type_ids,
93
- }
94
-
95
- def forward(
96
- self,
97
- input_ids=None,
98
- past_key_values=None,
99
- attention_mask=None,
100
- token_type_ids=None,
101
- position_ids=None,
102
- head_mask=None,
103
- inputs_embeds=None,
104
- encoder_hidden_states=None,
105
- encoder_attention_mask=None,
106
- labels=None,
107
- use_cache=None,
108
- output_attentions=None,
109
- output_hidden_states=None,
110
- return_dict=None,
111
- ):
112
- assert self.cached_prefix_emb is not None
113
- assert inputs_embeds is None # Not supported by this inference model.
114
- assert labels is None # Training not supported by this inference model.
115
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
116
-
117
- # assert len(past_key_values) + len(input_ids) == attention_mask.shape[1]
118
-
119
- # Create embedding
120
- prefix_len = self.cached_prefix_emb.shape[1]
121
- if input_ids.shape[1] != 1:
122
- gen_inputs = input_ids[:, prefix_len:]
123
- gen_emb = self.embeddings(gen_inputs)
124
- gen_emb = gen_emb + self.pos_embedding(gen_emb)
125
- if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]:
126
- prefix_emb = self.cached_prefix_emb.repeat_interleave(
127
- gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0
128
- )
129
- else:
130
- prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype)
131
- emb = torch.cat([prefix_emb, gen_emb], dim=1)
132
- else:
133
- emb = self.embeddings(input_ids)
134
- emb = emb + self.pos_embedding.get_fixed_embedding(
135
- attention_mask.shape[1] - (prefix_len + 1), attention_mask.device
136
- )
137
- transformer_outputs = self.transformer(
138
- inputs_embeds=emb,
139
- past_key_values=past_key_values,
140
- attention_mask=attention_mask,
141
- token_type_ids=token_type_ids,
142
- position_ids=position_ids,
143
- head_mask=head_mask,
144
- encoder_hidden_states=encoder_hidden_states,
145
- encoder_attention_mask=encoder_attention_mask,
146
- use_cache=use_cache,
147
- output_attentions=output_attentions,
148
- output_hidden_states=output_hidden_states,
149
- return_dict=return_dict,
150
- )
151
- hidden_states = transformer_outputs[0]
152
- lm_logits = self.lm_head(hidden_states)
153
-
154
- if not return_dict:
155
- return (lm_logits,) + transformer_outputs[1:]
156
-
157
- return CausalLMOutputWithCrossAttentions(
158
- loss=None,
159
- logits=lm_logits,
160
- past_key_values=transformer_outputs.past_key_values,
161
- hidden_states=transformer_outputs.hidden_states,
162
- attentions=transformer_outputs.attentions,
163
- cross_attentions=transformer_outputs.cross_attentions,
164
- )
165
-
166
- @staticmethod
167
- def _reorder_cache(past, beam_idx):
168
- """
169
- This function is used to re-order the :obj:`past_key_values` cache if
170
- :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
171
- called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
172
- """
173
- return tuple(
174
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
175
- for layer_past in past
176
- )
177
-
178
-
179
- class ConditioningEncoder(nn.Module):
180
- def __init__(
181
- self,
182
- spec_dim,
183
- embedding_dim,
184
- attn_blocks=6,
185
- num_attn_heads=4,
186
- do_checkpointing=False,
187
- mean=False,
188
- ):
189
- super().__init__()
190
- attn = []
191
- self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
192
- for a in range(attn_blocks):
193
- attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=do_checkpointing))
194
- self.attn = nn.Sequential(*attn)
195
- self.dim = embedding_dim
196
- self.do_checkpointing = do_checkpointing
197
- self.mean = mean
198
-
199
- def forward(self, x):
200
- h = self.init(x)
201
- h = self.attn(h)
202
- if self.mean:
203
- return h.mean(dim=2)
204
- else:
205
- return h[:, :, 0]
206
-
207
-
208
- class LearnedPositionEmbeddings(nn.Module):
209
- def __init__(self, seq_len, model_dim, init=0.02, relative=False):
210
- super().__init__()
211
- # nn.Embedding
212
- self.emb = torch.nn.Embedding(seq_len, model_dim)
213
- # Initializing this way is standard for GPT-2
214
- self.emb.weight.data.normal_(mean=0.0, std=init)
215
- self.relative = relative
216
- self.seq_len = seq_len
217
-
218
- def forward(self, x):
219
- sl = x.shape[1]
220
- if self.relative:
221
- start = random.randint(sl, self.seq_len) - sl
222
- return self.emb(torch.arange(start, start + sl, device=x.device))
223
- else:
224
- return self.emb(torch.arange(0, sl, device=x.device))
225
-
226
- def get_fixed_embedding(self, ind, dev):
227
- return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
228
-
229
-
230
- def build_hf_gpt_transformer(
231
- layers,
232
- model_dim,
233
- heads,
234
- max_mel_seq_len,
235
- max_text_seq_len,
236
- max_prompt_len,
237
- checkpointing,
238
- ):
239
- """
240
- GPT-2 implemented by the HuggingFace library.
241
- """
242
- from transformers import GPT2Config, GPT2Model
243
-
244
- gpt_config = GPT2Config(
245
- vocab_size=256, # Unused.
246
- n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len,
247
- n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len,
248
- n_embd=model_dim,
249
- n_layer=layers,
250
- n_head=heads,
251
- gradient_checkpointing=checkpointing,
252
- use_cache=not checkpointing,
253
- )
254
- gpt = GPT2Model(gpt_config)
255
- # Override the built in positional embeddings
256
- del gpt.wpe
257
- gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
258
- # Built-in token embeddings are unused.
259
- del gpt.wte
260
-
261
- # def _attn(self, query, key, value, attention_mask=None, head_mask=None):
262
- # attn_output = torch.nn.functional.scaled_dot_product_attention(
263
- # query, key, value, dropout_p=self.attn_dropout.p, is_causal=True
264
- # )
265
- # return attn_output, None
266
-
267
- # for i in range(len(gpt.h)):
268
- # gpt.h[i].attn._attn = types.MethodType(
269
- # _attn, gpt.h[i].attn
270
- # )
271
-
272
- mel_pos_emb = (
273
- LearnedPositionEmbeddings(max_mel_seq_len, model_dim)
274
- if max_mel_seq_len != -1
275
- else functools.partial(null_position_embeddings, dim=model_dim)
276
- )
277
- text_pos_emb = (
278
- LearnedPositionEmbeddings(max_text_seq_len, model_dim)
279
- if max_mel_seq_len != -1
280
- else functools.partial(null_position_embeddings, dim=model_dim)
281
- )
282
- # gpt = torch.compile(gpt, mode="reduce-overhead", fullgraph=True)
283
- return gpt, mel_pos_emb, text_pos_emb, None, None
284
-
285
-
286
- class MelEncoder(nn.Module):
287
- def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
288
- super().__init__()
289
- self.channels = channels
290
- self.encoder = nn.Sequential(
291
- nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
292
- nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
293
- nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
294
- nn.GroupNorm(channels // 16, channels // 2),
295
- nn.ReLU(),
296
- nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
297
- nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
298
- nn.GroupNorm(channels // 8, channels),
299
- nn.ReLU(),
300
- nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
301
- )
302
- self.reduction = 4
303
-
304
- def forward(self, x):
305
- for e in self.encoder:
306
- x = e(x)
307
- return x.permute(0, 2, 1)
308
-
309
-
310
- class UnifiedVoice(nn.Module):
311
- def __init__(
312
- self,
313
- start_text_token=261,
314
- stop_text_token=0,
315
- layers=8,
316
- model_dim=512,
317
- heads=8,
318
- max_text_tokens=120,
319
- max_mel_tokens=250,
320
- max_prompt_tokens=70,
321
- max_conditioning_inputs=1,
322
- mel_length_compression=1024,
323
- number_text_tokens=256,
324
- number_mel_codes=8194,
325
- start_mel_token=8192,
326
- stop_mel_token=8193,
327
- train_solo_embeddings=False,
328
- use_mel_codes_as_input=True,
329
- checkpointing=True,
330
- average_conditioning_embeddings=False,
331
- freeze_everything_but_position_embeddings=False,
332
- freeze_conditioning_encoder=False,
333
- tortoise_compat=True,
334
- label_smoothing=0.0,
335
- ):
336
- """
337
- Args:
338
- layers: Number of layers in transformer stack.
339
- model_dim: Operating dimensions of the transformer
340
- heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
341
- max_text_tokens: Maximum number of text tokens that will be encountered by model.
342
- max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
343
- max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
344
- mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
345
- number_text_tokens:
346
- start_text_token:
347
- stop_text_token:
348
- number_mel_codes:
349
- start_mel_token:
350
- stop_mel_token:
351
- train_solo_embeddings:
352
- use_mel_codes_as_input:
353
- checkpointing:
354
- average_conditioning_embeddings: Whether or not conditioning embeddings should be averaged, instead of fed piecewise into the model.
355
- """
356
- super().__init__()
357
-
358
- self.label_smoothing = label_smoothing
359
- self.number_text_tokens = number_text_tokens
360
- self.start_text_token = start_text_token
361
- self.stop_text_token = stop_text_token
362
- self.number_mel_codes = number_mel_codes
363
- self.start_mel_token = start_mel_token
364
- self.stop_mel_token = stop_mel_token
365
- self.start_prompt_token = start_mel_token
366
- self.stop_prompt_token = stop_mel_token
367
- self.layers = layers
368
- self.heads = heads
369
- self.model_dim = model_dim
370
- self.max_conditioning_inputs = max_conditioning_inputs
371
- self.max_mel_tokens = -1 if max_mel_tokens == -1 else max_mel_tokens + 2 + self.max_conditioning_inputs
372
- self.max_text_tokens = -1 if max_text_tokens == -1 else max_text_tokens + 2
373
- self.max_prompt_tokens = max_prompt_tokens
374
- self.mel_length_compression = mel_length_compression
375
- # self.conditioning_encoder = ConditioningEncoder(
376
- # 80, model_dim, num_attn_heads=heads
377
- # )
378
- self.average_conditioning_embeddings = average_conditioning_embeddings
379
- self.tortoise_compat = tortoise_compat # credit to https://github.com/152334H/DL-Art-School/commit/ae80992817059acf6eef38a680efa5124cee570b
380
- # nn.Embedding
381
- self.text_embedding = ml.Embedding(self.number_text_tokens, model_dim)
382
- if use_mel_codes_as_input:
383
- # nn.Embedding
384
- self.mel_embedding = ml.Embedding(self.number_mel_codes, model_dim)
385
- else:
386
- self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
387
- (
388
- self.gpt,
389
- self.mel_pos_embedding,
390
- self.text_pos_embedding,
391
- self.mel_layer_pos_embedding,
392
- self.text_layer_pos_embedding,
393
- ) = build_hf_gpt_transformer(
394
- layers,
395
- model_dim,
396
- heads,
397
- self.max_mel_tokens,
398
- self.max_text_tokens,
399
- self.max_prompt_tokens,
400
- checkpointing,
401
- )
402
- if train_solo_embeddings:
403
- self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
404
- self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True)
405
- else:
406
- self.mel_solo_embedding = 0
407
- self.text_solo_embedding = 0
408
-
409
- self.final_norm = nn.LayerNorm(model_dim)
410
- self.text_head = ml.Linear(model_dim, self.number_text_tokens)
411
- self.mel_head = ml.Linear(model_dim, self.number_mel_codes)
412
-
413
- # Initialize the embeddings per the GPT-2 scheme
414
- embeddings = [self.text_embedding]
415
- if use_mel_codes_as_input:
416
- embeddings.append(self.mel_embedding)
417
- for module in embeddings:
418
- module.weight.data.normal_(mean=0.0, std=0.02)
419
-
420
- if freeze_conditioning_encoder:
421
- print(" > Freezing conditioning encoder.")
422
- for p in self.conditioning_encoder.parameters():
423
- p.requires_grad = False
424
- p.DO_NOT_TRAIN = True
425
-
426
- if freeze_everything_but_position_embeddings:
427
- for p in self.parameters():
428
- p.requires_grad = False
429
- p.DO_NOT_TRAIN = True
430
- for m in [self.mel_pos_embedding, self.text_pos_embedding]:
431
- for p in m.parameters():
432
- del p.DO_NOT_TRAIN
433
- p.requires_grad = True
434
-
435
- def get_grad_norm_parameter_groups(self):
436
- return {
437
- "conditioning_encoder": list(self.conditioning_encoder.parameters()),
438
- "gpt": list(self.gpt.parameters()),
439
- "heads": list(self.text_head.parameters()) + list(self.mel_head.parameters()),
440
- }
441
-
442
- def post_init_gpt2_config(self, kv_cache=True, use_deepspeed=False, use_deepspeed_f16=False):
443
- seq_length = self.max_prompt_tokens + self.max_mel_tokens + self.max_text_tokens + 1
444
- gpt_config = GPT2Config(
445
- vocab_size=self.max_mel_tokens,
446
- n_positions=seq_length,
447
- n_ctx=seq_length,
448
- n_embd=self.model_dim,
449
- n_layer=self.layers,
450
- n_head=self.heads,
451
- gradient_checkpointing=False,
452
- use_cache=True,
453
- )
454
- self.inference_model = GPT2InferenceModel(
455
- gpt_config,
456
- self.gpt,
457
- self.mel_pos_embedding,
458
- self.mel_embedding,
459
- self.final_norm,
460
- self.mel_head,
461
- kv_cache=kv_cache,
462
- )
463
- # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
464
- self.gpt.wte = self.mel_embedding
465
-
466
- if use_deepspeed:
467
- # init deepspeed inference engine
468
- if use_deepspeed_f16:
469
- self.gpt.wte = self.mel_embedding.half()
470
- self.gpt.wpe = self.mel_pos_embedding.half()
471
- self.ds_engine = deepspeed.init_inference(
472
- model=self.inference_model.half(), # Transformers models
473
- mp_size=1, # Number of GPU
474
- dtype=torch.float16 if use_deepspeed_f16 else torch.float32, # desired data type of output
475
- replace_method="auto", # Lets DS autmatically identify the layer to replace
476
- replace_with_kernel_inject=True, # replace the model with the kernel injector
477
- )
478
- self.inference_model = self.ds_engine.module.eval()
479
-
480
- def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
481
- inp = F.pad(input, (1, 0), value=start_token)
482
- tar = F.pad(input, (0, 1), value=stop_token)
483
- return inp, tar
484
-
485
- def set_mel_padding(self, mel_input_tokens, mel_lengths):
486
- """
487
- Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
488
- that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
489
- preformatting to create a working TTS model.
490
- """
491
- # Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
492
- for b in range(len(mel_lengths)):
493
- actual_end = mel_lengths[b]
494
- if actual_end < mel_input_tokens.shape[-1]:
495
- mel_input_tokens[b, actual_end:] = self.stop_mel_token
496
- return mel_input_tokens
497
-
498
- def get_logits(
499
- self,
500
- speech_conditioning_inputs,
501
- first_inputs,
502
- first_head,
503
- second_inputs=None,
504
- second_head=None,
505
- prompt=None,
506
- get_attns=False,
507
- return_latent=False,
508
- attn_mask_text=None,
509
- attn_mask_mel=None,
510
- ):
511
- if prompt is not None and speech_conditioning_inputs is not None:
512
- offset = speech_conditioning_inputs.shape[1] + prompt.shape[1]
513
- if second_inputs is not None:
514
- emb = torch.cat(
515
- [speech_conditioning_inputs, prompt, first_inputs, second_inputs],
516
- dim=1,
517
- )
518
- else:
519
- emb = torch.cat([speech_conditioning_inputs, prompt, first_inputs], dim=1)
520
- elif speech_conditioning_inputs is not None:
521
- offset = speech_conditioning_inputs.shape[1]
522
- if second_inputs is not None:
523
- emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
524
- else:
525
- emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
526
- elif prompt is not None:
527
- offset = prompt.shape[1]
528
- if second_inputs is not None:
529
- emb = torch.cat([prompt, first_inputs, second_inputs], dim=1)
530
- else:
531
- emb = torch.cat([prompt, first_inputs], dim=1)
532
-
533
- # with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
534
- attn_mask = None
535
- if attn_mask_text is not None:
536
- attn_mask = torch.cat([attn_mask_text, attn_mask_mel], dim=1)
537
- if prompt is not None:
538
- attn_mask_prompt = torch.ones(prompt.shape[0], offset, dtype=torch.bool, device=emb.device)
539
- attn_mask = torch.cat([attn_mask_prompt, attn_mask], dim=1)
540
-
541
- gpt_out = self.gpt(
542
- inputs_embeds=emb,
543
- return_dict=True,
544
- output_attentions=get_attns,
545
- attention_mask=attn_mask,
546
- )
547
-
548
- if get_attns:
549
- return gpt_out.attentions
550
-
551
- enc = gpt_out.last_hidden_state[:, offset:]
552
- enc = self.final_norm(enc)
553
-
554
- if return_latent:
555
- return enc[:, : first_inputs.shape[1]], enc[:, -second_inputs.shape[1] :]
556
-
557
- first_logits = enc[:, : first_inputs.shape[1]]
558
- first_logits = first_head(first_logits)
559
- first_logits = first_logits.permute(0, 2, 1)
560
- if second_inputs is not None:
561
- second_logits = enc[:, -second_inputs.shape[1] :]
562
- second_logits = second_head(second_logits)
563
- second_logits = second_logits.permute(0, 2, 1)
564
- return first_logits, second_logits
565
- else:
566
- return first_logits
567
-
568
- def get_conditioning(self, speech_conditioning_input):
569
- speech_conditioning_input = (
570
- speech_conditioning_input.unsqueeze(1)
571
- if len(speech_conditioning_input.shape) == 3
572
- else speech_conditioning_input
573
- )
574
- conds = []
575
- for j in range(speech_conditioning_input.shape[1]):
576
- conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
577
- conds = torch.stack(conds, dim=1)
578
- conds = conds.mean(dim=1)
579
- return conds
580
-
581
- def get_prompts(self, prompt_codes):
582
- """
583
- Create a prompt from the mel codes. This is used to condition the model on the mel codes.
584
- Pad the prompt with start and stop mel tokens.
585
- """
586
- prompt = prompt_codes
587
- if self.training:
588
- prompt_len = random.randint(1, 9) # in secs
589
- prompt_len = prompt_len * 24 # in frames
590
-
591
- if prompt_codes.shape[1] < prompt_len:
592
- prompt_len = prompt_codes.shape[-1]
593
- start = 0
594
- else:
595
- start = random.randint(0, prompt_codes.shape[-1] - prompt_len)
596
-
597
- prompt = prompt_codes[:, start : start + prompt_len]
598
-
599
- # add start and stop tokens
600
- prompt = F.pad(prompt, (1, 0), value=self.start_prompt_token)
601
- prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token)
602
- return prompt
603
-
604
- # def get_prompts(self, prompt_codes):
605
- # """
606
- # Create a prompt from the mel codes. This is used to condition the model on the mel codes.
607
- # Pad the prompt with start and stop mel tokens.
608
- # """
609
- # prompt = prompt_codes
610
- # if self.training:
611
- # max_prompt_len = 9 * 24
612
- # if prompt_codes.shape[1] < max_prompt_len:
613
- # prompt = prompt_codes
614
- # else:
615
- # start = random.randint(0, prompt_codes.shape[1] - max_prompt_len)
616
- # prompt = prompt_codes[:, start : start + max_prompt_len]
617
-
618
- # # add start and stop tokens
619
- # prompt = F.pad(prompt, (1, 0), value=self.start_prompt_token)
620
- # prompt = F.pad(prompt, (0, 1), value=self.stop_prompt_token)
621
- # return prompt
622
-
623
- def forward(
624
- self,
625
- speech_conditioning_input,
626
- text_inputs,
627
- text_lengths,
628
- mel_codes,
629
- wav_lengths,
630
- prompt_codes,
631
- loss_weights=None,
632
- text_first=True,
633
- return_attentions=False,
634
- return_latent=False,
635
- ):
636
- """
637
- Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
638
- (actuated by `text_first`).
639
-
640
- speech_conditioning_input: MEL float tensor, (b,80,s)
641
- text_inputs: long tensor, (b,t)
642
- text_lengths: long tensor, (b,)
643
- mel_inputs: long tensor, (b,m)
644
- wav_lengths: long tensor, (b,)
645
-
646
- If return_attentions is specified, only logits are returned.
647
- If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
648
- """
649
-
650
- # ❗ FIXIT
651
- speech_conditioning_input = None
652
- if self.max_conditioning_inputs == 0:
653
- assert (
654
- speech_conditioning_input is None
655
- ), " ❗ speech_conditioning_input is not None, but max_conditioning_inputs == 0"
656
-
657
- max_text_len = text_lengths.max()
658
- # Due to the convolution in DVAE, codes do not end with silence at the right place. Rather it predicts some intermediate values
659
- # Like [..., 186, 45, 45, 83] where actually it should end with 186.
660
- # We take last 3 codes to prevent abrupt ending of the audio.
661
- # TODO: This is might need some testing.
662
- mel_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 3
663
-
664
- # If len(codes) + 3 is larger than maxiumum allowed length, we truncate the codes.
665
- max_mel_len = mel_lengths.max()
666
-
667
- if max_mel_len > mel_codes.shape[-1]:
668
- mel_codes = F.pad(mel_codes, (0, max_mel_len - mel_codes.shape[-1]))
669
-
670
- # mel_lengths[mel_lengths >= max_mel_len] = max_mel_len
671
-
672
- # silence aware lengths, skip the silence tokens at the end of the mel codes.
673
- silence = True
674
- for idx, l in enumerate(mel_lengths):
675
- length = l.item()
676
- while silence:
677
- if mel_codes[idx, length - 1] != 83:
678
- break
679
- length -= 1
680
- mel_lengths[idx] = length
681
-
682
- # Lovely assertions
683
- assert (
684
- max_mel_len <= mel_codes.shape[-1]
685
- ), f" ❗ max_mel_len ({max_mel_len}) > mel_codes.shape[-1] ({mel_codes.shape[-1]})"
686
- assert (
687
- max_text_len <= text_inputs.shape[-1]
688
- ), f" ❗ max_text_len ({max_text_len}) > text_inputs.shape[-1] ({text_inputs.shape[-1]})"
689
-
690
- # Append stop token to text inputs
691
- text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token)
692
-
693
- # Append silence token to mel codes
694
- mel_codes = F.pad(mel_codes[:, :max_mel_len], (0, 1), value=self.stop_mel_token)
695
-
696
- # Pad mel codes with STOP_MEL_TOKEN
697
- mel_codes = self.set_mel_padding(mel_codes, mel_lengths)
698
-
699
- # Compute speech conditioning input
700
- conds = None
701
- if speech_conditioning_input is not None:
702
- if not return_latent:
703
- # Compute speech conditioning input
704
- speech_conditioning_input = (
705
- speech_conditioning_input.unsqueeze(1)
706
- if len(speech_conditioning_input.shape) == 3
707
- else speech_conditioning_input
708
- )
709
-
710
- conds = []
711
- for j in range(speech_conditioning_input.shape[1]):
712
- conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
713
- conds = torch.stack(conds, dim=1)
714
- if self.average_conditioning_embeddings:
715
- conds = conds.mean(dim=1).unsqueeze(1)
716
- else:
717
- # already computed
718
- conds = speech_conditioning_input.unsqueeze(1)
719
-
720
- # Build input and target tensors
721
- # Prepend start token to inputs and append stop token to targets
722
- text_inputs, text_targets = self.build_aligned_inputs_and_targets(
723
- text_inputs, self.start_text_token, self.stop_text_token
724
- )
725
- mel_codes, mel_targets = self.build_aligned_inputs_and_targets(
726
- mel_codes, self.start_mel_token, self.stop_mel_token
727
- )
728
-
729
- # Set attn_mask
730
- attn_mask_text = None
731
- attn_mask_mel = None
732
- if not return_latent:
733
- attn_mask_text = torch.ones(
734
- text_inputs.shape[0],
735
- text_inputs.shape[1],
736
- dtype=torch.bool,
737
- device=text_inputs.device,
738
- )
739
- attn_mask_mel = torch.ones(
740
- mel_codes.shape[0],
741
- mel_codes.shape[1],
742
- dtype=torch.bool,
743
- device=mel_codes.device,
744
- )
745
-
746
- for idx, l in enumerate(text_lengths):
747
- attn_mask_text[idx, l + 1 :] = 0.0
748
-
749
- for idx, l in enumerate(mel_lengths):
750
- attn_mask_mel[idx, l + 1 :] = 0.0
751
-
752
- # Compute text embeddings + positional embeddings
753
- # print(" > text input latent:", text_inputs)
754
- text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
755
-
756
- # Compute mel embeddings + positional embeddings
757
- mel_emb = self.mel_embedding(mel_codes) + self.mel_pos_embedding(mel_codes)
758
-
759
- # Compute prompt embeddings + positional embeddings
760
- prompt = self.get_prompts(prompt_codes)
761
-
762
- prompt_emb = self.mel_embedding(prompt).detach() + self.mel_pos_embedding(prompt).detach()
763
-
764
- # Get logits
765
- sub = -4 # don't ask me why 😄
766
- if self.training:
767
- sub = -1
768
- text_logits, mel_logits = self.get_logits(
769
- conds,
770
- text_emb,
771
- self.text_head,
772
- mel_emb,
773
- self.mel_head,
774
- prompt=prompt_emb,
775
- get_attns=return_attentions,
776
- return_latent=return_latent,
777
- attn_mask_text=attn_mask_text,
778
- attn_mask_mel=attn_mask_mel,
779
- )
780
- if return_latent:
781
- return mel_logits[:, :sub] # sub to prevent bla.
782
-
783
- if return_attentions:
784
- return mel_logits
785
-
786
- # Set paddings to -1 to ignore them in loss
787
- for idx, l in enumerate(text_lengths):
788
- text_targets[idx, l + 1 :] = -1
789
-
790
- for idx, l in enumerate(mel_lengths):
791
- mel_targets[idx, l + 1 :] = -1
792
-
793
- # check if stoptoken is in every row of mel_targets
794
- assert (mel_targets == self.stop_mel_token).sum() >= mel_targets.shape[
795
- 0
796
- ], f" ❗ mel_targets does not contain stop token ({self.stop_mel_token}) in every row."
797
-
798
- # Compute losses
799
- loss_text = F.cross_entropy(
800
- text_logits, text_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing
801
- )
802
- loss_mel = F.cross_entropy(
803
- mel_logits, mel_targets.long(), ignore_index=-1, label_smoothing=self.label_smoothing
804
- )
805
-
806
- # if loss_weights is not None:
807
- # loss_text = loss_text * loss_weights[:, None]
808
- # loss_mel = loss_mel * loss_weights[:, None]
809
- return loss_text.mean(), loss_mel.mean(), mel_logits
810
-
811
- def text_forward(self, speech_conditioning_input, text_inputs, text_lengths):
812
- """
813
- Performs autoregressive modeling on only text. Still requires a speech_conditioning_input due to the way the
814
- model inputs are formatted. Just provide any audio clip (arguably, zeros could be provided).
815
- """
816
- # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
817
- # chopping the inputs by the maximum actual length.
818
- max_text_len = text_lengths.max()
819
- text_inputs = F.pad(text_inputs[:, :max_text_len], (0, 1), value=self.stop_text_token)
820
-
821
- speech_conditioning_input = (
822
- speech_conditioning_input.unsqueeze(1)
823
- if len(speech_conditioning_input.shape) == 3
824
- else speech_conditioning_input
825
- )
826
- conds = []
827
- for j in range(speech_conditioning_input.shape[1]):
828
- conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
829
- conds = torch.stack(conds, dim=1)
830
- if self.average_conditioning_embeddings:
831
- conds = conds.mean(dim=1).unsqueeze(1)
832
-
833
- text_inputs, text_targets = self.build_aligned_inputs_and_targets(
834
- text_inputs, self.start_text_token, self.stop_text_token
835
- )
836
- text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs) + self.text_solo_embedding
837
- text_logits = self.get_logits(conds, text_emb, self.text_head)
838
- loss_text = F.cross_entropy(text_logits, text_targets.long())
839
- return loss_text.mean()
840
-
841
- def speech_forward(self, speech_conditioning_input, mel_codes, wav_lengths, raw_mels=None):
842
- """
843
- Performs autoregressive modeling on only speech data.
844
- """
845
- assert self.max_mel_tokens >= mel_codes.shape[1], f"{mel_codes.shape[1]}"
846
-
847
- # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
848
- # chopping the inputs by the maximum actual length.
849
- max_mel_len = wav_lengths.max() // self.mel_length_compression
850
- mel_codes = F.pad(mel_codes[:, :max_mel_len], (0, 1), value=self.stop_mel_token)
851
- mel_codes = self.set_mel_padding(mel_codes, wav_lengths)
852
- if raw_mels is not None:
853
- raw_mels = raw_mels[:, :, : max_mel_len * 4]
854
-
855
- speech_conditioning_input = (
856
- speech_conditioning_input.unsqueeze(1)
857
- if len(speech_conditioning_input.shape) == 3
858
- else speech_conditioning_input
859
- )
860
- conds = []
861
- for j in range(speech_conditioning_input.shape[1]):
862
- conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
863
- conds = torch.stack(conds, dim=1)
864
- if self.average_conditioning_embeddings:
865
- conds = conds.mean(dim=1).unsqueeze(1)
866
-
867
- mel_codes, mel_targets = self.build_aligned_inputs_and_targets(
868
- mel_codes, self.start_mel_token, self.stop_mel_token
869
- )
870
- if raw_mels is not None:
871
- mel_inp = F.pad(raw_mels, (0, 4))
872
- else:
873
- mel_inp = mel_codes
874
- mel_emb = self.mel_embedding(mel_inp)
875
- mel_emb = mel_emb + self.mel_pos_embedding(mel_codes) + self.mel_solo_embedding
876
- mel_logits = self.get_logits(conds, mel_emb, self.mel_head)
877
- loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
878
- return loss_mel.mean()
879
-
880
- def get_generator(self, fake_inputs, **hf_generate_kwargs):
881
- return self.inference_model.generate_stream(
882
- fake_inputs,
883
- bos_token_id=self.start_mel_token,
884
- pad_token_id=self.stop_mel_token,
885
- eos_token_id=self.stop_mel_token,
886
- max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
887
- do_stream=True,
888
- **hf_generate_kwargs,
889
- )
890
-
891
- def compute_embeddings(
892
- self,
893
- speech_conditioning_latent,
894
- text_inputs,
895
- input_tokens=None,
896
- prompt_codes=None,
897
- pad_input_text=False,
898
- ):
899
- if pad_input_text and text_inputs.shape[1] < 250:
900
- text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token)
901
- else:
902
- text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
903
- text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
904
-
905
- emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
906
-
907
- print(" > Text inputs:", text_inputs)
908
- if prompt_codes is not None:
909
- prompt_codes = self.get_prompts(prompt_codes)
910
- prompt_emb = self.mel_embedding(prompt_codes) + self.mel_pos_embedding(prompt_codes)
911
- print(" > Prompt inputs:", prompt_codes)
912
- print(" > Prompt inputs shape:", prompt_codes.shape)
913
- emb = torch.cat([prompt_emb, emb], dim=1)
914
-
915
- if speech_conditioning_latent is not None:
916
- conds = speech_conditioning_latent.unsqueeze(1)
917
- emb = torch.cat([conds, emb], dim=1)
918
-
919
- self.inference_model.store_prefix_emb(emb)
920
-
921
- fake_inputs = torch.full(
922
- (
923
- emb.shape[0],
924
- emb.shape[1] + 1, # +1 for the start_mel_token
925
- ),
926
- fill_value=1,
927
- dtype=torch.long,
928
- device=text_inputs.device,
929
- )
930
- fake_inputs[:, -1] = self.start_mel_token
931
-
932
- if input_tokens is not None:
933
- fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1)
934
- return fake_inputs
935
-
936
- def inference_speech(
937
- self,
938
- speech_conditioning_latent,
939
- text_inputs,
940
- input_tokens=None,
941
- prompt_codes=None,
942
- pad_input_text=False,
943
- **hf_generate_kwargs,
944
- ):
945
- if pad_input_text and text_inputs.shape[1] < 250:
946
- text_inputs = F.pad(text_inputs, (0, 250 - text_inputs.shape[1]), value=self.stop_text_token)
947
- else:
948
- text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
949
- text_inputs = F.pad(text_inputs, (1, 0), value=self.start_text_token)
950
-
951
- emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
952
-
953
- print(" > Text inputs:", text_inputs)
954
- if prompt_codes is not None:
955
- prompt_codes = self.get_prompts(prompt_codes)
956
- prompt_emb = self.mel_embedding(prompt_codes) + self.mel_pos_embedding(prompt_codes)
957
- print(" > Prompt inputs:", prompt_codes)
958
- print(" > Prompt inputs shape:", prompt_codes.shape)
959
- emb = torch.cat([prompt_emb, emb], dim=1)
960
-
961
- if speech_conditioning_latent is not None:
962
- conds = speech_conditioning_latent.unsqueeze(1)
963
- emb = torch.cat([conds, emb], dim=1)
964
-
965
- self.inference_model.store_prefix_emb(emb)
966
-
967
- fake_inputs = torch.full(
968
- (
969
- emb.shape[0],
970
- emb.shape[1] + 1, # +1 for the start_mel_token
971
- ),
972
- fill_value=1,
973
- dtype=torch.long,
974
- device=text_inputs.device,
975
- )
976
- fake_inputs[:, -1] = self.start_mel_token
977
-
978
- if input_tokens is not None:
979
- fake_inputs = torch.cat([fake_inputs, input_tokens], dim=1)
980
-
981
- gen = self.inference_model.generate(
982
- fake_inputs,
983
- bos_token_id=self.start_mel_token,
984
- pad_token_id=self.stop_mel_token,
985
- eos_token_id=self.stop_mel_token,
986
- max_length=self.max_mel_tokens * 2 + self.max_prompt_tokens + self.max_text_tokens,
987
- **hf_generate_kwargs,
988
- )
989
- if "return_dict_in_generate" in hf_generate_kwargs:
990
- return gen.sequences[:, fake_inputs.shape[1] :], gen
991
- return gen[:, fake_inputs.shape[1] :]
992
-
993
- # Turns the (utterly insane) output of HF.generate() into a far more sane output:
994
- # [tensors(B,H,S,S)]. Outer=layers, B=batch,H=head,S=sequence
995
- def make_hf_generate_attentions_sane(self, attentions):
996
- layers = [[] for _ in range(len(attentions[0]))]
997
- full_attention_size = attentions[-1][0].shape[-1]
998
- for i, gen in enumerate(attentions):
999
- for j, lyr in enumerate(gen):
1000
- layers[j].append(F.pad(lyr, (0, full_attention_size - lyr.shape[-1])))
1001
- catted = []
1002
- for lyr in layers:
1003
- catted.append(torch.cat(lyr, dim=2))
1004
- return catted
1005
-
1006
- def convert_attentions_to_aligned_codes(self, text, attentions, codes, num_conds):
1007
- """
1008
- This was an attempt to make some sense out of the attention matrix retrieved from the unified_voice model. Unfortunately, I can't use it for aligning text & voice.
1009
- """
1010
- text_padding = num_conds + 2
1011
- num_text = text.shape[-1]
1012
- num_context = num_text + text_padding
1013
- assert num_context + 1 == attentions[0][0].shape[-1]
1014
- attentions = self.make_hf_generate_attentions_sane(attentions)
1015
- results = [torch.empty_like(codes) for _ in range(len(attentions))]
1016
- for l, layer in enumerate(attentions):
1017
- dec_context = layer[:, :, num_context:, :]
1018
- # Mask out everything that isn't text (including the start token, which gets a LOT of attention)
1019
- dec_context[:, :, :, : text_padding + 1] = 0
1020
- dec_context[:, :, :, num_context:] = 0
1021
- for h in range(dec_context.shape[1]):
1022
- dec_context_indices = torch.argmax(dec_context[0, h], dim=-1)
1023
- print(f"layer_{l};head_{h}: " + str(dec_context_indices))
1024
- for t, att_tok in enumerate(attentions):
1025
- combined_attention_weights = torch.zeros((codes.shape[0], num_text), device=codes.device)
1026
- for lyr in att_tok:
1027
- token_to_text_attentions = lyr[:, :, -1, text_padding : (text_padding + num_text)].sum(dim=1)
1028
- combined_attention_weights = combined_attention_weights + token_to_text_attentions
1029
- break
1030
- most_attended_text_token = combined_attention_weights.argmax(dim=-1)
1031
- results[:, t] = most_attended_text_token
1032
- eos_token_mask = codes != self.stop_mel_token
1033
- return results * eos_token_mask
1034
-
1035
-
1036
- @register_model
1037
- def register_unified_voice_prompt(opt_net, opt):
1038
- return UnifiedVoice(**opt_get(opt_net, ["kwargs"], {}))
1039
-
1040
-
1041
- if __name__ == "__main__":
1042
- gpt = UnifiedVoice(
1043
- model_dim=256,
1044
- heads=4,
1045
- train_solo_embeddings=True,
1046
- use_mel_codes_as_input=True,
1047
- max_conditioning_inputs=4,
1048
- freeze_everything_but_position_embeddings=True,
1049
- )
1050
- l = gpt(
1051
- torch.randn(2, 3, 80, 800),
1052
- torch.randint(high=256, size=(2, 120)),
1053
- torch.tensor([32, 120]),
1054
- torch.randint(high=8192, size=(2, 250)),
1055
- torch.tensor([250 * 256, 195 * 256]),
1056
- )
1057
- # gpt.text_forward(torch.randn(2,80,800), torch.randint(high=50, size=(2,80)), torch.tensor([32, 80]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
TTS/TTS/tts/layers/xtts/hifigan_decoder.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import Conv1d, ConvTranspose1d
4
+ from torch.nn import functional as F
5
+ from torch.nn.utils import remove_weight_norm, weight_norm
6
+ import torchaudio
7
+
8
+ from TTS.utils.io import load_fsspec
9
+
10
+
11
+ LRELU_SLOPE = 0.1
12
+
13
+
14
+ def get_padding(k, d):
15
+ return int((k * d - d) / 2)
16
+
17
+
18
+ class ResBlock1(torch.nn.Module):
19
+ """Residual Block Type 1. It has 3 convolutional layers in each convolutional block.
20
+
21
+ Network::
22
+
23
+ x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
24
+ |--------------------------------------------------------------------------------------------------|
25
+
26
+
27
+ Args:
28
+ channels (int): number of hidden channels for the convolutional layers.
29
+ kernel_size (int): size of the convolution filter in each layer.
30
+ dilations (list): list of dilation value for each conv layer in a block.
31
+ """
32
+
33
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
34
+ super().__init__()
35
+ self.convs1 = nn.ModuleList(
36
+ [
37
+ weight_norm(
38
+ Conv1d(
39
+ channels,
40
+ channels,
41
+ kernel_size,
42
+ 1,
43
+ dilation=dilation[0],
44
+ padding=get_padding(kernel_size, dilation[0]),
45
+ )
46
+ ),
47
+ weight_norm(
48
+ Conv1d(
49
+ channels,
50
+ channels,
51
+ kernel_size,
52
+ 1,
53
+ dilation=dilation[1],
54
+ padding=get_padding(kernel_size, dilation[1]),
55
+ )
56
+ ),
57
+ weight_norm(
58
+ Conv1d(
59
+ channels,
60
+ channels,
61
+ kernel_size,
62
+ 1,
63
+ dilation=dilation[2],
64
+ padding=get_padding(kernel_size, dilation[2]),
65
+ )
66
+ ),
67
+ ]
68
+ )
69
+
70
+ self.convs2 = nn.ModuleList(
71
+ [
72
+ weight_norm(
73
+ Conv1d(
74
+ channels,
75
+ channels,
76
+ kernel_size,
77
+ 1,
78
+ dilation=1,
79
+ padding=get_padding(kernel_size, 1),
80
+ )
81
+ ),
82
+ weight_norm(
83
+ Conv1d(
84
+ channels,
85
+ channels,
86
+ kernel_size,
87
+ 1,
88
+ dilation=1,
89
+ padding=get_padding(kernel_size, 1),
90
+ )
91
+ ),
92
+ weight_norm(
93
+ Conv1d(
94
+ channels,
95
+ channels,
96
+ kernel_size,
97
+ 1,
98
+ dilation=1,
99
+ padding=get_padding(kernel_size, 1),
100
+ )
101
+ ),
102
+ ]
103
+ )
104
+
105
+ def forward(self, x):
106
+ """
107
+ Args:
108
+ x (Tensor): input tensor.
109
+ Returns:
110
+ Tensor: output tensor.
111
+ Shapes:
112
+ x: [B, C, T]
113
+ """
114
+ for c1, c2 in zip(self.convs1, self.convs2):
115
+ xt = F.leaky_relu(x, LRELU_SLOPE)
116
+ xt = c1(xt)
117
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
118
+ xt = c2(xt)
119
+ x = xt + x
120
+ return x
121
+
122
+ def remove_weight_norm(self):
123
+ for l in self.convs1:
124
+ remove_weight_norm(l)
125
+ for l in self.convs2:
126
+ remove_weight_norm(l)
127
+
128
+
129
+ class ResBlock2(torch.nn.Module):
130
+ """Residual Block Type 2. It has 1 convolutional layers in each convolutional block.
131
+
132
+ Network::
133
+
134
+ x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
135
+ |---------------------------------------------------|
136
+
137
+
138
+ Args:
139
+ channels (int): number of hidden channels for the convolutional layers.
140
+ kernel_size (int): size of the convolution filter in each layer.
141
+ dilations (list): list of dilation value for each conv layer in a block.
142
+ """
143
+
144
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
145
+ super().__init__()
146
+ self.convs = nn.ModuleList(
147
+ [
148
+ weight_norm(
149
+ Conv1d(
150
+ channels,
151
+ channels,
152
+ kernel_size,
153
+ 1,
154
+ dilation=dilation[0],
155
+ padding=get_padding(kernel_size, dilation[0]),
156
+ )
157
+ ),
158
+ weight_norm(
159
+ Conv1d(
160
+ channels,
161
+ channels,
162
+ kernel_size,
163
+ 1,
164
+ dilation=dilation[1],
165
+ padding=get_padding(kernel_size, dilation[1]),
166
+ )
167
+ ),
168
+ ]
169
+ )
170
+
171
+ def forward(self, x):
172
+ for c in self.convs:
173
+ xt = F.leaky_relu(x, LRELU_SLOPE)
174
+ xt = c(xt)
175
+ x = xt + x
176
+ return x
177
+
178
+ def remove_weight_norm(self):
179
+ for l in self.convs:
180
+ remove_weight_norm(l)
181
+
182
+
183
+ class HifiganGenerator(torch.nn.Module):
184
+ def __init__(
185
+ self,
186
+ in_channels,
187
+ out_channels,
188
+ resblock_type,
189
+ resblock_dilation_sizes,
190
+ resblock_kernel_sizes,
191
+ upsample_kernel_sizes,
192
+ upsample_initial_channel,
193
+ upsample_factors,
194
+ inference_padding=5,
195
+ cond_channels=0,
196
+ conv_pre_weight_norm=True,
197
+ conv_post_weight_norm=True,
198
+ conv_post_bias=True,
199
+ cond_in_each_up_layer=False,
200
+ ):
201
+ r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
202
+
203
+ Network:
204
+ x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o
205
+ .. -> zI ---|
206
+ resblockN_kNx1 -> zN ---'
207
+
208
+ Args:
209
+ in_channels (int): number of input tensor channels.
210
+ out_channels (int): number of output tensor channels.
211
+ resblock_type (str): type of the `ResBlock`. '1' or '2'.
212
+ resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`.
213
+ resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`.
214
+ upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution.
215
+ upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
216
+ for each consecutive upsampling layer.
217
+ upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer.
218
+ inference_padding (int): constant padding applied to the input at inference time. Defaults to 5.
219
+ """
220
+ super().__init__()
221
+ self.inference_padding = inference_padding
222
+ self.num_kernels = len(resblock_kernel_sizes)
223
+ self.num_upsamples = len(upsample_factors)
224
+ self.cond_in_each_up_layer = cond_in_each_up_layer
225
+
226
+ # initial upsampling layers
227
+ self.conv_pre = weight_norm(
228
+ Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
229
+ )
230
+ resblock = ResBlock1 if resblock_type == "1" else ResBlock2
231
+ # upsampling layers
232
+ self.ups = nn.ModuleList()
233
+ for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
234
+ self.ups.append(
235
+ weight_norm(
236
+ ConvTranspose1d(
237
+ upsample_initial_channel // (2**i),
238
+ upsample_initial_channel // (2 ** (i + 1)),
239
+ k,
240
+ u,
241
+ padding=(k - u) // 2,
242
+ )
243
+ )
244
+ )
245
+ # MRF blocks
246
+ self.resblocks = nn.ModuleList()
247
+ for i in range(len(self.ups)):
248
+ ch = upsample_initial_channel // (2 ** (i + 1))
249
+ for _, (k, d) in enumerate(
250
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
251
+ ):
252
+ self.resblocks.append(resblock(ch, k, d))
253
+ # post convolution layer
254
+ self.conv_post = weight_norm(
255
+ Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)
256
+ )
257
+ if cond_channels > 0:
258
+ self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1)
259
+
260
+ if not conv_pre_weight_norm:
261
+ remove_weight_norm(self.conv_pre)
262
+
263
+ if not conv_post_weight_norm:
264
+ remove_weight_norm(self.conv_post)
265
+
266
+ if self.cond_in_each_up_layer:
267
+ self.conds = nn.ModuleList()
268
+ for i in range(len(self.ups)):
269
+ ch = upsample_initial_channel // (2 ** (i + 1))
270
+ self.conds.append(nn.Conv1d(cond_channels, ch, 1))
271
+
272
+ def forward(self, x, g=None):
273
+ """
274
+ Args:
275
+ x (Tensor): feature input tensor.
276
+ g (Tensor): global conditioning input tensor.
277
+
278
+ Returns:
279
+ Tensor: output waveform.
280
+
281
+ Shapes:
282
+ x: [B, C, T]
283
+ Tensor: [B, 1, T]
284
+ """
285
+ o = self.conv_pre(x)
286
+ if hasattr(self, "cond_layer"):
287
+ o = o + self.cond_layer(g)
288
+ for i in range(self.num_upsamples):
289
+ o = F.leaky_relu(o, LRELU_SLOPE)
290
+ o = self.ups[i](o)
291
+
292
+ if self.cond_in_each_up_layer:
293
+ o = o + self.conds[i](g)
294
+
295
+ z_sum = None
296
+ for j in range(self.num_kernels):
297
+ if z_sum is None:
298
+ z_sum = self.resblocks[i * self.num_kernels + j](o)
299
+ else:
300
+ z_sum += self.resblocks[i * self.num_kernels + j](o)
301
+ o = z_sum / self.num_kernels
302
+ o = F.leaky_relu(o)
303
+ o = self.conv_post(o)
304
+ o = torch.tanh(o)
305
+ return o
306
+
307
+ @torch.no_grad()
308
+ def inference(self, c):
309
+ """
310
+ Args:
311
+ x (Tensor): conditioning input tensor.
312
+
313
+ Returns:
314
+ Tensor: output waveform.
315
+
316
+ Shapes:
317
+ x: [B, C, T]
318
+ Tensor: [B, 1, T]
319
+ """
320
+ c = c.to(self.conv_pre.weight.device)
321
+ c = torch.nn.functional.pad(
322
+ c, (self.inference_padding, self.inference_padding), "replicate"
323
+ )
324
+ return self.forward(c)
325
+
326
+ def remove_weight_norm(self):
327
+ print("Removing weight norm...")
328
+ for l in self.ups:
329
+ remove_weight_norm(l)
330
+ for l in self.resblocks:
331
+ l.remove_weight_norm()
332
+ remove_weight_norm(self.conv_pre)
333
+ remove_weight_norm(self.conv_post)
334
+
335
+ def load_checkpoint(
336
+ self, config, checkpoint_path, eval=False, cache=False
337
+ ): # pylint: disable=unused-argument, redefined-builtin
338
+ state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
339
+ self.load_state_dict(state["model"])
340
+ if eval:
341
+ self.eval()
342
+ assert not self.training
343
+ self.remove_weight_norm()
344
+
345
+ class SELayer(nn.Module):
346
+ def __init__(self, channel, reduction=8):
347
+ super(SELayer, self).__init__()
348
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
349
+ self.fc = nn.Sequential(
350
+ nn.Linear(channel, channel // reduction),
351
+ nn.ReLU(inplace=True),
352
+ nn.Linear(channel // reduction, channel),
353
+ nn.Sigmoid(),
354
+ )
355
+
356
+ def forward(self, x):
357
+ b, c, _, _ = x.size()
358
+ y = self.avg_pool(x).view(b, c)
359
+ y = self.fc(y).view(b, c, 1, 1)
360
+ return x * y
361
+
362
+
363
+ class SEBasicBlock(nn.Module):
364
+ expansion = 1
365
+
366
+ def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
367
+ super(SEBasicBlock, self).__init__()
368
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
369
+ self.bn1 = nn.BatchNorm2d(planes)
370
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
371
+ self.bn2 = nn.BatchNorm2d(planes)
372
+ self.relu = nn.ReLU(inplace=True)
373
+ self.se = SELayer(planes, reduction)
374
+ self.downsample = downsample
375
+ self.stride = stride
376
+
377
+ def forward(self, x):
378
+ residual = x
379
+
380
+ out = self.conv1(x)
381
+ out = self.relu(out)
382
+ out = self.bn1(out)
383
+
384
+ out = self.conv2(out)
385
+ out = self.bn2(out)
386
+ out = self.se(out)
387
+
388
+ if self.downsample is not None:
389
+ residual = self.downsample(x)
390
+
391
+ out += residual
392
+ out = self.relu(out)
393
+ return out
394
+
395
+
396
+ def set_init_dict(model_dict, checkpoint_state, c):
397
+ # Partial initialization: if there is a mismatch with new and old layer, it is skipped.
398
+ for k, v in checkpoint_state.items():
399
+ if k not in model_dict:
400
+ print(" | > Layer missing in the model definition: {}".format(k))
401
+ # 1. filter out unnecessary keys
402
+ pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
403
+ # 2. filter out different size layers
404
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
405
+ # 3. skip reinit layers
406
+ if c.has("reinit_layers") and c.reinit_layers is not None:
407
+ for reinit_layer_name in c.reinit_layers:
408
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
409
+ # 4. overwrite entries in the existing state dict
410
+ model_dict.update(pretrained_dict)
411
+ print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
412
+ return model_dict
413
+
414
+
415
+ class PreEmphasis(nn.Module):
416
+ def __init__(self, coefficient=0.97):
417
+ super().__init__()
418
+ self.coefficient = coefficient
419
+ self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0))
420
+
421
+ def forward(self, x):
422
+ assert len(x.size()) == 2
423
+
424
+ x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect")
425
+ return torch.nn.functional.conv1d(x, self.filter).squeeze(1)
426
+
427
+
428
+
429
+ class ResNetSpeakerEncoder(nn.Module):
430
+ """This is copied from 🐸TTS to remove it from the dependencies.
431
+ """
432
+
433
+ # pylint: disable=W0102
434
+ def __init__(
435
+ self,
436
+ input_dim=64,
437
+ proj_dim=512,
438
+ layers=[3, 4, 6, 3],
439
+ num_filters=[32, 64, 128, 256],
440
+ encoder_type="ASP",
441
+ log_input=False,
442
+ use_torch_spec=False,
443
+ audio_config=None,
444
+ ):
445
+ super(ResNetSpeakerEncoder, self).__init__()
446
+
447
+ self.encoder_type = encoder_type
448
+ self.input_dim = input_dim
449
+ self.log_input = log_input
450
+ self.use_torch_spec = use_torch_spec
451
+ self.audio_config = audio_config
452
+ self.proj_dim = proj_dim
453
+
454
+ self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1)
455
+ self.relu = nn.ReLU(inplace=True)
456
+ self.bn1 = nn.BatchNorm2d(num_filters[0])
457
+
458
+ self.inplanes = num_filters[0]
459
+ self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0])
460
+ self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2))
461
+ self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2))
462
+ self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2))
463
+
464
+ self.instancenorm = nn.InstanceNorm1d(input_dim)
465
+
466
+ if self.use_torch_spec:
467
+ self.torch_spec = torch.nn.Sequential(
468
+ PreEmphasis(audio_config["preemphasis"]),
469
+ torchaudio.transforms.MelSpectrogram(
470
+ sample_rate=audio_config["sample_rate"],
471
+ n_fft=audio_config["fft_size"],
472
+ win_length=audio_config["win_length"],
473
+ hop_length=audio_config["hop_length"],
474
+ window_fn=torch.hamming_window,
475
+ n_mels=audio_config["num_mels"],
476
+ ),
477
+ )
478
+
479
+ else:
480
+ self.torch_spec = None
481
+
482
+ outmap_size = int(self.input_dim / 8)
483
+
484
+ self.attention = nn.Sequential(
485
+ nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1),
486
+ nn.ReLU(),
487
+ nn.BatchNorm1d(128),
488
+ nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1),
489
+ nn.Softmax(dim=2),
490
+ )
491
+
492
+ if self.encoder_type == "SAP":
493
+ out_dim = num_filters[3] * outmap_size
494
+ elif self.encoder_type == "ASP":
495
+ out_dim = num_filters[3] * outmap_size * 2
496
+ else:
497
+ raise ValueError("Undefined encoder")
498
+
499
+ self.fc = nn.Linear(out_dim, proj_dim)
500
+
501
+ self._init_layers()
502
+
503
+ def _init_layers(self):
504
+ for m in self.modules():
505
+ if isinstance(m, nn.Conv2d):
506
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
507
+ elif isinstance(m, nn.BatchNorm2d):
508
+ nn.init.constant_(m.weight, 1)
509
+ nn.init.constant_(m.bias, 0)
510
+
511
+ def create_layer(self, block, planes, blocks, stride=1):
512
+ downsample = None
513
+ if stride != 1 or self.inplanes != planes * block.expansion:
514
+ downsample = nn.Sequential(
515
+ nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
516
+ nn.BatchNorm2d(planes * block.expansion),
517
+ )
518
+
519
+ layers = []
520
+ layers.append(block(self.inplanes, planes, stride, downsample))
521
+ self.inplanes = planes * block.expansion
522
+ for _ in range(1, blocks):
523
+ layers.append(block(self.inplanes, planes))
524
+
525
+ return nn.Sequential(*layers)
526
+
527
+ # pylint: disable=R0201
528
+ def new_parameter(self, *size):
529
+ out = nn.Parameter(torch.FloatTensor(*size))
530
+ nn.init.xavier_normal_(out)
531
+ return out
532
+
533
+ def forward(self, x, l2_norm=False):
534
+ """Forward pass of the model.
535
+
536
+ Args:
537
+ x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True`
538
+ to compute the spectrogram on-the-fly.
539
+ l2_norm (bool): Whether to L2-normalize the outputs.
540
+
541
+ Shapes:
542
+ - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})`
543
+ """
544
+ x.squeeze_(1)
545
+ # if you torch spec compute it otherwise use the mel spec computed by the AP
546
+ if self.use_torch_spec:
547
+ x = self.torch_spec(x)
548
+
549
+ if self.log_input:
550
+ x = (x + 1e-6).log()
551
+ x = self.instancenorm(x).unsqueeze(1)
552
+
553
+ x = self.conv1(x)
554
+ x = self.relu(x)
555
+ x = self.bn1(x)
556
+
557
+ x = self.layer1(x)
558
+ x = self.layer2(x)
559
+ x = self.layer3(x)
560
+ x = self.layer4(x)
561
+
562
+ x = x.reshape(x.size()[0], -1, x.size()[-1])
563
+
564
+ w = self.attention(x)
565
+
566
+ if self.encoder_type == "SAP":
567
+ x = torch.sum(x * w, dim=2)
568
+ elif self.encoder_type == "ASP":
569
+ mu = torch.sum(x * w, dim=2)
570
+ sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5))
571
+ x = torch.cat((mu, sg), 1)
572
+
573
+ x = x.view(x.size()[0], -1)
574
+ x = self.fc(x)
575
+
576
+ if l2_norm:
577
+ x = torch.nn.functional.normalize(x, p=2, dim=1)
578
+ return x
579
+
580
+ def load_checkpoint(
581
+ self,
582
+ checkpoint_path: str,
583
+ eval: bool = False,
584
+ use_cuda: bool = False,
585
+ criterion=None,
586
+ cache=False,
587
+ ):
588
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
589
+ try:
590
+ self.load_state_dict(state["model"])
591
+ print(" > Model fully restored. ")
592
+ except (KeyError, RuntimeError) as error:
593
+ # If eval raise the error
594
+ if eval:
595
+ raise error
596
+
597
+ print(" > Partial model initialization.")
598
+ model_dict = self.state_dict()
599
+ model_dict = set_init_dict(model_dict, state["model"])
600
+ self.load_state_dict(model_dict)
601
+ del model_dict
602
+
603
+ # load the criterion for restore_path
604
+ if criterion is not None and "criterion" in state:
605
+ try:
606
+ criterion.load_state_dict(state["criterion"])
607
+ except (KeyError, RuntimeError) as error:
608
+ print(" > Criterion load ignored because of:", error)
609
+
610
+ if use_cuda:
611
+ self.cuda()
612
+ if criterion is not None:
613
+ criterion = criterion.cuda()
614
+
615
+ if eval:
616
+ self.eval()
617
+ assert not self.training
618
+
619
+ if not eval:
620
+ return criterion, state["step"]
621
+ return criterion
622
+
623
+ class HifiDecoder(torch.nn.Module):
624
+ def __init__(
625
+ self,
626
+ input_sample_rate=22050,
627
+ output_sample_rate=24000,
628
+ output_hop_length=256,
629
+ ar_mel_length_compression=1024,
630
+ decoder_input_dim=1024,
631
+ resblock_type_decoder="1",
632
+ resblock_dilation_sizes_decoder=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
633
+ resblock_kernel_sizes_decoder=[3, 7, 11],
634
+ upsample_rates_decoder=[8, 8, 2, 2],
635
+ upsample_initial_channel_decoder=512,
636
+ upsample_kernel_sizes_decoder=[16, 16, 4, 4],
637
+ d_vector_dim=512,
638
+ cond_d_vector_in_each_upsampling_layer=True,
639
+ speaker_encoder_audio_config={
640
+ "fft_size": 512,
641
+ "win_length": 400,
642
+ "hop_length": 160,
643
+ "sample_rate": 16000,
644
+ "preemphasis": 0.97,
645
+ "num_mels": 64,
646
+ },
647
+ ):
648
+ super().__init__()
649
+ self.input_sample_rate = input_sample_rate
650
+ self.output_sample_rate = output_sample_rate
651
+ self.output_hop_length = output_hop_length
652
+ self.ar_mel_length_compression = ar_mel_length_compression
653
+ self.speaker_encoder_audio_config = speaker_encoder_audio_config
654
+ self.waveform_decoder = HifiganGenerator(
655
+ decoder_input_dim,
656
+ 1,
657
+ resblock_type_decoder,
658
+ resblock_dilation_sizes_decoder,
659
+ resblock_kernel_sizes_decoder,
660
+ upsample_kernel_sizes_decoder,
661
+ upsample_initial_channel_decoder,
662
+ upsample_rates_decoder,
663
+ inference_padding=0,
664
+ cond_channels=d_vector_dim,
665
+ conv_pre_weight_norm=False,
666
+ conv_post_weight_norm=False,
667
+ conv_post_bias=False,
668
+ cond_in_each_up_layer=cond_d_vector_in_each_upsampling_layer,
669
+ )
670
+ self.speaker_encoder = ResNetSpeakerEncoder(
671
+ input_dim=64,
672
+ proj_dim=512,
673
+ log_input=True,
674
+ use_torch_spec=True,
675
+ audio_config=speaker_encoder_audio_config,
676
+ )
677
+
678
+ @property
679
+ def device(self):
680
+ return next(self.parameters()).device
681
+
682
+ def forward(self, latents, g=None):
683
+ """
684
+ Args:
685
+ x (Tensor): feature input tensor (GPT latent).
686
+ g (Tensor): global conditioning input tensor.
687
+
688
+ Returns:
689
+ Tensor: output waveform.
690
+
691
+ Shapes:
692
+ x: [B, C, T]
693
+ Tensor: [B, 1, T]
694
+ """
695
+
696
+ z = torch.nn.functional.interpolate(
697
+ latents.transpose(1, 2),
698
+ scale_factor=[self.ar_mel_length_compression / self.output_hop_length],
699
+ mode="linear",
700
+ ).squeeze(1)
701
+ # upsample to the right sr
702
+ if self.output_sample_rate != self.input_sample_rate:
703
+ z = torch.nn.functional.interpolate(
704
+ z,
705
+ scale_factor=[self.output_sample_rate / self.input_sample_rate],
706
+ mode="linear",
707
+ ).squeeze(0)
708
+ o = self.waveform_decoder(z, g=g)
709
+ return o
710
+
711
+ @torch.no_grad()
712
+ def inference(self, c, g):
713
+ """
714
+ Args:
715
+ x (Tensor): feature input tensor (GPT latent).
716
+ g (Tensor): global conditioning input tensor.
717
+
718
+ Returns:
719
+ Tensor: output waveform.
720
+
721
+ Shapes:
722
+ x: [B, C, T]
723
+ Tensor: [B, 1, T]
724
+ """
725
+ return self.forward(c, g=g)
726
+
727
+ def load_checkpoint(
728
+ self, checkpoint_path, eval=False
729
+ ): # pylint: disable=unused-argument, redefined-builtin
730
+ state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))
731
+ # remove unused keys
732
+ state = state["model"]
733
+ states_keys = list(state.keys())
734
+ for key in states_keys:
735
+ if "waveform_decoder." not in key and "speaker_encoder." not in key:
736
+ del state[key]
737
+
738
+ self.load_state_dict(state)
739
+ if eval:
740
+ self.eval()
741
+ assert not self.training
742
+ self.waveform_decoder.remove_weight_norm()
TTS/TTS/tts/layers/xtts/stream_generator.py ADDED
@@ -0,0 +1,1057 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from: https://github.com/LowinLi/transformers-stream-generator
2
+
3
+ from transformers import (
4
+ GenerationConfig,
5
+ GenerationMixin,
6
+ LogitsProcessorList,
7
+ StoppingCriteriaList,
8
+ DisjunctiveConstraint,
9
+ BeamSearchScorer,
10
+ PhrasalConstraint,
11
+ ConstrainedBeamSearchScorer,
12
+ PreTrainedModel,
13
+ )
14
+ import numpy as np
15
+ import random
16
+ import warnings
17
+ import inspect
18
+ from transformers.generation.utils import GenerateOutput, SampleOutput, logger
19
+ import torch
20
+ from typing import Callable, List, Optional, Union
21
+ from torch import nn
22
+ import torch.distributed as dist
23
+ import copy
24
+
25
+
26
+ def setup_seed(seed):
27
+ if seed == -1:
28
+ return
29
+ torch.manual_seed(seed)
30
+ if torch.cuda.is_available():
31
+ torch.cuda.manual_seed_all(seed)
32
+ np.random.seed(seed)
33
+ random.seed(seed)
34
+ torch.backends.cudnn.deterministic = True
35
+
36
+
37
+ class StreamGenerationConfig(GenerationConfig):
38
+ def __init__(self, **kwargs):
39
+ super().__init__(**kwargs)
40
+ self.do_stream = kwargs.pop("do_stream", False)
41
+
42
+
43
+ class NewGenerationMixin(GenerationMixin):
44
+ @torch.no_grad()
45
+ def generate(
46
+ self,
47
+ inputs: Optional[torch.Tensor] = None,
48
+ generation_config: Optional[StreamGenerationConfig] = None,
49
+ logits_processor: Optional[LogitsProcessorList] = None,
50
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
51
+ prefix_allowed_tokens_fn: Optional[
52
+ Callable[[int, torch.Tensor], List[int]]
53
+ ] = None,
54
+ synced_gpus: Optional[bool] = False,
55
+ seed=0,
56
+ **kwargs,
57
+ ) -> Union[GenerateOutput, torch.LongTensor]:
58
+ r"""
59
+
60
+ Generates sequences of token ids for models with a language modeling head.
61
+
62
+ <Tip warning={true}>
63
+
64
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
65
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
66
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
67
+
68
+ For an overview of generation strategies and code examples, check out the [following
69
+ guide](./generation_strategies).
70
+
71
+ </Tip>
72
+
73
+ Parameters:
74
+ inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
75
+ The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
76
+ method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
77
+ should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
78
+ `input_ids`, `input_values`, `input_features`, or `pixel_values`.
79
+ generation_config (`~generation.GenerationConfig`, *optional*):
80
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
81
+ passed to generate matching the attributes of `generation_config` will override them. If
82
+ `generation_config` is not provided, the default will be used, which had the following loading
83
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
84
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
85
+ default values, whose documentation should be checked to parameterize generation.
86
+ logits_processor (`LogitsProcessorList`, *optional*):
87
+ Custom logits processors that complement the default logits processors built from arguments and
88
+ generation config. If a logit processor is passed that is already created with the arguments or a
89
+ generation config an error is thrown. This feature is intended for advanced users.
90
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
91
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
92
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
93
+ generation config an error is thrown. This feature is intended for advanced users.
94
+ prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
95
+ If provided, this function constraints the beam search to allowed tokens only at each step. If not
96
+ provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
97
+ `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
98
+ on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
99
+ for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
100
+ Retrieval](https://arxiv.org/abs/2010.00904).
101
+ synced_gpus (`bool`, *optional*, defaults to `False`):
102
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
103
+ kwargs:
104
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
105
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
106
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
107
+
108
+ Return:
109
+ [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
110
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
111
+
112
+ If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
113
+ [`~utils.ModelOutput`] types are:
114
+
115
+ - [`~generation.GreedySearchDecoderOnlyOutput`],
116
+ - [`~generation.SampleDecoderOnlyOutput`],
117
+ - [`~generation.BeamSearchDecoderOnlyOutput`],
118
+ - [`~generation.BeamSampleDecoderOnlyOutput`]
119
+
120
+ If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
121
+ [`~utils.ModelOutput`] types are:
122
+
123
+ - [`~generation.GreedySearchEncoderDecoderOutput`],
124
+ - [`~generation.SampleEncoderDecoderOutput`],
125
+ - [`~generation.BeamSearchEncoderDecoderOutput`],
126
+ - [`~generation.BeamSampleEncoderDecoderOutput`]
127
+ """
128
+ #setup_seed(seed)
129
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
130
+ self._validate_model_class()
131
+
132
+ # priority: `generation_config` argument > `model.generation_config` (the default generation config)
133
+ if generation_config is None:
134
+ # legacy: users may modify the model configuration to control generation -- update the generation config
135
+ # model attribute accordingly, if it was created from the model config
136
+ if self.generation_config._from_model_config:
137
+ new_generation_config = StreamGenerationConfig.from_model_config(
138
+ self.config
139
+ )
140
+ if new_generation_config != self.generation_config:
141
+ warnings.warn(
142
+ "You have modified the pretrained model configuration to control generation. This is a"
143
+ " deprecated strategy to control generation and will be removed soon, in a future version."
144
+ " Please use a generation configuration file (see"
145
+ " https://huggingface.co/docs/transformers/main_classes/text_generation)"
146
+ )
147
+ self.generation_config = new_generation_config
148
+ generation_config = self.generation_config
149
+
150
+ generation_config = copy.deepcopy(generation_config)
151
+ model_kwargs = generation_config.update(
152
+ **kwargs
153
+ ) # All unused kwargs must be model kwargs
154
+ # self._validate_model_kwargs(model_kwargs.copy())
155
+
156
+ # 2. Set generation parameters if not already defined
157
+ logits_processor = (
158
+ logits_processor if logits_processor is not None else LogitsProcessorList()
159
+ )
160
+ stopping_criteria = (
161
+ stopping_criteria
162
+ if stopping_criteria is not None
163
+ else StoppingCriteriaList()
164
+ )
165
+
166
+ if (
167
+ generation_config.pad_token_id is None
168
+ and generation_config.eos_token_id is not None
169
+ ):
170
+ if model_kwargs.get("attention_mask", None) is None:
171
+ logger.warning(
172
+ "The attention mask and the pad token id were not set. As a consequence, you may observe "
173
+ "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
174
+ )
175
+ eos_token_id = generation_config.eos_token_id
176
+ if isinstance(eos_token_id, list):
177
+ eos_token_id = eos_token_id[0]
178
+ logger.warning(
179
+ f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
180
+ )
181
+ generation_config.pad_token_id = eos_token_id
182
+
183
+ # 3. Define model inputs
184
+ # inputs_tensor has to be defined
185
+ # model_input_name is defined if model-specific keyword input is passed
186
+ # otherwise model_input_name is None
187
+ # all model-specific keyword inputs are removed from `model_kwargs`
188
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
189
+ inputs, generation_config.bos_token_id, model_kwargs
190
+ )
191
+ batch_size = inputs_tensor.shape[0]
192
+
193
+ # 4. Define other model kwargs
194
+ model_kwargs["output_attentions"] = generation_config.output_attentions
195
+ model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
196
+ model_kwargs["use_cache"] = generation_config.use_cache
197
+
198
+ accepts_attention_mask = "attention_mask" in set(
199
+ inspect.signature(self.forward).parameters.keys()
200
+ )
201
+ requires_attention_mask = "encoder_outputs" not in model_kwargs
202
+
203
+ if (
204
+ model_kwargs.get("attention_mask", None) is None
205
+ and requires_attention_mask
206
+ and accepts_attention_mask
207
+ ):
208
+ model_kwargs[
209
+ "attention_mask"
210
+ ] = self._prepare_attention_mask_for_generation(
211
+ inputs_tensor,
212
+ generation_config.pad_token_id,
213
+ generation_config.eos_token_id,
214
+ )
215
+
216
+ # decoder-only models should use left-padding for generation
217
+ if not self.config.is_encoder_decoder:
218
+ if (
219
+ generation_config.pad_token_id is not None
220
+ and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
221
+ > 0
222
+ ):
223
+ logger.warning(
224
+ "A decoder-only architecture is being used, but right-padding was detected! For correct "
225
+ "generation results, please set `padding_side='left'` when initializing the tokenizer."
226
+ )
227
+
228
+ if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
229
+ # if model is encoder decoder encoder_outputs are created
230
+ # and added to `model_kwargs`
231
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
232
+ inputs_tensor, model_kwargs, model_input_name
233
+ )
234
+
235
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
236
+ if self.config.is_encoder_decoder:
237
+ input_ids = self._prepare_decoder_input_ids_for_generation(
238
+ batch_size,
239
+ decoder_start_token_id=generation_config.decoder_start_token_id,
240
+ bos_token_id=generation_config.bos_token_id,
241
+ model_kwargs=model_kwargs,
242
+ device=inputs_tensor.device,
243
+ )
244
+ else:
245
+ # if decoder-only then inputs_tensor has to be `input_ids`
246
+ input_ids = inputs_tensor
247
+
248
+ # 6. Prepare `max_length` depending on other stopping criteria.
249
+ input_ids_seq_length = input_ids.shape[-1]
250
+ has_default_max_length = (
251
+ kwargs.get("max_length") is None
252
+ and generation_config.max_length is not None
253
+ )
254
+ if has_default_max_length and generation_config.max_new_tokens is None:
255
+ warnings.warn(
256
+ "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to"
257
+ f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the"
258
+ " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
259
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
260
+ UserWarning,
261
+ )
262
+ elif has_default_max_length and generation_config.max_new_tokens is not None:
263
+ generation_config.max_length = (
264
+ generation_config.max_new_tokens + input_ids_seq_length
265
+ )
266
+ elif (
267
+ not has_default_max_length and generation_config.max_new_tokens is not None
268
+ ):
269
+ raise ValueError(
270
+ "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
271
+ " limit to the generated output length. Remove one of those arguments. Please refer to the"
272
+ " documentation for more information. "
273
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
274
+ )
275
+
276
+ if (
277
+ generation_config.min_length is not None
278
+ and generation_config.min_length > generation_config.max_length
279
+ ):
280
+ raise ValueError(
281
+ f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
282
+ f" the maximum length ({generation_config.max_length})"
283
+ )
284
+ if input_ids_seq_length >= generation_config.max_length:
285
+ input_ids_string = (
286
+ "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
287
+ )
288
+ logger.warning(
289
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
290
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
291
+ " increasing `max_new_tokens`."
292
+ )
293
+
294
+ # 7. determine generation mode
295
+ is_constraint_gen_mode = (
296
+ generation_config.constraints is not None
297
+ or generation_config.force_words_ids is not None
298
+ )
299
+
300
+ is_contrastive_search_gen_mode = (
301
+ generation_config.top_k is not None
302
+ and generation_config.top_k > 1
303
+ and generation_config.do_sample is False
304
+ and generation_config.penalty_alpha is not None
305
+ and generation_config.penalty_alpha > 0
306
+ )
307
+
308
+ is_greedy_gen_mode = (
309
+ (generation_config.num_beams == 1)
310
+ and (generation_config.num_beam_groups == 1)
311
+ and generation_config.do_sample is False
312
+ and not is_constraint_gen_mode
313
+ and not is_contrastive_search_gen_mode
314
+ )
315
+ is_sample_gen_mode = (
316
+ (generation_config.num_beams == 1)
317
+ and (generation_config.num_beam_groups == 1)
318
+ and generation_config.do_sample is True
319
+ and generation_config.do_stream is False
320
+ and not is_constraint_gen_mode
321
+ and not is_contrastive_search_gen_mode
322
+ )
323
+ is_sample_gen_stream_mode = (
324
+ (generation_config.num_beams == 1)
325
+ and (generation_config.num_beam_groups == 1)
326
+ and generation_config.do_stream is True
327
+ and not is_constraint_gen_mode
328
+ and not is_contrastive_search_gen_mode
329
+ )
330
+ is_beam_gen_mode = (
331
+ (generation_config.num_beams > 1)
332
+ and (generation_config.num_beam_groups == 1)
333
+ and generation_config.do_sample is False
334
+ and not is_constraint_gen_mode
335
+ and not is_contrastive_search_gen_mode
336
+ )
337
+ is_beam_sample_gen_mode = (
338
+ (generation_config.num_beams > 1)
339
+ and (generation_config.num_beam_groups == 1)
340
+ and generation_config.do_sample is True
341
+ and not is_constraint_gen_mode
342
+ and not is_contrastive_search_gen_mode
343
+ )
344
+ is_group_beam_gen_mode = (
345
+ (generation_config.num_beams > 1)
346
+ and (generation_config.num_beam_groups > 1)
347
+ and not is_constraint_gen_mode
348
+ and not is_contrastive_search_gen_mode
349
+ )
350
+
351
+ if generation_config.num_beam_groups > generation_config.num_beams:
352
+ raise ValueError(
353
+ "`num_beam_groups` has to be smaller or equal to `num_beams`"
354
+ )
355
+ if is_group_beam_gen_mode and generation_config.do_sample is True:
356
+ raise ValueError(
357
+ "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
358
+ )
359
+
360
+ if self.device.type != input_ids.device.type:
361
+ warnings.warn(
362
+ "You are calling .generate() with the `input_ids` being on a device type different"
363
+ f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
364
+ f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
365
+ " Please make sure that you have put `input_ids` to the"
366
+ f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
367
+ " running `.generate()`.",
368
+ UserWarning,
369
+ )
370
+ # 8. prepare distribution pre_processing samplers
371
+ logits_processor = self._get_logits_processor(
372
+ generation_config=generation_config,
373
+ input_ids_seq_length=input_ids_seq_length,
374
+ encoder_input_ids=inputs_tensor,
375
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
376
+ logits_processor=logits_processor,
377
+ )
378
+
379
+ # 9. prepare stopping criteria
380
+ stopping_criteria = self._get_stopping_criteria(
381
+ generation_config=generation_config, stopping_criteria=stopping_criteria
382
+ )
383
+ # 10. go into different generation modes
384
+ if is_greedy_gen_mode:
385
+ if generation_config.num_return_sequences > 1:
386
+ raise ValueError(
387
+ f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
388
+ " greedy search."
389
+ )
390
+
391
+ # 11. run greedy search
392
+ return self.greedy_search(
393
+ input_ids,
394
+ logits_processor=logits_processor,
395
+ stopping_criteria=stopping_criteria,
396
+ pad_token_id=generation_config.pad_token_id,
397
+ eos_token_id=generation_config.eos_token_id,
398
+ output_scores=generation_config.output_scores,
399
+ return_dict_in_generate=generation_config.return_dict_in_generate,
400
+ synced_gpus=synced_gpus,
401
+ **model_kwargs,
402
+ )
403
+
404
+ elif is_contrastive_search_gen_mode:
405
+ if generation_config.num_return_sequences > 1:
406
+ raise ValueError(
407
+ f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
408
+ " contrastive search."
409
+ )
410
+
411
+ return self.contrastive_search(
412
+ input_ids,
413
+ top_k=generation_config.top_k,
414
+ penalty_alpha=generation_config.penalty_alpha,
415
+ logits_processor=logits_processor,
416
+ stopping_criteria=stopping_criteria,
417
+ pad_token_id=generation_config.pad_token_id,
418
+ eos_token_id=generation_config.eos_token_id,
419
+ output_scores=generation_config.output_scores,
420
+ return_dict_in_generate=generation_config.return_dict_in_generate,
421
+ synced_gpus=synced_gpus,
422
+ **model_kwargs,
423
+ )
424
+
425
+ elif is_sample_gen_mode:
426
+ # 11. prepare logits warper
427
+ logits_warper = self._get_logits_warper(generation_config)
428
+
429
+ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
430
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
431
+ input_ids=input_ids,
432
+ expand_size=generation_config.num_return_sequences,
433
+ is_encoder_decoder=self.config.is_encoder_decoder,
434
+ **model_kwargs,
435
+ )
436
+
437
+ # 13. run sample
438
+ return self.sample(
439
+ input_ids,
440
+ logits_processor=logits_processor,
441
+ logits_warper=logits_warper,
442
+ stopping_criteria=stopping_criteria,
443
+ pad_token_id=generation_config.pad_token_id,
444
+ eos_token_id=generation_config.eos_token_id,
445
+ output_scores=generation_config.output_scores,
446
+ return_dict_in_generate=generation_config.return_dict_in_generate,
447
+ synced_gpus=synced_gpus,
448
+ **model_kwargs,
449
+ )
450
+ elif is_sample_gen_stream_mode:
451
+ # 11. prepare logits warper
452
+ logits_warper = self._get_logits_warper(generation_config)
453
+
454
+ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
455
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
456
+ input_ids=input_ids,
457
+ expand_size=generation_config.num_return_sequences,
458
+ is_encoder_decoder=self.config.is_encoder_decoder,
459
+ **model_kwargs,
460
+ )
461
+
462
+ # 13. run sample
463
+ return self.sample_stream(
464
+ input_ids,
465
+ logits_processor=logits_processor,
466
+ logits_warper=logits_warper,
467
+ stopping_criteria=stopping_criteria,
468
+ pad_token_id=generation_config.pad_token_id,
469
+ eos_token_id=generation_config.eos_token_id,
470
+ output_scores=generation_config.output_scores,
471
+ return_dict_in_generate=generation_config.return_dict_in_generate,
472
+ synced_gpus=synced_gpus,
473
+ **model_kwargs,
474
+ )
475
+ elif is_beam_gen_mode:
476
+ if generation_config.num_return_sequences > generation_config.num_beams:
477
+ raise ValueError(
478
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
479
+ )
480
+
481
+ if stopping_criteria.max_length is None:
482
+ raise ValueError(
483
+ "`max_length` needs to be a stopping_criteria for now."
484
+ )
485
+
486
+ # 11. prepare beam search scorer
487
+ beam_scorer = BeamSearchScorer(
488
+ batch_size=batch_size,
489
+ num_beams=generation_config.num_beams,
490
+ device=inputs_tensor.device,
491
+ length_penalty=generation_config.length_penalty,
492
+ do_early_stopping=generation_config.early_stopping,
493
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
494
+ )
495
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
496
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
497
+ input_ids=input_ids,
498
+ expand_size=generation_config.num_beams,
499
+ is_encoder_decoder=self.config.is_encoder_decoder,
500
+ **model_kwargs,
501
+ )
502
+ # 13. run beam search
503
+ return self.beam_search(
504
+ input_ids,
505
+ beam_scorer,
506
+ logits_processor=logits_processor,
507
+ stopping_criteria=stopping_criteria,
508
+ pad_token_id=generation_config.pad_token_id,
509
+ eos_token_id=generation_config.eos_token_id,
510
+ output_scores=generation_config.output_scores,
511
+ return_dict_in_generate=generation_config.return_dict_in_generate,
512
+ synced_gpus=synced_gpus,
513
+ **model_kwargs,
514
+ )
515
+
516
+ elif is_beam_sample_gen_mode:
517
+ # 11. prepare logits warper
518
+ logits_warper = self._get_logits_warper(generation_config)
519
+
520
+ if stopping_criteria.max_length is None:
521
+ raise ValueError(
522
+ "`max_length` needs to be a stopping_criteria for now."
523
+ )
524
+ # 12. prepare beam search scorer
525
+ beam_scorer = BeamSearchScorer(
526
+ batch_size=batch_size * generation_config.num_return_sequences,
527
+ num_beams=generation_config.num_beams,
528
+ device=inputs_tensor.device,
529
+ length_penalty=generation_config.length_penalty,
530
+ do_early_stopping=generation_config.early_stopping,
531
+ )
532
+
533
+ # 13. interleave input_ids with `num_beams` additional sequences per batch
534
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
535
+ input_ids=input_ids,
536
+ expand_size=generation_config.num_beams
537
+ * generation_config.num_return_sequences,
538
+ is_encoder_decoder=self.config.is_encoder_decoder,
539
+ **model_kwargs,
540
+ )
541
+
542
+ # 14. run beam sample
543
+ return self.beam_sample(
544
+ input_ids,
545
+ beam_scorer,
546
+ logits_processor=logits_processor,
547
+ logits_warper=logits_warper,
548
+ stopping_criteria=stopping_criteria,
549
+ pad_token_id=generation_config.pad_token_id,
550
+ eos_token_id=generation_config.eos_token_id,
551
+ output_scores=generation_config.output_scores,
552
+ return_dict_in_generate=generation_config.return_dict_in_generate,
553
+ synced_gpus=synced_gpus,
554
+ **model_kwargs,
555
+ )
556
+
557
+ elif is_group_beam_gen_mode:
558
+ if generation_config.num_return_sequences > generation_config.num_beams:
559
+ raise ValueError(
560
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
561
+ )
562
+
563
+ if generation_config.num_beams % generation_config.num_beam_groups != 0:
564
+ raise ValueError(
565
+ "`num_beams` should be divisible by `num_beam_groups` for group beam search."
566
+ )
567
+
568
+ if stopping_criteria.max_length is None:
569
+ raise ValueError(
570
+ "`max_length` needs to be a stopping_criteria for now."
571
+ )
572
+
573
+ has_default_typical_p = (
574
+ kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
575
+ )
576
+ if not has_default_typical_p:
577
+ raise ValueError(
578
+ "Decoder argument `typical_p` is not supported with beam groups."
579
+ )
580
+
581
+ # 11. prepare beam search scorer
582
+ beam_scorer = BeamSearchScorer(
583
+ batch_size=batch_size,
584
+ num_beams=generation_config.num_beams,
585
+ max_length=stopping_criteria.max_length,
586
+ device=inputs_tensor.device,
587
+ length_penalty=generation_config.length_penalty,
588
+ do_early_stopping=generation_config.early_stopping,
589
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
590
+ num_beam_groups=generation_config.num_beam_groups,
591
+ )
592
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
593
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
594
+ input_ids=input_ids,
595
+ expand_size=generation_config.num_beams,
596
+ is_encoder_decoder=self.config.is_encoder_decoder,
597
+ **model_kwargs,
598
+ )
599
+ # 13. run beam search
600
+ return self.group_beam_search(
601
+ input_ids,
602
+ beam_scorer,
603
+ logits_processor=logits_processor,
604
+ stopping_criteria=stopping_criteria,
605
+ pad_token_id=generation_config.pad_token_id,
606
+ eos_token_id=generation_config.eos_token_id,
607
+ output_scores=generation_config.output_scores,
608
+ return_dict_in_generate=generation_config.return_dict_in_generate,
609
+ synced_gpus=synced_gpus,
610
+ **model_kwargs,
611
+ )
612
+
613
+ elif is_constraint_gen_mode:
614
+ if generation_config.num_return_sequences > generation_config.num_beams:
615
+ raise ValueError(
616
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
617
+ )
618
+
619
+ if stopping_criteria.max_length is None:
620
+ raise ValueError(
621
+ "`max_length` needs to be a stopping_criteria for now."
622
+ )
623
+
624
+ if generation_config.num_beams <= 1:
625
+ raise ValueError(
626
+ "`num_beams` needs to be greater than 1 for constrained generation."
627
+ )
628
+
629
+ if generation_config.do_sample:
630
+ raise ValueError(
631
+ "`do_sample` needs to be false for constrained generation."
632
+ )
633
+
634
+ if (
635
+ generation_config.num_beam_groups is not None
636
+ and generation_config.num_beam_groups > 1
637
+ ):
638
+ raise ValueError(
639
+ "`num_beam_groups` not supported yet for constrained generation."
640
+ )
641
+
642
+ final_constraints = []
643
+ if generation_config.constraints is not None:
644
+ final_constraints = generation_config.constraints
645
+
646
+ if generation_config.force_words_ids is not None:
647
+
648
+ def typeerror():
649
+ raise ValueError(
650
+ "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
651
+ f"of positive integers, but is {generation_config.force_words_ids}."
652
+ )
653
+
654
+ if (
655
+ not isinstance(generation_config.force_words_ids, list)
656
+ or len(generation_config.force_words_ids) == 0
657
+ ):
658
+ typeerror()
659
+
660
+ for word_ids in generation_config.force_words_ids:
661
+ if isinstance(word_ids[0], list):
662
+ if not isinstance(word_ids, list) or len(word_ids) == 0:
663
+ typeerror()
664
+ if any(
665
+ not isinstance(token_ids, list) for token_ids in word_ids
666
+ ):
667
+ typeerror()
668
+ if any(
669
+ any(
670
+ (not isinstance(token_id, int) or token_id < 0)
671
+ for token_id in token_ids
672
+ )
673
+ for token_ids in word_ids
674
+ ):
675
+ typeerror()
676
+
677
+ constraint = DisjunctiveConstraint(word_ids)
678
+ else:
679
+ if not isinstance(word_ids, list) or len(word_ids) == 0:
680
+ typeerror()
681
+ if any(
682
+ (not isinstance(token_id, int) or token_id < 0)
683
+ for token_id in word_ids
684
+ ):
685
+ typeerror()
686
+
687
+ constraint = PhrasalConstraint(word_ids)
688
+ final_constraints.append(constraint)
689
+
690
+ # 11. prepare beam search scorer
691
+ constrained_beam_scorer = ConstrainedBeamSearchScorer(
692
+ constraints=final_constraints,
693
+ batch_size=batch_size,
694
+ num_beams=generation_config.num_beams,
695
+ device=inputs_tensor.device,
696
+ length_penalty=generation_config.length_penalty,
697
+ do_early_stopping=generation_config.early_stopping,
698
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
699
+ )
700
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
701
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
702
+ input_ids=input_ids,
703
+ expand_size=generation_config.num_beams,
704
+ is_encoder_decoder=self.config.is_encoder_decoder,
705
+ **model_kwargs,
706
+ )
707
+ # 13. run beam search
708
+ return self.constrained_beam_search(
709
+ input_ids,
710
+ constrained_beam_scorer=constrained_beam_scorer,
711
+ logits_processor=logits_processor,
712
+ stopping_criteria=stopping_criteria,
713
+ pad_token_id=generation_config.pad_token_id,
714
+ eos_token_id=generation_config.eos_token_id,
715
+ output_scores=generation_config.output_scores,
716
+ return_dict_in_generate=generation_config.return_dict_in_generate,
717
+ synced_gpus=synced_gpus,
718
+ **model_kwargs,
719
+ )
720
+
721
+ @torch.no_grad()
722
+ def sample_stream(
723
+ self,
724
+ input_ids: torch.LongTensor,
725
+ logits_processor: Optional[LogitsProcessorList] = None,
726
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
727
+ logits_warper: Optional[LogitsProcessorList] = None,
728
+ max_length: Optional[int] = None,
729
+ pad_token_id: Optional[int] = None,
730
+ eos_token_id: Optional[Union[int, List[int]]] = None,
731
+ output_attentions: Optional[bool] = None,
732
+ output_hidden_states: Optional[bool] = None,
733
+ output_scores: Optional[bool] = None,
734
+ return_dict_in_generate: Optional[bool] = None,
735
+ synced_gpus: Optional[bool] = False,
736
+ **model_kwargs,
737
+ ) -> Union[SampleOutput, torch.LongTensor]:
738
+ r"""
739
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
740
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
741
+
742
+ <Tip warning={true}>
743
+
744
+ In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
745
+ For an overview of generation strategies and code examples, check the [following
746
+ guide](./generation_strategies).
747
+
748
+ </Tip>
749
+
750
+ Parameters:
751
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
752
+ The sequence used as a prompt for the generation.
753
+ logits_processor (`LogitsProcessorList`, *optional*):
754
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
755
+ used to modify the prediction scores of the language modeling head applied at each generation step.
756
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
757
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
758
+ used to tell if the generation loop should stop.
759
+ logits_warper (`LogitsProcessorList`, *optional*):
760
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
761
+ to warp the prediction score distribution of the language modeling head applied before multinomial
762
+ sampling at each generation step.
763
+ max_length (`int`, *optional*, defaults to 20):
764
+ **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
765
+ tokens. The maximum length of the sequence to be generated.
766
+ pad_token_id (`int`, *optional*):
767
+ The id of the *padding* token.
768
+ eos_token_id (`int`, *optional*):
769
+ The id of the *end-of-sequence* token.
770
+ output_attentions (`bool`, *optional*, defaults to `False`):
771
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
772
+ returned tensors for more details.
773
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
774
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
775
+ for more details.
776
+ output_scores (`bool`, *optional*, defaults to `False`):
777
+ Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
778
+ return_dict_in_generate (`bool`, *optional*, defaults to `False`):
779
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
780
+ synced_gpus (`bool`, *optional*, defaults to `False`):
781
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
782
+ model_kwargs:
783
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
784
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
785
+
786
+ Return:
787
+ [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`:
788
+ A `torch.LongTensor` containing the generated tokens (default behaviour) or a
789
+ [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
790
+ `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if
791
+ `model.config.is_encoder_decoder=True`.
792
+
793
+ Examples:
794
+
795
+ ```python
796
+ >>> from transformers import (
797
+ ... AutoTokenizer,
798
+ ... AutoModelForCausalLM,
799
+ ... LogitsProcessorList,
800
+ ... MinLengthLogitsProcessor,
801
+ ... TopKLogitsWarper,
802
+ ... TemperatureLogitsWarper,
803
+ ... StoppingCriteriaList,
804
+ ... MaxLengthCriteria,
805
+ ... )
806
+ >>> import torch
807
+
808
+ >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
809
+ >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
810
+
811
+ >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
812
+ >>> model.config.pad_token_id = model.config.eos_token_id
813
+ >>> model.generation_config.pad_token_id = model.config.eos_token_id
814
+
815
+ >>> input_prompt = "Today is a beautiful day, and"
816
+ >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
817
+
818
+ >>> # instantiate logits processors
819
+ >>> logits_processor = LogitsProcessorList(
820
+ ... [
821
+ ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),
822
+ ... ]
823
+ ... )
824
+ >>> # instantiate logits processors
825
+ >>> logits_warper = LogitsProcessorList(
826
+ ... [
827
+ ... TopKLogitsWarper(50),
828
+ ... TemperatureLogitsWarper(0.7),
829
+ ... ]
830
+ ... )
831
+
832
+ >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
833
+
834
+ >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
835
+ >>> outputs = model.sample(
836
+ ... input_ids,
837
+ ... logits_processor=logits_processor,
838
+ ... logits_warper=logits_warper,
839
+ ... stopping_criteria=stopping_criteria,
840
+ ... )
841
+
842
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
843
+ ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
844
+ ```"""
845
+ # init values
846
+ logits_processor = (
847
+ logits_processor if logits_processor is not None else LogitsProcessorList()
848
+ )
849
+ stopping_criteria = (
850
+ stopping_criteria
851
+ if stopping_criteria is not None
852
+ else StoppingCriteriaList()
853
+ )
854
+ if max_length is not None:
855
+ warnings.warn(
856
+ "`max_length` is deprecated in this function, use"
857
+ " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
858
+ UserWarning,
859
+ )
860
+ stopping_criteria = validate_stopping_criteria(
861
+ stopping_criteria, max_length
862
+ )
863
+ logits_warper = (
864
+ logits_warper if logits_warper is not None else LogitsProcessorList()
865
+ )
866
+ pad_token_id = (
867
+ pad_token_id
868
+ if pad_token_id is not None
869
+ else self.generation_config.pad_token_id
870
+ )
871
+ eos_token_id = (
872
+ eos_token_id
873
+ if eos_token_id is not None
874
+ else self.generation_config.eos_token_id
875
+ )
876
+ if isinstance(eos_token_id, int):
877
+ eos_token_id = [eos_token_id]
878
+ output_scores = (
879
+ output_scores
880
+ if output_scores is not None
881
+ else self.generation_config.output_scores
882
+ )
883
+ output_attentions = (
884
+ output_attentions
885
+ if output_attentions is not None
886
+ else self.generation_config.output_attentions
887
+ )
888
+ output_hidden_states = (
889
+ output_hidden_states
890
+ if output_hidden_states is not None
891
+ else self.generation_config.output_hidden_states
892
+ )
893
+ return_dict_in_generate = (
894
+ return_dict_in_generate
895
+ if return_dict_in_generate is not None
896
+ else self.generation_config.return_dict_in_generate
897
+ )
898
+
899
+ # init attention / hidden states / scores tuples
900
+ scores = () if (return_dict_in_generate and output_scores) else None
901
+ decoder_attentions = (
902
+ () if (return_dict_in_generate and output_attentions) else None
903
+ )
904
+ cross_attentions = (
905
+ () if (return_dict_in_generate and output_attentions) else None
906
+ )
907
+ decoder_hidden_states = (
908
+ () if (return_dict_in_generate and output_hidden_states) else None
909
+ )
910
+
911
+ # keep track of which sequences are already finished
912
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
913
+
914
+ this_peer_finished = False # used by synced_gpus only
915
+ # auto-regressive generation
916
+ while True:
917
+ if synced_gpus:
918
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
919
+ # The following logic allows an early break if all peers finished generating their sequence
920
+ this_peer_finished_flag = torch.tensor(
921
+ 0.0 if this_peer_finished else 1.0
922
+ ).to(input_ids.device)
923
+ # send 0.0 if we finished, 1.0 otherwise
924
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
925
+ # did all peers finish? the reduced sum will be 0.0 then
926
+ if this_peer_finished_flag.item() == 0.0:
927
+ break
928
+
929
+ # prepare model inputs
930
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
931
+
932
+ # forward pass to get next token
933
+ outputs = self(
934
+ **model_inputs,
935
+ return_dict=True,
936
+ output_attentions=output_attentions,
937
+ output_hidden_states=output_hidden_states,
938
+ )
939
+
940
+ if synced_gpus and this_peer_finished:
941
+ continue # don't waste resources running the code we don't need
942
+
943
+ next_token_logits = outputs.logits[:, -1, :]
944
+
945
+ # pre-process distribution
946
+ next_token_scores = logits_processor(input_ids, next_token_logits)
947
+ next_token_scores = logits_warper(input_ids, next_token_scores)
948
+
949
+ # Store scores, attentions and hidden_states when required
950
+ if return_dict_in_generate:
951
+ if output_scores:
952
+ scores += (next_token_scores,)
953
+ if output_attentions:
954
+ decoder_attentions += (
955
+ (outputs.decoder_attentions,)
956
+ if self.config.is_encoder_decoder
957
+ else (outputs.attentions,)
958
+ )
959
+ if self.config.is_encoder_decoder:
960
+ cross_attentions += (outputs.cross_attentions,)
961
+
962
+ if output_hidden_states:
963
+ decoder_hidden_states += (
964
+ (outputs.decoder_hidden_states,)
965
+ if self.config.is_encoder_decoder
966
+ else (outputs.hidden_states,)
967
+ )
968
+
969
+ # sample
970
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
971
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
972
+
973
+ # finished sentences should have their next token be a padding token
974
+ if eos_token_id is not None:
975
+ if pad_token_id is None:
976
+ raise ValueError(
977
+ "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
978
+ )
979
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
980
+ 1 - unfinished_sequences
981
+ )
982
+ yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1])
983
+ # update generated ids, model inputs, and length for next step
984
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
985
+ model_kwargs = self._update_model_kwargs_for_generation(
986
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
987
+ )
988
+
989
+ # if eos_token was found in one sentence, set sentence to finished
990
+ if eos_token_id is not None:
991
+ unfinished_sequences = unfinished_sequences.mul(
992
+ (sum(next_tokens != i for i in eos_token_id)).long()
993
+ )
994
+
995
+ # stop when each sentence is finished, or if we exceed the maximum length
996
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
997
+ if not synced_gpus:
998
+ break
999
+ else:
1000
+ this_peer_finished = True
1001
+
1002
+
1003
+ def init_stream_support():
1004
+ """Overload PreTrainedModel for streaming."""
1005
+ PreTrainedModel.generate_stream = NewGenerationMixin.generate
1006
+ PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
1007
+
1008
+
1009
+ if __name__ == "__main__":
1010
+ from transformers import PreTrainedModel
1011
+ from transformers import AutoTokenizer, AutoModelForCausalLM
1012
+
1013
+ PreTrainedModel.generate = NewGenerationMixin.generate
1014
+ PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
1015
+ model = AutoModelForCausalLM.from_pretrained(
1016
+ "bigscience/bloom-560m", torch_dtype=torch.float16
1017
+ )
1018
+
1019
+ tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
1020
+ model = model.to("cuda:0")
1021
+ model = model.eval()
1022
+ prompt_text = "hello? \n"
1023
+ input_ids = tokenizer(
1024
+ prompt_text, return_tensors="pt", add_special_tokens=False
1025
+ ).input_ids
1026
+ input_ids = input_ids.to("cuda:0")
1027
+
1028
+ with torch.no_grad():
1029
+ result = model.generate(
1030
+ input_ids,
1031
+ max_new_tokens=200,
1032
+ do_sample=True,
1033
+ top_k=30,
1034
+ top_p=0.85,
1035
+ temperature=0.35,
1036
+ repetition_penalty=1.2,
1037
+ early_stopping=True,
1038
+ seed=0,
1039
+ )
1040
+ print(tokenizer.decode(result, skip_special_tokens=True))
1041
+ generator = model.generate(
1042
+ input_ids,
1043
+ max_new_tokens=200,
1044
+ do_sample=True,
1045
+ top_k=30,
1046
+ top_p=0.85,
1047
+ temperature=0.35,
1048
+ repetition_penalty=1.2,
1049
+ early_stopping=True,
1050
+ seed=0,
1051
+ do_stream=True,
1052
+ )
1053
+ stream_result = ""
1054
+ for x in generator:
1055
+ chunk = tokenizer.decode(x, skip_special_tokens=True)
1056
+ stream_result += chunk
1057
+ print(stream_result)
TTS/TTS/tts/layers/xtts/tokenizer.py CHANGED
@@ -1,206 +1,468 @@
1
- import json
2
  import os
3
  import re
 
4
 
5
- import inflect
6
- import pandas as pd
7
- import pypinyin
8
  import torch
9
- from num2words import num2words
10
  from tokenizers import Tokenizer
11
- from unidecode import unidecode
12
-
13
- from TTS.tts.utils.text.cleaners import english_cleaners
14
-
15
- _inflect = inflect.engine()
16
- _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
17
- _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
18
- _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
19
- _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
20
- _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
21
- _number_re = re.compile(r"[0-9]+")
22
-
23
-
24
- def _remove_commas(m):
25
- return m.group(1).replace(",", "")
26
-
27
-
28
- def _expand_decimal_point(m):
29
- return m.group(1).replace(".", " point ")
30
-
31
-
32
- def _expand_dollars(m):
33
- match = m.group(1)
34
- parts = match.split(".")
35
- if len(parts) > 2:
36
- return match + " dollars" # Unexpected format
37
- dollars = int(parts[0]) if parts[0] else 0
38
- cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
39
- if dollars and cents:
40
- dollar_unit = "dollar" if dollars == 1 else "dollars"
41
- cent_unit = "cent" if cents == 1 else "cents"
42
- return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
43
- elif dollars:
44
- dollar_unit = "dollar" if dollars == 1 else "dollars"
45
- return "%s %s" % (dollars, dollar_unit)
46
- elif cents:
47
- cent_unit = "cent" if cents == 1 else "cents"
48
- return "%s %s" % (cents, cent_unit)
49
- else:
50
- return "zero dollars"
51
-
52
-
53
- def _expand_ordinal(m):
54
- return _inflect.number_to_words(m.group(0))
55
-
56
-
57
- def _expand_number(m):
58
- num = int(m.group(0))
59
- if num > 1000 and num < 3000:
60
- if num == 2000:
61
- return "two thousand"
62
- elif num > 2000 and num < 2010:
63
- return "two thousand " + _inflect.number_to_words(num % 100)
64
- elif num % 100 == 0:
65
- return _inflect.number_to_words(num // 100) + " hundred"
66
- else:
67
- return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
68
- else:
69
- return _inflect.number_to_words(num, andword="")
70
-
71
-
72
- def normalize_numbers(text):
73
- text = re.sub(_comma_number_re, _remove_commas, text)
74
- text = re.sub(_pounds_re, r"\1 pounds", text)
75
- text = re.sub(_dollars_re, _expand_dollars, text)
76
- text = re.sub(_decimal_number_re, _expand_decimal_point, text)
77
- text = re.sub(_ordinal_re, _expand_ordinal, text)
78
- text = re.sub(_number_re, _expand_number, text)
79
- return text
80
 
 
 
 
81
 
82
- # Regular expression matching whitespace:
83
  _whitespace_re = re.compile(r"\s+")
84
 
85
  # List of (regular expression, replacement) pairs for abbreviations:
86
- _abbreviations = [
87
- (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
88
- for x in [
89
- ("mrs", "misess"),
90
- ("mr", "mister"),
91
- ("dr", "doctor"),
92
- ("st", "saint"),
93
- ("co", "company"),
94
- ("jr", "junior"),
95
- ("maj", "major"),
96
- ("gen", "general"),
97
- ("drs", "doctors"),
98
- ("rev", "reverend"),
99
- ("lt", "lieutenant"),
100
- ("hon", "honorable"),
101
- ("sgt", "sergeant"),
102
- ("capt", "captain"),
103
- ("esq", "esquire"),
104
- ("ltd", "limited"),
105
- ("col", "colonel"),
106
- ("ft", "fort"),
107
- ]
108
- ]
109
-
110
-
111
- def expand_abbreviations(text):
112
- for regex, replacement in _abbreviations:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  text = re.sub(regex, replacement, text)
114
  return text
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- def expand_numbers(text):
118
- return normalize_numbers(text)
119
-
120
-
121
- def lowercase(text):
122
- return text.lower()
123
-
124
-
125
- def collapse_whitespace(text):
126
- return re.sub(_whitespace_re, " ", text)
127
-
128
-
129
- def convert_to_ascii(text):
130
- return unidecode(text)
131
-
132
 
133
- def basic_cleaners(text):
134
- """Basic pipeline that lowercases and collapses whitespace without transliteration."""
135
- text = lowercase(text)
136
- text = collapse_whitespace(text)
137
- text = text.replace('"', "")
138
  return text
139
 
140
-
141
- def expand_numbers_multilang(text, lang):
142
- # TODO: Handle text more carefully. Currently, it just converts numbers without any context.
143
- # Find all numbers in the input string
144
- numbers = re.findall(r"\d+", text)
145
-
146
- # Transliterate the numbers to text
147
- for num in numbers:
148
- transliterated_num = "".join(num2words(num, lang=lang))
149
- text = text.replace(num, transliterated_num, 1)
150
-
151
  return text
152
 
153
-
154
- def transliteration_cleaners(text):
155
- """Pipeline for non-English text that transliterates to ASCII."""
156
- text = convert_to_ascii(text)
157
- text = lowercase(text)
158
- text = collapse_whitespace(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  return text
160
 
 
 
 
 
 
161
 
162
  def multilingual_cleaners(text, lang):
163
- text = lowercase(text)
164
- text = expand_numbers_multilang(text, lang)
165
- text = collapse_whitespace(text)
166
- text = text.replace('"', "")
167
- if lang == "tr":
168
  text = text.replace("İ", "i")
169
  text = text.replace("Ö", "ö")
170
  text = text.replace("Ü", "ü")
 
 
 
 
 
171
  return text
172
 
173
-
174
- def remove_extraneous_punctuation(word):
175
- replacement_punctuation = {"{": "(", "}": ")", "[": "(", "]": ")", "`": "'", "—": "-", "—": "-", "`": "'", "ʼ": "'"}
176
- replace = re.compile(
177
- "|".join([re.escape(k) for k in sorted(replacement_punctuation, key=len, reverse=True)]), flags=re.DOTALL
178
- )
179
- word = replace.sub(lambda x: replacement_punctuation[x.group(0)], word)
180
-
181
- # TODO: some of these are spoken ('@', '%', '+', etc). Integrate them into the cleaners.
182
- extraneous = re.compile(r"^[@#%_=\$\^&\*\+\\]$")
183
- word = extraneous.sub("", word)
184
- return word
185
-
186
-
187
- def arabic_cleaners(text):
188
  text = lowercase(text)
189
  text = collapse_whitespace(text)
190
  return text
191
 
 
 
192
 
193
- def chinese_cleaners(text):
 
194
  text = lowercase(text)
195
- text = "".join(
196
- [p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]
197
- )
198
  return text
199
 
200
-
201
  class VoiceBpeTokenizer:
202
  def __init__(self, vocab_file=None, preprocess=None):
203
  self.tokenizer = None
 
204
 
205
  if vocab_file is not None:
206
  with open(vocab_file, "r", encoding="utf-8") as f:
@@ -216,21 +478,20 @@ class VoiceBpeTokenizer:
216
  self.tokenizer = Tokenizer.from_file(vocab_file)
217
 
218
  def preprocess_text(self, txt, lang):
219
- if lang == "ja":
220
- import pykakasi
221
-
222
- kks = pykakasi.kakasi()
223
- results = kks.convert(txt)
224
- txt = " ".join([result["kana"] for result in results])
225
- txt = basic_cleaners(txt)
226
- elif lang == "en":
227
- txt = english_cleaners(txt)
228
- elif lang == "ar":
229
- txt = arabic_cleaners(txt)
230
- elif lang == "zh-cn":
231
- txt = chinese_cleaners(txt)
232
- else:
233
  txt = multilingual_cleaners(txt, lang)
 
 
 
 
 
 
 
 
 
 
 
 
234
  return txt
235
 
236
  def encode(self, txt, lang):
@@ -247,3 +508,9 @@ class VoiceBpeTokenizer:
247
  txt = txt.replace("[STOP]", "")
248
  txt = txt.replace("[UNK]", "")
249
  return txt
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
+ import json
4
 
 
 
 
5
  import torch
 
6
  from tokenizers import Tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ import pypinyin
9
+ from num2words import num2words
10
+ from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
11
 
 
12
  _whitespace_re = re.compile(r"\s+")
13
 
14
  # List of (regular expression, replacement) pairs for abbreviations:
15
+ _abbreviations = {
16
+ "en": [
17
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
18
+ for x in [
19
+ ("mrs", "misess"),
20
+ ("mr", "mister"),
21
+ ("dr", "doctor"),
22
+ ("st", "saint"),
23
+ ("co", "company"),
24
+ ("jr", "junior"),
25
+ ("maj", "major"),
26
+ ("gen", "general"),
27
+ ("drs", "doctors"),
28
+ ("rev", "reverend"),
29
+ ("lt", "lieutenant"),
30
+ ("hon", "honorable"),
31
+ ("sgt", "sergeant"),
32
+ ("capt", "captain"),
33
+ ("esq", "esquire"),
34
+ ("ltd", "limited"),
35
+ ("col", "colonel"),
36
+ ("ft", "fort"),
37
+ ]
38
+ ],
39
+ "es": [
40
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
41
+ for x in [
42
+ ("sra", "señora"),
43
+ ("sr", "señor"),
44
+ ("dr", "doctor"),
45
+ ("dra", "doctora"),
46
+ ("st", "santo"),
47
+ ("co", "compañía"),
48
+ ("jr", "junior"),
49
+ ("ltd", "limitada"),
50
+ ]
51
+ ],
52
+ "fr": [
53
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
54
+ for x in [
55
+ ("mme", "madame"),
56
+ ("mr", "monsieur"),
57
+ ("dr", "docteur"),
58
+ ("st", "saint"),
59
+ ("co", "compagnie"),
60
+ ("jr", "junior"),
61
+ ("ltd", "limitée"),
62
+ ]
63
+ ],
64
+ "de": [
65
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
66
+ for x in [
67
+ ("fr", "frau"),
68
+ ("dr", "doktor"),
69
+ ("st", "sankt"),
70
+ ("co", "firma"),
71
+ ("jr", "junior"),
72
+ ]
73
+ ],
74
+ "pt": [
75
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
76
+ for x in [
77
+ ("sra", "senhora"),
78
+ ("sr", "senhor"),
79
+ ("dr", "doutor"),
80
+ ("dra", "doutora"),
81
+ ("st", "santo"),
82
+ ("co", "companhia"),
83
+ ("jr", "júnior"),
84
+ ("ltd", "limitada"),
85
+ ]
86
+ ],
87
+ "it": [
88
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
89
+ for x in [
90
+ #("sig.ra", "signora"),
91
+ ("sig", "signore"),
92
+ ("dr", "dottore"),
93
+ ("st", "santo"),
94
+ ("co", "compagnia"),
95
+ ("jr", "junior"),
96
+ ("ltd", "limitata"),
97
+ ]
98
+ ],
99
+ "pl": [
100
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
101
+ for x in [
102
+ ("p", "pani"),
103
+ ("m", "pan"),
104
+ ("dr", "doktor"),
105
+ ("sw", "święty"),
106
+ ("jr", "junior"),
107
+ ]
108
+ ],
109
+ "ar": [
110
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
111
+ for x in [
112
+ # There are not many common abbreviations in Arabic as in English.
113
+ ]
114
+ ],
115
+ "zh-cn": [
116
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
117
+ for x in [
118
+ # Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
119
+ ]
120
+ ],
121
+ "cs": [
122
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
123
+ for x in [
124
+ ("dr", "doktor"), # doctor
125
+ ("ing", "inženýr"), # engineer
126
+ ("p", "pan"), # Could also map to pani for woman but no easy way to do it
127
+ # Other abbreviations would be specialized and not as common.
128
+ ]
129
+ ],
130
+ "ru": [
131
+ (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
132
+ for x in [
133
+ ("г-жа", "госпожа"), # Mrs.
134
+ ("г-н", "господин"), # Mr.
135
+ ("д-р", "доктор"), # doctor
136
+ # Other abbreviations are less common or specialized.
137
+ ]
138
+ ],
139
+ "nl": [
140
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
141
+ for x in [
142
+ ("dhr", "de heer"), # Mr.
143
+ ("mevr", "mevrouw"), # Mrs.
144
+ ("dr", "dokter"), # doctor
145
+ ("jhr", "jonkheer"), # young lord or nobleman
146
+ # Dutch uses more abbreviations, but these are the most common ones.
147
+ ]
148
+ ],
149
+ "tr": [
150
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
151
+ for x in [
152
+ ("b", "bay"), # Mr.
153
+ ("byk", "büyük"), # büyük
154
+ ("dr", "doktor"), # doctor
155
+ # Add other Turkish abbreviations here if needed.
156
+ ]
157
+ ],
158
+ }
159
+
160
+ def expand_abbreviations_multilingual(text, lang='en'):
161
+ for regex, replacement in _abbreviations[lang]:
162
  text = re.sub(regex, replacement, text)
163
  return text
164
 
165
+ _symbols_multilingual = {
166
+ 'en': [
167
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
168
+ for x in [
169
+ ("&", " and "),
170
+ ("@", " at "),
171
+ ("%", " percent "),
172
+ ("#", " hash "),
173
+ ("$", " dollar "),
174
+ ("£", " pound "),
175
+ ("°", " degree ")
176
+ ]
177
+ ],
178
+ 'es': [
179
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
180
+ for x in [
181
+ ("&", " y "),
182
+ ("@", " arroba "),
183
+ ("%", " por ciento "),
184
+ ("#", " numeral "),
185
+ ("$", " dolar "),
186
+ ("£", " libra "),
187
+ ("°", " grados ")
188
+ ]
189
+ ],
190
+ 'fr': [
191
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
192
+ for x in [
193
+ ("&", " et "),
194
+ ("@", " arobase "),
195
+ ("%", " pour cent "),
196
+ ("#", " dièse "),
197
+ ("$", " dollar "),
198
+ ("£", " livre "),
199
+ ("°", " degrés ")
200
+ ]
201
+ ],
202
+ 'de': [
203
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
204
+ for x in [
205
+ ("&", " und "),
206
+ ("@", " at "),
207
+ ("%", " prozent "),
208
+ ("#", " raute "),
209
+ ("$", " dollar "),
210
+ ("£", " pfund "),
211
+ ("°", " grad ")
212
+ ]
213
+ ],
214
+ 'pt': [
215
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
216
+ for x in [
217
+ ("&", " e "),
218
+ ("@", " arroba "),
219
+ ("%", " por cento "),
220
+ ("#", " cardinal "),
221
+ ("$", " dólar "),
222
+ ("£", " libra "),
223
+ ("°", " graus ")
224
+ ]
225
+ ],
226
+ 'it': [
227
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
228
+ for x in [
229
+ ("&", " e "),
230
+ ("@", " chiocciola "),
231
+ ("%", " per cento "),
232
+ ("#", " cancelletto "),
233
+ ("$", " dollaro "),
234
+ ("£", " sterlina "),
235
+ ("°", " gradi ")
236
+ ]
237
+ ],
238
+ 'pl': [
239
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
240
+ for x in [
241
+ ("&", " i "),
242
+ ("@", " małpa "),
243
+ ("%", " procent "),
244
+ ("#", " krzyżyk "),
245
+ ("$", " dolar "),
246
+ ("£", " funt "),
247
+ ("°", " stopnie ")
248
+ ]
249
+ ],
250
+ "ar": [
251
+ # Arabic
252
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
253
+ for x in [
254
+ ("&", " و "),
255
+ ("@", " على "),
256
+ ("%", " في المئة "),
257
+ ("#", " رقم "),
258
+ ("$", " دولار "),
259
+ ("£", " جنيه "),
260
+ ("°", " درجة ")
261
+ ]
262
+ ],
263
+ "zh-cn": [
264
+ # Chinese
265
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
266
+ for x in [
267
+ ("&", " 和 "),
268
+ ("@", " 在 "),
269
+ ("%", " 百分之 "),
270
+ ("#", " 号 "),
271
+ ("$", " 美元 "),
272
+ ("£", " 英镑 "),
273
+ ("°", " 度 ")
274
+ ]
275
+ ],
276
+ "cs": [
277
+ # Czech
278
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
279
+ for x in [
280
+ ("&", " a "),
281
+ ("@", " na "),
282
+ ("%", " procento "),
283
+ ("#", " křížek "),
284
+ ("$", " dolar "),
285
+ ("£", " libra "),
286
+ ("°", " stupně ")
287
+ ]
288
+ ],
289
+ "ru": [
290
+ # Russian
291
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
292
+ for x in [
293
+ ("&", " и "),
294
+ ("@", " собака "),
295
+ ("%", " процентов "),
296
+ ("#", " номер "),
297
+ ("$", " доллар "),
298
+ ("£", " фунт "),
299
+ ("°", " градус ")
300
+ ]
301
+ ],
302
+ "nl": [
303
+ # Dutch
304
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
305
+ for x in [
306
+ ("&", " en "),
307
+ ("@", " bij "),
308
+ ("%", " procent "),
309
+ ("#", " hekje "),
310
+ ("$", " dollar "),
311
+ ("£", " pond "),
312
+ ("°", " graden ")
313
+ ]
314
+ ],
315
+ "tr": [
316
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
317
+ for x in [
318
+ ("&", " ve "),
319
+ ("@", " at "),
320
+ ("%", " yüzde "),
321
+ ("#", " diyez "),
322
+ ("$", " dolar "),
323
+ ("£", " sterlin "),
324
+ ("°", " derece ")
325
+ ]
326
+ ],
327
+ }
328
+
329
+ def expand_symbols_multilingual(text, lang='en'):
330
+ for regex, replacement in _symbols_multilingual[lang]:
331
+ text = re.sub(regex, replacement, text)
332
+ text = text.replace(' ', ' ') # Ensure there are no double spaces
333
+ return text.strip()
334
+
335
+
336
+ _ordinal_re = {
337
+ "en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
338
+ "es": re.compile(r"([0-9]+)(º|ª|er|o|a|os|as)"),
339
+ "fr": re.compile(r"([0-9]+)(º|ª|er|re|e|ème)"),
340
+ "de": re.compile(r"([0-9]+)(st|nd|rd|th|º|ª|\.(?=\s|$))"),
341
+ "pt": re.compile(r"([0-9]+)(º|ª|o|a|os|as)"),
342
+ "it": re.compile(r"([0-9]+)(º|°|ª|o|a|i|e)"),
343
+ "pl": re.compile(r"([0-9]+)(º|ª|st|nd|rd|th)"),
344
+ "ar": re.compile(r"([0-9]+)(ون|ين|ث|ر|ى)"),
345
+ "cs": re.compile(r"([0-9]+)\.(?=\s|$)"), # In Czech, a dot is often used after the number to indicate ordinals.
346
+ "ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"),
347
+ "nl": re.compile(r"([0-9]+)(de|ste|e)"),
348
+ "tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"),
349
+ }
350
+ _number_re = re.compile(r"[0-9]+")
351
+ _currency_re = {
352
+ 'USD': re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
353
+ 'GBP': re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
354
+ 'EUR': re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))")
355
+ }
356
 
357
+ _comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
358
+ _dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
359
+ _decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
+ def _remove_commas(m):
362
+ text = m.group(0)
363
+ if "," in text:
364
+ text = text.replace(",", "")
 
365
  return text
366
 
367
+ def _remove_dots(m):
368
+ text = m.group(0)
369
+ if "." in text:
370
+ text = text.replace(".", "")
 
 
 
 
 
 
 
371
  return text
372
 
373
+ def _expand_decimal_point(m, lang='en'):
374
+ amount = m.group(1).replace(",", ".")
375
+ return num2words(float(amount), lang=lang if lang != "cs" else "cz")
376
+
377
+ def _expand_currency(m, lang='en', currency='USD'):
378
+ amount = float((re.sub(r'[^\d.]', '', m.group(0).replace(",", "."))))
379
+ full_amount = num2words(amount, to='currency', currency=currency, lang=lang if lang != "cs" else "cz")
380
+
381
+ and_equivalents = {
382
+ "en": ", ",
383
+ "es": " con ",
384
+ "fr": " et ",
385
+ "de": " und ",
386
+ "pt": " e ",
387
+ "it": " e ",
388
+ "pl": ", ",
389
+ "cs": ", ",
390
+ "ru": ", ",
391
+ "nl": ", ",
392
+ "ar": ", ",
393
+ "tr": ", ",
394
+ }
395
+
396
+ if amount.is_integer():
397
+ last_and = full_amount.rfind(and_equivalents[lang])
398
+ if last_and != -1:
399
+ full_amount = full_amount[:last_and]
400
+
401
+ return full_amount
402
+
403
+ def _expand_ordinal(m, lang='en'):
404
+ return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz")
405
+
406
+ def _expand_number(m, lang='en'):
407
+ return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz")
408
+
409
+ def expand_numbers_multilingual(text, lang='en'):
410
+ if lang == "zh-cn":
411
+ text = zh_num2words()(text)
412
+ else:
413
+ if lang in ["en", "ru"]:
414
+ text = re.sub(_comma_number_re, _remove_commas, text)
415
+ else:
416
+ text = re.sub(_dot_number_re, _remove_dots, text)
417
+ try:
418
+ text = re.sub(_currency_re['GBP'], lambda m: _expand_currency(m, lang, 'GBP'), text)
419
+ text = re.sub(_currency_re['USD'], lambda m: _expand_currency(m, lang, 'USD'), text)
420
+ text = re.sub(_currency_re['EUR'], lambda m: _expand_currency(m, lang, 'EUR'), text)
421
+ except:
422
+ pass
423
+ if lang != "tr":
424
+ text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
425
+ text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
426
+ text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
427
  return text
428
 
429
+ def lowercase(text):
430
+ return text.lower()
431
+
432
+ def collapse_whitespace(text):
433
+ return re.sub(_whitespace_re, " ", text)
434
 
435
  def multilingual_cleaners(text, lang):
436
+ text = text.replace('"', '')
437
+ if lang=="tr":
 
 
 
438
  text = text.replace("İ", "i")
439
  text = text.replace("Ö", "ö")
440
  text = text.replace("Ü", "ü")
441
+ text = lowercase(text)
442
+ text = expand_numbers_multilingual(text, lang)
443
+ text = expand_abbreviations_multilingual(text, lang)
444
+ text = expand_symbols_multilingual(text, lang=lang)
445
+ text = collapse_whitespace(text)
446
  return text
447
 
448
+ def basic_cleaners(text):
449
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  text = lowercase(text)
451
  text = collapse_whitespace(text)
452
  return text
453
 
454
+ def chinese_transliterate(text):
455
+ return "".join([p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)])
456
 
457
+ def japanese_cleaners(text, katsu):
458
+ text = katsu.romaji(text)
459
  text = lowercase(text)
 
 
 
460
  return text
461
 
 
462
  class VoiceBpeTokenizer:
463
  def __init__(self, vocab_file=None, preprocess=None):
464
  self.tokenizer = None
465
+ self.katsu = None
466
 
467
  if vocab_file is not None:
468
  with open(vocab_file, "r", encoding="utf-8") as f:
 
478
  self.tokenizer = Tokenizer.from_file(vocab_file)
479
 
480
  def preprocess_text(self, txt, lang):
481
+ if lang in ["en", "es", "fr", "de", "pt", "it", "pl", "ar", "cs", "ru", "nl", "tr", "zh-cn"]:
 
 
 
 
 
 
 
 
 
 
 
 
 
482
  txt = multilingual_cleaners(txt, lang)
483
+ if lang == "zh-cn":
484
+ txt = chinese_transliterate(txt)
485
+ elif lang == "ja":
486
+ assert txt[:4] == "[ja]", "Japanese speech should start with the [ja] token."
487
+ txt = txt[4:]
488
+ if self.katsu is None:
489
+ import cutlet
490
+ self.katsu = cutlet.Cutlet()
491
+ txt = japanese_cleaners(txt, self.katsu)
492
+ txt = "[ja]" + txt
493
+ else:
494
+ raise NotImplementedError()
495
  return txt
496
 
497
  def encode(self, txt, lang):
 
508
  txt = txt.replace("[STOP]", "")
509
  txt = txt.replace("[UNK]", "")
510
  return txt
511
+
512
+ def __len__(self):
513
+ return self.tokenizer.get_vocab_size()
514
+
515
+ def get_number_tokens(self):
516
+ return max(self.tokenizer.get_vocab().values()) + 1
TTS/TTS/tts/layers/xtts/zh_num2words.py ADDED
@@ -0,0 +1,1207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Authors:
2
+ # 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
3
+ # 2019.9 - 2022 Jiayu DU
4
+
5
+ import sys, os, argparse
6
+ import string, re
7
+ import csv
8
+
9
+ # ================================================================================ #
10
+ # basic constant
11
+ # ================================================================================ #
12
+ CHINESE_DIGIS = u'零一二三四五六七八九'
13
+ BIG_CHINESE_DIGIS_SIMPLIFIED = u'零壹贰叁肆伍陆柒捌玖'
14
+ BIG_CHINESE_DIGIS_TRADITIONAL = u'零壹貳參肆伍陸柒捌玖'
15
+ SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'十百千万'
16
+ SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'拾佰仟萬'
17
+ LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'亿兆京垓秭穰沟涧正载'
18
+ LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'億兆京垓秭穰溝澗正載'
19
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'十百千万'
20
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'拾佰仟萬'
21
+
22
+ ZERO_ALT = u'〇'
23
+ ONE_ALT = u'幺'
24
+ TWO_ALTS = [u'两', u'兩']
25
+
26
+ POSITIVE = [u'正', u'正']
27
+ NEGATIVE = [u'负', u'負']
28
+ POINT = [u'点', u'點']
29
+ # PLUS = [u'加', u'加']
30
+ # SIL = [u'杠', u'槓']
31
+
32
+ FILLER_CHARS = ['呃', '啊']
33
+
34
+ ER_WHITELIST = '(儿女|儿子|儿孙|女儿|儿媳|妻儿|' \
35
+ '胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|' \
36
+ '儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|' \
37
+ '佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)'
38
+ ER_WHITELIST_PATTERN = re.compile(ER_WHITELIST)
39
+
40
+ # 中文数字系统类型
41
+ NUMBERING_TYPES = ['low', 'mid', 'high']
42
+
43
+ CURRENCY_NAMES = '(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|' \
44
+ '里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)'
45
+ CURRENCY_UNITS = '((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)'
46
+ COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|' \
47
+ '砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|' \
48
+ '针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|' \
49
+ '毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|' \
50
+ '盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|' \
51
+ '纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)'
52
+
53
+
54
+ # Punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
55
+ CN_PUNCS_STOP = '!?。。'
56
+ CN_PUNCS_NONSTOP = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏·〈〉-'
57
+ CN_PUNCS = CN_PUNCS_STOP + CN_PUNCS_NONSTOP
58
+
59
+ PUNCS = CN_PUNCS + string.punctuation
60
+ PUNCS_TRANSFORM = str.maketrans(PUNCS, ' ' * len(PUNCS), '') # replace puncs with space
61
+
62
+
63
+ # https://zh.wikipedia.org/wiki/全行和半行
64
+ QJ2BJ = {
65
+ ' ': ' ',
66
+ '!': '!',
67
+ '"': '"',
68
+ '#': '#',
69
+ '$': '$',
70
+ '%': '%',
71
+ '&': '&',
72
+ ''': "'",
73
+ '(': '(',
74
+ ')': ')',
75
+ '*': '*',
76
+ '+': '+',
77
+ ',': ',',
78
+ '-': '-',
79
+ '.': '.',
80
+ '/': '/',
81
+ '0': '0',
82
+ '1': '1',
83
+ '2': '2',
84
+ '3': '3',
85
+ '4': '4',
86
+ '5': '5',
87
+ '6': '6',
88
+ '7': '7',
89
+ '8': '8',
90
+ '9': '9',
91
+ ':': ':',
92
+ ';': ';',
93
+ '<': '<',
94
+ '=': '=',
95
+ '>': '>',
96
+ '?': '?',
97
+ '@': '@',
98
+ 'A': 'A',
99
+ 'B': 'B',
100
+ 'C': 'C',
101
+ 'D': 'D',
102
+ 'E': 'E',
103
+ 'F': 'F',
104
+ 'G': 'G',
105
+ 'H': 'H',
106
+ 'I': 'I',
107
+ 'J': 'J',
108
+ 'K': 'K',
109
+ 'L': 'L',
110
+ 'M': 'M',
111
+ 'N': 'N',
112
+ 'O': 'O',
113
+ 'P': 'P',
114
+ 'Q': 'Q',
115
+ 'R': 'R',
116
+ 'S': 'S',
117
+ 'T': 'T',
118
+ 'U': 'U',
119
+ 'V': 'V',
120
+ 'W': 'W',
121
+ 'X': 'X',
122
+ 'Y': 'Y',
123
+ 'Z': 'Z',
124
+ '[': '[',
125
+ '\': '\\',
126
+ ']': ']',
127
+ '^': '^',
128
+ '_': '_',
129
+ '`': '`',
130
+ 'a': 'a',
131
+ 'b': 'b',
132
+ 'c': 'c',
133
+ 'd': 'd',
134
+ 'e': 'e',
135
+ 'f': 'f',
136
+ 'g': 'g',
137
+ 'h': 'h',
138
+ 'i': 'i',
139
+ 'j': 'j',
140
+ 'k': 'k',
141
+ 'l': 'l',
142
+ 'm': 'm',
143
+ 'n': 'n',
144
+ 'o': 'o',
145
+ 'p': 'p',
146
+ 'q': 'q',
147
+ 'r': 'r',
148
+ 's': 's',
149
+ 't': 't',
150
+ 'u': 'u',
151
+ 'v': 'v',
152
+ 'w': 'w',
153
+ 'x': 'x',
154
+ 'y': 'y',
155
+ 'z': 'z',
156
+ '{': '{',
157
+ '|': '|',
158
+ '}': '}',
159
+ '~': '~',
160
+ }
161
+ QJ2BJ_TRANSFORM = str.maketrans(''.join(QJ2BJ.keys()), ''.join(QJ2BJ.values()), '')
162
+
163
+
164
+ # 2013 China National Standard: https://zh.wikipedia.org/wiki/通用规范汉字表, raw resources:
165
+ # https://github.com/mozillazg/pinyin-data/blob/master/kMandarin_8105.txt with 8105 chinese chars in total
166
+ CN_CHARS_COMMON = (
167
+ '一丁七万丈三上下不与丏丐丑专且丕世丘丙业丛东丝丞丢两严丧个丫中丰串临丸丹为主丽举'
168
+ '乂乃久么义之乌乍乎乏乐乒乓乔乖乘乙乜九乞也习乡书乩买乱乳乸乾了予争事二亍于亏云互'
169
+ '亓五井亘亚些亟亡亢交亥亦产亨亩享京亭亮亲亳亵亶亸亹人亿什仁仂仃仄仅仆仇仉今介仍从'
170
+ '仑仓仔仕他仗付仙仝仞仟仡代令以仨仪仫们仰仲仳仵件价任份仿企伈伉伊伋伍伎伏伐休众优'
171
+ '伙会伛伞伟传伢伣伤伥伦伧伪伫伭伯估伲伴伶伸伺似伽伾佁佃但位低住佐佑体何佖佗佘余佚'
172
+ '佛作佝佞佟你佣佤佥佩佬佯佰佳佴佶佸佺佻佼佽佾使侁侂侃侄侈侉例侍侏侑侔侗侘供依侠侣'
173
+ '侥侦侧侨侩侪侬侮侯侴侵侹便促俄俅俊俍俎俏俐俑俗俘俙俚俜保俞俟信俣俦俨俩俪俫俭修俯'
174
+ '俱俳俵俶俸俺俾倌倍倏倒倓倔倕倘候倚倜倞借倡倥倦倧倨倩倪倬倭倮倴债倻值倾偁偃假偈偌'
175
+ '偎偏偓偕做停偡健偬偭偰偲偶偷偻偾偿傀傃傅傈傉傍傒傕傣傥傧储傩催傲傺傻僇僎像僔僖僚'
176
+ '僦僧僬僭僮僰僳僵僻儆儇儋儒儡儦儳儴儿兀允元兄充兆先光克免兑兔兕兖党兜兢入全八公六'
177
+ '兮兰共关兴兵其具典兹养兼兽冀冁内冈冉册再冏冒冔冕冗写军农冠冢冤冥冬冮冯冰冱冲决况'
178
+ '冶冷冻冼冽净凄准凇凉凋凌减凑凓凘凛凝几凡凤凫凭凯凰凳凶凸凹出击凼函凿刀刁刃分切刈'
179
+ '刊刍刎刑划刖列刘则刚创初删判刨利别刬刭刮到刳制刷券刹刺刻刽刿剀剁剂剃剅削剋剌前剐'
180
+ '剑剔剕剖剜剞剟剡剥剧剩剪副割剽剿劁劂劄劈劐劓力劝办功加务劢劣动助努劫劬劭励劲劳劼'
181
+ '劾势勃勇勉勋勍勐勒勔勖勘勚募勠勤勰勺勾勿匀包匆匈匍匏匐匕化北匙匜匝匠匡匣匦匪匮匹'
182
+ '区医匼匾匿十千卅升午卉半华协卑卒卓单卖南博卜卞卟占卡卢卣卤卦卧卫卬卮卯印危即却卵'
183
+ '卷卸卺卿厂厄厅历厉压厌厍厕厖厘厚厝原厢厣厥厦厨厩厮去厾县叁参叆叇又叉及友双反发叔'
184
+ '叕取受变叙叚叛叟叠口古句另叨叩只叫召叭叮可台叱史右叵叶号司叹叻叼叽吁吃各吆合吉吊'
185
+ '同名后吏吐向吒吓吕吖吗君吝吞吟吠吡吣否吧吨吩含听吭吮启吱吲吴吵吸吹吻吼吽吾呀呃呆'
186
+ '呇呈告呋呐呒呓呔呕呖呗员呙呛呜呢呣呤呦周呱呲味呵呶呷呸呻呼命咀咂咄咆咇咉咋和咍咎'
187
+ '咏咐咒咔咕咖咙咚咛咝咡咣咤咥咦咧咨咩咪咫咬咯咱咳咴咸咺咻咽咿哀品哂哃哄哆哇哈哉哌'
188
+ '响哎哏哐哑哒哓哔哕哗哙哚哝哞哟哢哥哦哧哨哩哪哭哮哱哲哳哺哼哽哿唁唆唇唉唏唐唑唔唛'
189
+ '唝唠唢唣唤唧唪唬售唯唰唱唳唵唷唼唾唿啁啃啄商啉啊啐啕啖啜啡啤啥啦啧啪啫啬啭啮啰啴'
190
+ '啵啶啷啸啻啼啾喀喁喂喃善喆喇喈喉喊喋喏喑喔喘喙喜喝喟喤喧喱喳喵喷喹喻喽喾嗄嗅嗉嗌'
191
+ '嗍嗐嗑嗒嗓嗔嗖嗜嗝嗞嗟嗡嗣嗤嗥嗦嗨嗪嗫嗬嗯嗲嗳嗵嗷嗽嗾嘀嘁嘈嘉嘌嘎嘏嘘嘚嘛嘞嘟嘡'
192
+ '嘣嘤嘧嘬嘭嘱嘲嘴嘶嘹嘻嘿噀噂噇噌噍噎噔噗噘噙噜噢噤器噩噪噫噬噱噶噻噼嚄嚅嚆嚎嚏嚓'
193
+ '嚚嚣嚭嚯嚷嚼囊囔囚四回囟因囡团囤囫园困囱围囵囷囹固国图囿圃圄圆圈圉圊圌圐圙圜土圢'
194
+ '圣在圩圪圫圬圭圮圯地圲圳圹场圻圾址坂均坉坊坋坌坍坎坏坐坑坒块坚坛坜坝坞坟坠坡坤坥'
195
+ '坦坨坩坪坫坬坭坯坰坳坷坻坼坽垂垃垄垆垈型垌垍垎垏垒垓垕垙垚垛垞垟垠垡垢垣垤垦垧垩'
196
+ '垫垭垮垯垱垲垴垵垸垺垾垿埂埃埆埇埋埌城埏埒埔埕埗埘埙埚埝域埠埤埪埫埭埯埴埵埸培基'
197
+ '埼埽堂堃堆堇堉堋堌堍堎堐堑堕堙堞堠堡堤堧堨堪堰堲堵堼堽堾塄塅塆塌塍塑塔塘塝塞塥填'
198
+ '塬塱塾墀墁境墅墈墉墐墒墓墕墘墙墚增墟墡墣墦墨墩墼壁壅壑壕壤士壬壮声壳壶壸壹处备复'
199
+ '夏夐夔夕外夙多夜够夤夥大天太夫夬夭央夯失头夷夸夹夺夼奁奂奄奇奈奉奋奎奏契奓奔奕奖'
200
+ '套奘奚奠奡奢奥奭女奴奶奸她好妁如妃妄妆妇妈妊妍妒妓妖妗妘妙妞妣妤妥妧妨妩妪妫妭妮'
201
+ '妯妲妹妻妾姆姈姊始姐姑姒姓委姗姘姚姜姝姞姣姤姥姨姬姮姱姶姹姻姽姿娀威娃娄娅娆娇娈'
202
+ '娉娌娑娓娘娜娟娠娣娥娩娱娲娴娵娶娼婀婆婉婊婌婍婕婘婚婞婠婢婤婧婪婫婳婴婵婶婷婺婻'
203
+ '婼婿媂媄媆媒媓媖��媛媞媪媭媱媲媳媵媸媾嫁嫂嫄嫉嫌嫒嫔嫕嫖嫘嫚嫜嫠嫡嫣嫦嫩嫪嫫嫭嫱'
204
+ '嫽嬉嬖嬗嬛嬥嬬嬴嬷嬿孀孅子孑孓孔孕孖字存孙孚孛孜孝孟孢季孤孥学孩孪孬孰孱孳孵孺孽'
205
+ '宁它宄宅宇守安宋完宏宓宕宗官宙定宛宜宝实宠审客宣室宥宦宧宪宫宬宰害宴宵家宸容宽宾'
206
+ '宿寁寂寄寅密寇富寐寒寓寝寞察寡寤寥寨寮寰寸对寺寻导寿封射将尉尊小少尔尕尖尘尚尜尝'
207
+ '尢尤尥尧尨尪尬就尴尸尹尺尻尼尽尾尿局屁层屃居屈屉届屋屎屏屐屑展屙属屠屡屣履屦屯山'
208
+ '屹屺屼屾屿岁岂岈岊岌岍岐岑岔岖岗岘岙岚岛岜岞岠岢岣岨岩岫岬岭岱岳岵岷岸岽岿峁峂峃'
209
+ '峄峋峒峗峘峙峛峡峣峤峥峦峧峨峪峭峰峱峻峿崀崁崂崃崄崆崇崌崎崒崔崖崚崛崞崟崡崤崦崧'
210
+ '崩崭崮崴崶崽崾崿嵁嵅嵇嵊嵋嵌嵎嵖嵘嵚嵛嵝嵩嵫嵬嵯嵲嵴嶂嶅嶍嶒嶓嶙嶝嶟嶦嶲嶷巅巇巉'
211
+ '巍川州巡巢工左巧巨巩巫差巯己已巳巴巷巽巾币市布帅帆师希帏帐帑帔帕帖帘帙帚帛帜帝帡'
212
+ '带帧帨席帮帱帷常帻帼帽幂幄幅幌幔幕幖幛幞幡幢幪干平年并幸幺幻幼幽广庄庆庇床庋序庐'
213
+ '庑库应底庖店庙庚府庞废庠庤庥度座庭庱庳庵庶康庸庹庼庾廆廉廊廋廑廒廓廖廙廛廨廪延廷'
214
+ '建廿开弁异弃弄弆弇弈弊弋式弑弓引弗弘弛弟张弢弥弦弧弨弩弭弯弱弶弸弹强弼彀归当录彖'
215
+ '彗彘彝彟形彤彦彧彩彪彬彭彰影彳彷役彻彼往征徂径待徇很徉徊律徐徒徕得徘徙徛徜御徨循'
216
+ '徭微徵德徼徽心必忆忉忌忍忏忐忑忒忖志忘忙忝忞忠忡忤忧忪快忭忮忱忳念忸忺忻忽忾忿怀'
217
+ '态怂怃怄怅怆怊怍怎怏怒怔怕怖怙怛怜思怠怡急怦性怨怩怪怫怯怵总怼怿恁恂恃恋恍恐恒恓'
218
+ '恔恕恙恚恝恢恣恤恧恨恩恪恫恬恭息恰恳恶恸恹恺恻恼恽恿悃悄悆悈悉悌悍悒悔悖悚悛悝悟'
219
+ '悠悢患悦您悫悬悭悯悰悱悲悴悸悻悼情惆惇惊惋惎惑惔惕惘惙惚惛惜惝惟惠惦惧惨惩惫惬惭'
220
+ '惮惯惰想惴惶惹惺愀愁愃愆愈愉愍愎意愐愔愕愚感愠愣愤愦愧愫愭愿慆慈慊慌慎慑慕慝慢慥'
221
+ '慧慨慬慭慰慵慷憋憎憔憕憙憧憨憩憬憭憷憺憾懂懈懊懋懑懒懔懦懵懿戆戈戊戋戌戍戎戏成我'
222
+ '戒戕或戗战戚戛戟戡戢戣戤戥截戬戭戮戳戴户戽戾房所扁扂扃扅扆扇扈扉扊手才扎扑扒打扔'
223
+ '托扛扞扣扦执扩扪扫扬扭扮扯扰扳扶批扺扼扽找承技抃抄抉把抑抒抓抔投抖抗折抚抛抟抠抡'
224
+ '抢护报抨披抬抱抵抹抻押抽抿拂拃拄担拆拇拈拉拊拌拍拎拐拒拓拔拖拗拘拙招拜拟拢拣拤拥'
225
+ '拦拧拨择括拭拮拯拱拳拴拶拷拼拽拾拿持挂指挈按挎挑挓挖挚挛挝挞挟挠挡挣挤挥挦挨挪挫'
226
+ '振挲挹挺挽捂捃捅捆捉捋捌捍捎捏捐捕捞损捡换捣捧捩捭据捯捶捷捺捻捽掀掂掇授掉掊掌掎'
227
+ '掏掐排掖掘掞掠探掣接控推掩措掬掭掮掰掳掴掷掸掺掼掾揄揆揉揍描提插揕揖揠握揣揩揪揭'
228
+ '揳援揶揸揽揿搀搁搂搅搋搌搏搐搒搓搔搛搜搞搠搡搦搪搬搭搴携搽摁摄摅摆摇摈摊摏摒摔摘'
229
+ '摛摞摧摩摭摴摸摹摽撂撄撅撇撑撒撕撖撙撞撤撩撬播撮撰撵撷撸撺撼擀擂擅操擎擐擒擘擞擢'
230
+ '擤擦擿攀攉攒攘攥攫攮支收攸改攻攽放政故效敉敌敏救敔敕敖教敛敝敞敢散敦敩敫敬数敲整'
231
+ '敷文斋斌斐斑斓斗料斛斜斝斟斠斡斤斥斧斩斫断斯新斶方於施旁旃旄旅旆旋旌旎族旐旒旖旗'
232
+ '旞无既日旦旧旨早旬旭旮旯旰旱旴旵时旷旸旺旻旿昀昂昃昄昆昇昈昉昊昌明昏昒易昔昕昙昝'
233
+ '星映昡昣昤春昧昨昪昫昭是昱昳昴昵昶昺昼昽显晁晃晅晊晋晌晏晐晒晓晔晕晖晗晙晚晞晟晡'
234
+ '晢晤晦晨晪晫普景晰晱晴晶晷智晾暂暄暅暇暌暑暕暖暗暝暧暨暮暲暴暵暶暹暾暿曈曌曙曛曜'
235
+ '曝曦曩曰曲曳更曷曹曼曾替最月有朋服朏朐朓朔朕朗望朝期朦木未末本札术朱朳朴朵朸机朽'
236
+ '杀杂权杄杆杈杉杌李杏材村杓杕杖杙杜杞束杠条来杧杨杩杪杭杯杰杲杳杵杷杻杼松板极构枅'
237
+ '枇枉枋枍析枕林枘枚果枝枞枢枣枥枧枨枪枫枭枯枰枲枳枵架枷枸枹柁柃柄柈柊柏某柑柒染柔'
238
+ '柖柘柙柚柜柝柞柠柢查柩柬柯柰柱柳柴柷柽柿栀栅标栈栉栊栋栌栎栏栐树栒栓栖栗栝栟校栩'
239
+ '株栲栳栴样核根栻格栽栾桀桁桂桃桄桅框案桉桊桌桎桐桑桓桔桕桠桡桢档桤桥桦桧桨桩桫桯'
240
+ '桲桴桶桷桹梁梃梅梆梌梏梓梗梠梢梣梦梧梨梭梯械梳梴梵梼梽梾梿检棁棂棉棋棍棐棒棓棕棘'
241
+ '棚棠棣棤棨棪棫棬森棰棱棵棹棺棻棼棽椀椁椅椆椋植椎椐椑椒椓椟椠椤椪椭椰椴椸椹椽椿楂'
242
+ '楒楔楗楙楚楝楞楠楣楦楩楪楫楮楯楷楸楹楼概榃榄榅榆榇榈榉榍榑榔榕榖榛榜榧榨榫榭榰榱'
243
+ '榴榷榻槁槃槊槌槎��槔槚槛槜槟槠槭槱槲槽槿樊樗樘樟模樨横樯樱樵樽樾橄橇橐橑橘橙橛橞'
244
+ '橡橥橦橱橹橼檀檄檎檐檑檗檞檠檩檫檬櫆欂欠次欢欣欤欧欲欸欹欺欻款歃歅歆歇歉歌歙止正'
245
+ '此步武歧歪歹死歼殁殂殃殄殆殇殉殊残殍殒殓殖殚殛殡殣殪殳殴段殷殿毁毂毅毋毌母每毐毒'
246
+ '毓比毕毖毗毙毛毡毪毫毯毳毵毹毽氅氆氇氍氏氐民氓气氕氖氘氙氚氛氟氡氢氤氦氧氨氩氪氮'
247
+ '氯氰氲水永氾氿汀汁求汆汇汈汉汊汋汐汔汕汗汛汜汝汞江池污汤汧汨汩汪汫汭汰汲汴汶汹汽'
248
+ '汾沁沂沃沄沅沆沇沈沉沌沏沐沓沔沘沙沚沛沟没沣沤沥沦沧沨沩沪沫沭沮沱河沸油沺治沼沽'
249
+ '沾沿泂泃泄泅泇泉泊泌泐泓泔法泖泗泙泚泛泜泞泠泡波泣泥注泪泫泮泯泰泱泳泵泷泸泺泻泼'
250
+ '泽泾洁洄洇洈洋洌洎洑洒洓洗洘洙洚洛洞洢洣津洧洨洪洫洭洮洱洲洳洴洵洸洹洺活洼洽派洿'
251
+ '流浃浅浆浇浈浉浊测浍济浏浐浑浒浓浔浕浙浚浛浜浞浟浠浡浣浥浦浩浪浬浭浮浯浰浲浴海浸'
252
+ '浼涂涄涅消涉涌涍涎涐涑涓涔涕涘涛涝涞涟涠涡涢涣涤润涧涨涩涪涫涮涯液涴涵涸涿淀淄淅'
253
+ '淆淇淋淌淏淑淖淘淙淜淝淞淟淠淡淤淦淫淬淮淯深淳淴混淹添淼清渊渌渍渎渐渑渔渗渚渝渟'
254
+ '渠渡渣渤渥温渫渭港渰渲渴游渺渼湃湄湉湍湎湑湓湔湖湘湛湜湝湟湣湫湮湲湴湾湿溁溃溅溆'
255
+ '溇溉溍溏源溘溚溜溞溟溠溢溥溦溧溪溯溱溲溴溵溶溷溹溺溻溽滁滂滃滆滇滉滋滍滏滑滓滔滕'
256
+ '滗滘滚滞滟滠满滢滤滥滦滧滨滩滪滫滴滹漂漆漈漉漋漏漓演漕漖漠漤漦漩漪漫漭漯漱漳漴漶'
257
+ '漷漹漻漼漾潆潇潋潍潏潖潘潜潞潟潢潦潩潭潮潲潴潵潸潺潼潽潾澂澄澈澉澌澍澎澛澜澡澥澧'
258
+ '澪澭澳澴澶澹澼澽激濂濉濋濑濒濞濠濡濩濮濯瀌瀍瀑瀔瀚瀛瀣瀱瀵瀹瀼灈灌灏灞火灭灯灰灵'
259
+ '灶灸灼灾灿炀炅炆炉炊炌炎炒炔炕炖炘炙炜炝炟炣炫炬炭炮炯炱炳炷炸点炻炼炽烀烁烂烃烈'
260
+ '烊烔烘烙烛烜烝烟烠烤烦烧烨烩烫烬热烯烶烷烹烺烻烽焆焉焊焌焐焓焕焖焗焘焙焚焜焞焦焯'
261
+ '焰焱然煁煃煅煊煋煌煎煓煜煞煟煤煦照煨煮煲煳煴煸煺煽熄熇熊熏熔熘熙熛熜熟熠熥熨熬熵'
262
+ '熹熻燃燊燋燎燏燔燕燚燠燥燧燮燹爆爇爔爚爝爟爨爪爬爰爱爵父爷爸爹爻爽爿牁牂片版牌牍'
263
+ '牒牖牙牚牛牝牟牡牢牤牥牦牧物牮牯牲牵特牺牻牾牿犀犁犄犇犊犋犍犏犒犟犨犬犯犰犴状犷'
264
+ '犸犹狁狂狃狄狈狉狍狎狐狒狗狙狝狞狠狡狨狩独狭狮狯狰狱狲狳狴狷狸狺狻狼猁猃猄猇猊猎'
265
+ '猕猖猗猛猜猝猞猡猢猥猩猪猫猬献猯猰猱猴猷猹猺猾猿獍獐獒獗獠獬獭獯獴獾玃玄率玉王玎'
266
+ '玑玒玓玕玖玘玙玚玛玞玟玠玡玢玤玥玦玩玫玭玮环现玱玲玳玶玷玹玺玻玼玿珀珂珅珇珈珉珊'
267
+ '珋珌珍珏珐珑珒珕珖珙珛珝珞珠珢珣珥珦珧珩珪珫班珰珲珵珷珸珹珺珽琀球琄琅理琇琈琉琊'
268
+ '琎琏琐琔琚琛琟琡琢琤琥琦琨琪琫琬琭琮琯琰琲琳琴琵琶琼瑀瑁瑂瑃瑄瑅瑆瑑瑓瑔瑕瑖瑗瑙'
269
+ '瑚瑛瑜瑝瑞瑟瑢瑧瑨瑬瑭瑰瑱瑳瑶瑷瑾璀璁璃璆璇璈璋璎璐璒璘璜璞璟璠璥璧璨璩璪璬璮璱'
270
+ '璲璺瓀瓒瓖瓘瓜瓞瓠瓢瓣瓤瓦瓮瓯瓴瓶瓷瓻瓿甄甍甏甑甓甗甘甚甜生甡甥甦用甩甪甫甬甭甯'
271
+ '田由甲申电男甸町画甾畀畅畈畋界畎畏畔畖留畚畛畜畤略畦番畬畯畲畴畸畹畿疁疃疆疍疏疐'
272
+ '疑疔疖疗疙疚疝疟疠疡疢疣疤疥疫疬疭疮疯疰疱疲疳疴疵疸疹疼疽疾痂痃痄病症痈痉痊痍痒'
273
+ '痓痔痕痘痛痞痢痣痤痦痧痨痪痫痰痱痴痹痼痿瘀瘁瘃瘅瘆瘊瘌瘐瘕瘗瘘瘙瘛瘟瘠瘢瘤瘥瘦瘩'
274
+ '瘪瘫瘭瘰瘳瘴瘵瘸瘼瘾瘿癀癃癌癍癔癖癗癜癞癣癫癯癸登白百癿皂的皆皇皈皋皎皑皓皕皖皙'
275
+ '皛皞皤皦皭皮皱皲皴皿盂盅盆盈盉益盍盎盏盐监盒盔盖盗盘盛盟盥盦目盯盱盲直盷相盹盼盾'
276
+ '省眄眇眈眉眊看眍眙眚真眠眢眦眨眩眬眭眯眵眶眷眸眺眼着睁睃睄睇睎睐睑睚睛睡睢督睥睦'
277
+ '睨睫睬睹睽睾睿瞀瞄瞅瞋瞌瞍瞎瞑瞒瞟瞠瞢瞥瞧瞩瞪瞫瞬瞭瞰瞳瞵瞻瞽瞿矍矗矛矜矞矢矣知'
278
+ '矧矩矫矬短矮矰石矶矸矻矼矾矿砀码砂砄砆砉砌砍砑砒研砖砗砘砚砜砝砟砠砣砥砧砫砬砭砮'
279
+ '砰破砵砷砸砹砺砻砼砾础硁硅硇硊硌硍硎硐硒硔硕硖硗硙硚硝硪硫硬硭确硼硿碃碇碈碉碌碍'
280
+ '碎碏碑碓碗碘碚碛碜碟碡碣碥碧碨碰碱碲碳碴碶碹碾磁磅磉磊磋磏磐磔磕磙磜磡磨磬磲磴磷'
281
+ '磹磻礁礅礌礓礞礴礵示礼社祀祁祃祆祇祈祉祊祋祎祏祐祓祕祖祗祚祛祜祝神祟祠祢祥祧票祭'
282
+ '祯祲祷祸祺祼祾禀禁禄禅禊禋福禒禔禘禚禛禤禧禳禹禺离禽禾秀私秃秆秉秋种科秒秕秘租秣'
283
+ '秤秦秧秩秫秬秭积��秸移秽秾稀稂稃稆程稌稍税稑稔稗稙稚稞稠稣稳稷稹稻稼稽稿穄穆穑穗'
284
+ '穙穜穟穰穴究穷穸穹空穿窀突窃窄窅窈窊窍窎窑窒窕窖窗窘窜窝窟窠窣窥窦窨窬窭窳窸窿立'
285
+ '竑竖竘站竞竟章竣童竦竫竭端竹竺竽竿笃笄笆笈笊笋笏笑笔笕笙笛笞笠笤笥符笨笪笫第笮笯'
286
+ '笱笳笸笺笼笾筀筅筇等筋筌筏筐筑筒答策筘筚筛筜筝筠筢筤筥筦筮筱筲筵筶筷筹筻筼签简箅'
287
+ '箍箐箓箔箕箖算箜管箢箦箧箨箩箪箫箬箭箱箴箸篁篆篇篌篑篓篙篚篝篡篥篦篪篮篯篱篷篼篾'
288
+ '簃簇簉簋簌簏簕簖簝簟簠簧簪簰簸簿籀籁籍籥米籴类籼籽粉粑粒粕粗粘粜粝粞粟粢粤粥粪粮'
289
+ '粱粲粳粹粼粽精粿糁糅糇糈糊糌糍糒糕糖糗糙糜糟糠糨糯糵系紊素索紧紫累絜絮絷綦綮縠縢'
290
+ '縻繁繄繇纂纛纠纡红纣纤纥约级纨纩纪纫纬纭纮纯纰纱纲纳纴纵纶纷纸纹纺纻纼纽纾线绀绁'
291
+ '绂练组绅细织终绉绊绋绌绍绎经绐绑绒结绔绕绖绗绘给绚绛络绝绞统绠绡绢绣绤绥绦继绨绩'
292
+ '绪绫续绮绯绰绱绲绳维绵绶绷绸绹绺绻综绽绾绿缀缁缂缃缄缅缆缇缈缉缊缌缎缐缑缒缓缔缕'
293
+ '编缗缘缙缚缛缜缝缞缟缠缡缢缣缤缥缦缧缨缩缪缫缬缭缮缯缰缱缲缳缴缵缶缸缺罂罄罅罍罐'
294
+ '网罔罕罗罘罚罟罡罢罨罩罪置罱署罴罶罹罽罾羁羊羌美羑羓羔羕羖羚羝羞羟羡群羧羯羰羱羲'
295
+ '羸羹羼羽羿翀翁翂翃翅翈翊翌翎翔翕翘翙翚翛翟翠翡翥翦翩翮翯翰翱翳翷翻翼翾耀老考耄者'
296
+ '耆耇耋而耍耏耐耑耒耔耕耖耗耘耙耜耠耢耤耥耦耧耨耩耪耰耱耳耵耶耷耸耻耽耿聂聃聆聊聋'
297
+ '职聍聒联聘聚聩聪聱聿肃肄肆肇肉肋肌肓肖肘肚肛肝肟肠股肢肤肥肩肪肫肭肮肯肱育肴肷肸'
298
+ '肺肼肽肾肿胀胁胂胃胄胆胈背胍胎胖胗胙胚胛胜胝胞胠胡胣胤胥胧胨胩胪胫胬胭胯胰胱胲胳'
299
+ '胴胶胸胺胼能脂脆脉脊脍脎脏脐脑脒脓脔脖脘脚脞脟脩脬脯脱脲脶脸脾脿腆腈腊腋腌腐腑腒'
300
+ '腓腔腕腘腙腚腠腥腧腨腩腭腮腯腰腱腴腹腺腻腼腽腾腿膀膂膈膊膏膑膘膙膛膜膝膦膨膳膺膻'
301
+ '臀臂臃臆臊臌臑臜臣臧自臬臭至致臻臼臾舀舁舂舄舅舆舌舍舐舒舔舛舜舞舟舠舢舣舥航舫般'
302
+ '舭舯舰舱舲舳舴舵舶舷舸船舻舾艄艅艇艉艋艎艏艘艚艟艨艮良艰色艳艴艺艽艾艿节芃芄芈芊'
303
+ '芋芍芎芏芑芒芗芘芙芜芝芟芠芡芣芤芥芦芨芩芪芫芬芭芮芯芰花芳芴芷芸芹芼芽芾苁苄苇苈'
304
+ '苉苊苋苌苍苎苏苑苒苓苔苕苗苘苛苜苞苟苠苡苣苤若苦苧苫苯英苴苷苹苻苾茀茁茂范茄茅茆'
305
+ '茈茉茋茌茎茏茑茓茔茕茗茚茛茜茝茧茨茫茬茭茯茱茳茴茵茶茸茹茺茼茽荀荁荃荄荆荇草荏荐'
306
+ '荑荒荓荔荖荙荚荛荜荞荟荠荡荣荤荥荦荧荨荩荪荫荬荭荮药荷荸荻荼荽莅莆莉莎莒莓莘莙莛'
307
+ '莜莝莞莠莨莩莪莫莰莱莲莳莴莶获莸莹莺莼莽莿菀菁菂菅菇菉菊菌菍菏菔菖菘菜菝菟菠菡菥'
308
+ '菩菪菰菱菲菹菼菽萁萃萄萆萋萌萍萎萏萑萘萚萜萝萣萤营萦萧萨萩萱萳萸萹萼落葆葎葑葖著'
309
+ '葙葚葛葜葡董葩葫葬葭葰葱葳葴葵葶葸葺蒂蒄蒇蒈蒉蒋蒌蒎蒐蒗蒙蒜蒟蒡蒨蒯蒱蒲蒴蒸蒹蒺'
310
+ '蒻蒽蒿蓁蓂蓄蓇蓉蓊蓍蓏蓐蓑蓓蓖蓝蓟蓠蓢蓣蓥蓦蓬蓰蓼蓿蔀蔃蔈蔊蔌蔑蔓蔗蔚蔟蔡蔫蔬蔷'
311
+ '蔸蔹蔺蔻蔼蔽蕃蕈蕉蕊蕖蕗蕙蕞蕤蕨蕰蕲蕴蕹蕺蕻蕾薁薄薅薇薏薛薜薢薤薨薪薮薯薰薳薷薸'
312
+ '薹薿藁藉藏藐藓藕藜藟藠藤藦藨藩藻藿蘅蘑蘖蘘蘧蘩蘸蘼虎虏虐虑虒虓虔虚虞虢虤虫虬虮虱'
313
+ '虷虸虹虺虻虼虽虾虿蚀蚁蚂蚄蚆蚊蚋蚌蚍蚓蚕蚜蚝蚣蚤蚧蚨蚩蚪蚬蚯蚰蚱蚲蚴蚶蚺蛀蛃蛄蛆'
314
+ '蛇蛉蛊蛋蛎蛏蛐蛑蛔蛘蛙蛛蛞蛟蛤蛩蛭蛮蛰蛱蛲蛳蛴蛸蛹蛾蜀蜂蜃蜇蜈蜉蜊蜍蜎蜐蜒蜓蜕蜗'
315
+ '蜘蜚蜜蜞蜡蜢蜣蜥蜩蜮蜱蜴蜷蜻蜾蜿蝇蝈蝉蝌蝎蝓蝗蝘蝙蝠蝣蝤蝥蝮蝰蝲蝴蝶蝻蝼蝽蝾螂螃'
316
+ '螅螈螋融螗螟螠螣螨螫螬螭螯螱螳螵螺螽蟀蟆蟊蟋蟏蟑蟒蟛蟠蟥蟪蟫蟮蟹蟾蠃蠊蠋蠓蠕蠖蠡'
317
+ '蠢蠲蠹蠼血衃衄衅行衍衎衒衔街衙衠衡衢衣补表衩衫衬衮衰衲衷衽衾衿袁袂袄袅袆袈袋袍袒'
318
+ '袖袗袜袢袤袪被袭袯袱袷袼裁裂装裆裈裉裎裒裔裕裘裙裛裟裢裣裤裥裨裰裱裳裴裸裹裼裾褂'
319
+ '褊褐褒褓褕褙褚褛褟褡褥褪褫褯褰褴褶襁襄襕襚襜襞襟襦襫襻西要覃覆见观觃规觅视觇览觉'
320
+ '觊觋觌觎觏觐觑角觖觚觜觞觟解觥触觫觭觯觱觳觿言訄訇訚訾詈詟詹誉誊誓謇警譬计订讣认'
321
+ '讥讦讧讨让讪讫训议讯记讱讲讳讴讵讶讷许讹论讻讼讽设访诀证诂诃评诅识诇诈诉诊诋诌词'
322
+ '诎诏诐译诒诓诔试诖诗诘诙诚诛诜话诞诟诠诡询诣诤该详诧诨诩诫诬语诮误诰诱诲诳说诵请'
323
+ '诸诹诺读诼诽课诿��谁谂调谄谅谆谇谈谊谋谌谍谎谏谐谑谒谓谔谕谖谗谙谚谛谜谝谞谟谠谡'
324
+ '谢谣谤谥谦谧谨谩谪谫谬谭谮谯谰谱谲谳谴谵谶谷谼谿豁豆豇豉豌豕豚象豢豨豪豫豮豳豸豹'
325
+ '豺貂貅貆貉貊貌貔貘贝贞负贡财责贤败账货质贩贪贫贬购贮贯贰贱贲贳贴贵贶贷贸费贺贻贼'
326
+ '贽贾贿赀赁赂赃资赅赆赇赈赉赊赋赌赍赎赏赐赑赒赓赔赕赖赗赘赙赚赛赜赝赞赟赠赡赢赣赤'
327
+ '赦赧赪赫赭走赳赴赵赶起趁趄超越趋趑趔趟趣趯趱足趴趵趸趺趼趾趿跂跃跄跆跋跌跎跏跐跑'
328
+ '跖跗跚跛距跞跟跣跤跨跪跬路跱跳践跶跷跸跹跺跻跽踅踉踊踌踏踒踔踝踞踟踢踣踦踩踪踬踮'
329
+ '踯踱踵踶踹踺踽蹀蹁蹂蹄蹅蹇蹈蹉蹊蹋蹐蹑蹒蹙蹚蹜蹢蹦蹩蹬蹭蹯蹰蹲蹴蹶蹼蹽蹾蹿躁躅躇'
330
+ '躏躐躔躜躞身躬躯躲躺车轧轨轩轪轫转轭轮软轰轱轲轳轴轵轶轷轸轹轺轻轼载轾轿辀辁辂较'
331
+ '辄辅辆辇辈辉辊辋辌辍辎辏辐辑辒输辔辕辖辗辘辙辚辛辜辞辟辣辨辩辫辰辱边辽达辿迁迂迄'
332
+ '迅过迈迎运近迓返迕还这进远违连迟迢迤迥迦迨迩迪迫迭迮述迳迷迸迹迺追退送适逃逄逅逆'
333
+ '选逊逋逍透逐逑递途逖逗通逛逝逞速造逡逢逦逭逮逯逴逵逶逸逻逼逾遁遂遄遆遇遍遏遐遑遒'
334
+ '道遗遘遛遢遣遥遨遭遮遴遵遹遽避邀邂邃邈邋邑邓邕邗邘邙邛邝邠邡邢那邦邨邪邬邮邯邰邱'
335
+ '邲邳邴邵邶邸邹邺邻邽邾邿郁郃郄郅郇郈郊郎郏郐郑郓郗郚郛郜郝郡郢郤郦郧部郪郫郭郯郴'
336
+ '郸都郾郿鄀鄂鄃鄄鄅鄌鄑鄗鄘鄙鄚鄜鄞鄠鄢鄣鄫鄯鄱鄹酂酃酅酆酉酊酋酌配酎酏酐酒酗酚酝'
337
+ '酞酡酢酣酤酥酦酩酪酬酮酯酰酱酲酴酵酶酷酸酹酺酽酾酿醅醇醉醋醌醍醐醑醒醚醛醢醨醪醭'
338
+ '醮醯醴醵醺醾采釉释里重野量釐金釜鉴銎銮鋆鋈錾鍪鎏鏊鏖鐾鑫钆钇针钉钊钋钌钍钎钏钐钒'
339
+ '钓钔钕钖钗钘钙钚钛钜钝钞钟钠钡钢钣钤钥钦钧钨钩钪钫钬钭钮钯钰钱钲钳钴钵钷钹钺钻钼'
340
+ '钽钾钿铀铁铂铃铄铅铆铈铉铊铋铌铍铎铏铐铑铒铕铖铗铘铙铚铛铜铝铞铟铠铡铢铣铤铥铧铨'
341
+ '铩铪铫铬铭铮铯铰铱铲铳铴铵银铷铸铹铺铻铼铽链铿销锁锂锃锄锅锆锇锈锉锊锋锌锍锎锏锐'
342
+ '锑锒锓锔锕锖锗锘错锚锛锜锝锞锟锡锢锣锤锥锦锧锨锩锪锫锬锭键锯锰锱锲锳锴锵锶锷锸锹'
343
+ '锺锻锼锽锾锿镀镁镂镃镄镅镆镇镈镉镊镋镌镍镎镏镐镑镒镓镔镕镖镗镘镚镛镜镝镞镠镡镢镣'
344
+ '镤镥镦镧镨镩镪镫镬镭镮镯镰镱镲镳镴镵镶长门闩闪闫闭问闯闰闱闲闳间闵闶闷闸闹闺闻闼'
345
+ '闽闾闿阀阁阂阃阄阅阆阇阈阉阊阋阌阍阎阏阐阑阒阔阕阖阗阘阙阚阜队阡阪阮阱防阳阴阵阶'
346
+ '阻阼阽阿陀陂附际陆陇陈陉陋陌降陎限陑陔陕陛陞陟陡院除陧陨险陪陬陲陴陵陶陷隃隅隆隈'
347
+ '隋隍随隐隔隗隘隙障隧隩隰隳隶隹隺隼隽难雀雁雄雅集雇雉雊雌雍雎雏雒雕雠雨雩雪雯雱雳'
348
+ '零雷雹雾需霁霄霅霆震霈霉霍霎霏霓霖霜霞霨霪霭霰露霸霹霾青靓靖静靛非靠靡面靥革靬靰'
349
+ '靳靴靶靸靺靼靽靿鞁鞅鞋鞍鞑鞒鞔鞘鞠鞡鞣鞧鞨鞫鞬鞭鞮鞯鞲鞳鞴韂韦韧韨韩韪韫韬韭音韵'
350
+ '韶页顶顷顸项顺须顼顽顾顿颀颁颂颃预颅领颇颈颉颊颋颌颍颎颏颐频颓颔颖颗题颙颚颛颜额'
351
+ '颞颟颠颡颢颤颥颦颧风飏飐飑飒飓飔飕飗飘飙飞食飧飨餍餐餮饔饕饥饧饨饩饪饫饬饭饮饯饰'
352
+ '饱饲饳饴饵饶饷饸饹饺饻饼饽饿馁馃馄馅馆馇馈馉馊馋馌馍馏馐馑馒馓馔馕首馗馘香馝馞馥'
353
+ '馧馨马驭驮驯驰驱驲驳驴驵驶驷驸驹驺驻驼驽驾驿骀骁骂骃骄骅骆骇骈骉骊骋验骍骎骏骐骑'
354
+ '骒骓骕骖骗骘骙骚骛骜骝骞骟骠骡骢骣骤骥骦骧骨骰骱骶骷骸骺骼髀髁髂髃髅髋髌髎髑髓高'
355
+ '髡髢髦髫髭髯髹髻髽鬃鬈鬏鬒鬓鬘鬟鬣鬯鬲鬶鬷鬻鬼魁魂魃魄魅魆魇魈魉魋魍魏魑魔鱼鱽鱾'
356
+ '鱿鲀鲁鲂鲃鲅鲆鲇鲈鲉鲊鲋鲌鲍鲎鲏鲐鲑鲒鲔鲕鲖鲗鲘鲙鲚鲛鲜鲝鲞鲟鲠鲡鲢鲣鲤鲥鲦鲧鲨'
357
+ '鲩鲪鲫鲬鲭鲮鲯鲰鲱鲲鲳鲴鲵鲷鲸鲹鲺鲻鲼鲽鲾鲿鳀鳁鳂鳃鳄鳅鳇鳈鳉鳊鳌鳍鳎鳏鳐鳑鳒鳓'
358
+ '鳔鳕鳖鳗鳘鳙鳚鳛鳜鳝鳞鳟鳠鳡鳢鳣鳤鸟鸠鸡鸢鸣鸤鸥鸦鸧鸨鸩鸪鸫鸬鸭鸮鸯鸰鸱鸲鸳鸵鸶'
359
+ '鸷鸸鸹鸺鸻鸼鸽鸾鸿鹀鹁鹂鹃鹄鹅鹆鹇鹈鹉鹊鹋鹌鹍鹎鹏鹐鹑鹒鹔鹕鹖鹗鹘鹙鹚鹛鹜鹝鹞鹟'
360
+ '鹠鹡鹢鹣鹤鹦鹧鹨鹩鹪鹫鹬鹭鹮鹯鹰鹱鹲鹳鹴鹾鹿麀麂麇麈麋麑麒麓麖麝麟麦麸麹麻麽麾黄'
361
+ '黇黉黍黎黏黑黔默黛黜黝黟黠黡黢黥黧黩黪黯黹黻黼黾鼋鼍鼎鼐鼒鼓鼗鼙鼠鼢鼩鼫鼬鼯鼱鼷'
362
+ '鼹鼻鼽鼾齁齇齉齐齑齿龀龁龂龃龄龅龆龇龈龉龊龋龌龙龚龛龟龠龢鿍鿎鿏㑇㑊㕮㘎㙍㙘㙦㛃'
363
+ '㛚㛹㟃㠇㠓㤘㥄㧐��㧟㫰㬊㬎㬚㭎㭕㮾㰀㳇㳘㳚㴔㵐㶲㸆㸌㺄㻬㽏㿠䁖䂮䃅䃎䅟䌹䎃䎖䏝䏡'
364
+ '䏲䐃䓖䓛䓨䓫䓬䗖䗛䗪䗴䜣䝙䢺䢼䣘䥽䦃䲟䲠䲢䴓䴔䴕䴖䴗䴘䴙䶮𠅤𠙶𠳐𡎚𡐓𣗋𣲗𣲘𣸣𤧛𤩽'
365
+ '𤫉𥔲𥕢𥖨𥻗𦈡𦒍𦙶𦝼𦭜𦰡𧿹𨐈𨙸𨚕𨟠𨭉𨱇𨱏𨱑𨱔𨺙𩽾𩾃𩾌𪟝𪣻𪤗𪨰𪨶𪩘𪾢𫄧𫄨𫄷𫄸𫇭𫌀𫍣𫍯'
366
+ '𫍲𫍽𫐄𫐐𫐓𫑡𫓧𫓯𫓶𫓹𫔍𫔎𫔶𫖮𫖯𫖳𫗧𫗴𫘜𫘝𫘦𫘧𫘨𫘪𫘬𫚕𫚖𫚭𫛭𫞩𫟅𫟦𫟹𫟼𫠆𫠊𫠜𫢸𫫇𫭟'
367
+ '𫭢𫭼𫮃𫰛𫵷𫶇𫷷𫸩𬀩𬀪𬂩𬃊𬇕𬇙𬇹𬉼𬊈𬊤𬌗𬍛𬍡𬍤𬒈𬒔𬒗𬕂𬘓𬘘𬘡𬘩𬘫𬘬𬘭𬘯𬙂𬙊𬙋𬜬𬜯𬞟'
368
+ '𬟁𬟽𬣙𬣞𬣡𬣳𬤇𬤊𬤝𬨂𬨎𬩽𬪩𬬩𬬭𬬮𬬱𬬸𬬹𬬻𬬿𬭁𬭊𬭎𬭚𬭛𬭤𬭩𬭬𬭯𬭳𬭶𬭸𬭼𬮱𬮿𬯀𬯎𬱖𬱟'
369
+ '𬳵𬳶𬳽𬳿𬴂𬴃𬴊𬶋𬶍𬶏𬶐𬶟𬶠𬶨𬶭𬶮𬷕𬸘𬸚𬸣𬸦𬸪𬹼𬺈𬺓'
370
+ )
371
+ CN_CHARS_EXT = '吶诶屌囧飚屄'
372
+
373
+ CN_CHARS = CN_CHARS_COMMON + CN_CHARS_EXT
374
+ IN_CH_CHARS = { c : True for c in CN_CHARS }
375
+
376
+ EN_CHARS = string.ascii_letters + string.digits
377
+ IN_EN_CHARS = { c : True for c in EN_CHARS }
378
+
379
+ VALID_CHARS = CN_CHARS + EN_CHARS + ' '
380
+ IN_VALID_CHARS = { c : True for c in VALID_CHARS }
381
+
382
+ # ================================================================================ #
383
+ # basic class
384
+ # ================================================================================ #
385
+ class ChineseChar(object):
386
+ """
387
+ 中文字符
388
+ 每个字符对应简体和繁体,
389
+ e.g. 简体 = '负', 繁体 = '負'
390
+ 转换时可转换为简体或繁体
391
+ """
392
+
393
+ def __init__(self, simplified, traditional):
394
+ self.simplified = simplified
395
+ self.traditional = traditional
396
+ #self.__repr__ = self.__str__
397
+
398
+ def __str__(self):
399
+ return self.simplified or self.traditional or None
400
+
401
+ def __repr__(self):
402
+ return self.__str__()
403
+
404
+
405
+ class ChineseNumberUnit(ChineseChar):
406
+ """
407
+ 中文数字/数位字符
408
+ 每个字符除繁简体外还有一个额外的大写字符
409
+ e.g. '陆' 和 '陸'
410
+ """
411
+
412
+ def __init__(self, power, simplified, traditional, big_s, big_t):
413
+ super(ChineseNumberUnit, self).__init__(simplified, traditional)
414
+ self.power = power
415
+ self.big_s = big_s
416
+ self.big_t = big_t
417
+
418
+ def __str__(self):
419
+ return '10^{}'.format(self.power)
420
+
421
+ @classmethod
422
+ def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
423
+
424
+ if small_unit:
425
+ return ChineseNumberUnit(power=index + 1,
426
+ simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1])
427
+ elif numbering_type == NUMBERING_TYPES[0]:
428
+ return ChineseNumberUnit(power=index + 8,
429
+ simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
430
+ elif numbering_type == NUMBERING_TYPES[1]:
431
+ return ChineseNumberUnit(power=(index + 2) * 4,
432
+ simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
433
+ elif numbering_type == NUMBERING_TYPES[2]:
434
+ return ChineseNumberUnit(power=pow(2, index + 3),
435
+ simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
436
+ else:
437
+ raise ValueError(
438
+ 'Counting type should be in {0} ({1} provided).'.format(NUMBERING_TYPES, numbering_type))
439
+
440
+
441
+ class ChineseNumberDigit(ChineseChar):
442
+ """
443
+ 中文数字字符
444
+ """
445
+
446
+ def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None):
447
+ super(ChineseNumberDigit, self).__init__(simplified, traditional)
448
+ self.value = value
449
+ self.big_s = big_s
450
+ self.big_t = big_t
451
+ self.alt_s = alt_s
452
+ self.alt_t = alt_t
453
+
454
+ def __str__(self):
455
+ return str(self.value)
456
+
457
+ @classmethod
458
+ def create(cls, i, v):
459
+ return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
460
+
461
+
462
+ class ChineseMath(ChineseChar):
463
+ """
464
+ 中文数位字符
465
+ """
466
+
467
+ def __init__(self, simplified, traditional, symbol, expression=None):
468
+ super(ChineseMath, self).__init__(simplified, traditional)
469
+ self.symbol = symbol
470
+ self.expression = expression
471
+ self.big_s = simplified
472
+ self.big_t = traditional
473
+
474
+
475
+ CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
476
+
477
+
478
+ class NumberSystem(object):
479
+ """
480
+ 中文数字系统
481
+ """
482
+ pass
483
+
484
+
485
+ class MathSymbol(object):
486
+ """
487
+ 用于中文数字系统的数学符号 (繁/简体), e.g.
488
+ positive = ['正', '正']
489
+ negative = ['负', '負']
490
+ point = ['点', '點']
491
+ """
492
+
493
+ def __init__(self, positive, negative, point):
494
+ self.positive = positive
495
+ self.negative = negative
496
+ self.point = point
497
+
498
+ def __iter__(self):
499
+ for v in self.__dict__.values():
500
+ yield v
501
+
502
+
503
+ # class OtherSymbol(object):
504
+ # """
505
+ # 其他符号
506
+ # """
507
+ #
508
+ # def __init__(self, sil):
509
+ # self.sil = sil
510
+ #
511
+ # def __iter__(self):
512
+ # for v in self.__dict__.values():
513
+ # yield v
514
+
515
+
516
+ # ================================================================================ #
517
+ # basic utils
518
+ # ================================================================================ #
519
+ def create_system(numbering_type=NUMBERING_TYPES[1]):
520
+ """
521
+ 根据数字系统类型返回创建相应的数字系统,默认为 mid
522
+ NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
523
+ low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc.
524
+ mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
525
+ high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
526
+ 返回对应的数字系统
527
+ """
528
+
529
+ # chinese number units of '亿' and larger
530
+ all_larger_units = zip(
531
+ LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL)
532
+ larger_units = [CNU.create(i, v, numbering_type, False)
533
+ for i, v in enumerate(all_larger_units)]
534
+ # chinese number units of '十, 百, 千, 万'
535
+ all_smaller_units = zip(
536
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL)
537
+ smaller_units = [CNU.create(i, v, small_unit=True)
538
+ for i, v in enumerate(all_smaller_units)]
539
+ # digis
540
+ chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS,
541
+ BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL)
542
+ digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
543
+ digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
544
+ digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
545
+ digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
546
+
547
+ # symbols
548
+ positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x)
549
+ negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x)
550
+ point_cn = CM(POINT[0], POINT[1], '.', lambda x,
551
+ y: float(str(x) + '.' + str(y)))
552
+ # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
553
+ system = NumberSystem()
554
+ system.units = smaller_units + larger_units
555
+ system.digits = digits
556
+ system.math = MathSymbol(positive_cn, negative_cn, point_cn)
557
+ # system.symbols = OtherSymbol(sil_cn)
558
+ return system
559
+
560
+
561
+ def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
562
+
563
+ def get_symbol(char, system):
564
+ for u in system.units:
565
+ if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
566
+ return u
567
+ for d in system.digits:
568
+ if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]:
569
+ return d
570
+ for m in system.math:
571
+ if char in [m.traditional, m.simplified]:
572
+ return m
573
+
574
+ def string2symbols(chinese_string, system):
575
+ int_string, dec_string = chinese_string, ''
576
+ for p in [system.math.point.simplified, system.math.point.traditional]:
577
+ if p in chinese_string:
578
+ int_string, dec_string = chinese_string.split(p)
579
+ break
580
+ return [get_symbol(c, system) for c in int_string], \
581
+ [get_symbol(c, system) for c in dec_string]
582
+
583
+ def correct_symbols(integer_symbols, system):
584
+ """
585
+ 一百八 to 一百八十
586
+ 一亿一千三百万 to 一亿 一千万 三百万
587
+ """
588
+
589
+ if integer_symbols and isinstance(integer_symbols[0], CNU):
590
+ if integer_symbols[0].power == 1:
591
+ integer_symbols = [system.digits[1]] + integer_symbols
592
+
593
+ if len(integer_symbols) > 1:
594
+ if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
595
+ integer_symbols.append(
596
+ CNU(integer_symbols[-2].power - 1, None, None, None, None))
597
+
598
+ result = []
599
+ unit_count = 0
600
+ for s in integer_symbols:
601
+ if isinstance(s, CND):
602
+ result.append(s)
603
+ unit_count = 0
604
+ elif isinstance(s, CNU):
605
+ current_unit = CNU(s.power, None, None, None, None)
606
+ unit_count += 1
607
+
608
+ if unit_count == 1:
609
+ result.append(current_unit)
610
+ elif unit_count > 1:
611
+ for i in range(len(result)):
612
+ if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power:
613
+ result[-i - 1] = CNU(result[-i - 1].power +
614
+ current_unit.power, None, None, None, None)
615
+ return result
616
+
617
+ def compute_value(integer_symbols):
618
+ """
619
+ Compute the value.
620
+ When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
621
+ e.g. '两千万' = 2000 * 10000 not 2000 + 10000
622
+ """
623
+ value = [0]
624
+ last_power = 0
625
+ for s in integer_symbols:
626
+ if isinstance(s, CND):
627
+ value[-1] = s.value
628
+ elif isinstance(s, CNU):
629
+ value[-1] *= pow(10, s.power)
630
+ if s.power > last_power:
631
+ value[:-1] = list(map(lambda v: v *
632
+ pow(10, s.power), value[:-1]))
633
+ last_power = s.power
634
+ value.append(0)
635
+ return sum(value)
636
+
637
+ system = create_system(numbering_type)
638
+ int_part, dec_part = string2symbols(chinese_string, system)
639
+ int_part = correct_symbols(int_part, system)
640
+ int_str = str(compute_value(int_part))
641
+ dec_str = ''.join([str(d.value) for d in dec_part])
642
+ if dec_part:
643
+ return '{0}.{1}'.format(int_str, dec_str)
644
+ else:
645
+ return int_str
646
+
647
+
648
+ def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False,
649
+ traditional=False, alt_zero=False, alt_one=False, alt_two=True,
650
+ use_zeros=True, use_units=True):
651
+
652
+ def get_value(value_string, use_zeros=True):
653
+
654
+ striped_string = value_string.lstrip('0')
655
+
656
+ # record nothing if all zeros
657
+ if not striped_string:
658
+ return []
659
+
660
+ # record one digits
661
+ elif len(striped_string) == 1:
662
+ if use_zeros and len(value_string) != len(striped_string):
663
+ return [system.digits[0], system.digits[int(striped_string)]]
664
+ else:
665
+ return [system.digits[int(striped_string)]]
666
+
667
+ # recursively record multiple digits
668
+ else:
669
+ result_unit = next(u for u in reversed(
670
+ system.units) if u.power < len(striped_string))
671
+ result_string = value_string[:-result_unit.power]
672
+ return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power:])
673
+
674
+ system = create_system(numbering_type)
675
+
676
+ int_dec = number_string.split('.')
677
+ if len(int_dec) == 1:
678
+ int_string = int_dec[0]
679
+ dec_string = ""
680
+ elif len(int_dec) == 2:
681
+ int_string = int_dec[0]
682
+ dec_string = int_dec[1]
683
+ else:
684
+ raise ValueError(
685
+ "invalid input num string with more than one dot: {}".format(number_string))
686
+
687
+ if use_units and len(int_string) > 1:
688
+ result_symbols = get_value(int_string)
689
+ else:
690
+ result_symbols = [system.digits[int(c)] for c in int_string]
691
+ dec_symbols = [system.digits[int(c)] for c in dec_string]
692
+ if dec_string:
693
+ result_symbols += [system.math.point] + dec_symbols
694
+
695
+ if alt_two:
696
+ liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t,
697
+ system.digits[2].big_s, system.digits[2].big_t)
698
+ for i, v in enumerate(result_symbols):
699
+ if isinstance(v, CND) and v.value == 2:
700
+ next_symbol = result_symbols[i +
701
+ 1] if i < len(result_symbols) - 1 else None
702
+ previous_symbol = result_symbols[i - 1] if i > 0 else None
703
+ if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
704
+ if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)):
705
+ result_symbols[i] = liang
706
+
707
+ # if big is True, '两' will not be used and `alt_two` has no impact on output
708
+ if big:
709
+ attr_name = 'big_'
710
+ if traditional:
711
+ attr_name += 't'
712
+ else:
713
+ attr_name += 's'
714
+ else:
715
+ if traditional:
716
+ attr_name = 'traditional'
717
+ else:
718
+ attr_name = 'simplified'
719
+
720
+ result = ''.join([getattr(s, attr_name) for s in result_symbols])
721
+
722
+ # if not use_zeros:
723
+ # result = result.strip(getattr(system.digits[0], attr_name))
724
+
725
+ if alt_zero:
726
+ result = result.replace(
727
+ getattr(system.digits[0], attr_name), system.digits[0].alt_s)
728
+
729
+ if alt_one:
730
+ result = result.replace(
731
+ getattr(system.digits[1], attr_name), system.digits[1].alt_s)
732
+
733
+ for i, p in enumerate(POINT):
734
+ if result.startswith(p):
735
+ return CHINESE_DIGIS[0] + result
736
+
737
+ # ^10, 11, .., 19
738
+ if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
739
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \
740
+ result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]:
741
+ result = result[1:]
742
+
743
+ return result
744
+
745
+
746
+ # ================================================================================ #
747
+ # different types of rewriters
748
+ # ================================================================================ #
749
+ class Cardinal:
750
+ """
751
+ CARDINAL类
752
+ """
753
+
754
+ def __init__(self, cardinal=None, chntext=None):
755
+ self.cardinal = cardinal
756
+ self.chntext = chntext
757
+
758
+ def chntext2cardinal(self):
759
+ return chn2num(self.chntext)
760
+
761
+ def cardinal2chntext(self):
762
+ return num2chn(self.cardinal)
763
+
764
+ class Digit:
765
+ """
766
+ DIGIT类
767
+ """
768
+
769
+ def __init__(self, digit=None, chntext=None):
770
+ self.digit = digit
771
+ self.chntext = chntext
772
+
773
+ # def chntext2digit(self):
774
+ # return chn2num(self.chntext)
775
+
776
+ def digit2chntext(self):
777
+ return num2chn(self.digit, alt_two=False, use_units=False)
778
+
779
+
780
+ class TelePhone:
781
+ """
782
+ TELEPHONE类
783
+ """
784
+
785
+ def __init__(self, telephone=None, raw_chntext=None, chntext=None):
786
+ self.telephone = telephone
787
+ self.raw_chntext = raw_chntext
788
+ self.chntext = chntext
789
+
790
+ # def chntext2telephone(self):
791
+ # sil_parts = self.raw_chntext.split('<SIL>')
792
+ # self.telephone = '-'.join([
793
+ # str(chn2num(p)) for p in sil_parts
794
+ # ])
795
+ # return self.telephone
796
+
797
+ def telephone2chntext(self, fixed=False):
798
+
799
+ if fixed:
800
+ sil_parts = self.telephone.split('-')
801
+ self.raw_chntext = '<SIL>'.join([
802
+ num2chn(part, alt_two=False, use_units=False) for part in sil_parts
803
+ ])
804
+ self.chntext = self.raw_chntext.replace('<SIL>', '')
805
+ else:
806
+ sp_parts = self.telephone.strip('+').split()
807
+ self.raw_chntext = '<SP>'.join([
808
+ num2chn(part, alt_two=False, use_units=False) for part in sp_parts
809
+ ])
810
+ self.chntext = self.raw_chntext.replace('<SP>', '')
811
+ return self.chntext
812
+
813
+
814
+ class Fraction:
815
+ """
816
+ FRACTION类
817
+ """
818
+
819
+ def __init__(self, fraction=None, chntext=None):
820
+ self.fraction = fraction
821
+ self.chntext = chntext
822
+
823
+ def chntext2fraction(self):
824
+ denominator, numerator = self.chntext.split('分之')
825
+ return chn2num(numerator) + '/' + chn2num(denominator)
826
+
827
+ def fraction2chntext(self):
828
+ numerator, denominator = self.fraction.split('/')
829
+ return num2chn(denominator) + '分之' + num2chn(numerator)
830
+
831
+
832
+ class Date:
833
+ """
834
+ DATE类
835
+ """
836
+
837
+ def __init__(self, date=None, chntext=None):
838
+ self.date = date
839
+ self.chntext = chntext
840
+
841
+ # def chntext2date(self):
842
+ # chntext = self.chntext
843
+ # try:
844
+ # year, other = chntext.strip().split('年', maxsplit=1)
845
+ # year = Digit(chntext=year).digit2chntext() + '年'
846
+ # except ValueError:
847
+ # other = chntext
848
+ # year = ''
849
+ # if other:
850
+ # try:
851
+ # month, day = other.strip().split('月', maxsplit=1)
852
+ # month = Cardinal(chntext=month).chntext2cardinal() + '月'
853
+ # except ValueError:
854
+ # day = chntext
855
+ # month = ''
856
+ # if day:
857
+ # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
858
+ # else:
859
+ # month = ''
860
+ # day = ''
861
+ # date = year + month + day
862
+ # self.date = date
863
+ # return self.date
864
+
865
+ def date2chntext(self):
866
+ date = self.date
867
+ try:
868
+ year, other = date.strip().split('年', 1)
869
+ year = Digit(digit=year).digit2chntext() + '年'
870
+ except ValueError:
871
+ other = date
872
+ year = ''
873
+ if other:
874
+ try:
875
+ month, day = other.strip().split('月', 1)
876
+ month = Cardinal(cardinal=month).cardinal2chntext() + '月'
877
+ except ValueError:
878
+ day = date
879
+ month = ''
880
+ if day:
881
+ day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
882
+ else:
883
+ month = ''
884
+ day = ''
885
+ chntext = year + month + day
886
+ self.chntext = chntext
887
+ return self.chntext
888
+
889
+
890
+ class Money:
891
+ """
892
+ MONEY类
893
+ """
894
+
895
+ def __init__(self, money=None, chntext=None):
896
+ self.money = money
897
+ self.chntext = chntext
898
+
899
+ # def chntext2money(self):
900
+ # return self.money
901
+
902
+ def money2chntext(self):
903
+ money = self.money
904
+ pattern = re.compile(r'(\d+(\.\d+)?)')
905
+ matchers = pattern.findall(money)
906
+ if matchers:
907
+ for matcher in matchers:
908
+ money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext())
909
+ self.chntext = money
910
+ return self.chntext
911
+
912
+
913
+ class Percentage:
914
+ """
915
+ PERCENTAGE类
916
+ """
917
+
918
+ def __init__(self, percentage=None, chntext=None):
919
+ self.percentage = percentage
920
+ self.chntext = chntext
921
+
922
+ def chntext2percentage(self):
923
+ return chn2num(self.chntext.strip().strip('百分之')) + '%'
924
+
925
+ def percentage2chntext(self):
926
+ return '百分之' + num2chn(self.percentage.strip().strip('%'))
927
+
928
+
929
+ def normalize_nsw(raw_text):
930
+ text = '^' + raw_text + '$'
931
+
932
+ # 规范化日期
933
+ pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)")
934
+ matchers = pattern.findall(text)
935
+ if matchers:
936
+ #print('date')
937
+ for matcher in matchers:
938
+ text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
939
+
940
+ # 规范化金钱
941
+ pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
942
+ matchers = pattern.findall(text)
943
+ if matchers:
944
+ #print('money')
945
+ for matcher in matchers:
946
+ text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
947
+
948
+ # 规范化固话/手机号码
949
+ # 手机
950
+ # http://www.jihaoba.com/news/show/13680
951
+ # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
952
+ # 联通:130、131、132、156、155、186、185、176
953
+ # 电信:133、153、189、180、181、177
954
+ pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
955
+ matchers = pattern.findall(text)
956
+ if matchers:
957
+ #print('telephone')
958
+ for matcher in matchers:
959
+ text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
960
+ # 固话
961
+ pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
962
+ matchers = pattern.findall(text)
963
+ if matchers:
964
+ # print('fixed telephone')
965
+ for matcher in matchers:
966
+ text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
967
+
968
+ # 规范化分数
969
+ pattern = re.compile(r"(\d+/\d+)")
970
+ matchers = pattern.findall(text)
971
+ if matchers:
972
+ #print('fraction')
973
+ for matcher in matchers:
974
+ text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
975
+
976
+ # 规范化百分数
977
+ text = text.replace('%', '%')
978
+ pattern = re.compile(r"(\d+(\.\d+)?%)")
979
+ matchers = pattern.findall(text)
980
+ if matchers:
981
+ #print('percentage')
982
+ for matcher in matchers:
983
+ text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
984
+
985
+ # 规范化纯数+量词
986
+ pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
987
+ matchers = pattern.findall(text)
988
+ if matchers:
989
+ #print('cardinal+quantifier')
990
+ for matcher in matchers:
991
+ text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
992
+
993
+ # 规范化数字编号
994
+ pattern = re.compile(r"(\d{4,32})")
995
+ matchers = pattern.findall(text)
996
+ if matchers:
997
+ #print('digit')
998
+ for matcher in matchers:
999
+ text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
1000
+
1001
+ # 规范化纯数
1002
+ pattern = re.compile(r"(\d+(\.\d+)?)")
1003
+ matchers = pattern.findall(text)
1004
+ if matchers:
1005
+ #print('cardinal')
1006
+ for matcher in matchers:
1007
+ text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
1008
+
1009
+
1010
+ # restore P2P, O2O, B2C, B2B etc
1011
+ pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
1012
+ matchers = pattern.findall(text)
1013
+ if matchers:
1014
+ # print('particular')
1015
+ for matcher in matchers:
1016
+ text = text.replace(matcher[0], matcher[1]+'2'+matcher[2], 1)
1017
+
1018
+ return text.lstrip('^').rstrip('$')
1019
+
1020
+
1021
+ def remove_erhua(text):
1022
+ """
1023
+ 去除儿化音词中的儿:
1024
+ 他女儿在那边儿 -> 他女儿在那边
1025
+ """
1026
+
1027
+ new_str=''
1028
+ while re.search('儿',text):
1029
+ a = re.search('儿',text).span()
1030
+ remove_er_flag = 0
1031
+
1032
+ if ER_WHITELIST_PATTERN.search(text):
1033
+ b = ER_WHITELIST_PATTERN.search(text).span()
1034
+ if b[0] <= a[0]:
1035
+ remove_er_flag = 1
1036
+
1037
+ if remove_er_flag == 0 :
1038
+ new_str = new_str + text[0:a[0]]
1039
+ text = text[a[1]:]
1040
+ else:
1041
+ new_str = new_str + text[0:b[1]]
1042
+ text = text[b[1]:]
1043
+
1044
+ text = new_str + text
1045
+ return text
1046
+
1047
+
1048
+ def remove_space(text):
1049
+ tokens = text.split()
1050
+ new = []
1051
+ for k,t in enumerate(tokens):
1052
+ if k != 0:
1053
+ if IN_EN_CHARS.get(tokens[k-1][-1]) and IN_EN_CHARS.get(t[0]):
1054
+ new.append(' ')
1055
+ new.append(t)
1056
+ return ''.join(new)
1057
+
1058
+
1059
+ class TextNorm:
1060
+ def __init__(self,
1061
+ to_banjiao:bool = False,
1062
+ to_upper:bool = False,
1063
+ to_lower:bool = False,
1064
+ remove_fillers:bool = False,
1065
+ remove_erhua:bool = False,
1066
+ check_chars:bool = False,
1067
+ remove_space:bool = False,
1068
+ cc_mode:str = '',
1069
+ ) :
1070
+ self.to_banjiao = to_banjiao
1071
+ self.to_upper = to_upper
1072
+ self.to_lower = to_lower
1073
+ self.remove_fillers = remove_fillers
1074
+ self.remove_erhua = remove_erhua
1075
+ self.check_chars = check_chars
1076
+ self.remove_space = remove_space
1077
+
1078
+ self.cc = None
1079
+ if cc_mode:
1080
+ from opencc import OpenCC # Open Chinese Convert: pip install opencc
1081
+ self.cc = OpenCC(cc_mode)
1082
+
1083
+ def __call__(self, text):
1084
+ if self.cc:
1085
+ text = self.cc.convert(text)
1086
+
1087
+ if self.to_banjiao:
1088
+ text = text.translate(QJ2BJ_TRANSFORM)
1089
+
1090
+ if self.to_upper:
1091
+ text = text.upper()
1092
+
1093
+ if self.to_lower:
1094
+ text = text.lower()
1095
+
1096
+ if self.remove_fillers:
1097
+ for c in FILLER_CHARS:
1098
+ text = text.replace(c, '')
1099
+
1100
+ if self.remove_erhua:
1101
+ text = remove_erhua(text)
1102
+
1103
+ text = normalize_nsw(text)
1104
+
1105
+ text = text.translate(PUNCS_TRANSFORM)
1106
+
1107
+ if self.check_chars:
1108
+ for c in text:
1109
+ if not IN_VALID_CHARS.get(c):
1110
+ print(f'WARNING: illegal char {c} in: {text}', file=sys.stderr)
1111
+ return ''
1112
+
1113
+ if self.remove_space:
1114
+ text = remove_space(text)
1115
+
1116
+ return text
1117
+
1118
+
1119
+ if __name__ == '__main__':
1120
+ p = argparse.ArgumentParser()
1121
+
1122
+ # normalizer options
1123
+ p.add_argument('--to_banjiao', action='store_true', help='convert quanjiao chars to banjiao')
1124
+ p.add_argument('--to_upper', action='store_true', help='convert to upper case')
1125
+ p.add_argument('--to_lower', action='store_true', help='convert to lower case')
1126
+ p.add_argument('--remove_fillers', action='store_true', help='remove filler chars such as "呃, 啊"')
1127
+ p.add_argument('--remove_erhua', action='store_true', help='remove erhua chars such as "他女儿在那边儿 -> 他女儿在那边"')
1128
+ p.add_argument('--check_chars', action='store_true' , help='skip sentences containing illegal chars')
1129
+ p.add_argument('--remove_space', action='store_true' , help='remove whitespace')
1130
+ p.add_argument('--cc_mode', choices=['', 't2s', 's2t'], default='', help='convert between traditional to simplified')
1131
+
1132
+ # I/O options
1133
+ p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines')
1134
+ p.add_argument('--has_key', action='store_true', help="will be deprecated, set --format ark instead")
1135
+ p.add_argument('--format', type=str, choices=['txt', 'ark', 'tsv'], default='txt', help='input format')
1136
+ p.add_argument('ifile', help='input filename, assume utf-8 encoding')
1137
+ p.add_argument('ofile', help='output filename')
1138
+
1139
+ args = p.parse_args()
1140
+
1141
+ if args.has_key:
1142
+ args.format = 'ark'
1143
+
1144
+ normalizer = TextNorm(
1145
+ to_banjiao = args.to_banjiao,
1146
+ to_upper = args.to_upper,
1147
+ to_lower = args.to_lower,
1148
+ remove_fillers = args.remove_fillers,
1149
+ remove_erhua = args.remove_erhua,
1150
+ check_chars = args.check_chars,
1151
+ remove_space = args.remove_space,
1152
+ cc_mode = args.cc_mode,
1153
+ )
1154
+
1155
+ normalizer = TextNorm(
1156
+ to_banjiao = args.to_banjiao,
1157
+ to_upper = args.to_upper,
1158
+ to_lower = args.to_lower,
1159
+ remove_fillers = args.remove_fillers,
1160
+ remove_erhua = args.remove_erhua,
1161
+ check_chars = args.check_chars,
1162
+ remove_space = args.remove_space,
1163
+ cc_mode = args.cc_mode,
1164
+ )
1165
+
1166
+ ndone = 0
1167
+ with open(args.ifile, 'r', encoding = 'utf8') as istream, open(args.ofile, 'w+', encoding = 'utf8') as ostream:
1168
+ if args.format == 'tsv':
1169
+ reader = csv.DictReader(istream, delimiter = '\t')
1170
+ assert('TEXT' in reader.fieldnames)
1171
+ print('\t'.join(reader.fieldnames), file=ostream)
1172
+
1173
+ for item in reader:
1174
+ text = item['TEXT']
1175
+
1176
+ if text:
1177
+ text = normalizer(text)
1178
+
1179
+ if text:
1180
+ item['TEXT'] = text
1181
+ print('\t'.join([ item[f] for f in reader.fieldnames ]), file = ostream)
1182
+
1183
+ ndone += 1
1184
+ if ndone % args.log_interval == 0:
1185
+ print(f'text norm: {ndone} lines done.', file = sys.stderr, flush = True)
1186
+ else:
1187
+ for l in istream:
1188
+ key, text = '', ''
1189
+ if args.format == 'ark': # KALDI archive, line format: "key text"
1190
+ cols = l.strip().split(maxsplit=1)
1191
+ key, text = cols[0], cols[1] if len(cols) == 2 else ''
1192
+ else:
1193
+ text = l.strip()
1194
+
1195
+ if text:
1196
+ text = normalizer(text)
1197
+
1198
+ if text:
1199
+ if args.format == 'ark':
1200
+ print(key + '\t' + text, file = ostream)
1201
+ else:
1202
+ print(text, file = ostream)
1203
+
1204
+ ndone += 1
1205
+ if ndone % args.log_interval == 0:
1206
+ print(f'text norm: {ndone} lines done.', file = sys.stderr, flush = True)
1207
+ print(f'text norm: {ndone} lines done in total.', file = sys.stderr, flush = True)
TTS/TTS/tts/models/forward_tts.py CHANGED
@@ -396,6 +396,7 @@ class ForwardTTS(BaseTTS):
396
  - g: :math:`(B, C)`
397
  """
398
  if hasattr(self, "emb_g"):
 
399
  g = self.emb_g(g) # [B, C, 1]
400
  if g is not None:
401
  g = g.unsqueeze(-1)
@@ -683,9 +684,10 @@ class ForwardTTS(BaseTTS):
683
  # encoder pass
684
  o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
685
  # duration predictor pass
686
- o_dr_log = self.duration_predictor(o_en, x_mask)
687
  o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
688
  y_lengths = o_dr.sum(1)
 
689
  # pitch predictor pass
690
  o_pitch = None
691
  if self.args.use_pitch:
 
396
  - g: :math:`(B, C)`
397
  """
398
  if hasattr(self, "emb_g"):
399
+ g = g.type(torch.LongTensor)
400
  g = self.emb_g(g) # [B, C, 1]
401
  if g is not None:
402
  g = g.unsqueeze(-1)
 
684
  # encoder pass
685
  o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
686
  # duration predictor pass
687
+ o_dr_log = self.duration_predictor(o_en.squeeze(), x_mask)
688
  o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
689
  y_lengths = o_dr.sum(1)
690
+
691
  # pitch predictor pass
692
  o_pitch = None
693
  if self.args.use_pitch:
TTS/TTS/tts/models/xtts.py CHANGED
@@ -13,9 +13,12 @@ from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedu
13
  from TTS.tts.layers.xtts.gpt import GPT
14
  from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
15
  from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
 
 
16
  from TTS.tts.models.base_tts import BaseTTS
17
  from TTS.utils.io import load_fsspec
18
 
 
19
 
20
  def load_audio(audiopath, sr=22050):
21
  """
@@ -195,13 +198,12 @@ class XttsArgs(Coqpit):
195
  Args:
196
  gpt_batch_size (int): The size of the auto-regressive batch.
197
  enable_redaction (bool, optional): Whether to enable redaction. Defaults to True.
198
- lazy_load (bool, optional): Whether to load models on demand. It reduces VRAM usage. Defaults to False.
199
  kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True.
200
  gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None.
201
  clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
202
  decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
203
  num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
204
- vocoder (VocType, optional): The vocoder to use for synthesis. Defaults to VocConf.Univnet.
205
 
206
  For GPT model:
207
  ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
@@ -231,12 +233,13 @@ class XttsArgs(Coqpit):
231
 
232
  gpt_batch_size: int = 1
233
  enable_redaction: bool = False
234
- lazy_load: bool = True
235
  kv_cache: bool = True
236
  gpt_checkpoint: str = None
237
  clvp_checkpoint: str = None
238
  decoder_checkpoint: str = None
239
  num_chars: int = 255
 
 
240
 
241
  # XTTS GPT Encoder params
242
  tokenizer_file: str = ""
@@ -266,6 +269,15 @@ class XttsArgs(Coqpit):
266
  diff_layer_drop: int = 0
267
  diff_unconditioned_percentage: int = 0
268
 
 
 
 
 
 
 
 
 
 
269
  # constants
270
  duration_const: int = 102400
271
 
@@ -285,7 +297,6 @@ class Xtts(BaseTTS):
285
 
286
  def __init__(self, config: Coqpit):
287
  super().__init__(config, ap=None, tokenizer=None)
288
- self.lazy_load = self.args.lazy_load
289
  self.mel_stats_path = None
290
  self.config = config
291
  self.gpt_checkpoint = self.args.gpt_checkpoint
@@ -295,14 +306,13 @@ class Xtts(BaseTTS):
295
 
296
  self.tokenizer = VoiceBpeTokenizer()
297
  self.gpt = None
298
- self.diffusion_decoder = None
299
  self.init_models()
300
  self.register_buffer("mel_stats", torch.ones(80))
301
 
302
  def init_models(self):
303
  """Initialize the models. We do it here since we need to load the tokenizer first."""
304
  if self.tokenizer.tokenizer is not None:
305
- self.args.gpt_number_text_tokens = self.tokenizer.tokenizer.get_vocab_size()
306
  self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]")
307
  self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]")
308
 
@@ -322,40 +332,50 @@ class Xtts(BaseTTS):
322
  stop_audio_token=self.args.gpt_stop_audio_token,
323
  )
324
 
325
- self.diffusion_decoder = DiffusionTts(
326
- model_channels=self.args.diff_model_channels,
327
- num_layers=self.args.diff_num_layers,
328
- in_channels=self.args.diff_in_channels,
329
- out_channels=self.args.diff_out_channels,
330
- in_latent_channels=self.args.diff_in_latent_channels,
331
- in_tokens=self.args.diff_in_tokens,
332
- dropout=self.args.diff_dropout,
333
- use_fp16=self.args.diff_use_fp16,
334
- num_heads=self.args.diff_num_heads,
335
- layer_drop=self.args.diff_layer_drop,
336
- unconditioned_percentage=self.args.diff_unconditioned_percentage,
337
- )
338
 
339
- self.vocoder = UnivNetGenerator()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
  @property
342
  def device(self):
343
  return next(self.parameters()).device
344
 
345
- @contextmanager
346
- def lazy_load_model(self, model):
347
- """Context to load a model on demand.
348
-
349
- Args:
350
- model (nn.Module): The model to be loaded.
351
- """
352
- if self.lazy_load:
353
- yield model
354
- else:
355
- m = model.to(self.device)
356
- yield m
357
- m = model.cpu()
358
-
359
  def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
360
  """Compute the conditioning latents for the GPT model from the given audio.
361
 
@@ -370,6 +390,7 @@ class Xtts(BaseTTS):
370
  cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False)
371
  return cond_latent.transpose(1, 2)
372
 
 
373
  def get_diffusion_cond_latents(
374
  self,
375
  audio_path,
@@ -389,20 +410,33 @@ class Xtts(BaseTTS):
389
  )
390
  diffusion_conds.append(cond_mel)
391
  diffusion_conds = torch.stack(diffusion_conds, dim=1)
392
- with self.lazy_load_model(self.diffusion_decoder) as diffusion:
393
- diffusion_latent = diffusion.get_conditioning(diffusion_conds)
394
  return diffusion_latent
395
 
 
 
 
 
 
 
 
 
 
 
 
396
  def get_conditioning_latents(
397
  self,
398
  audio_path,
399
  gpt_cond_len=3,
400
- ):
 
 
 
 
 
 
401
  gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T]
402
- diffusion_cond_latents = self.get_diffusion_cond_latents(
403
- audio_path,
404
- )
405
- return gpt_cond_latents.to(self.device), diffusion_cond_latents.to(self.device)
406
 
407
  def synthesize(self, text, config, speaker_wav, language, **kwargs):
408
  """Synthesize speech with the given input text.
@@ -447,10 +481,10 @@ class Xtts(BaseTTS):
447
  "decoder_sampler": config.decoder_sampler,
448
  }
449
  settings.update(kwargs) # allow overriding of preset settings with kwargs
450
- return self.inference(text, ref_audio_path, language, **settings)
451
 
452
- @torch.no_grad()
453
- def inference(
454
  self,
455
  text,
456
  ref_audio_path,
@@ -469,6 +503,7 @@ class Xtts(BaseTTS):
469
  cond_free_k=2,
470
  diffusion_temperature=1.0,
471
  decoder_sampler="ddim",
 
472
  **hf_generate_kwargs,
473
  ):
474
  """
@@ -517,6 +552,9 @@ class Xtts(BaseTTS):
517
  Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared.
518
  Defaults to 1.0.
519
 
 
 
 
520
  hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
521
  transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
522
  here: https://huggingface.co/docs/transformers/internal/generation_utils
@@ -525,81 +563,217 @@ class Xtts(BaseTTS):
525
  Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
526
  Sample rate is 24kHz.
527
  """
528
- text = f"[{language}]{text.strip().lower()}"
529
- text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
530
-
531
- assert (
532
- text_tokens.shape[-1] < self.args.gpt_max_text_tokens
533
- ), " ❗ XTTS can only generate text with a maximum of 400 tokens."
534
-
535
  (
536
  gpt_cond_latent,
537
  diffusion_conditioning,
 
538
  ) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
539
-
540
- diffuser = load_discrete_vocoder_diffuser(
541
- desired_diffusion_steps=decoder_iterations,
 
 
 
 
 
 
 
 
 
 
542
  cond_free=cond_free,
543
  cond_free_k=cond_free_k,
544
- sampler=decoder_sampler,
 
 
 
545
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546
 
547
- with torch.no_grad():
548
- self.gpt = self.gpt.to(self.device)
549
- with self.lazy_load_model(self.gpt) as gpt:
550
- gpt_codes = gpt.generate(
551
- cond_latents=gpt_cond_latent,
552
- text_inputs=text_tokens,
553
- input_tokens=None,
554
- do_sample=do_sample,
555
- top_p=top_p,
556
- top_k=top_k,
557
- temperature=temperature,
558
- num_return_sequences=self.gpt_batch_size,
559
- length_penalty=length_penalty,
560
- repetition_penalty=repetition_penalty,
561
- output_attentions=False,
562
- **hf_generate_kwargs,
563
- )
564
 
565
- with self.lazy_load_model(self.gpt) as gpt:
566
- expected_output_len = torch.tensor(
567
- [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
568
- )
569
- text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
570
- gpt_latents = gpt(
571
- text_tokens,
572
- text_len,
573
- gpt_codes,
574
- expected_output_len,
575
- cond_latents=gpt_cond_latent,
576
- return_attentions=False,
577
- return_latent=True,
578
- )
579
- silence_token = 83
580
- ctokens = 0
581
- for k in range(gpt_codes.shape[-1]):
582
- if gpt_codes[0, k] == silence_token:
583
- ctokens += 1
584
- else:
585
- ctokens = 0
586
- if ctokens > 8:
587
- gpt_latents = gpt_latents[:, :k]
588
- break
589
-
590
- with self.lazy_load_model(self.diffusion_decoder) as diffusion:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591
  mel = do_spectrogram_diffusion(
592
- diffusion,
593
  diffuser,
594
  gpt_latents,
595
  diffusion_conditioning,
596
  temperature=diffusion_temperature,
597
  )
598
- with self.lazy_load_model(self.vocoder) as vocoder:
599
- wav = vocoder.inference(mel)
600
 
601
  return {"wav": wav.cpu().numpy().squeeze()}
602
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603
  def forward(self):
604
  raise NotImplementedError("XTTS Training is not implemented")
605
 
@@ -616,7 +790,14 @@ class Xtts(BaseTTS):
616
  super().eval()
617
 
618
  def load_checkpoint(
619
- self, config, checkpoint_dir=None, checkpoint_path=None, vocab_path=None, eval=False, strict=True
 
 
 
 
 
 
 
620
  ):
621
  """
622
  Loads a checkpoint from disk and initializes the model's state and tokenizer.
@@ -626,7 +807,7 @@ class Xtts(BaseTTS):
626
  checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
627
  checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
628
  vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
629
- eval (bool, optional): Whether to set the model to evaluation mode. Defaults to False.
630
  strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.
631
 
632
  Returns:
@@ -636,19 +817,34 @@ class Xtts(BaseTTS):
636
  model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
637
  vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
638
 
639
- if os.path.exists(os.path.join(checkpoint_dir, "vocab.json")):
640
- self.tokenizer = VoiceBpeTokenizer(vocab_file=os.path.join(checkpoint_dir, "vocab.json"))
641
 
642
  self.init_models()
643
- if eval:
644
- self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
645
- self.load_state_dict(load_fsspec(model_path, map_location=self.device)["model"], strict=strict)
 
 
 
 
 
 
 
 
 
 
 
 
 
646
 
647
  if eval:
648
- self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
 
 
 
 
649
  self.gpt.eval()
650
- self.diffusion_decoder.eval()
651
- self.vocoder.eval()
652
 
653
  def train_step(self):
654
  raise NotImplementedError("XTTS Training is not implemented")
 
13
  from TTS.tts.layers.xtts.gpt import GPT
14
  from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer
15
  from TTS.tts.layers.xtts.vocoder import UnivNetGenerator
16
+ from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
17
+ from TTS.tts.layers.xtts.stream_generator import init_stream_support
18
  from TTS.tts.models.base_tts import BaseTTS
19
  from TTS.utils.io import load_fsspec
20
 
21
+ init_stream_support()
22
 
23
  def load_audio(audiopath, sr=22050):
24
  """
 
198
  Args:
199
  gpt_batch_size (int): The size of the auto-regressive batch.
200
  enable_redaction (bool, optional): Whether to enable redaction. Defaults to True.
 
201
  kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True.
202
  gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None.
203
  clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
204
  decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
205
  num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
206
+ use_hifigan (bool, optional): Whether to use hifigan or diffusion + univnet as a decoder. Defaults to True.
207
 
208
  For GPT model:
209
  ar_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
 
233
 
234
  gpt_batch_size: int = 1
235
  enable_redaction: bool = False
 
236
  kv_cache: bool = True
237
  gpt_checkpoint: str = None
238
  clvp_checkpoint: str = None
239
  decoder_checkpoint: str = None
240
  num_chars: int = 255
241
+ use_hifigan: bool = True
242
+ use_ne_hifigan: bool = False
243
 
244
  # XTTS GPT Encoder params
245
  tokenizer_file: str = ""
 
269
  diff_layer_drop: int = 0
270
  diff_unconditioned_percentage: int = 0
271
 
272
+ # HifiGAN Decoder params
273
+ input_sample_rate: int = 22050
274
+ output_sample_rate: int = 24000
275
+ output_hop_length: int = 256
276
+ ar_mel_length_compression: int = 1024
277
+ decoder_input_dim: int = 1024
278
+ d_vector_dim: int = 512
279
+ cond_d_vector_in_each_upsampling_layer: bool = True
280
+
281
  # constants
282
  duration_const: int = 102400
283
 
 
297
 
298
  def __init__(self, config: Coqpit):
299
  super().__init__(config, ap=None, tokenizer=None)
 
300
  self.mel_stats_path = None
301
  self.config = config
302
  self.gpt_checkpoint = self.args.gpt_checkpoint
 
306
 
307
  self.tokenizer = VoiceBpeTokenizer()
308
  self.gpt = None
 
309
  self.init_models()
310
  self.register_buffer("mel_stats", torch.ones(80))
311
 
312
  def init_models(self):
313
  """Initialize the models. We do it here since we need to load the tokenizer first."""
314
  if self.tokenizer.tokenizer is not None:
315
+ self.args.gpt_number_text_tokens = self.tokenizer.get_number_tokens()
316
  self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]")
317
  self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]")
318
 
 
332
  stop_audio_token=self.args.gpt_stop_audio_token,
333
  )
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
+ if self.args.use_hifigan:
337
+ self.hifigan_decoder = HifiDecoder(
338
+ input_sample_rate=self.args.input_sample_rate,
339
+ output_sample_rate=self.args.output_sample_rate,
340
+ output_hop_length=self.args.output_hop_length,
341
+ ar_mel_length_compression=self.args.ar_mel_length_compression,
342
+ decoder_input_dim=self.args.decoder_input_dim,
343
+ d_vector_dim=self.args.d_vector_dim,
344
+ cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
345
+ )
346
+
347
+ if self.args.use_ne_hifigan:
348
+ self.ne_hifigan_decoder = HifiDecoder(
349
+ input_sample_rate=self.args.input_sample_rate,
350
+ output_sample_rate=self.args.output_sample_rate,
351
+ output_hop_length=self.args.output_hop_length,
352
+ ar_mel_length_compression=self.args.ar_mel_length_compression,
353
+ decoder_input_dim=self.args.decoder_input_dim,
354
+ d_vector_dim=self.args.d_vector_dim,
355
+ cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
356
+ )
357
+
358
+ if not (self.args.use_hifigan or self.args.use_ne_hifigan):
359
+ self.diffusion_decoder = DiffusionTts(
360
+ model_channels=self.args.diff_model_channels,
361
+ num_layers=self.args.diff_num_layers,
362
+ in_channels=self.args.diff_in_channels,
363
+ out_channels=self.args.diff_out_channels,
364
+ in_latent_channels=self.args.diff_in_latent_channels,
365
+ in_tokens=self.args.diff_in_tokens,
366
+ dropout=self.args.diff_dropout,
367
+ use_fp16=self.args.diff_use_fp16,
368
+ num_heads=self.args.diff_num_heads,
369
+ layer_drop=self.args.diff_layer_drop,
370
+ unconditioned_percentage=self.args.diff_unconditioned_percentage,
371
+ )
372
+ self.vocoder = UnivNetGenerator()
373
 
374
  @property
375
  def device(self):
376
  return next(self.parameters()).device
377
 
378
+ @torch.inference_mode()
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  def get_gpt_cond_latents(self, audio_path: str, length: int = 3):
380
  """Compute the conditioning latents for the GPT model from the given audio.
381
 
 
390
  cond_latent = self.gpt.get_style_emb(mel.to(self.device), sample=False)
391
  return cond_latent.transpose(1, 2)
392
 
393
+ @torch.inference_mode()
394
  def get_diffusion_cond_latents(
395
  self,
396
  audio_path,
 
410
  )
411
  diffusion_conds.append(cond_mel)
412
  diffusion_conds = torch.stack(diffusion_conds, dim=1)
413
+ diffusion_latent = self.diffusion_decoder.get_conditioning(diffusion_conds)
 
414
  return diffusion_latent
415
 
416
+ @torch.inference_mode()
417
+ def get_speaker_embedding(
418
+ self,
419
+ audio_path
420
+ ):
421
+ audio = load_audio(audio_path, self.hifigan_decoder.speaker_encoder_audio_config["sample_rate"])
422
+ speaker_embedding = self.hifigan_decoder.speaker_encoder.forward(
423
+ audio.to(self.device), l2_norm=True
424
+ ).unsqueeze(-1).to(self.device)
425
+ return speaker_embedding
426
+
427
  def get_conditioning_latents(
428
  self,
429
  audio_path,
430
  gpt_cond_len=3,
431
+ ):
432
+ speaker_embedding = None
433
+ diffusion_cond_latents = None
434
+ if self.args.use_hifigan:
435
+ speaker_embedding = self.get_speaker_embedding(audio_path)
436
+ else:
437
+ diffusion_cond_latents = self.get_diffusion_cond_latents(audio_path)
438
  gpt_cond_latents = self.get_gpt_cond_latents(audio_path, length=gpt_cond_len) # [1, 1024, T]
439
+ return gpt_cond_latents, diffusion_cond_latents, speaker_embedding
 
 
 
440
 
441
  def synthesize(self, text, config, speaker_wav, language, **kwargs):
442
  """Synthesize speech with the given input text.
 
481
  "decoder_sampler": config.decoder_sampler,
482
  }
483
  settings.update(kwargs) # allow overriding of preset settings with kwargs
484
+ return self.full_inference(text, ref_audio_path, language, **settings)
485
 
486
+ @torch.inference_mode()
487
+ def full_inference(
488
  self,
489
  text,
490
  ref_audio_path,
 
503
  cond_free_k=2,
504
  diffusion_temperature=1.0,
505
  decoder_sampler="ddim",
506
+ decoder="hifigan",
507
  **hf_generate_kwargs,
508
  ):
509
  """
 
552
  Values at 0 re the "mean" prediction of the diffusion network and will sound bland and smeared.
553
  Defaults to 1.0.
554
 
555
+ decoder: (str) Selects the decoder to use between ("hifigan", "ne_hifigan" and "diffusion")
556
+ Defaults to hifigan
557
+
558
  hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
559
  transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
560
  here: https://huggingface.co/docs/transformers/internal/generation_utils
 
563
  Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
564
  Sample rate is 24kHz.
565
  """
 
 
 
 
 
 
 
566
  (
567
  gpt_cond_latent,
568
  diffusion_conditioning,
569
+ speaker_embedding
570
  ) = self.get_conditioning_latents(audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len)
571
+ return self.inference(
572
+ text,
573
+ language,
574
+ gpt_cond_latent,
575
+ speaker_embedding,
576
+ diffusion_conditioning,
577
+ temperature=temperature,
578
+ length_penalty=length_penalty,
579
+ repetition_penalty=repetition_penalty,
580
+ top_k=top_k,
581
+ top_p=top_p,
582
+ do_sample=do_sample,
583
+ decoder_iterations=decoder_iterations,
584
  cond_free=cond_free,
585
  cond_free_k=cond_free_k,
586
+ diffusion_temperature=diffusion_temperature,
587
+ decoder_sampler=decoder_sampler,
588
+ decoder=decoder,
589
+ **hf_generate_kwargs,
590
  )
591
+
592
+ @torch.inference_mode()
593
+ def inference(
594
+ self,
595
+ text,
596
+ language,
597
+ gpt_cond_latent,
598
+ speaker_embedding,
599
+ diffusion_conditioning,
600
+ # GPT inference
601
+ temperature=0.65,
602
+ length_penalty=1,
603
+ repetition_penalty=2.0,
604
+ top_k=50,
605
+ top_p=0.85,
606
+ do_sample=True,
607
+ # Decoder inference
608
+ decoder_iterations=100,
609
+ cond_free=True,
610
+ cond_free_k=2,
611
+ diffusion_temperature=1.0,
612
+ decoder_sampler="ddim",
613
+ decoder="hifigan",
614
+ **hf_generate_kwargs,
615
+ ):
616
+ text = f"[{language}]{text.strip().lower()}"
617
+ text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
618
 
619
+ assert (
620
+ text_tokens.shape[-1] < self.args.gpt_max_text_tokens
621
+ ), " ❗ XTTS can only generate text with a maximum of 400 tokens."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
622
 
623
+ if not self.args.use_hifigan:
624
+ diffuser = load_discrete_vocoder_diffuser(
625
+ desired_diffusion_steps=decoder_iterations,
626
+ cond_free=cond_free,
627
+ cond_free_k=cond_free_k,
628
+ sampler=decoder_sampler,
629
+ )
630
+
631
+ with torch.no_grad():
632
+ gpt_codes = self.gpt.generate(
633
+ cond_latents=gpt_cond_latent,
634
+ text_inputs=text_tokens,
635
+ input_tokens=None,
636
+ do_sample=do_sample,
637
+ top_p=top_p,
638
+ top_k=top_k,
639
+ temperature=temperature,
640
+ num_return_sequences=self.gpt_batch_size,
641
+ length_penalty=length_penalty,
642
+ repetition_penalty=repetition_penalty,
643
+ output_attentions=False,
644
+ **hf_generate_kwargs,
645
+ )
646
+ expected_output_len = torch.tensor(
647
+ [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
648
+ )
649
+ text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
650
+ gpt_latents = self.gpt(
651
+ text_tokens,
652
+ text_len,
653
+ gpt_codes,
654
+ expected_output_len,
655
+ cond_latents=gpt_cond_latent,
656
+ return_attentions=False,
657
+ return_latent=True,
658
+ )
659
+ silence_token = 83
660
+ ctokens = 0
661
+ for k in range(gpt_codes.shape[-1]):
662
+ if gpt_codes[0, k] == silence_token:
663
+ ctokens += 1
664
+ else:
665
+ ctokens = 0
666
+ if ctokens > 8:
667
+ gpt_latents = gpt_latents[:, :k]
668
+ break
669
+
670
+ if decoder == "hifigan":
671
+ assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
672
+ wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)
673
+ elif decoder == "ne_hifigan":
674
+ assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
675
+ wav = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding)
676
+ else:
677
+ assert hasattr(self, "diffusion_decoder"), "You must disable hifigan decoders to use difffusion by setting config `use_ne_hifigan: false` and `use_hifigan: false`"
678
  mel = do_spectrogram_diffusion(
679
+ self.diffusion_decoder,
680
  diffuser,
681
  gpt_latents,
682
  diffusion_conditioning,
683
  temperature=diffusion_temperature,
684
  )
685
+ wav = self.vocoder.inference(mel)
 
686
 
687
  return {"wav": wav.cpu().numpy().squeeze()}
688
 
689
+ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
690
+ """Handle chunk formatting in streaming mode"""
691
+ wav_chunk = wav_gen[:-overlap_len]
692
+ if wav_gen_prev is not None:
693
+ wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len]
694
+ if wav_overlap is not None:
695
+ crossfade_wav = wav_chunk[:overlap_len]
696
+ crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
697
+ wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
698
+ wav_chunk[:overlap_len] += crossfade_wav
699
+ wav_overlap = wav_gen[-overlap_len:]
700
+ wav_gen_prev = wav_gen
701
+ return wav_chunk, wav_gen_prev, wav_overlap
702
+
703
+ @torch.inference_mode()
704
+ def inference_stream(
705
+ self,
706
+ text,
707
+ language,
708
+ gpt_cond_latent,
709
+ speaker_embedding,
710
+ # Streaming
711
+ stream_chunk_size=20,
712
+ overlap_wav_len=1024,
713
+ # GPT inference
714
+ temperature=0.65,
715
+ length_penalty=1,
716
+ repetition_penalty=2.0,
717
+ top_k=50,
718
+ top_p=0.85,
719
+ do_sample=True,
720
+ # Decoder inference
721
+ decoder="hifigan",
722
+ **hf_generate_kwargs,
723
+ ):
724
+ assert hasattr(self, "hifigan_decoder"), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream."
725
+ text = f"[{language}]{text.strip().lower()}"
726
+ text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)
727
+
728
+ fake_inputs = self.gpt.compute_embeddings(
729
+ gpt_cond_latent.to(self.device),
730
+ text_tokens,
731
+ )
732
+ gpt_generator = self.gpt.get_generator(
733
+ fake_inputs=fake_inputs,
734
+ top_k=top_k,
735
+ top_p=top_p,
736
+ temperature=temperature,
737
+ do_sample=do_sample,
738
+ num_beams=1,
739
+ num_return_sequences=1,
740
+ length_penalty=float(length_penalty),
741
+ repetition_penalty=float(repetition_penalty),
742
+ output_attentions=False,
743
+ output_hidden_states=True,
744
+ **hf_generate_kwargs,
745
+ )
746
+
747
+ last_tokens = []
748
+ all_latents = []
749
+ wav_gen_prev = None
750
+ wav_overlap = None
751
+ is_end = False
752
+
753
+ while not is_end:
754
+ try:
755
+ x, latent = next(gpt_generator)
756
+ last_tokens += [x]
757
+ all_latents += [latent]
758
+ except StopIteration:
759
+ is_end = True
760
+
761
+ if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
762
+ gpt_latents = torch.cat(all_latents, dim=0)[None, :]
763
+ if decoder == "hifigan":
764
+ assert hasattr(self, "hifigan_decoder"), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`"
765
+ wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
766
+ elif decoder == "ne_hifigan":
767
+ assert hasattr(self, "ne_hifigan_decoder"), "You must enable ne_hifigan decoder to use it by setting config `use_ne_hifigan: true`"
768
+ wav_gen = self.ne_hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
769
+ else:
770
+ raise NotImplementedError("Diffusion for streaming inference not implemented.")
771
+ wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
772
+ wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
773
+ )
774
+ last_tokens = []
775
+ yield wav_chunk
776
+
777
  def forward(self):
778
  raise NotImplementedError("XTTS Training is not implemented")
779
 
 
790
  super().eval()
791
 
792
  def load_checkpoint(
793
+ self,
794
+ config,
795
+ checkpoint_dir=None,
796
+ checkpoint_path=None,
797
+ vocab_path=None,
798
+ eval=True,
799
+ strict=True,
800
+ use_deepspeed=False,
801
  ):
802
  """
803
  Loads a checkpoint from disk and initializes the model's state and tokenizer.
 
807
  checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
808
  checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
809
  vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
810
+ eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True.
811
  strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.
812
 
813
  Returns:
 
817
  model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
818
  vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
819
 
820
+ if os.path.exists(vocab_path):
821
+ self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path)
822
 
823
  self.init_models()
824
+
825
+ checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
826
+ ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan or self.args.use_ne_hifigan else []
827
+ ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"]
828
+ ignore_keys += [] if self.args.use_ne_hifigan else ["ne_hifigan_decoder"]
829
+ for key in list(checkpoint.keys()):
830
+ if key.split(".")[0] in ignore_keys:
831
+ del checkpoint[key]
832
+
833
+ # deal with v1 and v1.1. V1 has the init_gpt_for_inference keys, v1.1 do not
834
+ try:
835
+ self.load_state_dict(checkpoint, strict=strict)
836
+ except:
837
+ if eval:
838
+ self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
839
+ self.load_state_dict(checkpoint, strict=strict)
840
 
841
  if eval:
842
+ if hasattr(self, "hifigan_decoder"): self.hifigan_decoder.eval()
843
+ if hasattr(self, "ne_hifigan_decoder"): self.hifigan_decoder.eval()
844
+ if hasattr(self, "diffusion_decoder"): self.diffusion_decoder.eval()
845
+ if hasattr(self, "vocoder"): self.vocoder.eval()
846
+ self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
847
  self.gpt.eval()
 
 
848
 
849
  def train_step(self):
850
  raise NotImplementedError("XTTS Training is not implemented")
TTS/TTS/utils/audio/numpy_transforms.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Tuple
2
 
3
  import librosa
@@ -427,16 +428,24 @@ def load_wav(*, filename: str, sample_rate: int = None, resample: bool = False,
427
  return x
428
 
429
 
430
- def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, **kwargs) -> None:
431
  """Save float waveform to a file using Scipy.
432
 
433
  Args:
434
  wav (np.ndarray): Waveform with float values in range [-1, 1] to save.
435
  path (str): Path to a output file.
436
  sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
 
437
  """
438
  wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
439
- scipy.io.wavfile.write(path, sample_rate, wav_norm.astype(np.int16))
 
 
 
 
 
 
 
440
 
441
 
442
  def mulaw_encode(*, wav: np.ndarray, mulaw_qc: int, **kwargs) -> np.ndarray:
 
1
+ from io import BytesIO
2
  from typing import Tuple
3
 
4
  import librosa
 
428
  return x
429
 
430
 
431
+ def save_wav(*, wav: np.ndarray, path: str, sample_rate: int = None, pipe_out = None, **kwargs) -> None:
432
  """Save float waveform to a file using Scipy.
433
 
434
  Args:
435
  wav (np.ndarray): Waveform with float values in range [-1, 1] to save.
436
  path (str): Path to a output file.
437
  sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
438
+ pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
439
  """
440
  wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
441
+
442
+ wav_norm = wav_norm.astype(np.int16)
443
+ if pipe_out:
444
+ wav_buffer = BytesIO()
445
+ scipy.io.wavfile.write(wav_buffer, sample_rate, wav_norm)
446
+ wav_buffer.seek(0)
447
+ pipe_out.buffer.write(wav_buffer.read())
448
+ scipy.io.wavfile.write(path, sample_rate, wav_norm)
449
 
450
 
451
  def mulaw_encode(*, wav: np.ndarray, mulaw_qc: int, **kwargs) -> np.ndarray:
TTS/TTS/utils/audio/processor.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Dict, Tuple
2
 
3
  import librosa
@@ -693,20 +694,27 @@ class AudioProcessor(object):
693
  x = self.rms_volume_norm(x, self.db_level)
694
  return x
695
 
696
- def save_wav(self, wav: np.ndarray, path: str, sr: int = None) -> None:
697
  """Save a waveform to a file using Scipy.
698
 
699
  Args:
700
  wav (np.ndarray): Waveform to save.
701
  path (str): Path to a output file.
702
  sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
 
703
  """
704
  if self.do_rms_norm:
705
  wav_norm = self.rms_volume_norm(wav, self.db_level) * 32767
706
  else:
707
  wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
708
 
709
- scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm.astype(np.int16))
 
 
 
 
 
 
710
 
711
  def get_duration(self, filename: str) -> float:
712
  """Get the duration of a wav file using Librosa.
 
1
+ from io import BytesIO
2
  from typing import Dict, Tuple
3
 
4
  import librosa
 
694
  x = self.rms_volume_norm(x, self.db_level)
695
  return x
696
 
697
+ def save_wav(self, wav: np.ndarray, path: str, sr: int = None, pipe_out = None) -> None:
698
  """Save a waveform to a file using Scipy.
699
 
700
  Args:
701
  wav (np.ndarray): Waveform to save.
702
  path (str): Path to a output file.
703
  sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
704
+ pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
705
  """
706
  if self.do_rms_norm:
707
  wav_norm = self.rms_volume_norm(wav, self.db_level) * 32767
708
  else:
709
  wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
710
 
711
+ wav_norm = wav_norm.astype(np.int16)
712
+ if pipe_out:
713
+ wav_buffer = BytesIO()
714
+ scipy.io.wavfile.write(wav_buffer, sr if sr else self.sample_rate, wav_norm)
715
+ wav_buffer.seek(0)
716
+ pipe_out.buffer.write(wav_buffer.read())
717
+ scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm)
718
 
719
  def get_duration(self, filename: str) -> float:
720
  """Get the duration of a wav file using Librosa.
TTS/TTS/utils/manage.py CHANGED
@@ -6,6 +6,7 @@ from pathlib import Path
6
  from shutil import copyfile, rmtree
7
  from typing import Dict, List, Tuple
8
 
 
9
  import requests
10
  from tqdm import tqdm
11
 
@@ -293,8 +294,9 @@ class ModelManager(object):
293
  # get model from models.json
294
  model_item = self.models_dict[model_type][lang][dataset][model]
295
  model_item["model_type"] = model_type
 
296
  model_item = self.set_model_url(model_item)
297
- return model_item, model_full_name, model
298
 
299
  def ask_tos(self, model_full_path):
300
  """Ask the user to agree to the terms of service"""
@@ -320,6 +322,44 @@ class ModelManager(object):
320
  return False
321
  return True
322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  def download_model(self, model_name):
324
  """Download model files given the full model name.
325
  Model name is in the format
@@ -334,37 +374,39 @@ class ModelManager(object):
334
  Args:
335
  model_name (str): model name as explained above.
336
  """
337
- model_item, model_full_name, model = self._set_model_item(model_name)
338
  # set the model specific output path
339
  output_path = os.path.join(self.output_prefix, model_full_name)
340
  if os.path.exists(output_path):
341
- print(f" > {model_name} is already downloaded.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  else:
343
- os.makedirs(output_path, exist_ok=True)
344
- # handle TOS
345
- if not self.tos_agreed(model_item, output_path):
346
- if not self.ask_tos(output_path):
347
- os.rmdir(output_path)
348
- raise Exception(" [!] You must agree to the terms of service to use this model.")
349
- print(f" > Downloading model to {output_path}")
350
- try:
351
- if "fairseq" in model_name:
352
- self.download_fairseq_model(model_name, output_path)
353
- elif "github_rls_url" in model_item:
354
- self._download_github_model(model_item, output_path)
355
- elif "hf_url" in model_item:
356
- self._download_hf_model(model_item, output_path)
357
-
358
- except requests.RequestException as e:
359
- print(f" > Failed to download the model file to {output_path}")
360
- rmtree(output_path)
361
- raise e
362
- self.print_model_license(model_item=model_item)
363
  # find downloaded files
364
  output_model_path = output_path
365
  output_config_path = None
366
  if (
367
- model not in ["tortoise-v2", "bark", "xtts_v1"] and "fairseq" not in model_name
368
  ): # TODO:This is stupid but don't care for now.
369
  output_model_path, output_config_path = self._find_files(output_path)
370
  # update paths in the config.json
 
6
  from shutil import copyfile, rmtree
7
  from typing import Dict, List, Tuple
8
 
9
+ import fsspec
10
  import requests
11
  from tqdm import tqdm
12
 
 
294
  # get model from models.json
295
  model_item = self.models_dict[model_type][lang][dataset][model]
296
  model_item["model_type"] = model_type
297
+ md5hash = model_item["model_hash"] if "model_hash" in model_item else None
298
  model_item = self.set_model_url(model_item)
299
+ return model_item, model_full_name, model, md5hash
300
 
301
  def ask_tos(self, model_full_path):
302
  """Ask the user to agree to the terms of service"""
 
322
  return False
323
  return True
324
 
325
+ def create_dir_and_download_model(self, model_name, model_item, output_path):
326
+ os.makedirs(output_path, exist_ok=True)
327
+ # handle TOS
328
+ if not self.tos_agreed(model_item, output_path):
329
+ if not self.ask_tos(output_path):
330
+ os.rmdir(output_path)
331
+ raise Exception(" [!] You must agree to the terms of service to use this model.")
332
+ print(f" > Downloading model to {output_path}")
333
+ try:
334
+ if "fairseq" in model_name:
335
+ self.download_fairseq_model(model_name, output_path)
336
+ elif "github_rls_url" in model_item:
337
+ self._download_github_model(model_item, output_path)
338
+ elif "hf_url" in model_item:
339
+ self._download_hf_model(model_item, output_path)
340
+
341
+ except requests.RequestException as e:
342
+ print(f" > Failed to download the model file to {output_path}")
343
+ rmtree(output_path)
344
+ raise e
345
+ self.print_model_license(model_item=model_item)
346
+
347
+ def check_if_configs_are_equal(self, model_name, model_item, output_path):
348
+ with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
349
+ config_local = json.load(f)
350
+ remote_url = None
351
+ for url in model_item["hf_url"]:
352
+ if "config.json" in url:
353
+ remote_url = url
354
+ break
355
+
356
+ with fsspec.open(remote_url, "r", encoding="utf-8") as f:
357
+ config_remote = json.load(f)
358
+
359
+ if not config_local == config_remote:
360
+ print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
361
+ self.create_dir_and_download_model(model_name, model_item, output_path)
362
+
363
  def download_model(self, model_name):
364
  """Download model files given the full model name.
365
  Model name is in the format
 
374
  Args:
375
  model_name (str): model name as explained above.
376
  """
377
+ model_item, model_full_name, model, md5sum = self._set_model_item(model_name)
378
  # set the model specific output path
379
  output_path = os.path.join(self.output_prefix, model_full_name)
380
  if os.path.exists(output_path):
381
+ if md5sum is not None:
382
+ md5sum_file = os.path.join(output_path, "hash.md5")
383
+ if os.path.isfile(md5sum_file):
384
+ with open(md5sum_file, mode="r") as f:
385
+ if not f.read() == md5sum:
386
+ print(f" > {model_name} has been updated, clearing model cache...")
387
+ self.create_dir_and_download_model(model_name, model_item, output_path)
388
+ else:
389
+ print(f" > {model_name} is already downloaded.")
390
+ else:
391
+ print(f" > {model_name} has been updated, clearing model cache...")
392
+ self.create_dir_and_download_model(model_name, model_item, output_path)
393
+ # if the configs are different, redownload it
394
+ # ToDo: we need a better way to handle it
395
+ if "xtts_v1" in model_name:
396
+ try:
397
+ self.check_if_configs_are_equal(model_name, model_item, output_path)
398
+ except:
399
+ pass
400
+ else:
401
+ print(f" > {model_name} is already downloaded.")
402
  else:
403
+ self.create_dir_and_download_model(model_name, model_item, output_path)
404
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  # find downloaded files
406
  output_model_path = output_path
407
  output_config_path = None
408
  if (
409
+ model not in ["tortoise-v2", "bark", "xtts_v1", "xtts_v1.1"] and "fairseq" not in model_name
410
  ): # TODO:This is stupid but don't care for now.
411
  output_model_path, output_config_path = self._find_files(output_path)
412
  # update paths in the config.json
TTS/TTS/utils/synthesizer.py CHANGED
@@ -235,19 +235,20 @@ class Synthesizer(nn.Module):
235
  """
236
  return self.seg.segment(text)
237
 
238
- def save_wav(self, wav: List[int], path: str) -> None:
239
  """Save the waveform as a file.
240
 
241
  Args:
242
  wav (List[int]): waveform as a list of values.
243
  path (str): output path to save the waveform.
 
244
  """
245
  # if tensor convert to numpy
246
  if torch.is_tensor(wav):
247
  wav = wav.cpu().numpy()
248
  if isinstance(wav, list):
249
  wav = np.array(wav)
250
- save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate)
251
 
252
  def voice_conversion(self, source_wav: str, target_wav: str) -> List[int]:
253
  output_wav = self.vc_model.voice_conversion(source_wav, target_wav)
@@ -299,11 +300,7 @@ class Synthesizer(nn.Module):
299
  speaker_embedding = None
300
  speaker_id = None
301
  if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"):
302
- # handle Neon models with single speaker.
303
- if len(self.tts_model.speaker_manager.name_to_id) == 1:
304
- speaker_id = list(self.tts_model.speaker_manager.name_to_id.values())[0]
305
-
306
- elif speaker_name and isinstance(speaker_name, str):
307
  if self.tts_config.use_d_vector_file:
308
  # get the average speaker embedding from the saved d_vectors.
309
  speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding(
@@ -313,7 +310,9 @@ class Synthesizer(nn.Module):
313
  else:
314
  # get speaker idx from the speaker name
315
  speaker_id = self.tts_model.speaker_manager.name_to_id[speaker_name]
316
-
 
 
317
  elif not speaker_name and not speaker_wav:
318
  raise ValueError(
319
  " [!] Looks like you are using a multi-speaker model. "
 
235
  """
236
  return self.seg.segment(text)
237
 
238
+ def save_wav(self, wav: List[int], path: str, pipe_out = None) -> None:
239
  """Save the waveform as a file.
240
 
241
  Args:
242
  wav (List[int]): waveform as a list of values.
243
  path (str): output path to save the waveform.
244
+ pipe_out (BytesIO, optional): Flag to stdout the generated TTS wav file for shell pipe.
245
  """
246
  # if tensor convert to numpy
247
  if torch.is_tensor(wav):
248
  wav = wav.cpu().numpy()
249
  if isinstance(wav, list):
250
  wav = np.array(wav)
251
+ save_wav(wav=wav, path=path, sample_rate=self.output_sample_rate, pipe_out=pipe_out)
252
 
253
  def voice_conversion(self, source_wav: str, target_wav: str) -> List[int]:
254
  output_wav = self.vc_model.voice_conversion(source_wav, target_wav)
 
300
  speaker_embedding = None
301
  speaker_id = None
302
  if self.tts_speakers_file or hasattr(self.tts_model.speaker_manager, "name_to_id"):
303
+ if speaker_name and isinstance(speaker_name, str):
 
 
 
 
304
  if self.tts_config.use_d_vector_file:
305
  # get the average speaker embedding from the saved d_vectors.
306
  speaker_embedding = self.tts_model.speaker_manager.get_mean_embedding(
 
310
  else:
311
  # get speaker idx from the speaker name
312
  speaker_id = self.tts_model.speaker_manager.name_to_id[speaker_name]
313
+ # handle Neon models with single speaker.
314
+ elif len(self.tts_model.speaker_manager.name_to_id) == 1:
315
+ speaker_id = list(self.tts_model.speaker_manager.name_to_id.values())[0]
316
  elif not speaker_name and not speaker_wav:
317
  raise ValueError(
318
  " [!] Looks like you are using a multi-speaker model. "
TTS/docs/source/formatting_your_dataset.md CHANGED
@@ -17,19 +17,20 @@ Let's assume you created the audio clips and their transcription. You can collec
17
  ...
18
  ```
19
 
20
- You can either create separate transcription files for each clip or create a text file that maps each audio clip to its transcription. In this file, each line must be delimitered by a special character separating the audio file name from the transcription. And make sure that the delimiter is not used in the transcription text.
21
 
22
  We recommend the following format delimited by `|`. In the following example, `audio1`, `audio2` refer to files `audio1.wav`, `audio2.wav` etc.
23
 
24
  ```
25
  # metadata.txt
26
 
27
- audio1|This is my sentence.
28
- audio2|This is maybe my sentence.
29
- audio3|This is certainly my sentence.
30
- audio4|Let this be your sentence.
31
  ...
32
  ```
 
 
33
 
34
  In the end, we have the following folder structure
35
  ```
 
17
  ...
18
  ```
19
 
20
+ You can either create separate transcription files for each clip or create a text file that maps each audio clip to its transcription. In this file, each column must be delimitered by a special character separating the audio file name, the transcription and the normalized transcription. And make sure that the delimiter is not used in the transcription text.
21
 
22
  We recommend the following format delimited by `|`. In the following example, `audio1`, `audio2` refer to files `audio1.wav`, `audio2.wav` etc.
23
 
24
  ```
25
  # metadata.txt
26
 
27
+ audio1|This is my sentence.|This is my sentence.
28
+ audio2|1469 and 1470|fourteen sixty-nine and fourteen seventy
29
+ audio3|It'll be $16 sir.|It'll be sixteen dollars sir.
 
30
  ...
31
  ```
32
+ *If you don't have normalized transcriptions, you can use the same transcription for both columns. If it's your case, we recommend to use normalization later in the pipeline, either in the text cleaner or in the phonemizer.*
33
+
34
 
35
  In the end, we have the following folder structure
36
  ```
TTS/docs/source/implementing_a_new_model.md CHANGED
@@ -41,7 +41,7 @@
41
  6. Optionally, define `MyModelArgs`.
42
 
43
  `MyModelArgs` is a 👨‍✈️Coqpit class that sets all the class arguments of the `MyModel`. `MyModelArgs` must have
44
- all the fields neccessary to instantiate the `MyModel`. However, for training, you need to pass `MyModelConfig` to
45
  the model.
46
 
47
  7. Test `MyModel`.
 
41
  6. Optionally, define `MyModelArgs`.
42
 
43
  `MyModelArgs` is a 👨‍✈️Coqpit class that sets all the class arguments of the `MyModel`. `MyModelArgs` must have
44
+ all the fields necessary to instantiate the `MyModel`. However, for training, you need to pass `MyModelConfig` to
45
  the model.
46
 
47
  7. Test `MyModel`.
TTS/docs/source/inference.md CHANGED
@@ -114,18 +114,24 @@ tts-server --model_name "<type>/<language>/<dataset>/<model_name>" \
114
  You can run a multi-speaker and multi-lingual model in Python as
115
 
116
  ```python
 
117
  from TTS.api import TTS
118
 
119
- # List available 🐸TTS models and choose the first one
120
- model_name = TTS().list_models()[0]
 
 
 
 
121
  # Init TTS
122
- tts = TTS(model_name)
 
123
  # Run TTS
124
- # ❗ Since this model is multi-speaker and multi-lingual, we must set the target speaker and the language
125
- # Text to speech with a numpy output
126
- wav = tts.tts("This is a test! This is also a test!!", speaker=tts.speakers[0], language=tts.languages[0])
127
  # Text to speech to a file
128
- tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav")
129
  ```
130
 
131
  #### Here is an example for a single speaker model.
 
114
  You can run a multi-speaker and multi-lingual model in Python as
115
 
116
  ```python
117
+ import torch
118
  from TTS.api import TTS
119
 
120
+ # Get device
121
+ device = "cuda" if torch.cuda.is_available() else "cpu"
122
+
123
+ # List available 🐸TTS models
124
+ print(TTS().list_models())
125
+
126
  # Init TTS
127
+ tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1").to(device)
128
+
129
  # Run TTS
130
+ # ❗ Since this model is multi-lingual voice cloning model, we must set the target speaker_wav and language
131
+ # Text to speech list of amplitude values as output
132
+ wav = tts.tts(text="Hello world!", speaker_wav="my/cloning/audio.wav", language="en")
133
  # Text to speech to a file
134
+ tts.tts_to_file(text="Hello world!", speaker_wav="my/cloning/audio.wav", language="en", file_path="output.wav")
135
  ```
136
 
137
  #### Here is an example for a single speaker model.
TTS/docs/source/main_classes/trainer_api.md CHANGED
@@ -1,3 +1,3 @@
1
  # Trainer API
2
 
3
- We made the trainer a seprate project on https://github.com/coqui-ai/Trainer
 
1
  # Trainer API
2
 
3
+ We made the trainer a separate project on https://github.com/coqui-ai/Trainer
TTS/docs/source/models/forward_tts.md CHANGED
@@ -12,7 +12,7 @@ Currently we provide the following pre-configured architectures:
12
 
13
  - **FastPitch:**
14
 
15
- It uses the same FastSpeech architecture that is conditioned on fundemental frequency (f0) contours with the
16
  promise of more expressive speech.
17
 
18
  - **SpeedySpeech:**
 
12
 
13
  - **FastPitch:**
14
 
15
+ It uses the same FastSpeech architecture that is conditioned on fundamental frequency (f0) contours with the
16
  promise of more expressive speech.
17
 
18
  - **SpeedySpeech:**
TTS/docs/source/models/xtts.md CHANGED
@@ -28,7 +28,8 @@ This model is licensed under [Coqui Public Model License](https://coqui.ai/cpml)
28
  Come and join in our 🐸Community. We're active on [Discord](https://discord.gg/fBC58unbKE) and [Twitter](https://twitter.com/coqui_ai).
29
  You can also mail us at info@coqui.ai.
30
 
31
- Using 🐸TTS API:
 
32
 
33
  ```python
34
  from TTS.api import TTS
@@ -39,16 +40,9 @@ tts.tts_to_file(text="It took me quite a long time to develop a voice, and now t
39
  file_path="output.wav",
40
  speaker_wav="/path/to/target/speaker.wav",
41
  language="en")
42
-
43
- # generate speech by cloning a voice using custom settings
44
- tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
45
- file_path="output.wav",
46
- speaker_wav="/path/to/target/speaker.wav",
47
- language="en",
48
- decoder_iterations=30)
49
  ```
50
 
51
- Using 🐸TTS Command line:
52
 
53
  ```console
54
  tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 \
@@ -58,25 +52,85 @@ Using 🐸TTS Command line:
58
  --use_cuda true
59
  ```
60
 
61
- Using model directly:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  ```python
 
 
 
 
64
  from TTS.tts.configs.xtts_config import XttsConfig
65
  from TTS.tts.models.xtts import Xtts
66
 
 
67
  config = XttsConfig()
68
  config.load_json("/path/to/xtts/config.json")
69
  model = Xtts.init_from_config(config)
70
- model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", eval=True)
71
  model.cuda()
72
 
73
- outputs = model.synthesize(
 
 
 
 
 
74
  "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
75
- config,
76
- speaker_wav="/data/TTS-public/_refclips/3.wav",
77
- gpt_cond_len=3,
78
- language="en",
79
  )
 
 
 
 
 
 
 
 
 
80
  ```
81
 
82
 
 
28
  Come and join in our 🐸Community. We're active on [Discord](https://discord.gg/fBC58unbKE) and [Twitter](https://twitter.com/coqui_ai).
29
  You can also mail us at info@coqui.ai.
30
 
31
+ ### Inference
32
+ #### 🐸TTS API
33
 
34
  ```python
35
  from TTS.api import TTS
 
40
  file_path="output.wav",
41
  speaker_wav="/path/to/target/speaker.wav",
42
  language="en")
 
 
 
 
 
 
 
43
  ```
44
 
45
+ #### 🐸TTS Command line
46
 
47
  ```console
48
  tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 \
 
52
  --use_cuda true
53
  ```
54
 
55
+ #### model directly
56
+
57
+ If you want to be able to run with `use_deepspeed=True` and enjoy the speedup, you need to install deepspeed first.
58
+
59
+ ```console
60
+ pip install deepspeed==0.8.3
61
+ ```
62
+
63
+ ```python
64
+ import os
65
+ import torch
66
+ import torchaudio
67
+ from TTS.tts.configs.xtts_config import XttsConfig
68
+ from TTS.tts.models.xtts import Xtts
69
+
70
+ print("Loading model...")
71
+ config = XttsConfig()
72
+ config.load_json("/path/to/xtts/config.json")
73
+ model = Xtts.init_from_config(config)
74
+ model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=True)
75
+ model.cuda()
76
+
77
+ print("Computing speaker latents...")
78
+ gpt_cond_latent, diffusion_conditioning, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav")
79
+
80
+ print("Inference...")
81
+ out = model.inference(
82
+ "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
83
+ "en",
84
+ gpt_cond_latent,
85
+ speaker_embedding,
86
+ diffusion_conditioning,
87
+ temperature=0.7, # Add custom parameters here
88
+ )
89
+ torchaudio.save("xtts.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
90
+ ```
91
+
92
+
93
+ #### streaming inference
94
+
95
+ Here the goal is to stream the audio as it is being generated. This is useful for real-time applications.
96
+ Streaming inference is typically slower than regular inference, but it allows to get a first chunk of audio faster.
97
+
98
 
99
  ```python
100
+ import os
101
+ import time
102
+ import torch
103
+ import torchaudio
104
  from TTS.tts.configs.xtts_config import XttsConfig
105
  from TTS.tts.models.xtts import Xtts
106
 
107
+ print("Loading model...")
108
  config = XttsConfig()
109
  config.load_json("/path/to/xtts/config.json")
110
  model = Xtts.init_from_config(config)
111
+ model.load_checkpoint(config, checkpoint_dir="/path/to/xtts/", use_deepspeed=True)
112
  model.cuda()
113
 
114
+ print("Computing speaker latents...")
115
+ gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path="reference.wav")
116
+
117
+ print("Inference...")
118
+ t0 = time.time()
119
+ chunks = model.inference_stream(
120
  "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
121
+ "en",
122
+ gpt_cond_latent,
123
+ speaker_embedding
 
124
  )
125
+
126
+ wav_chuncks = []
127
+ for i, chunk in enumerate(chunks):
128
+ if i == 0:
129
+ print(f"Time to first chunck: {time.time() - t0}")
130
+ print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
131
+ wav_chuncks.append(chunk)
132
+ wav = torch.cat(wav_chuncks, dim=0)
133
+ torchaudio.save("xtts_streaming.wav", wav.squeeze().unsqueeze(0).cpu(), 24000)
134
  ```
135
 
136
 
TTS/notebooks/ExtractTTSpectrogram.ipynb CHANGED
@@ -13,15 +13,15 @@
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
16
- "%load_ext autoreload\n",
17
- "%autoreload 2\n",
18
  "import os\n",
19
  "import sys\n",
20
  "import torch\n",
21
  "import importlib\n",
22
  "import numpy as np\n",
23
- "from tqdm import tqdm as tqdm\n",
24
  "from torch.utils.data import DataLoader\n",
 
 
25
  "from TTS.tts.datasets.dataset import TTSDataset\n",
26
  "from TTS.tts.layers.losses import L1LossMasked\n",
27
  "from TTS.utils.audio import AudioProcessor\n",
@@ -33,8 +33,8 @@
33
  "\n",
34
  "%matplotlib inline\n",
35
  "\n",
36
- "import os\n",
37
- "os.environ['CUDA_VISIBLE_DEVICES']='2'"
38
  ]
39
  },
40
  {
@@ -43,6 +43,7 @@
43
  "metadata": {},
44
  "outputs": [],
45
  "source": [
 
46
  "def set_filename(wav_path, out_path):\n",
47
  " wav_file = os.path.basename(wav_path)\n",
48
  " file_name = wav_file.split('.')[0]\n",
@@ -61,6 +62,7 @@
61
  "metadata": {},
62
  "outputs": [],
63
  "source": [
 
64
  "OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
65
  "DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
66
  "DATASET = \"ljspeech\"\n",
@@ -73,12 +75,15 @@
73
  "QUANTIZE_BIT = None\n",
74
  "DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
75
  "\n",
 
76
  "use_cuda = torch.cuda.is_available()\n",
77
  "print(\" > CUDA enabled: \", use_cuda)\n",
78
  "\n",
 
79
  "C = load_config(CONFIG_PATH)\n",
80
  "C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
81
- "ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)"
 
82
  ]
83
  },
84
  {
@@ -87,14 +92,13 @@
87
  "metadata": {},
88
  "outputs": [],
89
  "source": [
90
- "print(C['r'])\n",
91
- "# if the vocabulary was passed, replace the default\n",
92
  "if 'characters' in C and C['characters']:\n",
93
  " symbols, phonemes = make_symbols(**C.characters)\n",
94
  "\n",
95
- "# load the model\n",
96
  "num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
97
- "# TODO: multiple speaker\n",
98
  "model = setup_model(C)\n",
99
  "model.load_checkpoint(C, MODEL_FILE, eval=True)"
100
  ]
@@ -105,11 +109,12 @@
105
  "metadata": {},
106
  "outputs": [],
107
  "source": [
 
108
  "preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
109
  "preprocessor = getattr(preprocessor, DATASET.lower())\n",
110
  "meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n",
111
  "dataset = TTSDataset(\n",
112
- " checkpoint[\"config\"][\"r\"],\n",
113
  " C.text_cleaner,\n",
114
  " False,\n",
115
  " ap,\n",
@@ -124,6 +129,24 @@
124
  ")\n"
125
  ]
126
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  {
128
  "cell_type": "markdown",
129
  "metadata": {},
@@ -137,83 +160,85 @@
137
  "metadata": {},
138
  "outputs": [],
139
  "source": [
140
- "import pickle\n",
141
- "\n",
142
- "file_idxs = []\n",
143
- "metadata = []\n",
144
- "losses = []\n",
145
- "postnet_losses = []\n",
146
- "criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
147
  "with torch.no_grad():\n",
148
- " for data in tqdm(loader):\n",
149
- " # setup input data\n",
150
- " text_input = data[0]\n",
151
- " text_lengths = data[1]\n",
152
- " linear_input = data[3]\n",
153
- " mel_input = data[4]\n",
154
- " mel_lengths = data[5]\n",
155
- " stop_targets = data[6]\n",
156
- " item_idx = data[7]\n",
157
  "\n",
158
- " # dispatch data to GPU\n",
159
- " if use_cuda:\n",
160
- " text_input = text_input.cuda()\n",
161
- " text_lengths = text_lengths.cuda()\n",
162
- " mel_input = mel_input.cuda()\n",
163
- " mel_lengths = mel_lengths.cuda()\n",
164
  "\n",
165
- " mask = sequence_mask(text_lengths)\n",
166
- " mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
167
- " \n",
168
- " # compute loss\n",
169
- " loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
170
- " loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
171
- " losses.append(loss.item())\n",
172
- " postnet_losses.append(loss_postnet.item())\n",
173
  "\n",
174
- " # compute mel specs from linear spec if model is Tacotron\n",
175
- " if C.model == \"Tacotron\":\n",
176
- " mel_specs = []\n",
177
- " postnet_outputs = postnet_outputs.data.cpu().numpy()\n",
178
- " for b in range(postnet_outputs.shape[0]):\n",
179
- " postnet_output = postnet_outputs[b]\n",
180
- " mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n",
181
- " postnet_outputs = torch.stack(mel_specs)\n",
182
- " elif C.model == \"Tacotron2\":\n",
183
- " postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
184
- " alignments = alignments.detach().cpu().numpy()\n",
185
  "\n",
186
- " if not DRY_RUN:\n",
187
- " for idx in range(text_input.shape[0]):\n",
188
- " wav_file_path = item_idx[idx]\n",
189
- " wav = ap.load_wav(wav_file_path)\n",
190
- " file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
191
- " file_idxs.append(file_name)\n",
 
 
 
 
 
192
  "\n",
193
- " # quantize and save wav\n",
194
- " if QUANTIZED_WAV:\n",
195
- " wavq = ap.quantize(wav)\n",
196
- " np.save(wavq_path, wavq)\n",
 
 
197
  "\n",
198
- " # save TTS mel\n",
199
- " mel = postnet_outputs[idx]\n",
200
- " mel_length = mel_lengths[idx]\n",
201
- " mel = mel[:mel_length, :].T\n",
202
- " np.save(mel_path, mel)\n",
203
  "\n",
204
- " metadata.append([wav_file_path, mel_path])\n",
 
 
 
 
205
  "\n",
206
- " # for wavernn\n",
207
- " if not DRY_RUN:\n",
208
- " pickle.dump(file_idxs, open(OUT_PATH+\"/dataset_ids.pkl\", \"wb\")) \n",
209
- " \n",
210
- " # for pwgan\n",
211
- " with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
212
- " for data in metadata:\n",
213
- " f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  "\n",
215
- " print(np.mean(losses))\n",
216
- " print(np.mean(postnet_losses))"
 
217
  ]
218
  },
219
  {
 
13
  "metadata": {},
14
  "outputs": [],
15
  "source": [
 
 
16
  "import os\n",
17
  "import sys\n",
18
  "import torch\n",
19
  "import importlib\n",
20
  "import numpy as np\n",
21
+ "from tqdm import tqdm\n",
22
  "from torch.utils.data import DataLoader\n",
23
+ "import soundfile as sf\n",
24
+ "import pickle\n",
25
  "from TTS.tts.datasets.dataset import TTSDataset\n",
26
  "from TTS.tts.layers.losses import L1LossMasked\n",
27
  "from TTS.utils.audio import AudioProcessor\n",
 
33
  "\n",
34
  "%matplotlib inline\n",
35
  "\n",
36
+ "# Configure CUDA visibility\n",
37
+ "os.environ['CUDA_VISIBLE_DEVICES'] = '2'"
38
  ]
39
  },
40
  {
 
43
  "metadata": {},
44
  "outputs": [],
45
  "source": [
46
+ "# Function to create directories and file names\n",
47
  "def set_filename(wav_path, out_path):\n",
48
  " wav_file = os.path.basename(wav_path)\n",
49
  " file_name = wav_file.split('.')[0]\n",
 
62
  "metadata": {},
63
  "outputs": [],
64
  "source": [
65
+ "# Paths and configurations\n",
66
  "OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
67
  "DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
68
  "DATASET = \"ljspeech\"\n",
 
75
  "QUANTIZE_BIT = None\n",
76
  "DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
77
  "\n",
78
+ "# Check CUDA availability\n",
79
  "use_cuda = torch.cuda.is_available()\n",
80
  "print(\" > CUDA enabled: \", use_cuda)\n",
81
  "\n",
82
+ "# Load the configuration\n",
83
  "C = load_config(CONFIG_PATH)\n",
84
  "C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
85
+ "ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n",
86
+ "print(C['r'])"
87
  ]
88
  },
89
  {
 
92
  "metadata": {},
93
  "outputs": [],
94
  "source": [
95
+ "# If the vocabulary was passed, replace the default\n",
 
96
  "if 'characters' in C and C['characters']:\n",
97
  " symbols, phonemes = make_symbols(**C.characters)\n",
98
  "\n",
99
+ "# Load the model\n",
100
  "num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
101
+ "# TODO: multiple speakers\n",
102
  "model = setup_model(C)\n",
103
  "model.load_checkpoint(C, MODEL_FILE, eval=True)"
104
  ]
 
109
  "metadata": {},
110
  "outputs": [],
111
  "source": [
112
+ "# Load the preprocessor based on the dataset\n",
113
  "preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
114
  "preprocessor = getattr(preprocessor, DATASET.lower())\n",
115
  "meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n",
116
  "dataset = TTSDataset(\n",
117
+ " C,\n",
118
  " C.text_cleaner,\n",
119
  " False,\n",
120
  " ap,\n",
 
129
  ")\n"
130
  ]
131
  },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": null,
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "# Initialize lists for storing results\n",
139
+ "file_idxs = []\n",
140
+ "metadata = []\n",
141
+ "losses = []\n",
142
+ "postnet_losses = []\n",
143
+ "criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
144
+ "\n",
145
+ "# Create log file\n",
146
+ "log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
147
+ "log_file = open(log_file_path, \"w\")"
148
+ ]
149
+ },
150
  {
151
  "cell_type": "markdown",
152
  "metadata": {},
 
160
  "metadata": {},
161
  "outputs": [],
162
  "source": [
163
+ "# Start processing with a progress bar\n",
 
 
 
 
 
 
164
  "with torch.no_grad():\n",
165
+ " for data in tqdm(loader, desc=\"Processing\"):\n",
166
+ " try:\n",
167
+ " # setup input data\n",
168
+ " text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n",
 
 
 
 
 
169
  "\n",
170
+ " # dispatch data to GPU\n",
171
+ " if use_cuda:\n",
172
+ " text_input = text_input.cuda()\n",
173
+ " text_lengths = text_lengths.cuda()\n",
174
+ " mel_input = mel_input.cuda()\n",
175
+ " mel_lengths = mel_lengths.cuda()\n",
176
  "\n",
177
+ " mask = sequence_mask(text_lengths)\n",
178
+ " mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
 
 
 
 
 
 
179
  "\n",
180
+ " # compute loss\n",
181
+ " loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
182
+ " loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
183
+ " losses.append(loss.item())\n",
184
+ " postnet_losses.append(loss_postnet.item())\n",
 
 
 
 
 
 
185
  "\n",
186
+ " # compute mel specs from linear spec if the model is Tacotron\n",
187
+ " if C.model == \"Tacotron\":\n",
188
+ " mel_specs = []\n",
189
+ " postnet_outputs = postnet_outputs.data.cpu().numpy()\n",
190
+ " for b in range(postnet_outputs.shape[0]):\n",
191
+ " postnet_output = postnet_outputs[b]\n",
192
+ " mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n",
193
+ " postnet_outputs = torch.stack(mel_specs)\n",
194
+ " elif C.model == \"Tacotron2\":\n",
195
+ " postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
196
+ " alignments = alignments.detach().cpu().numpy()\n",
197
  "\n",
198
+ " if not DRY_RUN:\n",
199
+ " for idx in range(text_input.shape[0]):\n",
200
+ " wav_file_path = item_idx[idx]\n",
201
+ " wav = ap.load_wav(wav_file_path)\n",
202
+ " file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
203
+ " file_idxs.append(file_name)\n",
204
  "\n",
205
+ " # quantize and save wav\n",
206
+ " if QUANTIZED_WAV:\n",
207
+ " wavq = ap.quantize(wav)\n",
208
+ " np.save(wavq_path, wavq)\n",
 
209
  "\n",
210
+ " # save TTS mel\n",
211
+ " mel = postnet_outputs[idx]\n",
212
+ " mel_length = mel_lengths[idx]\n",
213
+ " mel = mel[:mel_length, :].T\n",
214
+ " np.save(mel_path, mel)\n",
215
  "\n",
216
+ " metadata.append([wav_file_path, mel_path])\n",
217
+ "\n",
218
+ " except Exception as e:\n",
219
+ " log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
220
+ "\n",
221
+ " # Calculate and log mean losses\n",
222
+ " mean_loss = np.mean(losses)\n",
223
+ " mean_postnet_loss = np.mean(postnet_losses)\n",
224
+ " log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
225
+ " log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
226
+ "\n",
227
+ "# Close the log file\n",
228
+ "log_file.close()\n",
229
+ "\n",
230
+ "# For wavernn\n",
231
+ "if not DRY_RUN:\n",
232
+ " pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
233
+ "\n",
234
+ "# For pwgan\n",
235
+ "with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
236
+ " for data in metadata:\n",
237
+ " f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n",
238
  "\n",
239
+ "# Print mean losses\n",
240
+ "print(f\"Mean Loss: {mean_loss}\")\n",
241
+ "print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
242
  ]
243
  },
244
  {
TTS/notebooks/dataset_analysis/AnalyzeDataset.ipynb CHANGED
@@ -100,7 +100,7 @@
100
  " wav_file = item[\"audio_file\"].strip()\n",
101
  " wav_files.append(wav_file)\n",
102
  " if not os.path.exists(wav_file):\n",
103
- " print(waf_path)"
104
  ]
105
  },
106
  {
 
100
  " wav_file = item[\"audio_file\"].strip()\n",
101
  " wav_files.append(wav_file)\n",
102
  " if not os.path.exists(wav_file):\n",
103
+ " print(wav_file)"
104
  ]
105
  },
106
  {
TTS/requirements.ja.txt CHANGED
@@ -2,3 +2,4 @@
2
  # japanese g2p deps
3
  mecab-python3==1.0.6
4
  unidic-lite==1.0.8
 
 
2
  # japanese g2p deps
3
  mecab-python3==1.0.6
4
  unidic-lite==1.0.8
5
+ cutlet
TTS/tests/api_tests/test_synthesize_api.py CHANGED
@@ -13,3 +13,16 @@ def test_synthesize():
13
  '--text "This is it" '
14
  f'--out_path "{output_path}"'
15
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  '--text "This is it" '
14
  f'--out_path "{output_path}"'
15
  )
16
+
17
+ # 🐸 Coqui studio model with speed arg.
18
+ run_cli(
19
+ 'tts --model_name "coqui_studio/en/Torcull Diarmuid/coqui_studio" '
20
+ '--text "This is it but slow" --speed 0.1'
21
+ f'--out_path "{output_path}"'
22
+ )
23
+
24
+ # test pipe_out command
25
+ run_cli(
26
+ 'tts --text "test." --pipe_out '
27
+ f'--out_path "{output_path}" | aplay'
28
+ )
TTS/tests/zoo_tests/test_models.py CHANGED
@@ -3,13 +3,20 @@ import glob
3
  import os
4
  import shutil
5
 
 
 
6
  from tests import get_tests_data_path, get_tests_output_path, run_cli
7
  from TTS.tts.utils.languages import LanguageManager
8
  from TTS.tts.utils.speakers import SpeakerManager
9
  from TTS.utils.generic_utils import get_user_data_dir
10
  from TTS.utils.manage import ModelManager
11
 
12
- MODELS_WITH_SEP_TESTS = ["bark", "xtts"]
 
 
 
 
 
13
 
14
 
15
  def run_models(offset=0, step=1):
@@ -17,7 +24,8 @@ def run_models(offset=0, step=1):
17
  print(" > Run synthesizer with all the models.")
18
  output_path = os.path.join(get_tests_output_path(), "output.wav")
19
  manager = ModelManager(output_prefix=get_tests_output_path(), progress_bar=False)
20
- model_names = [name for name in manager.list_models() if name in MODELS_WITH_SEP_TESTS]
 
21
  for model_name in model_names[offset::step]:
22
  print(f"\n > Run - {model_name}")
23
  model_path, _, _ = manager.download_model(model_name)
@@ -67,23 +75,85 @@ def run_models(offset=0, step=1):
67
 
68
 
69
  def test_xtts():
 
70
  output_path = os.path.join(get_tests_output_path(), "output.wav")
71
  speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
72
- run_cli(
73
- "yes | "
74
- f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 "
75
- f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
76
- f'--speaker_wav "{speaker_wav}" --language_idx "en"'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
 
80
  def test_bark():
81
  """Bark is too big to run on github actions. We need to test it locally"""
82
  output_path = os.path.join(get_tests_output_path(), "output.wav")
83
- run_cli(
84
- f" tts --model_name tts_models/multilingual/multi-dataset/bark "
85
- f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True'
86
- )
 
 
 
 
 
 
 
87
 
88
 
89
  def test_voice_conversion():
 
3
  import os
4
  import shutil
5
 
6
+ import torch
7
+
8
  from tests import get_tests_data_path, get_tests_output_path, run_cli
9
  from TTS.tts.utils.languages import LanguageManager
10
  from TTS.tts.utils.speakers import SpeakerManager
11
  from TTS.utils.generic_utils import get_user_data_dir
12
  from TTS.utils.manage import ModelManager
13
 
14
+ MODELS_WITH_SEP_TESTS = [
15
+ "tts_models/multilingual/multi-dataset/bark",
16
+ "tts_models/en/multi-dataset/tortoise-v2",
17
+ "tts_models/multilingual/multi-dataset/xtts_v1",
18
+ "tts_models/multilingual/multi-dataset/xtts_v1.1",
19
+ ]
20
 
21
 
22
  def run_models(offset=0, step=1):
 
24
  print(" > Run synthesizer with all the models.")
25
  output_path = os.path.join(get_tests_output_path(), "output.wav")
26
  manager = ModelManager(output_prefix=get_tests_output_path(), progress_bar=False)
27
+ model_names = [name for name in manager.list_models() if name not in MODELS_WITH_SEP_TESTS]
28
+ print("Model names:", model_names)
29
  for model_name in model_names[offset::step]:
30
  print(f"\n > Run - {model_name}")
31
  model_path, _, _ = manager.download_model(model_name)
 
75
 
76
 
77
  def test_xtts():
78
+ """XTTS is too big to run on github actions. We need to test it locally"""
79
  output_path = os.path.join(get_tests_output_path(), "output.wav")
80
  speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
81
+ use_gpu = torch.cuda.is_available()
82
+ if use_gpu:
83
+ run_cli(
84
+ "yes | "
85
+ f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 "
86
+ f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
87
+ f'--speaker_wav "{speaker_wav}" --language_idx "en"'
88
+ )
89
+ else:
90
+ run_cli(
91
+ "yes | "
92
+ f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v1 "
93
+ f'--text "This is an example." --out_path "{output_path}" --progress_bar False '
94
+ f'--speaker_wav "{speaker_wav}" --language_idx "en"'
95
+ )
96
+
97
+
98
+ def test_xtts_streaming():
99
+ """Testing the new inference_stream method"""
100
+ from TTS.tts.configs.xtts_config import XttsConfig
101
+ from TTS.tts.models.xtts import Xtts
102
+ speaker_wav = os.path.join(get_tests_data_path(), "ljspeech", "wavs", "LJ001-0001.wav")
103
+ model_path = os.path.join(get_user_data_dir("tts"), "tts_models--multilingual--multi-dataset--xtts_v1")
104
+ config = XttsConfig()
105
+ config.load_json(os.path.join(model_path, "config.json"))
106
+ model = Xtts.init_from_config(config)
107
+ model.load_checkpoint(config, checkpoint_dir=model_path)
108
+ model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
109
+
110
+ print("Computing speaker latents...")
111
+ gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
112
+
113
+ print("Inference...")
114
+ chunks = model.inference_stream(
115
+ "It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
116
+ "en",
117
+ gpt_cond_latent,
118
+ speaker_embedding
119
  )
120
+ wav_chuncks = []
121
+ for i, chunk in enumerate(chunks):
122
+ if i == 0:
123
+ assert chunk.shape[-1] > 5000
124
+ wav_chuncks.append(chunk)
125
+ assert len(wav_chuncks) > 1
126
+
127
+
128
+ def test_tortoise():
129
+ output_path = os.path.join(get_tests_output_path(), "output.wav")
130
+ use_gpu = torch.cuda.is_available()
131
+ if use_gpu:
132
+ run_cli(
133
+ f" tts --model_name tts_models/en/multi-dataset/tortoise-v2 "
134
+ f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True'
135
+ )
136
+ else:
137
+ run_cli(
138
+ f" tts --model_name tts_models/en/multi-dataset/tortoise-v2 "
139
+ f'--text "This is an example." --out_path "{output_path}" --progress_bar False'
140
+ )
141
 
142
 
143
  def test_bark():
144
  """Bark is too big to run on github actions. We need to test it locally"""
145
  output_path = os.path.join(get_tests_output_path(), "output.wav")
146
+ use_gpu = torch.cuda.is_available()
147
+ if use_gpu:
148
+ run_cli(
149
+ f" tts --model_name tts_models/multilingual/multi-dataset/bark "
150
+ f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True'
151
+ )
152
+ else:
153
+ run_cli(
154
+ f" tts --model_name tts_models/multilingual/multi-dataset/bark "
155
+ f'--text "This is an example." --out_path "{output_path}" --progress_bar False'
156
+ )
157
 
158
 
159
  def test_voice_conversion():