File size: 3,817 Bytes
1a030c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch.nn as nn
from torch import Tensor
from pathlib import Path
import torch
import random
import torchvision.io as VIO
import torchvision.transforms.functional as VF
from dataclasses import dataclass
from tqdm.auto import tqdm

# https://huggingface.co/datasets/huggan/anime-faces
RAW_IMAGES_PATH = Path(
    "~/Downloads/datasets/anime/anime-faces/images").expanduser()
RESOLUTIONS = [64, 8]

AS_TENSORS_64 = Path(f"data/all_images_64.bin")
AS_TENSORS_8 = Path(f"data/all_images_8.bin")


@dataclass
class ImageBatch:
    im8: Tensor
    im64: Tensor
    loss: Tensor

    @property
    def n_batch(self):
        return self.im8.shape[0]

    def as_1d(self):
        return ImageBatch(
            im8=self.im8.view(self.n_batch, 8*8, self.im8.shape[-1]),
            im64=self.im64.view(self.n_batch, 64*64, self.im64.shape[-1]),
            loss=self.loss
        )

    def as_2d(self):
        return ImageBatch(
            im8=self.im8.view(self.n_batch, 8, 8, self.im8.shape[-1]),
            im64=self.im64.view(self.n_batch, 64, 64, self.im64.shape[-1]),
            loss=self.loss
        )


class ImageDB:
    def __init__(self, val_ratio=0.05, dtype=None) -> None:
        if not AS_TENSORS_64.exists():
            self.make_tensor_version()
        print("Load tensors file")
        self.dtype = dtype or torch.bfloat16
        self.all_images_64 = torch.load(AS_TENSORS_64).to(self.dtype)
        self.all_images_8 = torch.load(AS_TENSORS_8).to(self.dtype)
        self.n_val = int(len(self.all_images_64) * val_ratio)

    def split(self, s: str):
        if s == "train":
            return {
                8: self.all_images_8[:-self.n_val],
                64: self.all_images_64[:-self.n_val]
            }
        if s == "valid":
            return {
                8: self.all_images_8[-self.n_val:],
                64: self.all_images_64[-self.n_val:]
            }
        raise ValueError(f"Invalid split {s}")

    @property
    def train_ds(self):
        return self.split("train")

    @property
    def valid_ds(self):
        return self.split("valid")

    @torch.no_grad()
    def make_tensor_version(self, path=RAW_IMAGES_PATH):
        items = list(path.glob("*.png"))
        all_tensors = [load_single_image(item) for item in tqdm(items)]
        t64 = torch.stack([t[64] for t in all_tensors])
        t8 = torch.stack([t[8] for t in all_tensors])
        torch.save(t64, AS_TENSORS_64)
        torch.save(t8, AS_TENSORS_8)
        return {8: t8, 64: t64}

    def random_batch(self, bs: int, split: str = "train"):
        split_dict = self.split(split)
        im8 = split_dict[8]
        im64 = split_dict[64]
        keys = list(range(len(im8)))
        random.shuffle(keys)
        keys = keys[: bs]
        return ImageBatch(
            im64=im64[keys].cuda(),
            im8=im8[keys].cuda(),
            loss=torch.tensor(-1))


def load_single_image(path: Path):
    im = VIO.read_image(str(path))
    im = im / 255.0
    # resize to 8x8
    im8 = VF.resize(im, [8, 8], VF.InterpolationMode.NEAREST_EXACT)
    # C H W -> H W C
    im = im.permute(1, 2, 0).contiguous()
    im8 = im8.permute(1, 2, 0).contiguous()

    return {64: im, 8: im8}


class RGBToModel(nn.Module):
    def __init__(self, d_model, device=None, dtype=None):
        super().__init__()
        self.fc = nn.Linear(3, d_model, device=device, dtype=dtype)

    def forward(self, x):
        return self.fc(x)


class ModelToRGB(nn.Module):
    def __init__(self, d_model, device=None, dtype=None):
        super().__init__()
        self.norm = nn.LayerNorm(d_model, device=device, dtype=dtype)
        self.fc = nn.Linear(d_model, 3, device=device, dtype=dtype)

    def forward(self, x):
        x = self.norm(x)
        x = self.fc(x)
        x = x.sigmoid()
        return x