File size: 3,145 Bytes
ea847ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from __future__ import annotations
import random

import albumentations as A
import numpy as np
import PIL.Image
from albumentations.pytorch.transforms import ToTensorV2
from torch.utils.data import Dataset, DataLoader

from realfake.config import LABELS

IMG_RESIZE = 256
IMG_CROP = 224


class DictDataset(Dataset):
    def __init__(self, records: list[dict], transform_x=None):
        self.records = records
        self.transform_x = transform_x

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        record = self.records[idx]
        image = np.asarray(PIL.Image.open(record["path"]))
        if self.transform_x is not None:
            image = self.transform_x(image=image)["image"]
        item = {"image": image}
        if "label" in record:
            item["label"] = LABELS[record["label"]]
        return item


def get_augs(train: bool = True) -> A.Compose:
    if train:
        return A.Compose([
            A.Resize(IMG_RESIZE, IMG_RESIZE),
            A.RandomCrop(IMG_CROP, IMG_CROP),
            A.HorizontalFlip(),
            A.VerticalFlip(),
            A.RandomBrightnessContrast(),
            A.Affine(),
            A.Rotate(),
            A.CoarseDropout(),
            ExpandChannels(),
            RGBAtoRGB(),
            A.Normalize(),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.Resize(IMG_RESIZE, IMG_RESIZE),
            A.CenterCrop(IMG_CROP, IMG_CROP),
            ExpandChannels(),
            RGBAtoRGB(),
            A.Normalize(),
            ToTensorV2(),
        ])


class ExpandChannels(A.ImageOnlyTransform):
    """Expands image up to three channes if the image is grayscale."""

    def __init__(self, always_apply: bool = False, p: float = 0.5):
        super().__init__(True, 1.0)

    def apply(self, image, **params):
        if image.ndim == 2:
            image = np.repeat(image[..., None], 3, axis=2)
        elif image.shape[2] == 1:
            image = np.repeat(image, 3, axis=2)
        return image


class RGBAtoRGB(A.ImageOnlyTransform):
    """Converts RGBA image to RGB."""

    def __init__(self, always_apply: bool = False, p: float = 0.5):
        super().__init__(True, 1.0)

    def apply(self, image, **params):
        if image.shape[2] == 4:
            image = image[:, :, :3]
        return image


def get_dss(records: list) -> tuple[DictDataset, DictDataset]:
    train_records = [x for x in records if not x["valid"]]
    valid_records = [x for x in records if x["valid"]]
    assert len(train_records) + len(valid_records) == len(records)
    random.shuffle(train_records)
    train_ds = DictDataset(train_records, transform_x=get_augs(train=True))
    valid_ds = DictDataset(valid_records, transform_x=get_augs(train=False))
    return train_ds, valid_ds


def get_dls(train_ds: DictDataset, valid_ds: DictDataset, bs: int, num_workers: int) -> tuple[DataLoader, DataLoader]:
    train_dl = DataLoader(train_ds, batch_size=bs, num_workers=num_workers)
    valid_dl = DataLoader(valid_ds, batch_size=bs, num_workers=num_workers, shuffle=False)
    return train_dl, valid_dl