File size: 3,817 Bytes
1a030c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch.nn as nn
from torch import Tensor
from pathlib import Path
import torch
import random
import torchvision.io as VIO
import torchvision.transforms.functional as VF
from dataclasses import dataclass
from tqdm.auto import tqdm
# https://huggingface.co/datasets/huggan/anime-faces
RAW_IMAGES_PATH = Path(
"~/Downloads/datasets/anime/anime-faces/images").expanduser()
RESOLUTIONS = [64, 8]
AS_TENSORS_64 = Path(f"data/all_images_64.bin")
AS_TENSORS_8 = Path(f"data/all_images_8.bin")
@dataclass
class ImageBatch:
im8: Tensor
im64: Tensor
loss: Tensor
@property
def n_batch(self):
return self.im8.shape[0]
def as_1d(self):
return ImageBatch(
im8=self.im8.view(self.n_batch, 8*8, self.im8.shape[-1]),
im64=self.im64.view(self.n_batch, 64*64, self.im64.shape[-1]),
loss=self.loss
)
def as_2d(self):
return ImageBatch(
im8=self.im8.view(self.n_batch, 8, 8, self.im8.shape[-1]),
im64=self.im64.view(self.n_batch, 64, 64, self.im64.shape[-1]),
loss=self.loss
)
class ImageDB:
def __init__(self, val_ratio=0.05, dtype=None) -> None:
if not AS_TENSORS_64.exists():
self.make_tensor_version()
print("Load tensors file")
self.dtype = dtype or torch.bfloat16
self.all_images_64 = torch.load(AS_TENSORS_64).to(self.dtype)
self.all_images_8 = torch.load(AS_TENSORS_8).to(self.dtype)
self.n_val = int(len(self.all_images_64) * val_ratio)
def split(self, s: str):
if s == "train":
return {
8: self.all_images_8[:-self.n_val],
64: self.all_images_64[:-self.n_val]
}
if s == "valid":
return {
8: self.all_images_8[-self.n_val:],
64: self.all_images_64[-self.n_val:]
}
raise ValueError(f"Invalid split {s}")
@property
def train_ds(self):
return self.split("train")
@property
def valid_ds(self):
return self.split("valid")
@torch.no_grad()
def make_tensor_version(self, path=RAW_IMAGES_PATH):
items = list(path.glob("*.png"))
all_tensors = [load_single_image(item) for item in tqdm(items)]
t64 = torch.stack([t[64] for t in all_tensors])
t8 = torch.stack([t[8] for t in all_tensors])
torch.save(t64, AS_TENSORS_64)
torch.save(t8, AS_TENSORS_8)
return {8: t8, 64: t64}
def random_batch(self, bs: int, split: str = "train"):
split_dict = self.split(split)
im8 = split_dict[8]
im64 = split_dict[64]
keys = list(range(len(im8)))
random.shuffle(keys)
keys = keys[: bs]
return ImageBatch(
im64=im64[keys].cuda(),
im8=im8[keys].cuda(),
loss=torch.tensor(-1))
def load_single_image(path: Path):
im = VIO.read_image(str(path))
im = im / 255.0
# resize to 8x8
im8 = VF.resize(im, [8, 8], VF.InterpolationMode.NEAREST_EXACT)
# C H W -> H W C
im = im.permute(1, 2, 0).contiguous()
im8 = im8.permute(1, 2, 0).contiguous()
return {64: im, 8: im8}
class RGBToModel(nn.Module):
def __init__(self, d_model, device=None, dtype=None):
super().__init__()
self.fc = nn.Linear(3, d_model, device=device, dtype=dtype)
def forward(self, x):
return self.fc(x)
class ModelToRGB(nn.Module):
def __init__(self, d_model, device=None, dtype=None):
super().__init__()
self.norm = nn.LayerNorm(d_model, device=device, dtype=dtype)
self.fc = nn.Linear(d_model, 3, device=device, dtype=dtype)
def forward(self, x):
x = self.norm(x)
x = self.fc(x)
x = x.sigmoid()
return x
|