File size: 4,516 Bytes
49bceed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
import time
import numpy as np
import pytorch_lightning as pl
import torch.nn as nn
import torchmetrics as tm
from torch import optim
from utils import configs
from .backbone_model import CLIPModel, TorchModel
class ImageClassificationLightningModule(pl.LightningModule):
def __init__(
self,
num_classes: int = len(configs.CLASS_CHARACTERS) - 1,
learning_rate: float = 3e-4,
weight_decay: float = 0.0,
name_model: str = "resnet50",
freeze_model: bool = True,
pretrained_model: bool = True,
):
super().__init__()
self.num_classes = num_classes
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.freeze_model = freeze_model
self.pretrained_model = pretrained_model
self.name_model = name_model
self.criterion = (
nn.BCEWithLogitsLoss()
if self.num_classes in (1, 2)
else nn.CrossEntropyLoss()
)
self.create_models()
self.create_metrics_models()
def create_models(self):
if self.name_model != "clip":
self.model = TorchModel(
self.name_model,
self.freeze_model,
self.pretrained_model,
self.num_classes,
)
else:
self.model = CLIPModel(
configs.CLIP_NAME_MODEL,
self.freeze_model,
self.pretrained_model,
self.num_classes,
)
def create_metrics_models(self):
self.metrics_accuracy = tm.Accuracy(
num_classes=1 if self.num_classes in (1, 2) else self.num_classes,
average="macro",
task="multiclass",
)
self.metrics_precision = tm.Precision(
num_classes=1 if self.num_classes in (1, 2) else self.num_classes,
average="macro",
task="multiclass",
)
self.metrics_recall = tm.Recall(
num_classes=1 if self.num_classes in (1, 2) else self.num_classes,
average="macro",
task="multiclass",
)
self.metrics_f1 = tm.F1Score(
num_classes=1 if self.num_classes in (1, 2) else self.num_classes,
average="macro",
task="multiclass",
)
def configure_optimizers(self):
optimizer = optim.Adam(
self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
)
lr_scheduler = optim.lr_scheduler.LambdaLR(
optimizer,
lr_lambda=lambda x: (((1 + np.cos(x * np.pi / 20)) / 2) ** 1.0) * 0.9 + 0.1,
)
return {
"optimizer": optimizer,
"lr_scheduler": lr_scheduler,
"monitor": "metrics_f1_score",
}
def forward(self, x):
output = self.model(x)
return output
def training_step(self, batch, batch_idx):
x, y = batch
y = y.unsqueeze(1).float() if self.num_classes in (1, 2) else y
start_time = time.perf_counter()
preds_y = self(x)
inference_time = time.perf_counter() - start_time
loss = self.criterion(preds_y, y)
self.metrics_accuracy(preds_y, y)
self.metrics_precision(preds_y, y)
self.metrics_recall(preds_y, y)
self.metrics_f1(preds_y, y)
self.log(
"metrics_accuracy",
self.metrics_accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
)
self.log(
"metrics_precision",
self.metrics_precision,
on_step=False,
on_epoch=True,
prog_bar=True,
)
self.log(
"metrics_recall",
self.metrics_recall,
on_step=False,
on_epoch=True,
prog_bar=True,
)
self.log(
"metrics_f1_score",
self.metrics_f1,
on_step=False,
on_epoch=True,
prog_bar=True,
)
self.log(
"metrics_inference_time",
inference_time,
on_step=False,
on_epoch=True,
prog_bar=True,
)
return loss
|