from __future__ import annotations import random import albumentations as A import numpy as np import PIL.Image from albumentations.pytorch.transforms import ToTensorV2 from torch.utils.data import Dataset, DataLoader from realfake.config import LABELS IMG_RESIZE = 256 IMG_CROP = 224 class DictDataset(Dataset): def __init__(self, records: list[dict], transform_x=None): self.records = records self.transform_x = transform_x def __len__(self): return len(self.records) def __getitem__(self, idx): record = self.records[idx] image = np.asarray(PIL.Image.open(record["path"])) if self.transform_x is not None: image = self.transform_x(image=image)["image"] item = {"image": image} if "label" in record: item["label"] = LABELS[record["label"]] return item def get_augs(train: bool = True) -> A.Compose: if train: return A.Compose([ A.Resize(IMG_RESIZE, IMG_RESIZE), A.RandomCrop(IMG_CROP, IMG_CROP), A.HorizontalFlip(), A.VerticalFlip(), A.RandomBrightnessContrast(), A.Affine(), A.Rotate(), A.CoarseDropout(), ExpandChannels(), RGBAtoRGB(), A.Normalize(), ToTensorV2(), ]) else: return A.Compose([ A.Resize(IMG_RESIZE, IMG_RESIZE), A.CenterCrop(IMG_CROP, IMG_CROP), ExpandChannels(), RGBAtoRGB(), A.Normalize(), ToTensorV2(), ]) class ExpandChannels(A.ImageOnlyTransform): """Expands image up to three channes if the image is grayscale.""" def __init__(self, always_apply: bool = False, p: float = 0.5): super().__init__(True, 1.0) def apply(self, image, **params): if image.ndim == 2: image = np.repeat(image[..., None], 3, axis=2) elif image.shape[2] == 1: image = np.repeat(image, 3, axis=2) return image class RGBAtoRGB(A.ImageOnlyTransform): """Converts RGBA image to RGB.""" def __init__(self, always_apply: bool = False, p: float = 0.5): super().__init__(True, 1.0) def apply(self, image, **params): if image.shape[2] == 4: image = image[:, :, :3] return image def get_dss(records: list) -> tuple[DictDataset, DictDataset]: train_records = [x for x in records if not x["valid"]] valid_records = [x for x in records if x["valid"]] assert len(train_records) + len(valid_records) == len(records) random.shuffle(train_records) train_ds = DictDataset(train_records, transform_x=get_augs(train=True)) valid_ds = DictDataset(valid_records, transform_x=get_augs(train=False)) return train_ds, valid_ds def get_dls(train_ds: DictDataset, valid_ds: DictDataset, bs: int, num_workers: int) -> tuple[DataLoader, DataLoader]: train_dl = DataLoader(train_ds, batch_size=bs, num_workers=num_workers) valid_dl = DataLoader(valid_ds, batch_size=bs, num_workers=num_workers, shuffle=False) return train_dl, valid_dl