File size: 1,938 Bytes
3ed0796 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch
from .dc_ae import MyAutoencoderDC as AutoencoderDC
from .sd_vae import MyAutoencoderKL as AutoencoderKL
# dc-ae
def get_dc_ae(vae_dir, dtype, device):
dc_ae = AutoencoderDC.from_pretrained(vae_dir).to(dtype=dtype, device=device)
dc_ae.eval()
# Set requires_grad to False for all parameters to avoid functorch issues
# for param in dc_ae.parameters():
# param.requires_grad = False
return dc_ae
def dc_ae_encode(dc_ae, images):
with torch.no_grad():
z = dc_ae.encode(images).latent
latents = (z - dc_ae.mean) / dc_ae.std
return latents
def dc_ae_decode(dc_ae, latents, slice_vae=False):
with torch.no_grad():
z = latents * dc_ae.std + dc_ae.mean
if slice_vae and z.size(0) > 1:
decoded_slices = [dc_ae._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = dc_ae._decode(z)
images = decoded # decoded images
return images
# sd-vae
def get_sd_vae(vae_dir, dtype, device):
sd_vae = AutoencoderKL.from_pretrained(vae_dir).to(dtype=dtype, device=device)
sd_vae.eval()
# Set requires_grad to False for all parameters to avoid functorch issues
# for param in sd_vae.parameters():
# param.requires_grad = False
return sd_vae
def sd_vae_encode(sd_vae, images):
with torch.no_grad():
posterior = sd_vae.encode(images)
z = posterior.latent_dist.sample()
latents = (z - sd_vae.mean) / sd_vae.std
return latents
def sd_vae_decode(sd_vae, latents, slice_vae=False):
with torch.no_grad():
z = latents * sd_vae.std + sd_vae.mean
if slice_vae and z.shape[0] > 1:
decoded_slices = [sd_vae._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = sd_vae._decode(z).sample
return decoded
|