Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,534 Bytes
9e426da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import pathlib
import torch
import random
import numpy as np
from torchvision.io.image import read_image
import torchvision.transforms as tvtf
from torch.utils.data import Dataset
class CenterCrop:
def __init__(self, size):
self.size = size
def __call__(self, image):
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
return center_crop_arr(image, self.size)
from PIL import Image
IMG_EXTENSIONS = (
"*.png",
"*.JPEG",
"*.jpeg",
"*.jpg"
)
def test_collate(batch):
return torch.stack(batch)
class ImageDataset(Dataset):
def __init__(self, root, image_size=(224, 224)):
self.root = pathlib.Path(root)
images = []
for ext in IMG_EXTENSIONS:
images.extend(self.root.rglob(ext))
random.shuffle(images)
self.images = list(map(lambda x: str(x), images))
self.transform = tvtf.Compose(
[
CenterCrop(image_size[0]),
tvtf.ToTensor(),
tvtf.Lambda(lambda x: (x*255).to(torch.uint8)),
tvtf.Lambda(lambda x: x.expand(3, -1, -1))
]
)
self.size = image_size
def __getitem__(self, idx):
try:
image = Image.open(self.images[idx])
image = self.transform(image)
except Exception as e:
print(self.images[idx])
image = torch.zeros(3, self.size[0], self.size[1], dtype=torch.uint8)
# print(image)
metadata = dict(
path = self.images[idx],
root = self.root,
)
return image #, metadata
def __len__(self):
return len(self.images) |