# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse import os import torch from models.tta.autoencoder.autoencoder_trainer import AutoencoderKLTrainer from models.tta.ldm.audioldm_trainer import AudioLDMTrainer from utils.util import load_config def build_trainer(args, cfg): supported_trainer = { "AutoencoderKL": AutoencoderKLTrainer, "AudioLDM": AudioLDMTrainer, } trainer_class = supported_trainer[cfg.model_type] trainer = trainer_class(args, cfg) return trainer def main(): parser = argparse.ArgumentParser() parser.add_argument( "--config", default="config.json", help="json files for configurations.", required=True, ) parser.add_argument( "--num_workers", type=int, default=6, help="Number of dataloader workers." ) parser.add_argument( "--exp_name", type=str, default="exp_name", help="A specific name to note the experiment", required=True, ) parser.add_argument( "--resume", type=str, default=None, # action="store_true", help="The model name to restore", ) parser.add_argument( "--log_level", default="info", help="logging level (info, debug, warning)" ) parser.add_argument("--stdout_interval", default=5, type=int) parser.add_argument("--local_rank", default=-1, type=int) args = parser.parse_args() cfg = load_config(args.config) cfg.exp_name = args.exp_name # Model saving dir args.log_dir = os.path.join(cfg.log_dir, args.exp_name) os.makedirs(args.log_dir, exist_ok=True) if not cfg.train.ddp: args.local_rank = torch.device("cuda") # Build trainer trainer = build_trainer(args, cfg) # Restore models if args.resume: trainer.restore() trainer.train() if __name__ == "__main__": main()