Gosse Minnema
Re-enable LOME
2890e34
local env = import "../env.jsonnet";
local base = import "ace.jsonnet";
local dataset_path = env.str("DATA_PATH", "data/ace/events");
local debug = false;
# re-train
local pretrained_path = env.str("PRETRAINED_PATH", "cache/fn/best");
local rt_lr = env.json("RT_LR", 5e-5);
# module
local cuda_devices = base.cuda_devices;
{
dataset_reader: base.dataset_reader,
train_data_path: base.train_data_path,
validation_data_path: base.validation_data_path,
test_data_path: base.test_data_path,
datasets_for_vocab_creation: ["train"],
data_loader: base.data_loader,
validation_data_loader: base.validation_data_loader,
model: {
type: "span",
word_embedding: {
"_pretrained": {
"archive_file": pretrained_path,
"module_path": "word_embedding",
"freeze": false,
}
},
span_extractor: {
"_pretrained": {
"archive_file": pretrained_path,
"module_path": "_span_extractor",
"freeze": false,
}
},
span_finder: {
"_pretrained": {
"archive_file": pretrained_path,
"module_path": "_span_finder",
"freeze": false,
}
},
span_typing: {
type: 'mlp',
hidden_dims: base.model.span_typing.hidden_dims,
},
metrics: [{type: "srl"}],
typing_loss_factor: base.model.typing_loss_factor,
label_dim: base.model.label_dim,
max_decoding_spans: 128,
max_recursion_depth: 2,
debug: debug,
},
trainer: {
num_epochs: base.trainer.num_epochs,
patience: base.trainer.patience,
[if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0],
validation_metric: "+arg-c_f",
num_gradient_accumulation_steps: base.trainer.num_gradient_accumulation_steps,
optimizer: {
type: "transformer",
base: {
type: "adam",
lr: base.trainer.optimizer.base.lr,
},
embeddings_lr: 0.0,
encoder_lr: 1e-5,
pooler_lr: 1e-5,
layer_fix: base.trainer.optimizer.layer_fix,
parameter_groups: [
[['_span_finder.*'], {'lr': rt_lr}],
[['_span_extractor.*'], {'lr': rt_lr}],
]
}
},
[if std.length(cuda_devices) > 1 then "distributed"]: {
"cuda_devices": cuda_devices
},
[if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true
}