File size: 1,938 Bytes
3ed0796
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
from .dc_ae import MyAutoencoderDC as AutoencoderDC
from .sd_vae import MyAutoencoderKL as AutoencoderKL


# dc-ae
def get_dc_ae(vae_dir, dtype, device):
    dc_ae = AutoencoderDC.from_pretrained(vae_dir).to(dtype=dtype, device=device)
    dc_ae.eval()
    # Set requires_grad to False for all parameters to avoid functorch issues
    # for param in dc_ae.parameters():
    #     param.requires_grad = False
    return dc_ae


def dc_ae_encode(dc_ae, images):
    with torch.no_grad():
        z = dc_ae.encode(images).latent
        latents = (z - dc_ae.mean) / dc_ae.std
    return latents


def dc_ae_decode(dc_ae, latents, slice_vae=False):
    with torch.no_grad():
        z = latents * dc_ae.std + dc_ae.mean
        if slice_vae and z.size(0) > 1:
            decoded_slices = [dc_ae._decode(z_slice) for z_slice in z.split(1)]
            decoded = torch.cat(decoded_slices)
        else:
            decoded = dc_ae._decode(z)
        images = decoded  # decoded images
    return images


# sd-vae
def get_sd_vae(vae_dir, dtype, device):
    sd_vae = AutoencoderKL.from_pretrained(vae_dir).to(dtype=dtype, device=device)
    sd_vae.eval()
    # Set requires_grad to False for all parameters to avoid functorch issues
    # for param in sd_vae.parameters():
    #     param.requires_grad = False
    return sd_vae


def sd_vae_encode(sd_vae, images):
    with torch.no_grad():
        posterior = sd_vae.encode(images)
        z = posterior.latent_dist.sample()
        latents = (z - sd_vae.mean) / sd_vae.std
    return latents


def sd_vae_decode(sd_vae, latents, slice_vae=False):
    with torch.no_grad():
        z = latents * sd_vae.std + sd_vae.mean
        if slice_vae and z.shape[0] > 1:
            decoded_slices = [sd_vae._decode(z_slice).sample for z_slice in z.split(1)]
            decoded = torch.cat(decoded_slices)
        else:
            decoded = sd_vae._decode(z).sample
    return decoded