|
import argparse |
|
import multiprocessing |
|
import os |
|
from os.path import join, exists |
|
from functools import partial |
|
from io import BytesIO |
|
import shutil |
|
|
|
import lmdb |
|
from PIL import Image |
|
from torchvision.datasets import LSUNClass |
|
from torchvision.transforms import functional as trans_fn |
|
from tqdm import tqdm |
|
|
|
from multiprocessing import Process, Queue |
|
|
|
|
|
def resize_and_convert(img, size, resample, quality=100): |
|
img = trans_fn.resize(img, size, resample) |
|
img = trans_fn.center_crop(img, size) |
|
buffer = BytesIO() |
|
img.save(buffer, format="webp", quality=quality) |
|
val = buffer.getvalue() |
|
|
|
return val |
|
|
|
|
|
def resize_multiple(img, |
|
sizes=(128, 256, 512, 1024), |
|
resample=Image.LANCZOS, |
|
quality=100): |
|
imgs = [] |
|
|
|
for size in sizes: |
|
imgs.append(resize_and_convert(img, size, resample, quality)) |
|
|
|
return imgs |
|
|
|
|
|
def resize_worker(idx, img, sizes, resample): |
|
img = img.convert("RGB") |
|
out = resize_multiple(img, sizes=sizes, resample=resample) |
|
return idx, out |
|
|
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
class ConvertDataset(Dataset): |
|
def __init__(self, data) -> None: |
|
self.data = data |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
img, _ = self.data[index] |
|
bytes = resize_and_convert(img, 256, Image.LANCZOS, quality=90) |
|
return bytes |
|
|
|
|
|
if __name__ == "__main__": |
|
""" |
|
converting lsun' original lmdb to our lmdb, which is somehow more performant. |
|
""" |
|
from tqdm import tqdm |
|
|
|
|
|
src_path = 'datasets/bedroom_train_lmdb' |
|
out_path = 'datasets/bedroom256.lmdb' |
|
|
|
dataset = LSUNClass(root=os.path.expanduser(src_path)) |
|
dataset = ConvertDataset(dataset) |
|
loader = DataLoader(dataset, |
|
batch_size=50, |
|
num_workers=12, |
|
collate_fn=lambda x: x, |
|
shuffle=False) |
|
|
|
target = os.path.expanduser(out_path) |
|
if os.path.exists(target): |
|
shutil.rmtree(target) |
|
|
|
with lmdb.open(target, map_size=1024**4, readahead=False) as env: |
|
with tqdm(total=len(dataset)) as progress: |
|
i = 0 |
|
for batch in loader: |
|
with env.begin(write=True) as txn: |
|
for img in batch: |
|
key = f"{256}-{str(i).zfill(7)}".encode("utf-8") |
|
|
|
txn.put(key, img) |
|
i += 1 |
|
progress.update() |
|
|
|
|
|
|
|
|
|
|
|
with env.begin(write=True) as txn: |
|
txn.put("length".encode("utf-8"), str(i).encode("utf-8")) |
|
|