bri25yu's picture
Create train.py
473e279
from datasets import DatasetDict, load_dataset
from evaluate import load as load_metric
from transformers import *
def train(batch_size: int, model_name: str="t5-small", max_steps: int=10_000) -> None:
total_batch_size_per_step = 512
grad_acc_steps = total_batch_size_per_step // batch_size
assert grad_acc_steps * batch_size == total_batch_size_per_step
model_name_for_path = model_name.split("/")[-1]
output_dir = f"wmt19-ende-{model_name_for_path}"
args = Seq2SeqTrainingArguments(
output_dir=output_dir,
learning_rate=1e-4,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size * 2,
gradient_accumulation_steps=grad_acc_steps,
max_steps=max_steps,
weight_decay=1e-2,
optim="adamw_torch_fused",
lr_scheduler_type="constant",
evaluation_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=1,
save_safetensors=True,
metric_for_best_model="bleu",
push_to_hub=True,
bf16=True,
bf16_full_eval=True,
seed=42,
predict_with_generate=True,
log_level="error",
logging_steps=1,
logging_dir=output_dir,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
bleu = load_metric("bleu")
def compute_metrics(eval_preds: EvalPrediction):
logits, label_ids = eval_preds
label_ids[label_ids == -100] = tokenizer.pad_token_id
references = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
predictions = tokenizer.batch_decode(logits, skip_special_tokens=True)
bleu_outputs = bleu.compute(predictions=predictions, references=references)
return {
"bleu": 100 * bleu_outputs["bleu"],
"brevity_penalty": bleu_outputs["brevity_penalty"],
}
def map_fn(inputs):
map_fn = lambda s: tokenizer([d[s] for d in inputs["translation"]], return_attention_mask=False, max_length=64, truncation=True).input_ids
return {
"input_ids": map_fn("de"),
"labels": map_fn("en"),
}
get_dataset_split = lambda s: load_dataset("wmt19", "de-en", split=s, streaming=True).map(map_fn, batched=True)
apply_length_filter = lambda d: d.filter(lambda e: len(e["input_ids"]) >= 8 and len(e["labels"]) >= 8)
trainer = Seq2SeqTrainer(
model=AutoModelForSeq2SeqLM.from_pretrained(model_name),
args=args,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
train_dataset=apply_length_filter(get_dataset_split("train")),
eval_dataset=get_dataset_split("validation"),
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
trainer.remove_callback(PrinterCallback)
trainer.train()
trainer.push_to_hub()