Spaces:
Running
Running
import gc | |
import os | |
import torch | |
from safetensors.torch import load_file | |
from .clip import CLIPModel | |
from .t5 import T5EncoderModel | |
from .transformer import WanModel | |
from .vae import WanVAE | |
def download_model(model_id): | |
if not os.path.exists(model_id): | |
from huggingface_hub import snapshot_download | |
model_id = snapshot_download(repo_id=model_id) | |
return model_id | |
def get_vae(model_path, device="cuda", weight_dtype=torch.float32) -> WanVAE: | |
vae = WanVAE(model_path).to(device).to(weight_dtype) | |
vae.vae.requires_grad_(False) | |
vae.vae.eval() | |
gc.collect() | |
torch.cuda.empty_cache() | |
return vae | |
def get_transformer(model_path, device="cuda", weight_dtype=torch.bfloat16) -> WanModel: | |
config_path = os.path.join(model_path, "config.json") | |
transformer = WanModel.from_config(config_path).to(weight_dtype).to(device) | |
for file in os.listdir(model_path): | |
if file.endswith(".safetensors"): | |
file_path = os.path.join(model_path, file) | |
state_dict = load_file(file_path) | |
transformer.load_state_dict(state_dict, strict=False) | |
del state_dict | |
gc.collect() | |
torch.cuda.empty_cache() | |
transformer.requires_grad_(False) | |
transformer.eval() | |
gc.collect() | |
torch.cuda.empty_cache() | |
return transformer | |
def get_text_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> T5EncoderModel: | |
t5_model = os.path.join(model_path, "models_t5_umt5-xxl-enc-bf16.pth") | |
tokenizer_path = os.path.join(model_path, "google", "umt5-xxl") | |
text_encoder = T5EncoderModel(checkpoint_path=t5_model, tokenizer_path=tokenizer_path).to(device).to(weight_dtype) | |
text_encoder.requires_grad_(False) | |
text_encoder.eval() | |
gc.collect() | |
torch.cuda.empty_cache() | |
return text_encoder | |
def get_image_encoder(model_path, device="cuda", weight_dtype=torch.bfloat16) -> CLIPModel: | |
checkpoint_path = os.path.join(model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth") | |
tokenizer_path = os.path.join(model_path, "xlm-roberta-large") | |
image_enc = CLIPModel(checkpoint_path, tokenizer_path).to(weight_dtype).to(device) | |
image_enc.requires_grad_(False) | |
image_enc.eval() | |
gc.collect() | |
torch.cuda.empty_cache() | |
return image_enc | |