import os from dataclasses import dataclass, field from trainer import Trainer, TrainerArgs from TTS.config import load_config, register_config from TTS.tts.datasets import load_tts_samples from TTS.tts.models import setup_model @dataclass class TrainTTSArgs(TrainerArgs): config_path: str = field(default=None, metadata={"help": "Path to the config file."}) def main(): """Run `tts` model training directly by a `config.json` file.""" # init trainer args train_args = TrainTTSArgs() parser = train_args.init_argparse(arg_prefix="") # override trainer args from comman-line args args, config_overrides = parser.parse_known_args() train_args.parse_args(args) # load config.json and register if args.config_path or args.continue_path: if args.config_path: # init from a file config = load_config(args.config_path) if len(config_overrides) > 0: config.parse_known_args(config_overrides, relaxed_parser=True) elif args.continue_path: # continue from a prev experiment config = load_config(os.path.join(args.continue_path, "config.json")) if len(config_overrides) > 0: config.parse_known_args(config_overrides, relaxed_parser=True) else: # init from console args from TTS.config.shared_configs import BaseTrainingConfig # pylint: disable=import-outside-toplevel config_base = BaseTrainingConfig() config_base.parse_known_args(config_overrides) config = register_config(config_base.model)() # load training samples train_samples, eval_samples = load_tts_samples( config.datasets, eval_split=True, eval_split_max_size=config.eval_split_max_size, eval_split_size=config.eval_split_size, ) # init the model from config model = setup_model(config, train_samples + eval_samples) # init the trainer and 🚀 trainer = Trainer( train_args, model.config, config.output_path, model=model, train_samples=train_samples, eval_samples=eval_samples, parse_command_line_args=False, ) trainer.fit() if __name__ == "__main__": main()