musika_api / parse_test.py
nakas's picture
musika clone
050507e
import argparse
from typing import Any
import tensorflow as tf
class EasyDict(dict):
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name: str, value: Any) -> None:
self[name] = value
def __delattr__(self, name: str) -> None:
del self[name]
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def params_args(args):
parser = argparse.ArgumentParser()
parser.add_argument(
"--hop",
type=int,
default=256,
help="Hop size (window size = 4*hop)",
)
parser.add_argument(
"--mel_bins",
type=int,
default=256,
help="Mel bins in mel-spectrograms",
)
parser.add_argument(
"--sr",
type=int,
default=44100,
help="Sampling Rate",
)
parser.add_argument(
"--small",
type=str2bool,
default=False,
help="If True, use model with shorter available context, useful for small datasets",
)
parser.add_argument(
"--latdepth",
type=int,
default=64,
help="Depth of generated latent vectors",
)
parser.add_argument(
"--coorddepth",
type=int,
default=64,
help="Dimension of latent coordinate and style random vectors",
)
parser.add_argument(
"--base_channels",
type=int,
default=128,
help="Base channels for generator and discriminator architectures",
)
parser.add_argument(
"--shape",
type=int,
default=128,
help="Length of spectrograms time axis",
)
parser.add_argument(
"--window",
type=int,
default=64,
help="Generator spectrogram window (must divide shape)",
)
parser.add_argument(
"--mu_rescale",
type=float,
default=-25.0,
help="Spectrogram mu used to normalize",
)
parser.add_argument(
"--sigma_rescale",
type=float,
default=75.0,
help="Spectrogram sigma used to normalize",
)
parser.add_argument(
"--load_path_1",
type=str,
default="checkpoints/techno/",
help="Path of pretrained networks weights 1",
)
parser.add_argument(
"--load_path_2",
type=str,
default="checkpoints/metal/",
help="Path of pretrained networks weights 2",
)
parser.add_argument(
"--load_path_3",
type=str,
default="checkpoints/misc/",
help="Path of pretrained networks weights 3",
)
parser.add_argument(
"--dec_path",
type=str,
default="checkpoints/ae/",
help="Path of pretrained decoders weights",
)
parser.add_argument(
"--testing",
type=str2bool,
default=True,
help="True if optimizers weight do not need to be loaded",
)
parser.add_argument(
"--cpu",
type=str2bool,
default=False,
help="True if you wish to use cpu",
)
parser.add_argument(
"--mixed_precision",
type=str2bool,
default=True,
help="True if your GPU supports mixed precision",
)
tmp_args = parser.parse_args()
args.hop = tmp_args.hop
args.mel_bins = tmp_args.mel_bins
args.sr = tmp_args.sr
args.small = tmp_args.small
args.latdepth = tmp_args.latdepth
args.coorddepth = tmp_args.coorddepth
args.base_channels = tmp_args.base_channels
args.shape = tmp_args.shape
args.window = tmp_args.window
args.mu_rescale = tmp_args.mu_rescale
args.sigma_rescale = tmp_args.sigma_rescale
args.load_path_1 = tmp_args.load_path_1
args.load_path_2 = tmp_args.load_path_2
args.load_path_3 = tmp_args.load_path_3
args.dec_path = tmp_args.dec_path
args.testing = tmp_args.testing
args.cpu = tmp_args.cpu
args.mixed_precision = tmp_args.mixed_precision
if args.small:
args.latlen = 128
else:
args.latlen = 256
args.coordlen = (args.latlen // 2) * 3
print()
args.datatype = tf.float32
gpuls = tf.config.list_physical_devices("GPU")
if len(gpuls) == 0 or args.cpu:
args.cpu = True
args.mixed_precision = False
tf.config.set_visible_devices([], "GPU")
print()
print("Using CPU...")
print()
if args.mixed_precision:
args.datatype = tf.float16
print()
print("Using GPU with mixed precision enabled...")
print()
if not args.mixed_precision and not args.cpu:
print()
print("Using GPU without mixed precision...")
print()
return args
def parse_args():
args = EasyDict()
return params_args(args)