ALYYAN's picture
commiting backend files
4821854
raw
history blame
2.91 kB
# src/vitClassifier/components/model_training.py
import torch
from datasets import load_from_disk
from transformers import (ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer, DefaultDataCollator)
from vitClassifier.entity.config_entity import TrainingConfig
from vitClassifier import logger
import evaluate
class ModelTraining:
def __init__(self, config: TrainingConfig):
self.config = config
def train(self):
# --- NEW: Explicitly define the device ---
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
# --- Load datasets (no change) ---
train_data = load_from_disk(str(self.config.train_dataset_path))
val_data = load_from_disk(str(self.config.val_dataset_path))
id2label = {i: label for i, label in enumerate(train_data.features['label'].names)}
label2id = {label: i for i, label in id2label.items()}
model = ViTForImageClassification.from_pretrained(
self.config.model_name, num_labels=len(id2label), id2label=id2label,
label2id=label2id, ignore_mismatched_sizes=True
)
# --- NEW: Move the model to the correct device ---
model.to(device)
# --- TrainingArguments (no change) ---
args = TrainingArguments(
output_dir=str(self.config.root_dir),
learning_rate=self.config.learning_rate,
per_device_train_batch_size=self.config.batch_size,
per_device_eval_batch_size=self.config.batch_size,
num_train_epochs=self.config.epochs,
weight_decay=self.config.weight_decay,
warmup_steps=self.config.warmup_steps,
save_strategy='epoch',
eval_strategy='epoch',
load_best_model_at_end=True,
metric_for_best_model="accuracy",
save_total_limit=1,
report_to="none"
)
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = predictions.argmax(axis=1)
return metric.compute(predictions=predictions, references=labels)
data_collator = DefaultDataCollator()
processor = ViTImageProcessor.from_pretrained(self.config.model_name)
trainer = Trainer(
model, # The model is now already on the GPU
args,
train_dataset=train_data,
eval_dataset=val_data,
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor,
)
logger.info("Starting model fine-tuning with validation...")
trainer.train()
trainer.save_model(str(self.config.trained_model_path))
logger.info("Model fine-tuning complete and best model saved.")