|
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments |
|
from dataset import MyDataset |
|
from data_collator import MyDataCollator |
|
|
|
|
|
model_name = 'bert-base-uncased' |
|
batch_size = 16 |
|
num_epochs = 3 |
|
|
|
|
|
train_data = MyDataset('train.csv', AutoTokenizer.from_pretrained(model_name)) |
|
val_data = MyDataset('val.csv', AutoTokenizer.from_pretrained(model_name)) |
|
|
|
|
|
data_collator = MyDataCollator(AutoTokenizer.from_pretrained(model_name)) |
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=8) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir='./results', |
|
num_train_epochs=num_epochs, |
|
per_device_train_batch_size=batch_size, |
|
per_device_eval_batch_size=batch_size, |
|
evaluation_strategy='epoch', |
|
save_total_limit=2, |
|
save_steps=500, |
|
load_best_model_at_end=True, |
|
metric_for_best_model='accuracy', |
|
greater_is_better=True, |
|
save_on_each_node=True, |
|
) |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_data, |
|
eval_dataset=val_data, |
|
compute_metrics=lambda pred: {'accuracy': torch.sum(torch.argmax(pred.label_ids, dim=1) == torch.argmax(pred.predictions, dim=1))}, |
|
data_collator=data_collator, |
|
) |
|
|
|
|
|
trainer.train() |