|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
|
import torch |
|
|
|
from models.tts.fastspeech2.fs2_trainer import FastSpeech2Trainer |
|
from models.tts.vits.vits_trainer import VITSTrainer |
|
from models.tts.valle.valle_trainer import VALLETrainer |
|
from models.tts.naturalspeech2.ns2_trainer import NS2Trainer |
|
from models.tts.valle_v2.valle_ar_trainer import ValleARTrainer as VALLE_V2_AR |
|
from models.tts.valle_v2.valle_nar_trainer import ValleNARTrainer as VALLE_V2_NAR |
|
from models.tts.jets.jets_trainer import JetsTrainer |
|
|
|
from utils.util import load_config |
|
|
|
|
|
def build_trainer(args, cfg): |
|
supported_trainer = { |
|
"FastSpeech2": FastSpeech2Trainer, |
|
"VITS": VITSTrainer, |
|
"VALLE": VALLETrainer, |
|
"NaturalSpeech2": NS2Trainer, |
|
"VALLE_V2_AR": VALLE_V2_AR, |
|
"VALLE_V2_NAR": VALLE_V2_NAR, |
|
"Jets": JetsTrainer, |
|
} |
|
|
|
trainer_class = supported_trainer[cfg.model_type] |
|
trainer = trainer_class(args, cfg) |
|
return trainer |
|
|
|
|
|
def cuda_relevant(deterministic=False): |
|
torch.cuda.empty_cache() |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.enabled = True |
|
torch.backends.cudnn.benchmark = False |
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
torch.backends.cudnn.deterministic = deterministic |
|
torch.backends.cudnn.benchmark = not deterministic |
|
torch.use_deterministic_algorithms(deterministic) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--config", |
|
default="config.json", |
|
help="json files for configurations.", |
|
required=True, |
|
) |
|
parser.add_argument( |
|
"--seed", |
|
type=int, |
|
default=1234, |
|
help="random seed", |
|
required=False, |
|
) |
|
parser.add_argument( |
|
"--exp_name", |
|
type=str, |
|
default="exp_name", |
|
help="A specific name to note the experiment", |
|
required=True, |
|
) |
|
parser.add_argument( |
|
"--resume", action="store_true", help="The model name to restore" |
|
) |
|
parser.add_argument( |
|
"--test", action="store_true", default=False, help="Test the model" |
|
) |
|
parser.add_argument( |
|
"--log_level", default="warning", help="logging level (debug, info, warning)" |
|
) |
|
parser.add_argument( |
|
"--resume_type", |
|
type=str, |
|
default="resume", |
|
help="Resume training or finetuning.", |
|
) |
|
parser.add_argument( |
|
"--checkpoint_path", |
|
type=str, |
|
default=None, |
|
help="Checkpoint for resume training or finetuning.", |
|
) |
|
parser.add_argument( |
|
"--resume_from_ckpt_path", |
|
type=str, |
|
default="", |
|
help="Checkpoint for resume training or finetuning.", |
|
) |
|
|
|
args = parser.parse_args() |
|
cfg = load_config(args.config) |
|
|
|
|
|
if hasattr(cfg, "preprocess"): |
|
if hasattr(cfg.preprocess, "data_augment"): |
|
if ( |
|
type(cfg.preprocess.data_augment) == list |
|
and len(cfg.preprocess.data_augment) > 0 |
|
): |
|
new_datasets_list = [] |
|
for dataset in cfg.preprocess.data_augment: |
|
new_datasets = [ |
|
( |
|
f"{dataset}_pitch_shift" |
|
if cfg.preprocess.use_pitch_shift |
|
else None |
|
), |
|
( |
|
f"{dataset}_formant_shift" |
|
if cfg.preprocess.use_formant_shift |
|
else None |
|
), |
|
( |
|
f"{dataset}_equalizer" |
|
if cfg.preprocess.use_equalizer |
|
else None |
|
), |
|
( |
|
f"{dataset}_time_stretch" |
|
if cfg.preprocess.use_time_stretch |
|
else None |
|
), |
|
] |
|
new_datasets_list.extend(filter(None, new_datasets)) |
|
cfg.dataset.extend(new_datasets_list) |
|
|
|
print("experiment name: ", args.exp_name) |
|
|
|
cuda_relevant() |
|
|
|
|
|
print(f"Building {cfg.model_type} trainer") |
|
trainer = build_trainer(args, cfg) |
|
print(f"Start training {cfg.model_type} model") |
|
if args.test: |
|
trainer.test_loop() |
|
else: |
|
trainer.train_loop() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|