import argparse from typing import Any import tensorflow as tf class EasyDict(dict): def __getattr__(self, name: str) -> Any: try: return self[name] except KeyError: raise AttributeError(name) def __setattr__(self, name: str, value: Any) -> None: self[name] = value def __delattr__(self, name: str) -> None: del self[name] def str2bool(v): if isinstance(v, bool): return v if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise argparse.ArgumentTypeError("Boolean value expected.") def params_args(args): parser = argparse.ArgumentParser() parser.add_argument( "--hop", type=int, default=256, help="Hop size (window size = 4*hop)", ) parser.add_argument( "--mel_bins", type=int, default=256, help="Mel bins in mel-spectrograms", ) parser.add_argument( "--sr", type=int, default=44100, help="Sampling Rate", ) parser.add_argument( "--small", type=str2bool, default=False, help="If True, use model with shorter available context, useful for small datasets", ) parser.add_argument( "--latdepth", type=int, default=64, help="Depth of generated latent vectors", ) parser.add_argument( "--coorddepth", type=int, default=64, help="Dimension of latent coordinate and style random vectors", ) parser.add_argument( "--base_channels", type=int, default=128, help="Base channels for generator and discriminator architectures", ) parser.add_argument( "--shape", type=int, default=128, help="Length of spectrograms time axis", ) parser.add_argument( "--window", type=int, default=64, help="Generator spectrogram window (must divide shape)", ) parser.add_argument( "--mu_rescale", type=float, default=-25.0, help="Spectrogram mu used to normalize", ) parser.add_argument( "--sigma_rescale", type=float, default=75.0, help="Spectrogram sigma used to normalize", ) parser.add_argument( "--load_path_1", type=str, default="checkpoints/techno/", help="Path of pretrained networks weights 1", ) parser.add_argument( "--load_path_2", type=str, default="checkpoints/metal/", help="Path of pretrained networks weights 2", ) parser.add_argument( "--load_path_3", type=str, default="checkpoints/misc/", help="Path of pretrained networks weights 3", ) parser.add_argument( "--dec_path", type=str, default="checkpoints/ae/", help="Path of pretrained decoders weights", ) parser.add_argument( "--testing", type=str2bool, default=True, help="True if optimizers weight do not need to be loaded", ) parser.add_argument( "--cpu", type=str2bool, default=False, help="True if you wish to use cpu", ) parser.add_argument( "--mixed_precision", type=str2bool, default=True, help="True if your GPU supports mixed precision", ) tmp_args = parser.parse_args() args.hop = tmp_args.hop args.mel_bins = tmp_args.mel_bins args.sr = tmp_args.sr args.small = tmp_args.small args.latdepth = tmp_args.latdepth args.coorddepth = tmp_args.coorddepth args.base_channels = tmp_args.base_channels args.shape = tmp_args.shape args.window = tmp_args.window args.mu_rescale = tmp_args.mu_rescale args.sigma_rescale = tmp_args.sigma_rescale args.load_path_1 = tmp_args.load_path_1 args.load_path_2 = tmp_args.load_path_2 args.load_path_3 = tmp_args.load_path_3 args.dec_path = tmp_args.dec_path args.testing = tmp_args.testing args.cpu = tmp_args.cpu args.mixed_precision = tmp_args.mixed_precision if args.small: args.latlen = 128 else: args.latlen = 256 args.coordlen = (args.latlen // 2) * 3 print() args.datatype = tf.float32 gpuls = tf.config.list_physical_devices("GPU") if len(gpuls) == 0 or args.cpu: args.cpu = True args.mixed_precision = False tf.config.set_visible_devices([], "GPU") print() print("Using CPU...") print() if args.mixed_precision: args.datatype = tf.float16 print() print("Using GPU with mixed precision enabled...") print() if not args.mixed_precision and not args.cpu: print() print("Using GPU without mixed precision...") print() return args def parse_args(): args = EasyDict() return params_args(args)