vision-diffmask / code /models /classification.py
din0s's picture
Add code
d4ab5ac unverified
raw
history blame
3.72 kB
"""
Parts of this file have been adapted from
https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial15/Vision_Transformer.html
"""
import pytorch_lightning as pl
import torch.nn.functional as F
from argparse import ArgumentParser
from torch import Tensor
from torch.optim import AdamW, Optimizer, RAdam
from torch.optim.lr_scheduler import _LRScheduler
from transformers import get_scheduler, PreTrainedModel
class ImageClassificationNet(pl.LightningModule):
@staticmethod
def add_model_specific_args(parent_parser: ArgumentParser) -> ArgumentParser:
parser = parent_parser.add_argument_group("Classification Model")
parser.add_argument(
"--optimizer",
type=str,
default="AdamW",
choices=["AdamW", "RAdam"],
help="The optimizer to use to train the model.",
)
parser.add_argument(
"--weight_decay",
type=float,
default=1e-2,
help="The optimizer's weight decay.",
)
parser.add_argument(
"--lr",
type=float,
default=5e-5,
help="The initial learning rate for the model.",
)
return parent_parser
def __init__(
self,
model: PreTrainedModel,
num_train_steps: int,
optimizer: str = "AdamW",
weight_decay: float = 1e-2,
lr: float = 5e-5,
):
"""A PyTorch Lightning Module for a HuggingFace model used for image classification.
Args:
model (PreTrainedModel): a pretrained model for image classification
num_train_steps (int): number of training steps
optimizer (str): optimizer to use
weight_decay (float): weight decay for optimizer
lr (float): the learning rate used for training
"""
super().__init__()
# Save the hyperparameters and the model
self.save_hyperparameters(ignore=["model"])
self.model = model
def forward(self, x: Tensor) -> Tensor:
return self.model(x).logits
def configure_optimizers(self) -> tuple[list[Optimizer], list[_LRScheduler]]:
# Set the optimizer class based on the hyperparameter
if self.hparams.optimizer == "AdamW":
optim_class = AdamW
elif self.hparams.optimizer == "RAdam":
optim_class = RAdam
else:
raise Exception(f"Unknown optimizer {self.hparams.optimizer}")
# Create the optimizer and the learning rate scheduler
optimizer = optim_class(
self.parameters(),
weight_decay=self.hparams.weight_decay,
lr=self.hparams.lr,
)
lr_scheduler = get_scheduler(
name="linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=self.hparams.num_train_steps,
)
return [optimizer], [lr_scheduler]
def _calculate_loss(self, batch: tuple[Tensor, Tensor], mode: str) -> Tensor:
imgs, labels = batch
preds = self.model(imgs).logits
loss = F.cross_entropy(preds, labels)
acc = (preds.argmax(dim=-1) == labels).float().mean()
self.log(f"{mode}_loss", loss)
self.log(f"{mode}_acc", acc)
return loss
def training_step(self, batch: tuple[Tensor, Tensor], _: Tensor) -> Tensor:
loss = self._calculate_loss(batch, mode="train")
return loss
def validation_step(self, batch: tuple[Tensor, Tensor], _: Tensor):
self._calculate_loss(batch, mode="val")
def test_step(self, batch: tuple[Tensor, Tensor], _: Tensor):
self._calculate_loss(batch, mode="test")