Spaces:
Sleeping
Sleeping
| import torch | |
| import torchvision | |
| import pytorch_lightning as pl | |
| from torch import nn | |
| from torchmetrics.classification import Accuracy, F1Score, ConfusionMatrix | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import numpy as np | |
| class EffNetV2_S(pl.LightningModule): | |
| """A PyTorch Lightning Module for fine-tuning EfficientNetV2-S. | |
| This module encapsulates the EfficientNetV2-S model and provides a flexible | |
| fine-tuning strategy. It can be configured for Stage 1 (training only the | |
| classifier and later feature blocks) or Stage 2 (training the entire model). | |
| Args: | |
| lr (float, optional): The learning rate. Defaults to 1e-3. | |
| weight_decay (float, optional): Weight decay for the optimizer. Defaults to 1e-4. | |
| num_classes (int, optional): The number of output classes. Defaults to 101. | |
| class_names (list, optional): A list of class names for logging. Defaults to None. | |
| freeze_features (bool, optional): If True, freezes the backbone and unfreezes | |
| only the later blocks (Stage 1). If False, all features are trainable | |
| (Stage 2). Defaults to True. | |
| unfreeze_from_block (int, optional): Which feature block to start unfreezing | |
| from. Used only if freeze_features is True. Defaults to -3 (last 3 blocks). | |
| """ | |
| def __init__( | |
| self, | |
| lr: float = 1e-3, | |
| weight_decay: float = 1e-4, | |
| num_classes: int = 101, | |
| class_names: list = None, | |
| freeze_features: bool = True, # True = Stage 1, False = Stage 2 | |
| unfreeze_from_block: int = -3 # Only used if freeze_features=True | |
| ): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.class_names = class_names if class_names else [str(i) for i in range(num_classes)] | |
| # Load pretrained weights | |
| weights = torchvision.models.EfficientNet_V2_S_Weights.DEFAULT | |
| self.model = torchvision.models.efficientnet_v2_s(weights=weights) | |
| # ---- Freezing strategy ---- | |
| if freeze_features: | |
| # Freeze all first | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| # Unfreeze from a specific block (default: last 3 blocks) | |
| for param in self.model.features[unfreeze_from_block:].parameters(): | |
| param.requires_grad = True | |
| else: | |
| # Stage 2: unfreeze everything | |
| for param in self.model.parameters(): | |
| param.requires_grad = True | |
| # Classifier head | |
| self.model.classifier = nn.Sequential( | |
| nn.Dropout(p=0.2, inplace=True), | |
| nn.Linear(in_features=1280, out_features=self.hparams.num_classes, bias=True) | |
| ) | |
| # Loss & metrics | |
| self.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1) | |
| self.train_accuracy = Accuracy(task="multiclass", num_classes=self.hparams.num_classes) | |
| self.val_accuracy = Accuracy(task="multiclass", num_classes=self.hparams.num_classes) | |
| self.train_f1 = F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='macro') | |
| self.val_f1 = F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='macro') | |
| self.val_conf_matrix = ConfusionMatrix(task="multiclass", num_classes=self.hparams.num_classes) | |
| self.test_conf_matrix = ConfusionMatrix(task="multiclass", num_classes=self.hparams.num_classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| def training_step(self, batch, batch_idx): | |
| x, y = batch | |
| logits = self(x) | |
| loss = self.loss_fn(logits, y) | |
| self.train_accuracy(logits, y) | |
| self.train_f1(logits, y) | |
| self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True) | |
| self.log('train_acc', self.train_accuracy, on_step=False, on_epoch=True, prog_bar=True) | |
| self.log('train_f1', self.train_f1, on_step=False, on_epoch=True, prog_bar=True) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| x, y = batch | |
| logits = self(x) | |
| loss = self.loss_fn(logits, y) | |
| self.val_accuracy(logits, y) | |
| self.val_f1(logits, y) | |
| self.log('val_loss', loss, prog_bar=True) | |
| self.log('val_acc', self.val_accuracy, prog_bar=True) | |
| self.log('val_f1', self.val_f1, prog_bar=True) | |
| self.val_conf_matrix.update(logits, y) | |
| def on_validation_epoch_end(self): | |
| cm = self.val_conf_matrix.compute() | |
| per_class_acc = cm.diag() / (cm.sum(dim=1) + 1e-6) | |
| print("\n--- Per-Class Validation Accuracy ---") | |
| for i, acc in enumerate(per_class_acc): | |
| self.log(f'val_acc/{self.class_names[i]}', acc.item(), on_epoch=True) | |
| print(f"{self.class_names[i]:<20}: {acc.item():.4f}") | |
| print("------------------------------------") | |
| self.val_conf_matrix.reset() | |
| def test_step(self, batch, batch_idx): | |
| x, y = batch | |
| logits = self(x) | |
| self.test_conf_matrix.update(logits, y) | |
| def on_test_end(self): | |
| cm = self.test_conf_matrix.compute() | |
| print("\nGenerating final confusion matrix plot...") | |
| self.test_conf_matrix.reset() | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.Adam( | |
| self.parameters(), | |
| lr=self.hparams.lr, | |
| weight_decay=self.hparams.weight_decay | |
| ) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, | |
| T_max=self.trainer.max_epochs, | |
| eta_min=1e-6 | |
| ) | |
| return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "interval": "epoch"}} | |
| class EffNetb2(pl.LightningModule): | |
| """A PyTorch Lightning Module for fine-tuning EfficientNet-B2. | |
| This module encapsulates the EfficientNet-B2 model and provides a flexible | |
| fine-tuning strategy. It can be configured for Stage 1 (training only the | |
| classifier and later feature blocks) or Stage 2 (training the entire model). | |
| Args: | |
| lr (float, optional): The learning rate. Defaults to 1e-3. | |
| weight_decay (float, optional): Weight decay for the optimizer. Defaults to 1e-4. | |
| num_classes (int, optional): The number of output classes. Defaults to 101. | |
| class_names (list, optional): A list of class names for logging. Defaults to None. | |
| freeze_features (bool, optional): If True, freezes the backbone and unfreezes | |
| only the later blocks (Stage 1). If False, all features are trainable | |
| (Stage 2). Defaults to True. | |
| unfreeze_from_block (int, optional): Which feature block to start unfreezing | |
| from. Used only if freeze_features is True. Defaults to -3 (last 3 blocks). | |
| """ | |
| def __init__( | |
| self, | |
| lr: float = 1e-3, | |
| weight_decay: float = 1e-4, | |
| num_classes: int = 101, | |
| class_names: list = None, | |
| freeze_features: bool = True, | |
| unfreeze_from_block: int = -3 | |
| ): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.class_names = class_names if class_names is not None else [str(i) for i in range(num_classes)] | |
| # Model setup | |
| weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT | |
| self.model = torchvision.models.efficientnet_b2(weights=weights) | |
| # --- : Flexible Freezing Strategy --- | |
| if self.hparams.freeze_features: | |
| # Stage 1: Freeze all first | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| # Unfreeze from a specific block (default: last 3 blocks) | |
| for param in self.model.features[self.hparams.unfreeze_from_block:].parameters(): | |
| param.requires_grad = True | |
| else: | |
| # Stage 2: unfreeze everything | |
| for param in self.model.parameters(): | |
| param.requires_grad = True | |
| # Classifier head | |
| self.model.classifier = nn.Sequential( | |
| nn.Dropout(p=0.3, inplace=True), | |
| nn.Linear(in_features=1408, out_features=self.hparams.num_classes) | |
| ) | |
| # Metrics | |
| self.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1) | |
| self.train_accuracy = Accuracy(task="multiclass", num_classes=self.hparams.num_classes) | |
| self.val_accuracy = Accuracy(task="multiclass", num_classes=self.hparams.num_classes) | |
| self.train_f1 = F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='macro') | |
| self.val_f1 = F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='macro') | |
| self.val_conf_matrix = ConfusionMatrix(task="multiclass", num_classes=self.hparams.num_classes) | |
| self.test_conf_matrix = ConfusionMatrix(task="multiclass", num_classes=self.hparams.num_classes) | |
| def forward(self, x): | |
| return self.model(x) | |
| def training_step(self, batch, batch_idx): | |
| x, y = batch | |
| logits = self(x) | |
| loss = self.loss_fn(logits, y) | |
| self.train_accuracy(logits, y) | |
| self.train_f1(logits, y) | |
| self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True) | |
| self.log('train_acc', self.train_accuracy, on_step=False, on_epoch=True, prog_bar=True) | |
| self.log('train_f1', self.train_f1, on_step=False, on_epoch=True, prog_bar=True) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| x, y = batch | |
| logits = self(x) | |
| loss = self.loss_fn(logits, y) | |
| self.val_accuracy(logits, y) | |
| self.val_f1(logits, y) | |
| self.log('val_loss', loss, prog_bar=True) | |
| self.log('val_acc', self.val_accuracy, prog_bar=True) | |
| self.log('val_f1', self.val_f1, prog_bar=True) | |
| self.val_conf_matrix.update(logits, y) | |
| def on_validation_epoch_end(self): | |
| cm = self.val_conf_matrix.compute() | |
| # Add a small epsilon (1e-6) to the denominator for numerical stability. | |
| per_class_acc = cm.diag() / (cm.sum(dim=1) + 1e-6) | |
| print("\n--- Per-Class Validation Accuracy ---") | |
| for i, acc in enumerate(per_class_acc): | |
| class_name = self.class_names[i] | |
| self.log(f'val_acc/{class_name}', acc.item(), on_epoch=True) | |
| print(f"{class_name:<20}: {acc.item():.4f}") | |
| print("------------------------------------") | |
| self.val_conf_matrix.reset() | |
| def test_step(self, batch, batch_idx): | |
| x, y = batch | |
| logits = self(x) | |
| self.test_conf_matrix.update(logits, y) | |
| def on_test_end(self): | |
| cm = self.test_conf_matrix.compute() | |
| print("\nGenerating final confusion matrix plot...") | |
| # Assuming plot_confusion_matrix is defined elsewhere | |
| # plot_confusion_matrix(cm.cpu().numpy(), class_names=self.class_names) | |
| self.test_conf_matrix.reset() | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.Adam( | |
| self.parameters(), | |
| lr=self.hparams.lr, | |
| weight_decay=self.hparams.weight_decay | |
| ) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, | |
| T_max=self.trainer.max_epochs, | |
| eta_min=1e-6 | |
| ) | |
| return { | |
| "optimizer": optimizer, | |
| "lr_scheduler": { | |
| "scheduler": scheduler, | |
| "interval": "epoch", | |
| }, | |
| } | |