tomaarsen HF staff commited on
Commit
7817dee
·
1 Parent(s): db46fd3

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +165 -0
train.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import TrainingArguments
3
+
4
+ from span_marker import SpanMarkerModel, Trainer
5
+
6
+
7
+ def main() -> None:
8
+ # Load the dataset, ensure "tokens" and "ner_tags" columns, and get a list of labels
9
+ dataset = "Babelscape/multinerd"
10
+ train_dataset = load_dataset(dataset, split="train")
11
+ eval_dataset = load_dataset(dataset, split="validation").shuffle().select(range(3000))
12
+ labels = [
13
+ "O",
14
+ "B-PER",
15
+ "I-PER",
16
+ "B-ORG",
17
+ "I-ORG",
18
+ "B-LOC",
19
+ "I-LOC",
20
+ "B-ANIM",
21
+ "I-ANIM",
22
+ "B-BIO",
23
+ "I-BIO",
24
+ "B-CEL",
25
+ "I-CEL",
26
+ "B-DIS",
27
+ "I-DIS",
28
+ "B-EVE",
29
+ "I-EVE",
30
+ "B-FOOD",
31
+ "I-FOOD",
32
+ "B-INST",
33
+ "I-INST",
34
+ "B-MEDIA",
35
+ "I-MEDIA",
36
+ "B-MYTH",
37
+ "I-MYTH",
38
+ "B-PLANT",
39
+ "I-PLANT",
40
+ "B-TIME",
41
+ "I-TIME",
42
+ "B-VEHI",
43
+ "I-VEHI",
44
+ ]
45
+
46
+ # Initialize a SpanMarker model using a pretrained BERT-style encoder
47
+ model_name = "bert-base-multilingual-cased"
48
+ model = SpanMarkerModel.from_pretrained(
49
+ model_name,
50
+ labels=labels,
51
+ # SpanMarker hyperparameters:
52
+ model_max_length=256,
53
+ marker_max_length=128,
54
+ entity_max_length=8,
55
+ )
56
+
57
+ # Prepare the 🤗 transformers training arguments
58
+ args = TrainingArguments(
59
+ output_dir="models/span_marker_mbert_base_multinerd",
60
+ # Training Hyperparameters:
61
+ learning_rate=5e-5,
62
+ per_device_train_batch_size=32,
63
+ per_device_eval_batch_size=32,
64
+ # gradient_accumulation_steps=2,
65
+ num_train_epochs=1,
66
+ weight_decay=0.01,
67
+ warmup_ratio=0.1,
68
+ bf16=True, # Replace `bf16` with `fp16` if your hardware can't use bf16.
69
+ # Other Training parameters
70
+ logging_first_step=True,
71
+ logging_steps=50,
72
+ evaluation_strategy="steps",
73
+ save_strategy="steps",
74
+ eval_steps=1000,
75
+ save_total_limit=2,
76
+ dataloader_num_workers=2,
77
+ )
78
+
79
+ # Initialize the trainer using our model, training args & dataset, and train
80
+ trainer = Trainer(
81
+ model=model,
82
+ args=args,
83
+ train_dataset=train_dataset,
84
+ eval_dataset=eval_dataset,
85
+ )
86
+ trainer.train()
87
+ trainer.save_model("models/span_marker_mbert_base_multinerd/checkpoint-final")
88
+
89
+ test_dataset = load_dataset(dataset, split="test")
90
+ # Compute & save the metrics on the test set
91
+ metrics = trainer.evaluate(test_dataset, metric_key_prefix="test")
92
+ trainer.save_metrics("test", metrics)
93
+
94
+ trainer.create_model_card(language="multilingual", license="apache-2.0")
95
+
96
+
97
+ if __name__ == "__main__":
98
+ main()
99
+
100
+ """
101
+ Logs:
102
+
103
+ Training set:
104
+ This SpanMarker model will ignore 0.793782% of all annotated entities in the train dataset. This is caused by the SpanMarkerModel maximum entity length of 8 words and the maximum model input length of 256 tokens.
105
+ These are the frequencies of the missed entities due to maximum entity length out of 4111958 total entities:
106
+ - 12680 missed entities with 9 words (0.308369%)
107
+ - 7308 missed entities with 10 words (0.177726%)
108
+ - 4414 missed entities with 11 words (0.107345%)
109
+ - 2474 missed entities with 12 words (0.060166%)
110
+ - 1894 missed entities with 13 words (0.046061%)
111
+ - 1130 missed entities with 14 words (0.027481%)
112
+ - 744 missed entities with 15 words (0.018094%)
113
+ - 582 missed entities with 16 words (0.014154%)
114
+ - 344 missed entities with 17 words (0.008366%)
115
+ - 226 missed entities with 18 words (0.005496%)
116
+ - 84 missed entities with 19 words (0.002043%)
117
+ - 46 missed entities with 20 words (0.001119%)
118
+ - 20 missed entities with 21 words (0.000486%)
119
+ - 20 missed entities with 22 words (0.000486%)
120
+ - 12 missed entities with 23 words (0.000292%)
121
+ - 18 missed entities with 24 words (0.000438%)
122
+ - 2 missed entities with 25 words (0.000049%)
123
+ - 4 missed entities with 26 words (0.000097%)
124
+ - 4 missed entities with 27 words (0.000097%)
125
+ - 2 missed entities with 31 words (0.000049%)
126
+ - 8 missed entities with 32 words (0.000195%)
127
+ - 6 missed entities with 33 words (0.000146%)
128
+ - 2 missed entities with 34 words (0.000049%)
129
+ - 4 missed entities with 36 words (0.000097%)
130
+ - 8 missed entities with 37 words (0.000195%)
131
+ - 2 missed entities with 38 words (0.000049%)
132
+ - 2 missed entities with 41 words (0.000049%)
133
+ - 2 missed entities with 72 words (0.000049%)
134
+ Additionally, a total of 598 (0.014543%) entities were missed due to the maximum input length.
135
+
136
+ Validation set:
137
+ This SpanMarker model won't be able to predict 0.656224% of all annotated entities in the evaluation dataset. This is caused by the SpanMarkerModel maximum entity length of 8 words.
138
+ These are the frequencies of the missed entities due to maximum entity length out of 4724 total entities:
139
+ - 19 missed entities with 9 words (0.402202%)
140
+ - 7 missed entities with 10 words (0.148180%)
141
+ - 1 missed entities with 11 words (0.021169%)
142
+ - 1 missed entities with 12 words (0.021169%)
143
+ - 2 missed entities with 13 words (0.042337%)
144
+ - 1 missed entities with 16 words (0.021169%)
145
+
146
+ Testing set:
147
+ This SpanMarker model won't be able to predict 0.794755% of all annotated entities in the evaluation dataset. This is caused by the SpanMarkerModel maximum entity length of 8 words and the maximum model input length of 256 tokens.
148
+ These are the frequencies of the missed entities due to maximum entity length out of 511856 total entities:
149
+ - 1520 missed entities with 9 words (0.296959%)
150
+ - 870 missed entities with 10 words (0.169970%)
151
+ - 640 missed entities with 11 words (0.125035%)
152
+ - 346 missed entities with 12 words (0.067597%)
153
+ - 224 missed entities with 13 words (0.043762%)
154
+ - 172 missed entities with 14 words (0.033603%)
155
+ - 72 missed entities with 15 words (0.014066%)
156
+ - 66 missed entities with 16 words (0.012894%)
157
+ - 16 missed entities with 17 words (0.003126%)
158
+ - 14 missed entities with 18 words (0.002735%)
159
+ - 14 missed entities with 19 words (0.002735%)
160
+ - 2 missed entities with 20 words (0.000391%)
161
+ - 12 missed entities with 21 words (0.002344%)
162
+ - 2 missed entities with 24 words (0.000391%)
163
+ - 2 missed entities with 25 words (0.000391%)
164
+ Additionally, a total of 96 (0.018755%) entities were missed due to the maximum input length.
165
+ """