|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
import torch |
|
|
|
from models.tta.autoencoder.autoencoder_trainer import AutoencoderKLTrainer |
|
from models.tta.ldm.audioldm_trainer import AudioLDMTrainer |
|
from utils.util import load_config |
|
|
|
|
|
def build_trainer(args, cfg): |
|
supported_trainer = { |
|
"AutoencoderKL": AutoencoderKLTrainer, |
|
"AudioLDM": AudioLDMTrainer, |
|
} |
|
|
|
trainer_class = supported_trainer[cfg.model_type] |
|
trainer = trainer_class(args, cfg) |
|
return trainer |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--config", |
|
default="config.json", |
|
help="json files for configurations.", |
|
required=True, |
|
) |
|
parser.add_argument( |
|
"--num_workers", type=int, default=6, help="Number of dataloader workers." |
|
) |
|
parser.add_argument( |
|
"--exp_name", |
|
type=str, |
|
default="exp_name", |
|
help="A specific name to note the experiment", |
|
required=True, |
|
) |
|
parser.add_argument( |
|
"--resume", |
|
type=str, |
|
default=None, |
|
|
|
help="The model name to restore", |
|
) |
|
parser.add_argument( |
|
"--log_level", default="info", help="logging level (info, debug, warning)" |
|
) |
|
parser.add_argument("--stdout_interval", default=5, type=int) |
|
parser.add_argument("--local_rank", default=-1, type=int) |
|
args = parser.parse_args() |
|
cfg = load_config(args.config) |
|
cfg.exp_name = args.exp_name |
|
|
|
|
|
args.log_dir = os.path.join(cfg.log_dir, args.exp_name) |
|
os.makedirs(args.log_dir, exist_ok=True) |
|
|
|
if not cfg.train.ddp: |
|
args.local_rank = torch.device("cuda") |
|
|
|
|
|
trainer = build_trainer(args, cfg) |
|
|
|
|
|
if args.resume: |
|
trainer.restore() |
|
trainer.train() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|