dnnsdunca's picture
Create main.py
774e73a verified
raw
history blame
1.37 kB
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
from dataset import MyDataset
from data_collator import MyDataCollator
# Set hyperparameters
model_name = 'bert-base-uncased'
batch_size = 16
num_epochs = 3
# Load data
train_data = MyDataset('train.csv', AutoTokenizer.from_pretrained(model_name))
val_data = MyDataset('val.csv', AutoTokenizer.from_pretrained(model_name))
# Create data collator
data_collator = MyDataCollator(AutoTokenizer.from_pretrained(model_name))
# Create model
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=8)
# Create training arguments
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
evaluation_strategy='epoch',
save_total_limit=2,
save_steps=500,
load_best_model_at_end=True,
metric_for_best_model='accuracy',
greater_is_better=True,
save_on_each_node=True,
)
# Create trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_data,
eval_dataset=val_data,
compute_metrics=lambda pred: {'accuracy': torch.sum(torch.argmax(pred.label_ids, dim=1) == torch.argmax(pred.predictions, dim=1))},
data_collator=data_collator,
)
# Train model
trainer.train()