musika / parse_test.py
marcop's picture
Add first version
d6c7221
raw
history blame
No virus
4.28 kB
import argparse
from typing import Any
import tensorflow as tf
class EasyDict(dict):
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name: str, value: Any) -> None:
self[name] = value
def __delattr__(self, name: str) -> None:
del self[name]
def params_args(args):
parser = argparse.ArgumentParser()
parser.add_argument(
"--hop",
type=int,
default=256,
help="Hop size (window size = 4*hop)",
)
parser.add_argument(
"--mel_bins",
type=int,
default=256,
help="Mel bins in mel-spectrograms",
)
parser.add_argument(
"--sr",
type=int,
default=22050,
help="Sampling Rate",
)
parser.add_argument(
"--latlen",
type=int,
default=256,
help="Length of generated latent vectors",
)
parser.add_argument(
"--latdepth",
type=int,
default=64,
help="Depth of generated latent vectors",
)
parser.add_argument(
"--shape",
type=int,
default=128,
help="Length of spectrograms time axis",
)
parser.add_argument(
"--window",
type=int,
default=64,
help="Generator spectrogram window (must divide shape)",
)
parser.add_argument(
"--mu_rescale",
type=int,
default=-25.0,
help="Spectrogram mu used to normalize",
)
parser.add_argument(
"--sigma_rescale",
type=int,
default=75.0,
help="Spectrogram sigma used to normalize",
)
parser.add_argument(
"--load_path_techno",
type=str,
default="checkpoints/techno/",
help="Path of pretrained networks weights (techno)",
)
parser.add_argument(
"--load_path_classical",
type=str,
default="checkpoints/classical/",
help="Path of pretrained networks weights (classical)",
)
parser.add_argument(
"--dec_path_techno",
type=str,
default="checkpoints/techno/",
help="Path of pretrained decoders weights (techno)",
)
parser.add_argument(
"--dec_path_classical",
type=str,
default="checkpoints/classical/",
help="Path of pretrained decoders weights (classical)",
)
parser.add_argument(
"--testing",
type=bool,
default=True,
help="True if optimizers weight do not need to be loaded",
)
parser.add_argument(
"--cpu",
type=bool,
default=False,
help="True if you wish to use cpu",
)
parser.add_argument(
"--mixed_precision",
type=bool,
default=True,
help="True if your GPU supports mixed precision",
)
tmp_args = parser.parse_args()
args.hop = tmp_args.hop
args.mel_bins = tmp_args.mel_bins
args.sr = tmp_args.sr
args.latlen = tmp_args.latlen
args.latdepth = tmp_args.latdepth
args.shape = tmp_args.shape
args.window = tmp_args.window
args.mu_rescale = tmp_args.mu_rescale
args.sigma_rescale = tmp_args.sigma_rescale
args.load_path_techno = tmp_args.load_path_techno
args.load_path_classical = tmp_args.load_path_classical
args.dec_path_techno = tmp_args.dec_path_techno
args.dec_path_classical = tmp_args.dec_path_classical
args.testing = tmp_args.testing
args.cpu = tmp_args.cpu
args.mixed_precision = tmp_args.mixed_precision
print()
args.datatype = tf.float32
gpuls = tf.config.list_physical_devices("GPU")
if len(gpuls) == 0 or args.cpu:
args.cpu = True
args.mixed_precision = False
tf.config.set_visible_devices([], "GPU")
print()
print("Using CPU...")
print()
if args.mixed_precision:
args.datatype = tf.float16
print()
print("Using GPU with mixed precision enabled...")
print()
if not args.mixed_precision and not args.cpu:
print()
print("Using GPU without mixed precision...")
print()
return args
def parse_args():
args = EasyDict()
return params_args(args)