Ashwin B
Move project to Hugging Space
0b6b733
raw
history blame contribute delete
831 Bytes
from transformers import Trainer, TrainingArguments
def get_training_args(output_dir="outputs/model"):
return TrainingArguments(
output_dir=output_dir,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=3,
weight_decay=0.01,
logging_dir="outputs/logs",
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model="f1"
)
def train_model(model, args, train_dataset, val_dataset, compute_metrics):
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics
)
trainer.train()
return trainer