|
|
import torch |
|
|
from diffusers import AutoencoderKLCogVideoX, CogVideoXDDIMScheduler, CogVideoXTransformer3DModel |
|
|
from transformers import AutoTokenizer, T5EncoderModel |
|
|
|
|
|
from finetrainers.models.cogvideox import CogVideoXModelSpecification |
|
|
|
|
|
|
|
|
class DummyCogVideoXModelSpecification(CogVideoXModelSpecification): |
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
def load_condition_models(self): |
|
|
text_encoder = T5EncoderModel.from_pretrained( |
|
|
"hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype |
|
|
) |
|
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") |
|
|
return {"text_encoder": text_encoder, "tokenizer": tokenizer} |
|
|
|
|
|
def load_latent_models(self): |
|
|
torch.manual_seed(0) |
|
|
vae = AutoencoderKLCogVideoX( |
|
|
in_channels=3, |
|
|
out_channels=3, |
|
|
down_block_types=( |
|
|
"CogVideoXDownBlock3D", |
|
|
"CogVideoXDownBlock3D", |
|
|
"CogVideoXDownBlock3D", |
|
|
"CogVideoXDownBlock3D", |
|
|
), |
|
|
up_block_types=( |
|
|
"CogVideoXUpBlock3D", |
|
|
"CogVideoXUpBlock3D", |
|
|
"CogVideoXUpBlock3D", |
|
|
"CogVideoXUpBlock3D", |
|
|
), |
|
|
block_out_channels=(8, 8, 8, 8), |
|
|
latent_channels=4, |
|
|
layers_per_block=1, |
|
|
norm_num_groups=2, |
|
|
temporal_compression_ratio=4, |
|
|
) |
|
|
|
|
|
|
|
|
vae.to(self.vae_dtype) |
|
|
self.vae_config = vae.config |
|
|
return {"vae": vae} |
|
|
|
|
|
def load_diffusion_models(self): |
|
|
torch.manual_seed(0) |
|
|
transformer = CogVideoXTransformer3DModel( |
|
|
num_attention_heads=4, |
|
|
attention_head_dim=16, |
|
|
in_channels=4, |
|
|
out_channels=4, |
|
|
time_embed_dim=2, |
|
|
text_embed_dim=32, |
|
|
num_layers=2, |
|
|
sample_width=24, |
|
|
sample_height=24, |
|
|
sample_frames=9, |
|
|
patch_size=2, |
|
|
temporal_compression_ratio=4, |
|
|
max_text_seq_length=16, |
|
|
use_rotary_positional_embeddings=True, |
|
|
) |
|
|
|
|
|
|
|
|
transformer.to(self.transformer_dtype) |
|
|
self.transformer_config = transformer.config |
|
|
scheduler = CogVideoXDDIMScheduler() |
|
|
return {"transformer": transformer, "scheduler": scheduler} |
|
|
|