Spaces:
Sleeping
Sleeping
| # src/vitClassifier/components/model_training.py | |
| import torch | |
| from datasets import load_from_disk | |
| from transformers import (ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer, DefaultDataCollator) | |
| from vitClassifier.entity.config_entity import TrainingConfig | |
| from vitClassifier import logger | |
| import evaluate | |
| class ModelTraining: | |
| def __init__(self, config: TrainingConfig): | |
| self.config = config | |
| def train(self): | |
| # --- NEW: Explicitly define the device --- | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {device}") | |
| # --- Load datasets (no change) --- | |
| train_data = load_from_disk(str(self.config.train_dataset_path)) | |
| val_data = load_from_disk(str(self.config.val_dataset_path)) | |
| id2label = {i: label for i, label in enumerate(train_data.features['label'].names)} | |
| label2id = {label: i for i, label in id2label.items()} | |
| model = ViTForImageClassification.from_pretrained( | |
| self.config.model_name, num_labels=len(id2label), id2label=id2label, | |
| label2id=label2id, ignore_mismatched_sizes=True | |
| ) | |
| # --- NEW: Move the model to the correct device --- | |
| model.to(device) | |
| # --- TrainingArguments (no change) --- | |
| args = TrainingArguments( | |
| output_dir=str(self.config.root_dir), | |
| learning_rate=self.config.learning_rate, | |
| per_device_train_batch_size=self.config.batch_size, | |
| per_device_eval_batch_size=self.config.batch_size, | |
| num_train_epochs=self.config.epochs, | |
| weight_decay=self.config.weight_decay, | |
| warmup_steps=self.config.warmup_steps, | |
| save_strategy='epoch', | |
| eval_strategy='epoch', | |
| load_best_model_at_end=True, | |
| metric_for_best_model="accuracy", | |
| save_total_limit=1, | |
| report_to="none" | |
| ) | |
| metric = evaluate.load("accuracy") | |
| def compute_metrics(eval_pred): | |
| predictions, labels = eval_pred | |
| predictions = predictions.argmax(axis=1) | |
| return metric.compute(predictions=predictions, references=labels) | |
| data_collator = DefaultDataCollator() | |
| processor = ViTImageProcessor.from_pretrained(self.config.model_name) | |
| trainer = Trainer( | |
| model, # The model is now already on the GPU | |
| args, | |
| train_dataset=train_data, | |
| eval_dataset=val_data, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics, | |
| tokenizer=processor, | |
| ) | |
| logger.info("Starting model fine-tuning with validation...") | |
| trainer.train() | |
| trainer.save_model(str(self.config.trained_model_path)) | |
| logger.info("Model fine-tuning complete and best model saved.") |