File size: 6,439 Bytes
f1d3ae0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from datasets import load_dataset
from transformers import TrainingArguments

from span_marker import SpanMarkerModel, Trainer


def main() -> None:
    # Load the dataset, ensure "tokens" and "ner_tags" columns, and get a list of labels
    dataset = "Babelscape/multinerd"
    train_dataset = load_dataset(dataset, split="train")
    eval_dataset = load_dataset(dataset, split="validation").shuffle().select(range(3000))
    labels = [
        "O",
        "B-PER",
        "I-PER",
        "B-ORG",
        "I-ORG",
        "B-LOC",
        "I-LOC",
        "B-ANIM",
        "I-ANIM",
        "B-BIO",
        "I-BIO",
        "B-CEL",
        "I-CEL",
        "B-DIS",
        "I-DIS",
        "B-EVE",
        "I-EVE",
        "B-FOOD",
        "I-FOOD",
        "B-INST",
        "I-INST",
        "B-MEDIA",
        "I-MEDIA",
        "B-MYTH",
        "I-MYTH",
        "B-PLANT",
        "I-PLANT",
        "B-TIME",
        "I-TIME",
        "B-VEHI",
        "I-VEHI",
    ]

    # Initialize a SpanMarker model using a pretrained BERT-style encoder
    model_name = "xlm-roberta-base"
    model = SpanMarkerModel.from_pretrained(
        model_name,
        labels=labels,
        # SpanMarker hyperparameters:
        model_max_length=256,
        marker_max_length=128,
        entity_max_length=6,
    )

    # Prepare the 🤗 transformers training arguments
    args = TrainingArguments(
        output_dir="models/span_marker_xlm_roberta_base_multinerd",
        # Training Hyperparameters:
        learning_rate=1e-5,
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        # gradient_accumulation_steps=2,
        num_train_epochs=1,
        weight_decay=0.01,
        warmup_ratio=0.1,
        bf16=True,  # Replace `bf16` with `fp16` if your hardware can't use bf16.
        # Other Training parameters
        logging_first_step=True,
        logging_steps=50,
        evaluation_strategy="steps",
        save_strategy="steps",
        eval_steps=1000,
        save_total_limit=2,
        dataloader_num_workers=2,
    )

    # Initialize the trainer using our model, training args & dataset, and train
    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
    )
    trainer.train()
    trainer.save_model("models/span_marker_xlm_roberta_base_multinerd/checkpoint-final")

    test_dataset = load_dataset(dataset, split="test")
    # Compute & save the metrics on the test set
    metrics = trainer.evaluate(test_dataset, metric_key_prefix="test")
    trainer.save_metrics("test", metrics)


if __name__ == "__main__":
    main()

"""
This SpanMarker model will ignore 2.239322% of all annotated entities in the train dataset. This is caused by the SpanMarkerModel maximum entity length of 6 words and the maximum model input length of 256 tokens.
These are the frequencies of the missed entities due to maximum entity length out of 4111958 total entities:
- 35814 missed entities with 7 words (0.870972%)
- 21246 missed entities with 8 words (0.516688%)
- 12680 missed entities with 9 words (0.308369%)
- 7308 missed entities with 10 words (0.177726%)
- 4414 missed entities with 11 words (0.107345%)
- 2474 missed entities with 12 words (0.060166%)
- 1894 missed entities with 13 words (0.046061%)
- 1130 missed entities with 14 words (0.027481%)
- 744 missed entities with 15 words (0.018094%)
- 582 missed entities with 16 words (0.014154%)
- 344 missed entities with 17 words (0.008366%)
- 226 missed entities with 18 words (0.005496%)
- 84 missed entities with 19 words (0.002043%)
- 46 missed entities with 20 words (0.001119%)
- 20 missed entities with 21 words (0.000486%)
- 20 missed entities with 22 words (0.000486%)
- 12 missed entities with 23 words (0.000292%)
- 18 missed entities with 24 words (0.000438%)
- 2 missed entities with 25 words (0.000049%)
- 4 missed entities with 26 words (0.000097%)
- 4 missed entities with 27 words (0.000097%)
- 2 missed entities with 31 words (0.000049%)
- 8 missed entities with 32 words (0.000195%)
- 6 missed entities with 33 words (0.000146%)
- 2 missed entities with 34 words (0.000049%)
- 4 missed entities with 36 words (0.000097%)
- 8 missed entities with 37 words (0.000195%)
- 2 missed entities with 38 words (0.000049%)
- 2 missed entities with 41 words (0.000049%)
- 2 missed entities with 72 words (0.000049%)
Additionally, a total of 2978 (0.072423%) entities were missed due to the maximum input length.

This SpanMarker model won't be able to predict 2.501087% of all annotated entities in the evaluation dataset. This is caused by the SpanMarkerModel maximum entity length of 6 words.
These are the frequencies of the missed entities due to maximum entity length out of 4598 total entities:
- 45 missed entities with 7 words (0.978686%)
- 27 missed entities with 8 words (0.587212%)
- 21 missed entities with 9 words (0.456720%)
- 9 missed entities with 10 words (0.195737%)
- 3 missed entities with 12 words (0.065246%)
- 4 missed entities with 13 words (0.086994%)
- 3 missed entities with 14 words (0.065246%)
- 1 missed entities with 15 words (0.021749%)
- 1 missed entities with 16 words (0.021749%)
- 1 missed entities with 20 words (0.021749%)
"""

"""
wandb: Run summary:
wandb:                      eval/loss 0.00594
wandb:          eval/overall_accuracy 0.98181
wandb:                eval/overall_f1 0.90333
wandb:         eval/overall_precision 0.91259
wandb:            eval/overall_recall 0.89427
wandb:                   eval/runtime 21.4308
wandb:        eval/samples_per_second 154.171
wandb:          eval/steps_per_second 4.853
wandb:                      test/loss 0.00559
wandb:          test/overall_accuracy 0.98247
wandb:                test/overall_f1 0.91314
wandb:         test/overall_precision 0.91994
wandb:            test/overall_recall 0.90643
wandb:                   test/runtime 2202.6894
wandb:        test/samples_per_second 169.652
wandb:          test/steps_per_second 5.302
wandb:                    train/epoch 1.0
wandb:              train/global_step 93223
wandb:            train/learning_rate 0.0
wandb:                     train/loss 0.0049
wandb:               train/total_flos 7.851073325660897e+17
wandb:               train/train_loss 0.01782
wandb:            train/train_runtime 41756.9748
wandb: train/train_samples_per_second 71.44
wandb:   train/train_steps_per_second 2.233
"""