File size: 11,557 Bytes
3631068 edd29d3 3631068 416c7bb 3631068 416c7bb 3631068 68dd12a 3631068 68dd12a 3631068 2d6a578 416c7bb 2d6a578 416c7bb 2d6a578 416c7bb 2d6a578 416c7bb 2d6a578 416c7bb 2d6a578 416c7bb 2d6a578 3631068 afbe904 3631068 afbe904 3631068 afbe904 3631068 afbe904 3631068 912d566 3631068 afbe904 3631068 5c43f60 d1e10d9 3631068 912d566 a02f133 d1e10d9 3631068 69c8e55 5c43f60 3631068 69c8e55 3631068 d1e10d9 5c43f60 3631068 6ff7b63 d358c49 3631068 d1e10d9 3631068 69c8e55 3631068 d1e10d9 5c43f60 3631068 6ff7b63 69c8e55 3631068 d1e10d9 3631068 5c43f60 d1e10d9 5c43f60 d1e10d9 5c43f60 3631068 912d566 ef4052d 49d8194 3631068 49d8194 912d566 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 |
import torchmetrics
from . import config
from typing import Tuple, Dict, List, Any
import numpy as np
import torch
import torchvision
import torch.nn as nn
import pytorch_lightning as ptl
class DeepFontBaseline(nn.Module):
def __init__(self) -> None:
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 64, 11, 2),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, 3, 1, 1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(256, 256, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(256, 256, 3, 1, 1),
nn.ReLU(),
# fc
nn.Flatten(),
nn.Linear(256 * 12 * 12, 4096),
nn.ReLU(),
nn.Linear(4096, 4096),
nn.ReLU(),
nn.Linear(4096, config.FONT_COUNT),
)
def forward(self, X):
return self.model(X)
class ResNet18Regressor(nn.Module):
def __init__(self, pretrained: bool = False, regression_use_tanh: bool = False):
super().__init__()
weights = torchvision.models.ResNet18_Weights.DEFAULT if pretrained else None
self.model = torchvision.models.resnet18(weights=weights)
self.model.fc = nn.Linear(512, config.FONT_COUNT + 12)
self.regression_use_tanh = regression_use_tanh
def forward(self, X):
X = self.model(X)
# [0, 1]
if not self.regression_use_tanh:
X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].sigmoid()
else:
X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].tanh()
return X
class ResNet34Regressor(nn.Module):
def __init__(self, pretrained: bool = False, regression_use_tanh: bool = False):
super().__init__()
weights = torchvision.models.ResNet34_Weights.DEFAULT if pretrained else None
self.model = torchvision.models.resnet34(weights=weights)
self.model.fc = nn.Linear(512, config.FONT_COUNT + 12)
self.regression_use_tanh = regression_use_tanh
def forward(self, X):
X = self.model(X)
# [0, 1]
if not self.regression_use_tanh:
X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].sigmoid()
else:
X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].tanh()
return X
class ResNet50Regressor(nn.Module):
def __init__(self, pretrained: bool = False, regression_use_tanh: bool = False):
super().__init__()
weights = torchvision.models.ResNet50_Weights.DEFAULT if pretrained else None
self.model = torchvision.models.resnet50(weights=weights)
self.model.fc = nn.Linear(2048, config.FONT_COUNT + 12)
self.regression_use_tanh = regression_use_tanh
def forward(self, X):
X = self.model(X)
# [0, 1]
if not self.regression_use_tanh:
X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].sigmoid()
else:
X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].tanh()
return X
class ResNet101Regressor(nn.Module):
def __init__(self, pretrained: bool = False, regression_use_tanh: bool = False):
super().__init__()
weights = torchvision.models.ResNet101_Weights.DEFAULT if pretrained else None
self.model = torchvision.models.resnet101(weights=weights)
self.model.fc = nn.Linear(2048, config.FONT_COUNT + 12)
self.regression_use_tanh = regression_use_tanh
def forward(self, X):
X = self.model(X)
# [0, 1]
if not self.regression_use_tanh:
X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].sigmoid()
else:
X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].tanh()
return X
class FontDetectorLoss(nn.Module):
def __init__(
self, lambda_font, lambda_direction, lambda_regression, font_classification_only
):
super().__init__()
self.category_loss = nn.CrossEntropyLoss()
self.regression_loss = nn.MSELoss()
self.lambda_font = lambda_font
self.lambda_direction = lambda_direction
self.lambda_regression = lambda_regression
self.font_classfiication_only = font_classification_only
def forward(self, y_hat, y):
font_cat = self.category_loss(y_hat[..., : config.FONT_COUNT], y[..., 0].long())
if self.font_classfiication_only:
return self.lambda_font * font_cat
direction_cat = self.category_loss(
y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1].long()
)
regression = self.regression_loss(
y_hat[..., config.FONT_COUNT + 2 :], y[..., 2:]
)
return (
self.lambda_font * font_cat
+ self.lambda_direction * direction_cat
+ self.lambda_regression * regression
)
class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup, max_iters):
self.warmup = warmup
self.max_num_iters = max_iters
super().__init__(optimizer)
def get_lr(self):
lr_factor = self.get_lr_factor(epoch=self.last_epoch)
return [base_lr * lr_factor for base_lr in self.base_lrs]
def get_lr_factor(self, epoch):
lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
if epoch <= self.warmup:
lr_factor *= epoch * 1.0 / self.warmup
return lr_factor
class FontDetector(ptl.LightningModule):
def __init__(
self,
model: nn.Module,
lambda_font: float,
lambda_direction: float,
lambda_regression: float,
font_classification_only: bool,
lr: float,
betas: Tuple[float, float],
num_warmup_iters: int,
num_iters: int,
num_epochs: int,
):
super().__init__()
self.model = model
self.loss = FontDetectorLoss(
lambda_font, lambda_direction, lambda_regression, font_classification_only
)
self.font_accur_train = torchmetrics.Accuracy(
task="multiclass", num_classes=config.FONT_COUNT
)
self.font_accur_val = torchmetrics.Accuracy(
task="multiclass", num_classes=config.FONT_COUNT
)
self.font_accur_test = torchmetrics.Accuracy(
task="multiclass", num_classes=config.FONT_COUNT
)
if not font_classification_only:
self.direction_accur_train = torchmetrics.Accuracy(
task="multiclass", num_classes=2
)
self.direction_accur_val = torchmetrics.Accuracy(
task="multiclass", num_classes=2
)
self.direction_accur_test = torchmetrics.Accuracy(
task="multiclass", num_classes=2
)
self.lr = lr
self.betas = betas
self.num_warmup_iters = num_warmup_iters
self.num_iters = num_iters
self.num_epochs = num_epochs
self.load_epoch = -1
self.font_classification_only = font_classification_only
def forward(self, x):
return self.model(x)
def training_step(
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> Dict[str, Any]:
X, y = batch
y_hat = self.forward(X)
loss = self.loss(y_hat, y)
self.log("train_loss", loss, prog_bar=True, sync_dist=True)
# accur
self.log(
"train_font_accur",
self.font_accur_train(y_hat[..., : config.FONT_COUNT], y[..., 0]),
sync_dist=True,
)
if not self.font_classification_only:
self.log(
"train_direction_accur",
self.direction_accur_train(
y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
),
sync_dist=True,
)
return {"loss": loss}
def on_train_epoch_end(self) -> None:
self.log("train_font_accur", self.font_accur_train.compute(), sync_dist=True)
self.font_accur_train.reset()
if not self.font_classification_only:
self.log(
"train_direction_accur",
self.direction_accur_train.compute(),
sync_dist=True,
)
self.direction_accur_train.reset()
def validation_step(
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> Dict[str, Any]:
X, y = batch
y_hat = self.forward(X)
loss = self.loss(y_hat, y)
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
self.font_accur_val.update(y_hat[..., : config.FONT_COUNT], y[..., 0])
if not self.font_classification_only:
self.direction_accur_val.update(
y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
)
return {"loss": loss}
def on_validation_epoch_end(self):
self.log("val_font_accur", self.font_accur_val.compute(), sync_dist=True)
self.font_accur_val.reset()
if not self.font_classification_only:
self.log(
"val_direction_accur",
self.direction_accur_val.compute(),
sync_dist=True,
)
self.direction_accur_val.reset()
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
X, y = batch
y_hat = self.forward(X)
loss = self.loss(y_hat, y)
self.log("test_loss", loss, prog_bar=True, sync_dist=True)
self.font_accur_test.update(y_hat[..., : config.FONT_COUNT], y[..., 0])
if not self.font_classification_only:
self.direction_accur_test.update(
y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
)
return {"loss": loss}
def on_test_epoch_end(self) -> None:
self.log("test_font_accur", self.font_accur_test.compute(), sync_dist=True)
self.font_accur_test.reset()
if not self.font_classification_only:
self.log(
"test_direction_accur",
self.direction_accur_test.compute(),
sync_dist=True,
)
self.direction_accur_test.reset()
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.model.parameters(), lr=self.lr, betas=self.betas
)
self.scheduler = CosineWarmupScheduler(
optimizer, self.num_warmup_iters, self.num_iters
)
print("Load epoch:", self.load_epoch)
for _ in range(self.num_iters * (self.load_epoch + 1) // self.num_epochs):
self.scheduler.step()
print("Current learning rate set to:", self.scheduler.get_last_lr())
return optimizer
def optimizer_step(
self,
epoch: int,
batch_idx: int,
optimizer,
optimizer_idx: int = 0,
*args,
**kwargs
):
super().optimizer_step(
epoch, batch_idx, optimizer, optimizer_idx, *args, **kwargs
)
self.log("lr", self.scheduler.get_last_lr()[0])
self.scheduler.step()
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.load_epoch = checkpoint["epoch"]
|