from datasets import DatasetDict, load_dataset from evaluate import load as load_metric from transformers import * def train(batch_size: int, model_name: str="t5-small", max_steps: int=10_000) -> None: total_batch_size_per_step = 512 grad_acc_steps = total_batch_size_per_step // batch_size assert grad_acc_steps * batch_size == total_batch_size_per_step model_name_for_path = model_name.split("/")[-1] output_dir = f"wmt19-ende-{model_name_for_path}" args = Seq2SeqTrainingArguments( output_dir=output_dir, learning_rate=1e-4, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size * 2, gradient_accumulation_steps=grad_acc_steps, max_steps=max_steps, weight_decay=1e-2, optim="adamw_torch_fused", lr_scheduler_type="constant", evaluation_strategy="steps", eval_steps=100, save_strategy="steps", save_steps=100, save_total_limit=1, save_safetensors=True, metric_for_best_model="bleu", push_to_hub=True, bf16=True, bf16_full_eval=True, seed=42, predict_with_generate=True, log_level="error", logging_steps=1, logging_dir=output_dir, ) tokenizer = AutoTokenizer.from_pretrained(model_name) bleu = load_metric("bleu") def compute_metrics(eval_preds: EvalPrediction): logits, label_ids = eval_preds label_ids[label_ids == -100] = tokenizer.pad_token_id references = tokenizer.batch_decode(label_ids, skip_special_tokens=True) predictions = tokenizer.batch_decode(logits, skip_special_tokens=True) bleu_outputs = bleu.compute(predictions=predictions, references=references) return { "bleu": 100 * bleu_outputs["bleu"], "brevity_penalty": bleu_outputs["brevity_penalty"], } def map_fn(inputs): map_fn = lambda s: tokenizer([d[s] for d in inputs["translation"]], return_attention_mask=False, max_length=64, truncation=True).input_ids return { "input_ids": map_fn("de"), "labels": map_fn("en"), } get_dataset_split = lambda s: load_dataset("wmt19", "de-en", split=s, streaming=True).map(map_fn, batched=True) apply_length_filter = lambda d: d.filter(lambda e: len(e["input_ids"]) >= 8 and len(e["labels"]) >= 8) trainer = Seq2SeqTrainer( model=AutoModelForSeq2SeqLM.from_pretrained(model_name), args=args, data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer), train_dataset=apply_length_filter(get_dataset_split("train")), eval_dataset=get_dataset_split("validation"), tokenizer=tokenizer, compute_metrics=compute_metrics, ) trainer.remove_callback(PrinterCallback) trainer.train() trainer.push_to_hub()