|
from datasets import DatasetDict, load_dataset |
|
from evaluate import load as load_metric |
|
from transformers import * |
|
|
|
|
|
def train(batch_size: int, model_name: str="t5-small", max_steps: int=10_000) -> None: |
|
total_batch_size_per_step = 512 |
|
grad_acc_steps = total_batch_size_per_step // batch_size |
|
assert grad_acc_steps * batch_size == total_batch_size_per_step |
|
|
|
model_name_for_path = model_name.split("/")[-1] |
|
output_dir = f"wmt19-ende-{model_name_for_path}" |
|
args = Seq2SeqTrainingArguments( |
|
output_dir=output_dir, |
|
learning_rate=1e-4, |
|
per_device_train_batch_size=batch_size, |
|
per_device_eval_batch_size=batch_size * 2, |
|
gradient_accumulation_steps=grad_acc_steps, |
|
max_steps=max_steps, |
|
weight_decay=1e-2, |
|
optim="adamw_torch_fused", |
|
lr_scheduler_type="constant", |
|
evaluation_strategy="steps", |
|
eval_steps=100, |
|
save_strategy="steps", |
|
save_steps=100, |
|
save_total_limit=1, |
|
save_safetensors=True, |
|
metric_for_best_model="bleu", |
|
push_to_hub=True, |
|
bf16=True, |
|
bf16_full_eval=True, |
|
seed=42, |
|
predict_with_generate=True, |
|
log_level="error", |
|
logging_steps=1, |
|
logging_dir=output_dir, |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
bleu = load_metric("bleu") |
|
def compute_metrics(eval_preds: EvalPrediction): |
|
logits, label_ids = eval_preds |
|
label_ids[label_ids == -100] = tokenizer.pad_token_id |
|
|
|
references = tokenizer.batch_decode(label_ids, skip_special_tokens=True) |
|
predictions = tokenizer.batch_decode(logits, skip_special_tokens=True) |
|
|
|
bleu_outputs = bleu.compute(predictions=predictions, references=references) |
|
return { |
|
"bleu": 100 * bleu_outputs["bleu"], |
|
"brevity_penalty": bleu_outputs["brevity_penalty"], |
|
} |
|
|
|
def map_fn(inputs): |
|
map_fn = lambda s: tokenizer([d[s] for d in inputs["translation"]], return_attention_mask=False, max_length=64, truncation=True).input_ids |
|
return { |
|
"input_ids": map_fn("de"), |
|
"labels": map_fn("en"), |
|
} |
|
|
|
get_dataset_split = lambda s: load_dataset("wmt19", "de-en", split=s, streaming=True).map(map_fn, batched=True) |
|
apply_length_filter = lambda d: d.filter(lambda e: len(e["input_ids"]) >= 8 and len(e["labels"]) >= 8) |
|
|
|
trainer = Seq2SeqTrainer( |
|
model=AutoModelForSeq2SeqLM.from_pretrained(model_name), |
|
args=args, |
|
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer), |
|
train_dataset=apply_length_filter(get_dataset_split("train")), |
|
eval_dataset=get_dataset_split("validation"), |
|
tokenizer=tokenizer, |
|
compute_metrics=compute_metrics, |
|
) |
|
trainer.remove_callback(PrinterCallback) |
|
|
|
trainer.train() |
|
trainer.push_to_hub() |
|
|