File size: 2,627 Bytes
2890e34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
local env = import "../env.jsonnet";
local base = import "ace.jsonnet";

local dataset_path = env.str("DATA_PATH", "data/ace/events");

local debug = false;

# re-train
local pretrained_path = env.str("PRETRAINED_PATH", "cache/fn/best");
local rt_lr = env.json("RT_LR", 5e-5);

# module
local cuda_devices = base.cuda_devices;

{
    dataset_reader: base.dataset_reader,
    train_data_path: base.train_data_path,
    validation_data_path: base.validation_data_path,
    test_data_path: base.test_data_path,

    datasets_for_vocab_creation: ["train"],

    data_loader: base.data_loader,
    validation_data_loader: base.validation_data_loader,

    model: {
        type: "span",
        word_embedding: {
            "_pretrained": {
                "archive_file": pretrained_path,
                "module_path": "word_embedding",
                "freeze": false,
            }
        },
        span_extractor: {
            "_pretrained": {
                "archive_file": pretrained_path,
                "module_path": "_span_extractor",
                "freeze": false,
            }
        },
        span_finder: {
            "_pretrained": {
                "archive_file": pretrained_path,
                "module_path": "_span_finder",
                "freeze": false,
            }
        },
        span_typing: {
            type: 'mlp',
            hidden_dims: base.model.span_typing.hidden_dims,
        },
        metrics: [{type: "srl"}],

        typing_loss_factor: base.model.typing_loss_factor,
        label_dim: base.model.label_dim,
        max_decoding_spans: 128,
        max_recursion_depth: 2,
        debug: debug,
    },

    trainer: {
        num_epochs: base.trainer.num_epochs,
        patience: base.trainer.patience,
        [if std.length(cuda_devices) == 1 then "cuda_device"]: cuda_devices[0],
        validation_metric: "+arg-c_f",
        num_gradient_accumulation_steps: base.trainer.num_gradient_accumulation_steps,
        optimizer: {
            type: "transformer",
            base: {
                type: "adam",
                lr: base.trainer.optimizer.base.lr,
            },
            embeddings_lr: 0.0,
            encoder_lr: 1e-5,
            pooler_lr: 1e-5,
            layer_fix: base.trainer.optimizer.layer_fix,
            parameter_groups: [
            [['_span_finder.*'], {'lr': rt_lr}],
            [['_span_extractor.*'], {'lr': rt_lr}],
            ]
        }
    },

    [if std.length(cuda_devices) > 1 then "distributed"]: {
        "cuda_devices": cuda_devices
    },
    [if std.length(cuda_devices) == 1 then "evaluate_on_test"]: true
}