Spaces:
Running
Running
import torch | |
from modules.autoencoder import AutoEncoder, AutoEncoderParams | |
from modules.conditioner import HFEmbedder | |
from safetensors.torch import load_file as load_sft | |
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: | |
# max length 64, 128, 256 and 512 should work (if your sequence is short enough) | |
return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device) | |
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: | |
return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device) | |
def load_clap(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: | |
return HFEmbedder("laion/larger_clap_music", max_length=256, torch_dtype=torch.bfloat16).to(device) | |
def load_ae(ckpt_path, device: str | torch.device = "cuda",) -> AutoEncoder: | |
ae_params=AutoEncoderParams( | |
resolution=256, | |
in_channels=3, | |
ch=128, | |
out_ch=3, | |
ch_mult=[1, 2, 4, 4], | |
num_res_blocks=2, | |
z_channels=16, | |
scale_factor=0.3611, | |
shift_factor=0.1159, | |
) | |
# Loading the autoencoder | |
ae = AutoEncoder(ae_params) | |
sd = load_sft(ckpt_path,) | |
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) | |
ae.to(device) | |
return ae | |