tomaarsen HF staff commited on
Commit
9b02706
1 Parent(s): 2c6bf5e

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +87 -0
train.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import shutil
3
+ from datasets import load_dataset, concatenate_datasets
4
+ from transformers import TrainingArguments
5
+ from span_marker import SpanMarkerModel, Trainer
6
+ from span_marker.model_card import SpanMarkerModelCardData
7
+
8
+ import os
9
+
10
+ os.environ["CODECARBON_LOG_LEVEL"] = "error"
11
+
12
+
13
+ def main() -> None:
14
+ # Load the dataset, ensure "tokens" and "ner_tags" columns, and get a list of labels
15
+ dataset_id = "midas/inspec"
16
+ dataset_name = "Inspec"
17
+ dataset = load_dataset(dataset_id, "extraction")
18
+ dataset = dataset.rename_columns({"document": "tokens", "doc_bio_tags": "ner_tags"})
19
+ # Map string labels to integer labels instead
20
+ real_labels = ["O", "B", "I"]
21
+ dataset = dataset.map(lambda sample: {"ner_tags": [real_labels.index(tag) for tag in sample]}, input_columns="ner_tags")
22
+ # Use more readable labels
23
+ labels = ["O", "B-KEY", "I-KEY"]
24
+ # Train using train + validation set.
25
+ train_dataset = concatenate_datasets((dataset["train"], dataset["validation"]))
26
+
27
+ # Initialize a SpanMarker model using a pretrained BERT-style encoder
28
+ encoder_id = "bert-base-uncased"
29
+ model_id = "tomaarsen/span-marker_bert-base-uncased-keyphrase-inspec"
30
+ model = SpanMarkerModel.from_pretrained(
31
+ encoder_id,
32
+ labels=labels,
33
+ # SpanMarker hyperparameters:
34
+ model_max_length=256,
35
+ marker_max_length=128,
36
+ entity_max_length=8,
37
+ # Model card variables
38
+ model_card_data=SpanMarkerModelCardData(
39
+ model_id=model_id,
40
+ encoder_id=encoder_id,
41
+ dataset_name=dataset_name,
42
+ dataset_id=dataset_id,
43
+ license="apache-2.0",
44
+ language="en",
45
+ ),
46
+ )
47
+
48
+ # Prepare the 🤗 transformers training arguments
49
+ output_dir = Path("models") / model_id
50
+ args = TrainingArguments(
51
+ output_dir=output_dir,
52
+ hub_model_id=model_id,
53
+ run_name=f"bbu_keyphrase",
54
+ # Training Hyperparameters:
55
+ learning_rate=5e-5,
56
+ per_device_train_batch_size=32,
57
+ per_device_eval_batch_size=32,
58
+ num_train_epochs=3,
59
+ weight_decay=0.01,
60
+ warmup_ratio=0.1,
61
+ bf16=True, # Replace `bf16` with `fp16` if your hardware can't use bf16.
62
+ # Other Training parameters
63
+ logging_first_step=True,
64
+ logging_steps=50,
65
+ evaluation_strategy="no",
66
+ save_total_limit=2,
67
+ dataloader_num_workers=2,
68
+ )
69
+
70
+ # Initialize the trainer using our model, training args & dataset, and train
71
+ trainer = Trainer(
72
+ model=model,
73
+ args=args,
74
+ train_dataset=train_dataset
75
+ )
76
+ trainer.train()
77
+
78
+ # Compute & save the metrics on the test set
79
+ metrics = trainer.evaluate(dataset["test"], metric_key_prefix="test")
80
+ trainer.save_metrics("test", metrics)
81
+
82
+ trainer.save_model(output_dir / "checkpoint-final")
83
+ shutil.copy2(__file__, output_dir / "checkpoint-final" / "train.py")
84
+
85
+
86
+ if __name__ == "__main__":
87
+ main()