File size: 3,145 Bytes
ea847ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from __future__ import annotations
import random
import albumentations as A
import numpy as np
import PIL.Image
from albumentations.pytorch.transforms import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from realfake.config import LABELS
IMG_RESIZE = 256
IMG_CROP = 224
class DictDataset(Dataset):
def __init__(self, records: list[dict], transform_x=None):
self.records = records
self.transform_x = transform_x
def __len__(self):
return len(self.records)
def __getitem__(self, idx):
record = self.records[idx]
image = np.asarray(PIL.Image.open(record["path"]))
if self.transform_x is not None:
image = self.transform_x(image=image)["image"]
item = {"image": image}
if "label" in record:
item["label"] = LABELS[record["label"]]
return item
def get_augs(train: bool = True) -> A.Compose:
if train:
return A.Compose([
A.Resize(IMG_RESIZE, IMG_RESIZE),
A.RandomCrop(IMG_CROP, IMG_CROP),
A.HorizontalFlip(),
A.VerticalFlip(),
A.RandomBrightnessContrast(),
A.Affine(),
A.Rotate(),
A.CoarseDropout(),
ExpandChannels(),
RGBAtoRGB(),
A.Normalize(),
ToTensorV2(),
])
else:
return A.Compose([
A.Resize(IMG_RESIZE, IMG_RESIZE),
A.CenterCrop(IMG_CROP, IMG_CROP),
ExpandChannels(),
RGBAtoRGB(),
A.Normalize(),
ToTensorV2(),
])
class ExpandChannels(A.ImageOnlyTransform):
"""Expands image up to three channes if the image is grayscale."""
def __init__(self, always_apply: bool = False, p: float = 0.5):
super().__init__(True, 1.0)
def apply(self, image, **params):
if image.ndim == 2:
image = np.repeat(image[..., None], 3, axis=2)
elif image.shape[2] == 1:
image = np.repeat(image, 3, axis=2)
return image
class RGBAtoRGB(A.ImageOnlyTransform):
"""Converts RGBA image to RGB."""
def __init__(self, always_apply: bool = False, p: float = 0.5):
super().__init__(True, 1.0)
def apply(self, image, **params):
if image.shape[2] == 4:
image = image[:, :, :3]
return image
def get_dss(records: list) -> tuple[DictDataset, DictDataset]:
train_records = [x for x in records if not x["valid"]]
valid_records = [x for x in records if x["valid"]]
assert len(train_records) + len(valid_records) == len(records)
random.shuffle(train_records)
train_ds = DictDataset(train_records, transform_x=get_augs(train=True))
valid_ds = DictDataset(valid_records, transform_x=get_augs(train=False))
return train_ds, valid_ds
def get_dls(train_ds: DictDataset, valid_ds: DictDataset, bs: int, num_workers: int) -> tuple[DataLoader, DataLoader]:
train_dl = DataLoader(train_ds, batch_size=bs, num_workers=num_workers)
valid_dl = DataLoader(valid_ds, batch_size=bs, num_workers=num_workers, shuffle=False)
return train_dl, valid_dl
|