Spaces:
Runtime error
Runtime error
drewThomasson
commited on
Upload 10 files
Browse files- README.md +73 -13
- install.bat +10 -0
- install.sh +13 -0
- requirements.txt +7 -0
- start.bat +5 -0
- start.sh +9 -0
- utils/formatter.py +198 -0
- utils/gpt_train.py +221 -0
- utils/tokenizer.py +869 -0
- xtts_demo.py +693 -0
README.md
CHANGED
@@ -1,13 +1,73 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# xtts-finetune-webui
|
2 |
+
|
3 |
+
This webui is a slightly modified copy of the [official webui](https://github.com/coqui-ai/TTS/pull/3296) for finetune xtts.
|
4 |
+
|
5 |
+
If you are looking for an option for normal XTTS use look here [https://github.com/daswer123/xtts-webui](https://github.com/daswer123/xtts-webui)
|
6 |
+
|
7 |
+
## TODO
|
8 |
+
- [ ] Add the ability to use via console
|
9 |
+
|
10 |
+
## Key features:
|
11 |
+
|
12 |
+
### Data processing
|
13 |
+
|
14 |
+
1. Updated faster-whisper to 0.10.0 with the ability to select a larger-v3 model.
|
15 |
+
2. Changed output folder to output folder inside the main folder.
|
16 |
+
3. If there is already a dataset in the output folder and you want to add new data, you can do so by simply adding new audio, what was there will not be processed again and the new data will be automatically added
|
17 |
+
4. Turn on VAD filter
|
18 |
+
5. After the dataset is created, a file is created that specifies the language of the dataset. This file is read before training so that the language always matches. It is convenient when you restart the interface
|
19 |
+
|
20 |
+
### Fine-tuning XTTS Encoder
|
21 |
+
|
22 |
+
1. Added the ability to select the base model for XTTS, as well as when you re-training does not need to download the model again.
|
23 |
+
2. Added ability to select custom model as base model during training, which will allow finetune already finetune model.
|
24 |
+
3. Added possibility to get optimized version of the model for 1 click ( step 2.5, put optimized version in output folder).
|
25 |
+
4. You can choose whether to delete training folders after you have optimized the model
|
26 |
+
5. When you optimize the model, the example reference audio is moved to the output folder
|
27 |
+
6. Checking for correctness of the specified language and dataset language
|
28 |
+
|
29 |
+
### Inference
|
30 |
+
|
31 |
+
1. Added possibility to customize infer settings during model checking.
|
32 |
+
|
33 |
+
### Other
|
34 |
+
|
35 |
+
1. If you accidentally restart the interface during one of the steps, you can load data to additional buttons
|
36 |
+
2. Removed the display of logs as it was causing problems when restarted
|
37 |
+
3. The finished result is copied to the ready folder, these are fully finished files, you can move them anywhere and use them as a standard model
|
38 |
+
4. Added support for finetune Japanese
|
39 |
+
|
40 |
+
## Changes in webui
|
41 |
+
|
42 |
+
### 1 - Data processing
|
43 |
+
|
44 |
+
![image](https://github.com/daswer123/xtts-finetune-webui/assets/22278673/8f09b829-098b-48f5-9668-832e7319403b)
|
45 |
+
|
46 |
+
### 2 - Fine-tuning XTTS Encoder
|
47 |
+
|
48 |
+
![image](https://github.com/daswer123/xtts-finetune-webui/assets/22278673/897540d9-3a6b-463c-abb8-261c289cc929)
|
49 |
+
|
50 |
+
### 3 - Inference
|
51 |
+
|
52 |
+
![image](https://github.com/daswer123/xtts-finetune-webui/assets/22278673/aa05bcd4-8642-4de4-8f2f-bc0f5571af63)
|
53 |
+
|
54 |
+
## Install
|
55 |
+
|
56 |
+
1. Make sure you have `Cuda` installed
|
57 |
+
2. `git clone https://github.com/daswer123/xtts-finetune-webui`
|
58 |
+
3. `cd xtts-finetune-webui`
|
59 |
+
4. `pip install torch==2.1.1+cu118 torchaudio==2.1.1+cu118 --index-url https://download.pytorch.org/whl/cu118`
|
60 |
+
5. `pip install -r requirements.txt`
|
61 |
+
|
62 |
+
### If you're using Windows
|
63 |
+
|
64 |
+
1. First start `install.bat`
|
65 |
+
2. To start the server start `start.bat`
|
66 |
+
3. Go to the local address `127.0.0.1:5003`
|
67 |
+
|
68 |
+
### On Linux
|
69 |
+
|
70 |
+
1. Run `bash install.sh`
|
71 |
+
2. To start the server start `start.sh`
|
72 |
+
3. Go to the local address `127.0.0.1:5003`
|
73 |
+
~
|
install.bat
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
|
3 |
+
python -m venv venv
|
4 |
+
call venv/scripts/activate
|
5 |
+
|
6 |
+
|
7 |
+
pip install -r .\requirements.txt
|
8 |
+
pip install torch==2.1.1+cu118 torchaudio==2.1.1+cu118 --index-url https://download.pytorch.org/whl/cu118
|
9 |
+
|
10 |
+
python xtts_demo.py
|
install.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Create a Python virtual environment
|
4 |
+
python -m venv venv
|
5 |
+
# Activate the virtual environment
|
6 |
+
source venv/bin/activate
|
7 |
+
|
8 |
+
# Install other dependencies from requirements.txt
|
9 |
+
pip install -r requirements.txt
|
10 |
+
pip install torch==2.1.1+cu118 torchaudio==2.1.1+cu118 --index-url https://download.pytorch.org/whl/cu118
|
11 |
+
|
12 |
+
python xtts_demo.py
|
13 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
faster_whisper==1.0.2
|
2 |
+
gradio==4.13.0
|
3 |
+
spacy==3.7.4
|
4 |
+
coqui-tts[languages] == 0.24.1
|
5 |
+
|
6 |
+
cutlet
|
7 |
+
fugashi[unidic-lite]
|
start.bat
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
|
3 |
+
call venv/scripts/activate
|
4 |
+
|
5 |
+
python xtts_demo.py
|
start.sh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Create a Python virtual environment
|
4 |
+
python -m venv venv
|
5 |
+
# Activate the virtual environment
|
6 |
+
source venv/bin/activate
|
7 |
+
|
8 |
+
python xtts_demo.py
|
9 |
+
|
utils/formatter.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import torchaudio
|
4 |
+
import pandas
|
5 |
+
from faster_whisper import WhisperModel
|
6 |
+
from glob import glob
|
7 |
+
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners
|
11 |
+
# Add support for JA train
|
12 |
+
# from utils.tokenizer import multilingual_cleaners
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torchaudio
|
16 |
+
# torch.set_num_threads(1)
|
17 |
+
|
18 |
+
|
19 |
+
torch.set_num_threads(16)
|
20 |
+
import os
|
21 |
+
|
22 |
+
audio_types = (".wav", ".mp3", ".flac")
|
23 |
+
|
24 |
+
def find_latest_best_model(folder_path):
|
25 |
+
search_path = os.path.join(folder_path, '**', 'best_model.pth')
|
26 |
+
files = glob(search_path, recursive=True)
|
27 |
+
latest_file = max(files, key=os.path.getctime, default=None)
|
28 |
+
return latest_file
|
29 |
+
|
30 |
+
|
31 |
+
def list_audios(basePath, contains=None):
|
32 |
+
# return the set of files that are valid
|
33 |
+
return list_files(basePath, validExts=audio_types, contains=contains)
|
34 |
+
|
35 |
+
def list_files(basePath, validExts=None, contains=None):
|
36 |
+
# loop over the directory structure
|
37 |
+
for (rootDir, dirNames, filenames) in os.walk(basePath):
|
38 |
+
# loop over the filenames in the current directory
|
39 |
+
for filename in filenames:
|
40 |
+
# if the contains string is not none and the filename does not contain
|
41 |
+
# the supplied string, then ignore the file
|
42 |
+
if contains is not None and filename.find(contains) == -1:
|
43 |
+
continue
|
44 |
+
|
45 |
+
# determine the file extension of the current file
|
46 |
+
ext = filename[filename.rfind("."):].lower()
|
47 |
+
|
48 |
+
# check to see if the file is an audio and should be processed
|
49 |
+
if validExts is None or ext.endswith(validExts):
|
50 |
+
# construct the path to the audio and yield it
|
51 |
+
audioPath = os.path.join(rootDir, filename)
|
52 |
+
yield audioPath
|
53 |
+
|
54 |
+
def format_audio_list(audio_files, asr_model, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
|
55 |
+
audio_total_size = 0
|
56 |
+
os.makedirs(out_path, exist_ok=True)
|
57 |
+
|
58 |
+
lang_file_path = os.path.join(out_path, "lang.txt")
|
59 |
+
current_language = None
|
60 |
+
if os.path.exists(lang_file_path):
|
61 |
+
with open(lang_file_path, 'r', encoding='utf-8') as existing_lang_file:
|
62 |
+
current_language = existing_lang_file.read().strip()
|
63 |
+
|
64 |
+
if current_language != target_language:
|
65 |
+
with open(lang_file_path, 'w', encoding='utf-8') as lang_file:
|
66 |
+
lang_file.write(target_language + '\n')
|
67 |
+
print("Warning, existing language does not match target language. Updated lang.txt with target language.")
|
68 |
+
else:
|
69 |
+
print("Existing language matches target language")
|
70 |
+
|
71 |
+
metadata = {"audio_file": [], "text": [], "speaker_name": []}
|
72 |
+
train_metadata_path = os.path.join(out_path, "metadata_train.csv")
|
73 |
+
eval_metadata_path = os.path.join(out_path, "metadata_eval.csv")
|
74 |
+
|
75 |
+
existing_metadata = {'train': None, 'eval': None}
|
76 |
+
if os.path.exists(train_metadata_path):
|
77 |
+
existing_metadata['train'] = pandas.read_csv(train_metadata_path, sep="|")
|
78 |
+
print("Existing training metadata found and loaded.")
|
79 |
+
|
80 |
+
if os.path.exists(eval_metadata_path):
|
81 |
+
existing_metadata['eval'] = pandas.read_csv(eval_metadata_path, sep="|")
|
82 |
+
print("Existing evaluation metadata found and loaded.")
|
83 |
+
|
84 |
+
if gradio_progress is not None:
|
85 |
+
tqdm_object = gradio_progress.tqdm(audio_files, desc="Formatting...")
|
86 |
+
else:
|
87 |
+
tqdm_object = tqdm(audio_files)
|
88 |
+
|
89 |
+
for audio_path in tqdm_object:
|
90 |
+
audio_file_name_without_ext, _= os.path.splitext(os.path.basename(audio_path))
|
91 |
+
prefix_check = f"wavs/{audio_file_name_without_ext}_"
|
92 |
+
|
93 |
+
skip_processing = False
|
94 |
+
for key in ['train', 'eval']:
|
95 |
+
if existing_metadata[key] is not None:
|
96 |
+
mask = existing_metadata[key]['audio_file'].str.startswith(prefix_check)
|
97 |
+
if mask.any():
|
98 |
+
print(f"Segments from {audio_file_name_without_ext} have been previously processed; skipping...")
|
99 |
+
skip_processing = True
|
100 |
+
break
|
101 |
+
|
102 |
+
if skip_processing:
|
103 |
+
continue
|
104 |
+
|
105 |
+
wav, sr = torchaudio.load(audio_path)
|
106 |
+
if wav.size(0) != 1:
|
107 |
+
wav = torch.mean(wav, dim=0, keepdim=True)
|
108 |
+
|
109 |
+
wav = wav.squeeze()
|
110 |
+
audio_total_size += (wav.size(-1) / sr)
|
111 |
+
|
112 |
+
segments, _= asr_model.transcribe(audio_path, vad_filter=True, word_timestamps=True, language=target_language)
|
113 |
+
segments = list(segments)
|
114 |
+
i = 0
|
115 |
+
sentence = ""
|
116 |
+
sentence_start = None
|
117 |
+
first_word = True
|
118 |
+
words_list = []
|
119 |
+
for _, segment in enumerate(segments):
|
120 |
+
words = list(segment.words)
|
121 |
+
words_list.extend(words)
|
122 |
+
|
123 |
+
for word_idx, word in enumerate(words_list):
|
124 |
+
if first_word:
|
125 |
+
sentence_start = word.start
|
126 |
+
if word_idx == 0:
|
127 |
+
sentence_start = max(sentence_start - buffer, 0)
|
128 |
+
else:
|
129 |
+
previous_word_end = words_list[word_idx - 1].end
|
130 |
+
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start) / 2)
|
131 |
+
|
132 |
+
sentence = word.word
|
133 |
+
first_word = False
|
134 |
+
else:
|
135 |
+
sentence += word.word
|
136 |
+
|
137 |
+
if word.word[-1] in ["!", "。", ".", "?"]:
|
138 |
+
sentence = sentence[1:]
|
139 |
+
sentence = multilingual_cleaners(sentence, target_language)
|
140 |
+
audio_file_name, _= os.path.splitext(os.path.basename(audio_path))
|
141 |
+
audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}.wav"
|
142 |
+
|
143 |
+
if word_idx + 1 < len(words_list):
|
144 |
+
next_word_start = words_list[word_idx + 1].start
|
145 |
+
else:
|
146 |
+
next_word_start = (wav.shape[0] - 1) / sr
|
147 |
+
|
148 |
+
word_end = min((word.end + next_word_start) / 2, word.end + buffer)
|
149 |
+
|
150 |
+
absolute_path = os.path.join(out_path, audio_file)
|
151 |
+
os.makedirs(os.path.dirname(absolute_path), exist_ok=True)
|
152 |
+
i += 1
|
153 |
+
first_word = True
|
154 |
+
|
155 |
+
audio = wav[int(sr*sentence_start):int(sr *word_end)].unsqueeze(0)
|
156 |
+
if audio.size(-1) >= sr / 3:
|
157 |
+
torchaudio.save(absolute_path, audio, sr)
|
158 |
+
else:
|
159 |
+
continue
|
160 |
+
|
161 |
+
metadata["audio_file"].append(audio_file)
|
162 |
+
metadata["text"].append(sentence)
|
163 |
+
metadata["speaker_name"].append(speaker_name)
|
164 |
+
|
165 |
+
df = pandas.DataFrame(metadata)
|
166 |
+
|
167 |
+
mode = 'w' if not os.path.exists(train_metadata_path) else 'a'
|
168 |
+
header = not os.path.exists(train_metadata_path)
|
169 |
+
df.to_csv(train_metadata_path, sep="|", index=False, mode=mode, header=header)
|
170 |
+
|
171 |
+
mode = 'w' if not os.path.exists(eval_metadata_path) else 'a'
|
172 |
+
header = not os.path.exists(eval_metadata_path)
|
173 |
+
df.to_csv(eval_metadata_path, sep="|", index=False, mode=mode, header=header)
|
174 |
+
|
175 |
+
metadata = {"audio_file": [], "text": [], "speaker_name": []}
|
176 |
+
|
177 |
+
if os.path.exists(train_metadata_path) and os.path.exists(eval_metadata_path):
|
178 |
+
existing_train_df = existing_metadata['train']
|
179 |
+
existing_eval_df = existing_metadata['eval']
|
180 |
+
else:
|
181 |
+
existing_train_df = pandas.DataFrame(columns=["audio_file", "text", "speaker_name"])
|
182 |
+
existing_eval_df = pandas.DataFrame(columns=["audio_file", "text", "speaker_name"])
|
183 |
+
|
184 |
+
new_data_df = pandas.read_csv(train_metadata_path, sep="|")
|
185 |
+
|
186 |
+
combined_train_df = pandas.concat([existing_train_df, new_data_df], ignore_index=True).drop_duplicates().reset_index(drop=True)
|
187 |
+
combined_eval_df = pandas.concat([existing_eval_df, new_data_df], ignore_index=True).drop_duplicates().reset_index(drop=True)
|
188 |
+
|
189 |
+
combined_train_df_shuffled = combined_train_df.sample(frac=1)
|
190 |
+
num_val_samples = int(len(combined_train_df_shuffled)* eval_percentage)
|
191 |
+
|
192 |
+
final_eval_set = combined_train_df_shuffled[:num_val_samples]
|
193 |
+
final_training_set = combined_train_df_shuffled[num_val_samples:]
|
194 |
+
|
195 |
+
final_training_set.sort_values('audio_file').to_csv(train_metadata_path, sep='|', index=False)
|
196 |
+
final_eval_set.sort_values('audio_file').to_csv(eval_metadata_path, sep='|', index=False)
|
197 |
+
|
198 |
+
return train_metadata_path, eval_metadata_path, audio_total_size
|
utils/gpt_train.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import gc
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from trainer import Trainer, TrainerArgs
|
7 |
+
|
8 |
+
from TTS.config.shared_configs import BaseDatasetConfig
|
9 |
+
from TTS.tts.datasets import load_tts_samples
|
10 |
+
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
|
11 |
+
from TTS.utils.manage import ModelManager
|
12 |
+
import shutil
|
13 |
+
|
14 |
+
|
15 |
+
def train_gpt(custom_model,version, language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path, max_audio_length=255995):
|
16 |
+
# Logging parameters
|
17 |
+
RUN_NAME = "GPT_XTTS_FT"
|
18 |
+
PROJECT_NAME = "XTTS_trainer"
|
19 |
+
DASHBOARD_LOGGER = "tensorboard"
|
20 |
+
LOGGER_URI = None
|
21 |
+
|
22 |
+
# print(f"XTTS version = {version}")
|
23 |
+
|
24 |
+
# Set here the path that the checkpoints will be saved. Default: ./run/training/
|
25 |
+
OUT_PATH = os.path.join(output_path, "run", "training")
|
26 |
+
|
27 |
+
# Training Parameters
|
28 |
+
OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False
|
29 |
+
START_WITH_EVAL = False # if True it will star with evaluation
|
30 |
+
BATCH_SIZE = batch_size # set here the batch size
|
31 |
+
GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
|
32 |
+
|
33 |
+
|
34 |
+
# Define here the dataset that you want to use for the fine-tuning on.
|
35 |
+
config_dataset = BaseDatasetConfig(
|
36 |
+
formatter="coqui",
|
37 |
+
dataset_name="ft_dataset",
|
38 |
+
path=os.path.dirname(train_csv),
|
39 |
+
meta_file_train=train_csv,
|
40 |
+
meta_file_val=eval_csv,
|
41 |
+
language=language,
|
42 |
+
)
|
43 |
+
|
44 |
+
# Add here the configs of the datasets
|
45 |
+
DATASETS_CONFIG_LIST = [config_dataset]
|
46 |
+
|
47 |
+
# Define the path where XTTS v2.0.1 files will be downloaded
|
48 |
+
CHECKPOINTS_OUT_PATH = os.path.join(Path.cwd(), "base_models",f"{version}")
|
49 |
+
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
|
50 |
+
|
51 |
+
|
52 |
+
# DVAE files
|
53 |
+
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
|
54 |
+
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
|
55 |
+
|
56 |
+
# Set the path to the downloaded files
|
57 |
+
DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(DVAE_CHECKPOINT_LINK))
|
58 |
+
MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(MEL_NORM_LINK))
|
59 |
+
|
60 |
+
# download DVAE files if needed
|
61 |
+
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
|
62 |
+
print(" > Downloading DVAE files!")
|
63 |
+
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
|
64 |
+
|
65 |
+
|
66 |
+
# Download XTTS v2.0 checkpoint if needed
|
67 |
+
TOKENIZER_FILE_LINK = f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{version}/vocab.json"
|
68 |
+
XTTS_CHECKPOINT_LINK = f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{version}/model.pth"
|
69 |
+
XTTS_CONFIG_LINK = f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{version}/config.json"
|
70 |
+
XTTS_SPEAKER_LINK = f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/speakers_xtts.pth"
|
71 |
+
|
72 |
+
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
|
73 |
+
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file
|
74 |
+
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file
|
75 |
+
XTTS_CONFIG_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CONFIG_LINK)) # config.json file
|
76 |
+
XTTS_SPEAKER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_SPEAKER_LINK)) # speakers_xtts.pth file
|
77 |
+
|
78 |
+
# download XTTS v2.0 files if needed
|
79 |
+
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
|
80 |
+
print(f" > Downloading XTTS v{version} files!")
|
81 |
+
ModelManager._download_model_files(
|
82 |
+
[TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK, XTTS_CONFIG_LINK,XTTS_SPEAKER_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
|
83 |
+
)
|
84 |
+
|
85 |
+
# Transfer this files to ready folder
|
86 |
+
READY_MODEL_PATH = os.path.join(output_path,"ready")
|
87 |
+
if not os.path.exists(READY_MODEL_PATH):
|
88 |
+
os.makedirs(READY_MODEL_PATH)
|
89 |
+
|
90 |
+
NEW_TOKENIZER_FILE = os.path.join(READY_MODEL_PATH, "vocab.json")
|
91 |
+
# NEW_XTTS_CHECKPOINT = os.path.join(READY_MODEL_PATH, "model.pth")
|
92 |
+
NEW_XTTS_CONFIG_FILE = os.path.join(READY_MODEL_PATH, "config.json")
|
93 |
+
NEW_XTTS_SPEAKER_FILE = os.path.join(READY_MODEL_PATH, "speakers_xtts.pth")
|
94 |
+
|
95 |
+
shutil.copy(TOKENIZER_FILE, NEW_TOKENIZER_FILE)
|
96 |
+
# shutil.copy(XTTS_CHECKPOINT, os.path.join(READY_MODEL_PATH, "model.pth"))
|
97 |
+
shutil.copy(XTTS_CONFIG_FILE, NEW_XTTS_CONFIG_FILE)
|
98 |
+
shutil.copy(XTTS_SPEAKER_FILE, NEW_XTTS_SPEAKER_FILE)
|
99 |
+
|
100 |
+
# Use from ready folder
|
101 |
+
TOKENIZER_FILE = NEW_TOKENIZER_FILE # vocab.json file
|
102 |
+
# XTTS_CHECKPOINT = NEW_XTTS_CHECKPOINT # model.pth file
|
103 |
+
XTTS_CONFIG_FILE = NEW_XTTS_CONFIG_FILE # config.json file
|
104 |
+
XTTS_SPEAKER_FILE = NEW_XTTS_SPEAKER_FILE # speakers_xtts.pth file
|
105 |
+
|
106 |
+
|
107 |
+
if custom_model != "":
|
108 |
+
if os.path.exists(custom_model) and custom_model.endswith('.pth'):
|
109 |
+
XTTS_CHECKPOINT = custom_model
|
110 |
+
print(f" > Loading custom model: {XTTS_CHECKPOINT}")
|
111 |
+
else:
|
112 |
+
print(" > Error: The specified custom model is not a valid .pth file path.")
|
113 |
+
|
114 |
+
num_workers = 8
|
115 |
+
if language == "ja":
|
116 |
+
num_workers = 0
|
117 |
+
# init args and config
|
118 |
+
model_args = GPTArgs(
|
119 |
+
max_conditioning_length=132300, # 6 secs
|
120 |
+
min_conditioning_length=66150, # 3 secs
|
121 |
+
debug_loading_failures=False,
|
122 |
+
max_wav_length=max_audio_length, # ~11.6 seconds
|
123 |
+
max_text_length=200,
|
124 |
+
mel_norm_file=MEL_NORM_FILE,
|
125 |
+
dvae_checkpoint=DVAE_CHECKPOINT,
|
126 |
+
xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune
|
127 |
+
tokenizer_file=TOKENIZER_FILE,
|
128 |
+
gpt_num_audio_tokens=1026,
|
129 |
+
gpt_start_audio_token=1024,
|
130 |
+
gpt_stop_audio_token=1025,
|
131 |
+
gpt_use_masking_gt_prompt_approach=True,
|
132 |
+
gpt_use_perceiver_resampler=True,
|
133 |
+
)
|
134 |
+
# define audio config
|
135 |
+
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
|
136 |
+
# training parameters config
|
137 |
+
config = GPTTrainerConfig(
|
138 |
+
epochs=num_epochs,
|
139 |
+
output_path=OUT_PATH,
|
140 |
+
model_args=model_args,
|
141 |
+
run_name=RUN_NAME,
|
142 |
+
project_name=PROJECT_NAME,
|
143 |
+
run_description="""
|
144 |
+
GPT XTTS training
|
145 |
+
""",
|
146 |
+
dashboard_logger=DASHBOARD_LOGGER,
|
147 |
+
logger_uri=LOGGER_URI,
|
148 |
+
audio=audio_config,
|
149 |
+
batch_size=BATCH_SIZE,
|
150 |
+
batch_group_size=48,
|
151 |
+
eval_batch_size=BATCH_SIZE,
|
152 |
+
num_loader_workers=num_workers,
|
153 |
+
eval_split_max_size=256,
|
154 |
+
print_step=50,
|
155 |
+
plot_step=100,
|
156 |
+
log_model_step=100,
|
157 |
+
save_step=1000,
|
158 |
+
save_n_checkpoints=1,
|
159 |
+
save_checkpoints=True,
|
160 |
+
# target_loss="loss",
|
161 |
+
print_eval=False,
|
162 |
+
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
|
163 |
+
optimizer="AdamW",
|
164 |
+
optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS,
|
165 |
+
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
|
166 |
+
lr=5e-06, # learning rate
|
167 |
+
lr_scheduler="MultiStepLR",
|
168 |
+
# it was adjusted accordly for the new step scheme
|
169 |
+
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
|
170 |
+
test_sentences=[],
|
171 |
+
)
|
172 |
+
|
173 |
+
# init the model from config
|
174 |
+
model = GPTTrainer.init_from_config(config)
|
175 |
+
|
176 |
+
# load training samples
|
177 |
+
train_samples, eval_samples = load_tts_samples(
|
178 |
+
DATASETS_CONFIG_LIST,
|
179 |
+
eval_split=True,
|
180 |
+
eval_split_max_size=config.eval_split_max_size,
|
181 |
+
eval_split_size=config.eval_split_size,
|
182 |
+
)
|
183 |
+
|
184 |
+
# init the trainer and 🚀
|
185 |
+
trainer = Trainer(
|
186 |
+
TrainerArgs(
|
187 |
+
restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter
|
188 |
+
skip_train_epoch=False,
|
189 |
+
start_with_eval=START_WITH_EVAL,
|
190 |
+
grad_accum_steps=GRAD_ACUMM_STEPS,
|
191 |
+
),
|
192 |
+
config,
|
193 |
+
output_path=OUT_PATH,
|
194 |
+
model=model,
|
195 |
+
train_samples=train_samples,
|
196 |
+
eval_samples=eval_samples,
|
197 |
+
)
|
198 |
+
trainer.fit()
|
199 |
+
|
200 |
+
# get the longest text audio file to use as speaker reference
|
201 |
+
samples_len = [len(item["text"].split(" ")) for item in train_samples]
|
202 |
+
longest_text_idx = samples_len.index(max(samples_len))
|
203 |
+
speaker_ref = train_samples[longest_text_idx]["audio_file"]
|
204 |
+
|
205 |
+
trainer_out_path = trainer.output_path
|
206 |
+
|
207 |
+
# close file handlers and remove them from the logger
|
208 |
+
for handler in logging.getLogger('trainer').handlers:
|
209 |
+
if isinstance(handler, logging.FileHandler):
|
210 |
+
handler.close()
|
211 |
+
logging.getLogger('trainer').removeHandler(handler)
|
212 |
+
|
213 |
+
# now you should be able to delete the log file
|
214 |
+
log_file = os.path.join(trainer.output_path, f"trainer_{trainer.args.rank}_log.txt")
|
215 |
+
os.remove(log_file)
|
216 |
+
|
217 |
+
# deallocate VRAM and RAM
|
218 |
+
del model, trainer, train_samples, eval_samples
|
219 |
+
gc.collect()
|
220 |
+
|
221 |
+
return XTTS_SPEAKER_FILE,XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer_out_path, speaker_ref
|
utils/tokenizer.py
ADDED
@@ -0,0 +1,869 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import textwrap
|
4 |
+
from functools import cached_property
|
5 |
+
|
6 |
+
import pypinyin
|
7 |
+
import torch
|
8 |
+
from hangul_romanize import Transliter
|
9 |
+
from hangul_romanize.rule import academic
|
10 |
+
from num2words import num2words
|
11 |
+
from spacy.lang.ar import Arabic
|
12 |
+
from spacy.lang.en import English
|
13 |
+
from spacy.lang.es import Spanish
|
14 |
+
from spacy.lang.ja import Japanese
|
15 |
+
from spacy.lang.zh import Chinese
|
16 |
+
from tokenizers import Tokenizer
|
17 |
+
|
18 |
+
from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words
|
19 |
+
|
20 |
+
|
21 |
+
def get_spacy_lang(lang):
|
22 |
+
if lang == "zh":
|
23 |
+
return Chinese()
|
24 |
+
elif lang == "ja":
|
25 |
+
return Japanese()
|
26 |
+
elif lang == "ar":
|
27 |
+
return Arabic()
|
28 |
+
elif lang == "es":
|
29 |
+
return Spanish()
|
30 |
+
else:
|
31 |
+
# For most languages, Enlish does the job
|
32 |
+
return English()
|
33 |
+
|
34 |
+
|
35 |
+
def split_sentence(text, lang, text_split_length=250):
|
36 |
+
"""Preprocess the input text"""
|
37 |
+
text_splits = []
|
38 |
+
if text_split_length is not None and len(text) >= text_split_length:
|
39 |
+
text_splits.append("")
|
40 |
+
nlp = get_spacy_lang(lang)
|
41 |
+
nlp.add_pipe("sentencizer")
|
42 |
+
doc = nlp(text)
|
43 |
+
for sentence in doc.sents:
|
44 |
+
if len(text_splits[-1]) + len(str(sentence)) <= text_split_length:
|
45 |
+
# if the last sentence + the current sentence is less than the text_split_length
|
46 |
+
# then add the current sentence to the last sentence
|
47 |
+
text_splits[-1] += " " + str(sentence)
|
48 |
+
text_splits[-1] = text_splits[-1].lstrip()
|
49 |
+
elif len(str(sentence)) > text_split_length:
|
50 |
+
# if the current sentence is greater than the text_split_length
|
51 |
+
for line in textwrap.wrap(
|
52 |
+
str(sentence),
|
53 |
+
width=text_split_length,
|
54 |
+
drop_whitespace=True,
|
55 |
+
break_on_hyphens=False,
|
56 |
+
tabsize=1,
|
57 |
+
):
|
58 |
+
text_splits.append(str(line))
|
59 |
+
else:
|
60 |
+
text_splits.append(str(sentence))
|
61 |
+
|
62 |
+
if len(text_splits) > 1:
|
63 |
+
if text_splits[0] == "":
|
64 |
+
del text_splits[0]
|
65 |
+
else:
|
66 |
+
text_splits = [text.lstrip()]
|
67 |
+
|
68 |
+
return text_splits
|
69 |
+
|
70 |
+
|
71 |
+
_whitespace_re = re.compile(r"\s+")
|
72 |
+
|
73 |
+
# List of (regular expression, replacement) pairs for abbreviations:
|
74 |
+
_abbreviations = {
|
75 |
+
"en": [
|
76 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
77 |
+
for x in [
|
78 |
+
("mrs", "misess"),
|
79 |
+
("mr", "mister"),
|
80 |
+
("dr", "doctor"),
|
81 |
+
("st", "saint"),
|
82 |
+
("co", "company"),
|
83 |
+
("jr", "junior"),
|
84 |
+
("maj", "major"),
|
85 |
+
("gen", "general"),
|
86 |
+
("drs", "doctors"),
|
87 |
+
("rev", "reverend"),
|
88 |
+
("lt", "lieutenant"),
|
89 |
+
("hon", "honorable"),
|
90 |
+
("sgt", "sergeant"),
|
91 |
+
("capt", "captain"),
|
92 |
+
("esq", "esquire"),
|
93 |
+
("ltd", "limited"),
|
94 |
+
("col", "colonel"),
|
95 |
+
("ft", "fort"),
|
96 |
+
]
|
97 |
+
],
|
98 |
+
"es": [
|
99 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
100 |
+
for x in [
|
101 |
+
("sra", "señora"),
|
102 |
+
("sr", "señor"),
|
103 |
+
("dr", "doctor"),
|
104 |
+
("dra", "doctora"),
|
105 |
+
("st", "santo"),
|
106 |
+
("co", "compañía"),
|
107 |
+
("jr", "junior"),
|
108 |
+
("ltd", "limitada"),
|
109 |
+
]
|
110 |
+
],
|
111 |
+
"fr": [
|
112 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
113 |
+
for x in [
|
114 |
+
("mme", "madame"),
|
115 |
+
("mr", "monsieur"),
|
116 |
+
("dr", "docteur"),
|
117 |
+
("st", "saint"),
|
118 |
+
("co", "compagnie"),
|
119 |
+
("jr", "junior"),
|
120 |
+
("ltd", "limitée"),
|
121 |
+
]
|
122 |
+
],
|
123 |
+
"de": [
|
124 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
125 |
+
for x in [
|
126 |
+
("fr", "frau"),
|
127 |
+
("dr", "doktor"),
|
128 |
+
("st", "sankt"),
|
129 |
+
("co", "firma"),
|
130 |
+
("jr", "junior"),
|
131 |
+
]
|
132 |
+
],
|
133 |
+
"pt": [
|
134 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
135 |
+
for x in [
|
136 |
+
("sra", "senhora"),
|
137 |
+
("sr", "senhor"),
|
138 |
+
("dr", "doutor"),
|
139 |
+
("dra", "doutora"),
|
140 |
+
("st", "santo"),
|
141 |
+
("co", "companhia"),
|
142 |
+
("jr", "júnior"),
|
143 |
+
("ltd", "limitada"),
|
144 |
+
]
|
145 |
+
],
|
146 |
+
"it": [
|
147 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
148 |
+
for x in [
|
149 |
+
# ("sig.ra", "signora"),
|
150 |
+
("sig", "signore"),
|
151 |
+
("dr", "dottore"),
|
152 |
+
("st", "santo"),
|
153 |
+
("co", "compagnia"),
|
154 |
+
("jr", "junior"),
|
155 |
+
("ltd", "limitata"),
|
156 |
+
]
|
157 |
+
],
|
158 |
+
"pl": [
|
159 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
160 |
+
for x in [
|
161 |
+
("p", "pani"),
|
162 |
+
("m", "pan"),
|
163 |
+
("dr", "doktor"),
|
164 |
+
("sw", "święty"),
|
165 |
+
("jr", "junior"),
|
166 |
+
]
|
167 |
+
],
|
168 |
+
"ar": [
|
169 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
170 |
+
for x in [
|
171 |
+
# There are not many common abbreviations in Arabic as in English.
|
172 |
+
]
|
173 |
+
],
|
174 |
+
"zh": [
|
175 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
176 |
+
for x in [
|
177 |
+
# Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
|
178 |
+
]
|
179 |
+
],
|
180 |
+
"cs": [
|
181 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
182 |
+
for x in [
|
183 |
+
("dr", "doktor"), # doctor
|
184 |
+
("ing", "inženýr"), # engineer
|
185 |
+
("p", "pan"), # Could also map to pani for woman but no easy way to do it
|
186 |
+
# Other abbreviations would be specialized and not as common.
|
187 |
+
]
|
188 |
+
],
|
189 |
+
"ru": [
|
190 |
+
(re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
|
191 |
+
for x in [
|
192 |
+
("г-жа", "госпожа"), # Mrs.
|
193 |
+
("г-н", "господин"), # Mr.
|
194 |
+
("д-р", "доктор"), # doctor
|
195 |
+
# Other abbreviations are less common or specialized.
|
196 |
+
]
|
197 |
+
],
|
198 |
+
"nl": [
|
199 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
200 |
+
for x in [
|
201 |
+
("dhr", "de heer"), # Mr.
|
202 |
+
("mevr", "mevrouw"), # Mrs.
|
203 |
+
("dr", "dokter"), # doctor
|
204 |
+
("jhr", "jonkheer"), # young lord or nobleman
|
205 |
+
# Dutch uses more abbreviations, but these are the most common ones.
|
206 |
+
]
|
207 |
+
],
|
208 |
+
"tr": [
|
209 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
210 |
+
for x in [
|
211 |
+
("b", "bay"), # Mr.
|
212 |
+
("byk", "büyük"), # büyük
|
213 |
+
("dr", "doktor"), # doctor
|
214 |
+
# Add other Turkish abbreviations here if needed.
|
215 |
+
]
|
216 |
+
],
|
217 |
+
"hu": [
|
218 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
219 |
+
for x in [
|
220 |
+
("dr", "doktor"), # doctor
|
221 |
+
("b", "bácsi"), # Mr.
|
222 |
+
("nőv", "nővér"), # nurse
|
223 |
+
# Add other Hungarian abbreviations here if needed.
|
224 |
+
]
|
225 |
+
],
|
226 |
+
"ko": [
|
227 |
+
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
228 |
+
for x in [
|
229 |
+
# Korean doesn't typically use abbreviations in the same way as Latin-based scripts.
|
230 |
+
]
|
231 |
+
],
|
232 |
+
"ja": [
|
233 |
+
(re.compile("\\b%s\\b" % x[0]), x[1])
|
234 |
+
for x in [
|
235 |
+
("氏", "さん"), # Mr.
|
236 |
+
("夫人", "おんなのひと"), # Mrs.
|
237 |
+
("博士", "はかせ"), # Doctor or PhD
|
238 |
+
("株", "株式会社"), # Corporation
|
239 |
+
("有", "有限会社"), # Limited company
|
240 |
+
("大学", "だいがく"), # University
|
241 |
+
("先生", "せんせい"), # Teacher/Professor/Master
|
242 |
+
("君", "くん") # Used at the end of boys' names to express familiarity or affection.
|
243 |
+
]
|
244 |
+
],
|
245 |
+
}
|
246 |
+
|
247 |
+
|
248 |
+
def expand_abbreviations_multilingual(text, lang="en"):
|
249 |
+
for regex, replacement in _abbreviations[lang]:
|
250 |
+
text = re.sub(regex, replacement, text)
|
251 |
+
return text
|
252 |
+
|
253 |
+
|
254 |
+
_symbols_multilingual = {
|
255 |
+
"en": [
|
256 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
257 |
+
for x in [
|
258 |
+
("&", " and "),
|
259 |
+
("@", " at "),
|
260 |
+
("%", " percent "),
|
261 |
+
("#", " hash "),
|
262 |
+
("$", " dollar "),
|
263 |
+
("£", " pound "),
|
264 |
+
("°", " degree "),
|
265 |
+
]
|
266 |
+
],
|
267 |
+
"es": [
|
268 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
269 |
+
for x in [
|
270 |
+
("&", " y "),
|
271 |
+
("@", " arroba "),
|
272 |
+
("%", " por ciento "),
|
273 |
+
("#", " numeral "),
|
274 |
+
("$", " dolar "),
|
275 |
+
("£", " libra "),
|
276 |
+
("°", " grados "),
|
277 |
+
]
|
278 |
+
],
|
279 |
+
"fr": [
|
280 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
281 |
+
for x in [
|
282 |
+
("&", " et "),
|
283 |
+
("@", " arobase "),
|
284 |
+
("%", " pour cent "),
|
285 |
+
("#", " dièse "),
|
286 |
+
("$", " dollar "),
|
287 |
+
("£", " livre "),
|
288 |
+
("°", " degrés "),
|
289 |
+
]
|
290 |
+
],
|
291 |
+
"de": [
|
292 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
293 |
+
for x in [
|
294 |
+
("&", " und "),
|
295 |
+
("@", " at "),
|
296 |
+
("%", " prozent "),
|
297 |
+
("#", " raute "),
|
298 |
+
("$", " dollar "),
|
299 |
+
("£", " pfund "),
|
300 |
+
("°", " grad "),
|
301 |
+
]
|
302 |
+
],
|
303 |
+
"pt": [
|
304 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
305 |
+
for x in [
|
306 |
+
("&", " e "),
|
307 |
+
("@", " arroba "),
|
308 |
+
("%", " por cento "),
|
309 |
+
("#", " cardinal "),
|
310 |
+
("$", " dólar "),
|
311 |
+
("£", " libra "),
|
312 |
+
("°", " graus "),
|
313 |
+
]
|
314 |
+
],
|
315 |
+
"it": [
|
316 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
317 |
+
for x in [
|
318 |
+
("&", " e "),
|
319 |
+
("@", " chiocciola "),
|
320 |
+
("%", " per cento "),
|
321 |
+
("#", " cancelletto "),
|
322 |
+
("$", " dollaro "),
|
323 |
+
("£", " sterlina "),
|
324 |
+
("°", " gradi "),
|
325 |
+
]
|
326 |
+
],
|
327 |
+
"pl": [
|
328 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
329 |
+
for x in [
|
330 |
+
("&", " i "),
|
331 |
+
("@", " małpa "),
|
332 |
+
("%", " procent "),
|
333 |
+
("#", " krzyżyk "),
|
334 |
+
("$", " dolar "),
|
335 |
+
("£", " funt "),
|
336 |
+
("°", " stopnie "),
|
337 |
+
]
|
338 |
+
],
|
339 |
+
"ar": [
|
340 |
+
# Arabic
|
341 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
342 |
+
for x in [
|
343 |
+
("&", " و "),
|
344 |
+
("@", " على "),
|
345 |
+
("%", " في المئة "),
|
346 |
+
("#", " رقم "),
|
347 |
+
("$", " دولار "),
|
348 |
+
("£", " جنيه "),
|
349 |
+
("°", " درجة "),
|
350 |
+
]
|
351 |
+
],
|
352 |
+
"zh": [
|
353 |
+
# Chinese
|
354 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
355 |
+
for x in [
|
356 |
+
("&", " 和 "),
|
357 |
+
("@", " 在 "),
|
358 |
+
("%", " 百分之 "),
|
359 |
+
("#", " 号 "),
|
360 |
+
("$", " 美元 "),
|
361 |
+
("£", " 英镑 "),
|
362 |
+
("°", " 度 "),
|
363 |
+
]
|
364 |
+
],
|
365 |
+
"cs": [
|
366 |
+
# Czech
|
367 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
368 |
+
for x in [
|
369 |
+
("&", " a "),
|
370 |
+
("@", " na "),
|
371 |
+
("%", " procento "),
|
372 |
+
("#", " křížek "),
|
373 |
+
("$", " dolar "),
|
374 |
+
("£", " libra "),
|
375 |
+
("°", " stupně "),
|
376 |
+
]
|
377 |
+
],
|
378 |
+
"ru": [
|
379 |
+
# Russian
|
380 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
381 |
+
for x in [
|
382 |
+
("&", " и "),
|
383 |
+
("@", " собака "),
|
384 |
+
("%", " процентов "),
|
385 |
+
("#", " номер "),
|
386 |
+
("$", " доллар "),
|
387 |
+
("£", " фунт "),
|
388 |
+
("°", " градус "),
|
389 |
+
]
|
390 |
+
],
|
391 |
+
"nl": [
|
392 |
+
# Dutch
|
393 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
394 |
+
for x in [
|
395 |
+
("&", " en "),
|
396 |
+
("@", " bij "),
|
397 |
+
("%", " procent "),
|
398 |
+
("#", " hekje "),
|
399 |
+
("$", " dollar "),
|
400 |
+
("£", " pond "),
|
401 |
+
("°", " graden "),
|
402 |
+
]
|
403 |
+
],
|
404 |
+
"tr": [
|
405 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
406 |
+
for x in [
|
407 |
+
("&", " ve "),
|
408 |
+
("@", " at "),
|
409 |
+
("%", " yüzde "),
|
410 |
+
("#", " diyez "),
|
411 |
+
("$", " dolar "),
|
412 |
+
("£", " sterlin "),
|
413 |
+
("°", " derece "),
|
414 |
+
]
|
415 |
+
],
|
416 |
+
"hu": [
|
417 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
418 |
+
for x in [
|
419 |
+
("&", " és "),
|
420 |
+
("@", " kukac "),
|
421 |
+
("%", " százalék "),
|
422 |
+
("#", " kettőskereszt "),
|
423 |
+
("$", " dollár "),
|
424 |
+
("£", " font "),
|
425 |
+
("°", " fok "),
|
426 |
+
]
|
427 |
+
],
|
428 |
+
"ko": [
|
429 |
+
# Korean
|
430 |
+
(re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
|
431 |
+
for x in [
|
432 |
+
("&", " 그리고 "),
|
433 |
+
("@", " 에 "),
|
434 |
+
("%", " 퍼센트 "),
|
435 |
+
("#", " 번호 "),
|
436 |
+
("$", " 달러 "),
|
437 |
+
("£", " 파운드 "),
|
438 |
+
("°", " 도 "),
|
439 |
+
]
|
440 |
+
],
|
441 |
+
"ja": [
|
442 |
+
(re.compile(r"%s" % re.escape(x[0])), x[1])
|
443 |
+
for x in [
|
444 |
+
("&", " と "),
|
445 |
+
("@", " アットマーク "),
|
446 |
+
("%", " パーセント "),
|
447 |
+
("#", " ナンバー "),
|
448 |
+
("$", " ドル "),
|
449 |
+
("£", " ポンド "),
|
450 |
+
("°", " 度"),
|
451 |
+
]
|
452 |
+
],
|
453 |
+
}
|
454 |
+
|
455 |
+
|
456 |
+
def expand_symbols_multilingual(text, lang="en"):
|
457 |
+
for regex, replacement in _symbols_multilingual[lang]:
|
458 |
+
text = re.sub(regex, replacement, text)
|
459 |
+
text = text.replace(" ", " ") # Ensure there are no double spaces
|
460 |
+
return text.strip()
|
461 |
+
|
462 |
+
|
463 |
+
_ordinal_re = {
|
464 |
+
"en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
|
465 |
+
"es": re.compile(r"([0-9]+)(º|ª|er|o|a|os|as)"),
|
466 |
+
"fr": re.compile(r"([0-9]+)(º|ª|er|re|e|ème)"),
|
467 |
+
"de": re.compile(r"([0-9]+)(st|nd|rd|th|º|ª|\.(?=\s|$))"),
|
468 |
+
"pt": re.compile(r"([0-9]+)(º|ª|o|a|os|as)"),
|
469 |
+
"it": re.compile(r"([0-9]+)(º|°|ª|o|a|i|e)"),
|
470 |
+
"pl": re.compile(r"([0-9]+)(º|ª|st|nd|rd|th)"),
|
471 |
+
"ar": re.compile(r"([0-9]+)(ون|ين|ث|ر|ى)"),
|
472 |
+
"cs": re.compile(r"([0-9]+)\.(?=\s|$)"), # In Czech, a dot is often used after the number to indicate ordinals.
|
473 |
+
"ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"),
|
474 |
+
"nl": re.compile(r"([0-9]+)(de|ste|e)"),
|
475 |
+
"tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"),
|
476 |
+
"hu": re.compile(r"([0-9]+)(\.|adik|edik|odik|edik|ödik|ödike|ik)"),
|
477 |
+
"ko": re.compile(r"([0-9]+)(번째|번|차|째)"),
|
478 |
+
"ja": re.compile(r"([0-9]+)(番|回|つ|目|等|位)")
|
479 |
+
}
|
480 |
+
_number_re = re.compile(r"[0-9]+")
|
481 |
+
_currency_re = {
|
482 |
+
"USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
|
483 |
+
"GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
|
484 |
+
"EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
|
485 |
+
}
|
486 |
+
|
487 |
+
_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
|
488 |
+
_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
|
489 |
+
_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
|
490 |
+
|
491 |
+
|
492 |
+
def _remove_commas(m):
|
493 |
+
text = m.group(0)
|
494 |
+
if "," in text:
|
495 |
+
text = text.replace(",", "")
|
496 |
+
return text
|
497 |
+
|
498 |
+
|
499 |
+
def _remove_dots(m):
|
500 |
+
text = m.group(0)
|
501 |
+
if "." in text:
|
502 |
+
text = text.replace(".", "")
|
503 |
+
return text
|
504 |
+
|
505 |
+
|
506 |
+
def _expand_decimal_point(m, lang="en"):
|
507 |
+
amount = m.group(1).replace(",", ".")
|
508 |
+
return num2words(float(amount), lang=lang if lang != "cs" else "cz")
|
509 |
+
|
510 |
+
|
511 |
+
def _expand_currency(m, lang="en", currency="USD"):
|
512 |
+
amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
|
513 |
+
full_amount = num2words(amount, to="currency", currency=currency, lang=lang if lang != "cs" else "cz")
|
514 |
+
|
515 |
+
and_equivalents = {
|
516 |
+
"en": ", ",
|
517 |
+
"es": " con ",
|
518 |
+
"fr": " et ",
|
519 |
+
"de": " und ",
|
520 |
+
"pt": " e ",
|
521 |
+
"it": " e ",
|
522 |
+
"pl": ", ",
|
523 |
+
"cs": ", ",
|
524 |
+
"ru": ", ",
|
525 |
+
"nl": ", ",
|
526 |
+
"ar": ", ",
|
527 |
+
"tr": ", ",
|
528 |
+
"hu": ", ",
|
529 |
+
"ko": ", ",
|
530 |
+
}
|
531 |
+
|
532 |
+
if amount.is_integer():
|
533 |
+
last_and = full_amount.rfind(and_equivalents[lang])
|
534 |
+
if last_and != -1:
|
535 |
+
full_amount = full_amount[:last_and]
|
536 |
+
|
537 |
+
return full_amount
|
538 |
+
|
539 |
+
|
540 |
+
def _expand_ordinal(m, lang="en"):
|
541 |
+
return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz")
|
542 |
+
|
543 |
+
|
544 |
+
def _expand_number(m, lang="en"):
|
545 |
+
return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz")
|
546 |
+
|
547 |
+
|
548 |
+
def expand_numbers_multilingual(text, lang="en"):
|
549 |
+
if lang == "zh":
|
550 |
+
text = zh_num2words()(text)
|
551 |
+
else:
|
552 |
+
if lang in ["en", "ru"]:
|
553 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
554 |
+
else:
|
555 |
+
text = re.sub(_dot_number_re, _remove_dots, text)
|
556 |
+
try:
|
557 |
+
text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
|
558 |
+
text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
|
559 |
+
text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
|
560 |
+
except:
|
561 |
+
pass
|
562 |
+
if lang != "tr":
|
563 |
+
text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
|
564 |
+
text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
|
565 |
+
text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
|
566 |
+
return text
|
567 |
+
|
568 |
+
|
569 |
+
def lowercase(text):
|
570 |
+
return text.lower()
|
571 |
+
|
572 |
+
|
573 |
+
def collapse_whitespace(text):
|
574 |
+
return re.sub(_whitespace_re, " ", text)
|
575 |
+
|
576 |
+
|
577 |
+
def multilingual_cleaners(text, lang):
|
578 |
+
text = text.replace('"', "")
|
579 |
+
if lang == "tr":
|
580 |
+
text = text.replace("İ", "i")
|
581 |
+
text = text.replace("Ö", "ö")
|
582 |
+
text = text.replace("Ü", "ü")
|
583 |
+
text = lowercase(text)
|
584 |
+
text = expand_numbers_multilingual(text, lang)
|
585 |
+
text = expand_abbreviations_multilingual(text, lang)
|
586 |
+
text = expand_symbols_multilingual(text, lang=lang)
|
587 |
+
text = collapse_whitespace(text)
|
588 |
+
return text
|
589 |
+
|
590 |
+
|
591 |
+
def basic_cleaners(text):
|
592 |
+
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
593 |
+
text = lowercase(text)
|
594 |
+
text = collapse_whitespace(text)
|
595 |
+
return text
|
596 |
+
|
597 |
+
|
598 |
+
def chinese_transliterate(text):
|
599 |
+
return "".join(
|
600 |
+
[p[0] for p in pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)]
|
601 |
+
)
|
602 |
+
|
603 |
+
|
604 |
+
def japanese_cleaners(text, katsu):
|
605 |
+
text = katsu.romaji(text)
|
606 |
+
text = lowercase(text)
|
607 |
+
return text
|
608 |
+
|
609 |
+
|
610 |
+
def korean_transliterate(text):
|
611 |
+
r = Transliter(academic)
|
612 |
+
return r.translit(text)
|
613 |
+
|
614 |
+
|
615 |
+
DEFAULT_VOCAB_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../data/tokenizer.json")
|
616 |
+
|
617 |
+
|
618 |
+
class VoiceBpeTokenizer:
|
619 |
+
def __init__(self, vocab_file=None):
|
620 |
+
self.tokenizer = None
|
621 |
+
if vocab_file is not None:
|
622 |
+
self.tokenizer = Tokenizer.from_file(vocab_file)
|
623 |
+
self.char_limits = {
|
624 |
+
"en": 250,
|
625 |
+
"de": 253,
|
626 |
+
"fr": 273,
|
627 |
+
"es": 239,
|
628 |
+
"it": 213,
|
629 |
+
"pt": 203,
|
630 |
+
"pl": 224,
|
631 |
+
"zh": 82,
|
632 |
+
"ar": 166,
|
633 |
+
"cs": 186,
|
634 |
+
"ru": 182,
|
635 |
+
"nl": 251,
|
636 |
+
"tr": 226,
|
637 |
+
"ja": 71,
|
638 |
+
"hu": 224,
|
639 |
+
"ko": 95,
|
640 |
+
}
|
641 |
+
|
642 |
+
@cached_property
|
643 |
+
def katsu(self):
|
644 |
+
import cutlet
|
645 |
+
|
646 |
+
return cutlet.Cutlet()
|
647 |
+
|
648 |
+
def check_input_length(self, txt, lang):
|
649 |
+
lang = lang.split("-")[0] # remove the region
|
650 |
+
limit = self.char_limits.get(lang, 250)
|
651 |
+
if len(txt) > limit:
|
652 |
+
print(
|
653 |
+
f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio."
|
654 |
+
)
|
655 |
+
|
656 |
+
def preprocess_text(self, txt, lang):
|
657 |
+
if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "ko"}:
|
658 |
+
txt = multilingual_cleaners(txt, lang)
|
659 |
+
if lang == "zh":
|
660 |
+
txt = chinese_transliterate(txt)
|
661 |
+
if lang == "ko":
|
662 |
+
txt = korean_transliterate(txt)
|
663 |
+
elif lang == "ja":
|
664 |
+
txt = japanese_cleaners(txt, self.katsu)
|
665 |
+
elif lang == "hi":
|
666 |
+
# @manmay will implement this
|
667 |
+
txt = basic_cleaners(txt)
|
668 |
+
else:
|
669 |
+
raise NotImplementedError(f"Language '{lang}' is not supported.")
|
670 |
+
return txt
|
671 |
+
|
672 |
+
def encode(self, txt, lang):
|
673 |
+
lang = lang.split("-")[0] # remove the region
|
674 |
+
self.check_input_length(txt, lang)
|
675 |
+
txt = self.preprocess_text(txt, lang)
|
676 |
+
lang = "zh-cn" if lang == "zh" else lang
|
677 |
+
txt = f"[{lang}]{txt}"
|
678 |
+
txt = txt.replace(" ", "[SPACE]")
|
679 |
+
return self.tokenizer.encode(txt).ids
|
680 |
+
|
681 |
+
def decode(self, seq):
|
682 |
+
if isinstance(seq, torch.Tensor):
|
683 |
+
seq = seq.cpu().numpy()
|
684 |
+
txt = self.tokenizer.decode(seq, skip_special_tokens=False).replace(" ", "")
|
685 |
+
txt = txt.replace("[SPACE]", " ")
|
686 |
+
txt = txt.replace("[STOP]", "")
|
687 |
+
txt = txt.replace("[UNK]", "")
|
688 |
+
return txt
|
689 |
+
|
690 |
+
def __len__(self):
|
691 |
+
return self.tokenizer.get_vocab_size()
|
692 |
+
|
693 |
+
def get_number_tokens(self):
|
694 |
+
return max(self.tokenizer.get_vocab().values()) + 1
|
695 |
+
|
696 |
+
|
697 |
+
def test_expand_numbers_multilingual():
|
698 |
+
test_cases = [
|
699 |
+
# English
|
700 |
+
("In 12.5 seconds.", "In twelve point five seconds.", "en"),
|
701 |
+
("There were 50 soldiers.", "There were fifty soldiers.", "en"),
|
702 |
+
("This is a 1st test", "This is a first test", "en"),
|
703 |
+
("That will be $20 sir.", "That will be twenty dollars sir.", "en"),
|
704 |
+
("That will be 20€ sir.", "That will be twenty euro sir.", "en"),
|
705 |
+
("That will be 20.15€ sir.", "That will be twenty euro, fifteen cents sir.", "en"),
|
706 |
+
("That's 100,000.5.", "That's one hundred thousand point five.", "en"),
|
707 |
+
# French
|
708 |
+
("En 12,5 secondes.", "En douze virgule cinq secondes.", "fr"),
|
709 |
+
("Il y avait 50 soldats.", "Il y avait cinquante soldats.", "fr"),
|
710 |
+
("Ceci est un 1er test", "Ceci est un premier test", "fr"),
|
711 |
+
("Cela vous fera $20 monsieur.", "Cela vous fera vingt dollars monsieur.", "fr"),
|
712 |
+
("Cela vous fera 20€ monsieur.", "Cela vous fera vingt euros monsieur.", "fr"),
|
713 |
+
("Cela vous fera 20,15€ monsieur.", "Cela vous fera vingt euros et quinze centimes monsieur.", "fr"),
|
714 |
+
("Ce sera 100.000,5.", "Ce sera cent mille virgule cinq.", "fr"),
|
715 |
+
# German
|
716 |
+
("In 12,5 Sekunden.", "In zwölf Komma fünf Sekunden.", "de"),
|
717 |
+
("Es gab 50 Soldaten.", "Es gab fünfzig Soldaten.", "de"),
|
718 |
+
("Dies ist ein 1. Test", "Dies ist ein erste Test", "de"), # Issue with gender
|
719 |
+
("Das macht $20 Herr.", "Das macht zwanzig Dollar Herr.", "de"),
|
720 |
+
("Das macht 20€ Herr.", "Das macht zwanzig Euro Herr.", "de"),
|
721 |
+
("Das macht 20,15€ Herr.", "Das macht zwanzig Euro und fünfzehn Cent Herr.", "de"),
|
722 |
+
# Spanish
|
723 |
+
("En 12,5 segundos.", "En doce punto cinco segundos.", "es"),
|
724 |
+
("Había 50 soldados.", "Había cincuenta soldados.", "es"),
|
725 |
+
("Este es un 1er test", "Este es un primero test", "es"),
|
726 |
+
("Eso le costará $20 señor.", "Eso le costará veinte dólares señor.", "es"),
|
727 |
+
("Eso le costará 20€ señor.", "Eso le costará veinte euros señor.", "es"),
|
728 |
+
("Eso le costará 20,15€ señor.", "Eso le costará veinte euros con quince céntimos señor.", "es"),
|
729 |
+
# Italian
|
730 |
+
("In 12,5 secondi.", "In dodici virgola cinque secondi.", "it"),
|
731 |
+
("C'erano 50 soldati.", "C'erano cinquanta soldati.", "it"),
|
732 |
+
("Questo è un 1° test", "Questo è un primo test", "it"),
|
733 |
+
("Ti costerà $20 signore.", "Ti costerà venti dollari signore.", "it"),
|
734 |
+
("Ti costerà 20€ signore.", "Ti costerà venti euro signore.", "it"),
|
735 |
+
("Ti costerà 20,15€ signore.", "Ti costerà venti euro e quindici centesimi signore.", "it"),
|
736 |
+
# Portuguese
|
737 |
+
("Em 12,5 segundos.", "Em doze vírgula cinco segundos.", "pt"),
|
738 |
+
("Havia 50 soldados.", "Havia cinquenta soldados.", "pt"),
|
739 |
+
("Este é um 1º teste", "Este é um primeiro teste", "pt"),
|
740 |
+
("Isso custará $20 senhor.", "Isso custará vinte dólares senhor.", "pt"),
|
741 |
+
("Isso custará 20€ senhor.", "Isso custará vinte euros senhor.", "pt"),
|
742 |
+
(
|
743 |
+
"Isso custará 20,15€ senhor.",
|
744 |
+
"Isso custará vinte euros e quinze cêntimos senhor.",
|
745 |
+
"pt",
|
746 |
+
), # "cêntimos" should be "centavos" num2words issue
|
747 |
+
# Polish
|
748 |
+
("W 12,5 sekundy.", "W dwanaście przecinek pięć sekundy.", "pl"),
|
749 |
+
("Było 50 żołnierzy.", "Było pięćdziesiąt żołnierzy.", "pl"),
|
750 |
+
("To będzie kosztować 20€ panie.", "To będzie kosztować dwadzieścia euro panie.", "pl"),
|
751 |
+
("To będzie kosztować 20,15€ panie.", "To będzie kosztować dwadzieścia euro, piętnaście centów panie.", "pl"),
|
752 |
+
# Arabic
|
753 |
+
("في الـ 12,5 ثانية.", "في الـ اثنا عشر , خمسون ثانية.", "ar"),
|
754 |
+
("كان هناك 50 جنديًا.", "كان هناك خمسون جنديًا.", "ar"),
|
755 |
+
# ("ستكون النتيجة $20 يا سيد.", 'ستكون النتيجة عشرون دولار يا سيد.', 'ar'), # $ and € are mising from num2words
|
756 |
+
# ("ستكون النتيجة 20€ يا سيد.", 'ستكون النتيجة عشرون يورو يا سيد.', 'ar'),
|
757 |
+
# Czech
|
758 |
+
("Za 12,5 vteřiny.", "Za dvanáct celá pět vteřiny.", "cs"),
|
759 |
+
("Bylo tam 50 vojáků.", "Bylo tam padesát vojáků.", "cs"),
|
760 |
+
("To bude stát 20€ pane.", "To bude stát dvacet euro pane.", "cs"),
|
761 |
+
("To bude 20.15€ pane.", "To bude dvacet euro, patnáct centů pane.", "cs"),
|
762 |
+
# Russian
|
763 |
+
("Через 12.5 секунды.", "Через двенадцать запятая пять секунды.", "ru"),
|
764 |
+
("Там было 50 солдат.", "Там было пятьдесят солдат.", "ru"),
|
765 |
+
("Это будет 20.15€ сэр.", "Это будет двадцать евро, пятнадцать центов сэр.", "ru"),
|
766 |
+
("Это будет стоить 20€ господин.", "Это будет стоить двадцать евро господин.", "ru"),
|
767 |
+
# Dutch
|
768 |
+
("In 12,5 seconden.", "In twaalf komma vijf seconden.", "nl"),
|
769 |
+
("Er waren 50 soldaten.", "Er waren vijftig soldaten.", "nl"),
|
770 |
+
("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"),
|
771 |
+
("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"),
|
772 |
+
# Chinese (Simplified)
|
773 |
+
("在12.5秒内", "在十二点五秒内", "zh"),
|
774 |
+
("有50名士兵", "有五十名士兵", "zh"),
|
775 |
+
# ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work
|
776 |
+
# ("那将是20€先生", '那将是二十欧元先生', 'zh'),
|
777 |
+
# Turkish
|
778 |
+
# ("12,5 saniye içinde.", 'On iki virgül beş saniye içinde.', 'tr'), # decimal doesn't work for TR
|
779 |
+
("50 asker vardı.", "elli asker vardı.", "tr"),
|
780 |
+
("Bu 1. test", "Bu birinci test", "tr"),
|
781 |
+
# ("Bu 100.000,5.", 'Bu yüz bin virgül beş.', 'tr'),
|
782 |
+
# Hungarian
|
783 |
+
("12,5 másodperc alatt.", "tizenkettő egész öt tized másodperc alatt.", "hu"),
|
784 |
+
("50 katona volt.", "ötven katona volt.", "hu"),
|
785 |
+
("Ez az 1. teszt", "Ez az első teszt", "hu"),
|
786 |
+
# Korean
|
787 |
+
("12.5 초 안에.", "십이 점 다섯 초 안에.", "ko"),
|
788 |
+
("50 명의 병사가 있었다.", "오십 명의 병사가 있었다.", "ko"),
|
789 |
+
("이것은 1 번째 테스트입니다", "이것은 첫 번째 테스트입니다", "ko"),
|
790 |
+
]
|
791 |
+
for a, b, lang in test_cases:
|
792 |
+
out = expand_numbers_multilingual(a, lang=lang)
|
793 |
+
assert out == b, f"'{out}' vs '{b}'"
|
794 |
+
|
795 |
+
|
796 |
+
def test_abbreviations_multilingual():
|
797 |
+
test_cases = [
|
798 |
+
# English
|
799 |
+
("Hello Mr. Smith.", "Hello mister Smith.", "en"),
|
800 |
+
("Dr. Jones is here.", "doctor Jones is here.", "en"),
|
801 |
+
# Spanish
|
802 |
+
("Hola Sr. Garcia.", "Hola señor Garcia.", "es"),
|
803 |
+
("La Dra. Martinez es muy buena.", "La doctora Martinez es muy buena.", "es"),
|
804 |
+
# French
|
805 |
+
("Bonjour Mr. Dupond.", "Bonjour monsieur Dupond.", "fr"),
|
806 |
+
("Mme. Moreau est absente aujourd'hui.", "madame Moreau est absente aujourd'hui.", "fr"),
|
807 |
+
# German
|
808 |
+
("Frau Dr. Müller ist sehr klug.", "Frau doktor Müller ist sehr klug.", "de"),
|
809 |
+
# Portuguese
|
810 |
+
("Olá Sr. Silva.", "Olá senhor Silva.", "pt"),
|
811 |
+
("Dra. Costa, você está disponível?", "doutora Costa, você está disponível?", "pt"),
|
812 |
+
# Italian
|
813 |
+
("Buongiorno, Sig. Rossi.", "Buongiorno, signore Rossi.", "it"),
|
814 |
+
# ("Sig.ra Bianchi, posso aiutarti?", 'signora Bianchi, posso aiutarti?', 'it'), # Issue with matching that pattern
|
815 |
+
# Polish
|
816 |
+
("Dzień dobry, P. Kowalski.", "Dzień dobry, pani Kowalski.", "pl"),
|
817 |
+
("M. Nowak, czy mogę zadać pytanie?", "pan Nowak, czy mogę zadać pytanie?", "pl"),
|
818 |
+
# Czech
|
819 |
+
("P. Novák", "pan Novák", "cs"),
|
820 |
+
("Dr. Vojtěch", "doktor Vojtěch", "cs"),
|
821 |
+
# Dutch
|
822 |
+
("Dhr. Jansen", "de heer Jansen", "nl"),
|
823 |
+
("Mevr. de Vries", "mevrouw de Vries", "nl"),
|
824 |
+
# Russian
|
825 |
+
("Здравствуйте Г-н Иванов.", "Здравствуйте господин Иванов.", "ru"),
|
826 |
+
("Д-р Смирнов здесь, чтобы увидеть вас.", "доктор Смирнов здесь, чтобы увидеть вас.", "ru"),
|
827 |
+
# Turkish
|
828 |
+
("Merhaba B. Yılmaz.", "Merhaba bay Yılmaz.", "tr"),
|
829 |
+
("Dr. Ayşe burada.", "doktor Ayşe burada.", "tr"),
|
830 |
+
# Hungarian
|
831 |
+
("Dr. Szabó itt van.", "doktor Szabó itt van.", "hu"),
|
832 |
+
]
|
833 |
+
|
834 |
+
for a, b, lang in test_cases:
|
835 |
+
out = expand_abbreviations_multilingual(a, lang=lang)
|
836 |
+
assert out == b, f"'{out}' vs '{b}'"
|
837 |
+
|
838 |
+
|
839 |
+
def test_symbols_multilingual():
|
840 |
+
test_cases = [
|
841 |
+
("I have 14% battery", "I have 14 percent battery", "en"),
|
842 |
+
("Te veo @ la fiesta", "Te veo arroba la fiesta", "es"),
|
843 |
+
("J'ai 14° de fièvre", "J'ai 14 degrés de fièvre", "fr"),
|
844 |
+
("Die Rechnung beträgt £ 20", "Die Rechnung beträgt pfund 20", "de"),
|
845 |
+
("O meu email é ana&joao@gmail.com", "O meu email é ana e joao arroba gmail.com", "pt"),
|
846 |
+
("linguaggio di programmazione C#", "linguaggio di programmazione C cancelletto", "it"),
|
847 |
+
("Moja temperatura to 36.6°", "Moja temperatura to 36.6 stopnie", "pl"),
|
848 |
+
("Mám 14% baterie", "Mám 14 procento baterie", "cs"),
|
849 |
+
("Těším se na tebe @ party", "Těším se na tebe na party", "cs"),
|
850 |
+
("У меня 14% заряда", "У меня 14 процентов заряда", "ru"),
|
851 |
+
("Я буду @ дома", "Я буду собака дома", "ru"),
|
852 |
+
("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"),
|
853 |
+
("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"),
|
854 |
+
("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"),
|
855 |
+
("我的电量为 14%", "我的电量为 14 百分之", "zh"),
|
856 |
+
("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"),
|
857 |
+
("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"),
|
858 |
+
("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"),
|
859 |
+
]
|
860 |
+
|
861 |
+
for a, b, lang in test_cases:
|
862 |
+
out = expand_symbols_multilingual(a, lang=lang)
|
863 |
+
assert out == b, f"'{out}' vs '{b}'"
|
864 |
+
|
865 |
+
|
866 |
+
if __name__ == "__main__":
|
867 |
+
test_expand_numbers_multilingual()
|
868 |
+
test_abbreviations_multilingual()
|
869 |
+
test_symbols_multilingual()
|
xtts_demo.py
ADDED
@@ -0,0 +1,693 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import tempfile
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import os
|
8 |
+
import shutil
|
9 |
+
import glob
|
10 |
+
|
11 |
+
import gradio as gr
|
12 |
+
import librosa.display
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torchaudio
|
17 |
+
import traceback
|
18 |
+
from utils.formatter import format_audio_list,find_latest_best_model, list_audios
|
19 |
+
from utils.gpt_train import train_gpt
|
20 |
+
|
21 |
+
from faster_whisper import WhisperModel
|
22 |
+
|
23 |
+
from TTS.tts.configs.xtts_config import XttsConfig
|
24 |
+
from TTS.tts.models.xtts import Xtts
|
25 |
+
|
26 |
+
from TTS.tts.configs.xtts_config import XttsConfig
|
27 |
+
from TTS.tts.models.xtts import Xtts
|
28 |
+
|
29 |
+
# Clear logs
|
30 |
+
def remove_log_file(file_path):
|
31 |
+
log_file = Path(file_path)
|
32 |
+
|
33 |
+
if log_file.exists() and log_file.is_file():
|
34 |
+
log_file.unlink()
|
35 |
+
|
36 |
+
# remove_log_file(str(Path.cwd() / "log.out"))
|
37 |
+
|
38 |
+
def clear_gpu_cache():
|
39 |
+
# clear the GPU cache
|
40 |
+
if torch.cuda.is_available():
|
41 |
+
torch.cuda.empty_cache()
|
42 |
+
|
43 |
+
XTTS_MODEL = None
|
44 |
+
def load_model(xtts_checkpoint, xtts_config, xtts_vocab,xtts_speaker):
|
45 |
+
global XTTS_MODEL
|
46 |
+
clear_gpu_cache()
|
47 |
+
if not xtts_checkpoint or not xtts_config or not xtts_vocab:
|
48 |
+
return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!"
|
49 |
+
config = XttsConfig()
|
50 |
+
config.load_json(xtts_config)
|
51 |
+
XTTS_MODEL = Xtts.init_from_config(config)
|
52 |
+
print("Loading XTTS model! ")
|
53 |
+
XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab,speaker_file_path=xtts_speaker, use_deepspeed=False)
|
54 |
+
if torch.cuda.is_available():
|
55 |
+
XTTS_MODEL.cuda()
|
56 |
+
|
57 |
+
print("Model Loaded!")
|
58 |
+
return "Model Loaded!"
|
59 |
+
|
60 |
+
def run_tts(lang, tts_text, speaker_audio_file, temperature, length_penalty,repetition_penalty,top_k,top_p,sentence_split,use_config):
|
61 |
+
if XTTS_MODEL is None or not speaker_audio_file:
|
62 |
+
return "You need to run the previous step to load the model !!", None, None
|
63 |
+
|
64 |
+
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs)
|
65 |
+
|
66 |
+
if use_config:
|
67 |
+
out = XTTS_MODEL.inference(
|
68 |
+
text=tts_text,
|
69 |
+
language=lang,
|
70 |
+
gpt_cond_latent=gpt_cond_latent,
|
71 |
+
speaker_embedding=speaker_embedding,
|
72 |
+
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
|
73 |
+
length_penalty=XTTS_MODEL.config.length_penalty,
|
74 |
+
repetition_penalty=XTTS_MODEL.config.repetition_penalty,
|
75 |
+
top_k=XTTS_MODEL.config.top_k,
|
76 |
+
top_p=XTTS_MODEL.config.top_p,
|
77 |
+
enable_text_splitting = True
|
78 |
+
)
|
79 |
+
else:
|
80 |
+
out = XTTS_MODEL.inference(
|
81 |
+
text=tts_text,
|
82 |
+
language=lang,
|
83 |
+
gpt_cond_latent=gpt_cond_latent,
|
84 |
+
speaker_embedding=speaker_embedding,
|
85 |
+
temperature=temperature, # Add custom parameters here
|
86 |
+
length_penalty=length_penalty,
|
87 |
+
repetition_penalty=float(repetition_penalty),
|
88 |
+
top_k=top_k,
|
89 |
+
top_p=top_p,
|
90 |
+
enable_text_splitting = sentence_split
|
91 |
+
)
|
92 |
+
|
93 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
|
94 |
+
out["wav"] = torch.tensor(out["wav"]).unsqueeze(0)
|
95 |
+
out_path = fp.name
|
96 |
+
torchaudio.save(out_path, out["wav"], 24000)
|
97 |
+
|
98 |
+
return "Speech generated !", out_path, speaker_audio_file
|
99 |
+
|
100 |
+
|
101 |
+
def load_params_tts(out_path,version):
|
102 |
+
|
103 |
+
out_path = Path(out_path)
|
104 |
+
|
105 |
+
# base_model_path = Path.cwd() / "models" / version
|
106 |
+
|
107 |
+
# if not base_model_path.exists():
|
108 |
+
# return "Base model not found !","","",""
|
109 |
+
|
110 |
+
ready_model_path = out_path / "ready"
|
111 |
+
|
112 |
+
vocab_path = ready_model_path / "vocab.json"
|
113 |
+
config_path = ready_model_path / "config.json"
|
114 |
+
speaker_path = ready_model_path / "speakers_xtts.pth"
|
115 |
+
reference_path = ready_model_path / "reference.wav"
|
116 |
+
|
117 |
+
model_path = ready_model_path / "model.pth"
|
118 |
+
|
119 |
+
if not model_path.exists():
|
120 |
+
model_path = ready_model_path / "unoptimize_model.pth"
|
121 |
+
if not model_path.exists():
|
122 |
+
return "Params for TTS not found", "", "", ""
|
123 |
+
|
124 |
+
return "Params for TTS loaded", model_path, config_path, vocab_path,speaker_path, reference_path
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
|
129 |
+
parser = argparse.ArgumentParser(
|
130 |
+
description="""XTTS fine-tuning demo\n\n"""
|
131 |
+
"""
|
132 |
+
Example runs:
|
133 |
+
python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port
|
134 |
+
""",
|
135 |
+
formatter_class=argparse.RawTextHelpFormatter,
|
136 |
+
)
|
137 |
+
parser.add_argument(
|
138 |
+
"--port",
|
139 |
+
type=int,
|
140 |
+
help="Port to run the gradio demo. Default: 5003",
|
141 |
+
default=5003,
|
142 |
+
)
|
143 |
+
parser.add_argument(
|
144 |
+
"--out_path",
|
145 |
+
type=str,
|
146 |
+
help="Output path (where data and checkpoints will be saved) Default: output/",
|
147 |
+
default=str(Path.cwd() / "finetune_models"),
|
148 |
+
)
|
149 |
+
|
150 |
+
parser.add_argument(
|
151 |
+
"--num_epochs",
|
152 |
+
type=int,
|
153 |
+
help="Number of epochs to train. Default: 6",
|
154 |
+
default=6,
|
155 |
+
)
|
156 |
+
parser.add_argument(
|
157 |
+
"--batch_size",
|
158 |
+
type=int,
|
159 |
+
help="Batch size. Default: 2",
|
160 |
+
default=2,
|
161 |
+
)
|
162 |
+
parser.add_argument(
|
163 |
+
"--grad_acumm",
|
164 |
+
type=int,
|
165 |
+
help="Grad accumulation steps. Default: 1",
|
166 |
+
default=1,
|
167 |
+
)
|
168 |
+
parser.add_argument(
|
169 |
+
"--max_audio_length",
|
170 |
+
type=int,
|
171 |
+
help="Max permitted audio size in seconds. Default: 11",
|
172 |
+
default=11,
|
173 |
+
)
|
174 |
+
|
175 |
+
args = parser.parse_args()
|
176 |
+
|
177 |
+
with gr.Blocks() as demo:
|
178 |
+
with gr.Tab("1 - Data processing"):
|
179 |
+
out_path = gr.Textbox(
|
180 |
+
label="Output path (where data and checkpoints will be saved):",
|
181 |
+
value=args.out_path,
|
182 |
+
)
|
183 |
+
# upload_file = gr.Audio(
|
184 |
+
# sources="upload",
|
185 |
+
# label="Select here the audio files that you want to use for XTTS trainining !",
|
186 |
+
# type="filepath",
|
187 |
+
# )
|
188 |
+
upload_file = gr.File(
|
189 |
+
file_count="multiple",
|
190 |
+
label="Select here the audio files that you want to use for XTTS trainining (Supported formats: wav, mp3, and flac)",
|
191 |
+
)
|
192 |
+
|
193 |
+
audio_folder_path = gr.Textbox(
|
194 |
+
label="Path to the folder with audio files (optional):",
|
195 |
+
value="",
|
196 |
+
)
|
197 |
+
|
198 |
+
whisper_model = gr.Dropdown(
|
199 |
+
label="Whisper Model",
|
200 |
+
value="large-v3",
|
201 |
+
choices=[
|
202 |
+
"large-v3",
|
203 |
+
"large-v2",
|
204 |
+
"large",
|
205 |
+
"medium",
|
206 |
+
"small"
|
207 |
+
],
|
208 |
+
)
|
209 |
+
|
210 |
+
lang = gr.Dropdown(
|
211 |
+
label="Dataset Language",
|
212 |
+
value="en",
|
213 |
+
choices=[
|
214 |
+
"en",
|
215 |
+
"es",
|
216 |
+
"fr",
|
217 |
+
"de",
|
218 |
+
"it",
|
219 |
+
"pt",
|
220 |
+
"pl",
|
221 |
+
"tr",
|
222 |
+
"ru",
|
223 |
+
"nl",
|
224 |
+
"cs",
|
225 |
+
"ar",
|
226 |
+
"zh",
|
227 |
+
"hu",
|
228 |
+
"ko",
|
229 |
+
"ja"
|
230 |
+
],
|
231 |
+
)
|
232 |
+
progress_data = gr.Label(
|
233 |
+
label="Progress:"
|
234 |
+
)
|
235 |
+
# demo.load(read_logs, None, logs, every=1)
|
236 |
+
|
237 |
+
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset")
|
238 |
+
|
239 |
+
def preprocess_dataset(audio_path, audio_folder_path, language, whisper_model, out_path, train_csv, eval_csv, progress=gr.Progress(track_tqdm=True)):
|
240 |
+
clear_gpu_cache()
|
241 |
+
|
242 |
+
train_csv = ""
|
243 |
+
eval_csv = ""
|
244 |
+
|
245 |
+
out_path = os.path.join(out_path, "dataset")
|
246 |
+
os.makedirs(out_path, exist_ok=True)
|
247 |
+
|
248 |
+
if audio_folder_path:
|
249 |
+
audio_files = list(list_audios(audio_folder_path))
|
250 |
+
else:
|
251 |
+
audio_files = audio_path
|
252 |
+
|
253 |
+
if not audio_files:
|
254 |
+
return "No audio files found! Please provide files via Gradio or specify a folder path.", "", ""
|
255 |
+
else:
|
256 |
+
try:
|
257 |
+
# Loading Whisper
|
258 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
259 |
+
|
260 |
+
# Detect compute type
|
261 |
+
if torch.cuda.is_available():
|
262 |
+
compute_type = "float16"
|
263 |
+
else:
|
264 |
+
compute_type = "float32"
|
265 |
+
|
266 |
+
asr_model = WhisperModel(whisper_model, device=device, compute_type=compute_type)
|
267 |
+
train_meta, eval_meta, audio_total_size = format_audio_list(audio_files, asr_model=asr_model, target_language=language, out_path=out_path, gradio_progress=progress)
|
268 |
+
except:
|
269 |
+
traceback.print_exc()
|
270 |
+
error = traceback.format_exc()
|
271 |
+
return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", ""
|
272 |
+
|
273 |
+
# clear_gpu_cache()
|
274 |
+
|
275 |
+
# if audio total len is less than 2 minutes raise an error
|
276 |
+
if audio_total_size < 120:
|
277 |
+
message = "The sum of the duration of the audios that you provided should be at least 2 minutes!"
|
278 |
+
print(message)
|
279 |
+
return message, "", ""
|
280 |
+
|
281 |
+
print("Dataset Processed!")
|
282 |
+
return "Dataset Processed!", train_meta, eval_meta
|
283 |
+
|
284 |
+
|
285 |
+
with gr.Tab("2 - Fine-tuning XTTS Encoder"):
|
286 |
+
load_params_btn = gr.Button(value="Load Params from output folder")
|
287 |
+
version = gr.Dropdown(
|
288 |
+
label="XTTS base version",
|
289 |
+
value="v2.0.2",
|
290 |
+
choices=[
|
291 |
+
"v2.0.3",
|
292 |
+
"v2.0.2",
|
293 |
+
"v2.0.1",
|
294 |
+
"v2.0.0",
|
295 |
+
"main"
|
296 |
+
],
|
297 |
+
)
|
298 |
+
train_csv = gr.Textbox(
|
299 |
+
label="Train CSV:",
|
300 |
+
)
|
301 |
+
eval_csv = gr.Textbox(
|
302 |
+
label="Eval CSV:",
|
303 |
+
)
|
304 |
+
custom_model = gr.Textbox(
|
305 |
+
label="(Optional) Custom model.pth file , leave blank if you want to use the base file.",
|
306 |
+
value="",
|
307 |
+
)
|
308 |
+
num_epochs = gr.Slider(
|
309 |
+
label="Number of epochs:",
|
310 |
+
minimum=1,
|
311 |
+
maximum=100,
|
312 |
+
step=1,
|
313 |
+
value=args.num_epochs,
|
314 |
+
)
|
315 |
+
batch_size = gr.Slider(
|
316 |
+
label="Batch size:",
|
317 |
+
minimum=2,
|
318 |
+
maximum=512,
|
319 |
+
step=1,
|
320 |
+
value=args.batch_size,
|
321 |
+
)
|
322 |
+
grad_acumm = gr.Slider(
|
323 |
+
label="Grad accumulation steps:",
|
324 |
+
minimum=2,
|
325 |
+
maximum=128,
|
326 |
+
step=1,
|
327 |
+
value=args.grad_acumm,
|
328 |
+
)
|
329 |
+
max_audio_length = gr.Slider(
|
330 |
+
label="Max permitted audio size in seconds:",
|
331 |
+
minimum=2,
|
332 |
+
maximum=20,
|
333 |
+
step=1,
|
334 |
+
value=args.max_audio_length,
|
335 |
+
)
|
336 |
+
clear_train_data = gr.Dropdown(
|
337 |
+
label="Clear train data, you will delete selected folder, after optimizing",
|
338 |
+
value="none",
|
339 |
+
choices=[
|
340 |
+
"none",
|
341 |
+
"run",
|
342 |
+
"dataset",
|
343 |
+
"all"
|
344 |
+
])
|
345 |
+
|
346 |
+
progress_train = gr.Label(
|
347 |
+
label="Progress:"
|
348 |
+
)
|
349 |
+
|
350 |
+
# demo.load(read_logs, None, logs_tts_train, every=1)
|
351 |
+
train_btn = gr.Button(value="Step 2 - Run the training")
|
352 |
+
optimize_model_btn = gr.Button(value="Step 2.5 - Optimize the model")
|
353 |
+
|
354 |
+
def train_model(custom_model,version,language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length):
|
355 |
+
clear_gpu_cache()
|
356 |
+
|
357 |
+
run_dir = Path(output_path) / "run"
|
358 |
+
|
359 |
+
# # Remove train dir
|
360 |
+
if run_dir.exists():
|
361 |
+
os.remove(run_dir)
|
362 |
+
|
363 |
+
# Check if the dataset language matches the language you specified
|
364 |
+
lang_file_path = Path(output_path) / "dataset" / "lang.txt"
|
365 |
+
|
366 |
+
# Check if lang.txt already exists and contains a different language
|
367 |
+
current_language = None
|
368 |
+
if lang_file_path.exists():
|
369 |
+
with open(lang_file_path, 'r', encoding='utf-8') as existing_lang_file:
|
370 |
+
current_language = existing_lang_file.read().strip()
|
371 |
+
if current_language != language:
|
372 |
+
print("The language that was prepared for the dataset does not match the specified language. Change the language to the one specified in the dataset")
|
373 |
+
language = current_language
|
374 |
+
|
375 |
+
if not train_csv or not eval_csv:
|
376 |
+
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", ""
|
377 |
+
try:
|
378 |
+
# convert seconds to waveform frames
|
379 |
+
max_audio_length = int(max_audio_length * 22050)
|
380 |
+
speaker_xtts_path,config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(custom_model,version,language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length)
|
381 |
+
except:
|
382 |
+
traceback.print_exc()
|
383 |
+
error = traceback.format_exc()
|
384 |
+
return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", ""
|
385 |
+
|
386 |
+
# copy original files to avoid parameters changes issues
|
387 |
+
# os.system(f"cp {config_path} {exp_path}")
|
388 |
+
# os.system(f"cp {vocab_file} {exp_path}")
|
389 |
+
|
390 |
+
ready_dir = Path(output_path) / "ready"
|
391 |
+
|
392 |
+
ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth")
|
393 |
+
|
394 |
+
shutil.copy(ft_xtts_checkpoint, ready_dir / "unoptimize_model.pth")
|
395 |
+
# os.remove(ft_xtts_checkpoint)
|
396 |
+
|
397 |
+
ft_xtts_checkpoint = os.path.join(ready_dir, "unoptimize_model.pth")
|
398 |
+
|
399 |
+
# Reference
|
400 |
+
# Move reference audio to output folder and rename it
|
401 |
+
speaker_reference_path = Path(speaker_wav)
|
402 |
+
speaker_reference_new_path = ready_dir / "reference.wav"
|
403 |
+
shutil.copy(speaker_reference_path, speaker_reference_new_path)
|
404 |
+
|
405 |
+
print("Model training done!")
|
406 |
+
# clear_gpu_cache()
|
407 |
+
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint,speaker_xtts_path, speaker_reference_new_path
|
408 |
+
|
409 |
+
def optimize_model(out_path, clear_train_data):
|
410 |
+
# print(out_path)
|
411 |
+
out_path = Path(out_path) # Ensure that out_path is a Path object.
|
412 |
+
|
413 |
+
ready_dir = out_path / "ready"
|
414 |
+
run_dir = out_path / "run"
|
415 |
+
dataset_dir = out_path / "dataset"
|
416 |
+
|
417 |
+
# Clear specified training data directories.
|
418 |
+
if clear_train_data in {"run", "all"} and run_dir.exists():
|
419 |
+
try:
|
420 |
+
shutil.rmtree(run_dir)
|
421 |
+
except PermissionError as e:
|
422 |
+
print(f"An error occurred while deleting {run_dir}: {e}")
|
423 |
+
|
424 |
+
if clear_train_data in {"dataset", "all"} and dataset_dir.exists():
|
425 |
+
try:
|
426 |
+
shutil.rmtree(dataset_dir)
|
427 |
+
except PermissionError as e:
|
428 |
+
print(f"An error occurred while deleting {dataset_dir}: {e}")
|
429 |
+
|
430 |
+
# Get full path to model
|
431 |
+
model_path = ready_dir / "unoptimize_model.pth"
|
432 |
+
|
433 |
+
if not model_path.is_file():
|
434 |
+
return "Unoptimized model not found in ready folder", ""
|
435 |
+
|
436 |
+
# Load the checkpoint and remove unnecessary parts.
|
437 |
+
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
|
438 |
+
del checkpoint["optimizer"]
|
439 |
+
|
440 |
+
for key in list(checkpoint["model"].keys()):
|
441 |
+
if "dvae" in key:
|
442 |
+
del checkpoint["model"][key]
|
443 |
+
|
444 |
+
# Make sure out_path is a Path object or convert it to Path
|
445 |
+
os.remove(model_path)
|
446 |
+
|
447 |
+
# Save the optimized model.
|
448 |
+
optimized_model_file_name="model.pth"
|
449 |
+
optimized_model=ready_dir/optimized_model_file_name
|
450 |
+
|
451 |
+
torch.save(checkpoint, optimized_model)
|
452 |
+
ft_xtts_checkpoint=str(optimized_model)
|
453 |
+
|
454 |
+
clear_gpu_cache()
|
455 |
+
|
456 |
+
return f"Model optimized and saved at {ft_xtts_checkpoint}!", ft_xtts_checkpoint
|
457 |
+
|
458 |
+
def load_params(out_path):
|
459 |
+
path_output = Path(out_path)
|
460 |
+
|
461 |
+
dataset_path = path_output / "dataset"
|
462 |
+
|
463 |
+
if not dataset_path.exists():
|
464 |
+
return "The output folder does not exist!", "", ""
|
465 |
+
|
466 |
+
eval_train = dataset_path / "metadata_train.csv"
|
467 |
+
eval_csv = dataset_path / "metadata_eval.csv"
|
468 |
+
|
469 |
+
# Write the target language to lang.txt in the output directory
|
470 |
+
lang_file_path = dataset_path / "lang.txt"
|
471 |
+
|
472 |
+
# Check if lang.txt already exists and contains a different language
|
473 |
+
current_language = None
|
474 |
+
if os.path.exists(lang_file_path):
|
475 |
+
with open(lang_file_path, 'r', encoding='utf-8') as existing_lang_file:
|
476 |
+
current_language = existing_lang_file.read().strip()
|
477 |
+
|
478 |
+
clear_gpu_cache()
|
479 |
+
|
480 |
+
print(current_language)
|
481 |
+
return "The data has been updated", eval_train, eval_csv, current_language
|
482 |
+
|
483 |
+
with gr.Tab("3 - Inference"):
|
484 |
+
with gr.Row():
|
485 |
+
with gr.Column() as col1:
|
486 |
+
load_params_tts_btn = gr.Button(value="Load params for TTS from output folder")
|
487 |
+
xtts_checkpoint = gr.Textbox(
|
488 |
+
label="XTTS checkpoint path:",
|
489 |
+
value="",
|
490 |
+
)
|
491 |
+
xtts_config = gr.Textbox(
|
492 |
+
label="XTTS config path:",
|
493 |
+
value="",
|
494 |
+
)
|
495 |
+
|
496 |
+
xtts_vocab = gr.Textbox(
|
497 |
+
label="XTTS vocab path:",
|
498 |
+
value="",
|
499 |
+
)
|
500 |
+
xtts_speaker = gr.Textbox(
|
501 |
+
label="XTTS speaker path:",
|
502 |
+
value="",
|
503 |
+
)
|
504 |
+
progress_load = gr.Label(
|
505 |
+
label="Progress:"
|
506 |
+
)
|
507 |
+
load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model")
|
508 |
+
|
509 |
+
with gr.Column() as col2:
|
510 |
+
speaker_reference_audio = gr.Textbox(
|
511 |
+
label="Speaker reference audio:",
|
512 |
+
value="",
|
513 |
+
)
|
514 |
+
tts_language = gr.Dropdown(
|
515 |
+
label="Language",
|
516 |
+
value="en",
|
517 |
+
choices=[
|
518 |
+
"en",
|
519 |
+
"es",
|
520 |
+
"fr",
|
521 |
+
"de",
|
522 |
+
"it",
|
523 |
+
"pt",
|
524 |
+
"pl",
|
525 |
+
"tr",
|
526 |
+
"ru",
|
527 |
+
"nl",
|
528 |
+
"cs",
|
529 |
+
"ar",
|
530 |
+
"zh",
|
531 |
+
"hu",
|
532 |
+
"ko",
|
533 |
+
"ja",
|
534 |
+
]
|
535 |
+
)
|
536 |
+
tts_text = gr.Textbox(
|
537 |
+
label="Input Text.",
|
538 |
+
value="This model sounds really good and above all, it's reasonably fast.",
|
539 |
+
)
|
540 |
+
with gr.Accordion("Advanced settings", open=False) as acr:
|
541 |
+
temperature = gr.Slider(
|
542 |
+
label="temperature",
|
543 |
+
minimum=0,
|
544 |
+
maximum=1,
|
545 |
+
step=0.05,
|
546 |
+
value=0.75,
|
547 |
+
)
|
548 |
+
length_penalty = gr.Slider(
|
549 |
+
label="length_penalty",
|
550 |
+
minimum=-10.0,
|
551 |
+
maximum=10.0,
|
552 |
+
step=0.5,
|
553 |
+
value=1,
|
554 |
+
)
|
555 |
+
repetition_penalty = gr.Slider(
|
556 |
+
label="repetition penalty",
|
557 |
+
minimum=1,
|
558 |
+
maximum=10,
|
559 |
+
step=0.5,
|
560 |
+
value=5,
|
561 |
+
)
|
562 |
+
top_k = gr.Slider(
|
563 |
+
label="top_k",
|
564 |
+
minimum=1,
|
565 |
+
maximum=100,
|
566 |
+
step=1,
|
567 |
+
value=50,
|
568 |
+
)
|
569 |
+
top_p = gr.Slider(
|
570 |
+
label="top_p",
|
571 |
+
minimum=0,
|
572 |
+
maximum=1,
|
573 |
+
step=0.05,
|
574 |
+
value=0.85,
|
575 |
+
)
|
576 |
+
sentence_split = gr.Checkbox(
|
577 |
+
label="Enable text splitting",
|
578 |
+
value=True,
|
579 |
+
)
|
580 |
+
use_config = gr.Checkbox(
|
581 |
+
label="Use Inference settings from config, if disabled use the settings above",
|
582 |
+
value=False,
|
583 |
+
)
|
584 |
+
tts_btn = gr.Button(value="Step 4 - Inference")
|
585 |
+
|
586 |
+
with gr.Column() as col3:
|
587 |
+
progress_gen = gr.Label(
|
588 |
+
label="Progress:"
|
589 |
+
)
|
590 |
+
tts_output_audio = gr.Audio(label="Generated Audio.")
|
591 |
+
reference_audio = gr.Audio(label="Reference audio used.")
|
592 |
+
|
593 |
+
prompt_compute_btn.click(
|
594 |
+
fn=preprocess_dataset,
|
595 |
+
inputs=[
|
596 |
+
upload_file,
|
597 |
+
audio_folder_path,
|
598 |
+
lang,
|
599 |
+
whisper_model,
|
600 |
+
out_path,
|
601 |
+
train_csv,
|
602 |
+
eval_csv
|
603 |
+
],
|
604 |
+
outputs=[
|
605 |
+
progress_data,
|
606 |
+
train_csv,
|
607 |
+
eval_csv,
|
608 |
+
],
|
609 |
+
)
|
610 |
+
|
611 |
+
|
612 |
+
load_params_btn.click(
|
613 |
+
fn=load_params,
|
614 |
+
inputs=[out_path],
|
615 |
+
outputs=[
|
616 |
+
progress_train,
|
617 |
+
train_csv,
|
618 |
+
eval_csv,
|
619 |
+
lang
|
620 |
+
]
|
621 |
+
)
|
622 |
+
|
623 |
+
|
624 |
+
train_btn.click(
|
625 |
+
fn=train_model,
|
626 |
+
inputs=[
|
627 |
+
custom_model,
|
628 |
+
version,
|
629 |
+
lang,
|
630 |
+
train_csv,
|
631 |
+
eval_csv,
|
632 |
+
num_epochs,
|
633 |
+
batch_size,
|
634 |
+
grad_acumm,
|
635 |
+
out_path,
|
636 |
+
max_audio_length,
|
637 |
+
],
|
638 |
+
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint,xtts_speaker, speaker_reference_audio],
|
639 |
+
)
|
640 |
+
|
641 |
+
optimize_model_btn.click(
|
642 |
+
fn=optimize_model,
|
643 |
+
inputs=[
|
644 |
+
out_path,
|
645 |
+
clear_train_data
|
646 |
+
],
|
647 |
+
outputs=[progress_train,xtts_checkpoint],
|
648 |
+
)
|
649 |
+
|
650 |
+
load_btn.click(
|
651 |
+
fn=load_model,
|
652 |
+
inputs=[
|
653 |
+
xtts_checkpoint,
|
654 |
+
xtts_config,
|
655 |
+
xtts_vocab,
|
656 |
+
xtts_speaker
|
657 |
+
],
|
658 |
+
outputs=[progress_load],
|
659 |
+
)
|
660 |
+
|
661 |
+
tts_btn.click(
|
662 |
+
fn=run_tts,
|
663 |
+
inputs=[
|
664 |
+
tts_language,
|
665 |
+
tts_text,
|
666 |
+
speaker_reference_audio,
|
667 |
+
temperature,
|
668 |
+
length_penalty,
|
669 |
+
repetition_penalty,
|
670 |
+
top_k,
|
671 |
+
top_p,
|
672 |
+
sentence_split,
|
673 |
+
use_config
|
674 |
+
],
|
675 |
+
outputs=[progress_gen, tts_output_audio,reference_audio],
|
676 |
+
)
|
677 |
+
|
678 |
+
load_params_tts_btn.click(
|
679 |
+
fn=load_params_tts,
|
680 |
+
inputs=[
|
681 |
+
out_path,
|
682 |
+
version
|
683 |
+
],
|
684 |
+
outputs=[progress_load,xtts_checkpoint,xtts_config,xtts_vocab,xtts_speaker,speaker_reference_audio],
|
685 |
+
)
|
686 |
+
|
687 |
+
demo.launch(
|
688 |
+
share=False,
|
689 |
+
debug=False,
|
690 |
+
server_port=args.port,
|
691 |
+
# inweb=True,
|
692 |
+
# server_name="localhost"
|
693 |
+
)
|