realfake / realfake /models.py
devforfu
Fine-tuning support
12babad
from __future__ import annotations
import json
from pathlib import Path
import pytorch_lightning as pl
import timm
import torch
import torch.nn as nn
import torchmetrics
from pydantic import BaseModel, Field
from realfake.data import get_dss, get_dls
from realfake.utils import Args
N_CLASSES = 2
class AcceleratorParams(BaseModel):
"""PyTorch Lightning accelerator parameters."""
name: str = Field("gpu")
devices: int = Field(4)
strategy: str = Field("dp")
precision: int = Field(16)
override_float32_matmul: bool = Field(True)
float32_matmul: str = Field("medium")
class RealFakeParams(Args):
jsonl_file: Path
dry_run: bool = Field(False)
model_name: str = Field("convnext_tiny")
batch_size: int = Field(256)
freeze_epochs: int = Field(3)
epochs: int = Field(6)
base_lr: float = Field(1e-3)
pretrained: bool = Field(True)
progress_bar: bool = Field(False)
accelerator: AcceleratorParams = Field(default_factory=AcceleratorParams)
class RealFakeDataModule(pl.LightningDataModule):
def __init__(self, jsonl_records: Path, batch_size: int, num_workers: int = 0):
super().__init__()
self.jsonl_records = jsonl_records
self.batch_size = batch_size
self.num_workers = num_workers
self.dss = self.dls = None
def setup(self, stage=None):
records = [json.loads(line) for line in self.jsonl_records.open()]
self.dss = get_dss(records)
self.dls = get_dls(*self.dss, self.batch_size, self.num_workers)
def train_dataloader(self):
return self.dls[0]
def val_dataloader(self):
return self.dls[1]
class RealFakeClassifier(pl.LightningModule):
def __init__(self, params: RealFakeParams):
super().__init__()
self.params = params
self.ce = nn.BCEWithLogitsLoss()
self.model = timm.create_model(params.model_name, pretrained=params.pretrained, num_classes=N_CLASSES)
self.acc = torchmetrics.Accuracy(task="binary")
def train_dataloader(self):
return self.dls.train
def val_dataloader(self):
return self.dls.valid
def forward(self, batch):
x, y = batch["image"], batch["label"]
y = torch.nn.functional.one_hot(y, num_classes=N_CLASSES).float()
out = self.model(x)
loss = self.ce(out, y)
return loss, out, y
def training_step(self, batch, batch_idx):
loss, _, _ = self.forward(batch)
self.log("train_loss", loss, on_epoch=True, on_step=False)
return loss
def validation_step(self, batch, batch_idx):
loss, out, y = self.forward(batch)
y_pred = out.sigmoid().argmax(dim=-1)
y_true = y.argmax(dim=-1)
self.log("val_loss", loss, on_epoch=True, on_step=False)
return {"gt": y_true, "yhat": y_pred}
def validation_step_end(self, outputs):
self.acc.update(outputs["yhat"], outputs["gt"])
def validation_epoch_end(self, outputs):
self.log("val_acc", self.acc.compute(), on_epoch=True)
self.acc.reset()
def configure_optimizers(self):
adamw = torch.optim.AdamW(self.parameters(), lr=self.params.base_lr)
one_cycle = torch.optim.lr_scheduler.OneCycleLR(
adamw,
max_lr=self.params.base_lr,
total_steps=self.trainer.estimated_stepping_batches
)
return [adamw], [one_cycle]