realfake / realfake /data.py
devforfu
Init
ea847ad
from __future__ import annotations
import random
import albumentations as A
import numpy as np
import PIL.Image
from albumentations.pytorch.transforms import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from realfake.config import LABELS
IMG_RESIZE = 256
IMG_CROP = 224
class DictDataset(Dataset):
def __init__(self, records: list[dict], transform_x=None):
self.records = records
self.transform_x = transform_x
def __len__(self):
return len(self.records)
def __getitem__(self, idx):
record = self.records[idx]
image = np.asarray(PIL.Image.open(record["path"]))
if self.transform_x is not None:
image = self.transform_x(image=image)["image"]
item = {"image": image}
if "label" in record:
item["label"] = LABELS[record["label"]]
return item
def get_augs(train: bool = True) -> A.Compose:
if train:
return A.Compose([
A.Resize(IMG_RESIZE, IMG_RESIZE),
A.RandomCrop(IMG_CROP, IMG_CROP),
A.HorizontalFlip(),
A.VerticalFlip(),
A.RandomBrightnessContrast(),
A.Affine(),
A.Rotate(),
A.CoarseDropout(),
ExpandChannels(),
RGBAtoRGB(),
A.Normalize(),
ToTensorV2(),
])
else:
return A.Compose([
A.Resize(IMG_RESIZE, IMG_RESIZE),
A.CenterCrop(IMG_CROP, IMG_CROP),
ExpandChannels(),
RGBAtoRGB(),
A.Normalize(),
ToTensorV2(),
])
class ExpandChannels(A.ImageOnlyTransform):
"""Expands image up to three channes if the image is grayscale."""
def __init__(self, always_apply: bool = False, p: float = 0.5):
super().__init__(True, 1.0)
def apply(self, image, **params):
if image.ndim == 2:
image = np.repeat(image[..., None], 3, axis=2)
elif image.shape[2] == 1:
image = np.repeat(image, 3, axis=2)
return image
class RGBAtoRGB(A.ImageOnlyTransform):
"""Converts RGBA image to RGB."""
def __init__(self, always_apply: bool = False, p: float = 0.5):
super().__init__(True, 1.0)
def apply(self, image, **params):
if image.shape[2] == 4:
image = image[:, :, :3]
return image
def get_dss(records: list) -> tuple[DictDataset, DictDataset]:
train_records = [x for x in records if not x["valid"]]
valid_records = [x for x in records if x["valid"]]
assert len(train_records) + len(valid_records) == len(records)
random.shuffle(train_records)
train_ds = DictDataset(train_records, transform_x=get_augs(train=True))
valid_ds = DictDataset(valid_records, transform_x=get_augs(train=False))
return train_ds, valid_ds
def get_dls(train_ds: DictDataset, valid_ds: DictDataset, bs: int, num_workers: int) -> tuple[DataLoader, DataLoader]:
train_dl = DataLoader(train_ds, batch_size=bs, num_workers=num_workers)
valid_dl = DataLoader(valid_ds, batch_size=bs, num_workers=num_workers, shuffle=False)
return train_dl, valid_dl