|
class AttrDict(dict): |
|
def __init__(self, *args, **kwargs): |
|
super(AttrDict, self).__init__(*args, **kwargs) |
|
self.__dict__ = self |
|
|
|
def override(self, attrs): |
|
if isinstance(attrs, dict): |
|
self.__dict__.update(**attrs) |
|
elif isinstance(attrs, (list, tuple, set)): |
|
for attr in attrs: |
|
self.override(attr) |
|
elif attrs is not None: |
|
raise NotImplementedError |
|
return self |
|
|
|
|
|
|
|
params_chord = AttrDict( |
|
|
|
batch_size=16, |
|
max_epoch=10, |
|
learning_rate=5e-5, |
|
max_grad_norm=10, |
|
fp16=True, |
|
|
|
|
|
in_channels=2, |
|
out_channels=2, |
|
channels=64, |
|
attention_levels=[2, 3], |
|
n_res_blocks=2, |
|
channel_multipliers=[1, 2, 4, 4], |
|
n_heads=4, |
|
tf_layers=1, |
|
d_cond=12, |
|
|
|
|
|
linear_start=0.00085, |
|
linear_end=0.0120, |
|
n_steps=1000, |
|
latent_scaling_factor=0.18215 |
|
) |
|
|
|
|
|
|
|
params_chord_cond = AttrDict( |
|
|
|
batch_size=16, |
|
max_epoch=10, |
|
learning_rate=5e-5, |
|
max_grad_norm=10, |
|
fp16=True, |
|
|
|
|
|
in_channels=4, |
|
out_channels=2, |
|
channels=64, |
|
attention_levels=[2, 3], |
|
n_res_blocks=2, |
|
channel_multipliers=[1, 2, 4, 4], |
|
n_heads=4, |
|
tf_layers=1, |
|
d_cond=2, |
|
|
|
|
|
linear_start=0.00085, |
|
linear_end=0.0120, |
|
n_steps=1000, |
|
latent_scaling_factor=0.18215 |
|
) |
|
|
|
|
|
params_chord_lsh_cond = AttrDict( |
|
|
|
batch_size=16, |
|
max_epoch=10, |
|
learning_rate=5e-5, |
|
max_grad_norm=10, |
|
fp16=True, |
|
|
|
|
|
in_channels=6, |
|
out_channels=2, |
|
channels=64, |
|
attention_levels=[2, 3], |
|
n_res_blocks=2, |
|
channel_multipliers=[1, 2, 4, 4], |
|
n_heads=4, |
|
tf_layers=1, |
|
d_cond=2, |
|
|
|
|
|
linear_start=0.00085, |
|
linear_end=0.0120, |
|
n_steps=1000, |
|
latent_scaling_factor=0.18215 |
|
) |