Spaces:
Build error
Build error
local env = import "../env.jsonnet"; | |
local base = import "ace.jsonnet"; | |
local dataset_path = env.str("DATA_PATH", "data/ace/events"); | |
local debug = false; | |
# re-train | |
local pretrained_path = env.str("PRETRAINED_PATH", "cache/fn/best"); | |
local rt_lr = env.json("RT_LR", 5e-5); | |
# module | |
local cuda_devices = base.cuda_devices; | |
{ | |
dataset_reader: base.dataset_reader, | |
train_data_path: base.train_data_path, | |
validation_data_path: base.validation_data_path, | |
test_data_path: base.test_data_path, | |
datasets_for_vocab_creation: ["train"], | |
data_loader: base.data_loader, | |
validation_data_loader: base.validation_data_loader, | |
model: { | |
type: "span", | |
word_embedding: { | |
"_pretrained": { | |
"archive_file": pretrained_path, | |
"module_path": "word_embedding", | |
"freeze": false, | |
} | |
}, | |
span_extractor: { | |
"_pretrained": { | |
"archive_file": pretrained_path, | |
"module_path": "_span_extractor", | |
"freeze": false, | |
} | |
}, | |
span_finder: { | |
"_pretrained": { | |
"archive_file": pretrained_path, | |
"module_path": "_span_finder", | |
"freeze": false, | |
} | |
}, | |
span_typing: { | |
type: 'mlp', | |
hidden_dims: base.model.span_typing.hidden_dims, | |
}, | |
metrics: [{type: "srl"}], | |
typing_loss_factor: base.model.typing_loss_factor, | |
label_dim: base.model.label_dim, | |
max_decoding_spans: 128, | |
max_recursion_depth: 2, | |
debug: debug, | |
}, | |
trainer: { | |
num_epochs: base.trainer.num_epochs, | |
patience: base.trainer.patience, | |
[if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0], | |
validation_metric: "+arg-c_f", | |
num_gradient_accumulation_steps: base.trainer.num_gradient_accumulation_steps, | |
optimizer: { | |
type: "transformer", | |
base: { | |
type: "adam", | |
lr: base.trainer.optimizer.base.lr, | |
}, | |
embeddings_lr: 0.0, | |
encoder_lr: 1e-5, | |
pooler_lr: 1e-5, | |
layer_fix: base.trainer.optimizer.layer_fix, | |
parameter_groups: [ | |
[['_span_finder.*'], {'lr': rt_lr}], | |
[['_span_extractor.*'], {'lr': rt_lr}], | |
] | |
} | |
}, | |
[if std.length(cuda_devices) > 1 then "distributed"]: { | |
"cuda_devices": cuda_devices | |
}, | |
[if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true | |
} | |