Spaces:
Runtime error
Runtime error
import lightning as L | |
import torch.nn.functional as F | |
from torch import optim | |
from torchmetrics.classification import Accuracy, F1Score | |
import timm | |
import torch | |
class ResnetClassifier(L.LightningModule): | |
def __init__( | |
self, | |
base_model: str = "efficientnet_b0", | |
pretrained: bool = True, | |
num_classes: int = 2, # Binary classification with two classes | |
lr: float = 1e-3, | |
weight_decay: float = 1e-5, | |
factor: float = 0.1, | |
patience: int = 10, | |
min_lr: float = 1e-6, | |
): | |
super().__init__() | |
self.save_hyperparameters() | |
# Vision Transformer model initialization | |
self.model = timm.create_model( | |
base_model, pretrained=pretrained, num_classes=num_classes | |
) | |
# Define accuracy and F1 metrics for binary classification | |
self.train_acc = Accuracy(task="binary") | |
self.val_acc = Accuracy(task="binary") | |
self.test_acc = Accuracy(task="binary") | |
self.train_f1 = F1Score(task="binary") | |
self.val_f1 = F1Score(task="binary") | |
self.test_f1 = F1Score(task="binary") | |
def forward(self, x): | |
return self.model(x) | |
def _shared_step(self, batch, stage): | |
x, y = batch | |
logits = self(x) # Model output shape: [batch_size, num_classes] | |
loss = F.cross_entropy(logits, y) # Cross-entropy for binary classification | |
preds = torch.argmax(logits, dim=1) # Predicted class (0 or 1) | |
# Update and log metrics | |
acc = getattr(self, f"{stage}_acc") | |
f1 = getattr(self, f"{stage}_f1") | |
acc(preds, y) | |
f1(preds, y) | |
# Logging of metrics and loss | |
self.log(f"{stage}_loss", loss, prog_bar=True, on_epoch=True) | |
self.log(f"{stage}_acc", acc, prog_bar=True, on_epoch=True) | |
self.log(f"{stage}_f1", f1, prog_bar=True, on_epoch=True) | |
return loss | |
def training_step(self, batch, batch_idx): | |
return self._shared_step(batch, "train") | |
def validation_step(self, batch, batch_idx): | |
self._shared_step(batch, "val") | |
def test_step(self, batch, batch_idx): | |
self._shared_step(batch, "test") | |
def configure_optimizers(self): | |
optimizer = optim.AdamW( | |
self.parameters(), | |
lr=self.hparams.lr, | |
weight_decay=self.hparams.weight_decay, | |
) | |
scheduler = optim.lr_scheduler.ReduceLROnPlateau( | |
optimizer, | |
mode="min", | |
factor=self.hparams.factor, | |
patience=self.hparams.patience, | |
min_lr=self.hparams.min_lr, | |
) | |
return { | |
"optimizer": optimizer, | |
"lr_scheduler": { | |
"scheduler": scheduler, | |
"monitor": "val_loss", | |
"interval": "epoch", | |
}, | |
} | |
if __name__ == "__main__": | |
model = ResnetClassifier() | |
print(model) | |