Transformers-Fine-Tuner / fine_tuner.py
Canstralian's picture
Update fine_tuner.py
e2e74c5 verified
raw
history blame
2.39 kB
import torch
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
from transformers import set_seed
# Set seed for reproducibility
set_seed(42)
def fine_tune_model(dataset_url, model_name, epochs, batch_size, learning_rate):
"""
Fine-tunes a pre-trained transformer model on a custom dataset.
Parameters:
- dataset_url (str): URL or path to the dataset.
- model_name (str): Name of the pre-trained model.
- epochs (int): Number of training epochs.
- batch_size (int): Batch size for training.
- learning_rate (float): Learning rate for the optimizer.
Returns:
- dict: Status message containing training completion status.
"""
# Load the dataset
dataset = load_dataset(dataset_url)
# Load the pre-trained model for sequence classification (2 labels for binary classification)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
# Define the training arguments
training_args = TrainingArguments(
output_dir='./results', # Directory for storing results
num_train_epochs=epochs, # Number of training epochs
per_device_train_batch_size=batch_size, # Batch size for training
learning_rate=learning_rate, # Learning rate for the optimizer
logging_dir='./logs', # Directory for storing logs
logging_steps=10, # Log every 10 steps
evaluation_strategy="epoch", # Evaluate every epoch
save_strategy="epoch", # Save checkpoint every epoch
load_best_model_at_end=True, # Load the best model at the end of training
metric_for_best_model="accuracy", # Metric to monitor for selecting the best model
greater_is_better=True, # Set to True if higher metric values are better
)
# Initialize the Trainer with the model, arguments, and dataset
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset['train'], # Training dataset
eval_dataset=dataset['validation'], # Validation dataset
)
# Train the model
trainer.train()
# Return a status message after training completes
return {"status": "Training complete"}