|
import os |
|
from typing import Optional, Union |
|
|
|
import torch |
|
from omegaconf import OmegaConf |
|
from .model.dit import get_dit, parallelize |
|
from .model.text_embedders import get_text_embedder |
|
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler |
|
from omegaconf.dictconfig import DictConfig |
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
|
|
from .t2v_pipeline import Kandinsky4T2VPipeline |
|
|
|
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh |
|
|
|
|
|
def get_T2V_pipeline( |
|
device_map: Union[str, torch.device, dict], |
|
resolution: int = 512, |
|
cache_dir: str = './weights/', |
|
dit_path: str = None, |
|
text_encoder_path: str = None, |
|
tokenizer_path: str = None, |
|
vae_path: str = None, |
|
scheduler_path: str = None, |
|
conf_path: str = None, |
|
) -> Kandinsky4T2VPipeline: |
|
|
|
assert resolution in [512] |
|
|
|
if not isinstance(device_map, dict): |
|
device_map = { |
|
'dit': device_map, |
|
'vae': device_map, |
|
'text_embedder': device_map |
|
} |
|
|
|
try: |
|
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"]) |
|
except: |
|
local_rank, world_size = 0, 1 |
|
|
|
if world_size > 1: |
|
device_mesh = init_device_mesh("cuda", (world_size,), mesh_dim_names=("tensor_parallel",)) |
|
device_map["dit"] = torch.device(f'cuda:{local_rank}') |
|
|
|
os.makedirs(cache_dir, exist_ok=True) |
|
|
|
if dit_path is None: |
|
dit_path = hf_hub_download( |
|
repo_id="ai-forever/kandinsky4", filename=f"kandinsky4_distil_{resolution}.pt", local_dir=cache_dir |
|
) |
|
|
|
if vae_path is None: |
|
vae_path = snapshot_download( |
|
repo_id="THUDM/CogVideoX-5b", allow_patterns='vae/*', local_dir=cache_dir |
|
) |
|
vae_path = os.path.join(cache_dir, f"vae/") |
|
|
|
if scheduler_path is None: |
|
scheduler_path = snapshot_download( |
|
repo_id="THUDM/CogVideoX-5b", allow_patterns='scheduler/*', local_dir=cache_dir |
|
) |
|
scheduler_path = os.path.join(cache_dir, f"scheduler/") |
|
|
|
if text_encoder_path is None: |
|
text_encoder_path = snapshot_download( |
|
repo_id="THUDM/CogVideoX-5b", allow_patterns='text_encoder/*', local_dir=cache_dir |
|
) |
|
text_encoder_path = os.path.join(cache_dir, f"text_encoder/") |
|
|
|
if tokenizer_path is None: |
|
tokenizer_path = snapshot_download( |
|
repo_id="THUDM/CogVideoX-5b", allow_patterns='tokenizer/*', local_dir=cache_dir |
|
) |
|
tokenizer_path = os.path.join(cache_dir, f"tokenizer/") |
|
|
|
if conf_path is None: |
|
conf = get_default_conf(vae_path, text_encoder_path, tokenizer_path, scheduler_path, dit_path) |
|
else: |
|
conf = OmegaConf.load(conf_path) |
|
|
|
dit = get_dit(conf.dit) |
|
dit = dit.to(dtype=torch.bfloat16, device=device_map["dit"]) |
|
|
|
noise_scheduler = CogVideoXDDIMScheduler.from_pretrained(conf.dit.scheduler) |
|
|
|
if world_size > 1: |
|
dit = parallelize(dit, device_mesh["tensor_parallel"]) |
|
|
|
text_embedder = get_text_embedder(conf) |
|
text_embedder = text_embedder.freeze() |
|
if local_rank == 0: |
|
text_embedder = text_embedder.to(device=device_map["text_embedder"], dtype=torch.bfloat16) |
|
|
|
vae = AutoencoderKLCogVideoX.from_pretrained(conf.vae.checkpoint_path) |
|
vae = vae.eval() |
|
if local_rank == 0: |
|
vae = vae.to(device_map["vae"], dtype=torch.bfloat16) |
|
|
|
return Kandinsky4T2VPipeline( |
|
device_map=device_map, |
|
dit=dit, |
|
text_embedder=text_embedder, |
|
vae=vae, |
|
noise_scheduler=noise_scheduler, |
|
resolution=resolution, |
|
local_dit_rank=local_rank, |
|
world_size=world_size, |
|
) |
|
|
|
|
|
def get_default_conf( |
|
vae_path, |
|
text_encoder_path, |
|
tokenizer_path, |
|
scheduler_path, |
|
dit_path, |
|
) -> DictConfig: |
|
dit_params = { |
|
'in_visual_dim': 16, |
|
'in_text_dim': 4096, |
|
'out_visual_dim': 16, |
|
'time_dim': 512, |
|
'patch_size': [1, 2, 2], |
|
'model_dim': 3072, |
|
'ff_dim': 12288, |
|
'num_blocks': 21, |
|
'axes_dims': [16, 24, 24] |
|
} |
|
|
|
conf = { |
|
'vae': |
|
{ |
|
'checkpoint_path': vae_path |
|
}, |
|
'text_embedder': |
|
{ |
|
'emb_size': 4096, |
|
'tokens_lenght': 224, |
|
'params': |
|
{ |
|
'checkpoint_path': text_encoder_path, |
|
'tokenizer_path': tokenizer_path |
|
} |
|
}, |
|
'dit': |
|
{ |
|
'scheduler': scheduler_path, |
|
'checkpoint_path': dit_path, |
|
'params': dit_params |
|
|
|
}, |
|
'resolution': 512, |
|
} |
|
|
|
return DictConfig(conf) |