import matplotlib.pyplot as plt import json import torch import torchaudio def configure_args(config, args): for key in ["stage", "corpus_type", "source_path", "aux_path", "preprocessed_path"]: if getattr(args, key) != None: config["general"][key] = str(getattr(args, key)) for key in ["n_train", "n_val", "n_test"]: if getattr(args, key) != None: config["preprocess"][key] = getattr(args, key) for key in ["alpha", "beta", "learning_rate", "epoch"]: if getattr(args, key) != None: config["train"][key] = getattr(args, key) for key in ["load_pretrained", "early_stopping"]: config["train"][key] = getattr(args, key) if args.feature_loss_type != None: config["train"]["feature_loss"]["type"] = args.feature_loss_type for key in ["pretrained_path"]: if getattr(args, key) != None: config["train"][key] = str(getattr(args, key)) return config, args