Gosse Minnema
Re-enable LOME
2890e34
local env = import "../env.jsonnet";
local base = import "ace.jsonnet";
local pretrained_path = env.str("PRETRAINED_PATH", "cache/ace/best");
local lr = env.json("FT_LR", 5e-5);
# training
local cuda_devices = base.cuda_devices;
{
dataset_reader: base.dataset_reader,
train_data_path: base.train_data_path,
validation_data_path: base.validation_data_path,
test_data_path: base.test_data_path,
datasets_for_vocab_creation: ["train"],
data_loader: base.data_loader,
validation_data_loader: base.validation_data_loader,
model: {
type: "from_archive",
archive_file: pretrained_path
},
vocabulary: {
type: "from_files",
directory: pretrained_path + "/vocabulary"
},
trainer: {
num_epochs: base.trainer.num_epochs,
patience: base.trainer.patience,
[if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0],
validation_metric: "+arg-c_f",
num_gradient_accumulation_steps: base.trainer.num_gradient_accumulation_steps,
optimizer: {
type: "transformer",
base: {
type: "adam",
lr: lr,
},
embeddings_lr: 0.0,
encoder_lr: 1e-5,
pooler_lr: 1e-5,
layer_fix: base.trainer.optimizer.layer_fix,
}
},
[if std.length(cuda_devices) > 1 then "distributed"]: {
"cuda_devices": cuda_devices
},
[if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true
}