from __future__ import annotations import json from pathlib import Path import pytorch_lightning as pl import timm import torch import torch.nn as nn import torchmetrics from pydantic import BaseModel, Field from realfake.data import get_dss, get_dls from realfake.utils import Args N_CLASSES = 2 class AcceleratorParams(BaseModel): """PyTorch Lightning accelerator parameters.""" name: str = Field("gpu") devices: int = Field(4) strategy: str = Field("dp") precision: int = Field(16) override_float32_matmul: bool = Field(True) float32_matmul: str = Field("medium") class RealFakeParams(Args): jsonl_file: Path dry_run: bool = Field(False) model_name: str = Field("convnext_tiny") batch_size: int = Field(256) freeze_epochs: int = Field(3) epochs: int = Field(6) base_lr: float = Field(1e-3) pretrained: bool = Field(True) progress_bar: bool = Field(False) accelerator: AcceleratorParams = Field(default_factory=AcceleratorParams) class RealFakeDataModule(pl.LightningDataModule): def __init__(self, jsonl_records: Path, batch_size: int, num_workers: int = 0): super().__init__() self.jsonl_records = jsonl_records self.batch_size = batch_size self.num_workers = num_workers self.dss = self.dls = None def setup(self, stage=None): records = [json.loads(line) for line in self.jsonl_records.open()] self.dss = get_dss(records) self.dls = get_dls(*self.dss, self.batch_size, self.num_workers) def train_dataloader(self): return self.dls[0] def val_dataloader(self): return self.dls[1] class RealFakeClassifier(pl.LightningModule): def __init__(self, params: RealFakeParams): super().__init__() self.params = params self.ce = nn.BCEWithLogitsLoss() self.model = timm.create_model(params.model_name, pretrained=params.pretrained, num_classes=N_CLASSES) self.acc = torchmetrics.Accuracy(task="binary") def train_dataloader(self): return self.dls.train def val_dataloader(self): return self.dls.valid def forward(self, batch): x, y = batch["image"], batch["label"] y = torch.nn.functional.one_hot(y, num_classes=N_CLASSES).float() out = self.model(x) loss = self.ce(out, y) return loss, out, y def training_step(self, batch, batch_idx): loss, _, _ = self.forward(batch) self.log("train_loss", loss, on_epoch=True, on_step=False) return loss def validation_step(self, batch, batch_idx): loss, out, y = self.forward(batch) y_pred = out.sigmoid().argmax(dim=-1) y_true = y.argmax(dim=-1) self.log("val_loss", loss, on_epoch=True, on_step=False) return {"gt": y_true, "yhat": y_pred} def validation_step_end(self, outputs): self.acc.update(outputs["yhat"], outputs["gt"]) def validation_epoch_end(self, outputs): self.log("val_acc", self.acc.compute(), on_epoch=True) self.acc.reset() def configure_optimizers(self): adamw = torch.optim.AdamW(self.parameters(), lr=self.params.base_lr) one_cycle = torch.optim.lr_scheduler.OneCycleLR( adamw, max_lr=self.params.base_lr, total_steps=self.trainer.estimated_stepping_batches ) return [adamw], [one_cycle]