StupidGame's picture
Upload 1941 files
baa8e90
raw
history blame contribute delete
No virus
2.46 kB
import os
import torch
import hashlib
import argparse
import numpy as np
from torchvision import transforms
from diffusers import AutoencoderKL
from tqdm import tqdm
from PIL import Image
from vae import get_vae
def parse_args():
parser = argparse.ArgumentParser(description="Preprocess images into latents")
parser.add_argument("-r", "--res", type=int, default=512, help="Source resolution")
parser.add_argument("-f", "--fac", type=float, default=1.5, help="Upscale factor")
parser.add_argument("-v", "--ver", choices=["v1","xl"], default="v1", help="SD version")
parser.add_argument('--vae', help="Path to VAE (Optional)")
parser.add_argument('--src', default="raw", help="Source folder with images")
return parser.parse_args()
def encode(vae, img):
"""image [PIL Image] -> latent [np array]"""
inp = transforms.ToTensor()(img).unsqueeze(0)
inp = inp.to("cuda") # move to GPU
latent = vae.encode(inp*2.0-1.0)
latent = latent.latent_dist.sample()
return latent.cpu().detach()
def scale(path, res):
"""Crop image to the top-left corner"""
img = Image.open(path)
img = img.convert('RGB')
target = (res, res)
if min(img.height, img.width) < 256:
return
if img.width > img.height:
target = (int(img.width/img.height*res), res)
elif img.height > img.width:
target = (res, int(img.height/img.width*res))
img = img.resize(target, Image.LANCZOS)
img = img.crop([0,0,res,res])
return img
def process_folder(vae, src_dir, ver, res):
dst_dir = f"latents/{ver}_{res}px"
if not os.path.isdir(dst_dir):
os.mkdir(dst_dir)
for file in tqdm(os.listdir(src_dir)):
src = os.path.join(src_dir, file)
md5 = hashlib.md5(open(src,'rb').read()).hexdigest()
dst = os.path.join(dst_dir, f"{md5}.npy")
if os.path.isfile(dst):
continue
img = scale(src, res)
latent = encode(vae, img)
np.save(dst, latent)
def process_res(vae, src_dir, ver, res):
process_folder(vae, src_dir, ver, res)
# test image, optional
if os.path.isfile("test.png"):
if os.path.isfile(f"test_{ver}_{res}px.npy"):
return
img = scale("test.png", res)
latent = encode(vae, img)
np.save(f"test_{ver}_{res}px.npy", latent)
torch.cuda.empty_cache()
if __name__ == "__main__":
if not os.path.isdir("latents"):
os.mkdir("latents")
args = parse_args()
vae = get_vae(args.ver, args.vae)
vae.to("cuda")
## args
dst_res = int(args.res*args.fac)
process_res(vae, args.src, args.ver, args.res)
process_res(vae, args.src, args.ver, dst_res)