import os import sys sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) import time import numpy as np import pytorch_lightning as pl import torch.nn as nn import torchmetrics as tm from torch import optim from utils import configs from .backbone_model import CLIPModel, TorchModel class ImageClassificationLightningModule(pl.LightningModule): def __init__( self, num_classes: int = len(configs.CLASS_CHARACTERS) - 1, learning_rate: float = 3e-4, weight_decay: float = 0.0, name_model: str = "resnet50", freeze_model: bool = True, pretrained_model: bool = True, ): super().__init__() self.num_classes = num_classes self.learning_rate = learning_rate self.weight_decay = weight_decay self.freeze_model = freeze_model self.pretrained_model = pretrained_model self.name_model = name_model self.criterion = ( nn.BCEWithLogitsLoss() if self.num_classes in (1, 2) else nn.CrossEntropyLoss() ) self.create_models() self.create_metrics_models() def create_models(self): if self.name_model != "clip": self.model = TorchModel( self.name_model, self.freeze_model, self.pretrained_model, self.num_classes, ) else: self.model = CLIPModel( configs.CLIP_NAME_MODEL, self.freeze_model, self.pretrained_model, self.num_classes, ) def create_metrics_models(self): self.metrics_accuracy = tm.Accuracy( num_classes=1 if self.num_classes in (1, 2) else self.num_classes, average="macro", task="multiclass", ) self.metrics_precision = tm.Precision( num_classes=1 if self.num_classes in (1, 2) else self.num_classes, average="macro", task="multiclass", ) self.metrics_recall = tm.Recall( num_classes=1 if self.num_classes in (1, 2) else self.num_classes, average="macro", task="multiclass", ) self.metrics_f1 = tm.F1Score( num_classes=1 if self.num_classes in (1, 2) else self.num_classes, average="macro", task="multiclass", ) def configure_optimizers(self): optimizer = optim.Adam( self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay ) lr_scheduler = optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda x: (((1 + np.cos(x * np.pi / 20)) / 2) ** 1.0) * 0.9 + 0.1, ) return { "optimizer": optimizer, "lr_scheduler": lr_scheduler, "monitor": "metrics_f1_score", } def forward(self, x): output = self.model(x) return output def training_step(self, batch, batch_idx): x, y = batch y = y.unsqueeze(1).float() if self.num_classes in (1, 2) else y start_time = time.perf_counter() preds_y = self(x) inference_time = time.perf_counter() - start_time loss = self.criterion(preds_y, y) self.metrics_accuracy(preds_y, y) self.metrics_precision(preds_y, y) self.metrics_recall(preds_y, y) self.metrics_f1(preds_y, y) self.log( "metrics_accuracy", self.metrics_accuracy, on_step=False, on_epoch=True, prog_bar=True, ) self.log( "metrics_precision", self.metrics_precision, on_step=False, on_epoch=True, prog_bar=True, ) self.log( "metrics_recall", self.metrics_recall, on_step=False, on_epoch=True, prog_bar=True, ) self.log( "metrics_f1_score", self.metrics_f1, on_step=False, on_epoch=True, prog_bar=True, ) self.log( "metrics_inference_time", inference_time, on_step=False, on_epoch=True, prog_bar=True, ) return loss