tomaarsen HF staff commited on
Commit
f489e24
1 Parent(s): fa3f3bd

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +59 -0
train.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from span_marker import SpanMarkerModel, Trainer
3
+ from transformers import TrainingArguments
4
+
5
+
6
+ def main() -> None:
7
+ # Load the dataset, ensure "tokens" and "ner_tags" columns, and get a list of labels
8
+ dataset = load_dataset("conll2003")
9
+ labels = dataset["train"].features["ner_tags"].feature.names
10
+
11
+ # Initialize a SpanMarker model using a pretrained BERT-style encoder
12
+ model_name = "xlm-roberta-large"
13
+ model = SpanMarkerModel.from_pretrained(
14
+ model_name,
15
+ labels=labels,
16
+ # SpanMarker hyperparameters:
17
+ model_max_length=128,
18
+ marker_max_length=64,
19
+ entity_max_length=6,
20
+ )
21
+
22
+ # Prepare the 🤗 transformers training arguments
23
+ args = TrainingArguments(
24
+ output_dir="models/span_marker_xlm_roberta_large_conll03",
25
+ # Training Hyperparameters:
26
+ learning_rate=1e-5,
27
+ per_device_train_batch_size=4,
28
+ per_device_eval_batch_size=4,
29
+ gradient_accumulation_steps=2,
30
+ num_train_epochs=3,
31
+ weight_decay=0.01,
32
+ warmup_ratio=0.1,
33
+ bf16=True,
34
+ # Other Training parameters
35
+ logging_first_step=True,
36
+ logging_steps=50,
37
+ evaluation_strategy="steps",
38
+ save_strategy="steps",
39
+ eval_steps=1000,
40
+ dataloader_num_workers=2,
41
+ )
42
+
43
+ # Initialize the trainer using our model, training args & dataset, and train
44
+ trainer = Trainer(
45
+ model=model,
46
+ args=args,
47
+ train_dataset=dataset["train"],
48
+ eval_dataset=dataset["validation"],
49
+ )
50
+ trainer.train()
51
+ trainer.save_model("models/span_marker_xlm_roberta_large_conll03/checkpoint-final")
52
+
53
+ # Compute & save the metrics on the test set
54
+ metrics = trainer.evaluate(dataset["test"], metric_key_prefix="test")
55
+ trainer.save_metrics("test", metrics)
56
+
57
+
58
+ if __name__ == "__main__":
59
+ main()