|
from __future__ import annotations |
|
import random |
|
|
|
import albumentations as A |
|
import numpy as np |
|
import PIL.Image |
|
from albumentations.pytorch.transforms import ToTensorV2 |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
from realfake.config import LABELS |
|
|
|
IMG_RESIZE = 256 |
|
IMG_CROP = 224 |
|
|
|
|
|
class DictDataset(Dataset): |
|
def __init__(self, records: list[dict], transform_x=None): |
|
self.records = records |
|
self.transform_x = transform_x |
|
|
|
def __len__(self): |
|
return len(self.records) |
|
|
|
def __getitem__(self, idx): |
|
record = self.records[idx] |
|
image = np.asarray(PIL.Image.open(record["path"])) |
|
if self.transform_x is not None: |
|
image = self.transform_x(image=image)["image"] |
|
item = {"image": image} |
|
if "label" in record: |
|
item["label"] = LABELS[record["label"]] |
|
return item |
|
|
|
|
|
def get_augs(train: bool = True) -> A.Compose: |
|
if train: |
|
return A.Compose([ |
|
A.Resize(IMG_RESIZE, IMG_RESIZE), |
|
A.RandomCrop(IMG_CROP, IMG_CROP), |
|
A.HorizontalFlip(), |
|
A.VerticalFlip(), |
|
A.RandomBrightnessContrast(), |
|
A.Affine(), |
|
A.Rotate(), |
|
A.CoarseDropout(), |
|
ExpandChannels(), |
|
RGBAtoRGB(), |
|
A.Normalize(), |
|
ToTensorV2(), |
|
]) |
|
else: |
|
return A.Compose([ |
|
A.Resize(IMG_RESIZE, IMG_RESIZE), |
|
A.CenterCrop(IMG_CROP, IMG_CROP), |
|
ExpandChannels(), |
|
RGBAtoRGB(), |
|
A.Normalize(), |
|
ToTensorV2(), |
|
]) |
|
|
|
|
|
class ExpandChannels(A.ImageOnlyTransform): |
|
"""Expands image up to three channes if the image is grayscale.""" |
|
|
|
def __init__(self, always_apply: bool = False, p: float = 0.5): |
|
super().__init__(True, 1.0) |
|
|
|
def apply(self, image, **params): |
|
if image.ndim == 2: |
|
image = np.repeat(image[..., None], 3, axis=2) |
|
elif image.shape[2] == 1: |
|
image = np.repeat(image, 3, axis=2) |
|
return image |
|
|
|
|
|
class RGBAtoRGB(A.ImageOnlyTransform): |
|
"""Converts RGBA image to RGB.""" |
|
|
|
def __init__(self, always_apply: bool = False, p: float = 0.5): |
|
super().__init__(True, 1.0) |
|
|
|
def apply(self, image, **params): |
|
if image.shape[2] == 4: |
|
image = image[:, :, :3] |
|
return image |
|
|
|
|
|
def get_dss(records: list) -> tuple[DictDataset, DictDataset]: |
|
train_records = [x for x in records if not x["valid"]] |
|
valid_records = [x for x in records if x["valid"]] |
|
assert len(train_records) + len(valid_records) == len(records) |
|
random.shuffle(train_records) |
|
train_ds = DictDataset(train_records, transform_x=get_augs(train=True)) |
|
valid_ds = DictDataset(valid_records, transform_x=get_augs(train=False)) |
|
return train_ds, valid_ds |
|
|
|
|
|
def get_dls(train_ds: DictDataset, valid_ds: DictDataset, bs: int, num_workers: int) -> tuple[DataLoader, DataLoader]: |
|
train_dl = DataLoader(train_ds, batch_size=bs, num_workers=num_workers) |
|
valid_dl = DataLoader(valid_ds, batch_size=bs, num_workers=num_workers, shuffle=False) |
|
return train_dl, valid_dl |
|
|