Quenya-TTS / run_training_pipeline.py
AnnieZzz's picture
Update app.py and requirements.txt
cd4e2cb verified
import argparse
import os
import random
import sys
import torch
from TrainingInterfaces.TrainingPipelines.Avocodo_combined import run as hifi_codo
from TrainingInterfaces.TrainingPipelines.BigVGAN_combined import run as bigvgan
from TrainingInterfaces.TrainingPipelines.FastSpeech2Embedding_IntegrationTest import run as fs_integration_test
from TrainingInterfaces.TrainingPipelines.GST_FastSpeech2 import run as embedding
from TrainingInterfaces.TrainingPipelines.StochasticToucanTTS_Nancy import run as nancystoch
from TrainingInterfaces.TrainingPipelines.ToucanTTS_IntegrationTest import run as tt_integration_test
from TrainingInterfaces.TrainingPipelines.ToucanTTS_MetaCheckpoint import run as meta
from TrainingInterfaces.TrainingPipelines.ToucanTTS_Nancy import run as nancy
from TrainingInterfaces.TrainingPipelines.ToucanTTS_Finnish import run as finnish
from TrainingInterfaces.TrainingPipelines.ToucanTTS_English import run as english
from TrainingInterfaces.TrainingPipelines.ToucanTTS_Quenya import run as quenya
from TrainingInterfaces.TrainingPipelines.finetuning_example import run as fine_tuning_example
from TrainingInterfaces.TrainingPipelines.pretrain_aligner import run as aligner
pipeline_dict = {
# the finetuning example
"fine_ex" : fine_tuning_example,
# integration tests
"fs_it" : fs_integration_test,
"tt_it" : tt_integration_test,
# regular ToucanTTS pipelines
"nancy" : nancy,
"nancystoch" : nancystoch,
"meta" : meta,
"finnish" : finnish,
"english" : english,
"quenya" : quenya,
# training vocoders (not recommended, best to use provided checkpoint)
"avocodo" : hifi_codo,
"bigvgan" : bigvgan,
# training the GST embedding jointly with FastSpeech 2 on expressive data (not recommended, best to use provided checkpoint)
"embedding" : embedding,
# training the aligner from scratch (not recommended, best to use provided checkpoint)
"aligner" : aligner,
}
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Training with the IMS Toucan Speech Synthesis Toolkit')
parser.add_argument('pipeline',
choices=list(pipeline_dict.keys()),
help="Select pipeline to train.")
parser.add_argument('--gpu_id',
type=str,
help="Which GPU to run on. If not specified runs on CPU, but other than for integration tests that doesn't make much sense.",
default="cpu")
parser.add_argument('--resume_checkpoint',
type=str,
help="Path to checkpoint to resume from.",
default=None)
parser.add_argument('--resume',
action="store_true",
help="Automatically load the highest checkpoint and continue from there.",
default=False)
parser.add_argument('--finetune',
action="store_true",
help="Whether to fine-tune from the specified checkpoint.",
default=False)
parser.add_argument('--model_save_dir',
type=str,
help="Directory where the checkpoints should be saved to.",
default=None)
parser.add_argument('--wandb',
action="store_true",
help="Whether to use weights and biases to track training runs. Requires you to run wandb login and place your auth key before.",
default=False)
parser.add_argument('--wandb_resume_id',
type=str,
help="ID of a stopped wandb run to continue tracking",
default=None)
args = parser.parse_args()
if args.finetune and args.resume_checkpoint is None and not args.resume:
print("Need to provide path to checkpoint to fine-tune from!")
sys.exit()
if args.gpu_id == "cpu":
os.environ["CUDA_VISIBLE_DEVICES"] = ""
device = torch.device("cpu")
print(f"No GPU specified, using CPU. Training will likely not work without GPU.")
else:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.gpu_id}"
device = torch.device("cuda")
print(f"Making GPU {os.environ['CUDA_VISIBLE_DEVICES']} the only visible device.")
torch.manual_seed(131714)
random.seed(131714)
torch.random.manual_seed(131714)
pipeline_dict[args.pipeline](gpu_id=args.gpu_id,
resume_checkpoint=args.resume_checkpoint,
resume=args.resume,
finetune=args.finetune,
model_dir=args.model_save_dir,
use_wandb=args.wandb,
wandb_resume_id=args.wandb_resume_id)