import argparse import os import sys import tempfile import gradio as gr import librosa.display import numpy as np import os import torch import torchaudio import traceback from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts def clear_gpu_cache(): # clear the GPU cache if torch.cuda.is_available(): torch.cuda.empty_cache() XTTS_MODEL = None def load_model(xtts_checkpoint, xtts_config, xtts_vocab): global XTTS_MODEL clear_gpu_cache() if not xtts_checkpoint or not xtts_config or not xtts_vocab: return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!" config = XttsConfig() config.load_json(xtts_config) XTTS_MODEL = Xtts.init_from_config(config) print("Loading XTTS model! ") XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False) if torch.cuda.is_available(): XTTS_MODEL.cuda() print("Model Loaded!") return "Model Loaded!" def run_tts(lang, tts_text, speaker_audio_file): if XTTS_MODEL is None or not speaker_audio_file: return "You need to run the previous step to load the model !!", None, None gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs) out = XTTS_MODEL.inference( text=tts_text, language=lang, gpt_cond_latent=gpt_cond_latent, speaker_embedding=speaker_embedding, temperature=XTTS_MODEL.config.temperature, # Add custom parameters here length_penalty=XTTS_MODEL.config.length_penalty, repetition_penalty=XTTS_MODEL.config.repetition_penalty, top_k=XTTS_MODEL.config.top_k, top_p=XTTS_MODEL.config.top_p, ) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp: out["wav"] = torch.tensor(out["wav"]).unsqueeze(0) out_path = fp.name torchaudio.save(out_path, out["wav"], 24000) return "Speech generated !", out_path, speaker_audio_file # define a logger to redirect class Logger: def __init__(self, filename="log.out"): self.log_file = filename self.terminal = sys.stdout self.log = open(self.log_file, "w") def write(self, message): self.terminal.write(message) self.log.write(message) def flush(self): self.terminal.flush() self.log.flush() def isatty(self): return False # redirect stdout and stderr to a file sys.stdout = Logger() sys.stderr = sys.stdout # logging.basicConfig(stream=sys.stdout, level=logging.INFO) import logging logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[ logging.StreamHandler(sys.stdout) ] ) def read_logs(): sys.stdout.flush() with open(sys.stdout.log_file, "r") as f: return f.read() if __name__ == "__main__": parser = argparse.ArgumentParser( description="""XTTS fine-tuning demo\n\n""" """ Example runs: python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port """, formatter_class=argparse.RawTextHelpFormatter, ) parser.add_argument( "--port", type=int, help="Port to run the gradio demo. Default: 5003", default=5003, ) parser.add_argument( "--out_path", type=str, help="Output path (where data and checkpoints will be saved) Default: /tmp/xtts_ft/", default="/tmp/xtts_ft/", ) parser.add_argument( "--num_epochs", type=int, help="Number of epochs to train. Default: 10", default=10, ) parser.add_argument( "--batch_size", type=int, help="Batch size. Default: 4", default=4, ) parser.add_argument( "--grad_acumm", type=int, help="Grad accumulation steps. Default: 1", default=1, ) parser.add_argument( "--max_audio_length", type=int, help="Max permitted audio size in seconds. Default: 11", default=11, ) args = parser.parse_args() with gr.Blocks() as demo: with gr.Tab("1 - Data processing"): out_path = gr.Textbox( label="Output path (where data and checkpoints will be saved):", value=args.out_path, ) # upload_file = gr.Audio( # sources="upload", # label="Select here the audio files that you want to use for XTTS trainining !", # type="filepath", # ) upload_file = gr.File( file_count="multiple", label="Select here the audio files that you want to use for XTTS trainining (Supported formats: wav, mp3, and flac)", ) lang = gr.Dropdown( label="Dataset Language", value="en", choices=[ "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh", "hu", "ko", "ja" ], ) progress_data = gr.Label( label="Progress:" ) logs = gr.Textbox( label="Logs:", interactive=False, ) demo.load(read_logs, None, logs, every=1) prompt_compute_btn = gr.Button(value="Step 1 - Create dataset") def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)): clear_gpu_cache() out_path = os.path.join(out_path, "dataset") os.makedirs(out_path, exist_ok=True) if audio_path is None: return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", "" else: try: train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress) except: traceback.print_exc() error = traceback.format_exc() return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", "" clear_gpu_cache() # if audio total len is less than 2 minutes raise an error if audio_total_size < 120: message = "The sum of the duration of the audios that you provided should be at least 2 minutes!" print(message) return message, "", "" print("Dataset Processed!") return "Dataset Processed!", train_meta, eval_meta with gr.Tab("2 - Fine-tuning XTTS Encoder"): train_csv = gr.Textbox( label="Train CSV:", ) eval_csv = gr.Textbox( label="Eval CSV:", ) num_epochs = gr.Slider( label="Number of epochs:", minimum=1, maximum=100, step=1, value=args.num_epochs, ) batch_size = gr.Slider( label="Batch size:", minimum=2, maximum=512, step=1, value=args.batch_size, ) grad_acumm = gr.Slider( label="Grad accumulation steps:", minimum=2, maximum=128, step=1, value=args.grad_acumm, ) max_audio_length = gr.Slider( label="Max permitted audio size in seconds:", minimum=2, maximum=20, step=1, value=args.max_audio_length, ) progress_train = gr.Label( label="Progress:" ) logs_tts_train = gr.Textbox( label="Logs:", interactive=False, ) demo.load(read_logs, None, logs_tts_train, every=1) train_btn = gr.Button(value="Step 2 - Run the training") def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length): clear_gpu_cache() if not train_csv or not eval_csv: return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", "" try: # convert seconds to waveform frames max_audio_length = int(max_audio_length * 22050) config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length) except: traceback.print_exc() error = traceback.format_exc() return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", "" # copy original files to avoid parameters changes issues os.system(f"cp {config_path} {exp_path}") os.system(f"cp {vocab_file} {exp_path}") ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth") print("Model training done!") clear_gpu_cache() return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav with gr.Tab("3 - Inference"): with gr.Row(): with gr.Column() as col1: xtts_checkpoint = gr.Textbox( label="XTTS checkpoint path:", value="", ) xtts_config = gr.Textbox( label="XTTS config path:", value="", ) xtts_vocab = gr.Textbox( label="XTTS vocab path:", value="", ) progress_load = gr.Label( label="Progress:" ) load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model") with gr.Column() as col2: speaker_reference_audio = gr.Textbox( label="Speaker reference audio:", value="", ) tts_language = gr.Dropdown( label="Language", value="en", choices=[ "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh", "hu", "ko", "ja", ] ) tts_text = gr.Textbox( label="Input Text.", value="This model sounds really good and above all, it's reasonably fast.", ) tts_btn = gr.Button(value="Step 4 - Inference") with gr.Column() as col3: progress_gen = gr.Label( label="Progress:" ) tts_output_audio = gr.Audio(label="Generated Audio.") reference_audio = gr.Audio(label="Reference audio used.") prompt_compute_btn.click( fn=preprocess_dataset, inputs=[ upload_file, lang, out_path, ], outputs=[ progress_data, train_csv, eval_csv, ], ) train_btn.click( fn=train_model, inputs=[ lang, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, out_path, max_audio_length, ], outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio], ) load_btn.click( fn=load_model, inputs=[ xtts_checkpoint, xtts_config, xtts_vocab ], outputs=[progress_load], ) tts_btn.click( fn=run_tts, inputs=[ tts_language, tts_text, speaker_reference_audio, ], outputs=[progress_gen, tts_output_audio, reference_audio], ) demo.launch( share=True, debug=False, server_port=args.port, server_name="0.0.0.0" )