|
import argparse |
|
import os |
|
import sys |
|
import tempfile |
|
|
|
import gradio as gr |
|
import librosa.display |
|
import numpy as np |
|
|
|
import os |
|
import torch |
|
import torchaudio |
|
import traceback |
|
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list |
|
from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt |
|
|
|
from TTS.tts.configs.xtts_config import XttsConfig |
|
from TTS.tts.models.xtts import Xtts |
|
|
|
|
|
def clear_gpu_cache(): |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
XTTS_MODEL = None |
|
def load_model(xtts_checkpoint, xtts_config, xtts_vocab): |
|
global XTTS_MODEL |
|
clear_gpu_cache() |
|
if not xtts_checkpoint or not xtts_config or not xtts_vocab: |
|
return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!" |
|
config = XttsConfig() |
|
config.load_json(xtts_config) |
|
XTTS_MODEL = Xtts.init_from_config(config) |
|
print("Loading XTTS model! ") |
|
XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False) |
|
if torch.cuda.is_available(): |
|
XTTS_MODEL.cuda() |
|
|
|
print("Model Loaded!") |
|
return "Model Loaded!" |
|
|
|
def run_tts(lang, tts_text, speaker_audio_file): |
|
if XTTS_MODEL is None or not speaker_audio_file: |
|
return "You need to run the previous step to load the model !!", None, None |
|
|
|
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs) |
|
out = XTTS_MODEL.inference( |
|
text=tts_text, |
|
language=lang, |
|
gpt_cond_latent=gpt_cond_latent, |
|
speaker_embedding=speaker_embedding, |
|
temperature=XTTS_MODEL.config.temperature, |
|
length_penalty=XTTS_MODEL.config.length_penalty, |
|
repetition_penalty=XTTS_MODEL.config.repetition_penalty, |
|
top_k=XTTS_MODEL.config.top_k, |
|
top_p=XTTS_MODEL.config.top_p, |
|
) |
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: |
|
out["wav"] = torch.tensor(out["wav"]).unsqueeze(0) |
|
out_path = fp.name |
|
torchaudio.save(out_path, out["wav"], 24000) |
|
|
|
return "Speech generated !", out_path, speaker_audio_file |
|
|
|
|
|
|
|
|
|
|
|
class Logger: |
|
def __init__(self, filename="log.out"): |
|
self.log_file = filename |
|
self.terminal = sys.stdout |
|
self.log = open(self.log_file, "w") |
|
|
|
def write(self, message): |
|
self.terminal.write(message) |
|
self.log.write(message) |
|
|
|
def flush(self): |
|
self.terminal.flush() |
|
self.log.flush() |
|
|
|
def isatty(self): |
|
return False |
|
|
|
|
|
sys.stdout = Logger() |
|
sys.stderr = sys.stdout |
|
|
|
|
|
|
|
import logging |
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s [%(levelname)s] %(message)s", |
|
handlers=[ |
|
logging.StreamHandler(sys.stdout) |
|
] |
|
) |
|
|
|
def read_logs(): |
|
sys.stdout.flush() |
|
with open(sys.stdout.log_file, "r") as f: |
|
return f.read() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser( |
|
description="""XTTS fine-tuning demo\n\n""" |
|
""" |
|
Example runs: |
|
python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port |
|
""", |
|
formatter_class=argparse.RawTextHelpFormatter, |
|
) |
|
parser.add_argument( |
|
"--port", |
|
type=int, |
|
help="Port to run the gradio demo. Default: 5003", |
|
default=5003, |
|
) |
|
parser.add_argument( |
|
"--out_path", |
|
type=str, |
|
help="Output path (where data and checkpoints will be saved) Default: /tmp/xtts_ft/", |
|
default="/tmp/xtts_ft/", |
|
) |
|
|
|
parser.add_argument( |
|
"--num_epochs", |
|
type=int, |
|
help="Number of epochs to train. Default: 10", |
|
default=10, |
|
) |
|
parser.add_argument( |
|
"--batch_size", |
|
type=int, |
|
help="Batch size. Default: 4", |
|
default=4, |
|
) |
|
parser.add_argument( |
|
"--grad_acumm", |
|
type=int, |
|
help="Grad accumulation steps. Default: 1", |
|
default=1, |
|
) |
|
parser.add_argument( |
|
"--max_audio_length", |
|
type=int, |
|
help="Max permitted audio size in seconds. Default: 11", |
|
default=11, |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Tab("1 - Data processing"): |
|
out_path = gr.Textbox( |
|
label="Output path (where data and checkpoints will be saved):", |
|
value=args.out_path, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
upload_file = gr.File( |
|
file_count="multiple", |
|
label="Select here the audio files that you want to use for XTTS trainining (Supported formats: wav, mp3, and flac)", |
|
) |
|
lang = gr.Dropdown( |
|
label="Dataset Language", |
|
value="en", |
|
choices=[ |
|
"en", |
|
"es", |
|
"fr", |
|
"de", |
|
"it", |
|
"pt", |
|
"pl", |
|
"tr", |
|
"ru", |
|
"nl", |
|
"cs", |
|
"ar", |
|
"zh", |
|
"hu", |
|
"ko", |
|
"ja" |
|
], |
|
) |
|
progress_data = gr.Label( |
|
label="Progress:" |
|
) |
|
logs = gr.Textbox( |
|
label="Logs:", |
|
interactive=False, |
|
) |
|
demo.load(read_logs, None, logs, every=1) |
|
|
|
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset") |
|
|
|
def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)): |
|
clear_gpu_cache() |
|
out_path = os.path.join(out_path, "dataset") |
|
os.makedirs(out_path, exist_ok=True) |
|
if audio_path is None: |
|
return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", "" |
|
else: |
|
try: |
|
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress) |
|
except: |
|
traceback.print_exc() |
|
error = traceback.format_exc() |
|
return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", "" |
|
|
|
clear_gpu_cache() |
|
|
|
|
|
if audio_total_size < 120: |
|
message = "The sum of the duration of the audios that you provided should be at least 2 minutes!" |
|
print(message) |
|
return message, "", "" |
|
|
|
print("Dataset Processed!") |
|
return "Dataset Processed!", train_meta, eval_meta |
|
|
|
with gr.Tab("2 - Fine-tuning XTTS Encoder"): |
|
train_csv = gr.Textbox( |
|
label="Train CSV:", |
|
) |
|
eval_csv = gr.Textbox( |
|
label="Eval CSV:", |
|
) |
|
num_epochs = gr.Slider( |
|
label="Number of epochs:", |
|
minimum=1, |
|
maximum=100, |
|
step=1, |
|
value=args.num_epochs, |
|
) |
|
batch_size = gr.Slider( |
|
label="Batch size:", |
|
minimum=2, |
|
maximum=512, |
|
step=1, |
|
value=args.batch_size, |
|
) |
|
grad_acumm = gr.Slider( |
|
label="Grad accumulation steps:", |
|
minimum=2, |
|
maximum=128, |
|
step=1, |
|
value=args.grad_acumm, |
|
) |
|
max_audio_length = gr.Slider( |
|
label="Max permitted audio size in seconds:", |
|
minimum=2, |
|
maximum=20, |
|
step=1, |
|
value=args.max_audio_length, |
|
) |
|
progress_train = gr.Label( |
|
label="Progress:" |
|
) |
|
logs_tts_train = gr.Textbox( |
|
label="Logs:", |
|
interactive=False, |
|
) |
|
demo.load(read_logs, None, logs_tts_train, every=1) |
|
train_btn = gr.Button(value="Step 2 - Run the training") |
|
|
|
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length): |
|
clear_gpu_cache() |
|
if not train_csv or not eval_csv: |
|
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", "" |
|
try: |
|
|
|
max_audio_length = int(max_audio_length * 22050) |
|
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length) |
|
except: |
|
traceback.print_exc() |
|
error = traceback.format_exc() |
|
return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", "" |
|
|
|
|
|
os.system(f"cp {config_path} {exp_path}") |
|
os.system(f"cp {vocab_file} {exp_path}") |
|
|
|
ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth") |
|
print("Model training done!") |
|
clear_gpu_cache() |
|
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav |
|
|
|
with gr.Tab("3 - Inference"): |
|
with gr.Row(): |
|
with gr.Column() as col1: |
|
xtts_checkpoint = gr.Textbox( |
|
label="XTTS checkpoint path:", |
|
value="", |
|
) |
|
xtts_config = gr.Textbox( |
|
label="XTTS config path:", |
|
value="", |
|
) |
|
|
|
xtts_vocab = gr.Textbox( |
|
label="XTTS vocab path:", |
|
value="", |
|
) |
|
progress_load = gr.Label( |
|
label="Progress:" |
|
) |
|
load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model") |
|
|
|
with gr.Column() as col2: |
|
speaker_reference_audio = gr.Textbox( |
|
label="Speaker reference audio:", |
|
value="", |
|
) |
|
tts_language = gr.Dropdown( |
|
label="Language", |
|
value="en", |
|
choices=[ |
|
"en", |
|
"es", |
|
"fr", |
|
"de", |
|
"it", |
|
"pt", |
|
"pl", |
|
"tr", |
|
"ru", |
|
"nl", |
|
"cs", |
|
"ar", |
|
"zh", |
|
"hu", |
|
"ko", |
|
"ja", |
|
] |
|
) |
|
tts_text = gr.Textbox( |
|
label="Input Text.", |
|
value="This model sounds really good and above all, it's reasonably fast.", |
|
) |
|
tts_btn = gr.Button(value="Step 4 - Inference") |
|
|
|
with gr.Column() as col3: |
|
progress_gen = gr.Label( |
|
label="Progress:" |
|
) |
|
tts_output_audio = gr.Audio(label="Generated Audio.") |
|
reference_audio = gr.Audio(label="Reference audio used.") |
|
|
|
prompt_compute_btn.click( |
|
fn=preprocess_dataset, |
|
inputs=[ |
|
upload_file, |
|
lang, |
|
out_path, |
|
], |
|
outputs=[ |
|
progress_data, |
|
train_csv, |
|
eval_csv, |
|
], |
|
) |
|
|
|
|
|
train_btn.click( |
|
fn=train_model, |
|
inputs=[ |
|
lang, |
|
train_csv, |
|
eval_csv, |
|
num_epochs, |
|
batch_size, |
|
grad_acumm, |
|
out_path, |
|
max_audio_length, |
|
], |
|
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio], |
|
) |
|
|
|
load_btn.click( |
|
fn=load_model, |
|
inputs=[ |
|
xtts_checkpoint, |
|
xtts_config, |
|
xtts_vocab |
|
], |
|
outputs=[progress_load], |
|
) |
|
|
|
tts_btn.click( |
|
fn=run_tts, |
|
inputs=[ |
|
tts_language, |
|
tts_text, |
|
speaker_reference_audio, |
|
], |
|
outputs=[progress_gen, tts_output_audio, reference_audio], |
|
) |
|
|
|
demo.launch( |
|
share=True, |
|
debug=False, |
|
server_port=args.port, |
|
server_name="0.0.0.0" |
|
) |
|
|