File size: 1,963 Bytes
b446886 82e5f44 b446886 db72572 b446886 82e5f44 b446886 82e5f44 b446886 82e5f44 b446886 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import os
import click
import PIL
from itertools import batched
import numpy as np
import torch
import torchvision.transforms as T
from diffusers import AutoencoderKL
from tqdm import tqdm
@click.command()
@click.option("--model_name", type=str, default="stabilityai/stable-diffusion-2-1")
@click.option("--swim_dir", type=str, default="datasets/swim_data")
@click.option("--batch_size", type=int, default=1)
def compute_latent(model_name: str, swim_dir: str, batch_size: int):
model = AutoencoderKL.from_pretrained(model_name, subfolder="vae").cuda()
model.eval()
# create folder for latent vectors
os.makedirs(os.path.join(swim_dir, "train/latents"), exist_ok=True)
os.makedirs(os.path.join(swim_dir, "val/latents"), exist_ok=True)
transforms = T.Compose(
[
T.Resize(512),
T.CenterCrop(512),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
]
)
for split in ["train", "val"]:
output = os.path.join(swim_dir, split, "latents")
for image_names in tqdm(
list(
batched(os.listdir(os.path.join(swim_dir, split, "images")), batch_size)
)
):
images = [
transforms(
PIL.Image.open(os.path.join(swim_dir, split, "images", name))
)
for name in image_names
]
with torch.no_grad():
images = torch.stack(images).cuda()
latents = model.encode(images).latent_dist.mode()
latents = latents.detach().cpu().numpy()
for name, latent in zip(image_names, latents):
np.save(
os.path.join(
output, name.replace(".jpg", ".npy").replace(".png", ".npy")
),
latent,
)
if __name__ == "__main__":
compute_latent()
|