|
import os |
|
import click |
|
import PIL |
|
|
|
from itertools import batched |
|
|
|
import numpy as np |
|
import torch |
|
import torchvision.transforms as T |
|
from diffusers import AutoencoderKL |
|
from tqdm import tqdm |
|
|
|
|
|
@click.command() |
|
@click.option("--model_name", type=str, default="stabilityai/stable-diffusion-2-1") |
|
@click.option("--swim_dir", type=str, default="datasets/swim_data") |
|
@click.option("--batch_size", type=int, default=1) |
|
def compute_latent(model_name: str, swim_dir: str, batch_size: int): |
|
model = AutoencoderKL.from_pretrained(model_name, subfolder="vae").cuda() |
|
model.eval() |
|
|
|
|
|
os.makedirs(os.path.join(swim_dir, "train/latents"), exist_ok=True) |
|
os.makedirs(os.path.join(swim_dir, "val/latents"), exist_ok=True) |
|
|
|
transforms = T.Compose( |
|
[ |
|
T.Resize(512), |
|
T.CenterCrop(512), |
|
T.ToTensor(), |
|
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
|
] |
|
) |
|
|
|
for split in ["train", "val"]: |
|
output = os.path.join(swim_dir, split, "latents") |
|
|
|
for image_names in tqdm( |
|
list( |
|
batched(os.listdir(os.path.join(swim_dir, split, "images")), batch_size) |
|
) |
|
): |
|
images = [ |
|
transforms( |
|
PIL.Image.open(os.path.join(swim_dir, split, "images", name)) |
|
) |
|
for name in image_names |
|
] |
|
with torch.no_grad(): |
|
images = torch.stack(images).cuda() |
|
latents = model.encode(images).latent_dist.mode() |
|
latents = latents.detach().cpu().numpy() |
|
|
|
for name, latent in zip(image_names, latents): |
|
np.save( |
|
os.path.join( |
|
output, name.replace(".jpg", ".npy").replace(".png", ".npy") |
|
), |
|
latent, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
compute_latent() |
|
|