Spaces:
Paused
Paused
import argparse | |
import datetime | |
import logging | |
from pathlib import Path | |
from typing import Any, List, Literal, Tuple | |
from pydantic import BaseModel, ValidationInfo, field_validator | |
class Args(BaseModel): | |
########## Model ########## | |
model_path: Path | |
model_name: str | |
model_type: Literal["i2v", "t2v", "i2vFlow"] # i2vFlow for FloVD | |
training_type: Literal["lora", "sft", "controlnet"] = "lora" | |
additional_save_blocks: List[str] | None = None | |
depth_ckpt_path: str | |
########## Output ########## | |
output_dir: Path = Path("train_results/{:%Y-%m-%d-%H-%M-%S}".format(datetime.datetime.now())) | |
report_to: Literal["tensorboard", "wandb", "all"] | None = None | |
tracker_name: str = "finetrainer-cogvideo" | |
run_name: str = "CogVideoX" | |
########## Data ########### | |
data_root: Path | |
caption_column: Path | |
image_column: Path | None = None | |
video_column: Path | |
########## Training ######### | |
resume_from_checkpoint: Path | None = None | |
seed: int | None = None | |
train_epochs: int | |
train_steps: int | None = None | |
checkpointing_steps: int = 200 | |
checkpointing_limit: int = 10 | |
batch_size: int | |
gradient_accumulation_steps: int = 1 | |
train_resolution: Tuple[int, int, int] # shape: (frames, height, width) | |
#### deprecated args: video_resolution_buckets | |
# if use bucket for training, should not be None | |
# Note1: At least one frame rate in the bucket must be less than or equal to the frame rate of any video in the dataset | |
# Note2: For cogvideox, cogvideox1.5 | |
# The frame rate set in the bucket must be an integer multiple of 8 (spatial_compression_rate[4] * path_t[2] = 8) | |
# The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8]) | |
# video_resolution_buckets: List[Tuple[int, int, int]] | None = None | |
mixed_precision: Literal["no", "fp16", "bf16"] | |
learning_rate: float = 2e-5 | |
optimizer: str = "adamw" | |
beta1: float = 0.9 | |
beta2: float = 0.95 | |
beta3: float = 0.98 | |
epsilon: float = 1e-8 | |
weight_decay: float = 1e-4 | |
max_grad_norm: float = 1.0 | |
lr_scheduler: str = "constant_with_warmup" | |
lr_warmup_steps: int = 100 | |
lr_num_cycles: int = 1 | |
lr_power: float = 1.0 | |
num_workers: int = 8 | |
pin_memory: bool = True | |
gradient_checkpointing: bool = True | |
enable_slicing: bool = True | |
enable_tiling: bool = True | |
nccl_timeout: int = 1800 | |
########## Lora ########## | |
rank: int = 128 | |
lora_alpha: int = 64 | |
target_modules: List[str] = ["to_q", "to_k", "to_v", "to_out.0"] | |
########## Validation ########## | |
do_validation: bool = False | |
validation_steps: int | None # if set, should be a multiple of checkpointing_steps | |
validation_dir: Path | None # if set do_validation, should not be None | |
validation_prompts: str | None # if set do_validation, should not be None | |
validation_images: str | None # if set do_validation and model_type == i2v, should not be None | |
validation_videos: str | None # if set do_validation and model_type == v2v, should not be None | |
gen_fps: int = 15 | |
max_scene: int = 8 | |
########## Controlnet ########## | |
controlnet_transformer_num_layers: int = 8 | |
controlnet_input_channels: int = 16 | |
controlnet_weights: float = 1.0 | |
controlnet_guidance_start: float = 0.0 | |
controlnet_guidance_end: float = 1.0 | |
controlnet_out_proj_dim_factor: int = 64 | |
controlnet_out_proj_zero_init: bool = True | |
enable_time_sampling: bool = True | |
time_sampling_type: str = 'truncated_normal' | |
time_sampling_mean: float = 0.95 | |
time_sampling_std: float = 0.1 | |
use_valid_mask: bool = False | |
notextinflow: bool = False | |
#### deprecated args: gen_video_resolution | |
# 1. If set do_validation, should not be None | |
# 2. Suggest selecting the bucket from `video_resolution_buckets` that is closest to the resolution you have chosen for fine-tuning | |
# or the resolution recommended by the model | |
# 3. Note: For cogvideox, cogvideox1.5 | |
# The frame rate set in the bucket must be an integer multiple of 8 (spatial_compression_rate[4] * path_t[2] = 8) | |
# The height and width set in the bucket must be an integer multiple of 8 (temporal_compression_rate[8]) | |
# gen_video_resolution: Tuple[int, int, int] | None # shape: (frames, height, width) | |
def validate_image_column(cls, v: str | None, info: ValidationInfo) -> str | None: | |
values = info.data | |
if values.get("model_type") == "i2v" and not v: | |
logging.warning( | |
"No `image_column` specified for i2v model. Will automatically extract first frames from videos as conditioning images." | |
) | |
return v | |
def validate_validation_required_fields(cls, v: Any, info: ValidationInfo) -> Any: | |
values = info.data | |
if values.get("do_validation") and not v: | |
field_name = info.field_name | |
raise ValueError(f"{field_name} must be specified when do_validation is True") | |
return v | |
def validate_validation_images(cls, v: str | None, info: ValidationInfo) -> str | None: | |
values = info.data | |
if values.get("do_validation") and values.get("model_type") == "i2v" and not v: | |
raise ValueError("validation_images must be specified when do_validation is True and model_type is i2v") | |
return v | |
def validate_validation_videos(cls, v: str | None, info: ValidationInfo) -> str | None: | |
values = info.data | |
if values.get("do_validation") and values.get("model_type") == "v2v" and not v: | |
raise ValueError("validation_videos must be specified when do_validation is True and model_type is v2v") | |
return v | |
def validate_validation_steps(cls, v: int | None, info: ValidationInfo) -> int | None: | |
values = info.data | |
if values.get("do_validation"): | |
if v is None: | |
raise ValueError("validation_steps must be specified when do_validation is True") | |
if values.get("checkpointing_steps") and v % values["checkpointing_steps"] != 0: | |
raise ValueError("validation_steps must be a multiple of checkpointing_steps") | |
return v | |
def validate_train_resolution(cls, v: Tuple[int, int, int], info: ValidationInfo) -> str: | |
try: | |
frames, height, width = v | |
# Check if (frames - 1) is multiple of 8 | |
if (frames - 1) % 8 != 0: | |
raise ValueError("Number of frames - 1 must be a multiple of 8") | |
# Check resolution for cogvideox-5b models | |
model_name = info.data.get("model_name", "") | |
if model_name in ["cogvideox-5b-i2v", "cogvideox-5b-t2v"]: | |
if (height, width) != (480, 720): | |
raise ValueError("For cogvideox-5b models, height must be 480 and width must be 720") | |
return v | |
except ValueError as e: | |
if ( | |
str(e) == "not enough values to unpack (expected 3, got 0)" | |
or str(e) == "invalid literal for int() with base 10" | |
): | |
raise ValueError("train_resolution must be in format 'frames x height x width'") | |
raise e | |
def validate_mixed_precision(cls, v: str, info: ValidationInfo) -> str: | |
if v == "fp16" and "cogvideox-2b" not in str(info.data.get("model_path", "")).lower(): | |
logging.warning( | |
"All CogVideoX models except cogvideox-2b were trained with bfloat16. " | |
"Using fp16 precision may lead to training instability." | |
) | |
return v | |
def parse_args(cls): | |
"""Parse command line arguments and return Args instance""" | |
parser = argparse.ArgumentParser() | |
# Required arguments | |
parser.add_argument("--model_path", type=str, required=True) | |
parser.add_argument("--model_name", type=str, required=True) | |
parser.add_argument("--model_type", type=str, required=True) | |
parser.add_argument("--depth_ckpt_path", type=str, required=False, default="./ckpt/others/depth_anything_v2_metric_hypersim_vitb.pth", help="Path to the checkpoint of the depth estimation networks") | |
parser.add_argument("--training_type", type=str, required=True) | |
parser.add_argument("--additional_save_blocks", type=str, required=False, default=None) | |
parser.add_argument("--output_dir", type=str, required=True) | |
parser.add_argument("--data_root", type=str, required=True) | |
parser.add_argument("--caption_column", type=str, required=True) | |
parser.add_argument("--video_column", type=str, required=True) | |
parser.add_argument("--train_resolution", type=str, required=True) | |
parser.add_argument("--report_to", type=str, required=True) | |
parser.add_argument("--run_name", type=str, required=False, default='CogVideoX') | |
# Training hyperparameters | |
parser.add_argument("--seed", type=int, default=42) | |
parser.add_argument("--train_epochs", type=int, default=10) | |
parser.add_argument("--train_steps", type=int, default=None) | |
parser.add_argument("--gradient_accumulation_steps", type=int, default=1) | |
parser.add_argument("--batch_size", type=int, default=1) | |
parser.add_argument("--learning_rate", type=float, default=2e-5) | |
parser.add_argument("--optimizer", type=str, default="adamw") | |
parser.add_argument("--beta1", type=float, default=0.9) | |
parser.add_argument("--beta2", type=float, default=0.95) | |
parser.add_argument("--beta3", type=float, default=0.98) | |
parser.add_argument("--epsilon", type=float, default=1e-8) | |
parser.add_argument("--weight_decay", type=float, default=1e-4) | |
parser.add_argument("--max_grad_norm", type=float, default=1.0) | |
# Learning rate scheduler | |
parser.add_argument("--lr_scheduler", type=str, default="constant_with_warmup") | |
parser.add_argument("--lr_warmup_steps", type=int, default=100) | |
parser.add_argument("--lr_num_cycles", type=int, default=1) | |
parser.add_argument("--lr_power", type=float, default=1.0) | |
# Data loading | |
parser.add_argument("--num_workers", type=int, default=8) | |
parser.add_argument("--pin_memory", type=bool, default=True) | |
parser.add_argument("--image_column", type=str, default=None) | |
# Model configuration | |
parser.add_argument("--mixed_precision", type=str, default="no") | |
parser.add_argument("--gradient_checkpointing", type=bool, default=True) | |
parser.add_argument("--enable_slicing", type=bool, default=True) | |
parser.add_argument("--enable_tiling", type=bool, default=True) | |
parser.add_argument("--nccl_timeout", type=int, default=1800) | |
# LoRA parameters | |
parser.add_argument("--rank", type=int, default=128) | |
parser.add_argument("--lora_alpha", type=int, default=64) | |
parser.add_argument("--target_modules", type=str, nargs="+", default=["to_q", "to_k", "to_v", "to_out.0"]) | |
# Checkpointing | |
parser.add_argument("--checkpointing_steps", type=int, default=200) | |
parser.add_argument("--checkpointing_limit", type=int, default=10) | |
parser.add_argument("--resume_from_checkpoint", type=str, default=None) | |
# Validation | |
parser.add_argument("--do_validation", type=lambda x: x.lower() == 'true', default=False) | |
parser.add_argument("--validation_steps", type=int, default=None) | |
parser.add_argument("--validation_dir", type=str, default=None) | |
parser.add_argument("--validation_prompts", type=str, default=None) | |
parser.add_argument("--validation_images", type=str, default=None) | |
parser.add_argument("--validation_videos", type=str, default=None) | |
parser.add_argument("--gen_fps", type=int, default=15) | |
parser.add_argument("--max_scene", type=int, default=8) | |
# Controlnet | |
parser.add_argument("--controlnet_transformer_num_layers", type=int, default=8) | |
parser.add_argument("--controlnet_input_channels", type=int, default=16) | |
parser.add_argument("--controlnet_weights", type=float, default=1.0) | |
parser.add_argument("--controlnet_guidance_start", type=float, default=0.0) | |
parser.add_argument("--controlnet_guidance_end", type=float, default=1.0) | |
parser.add_argument("--controlnet_out_proj_dim_factor", type=int, default=64) | |
parser.add_argument("--controlnet_out_proj_zero_init", type=bool, default=True) | |
parser.add_argument("--enable_time_sampling", type=bool, default=True) | |
# parser.add_argument("--enable_time_sampling", type=lambda x: x.lower() == 'true', default=False) | |
parser.add_argument("--time_sampling_type", type=str, default='truncated_normal') | |
parser.add_argument("--time_sampling_mean", type=float, default=0.95) | |
parser.add_argument("--time_sampling_std", type=float, default=0.1) | |
parser.add_argument("--use_valid_mask", type=bool, default=False) | |
parser.add_argument("--notextinflow", type=bool, default=False) | |
args = parser.parse_args() | |
# Convert video_resolution_buckets string to list of tuples | |
frames, height, width = args.train_resolution.split("x") | |
args.train_resolution = (int(frames), int(height), int(width)) | |
if args.additional_save_blocks is not None: | |
args.additional_save_blocks = args.additional_save_blocks.split(',') | |
if not args.training_type == 'lora': | |
# Use additional trainable blocks only for 'lora' setting | |
assert args.additional_save_blocks is None | |
return cls(**vars(args)) | |