|
from __future__ import annotations |
|
import json |
|
from pathlib import Path |
|
|
|
import pytorch_lightning as pl |
|
import timm |
|
import torch |
|
import torch.nn as nn |
|
import torchmetrics |
|
from pydantic import BaseModel, Field |
|
|
|
from realfake.data import get_dss, get_dls |
|
from realfake.utils import Args |
|
|
|
N_CLASSES = 2 |
|
|
|
|
|
class AcceleratorParams(BaseModel): |
|
"""PyTorch Lightning accelerator parameters.""" |
|
|
|
name: str = Field("gpu") |
|
devices: int = Field(4) |
|
strategy: str = Field("dp") |
|
precision: int = Field(16) |
|
override_float32_matmul: bool = Field(True) |
|
float32_matmul: str = Field("medium") |
|
|
|
|
|
class RealFakeParams(Args): |
|
jsonl_file: Path |
|
dry_run: bool = Field(False) |
|
model_name: str = Field("convnext_tiny") |
|
batch_size: int = Field(256) |
|
freeze_epochs: int = Field(3) |
|
epochs: int = Field(6) |
|
base_lr: float = Field(1e-3) |
|
pretrained: bool = Field(True) |
|
progress_bar: bool = Field(False) |
|
accelerator: AcceleratorParams = Field(default_factory=AcceleratorParams) |
|
|
|
|
|
class RealFakeDataModule(pl.LightningDataModule): |
|
|
|
def __init__(self, jsonl_records: Path, batch_size: int, num_workers: int = 0): |
|
super().__init__() |
|
self.jsonl_records = jsonl_records |
|
self.batch_size = batch_size |
|
self.num_workers = num_workers |
|
self.dss = self.dls = None |
|
|
|
def setup(self, stage=None): |
|
records = [json.loads(line) for line in self.jsonl_records.open()] |
|
self.dss = get_dss(records) |
|
self.dls = get_dls(*self.dss, self.batch_size, self.num_workers) |
|
|
|
def train_dataloader(self): |
|
return self.dls[0] |
|
|
|
def val_dataloader(self): |
|
return self.dls[1] |
|
|
|
|
|
class RealFakeClassifier(pl.LightningModule): |
|
|
|
def __init__(self, params: RealFakeParams): |
|
super().__init__() |
|
self.params = params |
|
self.ce = nn.BCEWithLogitsLoss() |
|
self.model = timm.create_model(params.model_name, pretrained=params.pretrained, num_classes=N_CLASSES) |
|
self.acc = torchmetrics.Accuracy(task="binary") |
|
|
|
def train_dataloader(self): |
|
return self.dls.train |
|
|
|
def val_dataloader(self): |
|
return self.dls.valid |
|
|
|
def forward(self, batch): |
|
x, y = batch["image"], batch["label"] |
|
y = torch.nn.functional.one_hot(y, num_classes=N_CLASSES).float() |
|
out = self.model(x) |
|
loss = self.ce(out, y) |
|
return loss, out, y |
|
|
|
def training_step(self, batch, batch_idx): |
|
loss, _, _ = self.forward(batch) |
|
self.log("train_loss", loss, on_epoch=True, on_step=False) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
loss, out, y = self.forward(batch) |
|
y_pred = out.sigmoid().argmax(dim=-1) |
|
y_true = y.argmax(dim=-1) |
|
self.log("val_loss", loss, on_epoch=True, on_step=False) |
|
return {"gt": y_true, "yhat": y_pred} |
|
|
|
def validation_step_end(self, outputs): |
|
self.acc.update(outputs["yhat"], outputs["gt"]) |
|
|
|
def validation_epoch_end(self, outputs): |
|
self.log("val_acc", self.acc.compute(), on_epoch=True) |
|
self.acc.reset() |
|
|
|
def configure_optimizers(self): |
|
adamw = torch.optim.AdamW(self.parameters(), lr=self.params.base_lr) |
|
one_cycle = torch.optim.lr_scheduler.OneCycleLR( |
|
adamw, |
|
max_lr=self.params.base_lr, |
|
total_steps=self.trainer.estimated_stepping_batches |
|
) |
|
return [adamw], [one_cycle] |
|
|