from __gin__ import dynamic_registration from gamadhani import src from gamadhani.src import dataset from gamadhani.src import model_transformer from gamadhani.src import task_functions from gamadhani.utils import utils import torch.optim MODEL_DIM = 512 EMB_DIM = 512 NUM_TOKENS = 7928 NUM_QUANTIZERS = 1 DROPOUT_RATE = 0.3 NUM_HEADS = 8 SEQ_LEN = 1200 HEAD_DIM = 32 NUM_LAYERS = 8 LR = 1e-3 model_transformer.XTransformerPrior: num_tokens = %NUM_TOKENS seq_len = %SEQ_LEN model_dim = %MODEL_DIM emb_dim = %EMB_DIM head_dim = %HEAD_DIM num_layers = %NUM_LAYERS num_heads = %NUM_HEADS dropout_rate = %DROPOUT_RATE src.dataset.Task: read_fn = @src.task_functions.pitch_read_downsample invert_fn = @src.task_functions.invert_pitch_read_downsample kwargs = {"seq_len": %SEQ_LEN, "decoder_key": "pitch", "min_norm_pitch": -4915, "time_downsample": 2, "pitch_downsample": 10, "base_tonic": 440.} src.dataset.SequenceDataset: task = @dataset.Task() apply_transform = False model_transformer.XTransformerPrior.configure_optimizers: optimizer_cls = @torch.optim.AdamW scheduler_cls = @utils.build_warmed_exponential_lr_scheduler utils.build_warmed_exponential_lr_scheduler: start_factor = .01 peak_iteration = 10000 cycle_length = 394600 eta_min = 0.1 eta_max = %LR utils.set_seed: seed = 2023 torch.optim.AdamW: lr = %LR betas = (.9, .98)