File size: 1,396 Bytes
afe1a07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch 
from modules.autoencoder import AutoEncoder, AutoEncoderParams
from modules.conditioner import HFEmbedder
from safetensors.torch import load_file as load_sft 


def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
    # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
    return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)


def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
    return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)


def load_clap(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
    return HFEmbedder("laion/larger_clap_music", max_length=256, torch_dtype=torch.bfloat16).to(device)

def load_ae(ckpt_path, device: str | torch.device = "cuda",) -> AutoEncoder:
    ae_params=AutoEncoderParams(
        resolution=256,
        in_channels=3,
        ch=128,
        out_ch=3,
        ch_mult=[1, 2, 4, 4],
        num_res_blocks=2,
        z_channels=16,
        scale_factor=0.3611,
        shift_factor=0.1159,
    )
    # Loading the autoencoder
    ae = AutoEncoder(ae_params)
    sd = load_sft(ckpt_path,)
    missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
    ae.to(device)
    return ae