import argparse from typing import Any import tensorflow as tf class EasyDict(dict): def __getattr__(self, name: str) -> Any: try: return self[name] except KeyError: raise AttributeError(name) def __setattr__(self, name: str, value: Any) -> None: self[name] = value def __delattr__(self, name: str) -> None: del self[name] def params_args(args): parser = argparse.ArgumentParser() parser.add_argument( "--hop", type=int, default=256, help="Hop size (window size = 4*hop)", ) parser.add_argument( "--mel_bins", type=int, default=256, help="Mel bins in mel-spectrograms", ) parser.add_argument( "--sr", type=int, default=22050, help="Sampling Rate", ) parser.add_argument( "--latlen", type=int, default=256, help="Length of generated latent vectors", ) parser.add_argument( "--latdepth", type=int, default=64, help="Depth of generated latent vectors", ) parser.add_argument( "--shape", type=int, default=128, help="Length of spectrograms time axis", ) parser.add_argument( "--window", type=int, default=64, help="Generator spectrogram window (must divide shape)", ) parser.add_argument( "--mu_rescale", type=int, default=-25.0, help="Spectrogram mu used to normalize", ) parser.add_argument( "--sigma_rescale", type=int, default=75.0, help="Spectrogram sigma used to normalize", ) parser.add_argument( "--load_path_techno", type=str, default="checkpoints/techno/", help="Path of pretrained networks weights (techno)", ) parser.add_argument( "--load_path_classical", type=str, default="checkpoints/classical/", help="Path of pretrained networks weights (classical)", ) parser.add_argument( "--dec_path_techno", type=str, default="checkpoints/techno/", help="Path of pretrained decoders weights (techno)", ) parser.add_argument( "--dec_path_classical", type=str, default="checkpoints/classical/", help="Path of pretrained decoders weights (classical)", ) parser.add_argument( "--testing", type=bool, default=True, help="True if optimizers weight do not need to be loaded", ) parser.add_argument( "--cpu", type=bool, default=False, help="True if you wish to use cpu", ) parser.add_argument( "--mixed_precision", type=bool, default=True, help="True if your GPU supports mixed precision", ) tmp_args = parser.parse_args() args.hop = tmp_args.hop args.mel_bins = tmp_args.mel_bins args.sr = tmp_args.sr args.latlen = tmp_args.latlen args.latdepth = tmp_args.latdepth args.shape = tmp_args.shape args.window = tmp_args.window args.mu_rescale = tmp_args.mu_rescale args.sigma_rescale = tmp_args.sigma_rescale args.load_path_techno = tmp_args.load_path_techno args.load_path_classical = tmp_args.load_path_classical args.dec_path_techno = tmp_args.dec_path_techno args.dec_path_classical = tmp_args.dec_path_classical args.testing = tmp_args.testing args.cpu = tmp_args.cpu args.mixed_precision = tmp_args.mixed_precision print() args.datatype = tf.float32 gpuls = tf.config.list_physical_devices("GPU") if len(gpuls) == 0 or args.cpu: args.cpu = True args.mixed_precision = False tf.config.set_visible_devices([], "GPU") print() print("Using CPU...") print() if args.mixed_precision: args.datatype = tf.float16 print() print("Using GPU with mixed precision enabled...") print() if not args.mixed_precision and not args.cpu: print() print("Using GPU without mixed precision...") print() return args def parse_args(): args = EasyDict() return params_args(args)