Spaces:
Sleeping
Sleeping
from transformers import Trainer, TrainingArguments | |
def get_training_args(output_dir="outputs/model"): | |
return TrainingArguments( | |
output_dir=output_dir, | |
evaluation_strategy="epoch", | |
save_strategy="epoch", | |
learning_rate=2e-5, | |
per_device_train_batch_size=16, | |
per_device_eval_batch_size=16, | |
num_train_epochs=3, | |
weight_decay=0.01, | |
logging_dir="outputs/logs", | |
logging_steps=10, | |
load_best_model_at_end=True, | |
metric_for_best_model="f1" | |
) | |
def train_model(model, args, train_dataset, val_dataset, compute_metrics): | |
trainer = Trainer( | |
model=model, | |
args=args, | |
train_dataset=train_dataset, | |
eval_dataset=val_dataset, | |
compute_metrics=compute_metrics | |
) | |
trainer.train() | |
return trainer | |