Spaces:
Running
Running
import argparse | |
import os | |
import random | |
import sys | |
import torch | |
from TrainingInterfaces.TrainingPipelines.Avocodo_combined import run as hifi_codo | |
from TrainingInterfaces.TrainingPipelines.BigVGAN_combined import run as bigvgan | |
from TrainingInterfaces.TrainingPipelines.FastSpeech2Embedding_IntegrationTest import run as fs_integration_test | |
from TrainingInterfaces.TrainingPipelines.GST_FastSpeech2 import run as embedding | |
from TrainingInterfaces.TrainingPipelines.StochasticToucanTTS_Nancy import run as nancystoch | |
from TrainingInterfaces.TrainingPipelines.ToucanTTS_IntegrationTest import run as tt_integration_test | |
from TrainingInterfaces.TrainingPipelines.ToucanTTS_MetaCheckpoint import run as meta | |
from TrainingInterfaces.TrainingPipelines.ToucanTTS_Nancy import run as nancy | |
from TrainingInterfaces.TrainingPipelines.ToucanTTS_Finnish import run as finnish | |
from TrainingInterfaces.TrainingPipelines.ToucanTTS_English import run as english | |
from TrainingInterfaces.TrainingPipelines.ToucanTTS_Quenya import run as quenya | |
from TrainingInterfaces.TrainingPipelines.finetuning_example import run as fine_tuning_example | |
from TrainingInterfaces.TrainingPipelines.pretrain_aligner import run as aligner | |
pipeline_dict = { | |
# the finetuning example | |
"fine_ex" : fine_tuning_example, | |
# integration tests | |
"fs_it" : fs_integration_test, | |
"tt_it" : tt_integration_test, | |
# regular ToucanTTS pipelines | |
"nancy" : nancy, | |
"nancystoch" : nancystoch, | |
"meta" : meta, | |
"finnish" : finnish, | |
"english" : english, | |
"quenya" : quenya, | |
# training vocoders (not recommended, best to use provided checkpoint) | |
"avocodo" : hifi_codo, | |
"bigvgan" : bigvgan, | |
# training the GST embedding jointly with FastSpeech 2 on expressive data (not recommended, best to use provided checkpoint) | |
"embedding" : embedding, | |
# training the aligner from scratch (not recommended, best to use provided checkpoint) | |
"aligner" : aligner, | |
} | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Training with the IMS Toucan Speech Synthesis Toolkit') | |
parser.add_argument('pipeline', | |
choices=list(pipeline_dict.keys()), | |
help="Select pipeline to train.") | |
parser.add_argument('--gpu_id', | |
type=str, | |
help="Which GPU to run on. If not specified runs on CPU, but other than for integration tests that doesn't make much sense.", | |
default="cpu") | |
parser.add_argument('--resume_checkpoint', | |
type=str, | |
help="Path to checkpoint to resume from.", | |
default=None) | |
parser.add_argument('--resume', | |
action="store_true", | |
help="Automatically load the highest checkpoint and continue from there.", | |
default=False) | |
parser.add_argument('--finetune', | |
action="store_true", | |
help="Whether to fine-tune from the specified checkpoint.", | |
default=False) | |
parser.add_argument('--model_save_dir', | |
type=str, | |
help="Directory where the checkpoints should be saved to.", | |
default=None) | |
parser.add_argument('--wandb', | |
action="store_true", | |
help="Whether to use weights and biases to track training runs. Requires you to run wandb login and place your auth key before.", | |
default=False) | |
parser.add_argument('--wandb_resume_id', | |
type=str, | |
help="ID of a stopped wandb run to continue tracking", | |
default=None) | |
args = parser.parse_args() | |
if args.finetune and args.resume_checkpoint is None and not args.resume: | |
print("Need to provide path to checkpoint to fine-tune from!") | |
sys.exit() | |
if args.gpu_id == "cpu": | |
os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
device = torch.device("cpu") | |
print(f"No GPU specified, using CPU. Training will likely not work without GPU.") | |
else: | |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |
os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.gpu_id}" | |
device = torch.device("cuda") | |
print(f"Making GPU {os.environ['CUDA_VISIBLE_DEVICES']} the only visible device.") | |
torch.manual_seed(131714) | |
random.seed(131714) | |
torch.random.manual_seed(131714) | |
pipeline_dict[args.pipeline](gpu_id=args.gpu_id, | |
resume_checkpoint=args.resume_checkpoint, | |
resume=args.resume, | |
finetune=args.finetune, | |
model_dir=args.model_save_dir, | |
use_wandb=args.wandb, | |
wandb_resume_id=args.wandb_resume_id) | |