DanielKiani's picture
Initial commit of Food101 Classification
43124a6
from prepare_data import Food101DataModule, CustomFood101, get_model_components
from models import EffNetV2_S , EffNetb2
import pytorch_lightning as pl
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import EarlyStopping ,ModelCheckpoint
from typing import Optional
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from typing import List
DATA_DIR = "data"
MODEL_NAME = "EfficientNet_V2_S"
BATCH_SIZE = 32
SUBSET_FRACTION = 0.2 # Useing a smaller subset for quick testing
CHECKPOINT_PATH = "checkpoints/best-model-epoch=22-val_acc=0.8541.ckpt" # Path to your trained model checkpoint
def plot_confusion_matrix(cm: np.ndarray, class_names: List[str], figsize: tuple = (25, 25)):
"""
Creates and saves a multi-class confusion matrix plot.
This function normalizes the confusion matrix to show prediction
percentages for each class, visualizes it as a heatmap, and saves
the resulting figure to a file.
Args:
cm (np.ndarray): The confusion matrix from torchmetrics or scikit-learn.
class_names (List[str]): A list of class names for the labels.
figsize (tuple, optional): The size of the figure. Defaults to (25, 25).
"""
# 1. Normalize the confusion matrix to show percentages
# Add a small epsilon to prevent division by zero
cm_normalized = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + 1e-6)
# 2. Create a DataFrame for a beautiful plot with labels
df_cm = pd.DataFrame(cm_normalized, index=class_names, columns=class_names)
# 3. Create the plot
plt.figure(figsize=figsize)
heatmap = sns.heatmap(df_cm, annot=False, cmap='Blues') # Annotations off for 101 classes
# 4. Format the plot
heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=8)
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=8)
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Normalized Confusion Matrix')
plt.tight_layout()
# 5. Save the figure and show the plot
plt.savefig('confusion_matrix.png', dpi=300)
print("Confusion matrix plot saved to confusion_matrix.png")
plt.show()
def run_training_session(
model_name: str = "EfficientNet_V2_S",
batch_size: int = 32,
data_dir: str = 'data',
subset_fraction: float = 1.0,
checkpoint_path: str = "checkpoints/",
lr: float = 1e-3,
weight_decay: float = 1e-4,
freeze_features: bool = True,
early_stopping_patience: int = 5,
max_epochs: int = 100,
accelerator: str = 'auto',
resume_from_checkpoint: Optional[str] = None
) -> Trainer:
"""
Sets up and runs a complete training session for a specified model.
This function handles the entire pipeline: data preparation, model
instantiation, logger and callback setup, and trainer execution.
Args:
model_name (str): The name of the model architecture to train.
batch_size (int): The number of samples per batch.
data_dir (str): The root directory for the dataset.
subset_fraction (float): The fraction of the dataset to use for training.
checkpoint_path (str): Directory to save model checkpoints.
lr (float): The learning rate for the optimizer.
weight_decay (float): The weight decay for the optimizer.
freeze_features (bool): Flag to control the fine-tuning strategy
(e.g., for two-stage training).
early_stopping_patience (int): Number of epochs with no improvement
after which training will be stopped.
max_epochs (int): The maximum number of epochs to train for.
accelerator (str): The hardware accelerator to use ('auto', 'cpu', 'gpu').
resume_from_checkpoint (Optional[str]): Path to a checkpoint file to
resume training from. Defaults to None.
Returns:
Trainer: The PyTorch Lightning Trainer object after fitting is complete.
"""
# A registry to map model names to their actual classes
model_class_registry = {
"EfficientNet_V2_S": EffNetV2_S,
"EfficientNet_B2": EffNetb2,
}
if model_name not in model_class_registry:
raise ValueError(f"Model '{model_name}' is not a recognized class.")
# Get model-specific transforms
components = get_model_components(model_name)
train_transforms = components["train_transforms"]
val_transforms = components["val_transforms"]
# Set up the DataModule
food_datamodule = Food101DataModule(
data_dir=data_dir,
batch_size=batch_size,
train_transforms=train_transforms,
val_transforms=val_transforms,
subset_fraction=subset_fraction
)
food_datamodule.prepare_data()
food_datamodule.setup()
# Instantiate the model dynamically
model_class = model_class_registry[model_name]
model = model_class(
num_classes=len(food_datamodule.classes),
class_names=food_datamodule.classes,
lr=lr,
weight_decay=weight_decay,
freeze_features=freeze_features
)
# Set up logger and callbacks
logger = CSVLogger(save_dir="logs/", name=model_name)
early_stop_callback = EarlyStopping(
monitor="val_loss",
patience=early_stopping_patience,
mode="min"
)
best_model_checkpoint = ModelCheckpoint(
dirpath=checkpoint_path,
filename="best-model-{epoch:02d}-{val_acc:.4f}",
save_top_k=1,
monitor="val_acc",
mode="max"
)
callbacks = [early_stop_callback, best_model_checkpoint]
# Instantiate the Trainer
trainer = Trainer(
max_epochs=max_epochs,
accelerator=accelerator,
callbacks=callbacks,
logger=logger,
)
# Start training
trainer.fit(
model,
datamodule=food_datamodule,
ckpt_path=resume_from_checkpoint
)
return trainer
# ===================================================================
# Main Execution Block
# ===================================================================
if __name__ == "__main__":
# --- 1. DEFINE YOUR TRAINING CONFIGURATION HERE ---
config = {
"model_name": "EfficientNet_V2_S",
"batch_size": 32,
"lr": 1e-4,
"epochs": 50,
"subset_fraction": 1.0, # Use 1.0 for the full dataset
"freeze_features": True,
"early_stopping_patience": 10
}
# --- 2. PRINT CONFIGURATION AND START TRAINING ---
print("--- Starting Training Session ---")
for key, value in config.items():
print(f" {key}: {value}")
print("---------------------------------")
run_training_session(
model_name=config["model_name"],
batch_size=config["batch_size"],
lr=config["lr"],
max_epochs=config["epochs"],
subset_fraction=config["subset_fraction"],
freeze_features=config["freeze_features"],
early_stopping_patience=config["early_stopping_patience"]
)
print("\n--- Training Session Complete ---")
print("\n--- Starting Evaluation on Test Set ---")
print(f"Loading model from checkpoint: {CHECKPOINT_PATH}")
# Step 1: Set up the DataModule for the test set
components = get_model_components(MODEL_NAME)
val_transforms = components["val_transforms"]
datamodule = Food101DataModule(
data_dir=DATA_DIR,
batch_size=BATCH_SIZE,
val_transforms=val_transforms
)
# This prepares the test dataloader specifically
datamodule.setup(stage='test')
# Step 2: Load the trained model from the checkpoint file
model = EffNetV2_S.load_from_checkpoint(CHECKPOINT_PATH)
model.class_names = datamodule.classes
model.eval() # Set the model to evaluation mode
# Step 3: Create a Trainer and run the test
trainer = pl.Trainer(accelerator='auto')
# This call will run the test_step and automatically trigger the
# on_test_end hook in your model, which generates the plot.
trainer.test(model, datamodule=datamodule)
print("\nEvaluation complete. The confusion matrix plot has been saved.")