bri25yu commited on
Commit
473e279
1 Parent(s): ccf6f6b

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +78 -0
train.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import DatasetDict, load_dataset
2
+ from evaluate import load as load_metric
3
+ from transformers import *
4
+
5
+
6
+ def train(batch_size: int, model_name: str="t5-small", max_steps: int=10_000) -> None:
7
+ total_batch_size_per_step = 512
8
+ grad_acc_steps = total_batch_size_per_step // batch_size
9
+ assert grad_acc_steps * batch_size == total_batch_size_per_step
10
+
11
+ model_name_for_path = model_name.split("/")[-1]
12
+ output_dir = f"wmt19-ende-{model_name_for_path}"
13
+ args = Seq2SeqTrainingArguments(
14
+ output_dir=output_dir,
15
+ learning_rate=1e-4,
16
+ per_device_train_batch_size=batch_size,
17
+ per_device_eval_batch_size=batch_size * 2,
18
+ gradient_accumulation_steps=grad_acc_steps,
19
+ max_steps=max_steps,
20
+ weight_decay=1e-2,
21
+ optim="adamw_torch_fused",
22
+ lr_scheduler_type="constant",
23
+ evaluation_strategy="steps",
24
+ eval_steps=100,
25
+ save_strategy="steps",
26
+ save_steps=100,
27
+ save_total_limit=1,
28
+ save_safetensors=True,
29
+ metric_for_best_model="bleu",
30
+ push_to_hub=True,
31
+ bf16=True,
32
+ bf16_full_eval=True,
33
+ seed=42,
34
+ predict_with_generate=True,
35
+ log_level="error",
36
+ logging_steps=1,
37
+ logging_dir=output_dir,
38
+ )
39
+
40
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
41
+
42
+ bleu = load_metric("bleu")
43
+ def compute_metrics(eval_preds: EvalPrediction):
44
+ logits, label_ids = eval_preds
45
+ label_ids[label_ids == -100] = tokenizer.pad_token_id
46
+
47
+ references = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
48
+ predictions = tokenizer.batch_decode(logits, skip_special_tokens=True)
49
+
50
+ bleu_outputs = bleu.compute(predictions=predictions, references=references)
51
+ return {
52
+ "bleu": 100 * bleu_outputs["bleu"],
53
+ "brevity_penalty": bleu_outputs["brevity_penalty"],
54
+ }
55
+
56
+ def map_fn(inputs):
57
+ map_fn = lambda s: tokenizer([d[s] for d in inputs["translation"]], return_attention_mask=False, max_length=64, truncation=True).input_ids
58
+ return {
59
+ "input_ids": map_fn("de"),
60
+ "labels": map_fn("en"),
61
+ }
62
+
63
+ get_dataset_split = lambda s: load_dataset("wmt19", "de-en", split=s, streaming=True).map(map_fn, batched=True)
64
+ apply_length_filter = lambda d: d.filter(lambda e: len(e["input_ids"]) >= 8 and len(e["labels"]) >= 8)
65
+
66
+ trainer = Seq2SeqTrainer(
67
+ model=AutoModelForSeq2SeqLM.from_pretrained(model_name),
68
+ args=args,
69
+ data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
70
+ train_dataset=apply_length_filter(get_dataset_split("train")),
71
+ eval_dataset=get_dataset_split("validation"),
72
+ tokenizer=tokenizer,
73
+ compute_metrics=compute_metrics,
74
+ )
75
+ trainer.remove_callback(PrinterCallback)
76
+
77
+ trainer.train()
78
+ trainer.push_to_hub()