dcfidalgo commited on
Commit
435eb54
1 Parent(s): b37798c

add training script

Browse files
Files changed (1) hide show
  1. zeroshot_training_script.py +247 -0
zeroshot_training_script.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # # Creating a Zero-Shot classifier based on BETO
5
+ #
6
+ # This notebook/script fine-tunes a BETO (spanish bert, 'dccuchile/bert-base-spanish-wwm-cased') model on the spanish XNLI dataset.
7
+ # The fine-tuned model can then be fed to a Huggingface ZeroShot pipeline to obtain a ZeroShot classifier.
8
+
9
+ # In[ ]:
10
+
11
+
12
+ from datasets import load_dataset, Dataset, load_metric, load_from_disk
13
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
14
+ from transformers import Trainer, TrainingArguments
15
+ import torch
16
+ from pathlib import Path
17
+ # from ray import tune
18
+ # from ray.tune.suggest.hyperopt import HyperOptSearch
19
+ # from ray.tune.schedulers import ASHAScheduler
20
+
21
+
22
+ # # Prepare the datasets
23
+
24
+ # In[ ]:
25
+
26
+
27
+ xnli_es = load_dataset("xnli", "es")
28
+
29
+
30
+ # In[ ]:
31
+
32
+
33
+ xnli_es
34
+
35
+
36
+ # >joeddav
37
+ # >Aug '20
38
+ # >
39
+ # >@rsk97 In addition, just make sure the model used is trained on an NLI task and that the **last output label corresponds to entailment** while the **first output label corresponds to contradiction**.
40
+ #
41
+ # => We change the original `label` and use the `labels` column, which is required by a `AutoModelForSequenceClassification`
42
+
43
+ # In[ ]:
44
+
45
+
46
+ # see markdown above
47
+ def switch_label_id(row):
48
+ if row["label"] == 0:
49
+ return {"labels": 2}
50
+ elif row["label"] == 2:
51
+ return {"labels": 0}
52
+ else:
53
+ return {"labels": 1}
54
+
55
+ for split in xnli_es:
56
+ xnli_es[split] = xnli_es[split].map(switch_label_id)
57
+
58
+
59
+ # ## Tokenize data
60
+
61
+ # In[ ]:
62
+
63
+
64
+ tokenizer = AutoTokenizer.from_pretrained("dccuchile/bert-base-spanish-wwm-cased")
65
+
66
+
67
+ # In a first attempt i padded all data to the maximum length of the dataset (379). However, the traninig takes substanially longer with all the paddings, it's better to pass in the tokenizer to the `Trainer` and let the `Trainer` do the padding on a batch level.
68
+
69
+ # In[ ]:
70
+
71
+
72
+ # Figured out max length of the dataset manually
73
+ # max_length = 379
74
+ def tokenize(row):
75
+ return tokenizer(row["premise"], row["hypothesis"], truncation=True, max_length=512) #, padding="max_length", max_length=max_length)
76
+
77
+
78
+ # In[ ]:
79
+
80
+
81
+ data = {}
82
+ for split in xnli_es:
83
+ data[split] = xnli_es[split].map(
84
+ tokenize,
85
+ remove_columns=["hypothesis", "premise", "label"],
86
+ batched=True,
87
+ batch_size=128
88
+ )
89
+
90
+
91
+ # In[ ]:
92
+
93
+
94
+ train_path = str(Path("./train_ds").absolute())
95
+ valid_path = str(Path("./valid_ds").absolute())
96
+
97
+ data["train"].save_to_disk(train_path)
98
+ data["validation"].save_to_disk(valid_path)
99
+
100
+
101
+ # In[ ]:
102
+
103
+
104
+ # We can use `datasets.Dataset`s directly
105
+
106
+ # class XnliDataset(torch.utils.data.Dataset):
107
+ # def __init__(self, data):
108
+ # self.data = data
109
+
110
+ # def __getitem__(self, idx):
111
+ # item = {key: torch.tensor(val) for key, val in self.data[idx].items()}
112
+ # return item
113
+
114
+ # def __len__(self):
115
+ # return len(self.data)
116
+
117
+
118
+ # In[ ]:
119
+
120
+
121
+ def trainable(config):
122
+ metric = load_metric("xnli", "es")
123
+
124
+ def compute_metrics(eval_pred):
125
+ predictions, labels = eval_pred
126
+ predictions = predictions.argmax(axis=-1)
127
+ return metric.compute(predictions=predictions, references=labels)
128
+
129
+ model = AutoModelForSequenceClassification.from_pretrained("dccuchile/bert-base-spanish-wwm-cased", num_labels=3)
130
+
131
+ training_args = TrainingArguments(
132
+ output_dir='./results', # output directory
133
+ do_train=True,
134
+ do_eval=True,
135
+ evaluation_strategy="steps",
136
+ eval_steps=500,
137
+ load_best_model_at_end=True,
138
+ metric_for_best_model="eval_accuracy",
139
+ num_train_epochs=config["epochs"], # total number of training epochs
140
+ per_device_train_batch_size=config["batch_size"], # batch size per device during training
141
+ per_device_eval_batch_size=config["batch_size_eval"], # batch size for evaluation
142
+ warmup_steps=config["warmup_steps"], # 500
143
+ weight_decay=config["weight_decay"], # 0.001 # strength of weight decay
144
+ learning_rate=config["learning_rate"], # 5e-05
145
+ logging_dir='./logs', # directory for storing logs
146
+ logging_steps=250,
147
+ #save_steps=500, # ignored when using load_best_model_at_end
148
+ save_total_limit=10,
149
+ no_cuda=False,
150
+ disable_tqdm=True,
151
+ )
152
+
153
+ # train_dataset = XnliDataset(load_from_disk(config["train_path"]))
154
+ # valid_dataset = XnliDataset(load_from_disk(config["valid_path"]))
155
+ train_dataset = load_from_disk(config["train_path"])
156
+ valid_dataset = load_from_disk(config["valid_path"])
157
+
158
+
159
+ trainer = Trainer(
160
+ model,
161
+ tokenizer=tokenizer,
162
+ args=training_args, # training arguments, defined above
163
+ train_dataset=train_dataset, # training dataset
164
+ eval_dataset=valid_dataset, # evaluation dataset
165
+ compute_metrics=compute_metrics,
166
+ )
167
+
168
+ trainer.train()
169
+
170
+
171
+ # In[ ]:
172
+
173
+
174
+ trainable(
175
+ {
176
+ "train_path": train_path,
177
+ "valid_path": valid_path,
178
+ "batch_size": 16,
179
+ "batch_size_eval": 64,
180
+ "warmup_steps": 500,
181
+ "weight_decay": 0.001,
182
+ "learning_rate": 5e-5,
183
+ "epochs": 3,
184
+ }
185
+ )
186
+
187
+
188
+ # # HPO
189
+
190
+ # In[ ]:
191
+
192
+
193
+ # config = {
194
+ # "train_path": train_path,
195
+ # "valid_path": valid_path,
196
+ # "warmup_steps": tune.randint(0, 500),
197
+ # "weight_decay": tune.loguniform(0.00001, 0.1),
198
+ # "learning_rate": tune.loguniform(5e-6, 5e-4),
199
+ # "epochs": tune.choice([2, 3, 4])
200
+ # }
201
+
202
+
203
+ # # In[ ]:
204
+
205
+
206
+ # analysis = tune.run(
207
+ # trainable,
208
+ # config=config,
209
+ # metric="eval_acc",
210
+ # mode="max",
211
+ # #search_alg=HyperOptSearch(),
212
+ # #scheduler=ASHAScheduler(),
213
+ # num_samples=1,
214
+ # )
215
+
216
+
217
+ # # In[ ]:
218
+
219
+
220
+ # def model_init():
221
+ # return AutoModelForSequenceClassification.from_pretrained("dccuchile/bert-base-spanish-wwm-cased", num_labels=3)
222
+
223
+ # trainer = Trainer(
224
+ # args=training_args, # training arguments, defined above
225
+ # train_dataset=train_dataset, # training dataset
226
+ # eval_dataset=valid_dataset, # evaluation dataset
227
+ # model_init=model_init,
228
+ # compute_metrics=compute_metrics,
229
+ # )
230
+
231
+
232
+ # # In[ ]:
233
+
234
+
235
+ # best_trial = trainer.hyperparameter_search(
236
+ # direction="maximize",
237
+ # backend="ray",
238
+ # n_trials=2,
239
+ # # Choose among many libraries:
240
+ # # https://docs.ray.io/en/latest/tune/api_docs/suggestion.html
241
+ # search_alg=HyperOptSearch(mode="max", metric="accuracy"),
242
+ # # Choose among schedulers:
243
+ # # https://docs.ray.io/en/latest/tune/api_docs/schedulers.html
244
+ # scheduler=ASHAScheduler(mode="max", metric="accuracy"),
245
+ # local_dir="tune_runs",
246
+ # )
247
+