File size: 2,455 Bytes
baa8e90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import torch
import hashlib
import argparse
import numpy as np
from torchvision import transforms
from diffusers import AutoencoderKL
from tqdm import tqdm
from PIL import Image

from vae import get_vae

def parse_args():
	parser = argparse.ArgumentParser(description="Preprocess images into latents")
	parser.add_argument("-r", "--res", type=int, default=512, help="Source resolution")
	parser.add_argument("-f", "--fac", type=float, default=1.5, help="Upscale factor")
	parser.add_argument("-v", "--ver", choices=["v1","xl"], default="v1", help="SD version")
	parser.add_argument('--vae', help="Path to VAE (Optional)")
	parser.add_argument('--src', default="raw", help="Source folder with images")
	return parser.parse_args()

def encode(vae, img):
	"""image [PIL Image] -> latent [np array]"""
	inp = transforms.ToTensor()(img).unsqueeze(0)
	inp = inp.to("cuda") # move to GPU
	latent = vae.encode(inp*2.0-1.0)
	latent = latent.latent_dist.sample()
	return latent.cpu().detach()

def scale(path, res):
	"""Crop image to the top-left corner"""
	img = Image.open(path)
	img = img.convert('RGB')
	target = (res, res)
	if min(img.height, img.width) < 256:
		return
	if img.width > img.height:
		target = (int(img.width/img.height*res), res)
	elif img.height > img.width:
		target = (res, int(img.height/img.width*res))
	img = img.resize(target, Image.LANCZOS)
	img = img.crop([0,0,res,res])
	return img

def process_folder(vae, src_dir, ver, res):
	dst_dir = f"latents/{ver}_{res}px"
	if not os.path.isdir(dst_dir):
		os.mkdir(dst_dir)

	for file in tqdm(os.listdir(src_dir)):
		src = os.path.join(src_dir, file)
		md5 = hashlib.md5(open(src,'rb').read()).hexdigest()
		dst = os.path.join(dst_dir, f"{md5}.npy")
		if os.path.isfile(dst):
			continue
		img = scale(src, res)
		latent = encode(vae, img)
		np.save(dst, latent)

def process_res(vae, src_dir, ver, res):
	process_folder(vae, src_dir, ver, res)
	# test image, optional
	if os.path.isfile("test.png"):
		if os.path.isfile(f"test_{ver}_{res}px.npy"):
			return
		img = scale("test.png", res)
		latent = encode(vae, img)
		np.save(f"test_{ver}_{res}px.npy", latent)
	torch.cuda.empty_cache()

if __name__ == "__main__":
	if not os.path.isdir("latents"):
		os.mkdir("latents")
	args = parse_args()
	vae = get_vae(args.ver, args.vae)
	vae.to("cuda")
	## args
	dst_res = int(args.res*args.fac)
	process_res(vae, args.src, args.ver, args.res)
	process_res(vae, args.src, args.ver, dst_res)