Spaces:
Sleeping
Sleeping
| from prepare_data import Food101DataModule, CustomFood101, get_model_components | |
| from models import EffNetV2_S , EffNetb2 | |
| import pytorch_lightning as pl | |
| from pytorch_lightning import Trainer, LightningModule | |
| from pytorch_lightning.loggers import CSVLogger | |
| from pytorch_lightning.callbacks import EarlyStopping ,ModelCheckpoint | |
| from typing import Optional | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import numpy as np | |
| import pandas as pd | |
| from typing import List | |
| DATA_DIR = "data" | |
| MODEL_NAME = "EfficientNet_V2_S" | |
| BATCH_SIZE = 32 | |
| SUBSET_FRACTION = 0.2 # Useing a smaller subset for quick testing | |
| CHECKPOINT_PATH = "checkpoints/best-model-epoch=22-val_acc=0.8541.ckpt" # Path to your trained model checkpoint | |
| def plot_confusion_matrix(cm: np.ndarray, class_names: List[str], figsize: tuple = (25, 25)): | |
| """ | |
| Creates and saves a multi-class confusion matrix plot. | |
| This function normalizes the confusion matrix to show prediction | |
| percentages for each class, visualizes it as a heatmap, and saves | |
| the resulting figure to a file. | |
| Args: | |
| cm (np.ndarray): The confusion matrix from torchmetrics or scikit-learn. | |
| class_names (List[str]): A list of class names for the labels. | |
| figsize (tuple, optional): The size of the figure. Defaults to (25, 25). | |
| """ | |
| # 1. Normalize the confusion matrix to show percentages | |
| # Add a small epsilon to prevent division by zero | |
| cm_normalized = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + 1e-6) | |
| # 2. Create a DataFrame for a beautiful plot with labels | |
| df_cm = pd.DataFrame(cm_normalized, index=class_names, columns=class_names) | |
| # 3. Create the plot | |
| plt.figure(figsize=figsize) | |
| heatmap = sns.heatmap(df_cm, annot=False, cmap='Blues') # Annotations off for 101 classes | |
| # 4. Format the plot | |
| heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=8) | |
| heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=8) | |
| plt.ylabel('True Label') | |
| plt.xlabel('Predicted Label') | |
| plt.title('Normalized Confusion Matrix') | |
| plt.tight_layout() | |
| # 5. Save the figure and show the plot | |
| plt.savefig('confusion_matrix.png', dpi=300) | |
| print("Confusion matrix plot saved to confusion_matrix.png") | |
| plt.show() | |
| def run_training_session( | |
| model_name: str = "EfficientNet_V2_S", | |
| batch_size: int = 32, | |
| data_dir: str = 'data', | |
| subset_fraction: float = 1.0, | |
| checkpoint_path: str = "checkpoints/", | |
| lr: float = 1e-3, | |
| weight_decay: float = 1e-4, | |
| freeze_features: bool = True, | |
| early_stopping_patience: int = 5, | |
| max_epochs: int = 100, | |
| accelerator: str = 'auto', | |
| resume_from_checkpoint: Optional[str] = None | |
| ) -> Trainer: | |
| """ | |
| Sets up and runs a complete training session for a specified model. | |
| This function handles the entire pipeline: data preparation, model | |
| instantiation, logger and callback setup, and trainer execution. | |
| Args: | |
| model_name (str): The name of the model architecture to train. | |
| batch_size (int): The number of samples per batch. | |
| data_dir (str): The root directory for the dataset. | |
| subset_fraction (float): The fraction of the dataset to use for training. | |
| checkpoint_path (str): Directory to save model checkpoints. | |
| lr (float): The learning rate for the optimizer. | |
| weight_decay (float): The weight decay for the optimizer. | |
| freeze_features (bool): Flag to control the fine-tuning strategy | |
| (e.g., for two-stage training). | |
| early_stopping_patience (int): Number of epochs with no improvement | |
| after which training will be stopped. | |
| max_epochs (int): The maximum number of epochs to train for. | |
| accelerator (str): The hardware accelerator to use ('auto', 'cpu', 'gpu'). | |
| resume_from_checkpoint (Optional[str]): Path to a checkpoint file to | |
| resume training from. Defaults to None. | |
| Returns: | |
| Trainer: The PyTorch Lightning Trainer object after fitting is complete. | |
| """ | |
| # A registry to map model names to their actual classes | |
| model_class_registry = { | |
| "EfficientNet_V2_S": EffNetV2_S, | |
| "EfficientNet_B2": EffNetb2, | |
| } | |
| if model_name not in model_class_registry: | |
| raise ValueError(f"Model '{model_name}' is not a recognized class.") | |
| # Get model-specific transforms | |
| components = get_model_components(model_name) | |
| train_transforms = components["train_transforms"] | |
| val_transforms = components["val_transforms"] | |
| # Set up the DataModule | |
| food_datamodule = Food101DataModule( | |
| data_dir=data_dir, | |
| batch_size=batch_size, | |
| train_transforms=train_transforms, | |
| val_transforms=val_transforms, | |
| subset_fraction=subset_fraction | |
| ) | |
| food_datamodule.prepare_data() | |
| food_datamodule.setup() | |
| # Instantiate the model dynamically | |
| model_class = model_class_registry[model_name] | |
| model = model_class( | |
| num_classes=len(food_datamodule.classes), | |
| class_names=food_datamodule.classes, | |
| lr=lr, | |
| weight_decay=weight_decay, | |
| freeze_features=freeze_features | |
| ) | |
| # Set up logger and callbacks | |
| logger = CSVLogger(save_dir="logs/", name=model_name) | |
| early_stop_callback = EarlyStopping( | |
| monitor="val_loss", | |
| patience=early_stopping_patience, | |
| mode="min" | |
| ) | |
| best_model_checkpoint = ModelCheckpoint( | |
| dirpath=checkpoint_path, | |
| filename="best-model-{epoch:02d}-{val_acc:.4f}", | |
| save_top_k=1, | |
| monitor="val_acc", | |
| mode="max" | |
| ) | |
| callbacks = [early_stop_callback, best_model_checkpoint] | |
| # Instantiate the Trainer | |
| trainer = Trainer( | |
| max_epochs=max_epochs, | |
| accelerator=accelerator, | |
| callbacks=callbacks, | |
| logger=logger, | |
| ) | |
| # Start training | |
| trainer.fit( | |
| model, | |
| datamodule=food_datamodule, | |
| ckpt_path=resume_from_checkpoint | |
| ) | |
| return trainer | |
| # =================================================================== | |
| # Main Execution Block | |
| # =================================================================== | |
| if __name__ == "__main__": | |
| # --- 1. DEFINE YOUR TRAINING CONFIGURATION HERE --- | |
| config = { | |
| "model_name": "EfficientNet_V2_S", | |
| "batch_size": 32, | |
| "lr": 1e-4, | |
| "epochs": 50, | |
| "subset_fraction": 1.0, # Use 1.0 for the full dataset | |
| "freeze_features": True, | |
| "early_stopping_patience": 10 | |
| } | |
| # --- 2. PRINT CONFIGURATION AND START TRAINING --- | |
| print("--- Starting Training Session ---") | |
| for key, value in config.items(): | |
| print(f" {key}: {value}") | |
| print("---------------------------------") | |
| run_training_session( | |
| model_name=config["model_name"], | |
| batch_size=config["batch_size"], | |
| lr=config["lr"], | |
| max_epochs=config["epochs"], | |
| subset_fraction=config["subset_fraction"], | |
| freeze_features=config["freeze_features"], | |
| early_stopping_patience=config["early_stopping_patience"] | |
| ) | |
| print("\n--- Training Session Complete ---") | |
| print("\n--- Starting Evaluation on Test Set ---") | |
| print(f"Loading model from checkpoint: {CHECKPOINT_PATH}") | |
| # Step 1: Set up the DataModule for the test set | |
| components = get_model_components(MODEL_NAME) | |
| val_transforms = components["val_transforms"] | |
| datamodule = Food101DataModule( | |
| data_dir=DATA_DIR, | |
| batch_size=BATCH_SIZE, | |
| val_transforms=val_transforms | |
| ) | |
| # This prepares the test dataloader specifically | |
| datamodule.setup(stage='test') | |
| # Step 2: Load the trained model from the checkpoint file | |
| model = EffNetV2_S.load_from_checkpoint(CHECKPOINT_PATH) | |
| model.class_names = datamodule.classes | |
| model.eval() # Set the model to evaluation mode | |
| # Step 3: Create a Trainer and run the test | |
| trainer = pl.Trainer(accelerator='auto') | |
| # This call will run the test_step and automatically trigger the | |
| # on_test_end hook in your model, which generates the plot. | |
| trainer.test(model, datamodule=datamodule) | |
| print("\nEvaluation complete. The confusion matrix plot has been saved.") |