drewThomasson commited on
Commit
c39db41
·
verified ·
1 Parent(s): a782fff

Upload 10 files

Browse files
Files changed (10) hide show
  1. README.md +73 -13
  2. install.bat +10 -0
  3. install.sh +13 -0
  4. requirements.txt +7 -0
  5. start.bat +5 -0
  6. start.sh +9 -0
  7. utils/formatter.py +198 -0
  8. utils/gpt_train.py +221 -0
  9. utils/tokenizer.py +869 -0
  10. xtts_demo.py +693 -0
README.md CHANGED
@@ -1,13 +1,73 @@
1
- ---
2
- title: Xtts Finetune Webui Other Guys Work
3
- emoji: 🐢
4
- colorFrom: blue
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 4.44.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # xtts-finetune-webui
2
+
3
+ This webui is a slightly modified copy of the [official webui](https://github.com/coqui-ai/TTS/pull/3296) for finetune xtts.
4
+
5
+ If you are looking for an option for normal XTTS use look here [https://github.com/daswer123/xtts-webui](https://github.com/daswer123/xtts-webui)
6
+
7
+ ## TODO
8
+ - [ ] Add the ability to use via console
9
+
10
+ ## Key features:
11
+
12
+ ### Data processing
13
+
14
+ 1. Updated faster-whisper to 0.10.0 with the ability to select a larger-v3 model.
15
+ 2. Changed output folder to output folder inside the main folder.
16
+ 3. If there is already a dataset in the output folder and you want to add new data, you can do so by simply adding new audio, what was there will not be processed again and the new data will be automatically added
17
+ 4. Turn on VAD filter
18
+ 5. After the dataset is created, a file is created that specifies the language of the dataset. This file is read before training so that the language always matches. It is convenient when you restart the interface
19
+
20
+ ### Fine-tuning XTTS Encoder
21
+
22
+ 1. Added the ability to select the base model for XTTS, as well as when you re-training does not need to download the model again.
23
+ 2. Added ability to select custom model as base model during training, which will allow finetune already finetune model.
24
+ 3. Added possibility to get optimized version of the model for 1 click ( step 2.5, put optimized version in output folder).
25
+ 4. You can choose whether to delete training folders after you have optimized the model
26
+ 5. When you optimize the model, the example reference audio is moved to the output folder
27
+ 6. Checking for correctness of the specified language and dataset language
28
+
29
+ ### Inference
30
+
31
+ 1. Added possibility to customize infer settings during model checking.
32
+
33
+ ### Other
34
+
35
+ 1. If you accidentally restart the interface during one of the steps, you can load data to additional buttons
36
+ 2. Removed the display of logs as it was causing problems when restarted
37
+ 3. The finished result is copied to the ready folder, these are fully finished files, you can move them anywhere and use them as a standard model
38
+ 4. Added support for finetune Japanese
39
+
40
+ ## Changes in webui
41
+
42
+ ### 1 - Data processing
43
+
44
+ ![image](https://github.com/daswer123/xtts-finetune-webui/assets/22278673/8f09b829-098b-48f5-9668-832e7319403b)
45
+
46
+ ### 2 - Fine-tuning XTTS Encoder
47
+
48
+ ![image](https://github.com/daswer123/xtts-finetune-webui/assets/22278673/897540d9-3a6b-463c-abb8-261c289cc929)
49
+
50
+ ### 3 - Inference
51
+
52
+ ![image](https://github.com/daswer123/xtts-finetune-webui/assets/22278673/aa05bcd4-8642-4de4-8f2f-bc0f5571af63)
53
+
54
+ ## Install
55
+
56
+ 1. Make sure you have `Cuda` installed
57
+ 2. `git clone https://github.com/daswer123/xtts-finetune-webui`
58
+ 3. `cd xtts-finetune-webui`
59
+ 4. `pip install torch==2.1.1+cu118 torchaudio==2.1.1+cu118 --index-url https://download.pytorch.org/whl/cu118`
60
+ 5. `pip install -r requirements.txt`
61
+
62
+ ### If you're using Windows
63
+
64
+ 1. First start `install.bat`
65
+ 2. To start the server start `start.bat`
66
+ 3. Go to the local address `127.0.0.1:5003`
67
+
68
+ ### On Linux
69
+
70
+ 1. Run `bash install.sh`
71
+ 2. To start the server start `start.sh`
72
+ 3. Go to the local address `127.0.0.1:5003`
73
+ ~
install.bat ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+
3
+ python -m venv venv
4
+ call venv/scripts/activate
5
+
6
+
7
+ pip install -r .\requirements.txt
8
+ pip install torch==2.1.1+cu118 torchaudio==2.1.1+cu118 --index-url https://download.pytorch.org/whl/cu118
9
+
10
+ python xtts_demo.py
install.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Create a Python virtual environment
4
+ python -m venv venv
5
+ # Activate the virtual environment
6
+ source venv/bin/activate
7
+
8
+ # Install other dependencies from requirements.txt
9
+ pip install -r requirements.txt
10
+ pip install torch==2.1.1+cu118 torchaudio==2.1.1+cu118 --index-url https://download.pytorch.org/whl/cu118
11
+
12
+ python xtts_demo.py
13
+
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ faster_whisper==1.0.2
2
+ gradio==4.13.0
3
+ spacy==3.7.4
4
+ coqui-tts[languages] == 0.24.1
5
+
6
+ cutlet
7
+ fugashi[unidic-lite]
start.bat ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ @echo off
2
+
3
+ call venv/scripts/activate
4
+
5
+ python xtts_demo.py
start.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Create a Python virtual environment
4
+ python -m venv venv
5
+ # Activate the virtual environment
6
+ source venv/bin/activate
7
+
8
+ python xtts_demo.py
9
+
utils/formatter.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torchaudio
4
+ import pandas
5
+ from faster_whisper import WhisperModel
6
+ from glob import glob
7
+
8
+ from tqdm import tqdm
9
+
10
+ from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners
11
+ # Add support for JA train
12
+ # from utils.tokenizer import multilingual_cleaners
13
+
14
+ import torch
15
+ import torchaudio
16
+ # torch.set_num_threads(1)
17
+
18
+
19
+ torch.set_num_threads(16)
20
+ import os
21
+
22
+ audio_types = (".wav", ".mp3", ".flac")
23
+
24
+ def find_latest_best_model(folder_path):
25
+ search_path = os.path.join(folder_path, '**', 'best_model.pth')
26
+ files = glob(search_path, recursive=True)
27
+ latest_file = max(files, key=os.path.getctime, default=None)
28
+ return latest_file
29
+
30
+
31
+ def list_audios(basePath, contains=None):
32
+ # return the set of files that are valid
33
+ return list_files(basePath, validExts=audio_types, contains=contains)
34
+
35
+ def list_files(basePath, validExts=None, contains=None):
36
+ # loop over the directory structure
37
+ for (rootDir, dirNames, filenames) in os.walk(basePath):
38
+ # loop over the filenames in the current directory
39
+ for filename in filenames:
40
+ # if the contains string is not none and the filename does not contain
41
+ # the supplied string, then ignore the file
42
+ if contains is not None and filename.find(contains) == -1:
43
+ continue
44
+
45
+ # determine the file extension of the current file
46
+ ext = filename[filename.rfind("."):].lower()
47
+
48
+ # check to see if the file is an audio and should be processed
49
+ if validExts is None or ext.endswith(validExts):
50
+ # construct the path to the audio and yield it
51
+ audioPath = os.path.join(rootDir, filename)
52
+ yield audioPath
53
+
54
+ def format_audio_list(audio_files, asr_model, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
55
+ audio_total_size = 0
56
+ os.makedirs(out_path, exist_ok=True)
57
+
58
+ lang_file_path = os.path.join(out_path, "lang.txt")
59
+ current_language = None
60
+ if os.path.exists(lang_file_path):
61
+ with open(lang_file_path, 'r', encoding='utf-8') as existing_lang_file:
62
+ current_language = existing_lang_file.read().strip()
63
+
64
+ if current_language != target_language:
65
+ with open(lang_file_path, 'w', encoding='utf-8') as lang_file:
66
+ lang_file.write(target_language + '\n')
67
+ print("Warning, existing language does not match target language. Updated lang.txt with target language.")
68
+ else:
69
+ print("Existing language matches target language")
70
+
71
+ metadata = {"audio_file": [], "text": [], "speaker_name": []}
72
+ train_metadata_path = os.path.join(out_path, "metadata_train.csv")
73
+ eval_metadata_path = os.path.join(out_path, "metadata_eval.csv")
74
+
75
+ existing_metadata = {'train': None, 'eval': None}
76
+ if os.path.exists(train_metadata_path):
77
+ existing_metadata['train'] = pandas.read_csv(train_metadata_path, sep="|")
78
+ print("Existing training metadata found and loaded.")
79
+
80
+ if os.path.exists(eval_metadata_path):
81
+ existing_metadata['eval'] = pandas.read_csv(eval_metadata_path, sep="|")
82
+ print("Existing evaluation metadata found and loaded.")
83
+
84
+ if gradio_progress is not None:
85
+ tqdm_object = gradio_progress.tqdm(audio_files, desc="Formatting...")
86
+ else:
87
+ tqdm_object = tqdm(audio_files)
88
+
89
+ for audio_path in tqdm_object:
90
+ audio_file_name_without_ext, _= os.path.splitext(os.path.basename(audio_path))
91
+ prefix_check = f"wavs/{audio_file_name_without_ext}_"
92
+
93
+ skip_processing = False
94
+ for key in ['train', 'eval']:
95
+ if existing_metadata[key] is not None:
96
+ mask = existing_metadata[key]['audio_file'].str.startswith(prefix_check)
97
+ if mask.any():
98
+ print(f"Segments from {audio_file_name_without_ext} have been previously processed; skipping...")
99
+ skip_processing = True
100
+ break
101
+
102
+ if skip_processing:
103
+ continue
104
+
105
+ wav, sr = torchaudio.load(audio_path)
106
+ if wav.size(0) != 1:
107
+ wav = torch.mean(wav, dim=0, keepdim=True)
108
+
109
+ wav = wav.squeeze()
110
+ audio_total_size += (wav.size(-1) / sr)
111
+
112
+ segments, _= asr_model.transcribe(audio_path, vad_filter=True, word_timestamps=True, language=target_language)
113
+ segments = list(segments)
114
+ i = 0
115
+ sentence = ""
116
+ sentence_start = None
117
+ first_word = True
118
+ words_list = []
119
+ for _, segment in enumerate(segments):
120
+ words = list(segment.words)
121
+ words_list.extend(words)
122
+
123
+ for word_idx, word in enumerate(words_list):
124
+ if first_word:
125
+ sentence_start = word.start
126
+ if word_idx == 0:
127
+ sentence_start = max(sentence_start - buffer, 0)
128
+ else:
129
+ previous_word_end = words_list[word_idx - 1].end
130
+ sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start) / 2)
131
+
132
+ sentence = word.word
133
+ first_word = False
134
+ else:
135
+ sentence += word.word
136
+
137
+ if word.word[-1] in ["!", "。", ".", "?"]:
138
+ sentence = sentence[1:]
139
+ sentence = multilingual_cleaners(sentence, target_language)
140
+ audio_file_name, _= os.path.splitext(os.path.basename(audio_path))
141
+ audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}.wav"
142
+
143
+ if word_idx + 1 < len(words_list):
144
+ next_word_start = words_list[word_idx + 1].start
145
+ else:
146
+ next_word_start = (wav.shape[0] - 1) / sr
147
+
148
+ word_end = min((word.end + next_word_start) / 2, word.end + buffer)
149
+
150
+ absolute_path = os.path.join(out_path, audio_file)
151
+ os.makedirs(os.path.dirname(absolute_path), exist_ok=True)
152
+ i += 1
153
+ first_word = True
154
+
155
+ audio = wav[int(sr*sentence_start):int(sr *word_end)].unsqueeze(0)
156
+ if audio.size(-1) >= sr / 3:
157
+ torchaudio.save(absolute_path, audio, sr)
158
+ else:
159
+ continue
160
+
161
+ metadata["audio_file"].append(audio_file)
162
+ metadata["text"].append(sentence)
163
+ metadata["speaker_name"].append(speaker_name)
164
+
165
+ df = pandas.DataFrame(metadata)
166
+
167
+ mode = 'w' if not os.path.exists(train_metadata_path) else 'a'
168
+ header = not os.path.exists(train_metadata_path)
169
+ df.to_csv(train_metadata_path, sep="|", index=False, mode=mode, header=header)
170
+
171
+ mode = 'w' if not os.path.exists(eval_metadata_path) else 'a'
172
+ header = not os.path.exists(eval_metadata_path)
173
+ df.to_csv(eval_metadata_path, sep="|", index=False, mode=mode, header=header)
174
+
175
+ metadata = {"audio_file": [], "text": [], "speaker_name": []}
176
+
177
+ if os.path.exists(train_metadata_path) and os.path.exists(eval_metadata_path):
178
+ existing_train_df = existing_metadata['train']
179
+ existing_eval_df = existing_metadata['eval']
180
+ else:
181
+ existing_train_df = pandas.DataFrame(columns=["audio_file", "text", "speaker_name"])
182
+ existing_eval_df = pandas.DataFrame(columns=["audio_file", "text", "speaker_name"])
183
+
184
+ new_data_df = pandas.read_csv(train_metadata_path, sep="|")
185
+
186
+ combined_train_df = pandas.concat([existing_train_df, new_data_df], ignore_index=True).drop_duplicates().reset_index(drop=True)
187
+ combined_eval_df = pandas.concat([existing_eval_df, new_data_df], ignore_index=True).drop_duplicates().reset_index(drop=True)
188
+
189
+ combined_train_df_shuffled = combined_train_df.sample(frac=1)
190
+ num_val_samples = int(len(combined_train_df_shuffled)* eval_percentage)
191
+
192
+ final_eval_set = combined_train_df_shuffled[:num_val_samples]
193
+ final_training_set = combined_train_df_shuffled[num_val_samples:]
194
+
195
+ final_training_set.sort_values('audio_file').to_csv(train_metadata_path, sep='|', index=False)
196
+ final_eval_set.sort_values('audio_file').to_csv(eval_metadata_path, sep='|', index=False)
197
+
198
+ return train_metadata_path, eval_metadata_path, audio_total_size
utils/gpt_train.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import gc
4
+ from pathlib import Path
5
+
6
+ from trainer import Trainer, TrainerArgs
7
+
8
+ from TTS.config.shared_configs import BaseDatasetConfig
9
+ from TTS.tts.datasets import load_tts_samples
10
+ from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
11
+ from TTS.utils.manage import ModelManager
12
+ import shutil
13
+
14
+
15
+ def train_gpt(custom_model,version, language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path, max_audio_length=255995):
16
+ # Logging parameters
17
+ RUN_NAME = "GPT_XTTS_FT"
18
+ PROJECT_NAME = "XTTS_trainer"
19
+ DASHBOARD_LOGGER = "tensorboard"
20
+ LOGGER_URI = None
21
+
22
+ # print(f"XTTS version = {version}")
23
+
24
+ # Set here the path that the checkpoints will be saved. Default: ./run/training/
25
+ OUT_PATH = os.path.join(output_path, "run", "training")
26
+
27
+ # Training Parameters
28
+ OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False
29
+ START_WITH_EVAL = False # if True it will star with evaluation
30
+ BATCH_SIZE = batch_size # set here the batch size
31
+ GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
32
+
33
+
34
+ # Define here the dataset that you want to use for the fine-tuning on.
35
+ config_dataset = BaseDatasetConfig(
36
+ formatter="coqui",
37
+ dataset_name="ft_dataset",
38
+ path=os.path.dirname(train_csv),
39
+ meta_file_train=train_csv,
40
+ meta_file_val=eval_csv,
41
+ language=language,
42
+ )
43
+
44
+ # Add here the configs of the datasets
45
+ DATASETS_CONFIG_LIST = [config_dataset]
46
+
47
+ # Define the path where XTTS v2.0.1 files will be downloaded
48
+ CHECKPOINTS_OUT_PATH = os.path.join(Path.cwd(), "base_models",f"{version}")
49
+ os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
50
+
51
+
52
+ # DVAE files
53
+ DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
54
+ MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
55
+
56
+ # Set the path to the downloaded files
57
+ DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(DVAE_CHECKPOINT_LINK))
58
+ MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(MEL_NORM_LINK))
59
+
60
+ # download DVAE files if needed
61
+ if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
62
+ print(" > Downloading DVAE files!")
63
+ ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
64
+
65
+
66
+ # Download XTTS v2.0 checkpoint if needed
67
+ TOKENIZER_FILE_LINK = f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{version}/vocab.json"
68
+ XTTS_CHECKPOINT_LINK = f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{version}/model.pth"
69
+ XTTS_CONFIG_LINK = f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{version}/config.json"
70
+ XTTS_SPEAKER_LINK = f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/speakers_xtts.pth"
71
+
72
+ # XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
73
+ TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file
74
+ XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file
75
+ XTTS_CONFIG_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CONFIG_LINK)) # config.json file
76
+ XTTS_SPEAKER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_SPEAKER_LINK)) # speakers_xtts.pth file
77
+
78
+ # download XTTS v2.0 files if needed
79
+ if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
80
+ print(f" > Downloading XTTS v{version} files!")
81
+ ModelManager._download_model_files(
82
+ [TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK, XTTS_CONFIG_LINK,XTTS_SPEAKER_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
83
+ )
84
+
85
+ # Transfer this files to ready folder
86
+ READY_MODEL_PATH = os.path.join(output_path,"ready")
87
+ if not os.path.exists(READY_MODEL_PATH):
88
+ os.makedirs(READY_MODEL_PATH)
89
+
90
+ NEW_TOKENIZER_FILE = os.path.join(READY_MODEL_PATH, "vocab.json")
91
+ # NEW_XTTS_CHECKPOINT = os.path.join(READY_MODEL_PATH, "model.pth")
92
+ NEW_XTTS_CONFIG_FILE = os.path.join(READY_MODEL_PATH, "config.json")
93
+ NEW_XTTS_SPEAKER_FILE = os.path.join(READY_MODEL_PATH, "speakers_xtts.pth")
94
+
95
+ shutil.copy(TOKENIZER_FILE, NEW_TOKENIZER_FILE)
96
+ # shutil.copy(XTTS_CHECKPOINT, os.path.join(READY_MODEL_PATH, "model.pth"))
97
+ shutil.copy(XTTS_CONFIG_FILE, NEW_XTTS_CONFIG_FILE)
98
+ shutil.copy(XTTS_SPEAKER_FILE, NEW_XTTS_SPEAKER_FILE)
99
+
100
+ # Use from ready folder
101
+ TOKENIZER_FILE = NEW_TOKENIZER_FILE # vocab.json file
102
+ # XTTS_CHECKPOINT = NEW_XTTS_CHECKPOINT # model.pth file
103
+ XTTS_CONFIG_FILE = NEW_XTTS_CONFIG_FILE # config.json file
104
+ XTTS_SPEAKER_FILE = NEW_XTTS_SPEAKER_FILE # speakers_xtts.pth file
105
+
106
+
107
+ if custom_model != "":
108
+ if os.path.exists(custom_model) and custom_model.endswith('.pth'):
109
+ XTTS_CHECKPOINT = custom_model
110
+ print(f" > Loading custom model: {XTTS_CHECKPOINT}")
111
+ else:
112
+ print(" > Error: The specified custom model is not a valid .pth file path.")
113
+
114
+ num_workers = 8
115
+ if language == "ja":
116
+ num_workers = 0
117
+ # init args and config
118
+ model_args = GPTArgs(
119
+ max_conditioning_length=132300, # 6 secs
120
+ min_conditioning_length=66150, # 3 secs
121
+ debug_loading_failures=False,
122
+ max_wav_length=max_audio_length, # ~11.6 seconds
123
+ max_text_length=200,
124
+ mel_norm_file=MEL_NORM_FILE,
125
+ dvae_checkpoint=DVAE_CHECKPOINT,
126
+ xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune
127
+ tokenizer_file=TOKENIZER_FILE,
128
+ gpt_num_audio_tokens=1026,
129
+ gpt_start_audio_token=1024,
130
+ gpt_stop_audio_token=1025,
131
+ gpt_use_masking_gt_prompt_approach=True,
132
+ gpt_use_perceiver_resampler=True,
133
+ )
134
+ # define audio config
135
+ audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
136
+ # training parameters config
137
+ config = GPTTrainerConfig(
138
+ epochs=num_epochs,
139
+ output_path=OUT_PATH,
140
+ model_args=model_args,
141
+ run_name=RUN_NAME,
142
+ project_name=PROJECT_NAME,
143
+ run_description="""
144
+ GPT XTTS training
145
+ """,
146
+ dashboard_logger=DASHBOARD_LOGGER,
147
+ logger_uri=LOGGER_URI,
148
+ audio=audio_config,
149
+ batch_size=BATCH_SIZE,
150
+ batch_group_size=48,
151
+ eval_batch_size=BATCH_SIZE,
152
+ num_loader_workers=num_workers,
153
+ eval_split_max_size=256,
154
+ print_step=50,
155
+ plot_step=100,
156
+ log_model_step=100,
157
+ save_step=1000,
158
+ save_n_checkpoints=1,
159
+ save_checkpoints=True,
160
+ # target_loss="loss",
161
+ print_eval=False,
162
+ # Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
163
+ optimizer="AdamW",
164
+ optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS,
165
+ optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
166
+ lr=5e-06, # learning rate
167
+ lr_scheduler="MultiStepLR",
168
+ # it was adjusted accordly for the new step scheme
169
+ lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
170
+ test_sentences=[],
171
+ )
172
+
173
+ # init the model from config
174
+ model = GPTTrainer.init_from_config(config)
175
+
176
+ # load training samples
177
+ train_samples, eval_samples = load_tts_samples(
178
+ DATASETS_CONFIG_LIST,
179
+ eval_split=True,
180
+ eval_split_max_size=config.eval_split_max_size,
181
+ eval_split_size=config.eval_split_size,
182
+ )
183
+
184
+ # init the trainer and 🚀
185
+ trainer = Trainer(
186
+ TrainerArgs(
187
+ restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter
188
+ skip_train_epoch=False,
189
+ start_with_eval=START_WITH_EVAL,
190
+ grad_accum_steps=GRAD_ACUMM_STEPS,
191
+ ),
192
+ config,
193
+ output_path=OUT_PATH,
194
+ model=model,
195
+ train_samples=train_samples,
196
+ eval_samples=eval_samples,
197
+ )
198
+ trainer.fit()
199
+
200
+ # get the longest text audio file to use as speaker reference
201
+ samples_len = [len(item["text"].split(" ")) for item in train_samples]
202
+ longest_text_idx = samples_len.index(max(samples_len))
203
+ speaker_ref = train_samples[longest_text_idx]["audio_file"]
204
+
205
+ trainer_out_path = trainer.output_path
206
+
207
+ # close file handlers and remove them from the logger
208
+ for handler in logging.getLogger('trainer').handlers:
209
+ if isinstance(handler, logging.FileHandler):
210
+ handler.close()
211
+ logging.getLogger('trainer').removeHandler(handler)
212
+
213
+ # now you should be able to delete the log file
214
+ log_file = os.path.join(trainer.output_path, f"trainer_{trainer.args.rank}_log.txt")
215
+ os.remove(log_file)
216
+
217
+ # deallocate VRAM and RAM
218
+ del model, trainer, train_samples, eval_samples
219
+ gc.collect()
220
+
221
+ return XTTS_SPEAKER_FILE,XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer_out_path, speaker_ref
utils/tokenizer.py ADDED
@@ -0,0 +1,869 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import textwrap
4
+ from functools import cached_property
5
+
6
+ import pypinyin
7
+ import torch
8
+ from hangul_romanize import Transliter
9
+ from hangul_romanize.rule import academic
10
+ from num2words import num2words
11
+ from spacy.lang.ar import Arabic
12
+ from spacy.lang.en import English
13
+ from spacy.lang.es import Spanish
14
+ from spacy.lang.ja import Japanese
15
+ from spacy.lang.zh import Chinese
16
+ from tokenizers import Tokenizer
17
+
18
+ from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
19
+
20
+
21
+ def get_spacy_lang(lang):
22
+ if lang == "zh":
23
+ return Chinese()
24
+ elif lang == "ja":
25
+ return Japanese()
26
+ elif lang == "ar":
27
+ return Arabic()
28
+ elif lang == "es":
29
+ return Spanish()
30
+ else:
31
+ # For most languages, Enlish does the job
32
+ return English()
33
+
34
+
35
+ def split_sentence(text, lang, text_split_length=250):
36
+ """Preprocess the input text"""
37
+ text_splits = []
38
+ if text_split_length is not None and len(text) >= text_split_length:
39
+ text_splits.append("")
40
+ nlp = get_spacy_lang(lang)
41
+ nlp.add_pipe("sentencizer")
42
+ doc = nlp(text)
43
+ for sentence in doc.sents:
44
+ if len(text_splits[-1]) + len(str(sentence)) <= text_split_length:
45
+ # if the last sentence + the current sentence is less than the text_split_length
46
+ # then add the current sentence to the last sentence
47
+ text_splits[-1] += " " + str(sentence)
48
+ text_splits[-1] = text_splits[-1].lstrip()
49
+ elif len(str(sentence)) > text_split_length:
50
+ # if the current sentence is greater than the text_split_length
51
+ for line in textwrap.wrap(
52
+ str(sentence),
53
+ width=text_split_length,
54
+ drop_whitespace=True,
55
+ break_on_hyphens=False,
56
+ tabsize=1,
57
+ ):
58
+ text_splits.append(str(line))
59
+ else:
60
+ text_splits.append(str(sentence))
61
+
62
+ if len(text_splits) > 1:
63
+ if text_splits[0] == "":
64
+ del text_splits[0]
65
+ else:
66
+ text_splits = [text.lstrip()]
67
+
68
+ return text_splits
69
+
70
+
71
+ _whitespace_re = re.compile(r"\s+")
72
+
73
+ # List of (regular expression, replacement) pairs for abbreviations:
74
+ _abbreviations = {
75
+ "en": [
76
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
77
+ for x in [
78
+ ("mrs", "misess"),
79
+ ("mr", "mister"),
80
+ ("dr", "doctor"),
81
+ ("st", "saint"),
82
+ ("co", "company"),
83
+ ("jr", "junior"),
84
+ ("maj", "major"),
85
+ ("gen", "general"),
86
+ ("drs", "doctors"),
87
+ ("rev", "reverend"),
88
+ ("lt", "lieutenant"),
89
+ ("hon", "honorable"),
90
+ ("sgt", "sergeant"),
91
+ ("capt", "captain"),
92
+ ("esq", "esquire"),
93
+ ("ltd", "limited"),
94
+ ("col", "colonel"),
95
+ ("ft", "fort"),
96
+ ]
97
+ ],
98
+ "es": [
99
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
100
+ for x in [
101
+ ("sra", "señora"),
102
+ ("sr", "señor"),
103
+ ("dr", "doctor"),
104
+ ("dra", "doctora"),
105
+ ("st", "santo"),
106
+ ("co", "compañía"),
107
+ ("jr", "junior"),
108
+ ("ltd", "limitada"),
109
+ ]
110
+ ],
111
+ "fr": [
112
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
113
+ for x in [
114
+ ("mme", "madame"),
115
+ ("mr", "monsieur"),
116
+ ("dr", "docteur"),
117
+ ("st", "saint"),
118
+ ("co", "compagnie"),
119
+ ("jr", "junior"),
120
+ ("ltd", "limitée"),
121
+ ]
122
+ ],
123
+ "de": [
124
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
125
+ for x in [
126
+ ("fr", "frau"),
127
+ ("dr", "doktor"),
128
+ ("st", "sankt"),
129
+ ("co", "firma"),
130
+ ("jr", "junior"),
131
+ ]
132
+ ],
133
+ "pt": [
134
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
135
+ for x in [
136
+ ("sra", "senhora"),
137
+ ("sr", "senhor"),
138
+ ("dr", "doutor"),
139
+ ("dra", "doutora"),
140
+ ("st", "santo"),
141
+ ("co", "companhia"),
142
+ ("jr", "júnior"),
143
+ ("ltd", "limitada"),
144
+ ]
145
+ ],
146
+ "it": [
147
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
148
+ for x in [
149
+ # ("sig.ra", "signora"),
150
+ ("sig", "signore"),
151
+ ("dr", "dottore"),
152
+ ("st", "santo"),
153
+ ("co", "compagnia"),
154
+ ("jr", "junior"),
155
+ ("ltd", "limitata"),
156
+ ]
157
+ ],
158
+ "pl": [
159
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
160
+ for x in [
161
+ ("p", "pani"),
162
+ ("m", "pan"),
163
+ ("dr", "doktor"),
164
+ ("sw", "święty"),
165
+ ("jr", "junior"),
166
+ ]
167
+ ],
168
+ "ar": [
169
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
170
+ for x in [
171
+ # There are not many common abbreviations in Arabic as in English.
172
+ ]
173
+ ],
174
+ "zh": [
175
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
176
+ for x in [
177
+ # Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
178
+ ]
179
+ ],
180
+ "cs": [
181
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
182
+ for x in [
183
+ ("dr", "doktor"), # doctor
184
+ ("ing", "inženýr"), # engineer
185
+ ("p", "pan"), # Could also map to pani for woman but no easy way to do it
186
+ # Other abbreviations would be specialized and not as common.
187
+ ]
188
+ ],
189
+ "ru": [
190
+ (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
191
+ for x in [
192
+ ("г-жа", "госпожа"), # Mrs.
193
+ ("г-н", "господин"), # Mr.
194
+ ("д-р", "доктор"), # doctor
195
+ # Other abbreviations are less common or specialized.
196
+ ]
197
+ ],
198
+ "nl": [
199
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
200
+ for x in [
201
+ ("dhr", "de heer"), # Mr.
202
+ ("mevr", "mevrouw"), # Mrs.
203
+ ("dr", "dokter"), # doctor
204
+ ("jhr", "jonkheer"), # young lord or nobleman
205
+ # Dutch uses more abbreviations, but these are the most common ones.
206
+ ]
207
+ ],
208
+ "tr": [
209
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
210
+ for x in [
211
+ ("b", "bay"), # Mr.
212
+ ("byk", "büyük"), # büyük
213
+ ("dr", "doktor"), # doctor
214
+ # Add other Turkish abbreviations here if needed.
215
+ ]
216
+ ],
217
+ "hu": [
218
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
219
+ for x in [
220
+ ("dr", "doktor"), # doctor
221
+ ("b", "bácsi"), # Mr.
222
+ ("nőv", "nővér"), # nurse
223
+ # Add other Hungarian abbreviations here if needed.
224
+ ]
225
+ ],
226
+ "ko": [
227
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
228
+ for x in [
229
+ # Korean doesn't typically use abbreviations in the same way as Latin-based scripts.
230
+ ]
231
+ ],
232
+ "ja": [
233
+ (re.compile("\\b%s\\b" % x[0]), x[1])
234
+ for x in [
235
+ ("氏", "さん"), # Mr.
236
+ ("夫人", "おんなのひと"), # Mrs.
237
+ ("博士", "はかせ"), # Doctor or PhD
238
+ ("株", "株式会社"), # Corporation
239
+ ("有", "有限会社"), # Limited company
240
+ ("大学", "だいがく"), # University
241
+ ("先生", "せんせい"), # Teacher/Professor/Master
242
+ ("君", "くん") # Used at the end of boys' names to express familiarity or affection.
243
+ ]
244
+ ],
245
+ }
246
+
247
+
248
+ def expand_abbreviations_multilingual(text, lang="en"):
249
+ for regex, replacement in _abbreviations[lang]:
250
+ text = re.sub(regex, replacement, text)
251
+ return text
252
+
253
+
254
+ _symbols_multilingual = {
255
+ "en": [
256
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
257
+ for x in [
258
+ ("&", " and "),
259
+ ("@", " at "),
260
+ ("%", " percent "),
261
+ ("#", " hash "),
262
+ ("$", " dollar "),
263
+ ("£", " pound "),
264
+ ("°", " degree "),
265
+ ]
266
+ ],
267
+ "es": [
268
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
269
+ for x in [
270
+ ("&", " y "),
271
+ ("@", " arroba "),
272
+ ("%", " por ciento "),
273
+ ("#", " numeral "),
274
+ ("$", " dolar "),
275
+ ("£", " libra "),
276
+ ("°", " grados "),
277
+ ]
278
+ ],
279
+ "fr": [
280
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
281
+ for x in [
282
+ ("&", " et "),
283
+ ("@", " arobase "),
284
+ ("%", " pour cent "),
285
+ ("#", " dièse "),
286
+ ("$", " dollar "),
287
+ ("£", " livre "),
288
+ ("°", " degrés "),
289
+ ]
290
+ ],
291
+ "de": [
292
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
293
+ for x in [
294
+ ("&", " und "),
295
+ ("@", " at "),
296
+ ("%", " prozent "),
297
+ ("#", " raute "),
298
+ ("$", " dollar "),
299
+ ("£", " pfund "),
300
+ ("°", " grad "),
301
+ ]
302
+ ],
303
+ "pt": [
304
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
305
+ for x in [
306
+ ("&", " e "),
307
+ ("@", " arroba "),
308
+ ("%", " por cento "),
309
+ ("#", " cardinal "),
310
+ ("$", " dólar "),
311
+ ("£", " libra "),
312
+ ("°", " graus "),
313
+ ]
314
+ ],
315
+ "it": [
316
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
317
+ for x in [
318
+ ("&", " e "),
319
+ ("@", " chiocciola "),
320
+ ("%", " per cento "),
321
+ ("#", " cancelletto "),
322
+ ("$", " dollaro "),
323
+ ("£", " sterlina "),
324
+ ("°", " gradi "),
325
+ ]
326
+ ],
327
+ "pl": [
328
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
329
+ for x in [
330
+ ("&", " i "),
331
+ ("@", " małpa "),
332
+ ("%", " procent "),
333
+ ("#", " krzyżyk "),
334
+ ("$", " dolar "),
335
+ ("£", " funt "),
336
+ ("°", " stopnie "),
337
+ ]
338
+ ],
339
+ "ar": [
340
+ # Arabic
341
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
342
+ for x in [
343
+ ("&", " و "),
344
+ ("@", " على "),
345
+ ("%", " في المئة "),
346
+ ("#", " رقم "),
347
+ ("$", " دولار "),
348
+ ("£", " جنيه "),
349
+ ("°", " درجة "),
350
+ ]
351
+ ],
352
+ "zh": [
353
+ # Chinese
354
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
355
+ for x in [
356
+ ("&", " 和 "),
357
+ ("@", " 在 "),
358
+ ("%", " 百分之 "),
359
+ ("#", " 号 "),
360
+ ("$", " 美元 "),
361
+ ("£", " 英镑 "),
362
+ ("°", " 度 "),
363
+ ]
364
+ ],
365
+ "cs": [
366
+ # Czech
367
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
368
+ for x in [
369
+ ("&", " a "),
370
+ ("@", " na "),
371
+ ("%", " procento "),
372
+ ("#", " křížek "),
373
+ ("$", " dolar "),
374
+ ("£", " libra "),
375
+ ("°", " stupně "),
376
+ ]
377
+ ],
378
+ "ru": [
379
+ # Russian
380
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
381
+ for x in [
382
+ ("&", " и "),
383
+ ("@", " собака "),
384
+ ("%", " процентов "),
385
+ ("#", " номер "),
386
+ ("$", " доллар "),
387
+ ("£", " фунт "),
388
+ ("°", " градус "),
389
+ ]
390
+ ],
391
+ "nl": [
392
+ # Dutch
393
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
394
+ for x in [
395
+ ("&", " en "),
396
+ ("@", " bij "),
397
+ ("%", " procent "),
398
+ ("#", " hekje "),
399
+ ("$", " dollar "),
400
+ ("£", " pond "),
401
+ ("°", " graden "),
402
+ ]
403
+ ],
404
+ "tr": [
405
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
406
+ for x in [
407
+ ("&", " ve "),
408
+ ("@", " at "),
409
+ ("%", " yüzde "),
410
+ ("#", " diyez "),
411
+ ("$", " dolar "),
412
+ ("£", " sterlin "),
413
+ ("°", " derece "),
414
+ ]
415
+ ],
416
+ "hu": [
417
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
418
+ for x in [
419
+ ("&", " és "),
420
+ ("@", " kukac "),
421
+ ("%", " százalék "),
422
+ ("#", " kettőskereszt "),
423
+ ("$", " dollár "),
424
+ ("£", " font "),
425
+ ("°", " fok "),
426
+ ]
427
+ ],
428
+ "ko": [
429
+ # Korean
430
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
431
+ for x in [
432
+ ("&", " 그리고 "),
433
+ ("@", " 에 "),
434
+ ("%", " 퍼센트 "),
435
+ ("#", " 번호 "),
436
+ ("$", " 달러 "),
437
+ ("£", " 파운드 "),
438
+ ("°", " 도 "),
439
+ ]
440
+ ],
441
+ "ja": [
442
+ (re.compile(r"%s" % re.escape(x[0])), x[1])
443
+ for x in [
444
+ ("&", " と "),
445
+ ("@", " アットマーク "),
446
+ ("%", " パーセント "),
447
+ ("#", " ナンバー "),
448
+ ("$", " ドル "),
449
+ ("£", " ポンド "),
450
+ ("°", " 度"),
451
+ ]
452
+ ],
453
+ }
454
+
455
+
456
+ def expand_symbols_multilingual(text, lang="en"):
457
+ for regex, replacement in _symbols_multilingual[lang]:
458
+ text = re.sub(regex, replacement, text)
459
+ text = text.replace(" ", " ") # Ensure there are no double spaces
460
+ return text.strip()
461
+
462
+
463
+ _ordinal_re = {
464
+ "en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
465
+ "es": re.compile(r"([0-9]+)(º|ª|er|o|a|os|as)"),
466
+ "fr": re.compile(r"([0-9]+)(º|ª|er|re|e|ème)"),
467
+ "de": re.compile(r"([0-9]+)(st|nd|rd|th|º|ª|\.(?=\s|$))"),
468
+ "pt": re.compile(r"([0-9]+)(º|ª|o|a|os|as)"),
469
+ "it": re.compile(r"([0-9]+)(º|°|ª|o|a|i|e)"),
470
+ "pl": re.compile(r"([0-9]+)(º|ª|st|nd|rd|th)"),
471
+ "ar": re.compile(r"([0-9]+)(ون|ين|ث|ر|ى)"),
472
+ "cs": re.compile(r"([0-9]+)\.(?=\s|$)"), # In Czech, a dot is often used after the number to indicate ordinals.
473
+ "ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"),
474
+ "nl": re.compile(r"([0-9]+)(de|ste|e)"),
475
+ "tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"),
476
+ "hu": re.compile(r"([0-9]+)(\.|adik|edik|odik|edik|ödik|ödike|ik)"),
477
+ "ko": re.compile(r"([0-9]+)(번째|번|차|째)"),
478
+ "ja": re.compile(r"([0-9]+)(番|回|つ|目|等|位)")
479
+ }
480
+ _number_re = re.compile(r"[0-9]+")
481
+ _currency_re = {
482
+ "USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
483
+ "GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
484
+ "EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
485
+ }
486
+
487
+ _comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
488
+ _dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
489
+ _decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
490
+
491
+
492
+ def _remove_commas(m):
493
+ text = m.group(0)
494
+ if "," in text:
495
+ text = text.replace(",", "")
496
+ return text
497
+
498
+
499
+ def _remove_dots(m):
500
+ text = m.group(0)
501
+ if "." in text:
502
+ text = text.replace(".", "")
503
+ return text
504
+
505
+
506
+ def _expand_decimal_point(m, lang="en"):
507
+ amount = m.group(1).replace(",", ".")
508
+ return num2words(float(amount), lang=lang if lang != "cs" else "cz")
509
+
510
+
511
+ def _expand_currency(m, lang="en", currency="USD"):
512
+ amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
513
+ full_amount = num2words(amount, to="currency", currency=currency, lang=lang if lang != "cs" else "cz")
514
+
515
+ and_equivalents = {
516
+ "en": ", ",
517
+ "es": " con ",
518
+ "fr": " et ",
519
+ "de": " und ",
520
+ "pt": " e ",
521
+ "it": " e ",
522
+ "pl": ", ",
523
+ "cs": ", ",
524
+ "ru": ", ",
525
+ "nl": ", ",
526
+ "ar": ", ",
527
+ "tr": ", ",
528
+ "hu": ", ",
529
+ "ko": ", ",
530
+ }
531
+
532
+ if amount.is_integer():
533
+ last_and = full_amount.rfind(and_equivalents[lang])
534
+ if last_and != -1:
535
+ full_amount = full_amount[:last_and]
536
+
537
+ return full_amount
538
+
539
+
540
+ def _expand_ordinal(m, lang="en"):
541
+ return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz")
542
+
543
+
544
+ def _expand_number(m, lang="en"):
545
+ return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz")
546
+
547
+
548
+ def expand_numbers_multilingual(text, lang="en"):
549
+ if lang == "zh":
550
+ text = zh_num2words()(text)
551
+ else:
552
+ if lang in ["en", "ru"]:
553
+ text = re.sub(_comma_number_re, _remove_commas, text)
554
+ else:
555
+ text = re.sub(_dot_number_re, _remove_dots, text)
556
+ try:
557
+ text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
558
+ text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
559
+ text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
560
+ except:
561
+ pass
562
+ if lang != "tr":
563
+ text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
564
+ text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
565
+ text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
566
+ return text
567
+
568
+
569
+ def lowercase(text):
570
+ return text.lower()
571
+
572
+
573
+ def collapse_whitespace(text):
574
+ return re.sub(_whitespace_re, " ", text)
575
+
576
+
577
+ def multilingual_cleaners(text, lang):
578
+ text = text.replace('"', "")
579
+ if lang == "tr":
580
+ text = text.replace("İ", "i")
581
+ text = text.replace("Ö", "ö")
582
+ text = text.replace("Ü", "ü")
583
+ text = lowercase(text)
584
+ text = expand_numbers_multilingual(text, lang)
585
+ text = expand_abbreviations_multilingual(text, lang)
586
+ text = expand_symbols_multilingual(text, lang=lang)
587
+ text = collapse_whitespace(text)
588
+ return text
589
+
590
+
591
+ def basic_cleaners(text):
592
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
593
+ text = lowercase(text)
594
+ text = collapse_whitespace(text)
595
+ return text
596
+
597
+
598
+ def chinese_transliterate(text):
599
+ return "".join(
600
+ [p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]
601
+ )
602
+
603
+
604
+ def japanese_cleaners(text, katsu):
605
+ text = katsu.romaji(text)
606
+ text = lowercase(text)
607
+ return text
608
+
609
+
610
+ def korean_transliterate(text):
611
+ r = Transliter(academic)
612
+ return r.translit(text)
613
+
614
+
615
+ DEFAULT_VOCAB_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../data/tokenizer.json")
616
+
617
+
618
+ class VoiceBpeTokenizer:
619
+ def __init__(self, vocab_file=None):
620
+ self.tokenizer = None
621
+ if vocab_file is not None:
622
+ self.tokenizer = Tokenizer.from_file(vocab_file)
623
+ self.char_limits = {
624
+ "en": 250,
625
+ "de": 253,
626
+ "fr": 273,
627
+ "es": 239,
628
+ "it": 213,
629
+ "pt": 203,
630
+ "pl": 224,
631
+ "zh": 82,
632
+ "ar": 166,
633
+ "cs": 186,
634
+ "ru": 182,
635
+ "nl": 251,
636
+ "tr": 226,
637
+ "ja": 71,
638
+ "hu": 224,
639
+ "ko": 95,
640
+ }
641
+
642
+ @cached_property
643
+ def katsu(self):
644
+ import cutlet
645
+
646
+ return cutlet.Cutlet()
647
+
648
+ def check_input_length(self, txt, lang):
649
+ lang = lang.split("-")[0] # remove the region
650
+ limit = self.char_limits.get(lang, 250)
651
+ if len(txt) > limit:
652
+ print(
653
+ f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio."
654
+ )
655
+
656
+ def preprocess_text(self, txt, lang):
657
+ if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "ko"}:
658
+ txt = multilingual_cleaners(txt, lang)
659
+ if lang == "zh":
660
+ txt = chinese_transliterate(txt)
661
+ if lang == "ko":
662
+ txt = korean_transliterate(txt)
663
+ elif lang == "ja":
664
+ txt = japanese_cleaners(txt, self.katsu)
665
+ elif lang == "hi":
666
+ # @manmay will implement this
667
+ txt = basic_cleaners(txt)
668
+ else:
669
+ raise NotImplementedError(f"Language '{lang}' is not supported.")
670
+ return txt
671
+
672
+ def encode(self, txt, lang):
673
+ lang = lang.split("-")[0] # remove the region
674
+ self.check_input_length(txt, lang)
675
+ txt = self.preprocess_text(txt, lang)
676
+ lang = "zh-cn" if lang == "zh" else lang
677
+ txt = f"[{lang}]{txt}"
678
+ txt = txt.replace(" ", "[SPACE]")
679
+ return self.tokenizer.encode(txt).ids
680
+
681
+ def decode(self, seq):
682
+ if isinstance(seq, torch.Tensor):
683
+ seq = seq.cpu().numpy()
684
+ txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(" ", "")
685
+ txt = txt.replace("[SPACE]", " ")
686
+ txt = txt.replace("[STOP]", "")
687
+ txt = txt.replace("[UNK]", "")
688
+ return txt
689
+
690
+ def __len__(self):
691
+ return self.tokenizer.get_vocab_size()
692
+
693
+ def get_number_tokens(self):
694
+ return max(self.tokenizer.get_vocab().values()) + 1
695
+
696
+
697
+ def test_expand_numbers_multilingual():
698
+ test_cases = [
699
+ # English
700
+ ("In 12.5 seconds.", "In twelve point five seconds.", "en"),
701
+ ("There were 50 soldiers.", "There were fifty soldiers.", "en"),
702
+ ("This is a 1st test", "This is a first test", "en"),
703
+ ("That will be $20 sir.", "That will be twenty dollars sir.", "en"),
704
+ ("That will be 20€ sir.", "That will be twenty euro sir.", "en"),
705
+ ("That will be 20.15€ sir.", "That will be twenty euro, fifteen cents sir.", "en"),
706
+ ("That's 100,000.5.", "That's one hundred thousand point five.", "en"),
707
+ # French
708
+ ("En 12,5 secondes.", "En douze virgule cinq secondes.", "fr"),
709
+ ("Il y avait 50 soldats.", "Il y avait cinquante soldats.", "fr"),
710
+ ("Ceci est un 1er test", "Ceci est un premier test", "fr"),
711
+ ("Cela vous fera $20 monsieur.", "Cela vous fera vingt dollars monsieur.", "fr"),
712
+ ("Cela vous fera 20€ monsieur.", "Cela vous fera vingt euros monsieur.", "fr"),
713
+ ("Cela vous fera 20,15€ monsieur.", "Cela vous fera vingt euros et quinze centimes monsieur.", "fr"),
714
+ ("Ce sera 100.000,5.", "Ce sera cent mille virgule cinq.", "fr"),
715
+ # German
716
+ ("In 12,5 Sekunden.", "In zwölf Komma fünf Sekunden.", "de"),
717
+ ("Es gab 50 Soldaten.", "Es gab fünfzig Soldaten.", "de"),
718
+ ("Dies ist ein 1. Test", "Dies ist ein erste Test", "de"), # Issue with gender
719
+ ("Das macht $20 Herr.", "Das macht zwanzig Dollar Herr.", "de"),
720
+ ("Das macht 20€ Herr.", "Das macht zwanzig Euro Herr.", "de"),
721
+ ("Das macht 20,15€ Herr.", "Das macht zwanzig Euro und fünfzehn Cent Herr.", "de"),
722
+ # Spanish
723
+ ("En 12,5 segundos.", "En doce punto cinco segundos.", "es"),
724
+ ("Había 50 soldados.", "Había cincuenta soldados.", "es"),
725
+ ("Este es un 1er test", "Este es un primero test", "es"),
726
+ ("Eso le costará $20 señor.", "Eso le costará veinte dólares señor.", "es"),
727
+ ("Eso le costará 20€ señor.", "Eso le costará veinte euros señor.", "es"),
728
+ ("Eso le costará 20,15€ señor.", "Eso le costará veinte euros con quince céntimos señor.", "es"),
729
+ # Italian
730
+ ("In 12,5 secondi.", "In dodici virgola cinque secondi.", "it"),
731
+ ("C'erano 50 soldati.", "C'erano cinquanta soldati.", "it"),
732
+ ("Questo è un 1° test", "Questo è un primo test", "it"),
733
+ ("Ti costerà $20 signore.", "Ti costerà venti dollari signore.", "it"),
734
+ ("Ti costerà 20€ signore.", "Ti costerà venti euro signore.", "it"),
735
+ ("Ti costerà 20,15€ signore.", "Ti costerà venti euro e quindici centesimi signore.", "it"),
736
+ # Portuguese
737
+ ("Em 12,5 segundos.", "Em doze vírgula cinco segundos.", "pt"),
738
+ ("Havia 50 soldados.", "Havia cinquenta soldados.", "pt"),
739
+ ("Este é um 1º teste", "Este é um primeiro teste", "pt"),
740
+ ("Isso custará $20 senhor.", "Isso custará vinte dólares senhor.", "pt"),
741
+ ("Isso custará 20€ senhor.", "Isso custará vinte euros senhor.", "pt"),
742
+ (
743
+ "Isso custará 20,15€ senhor.",
744
+ "Isso custará vinte euros e quinze cêntimos senhor.",
745
+ "pt",
746
+ ), # "cêntimos" should be "centavos" num2words issue
747
+ # Polish
748
+ ("W 12,5 sekundy.", "W dwanaście przecinek pięć sekundy.", "pl"),
749
+ ("Było 50 żołnierzy.", "Było pięćdziesiąt żołnierzy.", "pl"),
750
+ ("To będzie kosztować 20€ panie.", "To będzie kosztować dwadzieścia euro panie.", "pl"),
751
+ ("To będzie kosztować 20,15€ panie.", "To będzie kosztować dwadzieścia euro, piętnaście centów panie.", "pl"),
752
+ # Arabic
753
+ ("في الـ 12,5 ثانية.", "في الـ اثنا عشر , خمسون ثانية.", "ar"),
754
+ ("كان هناك 50 جنديًا.", "كان هناك خمسون جنديًا.", "ar"),
755
+ # ("ستكون النتيجة $20 يا سيد.", 'ستكون النتيجة عشرون دولار يا سيد.', 'ar'), # $ and € are mising from num2words
756
+ # ("ستكون النتيجة 20€ يا سيد.", 'ستكون النتيجة عشرون يورو يا سيد.', 'ar'),
757
+ # Czech
758
+ ("Za 12,5 vteřiny.", "Za dvanáct celá pět vteřiny.", "cs"),
759
+ ("Bylo tam 50 vojáků.", "Bylo tam padesát vojáků.", "cs"),
760
+ ("To bude stát 20€ pane.", "To bude stát dvacet euro pane.", "cs"),
761
+ ("To bude 20.15€ pane.", "To bude dvacet euro, patnáct centů pane.", "cs"),
762
+ # Russian
763
+ ("Через 12.5 секунды.", "Через двенадцать запятая пять секунды.", "ru"),
764
+ ("Там было 50 солдат.", "Там было пятьдесят солдат.", "ru"),
765
+ ("Это будет 20.15€ сэр.", "Это будет двадцать евро, пятнадцать центов сэр.", "ru"),
766
+ ("Это будет стоить 20€ господин.", "Это будет стоить двадцать евро господин.", "ru"),
767
+ # Dutch
768
+ ("In 12,5 seconden.", "In twaalf komma vijf seconden.", "nl"),
769
+ ("Er waren 50 soldaten.", "Er waren vijftig soldaten.", "nl"),
770
+ ("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"),
771
+ ("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"),
772
+ # Chinese (Simplified)
773
+ ("在12.5秒内", "在十二点五秒内", "zh"),
774
+ ("有50名士兵", "有五十名士兵", "zh"),
775
+ # ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work
776
+ # ("那将是20€先生", '那将是二十欧元先生', 'zh'),
777
+ # Turkish
778
+ # ("12,5 saniye içinde.", 'On iki virgül beş saniye içinde.', 'tr'), # decimal doesn't work for TR
779
+ ("50 asker vardı.", "elli asker vardı.", "tr"),
780
+ ("Bu 1. test", "Bu birinci test", "tr"),
781
+ # ("Bu 100.000,5.", 'Bu yüz bin virgül beş.', 'tr'),
782
+ # Hungarian
783
+ ("12,5 másodperc alatt.", "tizenkettő egész öt tized másodperc alatt.", "hu"),
784
+ ("50 katona volt.", "ötven katona volt.", "hu"),
785
+ ("Ez az 1. teszt", "Ez az első teszt", "hu"),
786
+ # Korean
787
+ ("12.5 초 안에.", "십이 점 다섯 초 안에.", "ko"),
788
+ ("50 명의 병사가 있었다.", "오십 명의 병사가 있었다.", "ko"),
789
+ ("이것은 1 번째 테스트입니다", "이것은 첫 번째 테스트입니다", "ko"),
790
+ ]
791
+ for a, b, lang in test_cases:
792
+ out = expand_numbers_multilingual(a, lang=lang)
793
+ assert out == b, f"'{out}' vs '{b}'"
794
+
795
+
796
+ def test_abbreviations_multilingual():
797
+ test_cases = [
798
+ # English
799
+ ("Hello Mr. Smith.", "Hello mister Smith.", "en"),
800
+ ("Dr. Jones is here.", "doctor Jones is here.", "en"),
801
+ # Spanish
802
+ ("Hola Sr. Garcia.", "Hola señor Garcia.", "es"),
803
+ ("La Dra. Martinez es muy buena.", "La doctora Martinez es muy buena.", "es"),
804
+ # French
805
+ ("Bonjour Mr. Dupond.", "Bonjour monsieur Dupond.", "fr"),
806
+ ("Mme. Moreau est absente aujourd'hui.", "madame Moreau est absente aujourd'hui.", "fr"),
807
+ # German
808
+ ("Frau Dr. Müller ist sehr klug.", "Frau doktor Müller ist sehr klug.", "de"),
809
+ # Portuguese
810
+ ("Olá Sr. Silva.", "Olá senhor Silva.", "pt"),
811
+ ("Dra. Costa, você está disponível?", "doutora Costa, você está disponível?", "pt"),
812
+ # Italian
813
+ ("Buongiorno, Sig. Rossi.", "Buongiorno, signore Rossi.", "it"),
814
+ # ("Sig.ra Bianchi, posso aiutarti?", 'signora Bianchi, posso aiutarti?', 'it'), # Issue with matching that pattern
815
+ # Polish
816
+ ("Dzień dobry, P. Kowalski.", "Dzień dobry, pani Kowalski.", "pl"),
817
+ ("M. Nowak, czy mogę zadać pytanie?", "pan Nowak, czy mogę zadać pytanie?", "pl"),
818
+ # Czech
819
+ ("P. Novák", "pan Novák", "cs"),
820
+ ("Dr. Vojtěch", "doktor Vojtěch", "cs"),
821
+ # Dutch
822
+ ("Dhr. Jansen", "de heer Jansen", "nl"),
823
+ ("Mevr. de Vries", "mevrouw de Vries", "nl"),
824
+ # Russian
825
+ ("Здравствуйте Г-н Иванов.", "Здравствуйте господин Иванов.", "ru"),
826
+ ("Д-р Смирнов здесь, чтобы увидеть вас.", "доктор Смирнов здесь, чтобы увидеть вас.", "ru"),
827
+ # Turkish
828
+ ("Merhaba B. Yılmaz.", "Merhaba bay Yılmaz.", "tr"),
829
+ ("Dr. Ayşe burada.", "doktor Ayşe burada.", "tr"),
830
+ # Hungarian
831
+ ("Dr. Szabó itt van.", "doktor Szabó itt van.", "hu"),
832
+ ]
833
+
834
+ for a, b, lang in test_cases:
835
+ out = expand_abbreviations_multilingual(a, lang=lang)
836
+ assert out == b, f"'{out}' vs '{b}'"
837
+
838
+
839
+ def test_symbols_multilingual():
840
+ test_cases = [
841
+ ("I have 14% battery", "I have 14 percent battery", "en"),
842
+ ("Te veo @ la fiesta", "Te veo arroba la fiesta", "es"),
843
+ ("J'ai 14° de fièvre", "J'ai 14 degrés de fièvre", "fr"),
844
+ ("Die Rechnung beträgt £ 20", "Die Rechnung beträgt pfund 20", "de"),
845
+ ("O meu email é ana&joao@gmail.com", "O meu email é ana e joao arroba gmail.com", "pt"),
846
+ ("linguaggio di programmazione C#", "linguaggio di programmazione C cancelletto", "it"),
847
+ ("Moja temperatura to 36.6°", "Moja temperatura to 36.6 stopnie", "pl"),
848
+ ("Mám 14% baterie", "Mám 14 procento baterie", "cs"),
849
+ ("Těším se na tebe @ party", "Těším se na tebe na party", "cs"),
850
+ ("У меня 14% заряда", "У меня 14 процентов заряда", "ru"),
851
+ ("Я буду @ дома", "Я буду собака дома", "ru"),
852
+ ("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"),
853
+ ("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"),
854
+ ("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"),
855
+ ("我的电量为 14%", "我的电量为 14 百分之", "zh"),
856
+ ("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"),
857
+ ("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"),
858
+ ("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"),
859
+ ]
860
+
861
+ for a, b, lang in test_cases:
862
+ out = expand_symbols_multilingual(a, lang=lang)
863
+ assert out == b, f"'{out}' vs '{b}'"
864
+
865
+
866
+ if __name__ == "__main__":
867
+ test_expand_numbers_multilingual()
868
+ test_abbreviations_multilingual()
869
+ test_symbols_multilingual()
xtts_demo.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ import tempfile
5
+ from pathlib import Path
6
+
7
+ import os
8
+ import shutil
9
+ import glob
10
+
11
+ import gradio as gr
12
+ import librosa.display
13
+ import numpy as np
14
+
15
+ import torch
16
+ import torchaudio
17
+ import traceback
18
+ from utils.formatter import format_audio_list,find_latest_best_model, list_audios
19
+ from utils.gpt_train import train_gpt
20
+
21
+ from faster_whisper import WhisperModel
22
+
23
+ from TTS.tts.configs.xtts_config import XttsConfig
24
+ from TTS.tts.models.xtts import Xtts
25
+
26
+ from TTS.tts.configs.xtts_config import XttsConfig
27
+ from TTS.tts.models.xtts import Xtts
28
+
29
+ # Clear logs
30
+ def remove_log_file(file_path):
31
+ log_file = Path(file_path)
32
+
33
+ if log_file.exists() and log_file.is_file():
34
+ log_file.unlink()
35
+
36
+ # remove_log_file(str(Path.cwd() / "log.out"))
37
+
38
+ def clear_gpu_cache():
39
+ # clear the GPU cache
40
+ if torch.cuda.is_available():
41
+ torch.cuda.empty_cache()
42
+
43
+ XTTS_MODEL = None
44
+ def load_model(xtts_checkpoint, xtts_config, xtts_vocab,xtts_speaker):
45
+ global XTTS_MODEL
46
+ clear_gpu_cache()
47
+ if not xtts_checkpoint or not xtts_config or not xtts_vocab:
48
+ return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!"
49
+ config = XttsConfig()
50
+ config.load_json(xtts_config)
51
+ XTTS_MODEL = Xtts.init_from_config(config)
52
+ print("Loading XTTS model! ")
53
+ XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab,speaker_file_path=xtts_speaker, use_deepspeed=False)
54
+ if torch.cuda.is_available():
55
+ XTTS_MODEL.cuda()
56
+
57
+ print("Model Loaded!")
58
+ return "Model Loaded!"
59
+
60
+ def run_tts(lang, tts_text, speaker_audio_file, temperature, length_penalty,repetition_penalty,top_k,top_p,sentence_split,use_config):
61
+ if XTTS_MODEL is None or not speaker_audio_file:
62
+ return "You need to run the previous step to load the model !!", None, None
63
+
64
+ gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs)
65
+
66
+ if use_config:
67
+ out = XTTS_MODEL.inference(
68
+ text=tts_text,
69
+ language=lang,
70
+ gpt_cond_latent=gpt_cond_latent,
71
+ speaker_embedding=speaker_embedding,
72
+ temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
73
+ length_penalty=XTTS_MODEL.config.length_penalty,
74
+ repetition_penalty=XTTS_MODEL.config.repetition_penalty,
75
+ top_k=XTTS_MODEL.config.top_k,
76
+ top_p=XTTS_MODEL.config.top_p,
77
+ enable_text_splitting = True
78
+ )
79
+ else:
80
+ out = XTTS_MODEL.inference(
81
+ text=tts_text,
82
+ language=lang,
83
+ gpt_cond_latent=gpt_cond_latent,
84
+ speaker_embedding=speaker_embedding,
85
+ temperature=temperature, # Add custom parameters here
86
+ length_penalty=length_penalty,
87
+ repetition_penalty=float(repetition_penalty),
88
+ top_k=top_k,
89
+ top_p=top_p,
90
+ enable_text_splitting = sentence_split
91
+ )
92
+
93
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
94
+ out["wav"] = torch.tensor(out["wav"]).unsqueeze(0)
95
+ out_path = fp.name
96
+ torchaudio.save(out_path, out["wav"], 24000)
97
+
98
+ return "Speech generated !", out_path, speaker_audio_file
99
+
100
+
101
+ def load_params_tts(out_path,version):
102
+
103
+ out_path = Path(out_path)
104
+
105
+ # base_model_path = Path.cwd() / "models" / version
106
+
107
+ # if not base_model_path.exists():
108
+ # return "Base model not found !","","",""
109
+
110
+ ready_model_path = out_path / "ready"
111
+
112
+ vocab_path = ready_model_path / "vocab.json"
113
+ config_path = ready_model_path / "config.json"
114
+ speaker_path = ready_model_path / "speakers_xtts.pth"
115
+ reference_path = ready_model_path / "reference.wav"
116
+
117
+ model_path = ready_model_path / "model.pth"
118
+
119
+ if not model_path.exists():
120
+ model_path = ready_model_path / "unoptimize_model.pth"
121
+ if not model_path.exists():
122
+ return "Params for TTS not found", "", "", ""
123
+
124
+ return "Params for TTS loaded", model_path, config_path, vocab_path,speaker_path, reference_path
125
+
126
+
127
+ if __name__ == "__main__":
128
+
129
+ parser = argparse.ArgumentParser(
130
+ description="""XTTS fine-tuning demo\n\n"""
131
+ """
132
+ Example runs:
133
+ python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port
134
+ """,
135
+ formatter_class=argparse.RawTextHelpFormatter,
136
+ )
137
+ parser.add_argument(
138
+ "--port",
139
+ type=int,
140
+ help="Port to run the gradio demo. Default: 5003",
141
+ default=5003,
142
+ )
143
+ parser.add_argument(
144
+ "--out_path",
145
+ type=str,
146
+ help="Output path (where data and checkpoints will be saved) Default: output/",
147
+ default=str(Path.cwd() / "finetune_models"),
148
+ )
149
+
150
+ parser.add_argument(
151
+ "--num_epochs",
152
+ type=int,
153
+ help="Number of epochs to train. Default: 6",
154
+ default=6,
155
+ )
156
+ parser.add_argument(
157
+ "--batch_size",
158
+ type=int,
159
+ help="Batch size. Default: 2",
160
+ default=2,
161
+ )
162
+ parser.add_argument(
163
+ "--grad_acumm",
164
+ type=int,
165
+ help="Grad accumulation steps. Default: 1",
166
+ default=1,
167
+ )
168
+ parser.add_argument(
169
+ "--max_audio_length",
170
+ type=int,
171
+ help="Max permitted audio size in seconds. Default: 11",
172
+ default=11,
173
+ )
174
+
175
+ args = parser.parse_args()
176
+
177
+ with gr.Blocks() as demo:
178
+ with gr.Tab("1 - Data processing"):
179
+ out_path = gr.Textbox(
180
+ label="Output path (where data and checkpoints will be saved):",
181
+ value=args.out_path,
182
+ )
183
+ # upload_file = gr.Audio(
184
+ # sources="upload",
185
+ # label="Select here the audio files that you want to use for XTTS trainining !",
186
+ # type="filepath",
187
+ # )
188
+ upload_file = gr.File(
189
+ file_count="multiple",
190
+ label="Select here the audio files that you want to use for XTTS trainining (Supported formats: wav, mp3, and flac)",
191
+ )
192
+
193
+ audio_folder_path = gr.Textbox(
194
+ label="Path to the folder with audio files (optional):",
195
+ value="",
196
+ )
197
+
198
+ whisper_model = gr.Dropdown(
199
+ label="Whisper Model",
200
+ value="large-v3",
201
+ choices=[
202
+ "large-v3",
203
+ "large-v2",
204
+ "large",
205
+ "medium",
206
+ "small"
207
+ ],
208
+ )
209
+
210
+ lang = gr.Dropdown(
211
+ label="Dataset Language",
212
+ value="en",
213
+ choices=[
214
+ "en",
215
+ "es",
216
+ "fr",
217
+ "de",
218
+ "it",
219
+ "pt",
220
+ "pl",
221
+ "tr",
222
+ "ru",
223
+ "nl",
224
+ "cs",
225
+ "ar",
226
+ "zh",
227
+ "hu",
228
+ "ko",
229
+ "ja"
230
+ ],
231
+ )
232
+ progress_data = gr.Label(
233
+ label="Progress:"
234
+ )
235
+ # demo.load(read_logs, None, logs, every=1)
236
+
237
+ prompt_compute_btn = gr.Button(value="Step 1 - Create dataset")
238
+
239
+ def preprocess_dataset(audio_path, audio_folder_path, language, whisper_model, out_path, train_csv, eval_csv, progress=gr.Progress(track_tqdm=True)):
240
+ clear_gpu_cache()
241
+
242
+ train_csv = ""
243
+ eval_csv = ""
244
+
245
+ out_path = os.path.join(out_path, "dataset")
246
+ os.makedirs(out_path, exist_ok=True)
247
+
248
+ if audio_folder_path:
249
+ audio_files = list(list_audios(audio_folder_path))
250
+ else:
251
+ audio_files = audio_path
252
+
253
+ if not audio_files:
254
+ return "No audio files found! Please provide files via Gradio or specify a folder path.", "", ""
255
+ else:
256
+ try:
257
+ # Loading Whisper
258
+ device = "cuda" if torch.cuda.is_available() else "cpu"
259
+
260
+ # Detect compute type
261
+ if torch.cuda.is_available():
262
+ compute_type = "float16"
263
+ else:
264
+ compute_type = "float32"
265
+
266
+ asr_model = WhisperModel(whisper_model, device=device, compute_type=compute_type)
267
+ train_meta, eval_meta, audio_total_size = format_audio_list(audio_files, asr_model=asr_model, target_language=language, out_path=out_path, gradio_progress=progress)
268
+ except:
269
+ traceback.print_exc()
270
+ error = traceback.format_exc()
271
+ return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", ""
272
+
273
+ # clear_gpu_cache()
274
+
275
+ # if audio total len is less than 2 minutes raise an error
276
+ if audio_total_size < 120:
277
+ message = "The sum of the duration of the audios that you provided should be at least 2 minutes!"
278
+ print(message)
279
+ return message, "", ""
280
+
281
+ print("Dataset Processed!")
282
+ return "Dataset Processed!", train_meta, eval_meta
283
+
284
+
285
+ with gr.Tab("2 - Fine-tuning XTTS Encoder"):
286
+ load_params_btn = gr.Button(value="Load Params from output folder")
287
+ version = gr.Dropdown(
288
+ label="XTTS base version",
289
+ value="v2.0.2",
290
+ choices=[
291
+ "v2.0.3",
292
+ "v2.0.2",
293
+ "v2.0.1",
294
+ "v2.0.0",
295
+ "main"
296
+ ],
297
+ )
298
+ train_csv = gr.Textbox(
299
+ label="Train CSV:",
300
+ )
301
+ eval_csv = gr.Textbox(
302
+ label="Eval CSV:",
303
+ )
304
+ custom_model = gr.Textbox(
305
+ label="(Optional) Custom model.pth file , leave blank if you want to use the base file.",
306
+ value="",
307
+ )
308
+ num_epochs = gr.Slider(
309
+ label="Number of epochs:",
310
+ minimum=1,
311
+ maximum=100,
312
+ step=1,
313
+ value=args.num_epochs,
314
+ )
315
+ batch_size = gr.Slider(
316
+ label="Batch size:",
317
+ minimum=2,
318
+ maximum=512,
319
+ step=1,
320
+ value=args.batch_size,
321
+ )
322
+ grad_acumm = gr.Slider(
323
+ label="Grad accumulation steps:",
324
+ minimum=2,
325
+ maximum=128,
326
+ step=1,
327
+ value=args.grad_acumm,
328
+ )
329
+ max_audio_length = gr.Slider(
330
+ label="Max permitted audio size in seconds:",
331
+ minimum=2,
332
+ maximum=20,
333
+ step=1,
334
+ value=args.max_audio_length,
335
+ )
336
+ clear_train_data = gr.Dropdown(
337
+ label="Clear train data, you will delete selected folder, after optimizing",
338
+ value="none",
339
+ choices=[
340
+ "none",
341
+ "run",
342
+ "dataset",
343
+ "all"
344
+ ])
345
+
346
+ progress_train = gr.Label(
347
+ label="Progress:"
348
+ )
349
+
350
+ # demo.load(read_logs, None, logs_tts_train, every=1)
351
+ train_btn = gr.Button(value="Step 2 - Run the training")
352
+ optimize_model_btn = gr.Button(value="Step 2.5 - Optimize the model")
353
+
354
+ def train_model(custom_model,version,language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length):
355
+ clear_gpu_cache()
356
+
357
+ run_dir = Path(output_path) / "run"
358
+
359
+ # # Remove train dir
360
+ if run_dir.exists():
361
+ os.remove(run_dir)
362
+
363
+ # Check if the dataset language matches the language you specified
364
+ lang_file_path = Path(output_path) / "dataset" / "lang.txt"
365
+
366
+ # Check if lang.txt already exists and contains a different language
367
+ current_language = None
368
+ if lang_file_path.exists():
369
+ with open(lang_file_path, 'r', encoding='utf-8') as existing_lang_file:
370
+ current_language = existing_lang_file.read().strip()
371
+ if current_language != language:
372
+ print("The language that was prepared for the dataset does not match the specified language. Change the language to the one specified in the dataset")
373
+ language = current_language
374
+
375
+ if not train_csv or not eval_csv:
376
+ return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", ""
377
+ try:
378
+ # convert seconds to waveform frames
379
+ max_audio_length = int(max_audio_length * 22050)
380
+ speaker_xtts_path,config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(custom_model,version,language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length)
381
+ except:
382
+ traceback.print_exc()
383
+ error = traceback.format_exc()
384
+ return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", ""
385
+
386
+ # copy original files to avoid parameters changes issues
387
+ # os.system(f"cp {config_path} {exp_path}")
388
+ # os.system(f"cp {vocab_file} {exp_path}")
389
+
390
+ ready_dir = Path(output_path) / "ready"
391
+
392
+ ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth")
393
+
394
+ shutil.copy(ft_xtts_checkpoint, ready_dir / "unoptimize_model.pth")
395
+ # os.remove(ft_xtts_checkpoint)
396
+
397
+ ft_xtts_checkpoint = os.path.join(ready_dir, "unoptimize_model.pth")
398
+
399
+ # Reference
400
+ # Move reference audio to output folder and rename it
401
+ speaker_reference_path = Path(speaker_wav)
402
+ speaker_reference_new_path = ready_dir / "reference.wav"
403
+ shutil.copy(speaker_reference_path, speaker_reference_new_path)
404
+
405
+ print("Model training done!")
406
+ # clear_gpu_cache()
407
+ return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint,speaker_xtts_path, speaker_reference_new_path
408
+
409
+ def optimize_model(out_path, clear_train_data):
410
+ # print(out_path)
411
+ out_path = Path(out_path) # Ensure that out_path is a Path object.
412
+
413
+ ready_dir = out_path / "ready"
414
+ run_dir = out_path / "run"
415
+ dataset_dir = out_path / "dataset"
416
+
417
+ # Clear specified training data directories.
418
+ if clear_train_data in {"run", "all"} and run_dir.exists():
419
+ try:
420
+ shutil.rmtree(run_dir)
421
+ except PermissionError as e:
422
+ print(f"An error occurred while deleting {run_dir}: {e}")
423
+
424
+ if clear_train_data in {"dataset", "all"} and dataset_dir.exists():
425
+ try:
426
+ shutil.rmtree(dataset_dir)
427
+ except PermissionError as e:
428
+ print(f"An error occurred while deleting {dataset_dir}: {e}")
429
+
430
+ # Get full path to model
431
+ model_path = ready_dir / "unoptimize_model.pth"
432
+
433
+ if not model_path.is_file():
434
+ return "Unoptimized model not found in ready folder", ""
435
+
436
+ # Load the checkpoint and remove unnecessary parts.
437
+ checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
438
+ del checkpoint["optimizer"]
439
+
440
+ for key in list(checkpoint["model"].keys()):
441
+ if "dvae" in key:
442
+ del checkpoint["model"][key]
443
+
444
+ # Make sure out_path is a Path object or convert it to Path
445
+ os.remove(model_path)
446
+
447
+ # Save the optimized model.
448
+ optimized_model_file_name="model.pth"
449
+ optimized_model=ready_dir/optimized_model_file_name
450
+
451
+ torch.save(checkpoint, optimized_model)
452
+ ft_xtts_checkpoint=str(optimized_model)
453
+
454
+ clear_gpu_cache()
455
+
456
+ return f"Model optimized and saved at {ft_xtts_checkpoint}!", ft_xtts_checkpoint
457
+
458
+ def load_params(out_path):
459
+ path_output = Path(out_path)
460
+
461
+ dataset_path = path_output / "dataset"
462
+
463
+ if not dataset_path.exists():
464
+ return "The output folder does not exist!", "", ""
465
+
466
+ eval_train = dataset_path / "metadata_train.csv"
467
+ eval_csv = dataset_path / "metadata_eval.csv"
468
+
469
+ # Write the target language to lang.txt in the output directory
470
+ lang_file_path = dataset_path / "lang.txt"
471
+
472
+ # Check if lang.txt already exists and contains a different language
473
+ current_language = None
474
+ if os.path.exists(lang_file_path):
475
+ with open(lang_file_path, 'r', encoding='utf-8') as existing_lang_file:
476
+ current_language = existing_lang_file.read().strip()
477
+
478
+ clear_gpu_cache()
479
+
480
+ print(current_language)
481
+ return "The data has been updated", eval_train, eval_csv, current_language
482
+
483
+ with gr.Tab("3 - Inference"):
484
+ with gr.Row():
485
+ with gr.Column() as col1:
486
+ load_params_tts_btn = gr.Button(value="Load params for TTS from output folder")
487
+ xtts_checkpoint = gr.Textbox(
488
+ label="XTTS checkpoint path:",
489
+ value="",
490
+ )
491
+ xtts_config = gr.Textbox(
492
+ label="XTTS config path:",
493
+ value="",
494
+ )
495
+
496
+ xtts_vocab = gr.Textbox(
497
+ label="XTTS vocab path:",
498
+ value="",
499
+ )
500
+ xtts_speaker = gr.Textbox(
501
+ label="XTTS speaker path:",
502
+ value="",
503
+ )
504
+ progress_load = gr.Label(
505
+ label="Progress:"
506
+ )
507
+ load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model")
508
+
509
+ with gr.Column() as col2:
510
+ speaker_reference_audio = gr.Textbox(
511
+ label="Speaker reference audio:",
512
+ value="",
513
+ )
514
+ tts_language = gr.Dropdown(
515
+ label="Language",
516
+ value="en",
517
+ choices=[
518
+ "en",
519
+ "es",
520
+ "fr",
521
+ "de",
522
+ "it",
523
+ "pt",
524
+ "pl",
525
+ "tr",
526
+ "ru",
527
+ "nl",
528
+ "cs",
529
+ "ar",
530
+ "zh",
531
+ "hu",
532
+ "ko",
533
+ "ja",
534
+ ]
535
+ )
536
+ tts_text = gr.Textbox(
537
+ label="Input Text.",
538
+ value="This model sounds really good and above all, it's reasonably fast.",
539
+ )
540
+ with gr.Accordion("Advanced settings", open=False) as acr:
541
+ temperature = gr.Slider(
542
+ label="temperature",
543
+ minimum=0,
544
+ maximum=1,
545
+ step=0.05,
546
+ value=0.75,
547
+ )
548
+ length_penalty = gr.Slider(
549
+ label="length_penalty",
550
+ minimum=-10.0,
551
+ maximum=10.0,
552
+ step=0.5,
553
+ value=1,
554
+ )
555
+ repetition_penalty = gr.Slider(
556
+ label="repetition penalty",
557
+ minimum=1,
558
+ maximum=10,
559
+ step=0.5,
560
+ value=5,
561
+ )
562
+ top_k = gr.Slider(
563
+ label="top_k",
564
+ minimum=1,
565
+ maximum=100,
566
+ step=1,
567
+ value=50,
568
+ )
569
+ top_p = gr.Slider(
570
+ label="top_p",
571
+ minimum=0,
572
+ maximum=1,
573
+ step=0.05,
574
+ value=0.85,
575
+ )
576
+ sentence_split = gr.Checkbox(
577
+ label="Enable text splitting",
578
+ value=True,
579
+ )
580
+ use_config = gr.Checkbox(
581
+ label="Use Inference settings from config, if disabled use the settings above",
582
+ value=False,
583
+ )
584
+ tts_btn = gr.Button(value="Step 4 - Inference")
585
+
586
+ with gr.Column() as col3:
587
+ progress_gen = gr.Label(
588
+ label="Progress:"
589
+ )
590
+ tts_output_audio = gr.Audio(label="Generated Audio.")
591
+ reference_audio = gr.Audio(label="Reference audio used.")
592
+
593
+ prompt_compute_btn.click(
594
+ fn=preprocess_dataset,
595
+ inputs=[
596
+ upload_file,
597
+ audio_folder_path,
598
+ lang,
599
+ whisper_model,
600
+ out_path,
601
+ train_csv,
602
+ eval_csv
603
+ ],
604
+ outputs=[
605
+ progress_data,
606
+ train_csv,
607
+ eval_csv,
608
+ ],
609
+ )
610
+
611
+
612
+ load_params_btn.click(
613
+ fn=load_params,
614
+ inputs=[out_path],
615
+ outputs=[
616
+ progress_train,
617
+ train_csv,
618
+ eval_csv,
619
+ lang
620
+ ]
621
+ )
622
+
623
+
624
+ train_btn.click(
625
+ fn=train_model,
626
+ inputs=[
627
+ custom_model,
628
+ version,
629
+ lang,
630
+ train_csv,
631
+ eval_csv,
632
+ num_epochs,
633
+ batch_size,
634
+ grad_acumm,
635
+ out_path,
636
+ max_audio_length,
637
+ ],
638
+ outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint,xtts_speaker, speaker_reference_audio],
639
+ )
640
+
641
+ optimize_model_btn.click(
642
+ fn=optimize_model,
643
+ inputs=[
644
+ out_path,
645
+ clear_train_data
646
+ ],
647
+ outputs=[progress_train,xtts_checkpoint],
648
+ )
649
+
650
+ load_btn.click(
651
+ fn=load_model,
652
+ inputs=[
653
+ xtts_checkpoint,
654
+ xtts_config,
655
+ xtts_vocab,
656
+ xtts_speaker
657
+ ],
658
+ outputs=[progress_load],
659
+ )
660
+
661
+ tts_btn.click(
662
+ fn=run_tts,
663
+ inputs=[
664
+ tts_language,
665
+ tts_text,
666
+ speaker_reference_audio,
667
+ temperature,
668
+ length_penalty,
669
+ repetition_penalty,
670
+ top_k,
671
+ top_p,
672
+ sentence_split,
673
+ use_config
674
+ ],
675
+ outputs=[progress_gen, tts_output_audio,reference_audio],
676
+ )
677
+
678
+ load_params_tts_btn.click(
679
+ fn=load_params_tts,
680
+ inputs=[
681
+ out_path,
682
+ version
683
+ ],
684
+ outputs=[progress_load,xtts_checkpoint,xtts_config,xtts_vocab,xtts_speaker,speaker_reference_audio],
685
+ )
686
+
687
+ demo.launch(
688
+ share=False,
689
+ debug=False,
690
+ server_port=args.port,
691
+ # inweb=True,
692
+ # server_name="localhost"
693
+ )